ising_fnqs / README.md
llviteritti's picture
Update README.md
8a773aa verified
---
library_name: transformers
tags: []
---
<!-- Provide a quick summary of what the model is/does. -->
Foundation Neural-Network Quantum State trained on the Ising in transverse field model on a chain with \\(L=100\\) sites.
The system is described by the following Hamiltonian (with periodic boundary conditions):
$$
\hat{H} = -J\sum_{i=1}^N \hat{S}_i^z \hat{S}_{i+1}^z - h \sum_{i=1}^N \hat{S}_i^x \ ,
$$
where \\(\hat{S}_i^x\\) and \\(\hat{S}_i^z\\) are spin- \\(1/2\\) operators on site \\(i\\).
The model has been trained on \\(R=6000\\) different values of the field \\(h\\) equispaced in the interval \\(h \in [0.8, 1.2]\\),
using a total batch size of \\(M=12000\\) samples.
The computation has been distributed over 4 A100-64GB GPUs for few hours.
## How to Get Started with the Model
Use the code below to get started with the model. In particular, we sample the model for a fixed value of the external field \\(h\\) using NetKet.
```python
from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
import netket as nk
import flax
from flax.training import checkpoints
flax.config.update('flax_use_orbax_checkpointing', False)
lattice = nk.graph.Hypercube(length=100, n_dim=1, pbc=True)
revision = "main"
h = 1.0 #* fix the value of the external field
assert h >= 0.8 and h <= 1.2 #* the model has been trained on this interval
from transformers import FlaxAutoModel
wf = FlaxAutoModel.from_pretrained("nqs-models/ising_fnqs", trust_remote_code=True)
N_params = nk.jax.tree_size(wf.params)
print('Number of parameters = ', N_params, flush=True)
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes)
hamiltonian = nk.operator.IsingJax(hilbert=hilbert, graph=lattice, h=h, J=-1.0)
action = nk.sampler.rules.LocalRule()
sampler = nk.sampler.MetropolisSampler(hilbert=hilbert,
rule=action,
n_chains=12000,
n_sweeps=lattice.n_nodes)
key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key, 2)
vstate = nk.vqs.MCState(sampler=sampler,
apply_fun=partial(wf.__call__, coups=h),
sampler_seed=subkey,
n_samples=12000,
n_discard_per_chain=0,
variables=wf.params,
chunk_size=12000)
# start from thermalized configurations
from huggingface_hub import hf_hub_download
path = hf_hub_download(repo_id="nqs-models/ising_fnqs", filename="spins", revision=revision)
samples = checkpoints.restore_checkpoint(path, prefix="spins", target=None)
samples = jnp.array(samples, dtype='int8')
vstate.sampler_state = vstate.sampler_state.replace(σ = samples)
import time
# Sample the model
for _ in range(10):
start = time.time()
E = vstate.expect(hamiltonian)
vstate.sample()
print("Mean: ", E.mean.real / lattice.n_nodes, "\t time=", time.time()-start)
```
The time per sweep is 3.5s, evaluated on a single A100-40GB GPU.
### Extract hidden representation
The hidden representation associated to the input batch of configurations can be extracted as:
```python
wf = FlaxAutoModel.from_pretrained("nqs-models/ising_fnqs", trust_remote_code=True, return_z=True)
z = wf(wf.params, samples, h)
```
#### Training Hyperparameters
Number of layers: 6
Embedding dimension: 72
Hidden dimension: 144
Number of heads: 12
Patch size: 4
Total number of parameters: 198288
## Model Card Contact
Riccardo Rende (rrende@sissa.it)
Luciano Loris Viteritti (luciano.viteritti@epfl.ch)