from transformers import FlaxPreTrainedModel import jax.numpy as jnp from .transformer_fnqs import ViTFNQS from .vit_fnqs_config import ViTFNQSConfig class ViTFNQSModel(FlaxPreTrainedModel): config_class = ViTFNQSConfig def __init__( self, config: ViTFNQSConfig, input_shape = (jnp.zeros((1, 100)), jnp.zeros((1, 1))), seed: int = 0, dtype: jnp.dtype = jnp.float64, _do_init: bool = True, **kwargs, ): self.model = ViTFNQS(L_eff=config.L_eff, num_layers=config.num_layers, d_model=config.d_model, heads=config.heads, b=config.b, complex=config.complex, disorder=config.disorder, transl_invariant=config.tras_inv, two_dimensional=config.two_dim, ) if not "return_z" in kwargs: self.return_z = False else: self.return_z = kwargs["return_z"] super().__init__(config, ViTFNQS, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) def __call__(self, params, spins, coups): return self.model.apply(params, spins, coups, return_z=self.return_z) def init_weights(self, rng, input_shape): return self.model.init(rng, *input_shape)