Update transformer_fnqs.py
Browse files- transformer_fnqs.py +13 -10
transformer_fnqs.py
CHANGED
|
@@ -99,19 +99,22 @@ class OuputHead(nn.Module):
|
|
| 99 |
self.output_layer0 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
|
| 100 |
self.output_layer1 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
|
| 101 |
|
| 102 |
-
def __call__(self, x):
|
| 103 |
|
| 104 |
-
|
| 105 |
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
if self.complex:
|
| 109 |
-
sign = self.norm1(self.output_layer1(
|
| 110 |
-
|
| 111 |
else:
|
| 112 |
-
|
| 113 |
|
| 114 |
-
return jnp.sum(log_cosh(
|
| 115 |
|
| 116 |
class ViTFNQS(nn.Module):
|
| 117 |
num_layers: int
|
|
@@ -135,7 +138,7 @@ class ViTFNQS(nn.Module):
|
|
| 135 |
|
| 136 |
self.output = OuputHead(self.d_model, complex=self.complex)
|
| 137 |
|
| 138 |
-
def __call__(self, spins, coups):
|
| 139 |
x = jnp.atleast_2d(spins)
|
| 140 |
|
| 141 |
if self.disorder:
|
|
@@ -154,6 +157,6 @@ class ViTFNQS(nn.Module):
|
|
| 154 |
|
| 155 |
x = self.encoder(x)
|
| 156 |
|
| 157 |
-
|
| 158 |
|
| 159 |
-
return
|
|
|
|
| 99 |
self.output_layer0 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
|
| 100 |
self.output_layer1 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
|
| 101 |
|
| 102 |
+
def __call__(self, x, return_z=False):
|
| 103 |
|
| 104 |
+
z = self.out_layer_norm(x.sum(axis=1))
|
| 105 |
|
| 106 |
+
if return_z:
|
| 107 |
+
return z
|
| 108 |
+
|
| 109 |
+
amp = self.norm0(self.output_layer0(z))
|
| 110 |
|
| 111 |
if self.complex:
|
| 112 |
+
sign = self.norm1(self.output_layer1(z))
|
| 113 |
+
out = amp + 1j*sign
|
| 114 |
else:
|
| 115 |
+
out = amp
|
| 116 |
|
| 117 |
+
return jnp.sum(log_cosh(out), axis=-1)
|
| 118 |
|
| 119 |
class ViTFNQS(nn.Module):
|
| 120 |
num_layers: int
|
|
|
|
| 138 |
|
| 139 |
self.output = OuputHead(self.d_model, complex=self.complex)
|
| 140 |
|
| 141 |
+
def __call__(self, spins, coups, return_z=False):
|
| 142 |
x = jnp.atleast_2d(spins)
|
| 143 |
|
| 144 |
if self.disorder:
|
|
|
|
| 157 |
|
| 158 |
x = self.encoder(x)
|
| 159 |
|
| 160 |
+
out = self.output(x, return_z=return_z)
|
| 161 |
|
| 162 |
+
return out
|