--- library_name: transformers tags: [] --- Foundation Neural-Network Quantum State trained on the Ising in transverse field model on a chain with \\(L=100\\) sites. The system is described by the following Hamiltonian (with periodic boundary conditions): $$ \hat{H} = -J\sum_{i=1}^N \hat{S}_i^z \hat{S}_{i+1}^z - h \sum_{i=1}^N \hat{S}_i^x \ , $$ where \\(\hat{S}_i^x\\) and \\(\hat{S}_i^z\\) are spin- \\(1/2\\) operators on site \\(i\\). The model has been trained on \\(R=6000\\) different values of the field \\(h\\) equispaced in the interval \\(h \in [0.8, 1.2]\\), using a total batch size of \\(M=12000\\) samples. The computation has been distributed over 4 A100-64GB GPUs for few hours. ## How to Get Started with the Model Use the code below to get started with the model. In particular, we sample the model for a fixed value of the external field \\(h\\) using NetKet. ```python from functools import partial import numpy as np import jax import jax.numpy as jnp import netket as nk import flax from flax.training import checkpoints flax.config.update('flax_use_orbax_checkpointing', False) lattice = nk.graph.Hypercube(length=100, n_dim=1, pbc=True) revision = "main" h = 1.0 #* fix the value of the external field assert h >= 0.8 and h <= 1.2 #* the model has been trained on this interval from transformers import FlaxAutoModel wf = FlaxAutoModel.from_pretrained("nqs-models/ising_fnqs", trust_remote_code=True) N_params = nk.jax.tree_size(wf.params) print('Number of parameters = ', N_params, flush=True) hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes) hamiltonian = nk.operator.IsingJax(hilbert=hilbert, graph=lattice, h=h, J=-1.0) action = nk.sampler.rules.LocalRule() sampler = nk.sampler.MetropolisSampler(hilbert=hilbert, rule=action, n_chains=12000, n_sweeps=lattice.n_nodes) key = jax.random.PRNGKey(0) key, subkey = jax.random.split(key, 2) vstate = nk.vqs.MCState(sampler=sampler, apply_fun=partial(wf.__call__, coups=h), sampler_seed=subkey, n_samples=12000, n_discard_per_chain=0, variables=wf.params, chunk_size=12000) # start from thermalized configurations from huggingface_hub import hf_hub_download path = hf_hub_download(repo_id="nqs-models/ising_fnqs", filename="spins", revision=revision) samples = checkpoints.restore_checkpoint(path, prefix="spins", target=None) samples = jnp.array(samples, dtype='int8') vstate.sampler_state = vstate.sampler_state.replace(σ = samples) import time # Sample the model for _ in range(10): start = time.time() E = vstate.expect(hamiltonian) vstate.sample() print("Mean: ", E.mean.real / lattice.n_nodes, "\t time=", time.time()-start) ``` The time per sweep is 3.5s, evaluated on a single A100-40GB GPU. ### Extract hidden representation The hidden representation associated to the input batch of configurations can be extracted as: ```python wf = FlaxAutoModel.from_pretrained("nqs-models/ising_fnqs", trust_remote_code=True, return_z=True) z = wf(wf.params, samples, h) ``` #### Training Hyperparameters Number of layers: 6 Embedding dimension: 72 Hidden dimension: 144 Number of heads: 12 Patch size: 4 Total number of parameters: 198288 ## Model Card Contact Riccardo Rende (rrende@sissa.it) Luciano Loris Viteritti (luciano.viteritti@epfl.ch)