rrende commited on
Commit
bdb1b41
·
verified ·
1 Parent(s): 23cc2be

Update transformer_fnqs.py

Browse files
Files changed (1) hide show
  1. transformer_fnqs.py +13 -10
transformer_fnqs.py CHANGED
@@ -99,19 +99,22 @@ class OuputHead(nn.Module):
99
  self.output_layer0 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
100
  self.output_layer1 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
101
 
102
- def __call__(self, x):
103
 
104
- x = self.out_layer_norm(x.sum(axis=1))
105
 
106
- amp = self.norm0(self.output_layer0(x))
 
 
 
107
 
108
  if self.complex:
109
- sign = self.norm1(self.output_layer1(x))
110
- z = amp + 1j*sign
111
  else:
112
- z = amp
113
 
114
- return jnp.sum(log_cosh(z), axis=-1)
115
 
116
  class ViTFNQS(nn.Module):
117
  num_layers: int
@@ -135,7 +138,7 @@ class ViTFNQS(nn.Module):
135
 
136
  self.output = OuputHead(self.d_model, complex=self.complex)
137
 
138
- def __call__(self, spins, coups):
139
  x = jnp.atleast_2d(spins)
140
 
141
  if self.disorder:
@@ -154,6 +157,6 @@ class ViTFNQS(nn.Module):
154
 
155
  x = self.encoder(x)
156
 
157
- z = self.output(x)
158
 
159
- return z
 
99
  self.output_layer0 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
100
  self.output_layer1 = nn.Dense(self.d_model, param_dtype=jnp.float64, dtype=jnp.float64, kernel_init=nn.initializers.xavier_uniform(), bias_init=jax.nn.initializers.zeros)
101
 
102
+ def __call__(self, x, return_z=False):
103
 
104
+ z = self.out_layer_norm(x.sum(axis=1))
105
 
106
+ if return_z:
107
+ return z
108
+
109
+ amp = self.norm0(self.output_layer0(z))
110
 
111
  if self.complex:
112
+ sign = self.norm1(self.output_layer1(z))
113
+ out = amp + 1j*sign
114
  else:
115
+ out = amp
116
 
117
+ return jnp.sum(log_cosh(out), axis=-1)
118
 
119
  class ViTFNQS(nn.Module):
120
  num_layers: int
 
138
 
139
  self.output = OuputHead(self.d_model, complex=self.complex)
140
 
141
+ def __call__(self, spins, coups, return_z=False):
142
  x = jnp.atleast_2d(spins)
143
 
144
  if self.disorder:
 
157
 
158
  x = self.encoder(x)
159
 
160
+ out = self.output(x, return_z=return_z)
161
 
162
+ return out