| import jax |
| import jax.numpy as jnp |
|
|
| from flax import linen as nn |
| import jax.numpy as jnp |
|
|
| from einops import rearrange |
|
|
| def roll(J, shift, axis=-1): |
| return jnp.roll(J, shift, axis=axis) |
|
|
| from functools import partial |
| @partial(jax.vmap, in_axes=(None, 0, None), out_axes=1) |
| @partial(jax.vmap, in_axes=(None, None, 0), out_axes=1) |
| def roll2d(spins, i, j): |
| side = int(spins.shape[-1]**0.5) |
| spins = spins.reshape(spins.shape[0], side, side) |
| spins = jnp.roll(jnp.roll(spins, i, axis=-2), j, axis=-1) |
| return spins.reshape(spins.shape[0], -1) |
| |
| class FMHA(nn.Module): |
| d_model : int |
| h: int |
| L_eff: int |
| transl_invariant: bool = True |
| two_dimensional: bool = False |
|
|
| def setup(self): |
| self.v = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64) |
| if self.transl_invariant: |
| self.J = self.param("J", nn.initializers.xavier_uniform(), (self.h, self.L_eff), jnp.float64) |
| if self.two_dimensional: |
| sq_L_eff = int(self.L_eff**0.5) |
| assert sq_L_eff * sq_L_eff == self.L_eff |
| self.J = roll2d(self.J, jnp.arange(sq_L_eff), jnp.arange(sq_L_eff)) |
| self.J = self.J.reshape(self.h, -1, self.L_eff) |
| else: |
| self.J = jax.vmap(roll, (None, 0), out_axes=1)(self.J, jnp.arange(self.L_eff)) |
| else: |
| self.J = self.param("J", nn.initializers.xavier_uniform(), (self.h, self.L_eff, self.L_eff), jnp.float64) |
|
|
| self.W = nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64) |
|
|
| def __call__(self, x): |
| v = self.v(x) |
| v = rearrange(v, 'batch L_eff (h d_eff) -> batch L_eff h d_eff', h=self.h) |
| v = rearrange(v, 'batch L_eff h d_eff -> batch h L_eff d_eff') |
| x = jnp.matmul(self.J, v) |
| x = rearrange(x, 'batch h L_eff d_eff -> batch L_eff h d_eff') |
| x = rearrange(x, 'batch L_eff h d_eff -> batch L_eff (h d_eff)') |
|
|
| x = self.W(x) |
|
|
| return x |
|
|