smithblack-0 commited on
Commit
dff3eb9
·
verified ·
1 Parent(s): a15e620

Update architecture and tokenizer

Browse files
Files changed (1) hide show
  1. huggingface.py +11 -9
huggingface.py CHANGED
@@ -1736,16 +1736,17 @@ gated residual connections around both sublayers:
1736
 
1737
  normed_attn = RMSNorm(x)
1738
  attn_out, router_diagnostics = SHRAMHybridLayer(normed_attn, ...)
1739
- h = x + residual_gate * attn_out
1740
 
1741
  normed_mlp = RMSNorm(h)
1742
  mlp_out = SwiGLUMLP(normed_mlp)
1743
- out = h + residual_gate * mlp_out
1744
 
1745
- A single shared residual_gate vector (shape: embedding_width, init: zeros) gates
1746
- both sublayer contributions. At initialisation the layer is a pure identity, which
1747
- prevents variance explosion through depth regardless of how HuggingFace initialises
1748
- the projection weights. The gate is a trainable parameter and opens during training.
 
1749
 
1750
  Pre-norm keeps the residual stream unnormalised. Gradients flow more cleanly
1751
  through unnormalised residuals at depth, and each sublayer receives a stable,
@@ -3746,7 +3747,8 @@ class DecoderLayer(nn.Module):
3746
  self.mlp_norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps)
3747
  self.attention = SHRAMHybridLayer(config)
3748
  self.mlp = SwiGLUMLP(config)
3749
- self.residual_gate = nn.Parameter(1e-6*torch.randn([config.embedding_width]))
 
3750
  def num_mosrah_parameters(self) -> int:
3751
  """Return the total number of trainable MoSRAH parameters in this decoder layer."""
3752
  return self.attention.num_mosrah_parameters()
@@ -3780,8 +3782,8 @@ class DecoderLayer(nn.Module):
3780
  active_mask=active_mask,
3781
  cache=cache,
3782
  )
3783
- hidden_states = x + self.residual_gate*attn_out
3784
- output = hidden_states + self.residual_gate*self.mlp(self.mlp_norm(hidden_states))
3785
  return output, router_diagnostics
3786
 
3787
 
 
1736
 
1737
  normed_attn = RMSNorm(x)
1738
  attn_out, router_diagnostics = SHRAMHybridLayer(normed_attn, ...)
1739
+ h = x + attn_residual_gate * attn_out
1740
 
1741
  normed_mlp = RMSNorm(h)
1742
  mlp_out = SwiGLUMLP(normed_mlp)
1743
+ out = h + mlp_residual_gate * mlp_out
1744
 
1745
+ Two independent residual gate vectors (shape: embedding_width, init: near-zero) gate
1746
+ the attention and MLP sublayer contributions separately. At initialisation the layer is
1747
+ a pure identity. The gates are independent trainable parameters so gradients from the
1748
+ two sublayers never accumulate into a shared parameter, preventing norm explosion at
1749
+ depth.
1750
 
1751
  Pre-norm keeps the residual stream unnormalised. Gradients flow more cleanly
1752
  through unnormalised residuals at depth, and each sublayer receives a stable,
 
3747
  self.mlp_norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps)
3748
  self.attention = SHRAMHybridLayer(config)
3749
  self.mlp = SwiGLUMLP(config)
3750
+ self.attn_residual_gate = nn.Parameter(1e-6*torch.randn([config.embedding_width]))
3751
+ self.mlp_residual_gate = nn.Parameter(1e-6*torch.randn([config.embedding_width]))
3752
  def num_mosrah_parameters(self) -> int:
3753
  """Return the total number of trainable MoSRAH parameters in this decoder layer."""
3754
  return self.attention.num_mosrah_parameters()
 
3782
  active_mask=active_mask,
3783
  cache=cache,
3784
  )
3785
+ hidden_states = x + self.attn_residual_gate*attn_out
3786
+ output = hidden_states + self.mlp_residual_gate*self.mlp(self.mlp_norm(hidden_states))
3787
  return output, router_diagnostics
3788
 
3789