| | --- |
| | library_name: transformers |
| | tags: [] |
| | --- |
| | |
| | <!-- Provide a quick summary of what the model is/does. --> |
| | Foundation Neural-Network Quantum State trained on the Ising in transverse field model on a chain with \\(L=100\\) sites. |
| | The system is described by the following Hamiltonian (with periodic boundary conditions): |
| |
|
| | $$ |
| | \hat{H} = -J\sum_{i=1}^N \hat{S}_i^z \hat{S}_{i+1}^z - h \sum_{i=1}^N \hat{S}_i^x \ , |
| | $$ |
| | |
| | where \\(\hat{S}_i^x\\) and \\(\hat{S}_i^z\\) are spin- \\(1/2\\) operators on site \\(i\\). |
| |
|
| |
|
| | The model has been trained on \\(R=6000\\) different values of the field \\(h\\) equispaced in the interval \\(h \in [0.8, 1.2]\\), |
| | using a total batch size of \\(M=12000\\) samples. |
| |
|
| | The computation has been distributed over 4 A100-64GB GPUs for few hours. |
| |
|
| |
|
| | ## How to Get Started with the Model |
| |
|
| | Use the code below to get started with the model. In particular, we sample the model for a fixed value of the external field \\(h\\) using NetKet. |
| |
|
| | ```python |
| | from functools import partial |
| | import numpy as np |
| | |
| | import jax |
| | import jax.numpy as jnp |
| | import netket as nk |
| | |
| | import flax |
| | from flax.training import checkpoints |
| | |
| | flax.config.update('flax_use_orbax_checkpointing', False) |
| | |
| | lattice = nk.graph.Hypercube(length=100, n_dim=1, pbc=True) |
| | |
| | revision = "main" |
| | h = 1.0 #* fix the value of the external field |
| | |
| | assert h >= 0.8 and h <= 1.2 #* the model has been trained on this interval |
| | |
| | from transformers import FlaxAutoModel |
| | wf = FlaxAutoModel.from_pretrained("nqs-models/ising_fnqs", trust_remote_code=True) |
| | N_params = nk.jax.tree_size(wf.params) |
| | print('Number of parameters = ', N_params, flush=True) |
| | |
| | hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes) |
| | hamiltonian = nk.operator.IsingJax(hilbert=hilbert, graph=lattice, h=h, J=-1.0) |
| | |
| | action = nk.sampler.rules.LocalRule() |
| | sampler = nk.sampler.MetropolisSampler(hilbert=hilbert, |
| | rule=action, |
| | n_chains=12000, |
| | n_sweeps=lattice.n_nodes) |
| | |
| | key = jax.random.PRNGKey(0) |
| | key, subkey = jax.random.split(key, 2) |
| | vstate = nk.vqs.MCState(sampler=sampler, |
| | apply_fun=partial(wf.__call__, coups=h), |
| | sampler_seed=subkey, |
| | n_samples=12000, |
| | n_discard_per_chain=0, |
| | variables=wf.params, |
| | chunk_size=12000) |
| | |
| | # start from thermalized configurations |
| | from huggingface_hub import hf_hub_download |
| | path = hf_hub_download(repo_id="nqs-models/ising_fnqs", filename="spins", revision=revision) |
| | samples = checkpoints.restore_checkpoint(path, prefix="spins", target=None) |
| | samples = jnp.array(samples, dtype='int8') |
| | vstate.sampler_state = vstate.sampler_state.replace(σ = samples) |
| | |
| | import time |
| | # Sample the model |
| | for _ in range(10): |
| | start = time.time() |
| | E = vstate.expect(hamiltonian) |
| | vstate.sample() |
| | |
| | print("Mean: ", E.mean.real / lattice.n_nodes, "\t time=", time.time()-start) |
| | |
| | ``` |
| |
|
| | The time per sweep is 3.5s, evaluated on a single A100-40GB GPU. |
| |
|
| | ### Extract hidden representation |
| |
|
| | The hidden representation associated to the input batch of configurations can be extracted as: |
| |
|
| | ```python |
| | wf = FlaxAutoModel.from_pretrained("nqs-models/ising_fnqs", trust_remote_code=True, return_z=True) |
| | |
| | z = wf(wf.params, samples, h) |
| | ``` |
| |
|
| | #### Training Hyperparameters |
| |
|
| | Number of layers: 6 |
| | Embedding dimension: 72 |
| | Hidden dimension: 144 |
| | Number of heads: 12 |
| | Patch size: 4 |
| |
|
| | Total number of parameters: 198288 |
| |
|
| |
|
| | ## Model Card Contact |
| |
|
| | Riccardo Rende (rrende@sissa.it) |
| | Luciano Loris Viteritti (luciano.viteritti@epfl.ch) |