Update README.md
Browse files
README.md
CHANGED
|
@@ -18,6 +18,61 @@ The computation has been distributed over 4 A100-64GB GPUs for few hours.
|
|
| 18 |
Use the code below to get started with the model. In particular, we sample the model for a fixed value of the externa field \\(h\\) using NetKet.
|
| 19 |
|
| 20 |
```python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
```
|
| 23 |
|
|
|
|
| 18 |
Use the code below to get started with the model. In particular, we sample the model for a fixed value of the externa field \\(h\\) using NetKet.
|
| 19 |
|
| 20 |
```python
|
| 21 |
+
from functools import partial
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
import jax
|
| 25 |
+
import jax.numpy as jnp
|
| 26 |
+
import netket as nk
|
| 27 |
+
|
| 28 |
+
import flax
|
| 29 |
+
from flax.training import checkpoints
|
| 30 |
+
|
| 31 |
+
flax.config.update('flax_use_orbax_checkpointing', False)
|
| 32 |
+
|
| 33 |
+
lattice = nk.graph.Hypercube(length=100, n_dim=1, pbc=True)
|
| 34 |
+
|
| 35 |
+
revision = "main"
|
| 36 |
+
h = 1.0 #* fix the value of the external field
|
| 37 |
+
|
| 38 |
+
from transformers import FlaxAutoModel
|
| 39 |
+
wf = FlaxAutoModel.from_pretrained("nqs-models/ising_fnqs", trust_remote_code=True)
|
| 40 |
+
N_params = nk.jax.tree_size(wf.params)
|
| 41 |
+
print('Number of parameters = ', N_params, flush=True)
|
| 42 |
+
|
| 43 |
+
hilbert = nk.hilbert.Spin(s=1/2, N=lattice.n_nodes)
|
| 44 |
+
hamiltonian = nk.operator.IsingJax(hilbert=hilbert, graph=lattice, h=h, J=-1.0)
|
| 45 |
+
|
| 46 |
+
action = nk.sampler.rules.LocalRule()
|
| 47 |
+
sampler = nk.sampler.MetropolisSampler(hilbert=hilbert,
|
| 48 |
+
rule=action,
|
| 49 |
+
n_chains=12000,
|
| 50 |
+
n_sweeps=lattice.n_nodes)
|
| 51 |
+
|
| 52 |
+
key = jax.random.PRNGKey(0)
|
| 53 |
+
key, subkey = jax.random.split(key, 2)
|
| 54 |
+
vstate = nk.vqs.MCState(sampler=sampler,
|
| 55 |
+
apply_fun=partial(wf.__call__, coups=h),
|
| 56 |
+
sampler_seed=subkey,
|
| 57 |
+
n_samples=12000,
|
| 58 |
+
n_discard_per_chain=0,
|
| 59 |
+
variables=wf.params,
|
| 60 |
+
chunk_size=12000)
|
| 61 |
+
|
| 62 |
+
from huggingface_hub import hf_hub_download
|
| 63 |
+
path = hf_hub_download(repo_id="nqs-models/ising_fnqs", filename="spins", revision=revision)
|
| 64 |
+
samples = checkpoints.restore_checkpoint(path, prefix="spins", target=None)
|
| 65 |
+
samples = jnp.array(samples, dtype='int8')
|
| 66 |
+
vstate.sampler_state = vstate.sampler_state.replace(σ = samples)
|
| 67 |
+
|
| 68 |
+
import time
|
| 69 |
+
# Sample the model
|
| 70 |
+
for _ in range(10):
|
| 71 |
+
start = time.time()
|
| 72 |
+
E = vstate.expect(hamiltonian)
|
| 73 |
+
vstate.sample()
|
| 74 |
+
|
| 75 |
+
print("Mean: ", E.mean.real / lattice.n_nodes, "\t time=", time.time()-start)
|
| 76 |
|
| 77 |
```
|
| 78 |
|