rrende commited on
Commit
a461c60
·
verified ·
1 Parent(s): c88a493

Upload model

Browse files
Files changed (1) hide show
  1. transformer_fnqs.py +2 -2
transformer_fnqs.py CHANGED
@@ -57,8 +57,8 @@ class EncoderBlock(nn.Module):
57
  self.layer_norm_2 = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64)
58
 
59
  self.ff = nn.Sequential([
60
- nn.Dense(2*self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64),
61
- nn.relu,
62
  nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64),
63
  ])
64
 
 
57
  self.layer_norm_2 = nn.LayerNorm(dtype=jnp.float64, param_dtype=jnp.float64)
58
 
59
  self.ff = nn.Sequential([
60
+ nn.Dense(4*self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64),
61
+ nn.gelu,
62
  nn.Dense(self.d_model, kernel_init=nn.initializers.xavier_uniform(), param_dtype=jnp.float64, dtype=jnp.float64),
63
  ])
64