| from transformers import FlaxPreTrainedModel |
| import jax.numpy as jnp |
| from .transformer import ViT |
| from .vitnqs_config import ViTNQSConfig |
|
|
|
|
| class ViTNQSModel(FlaxPreTrainedModel): |
| config_class = ViTNQSConfig |
|
|
| def __init__( |
| self, |
| config: ViTNQSConfig, |
| input_shape = jnp.zeros((1, 100)), |
| seed: int = 0, |
| dtype: jnp.dtype = jnp.float64, |
| _do_init: bool = True, |
| **kwargs, |
| ): |
| self.model = ViT(L_eff=config.L_eff, |
| num_layers=config.num_layers, |
| d_model=config.d_model, |
| heads=config.heads, |
| b=config.b, |
| transl_invariant=config.tras_inv, |
| two_dimensional=config.two_dim, |
| ) |
| if not "return_z" in kwargs: |
| self.return_z = False |
| else: |
| self.return_z = kwargs["return_z"] |
|
|
| super().__init__(config, ViT, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) |
|
|
| def __call__(self, params, spins): |
| return self.model.apply(params, spins, self.return_z) |
|
|
| def init_weights(self, rng, input_shape): |
| return self.model.init(rng, input_shape) |