| from transformers import FlaxPreTrainedModel | |
| import jax.numpy as jnp | |
| from .transformer_fnqs import ViTFNQS | |
| from .vit_fnqs_config import ViTFNQSConfig | |
| class ViTFNQSModel(FlaxPreTrainedModel): | |
| config_class = ViTFNQSConfig | |
| def __init__( | |
| self, | |
| config: ViTFNQSConfig, | |
| input_shape = (jnp.zeros((1, 100)), jnp.zeros((1, 1))), | |
| seed: int = 0, | |
| dtype: jnp.dtype = jnp.float64, | |
| _do_init: bool = True, | |
| **kwargs, | |
| ): | |
| self.model = ViTFNQS(L_eff=config.L_eff, | |
| num_layers=config.num_layers, | |
| d_model=config.d_model, | |
| heads=config.heads, | |
| b=config.b, | |
| complex=config.complex, | |
| disorder=config.disorder, | |
| transl_invariant=config.tras_inv, | |
| two_dimensional=config.two_dim, | |
| ) | |
| if not "return_z" in kwargs: | |
| self.return_z = False | |
| else: | |
| self.return_z = kwargs["return_z"] | |
| super().__init__(config, ViTFNQS, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) | |
| def __call__(self, params, spins, coups): | |
| return self.model.apply(params, spins, coups, return_z=self.return_z) | |
| def init_weights(self, rng, input_shape): | |
| return self.model.init(rng, *input_shape) |