Upload model
Browse files- transformer_fnqs.py +2 -2
transformer_fnqs.py
CHANGED
|
@@ -57,8 +57,8 @@ class EncoderBlock(nn.Module):
|
|
| 57 |
self.layer_norm_2 = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64)
|
| 58 |
|
| 59 |
self.ff = nn.Sequential([
|
| 60 |
-
nn.Dense(
|
| 61 |
-
nn.
|
| 62 |
nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64),
|
| 63 |
])
|
| 64 |
|
|
|
|
| 57 |
self.layer_norm_2 = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64)
|
| 58 |
|
| 59 |
self.ff = nn.Sequential([
|
| 60 |
+
nn.Dense(4*self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64),
|
| 61 |
+
nn.gelu,
|
| 62 |
nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64),
|
| 63 |
])
|
| 64 |
|