smithblack-0 commited on
Commit
102f1bb
·
verified ·
1 Parent(s): dff3eb9

Update architecture and tokenizer

Browse files
Files changed (4) hide show
  1. README.md +1 -0
  2. config.json +2 -1
  3. configuration.py +6 -0
  4. huggingface.py +26 -11
README.md CHANGED
@@ -95,6 +95,7 @@ contains no weights. All values are overridable via kwargs.
95
  | `tie_word_embeddings` | False |
96
  | `training_sequence_length` | 1024 |
97
  | `use_cache` | True |
 
98
  | `vocab_size` | 50277 |
99
  | `window_size` | 128 |
100
 
 
95
  | `tie_word_embeddings` | False |
96
  | `training_sequence_length` | 1024 |
97
  | `use_cache` | True |
98
+ | `use_residual_gate` | True |
99
  | `vocab_size` | 50277 |
100
  | `window_size` | 128 |
101
 
config.json CHANGED
@@ -21,8 +21,9 @@
21
  "rope_mode": "main_sequence",
22
  "tie_word_embeddings": false,
23
  "training_sequence_length": 1024,
24
- "transformers_version": "5.12.0",
25
  "use_cache": true,
 
26
  "vocab_size": 50277,
27
  "window_size": 128
28
  }
 
21
  "rope_mode": "main_sequence",
22
  "tie_word_embeddings": false,
23
  "training_sequence_length": 1024,
24
+ "transformers_version": "5.12.1",
25
  "use_cache": true,
26
+ "use_residual_gate": true,
27
  "vocab_size": 50277,
28
  "window_size": 128
29
  }
configuration.py CHANGED
@@ -79,6 +79,10 @@ class ShramConfig(PretrainedConfig):
79
  use_cache: Whether to return past_key_values for KV caching.
80
  output_hidden_states: Whether to return hidden states after each layer.
81
  tie_word_embeddings: Whether input embedding and LM head share weights.
 
 
 
 
82
  """
83
 
84
  model_type = "shram"
@@ -111,6 +115,7 @@ class ShramConfig(PretrainedConfig):
111
  use_cache: bool = True,
112
  output_hidden_states: bool = False,
113
  tie_word_embeddings: bool = False,
 
114
  **kwargs
115
  ):
116
  if head_dim % 2 != 0:
@@ -167,6 +172,7 @@ class ShramConfig(PretrainedConfig):
167
  self.beta = beta
168
  self.attention_dropout = attention_dropout
169
  self.use_cache = use_cache
 
170
 
171
  super().__init__(
172
  tie_word_embeddings=tie_word_embeddings,
 
79
  use_cache: Whether to return past_key_values for KV caching.
80
  output_hidden_states: Whether to return hidden states after each layer.
81
  tie_word_embeddings: Whether input embedding and LM head share weights.
82
+ use_residual_gate: When True, each DecoderLayer gates its residual contributions
83
+ with a learnable scalar parameter (init: zero). When False, uses a fixed
84
+ ``1/√num_decoder_layers`` scale instead, which preserves O(1) residual
85
+ variance at depth with no learnable gate. Default True.
86
  """
87
 
88
  model_type = "shram"
 
115
  use_cache: bool = True,
116
  output_hidden_states: bool = False,
117
  tie_word_embeddings: bool = False,
118
+ use_residual_gate: bool = True,
119
  **kwargs
120
  ):
121
  if head_dim % 2 != 0:
 
172
  self.beta = beta
173
  self.attention_dropout = attention_dropout
174
  self.use_cache = use_cache
175
+ self.use_residual_gate = use_residual_gate
176
 
177
  super().__init__(
178
  tie_word_embeddings=tie_word_embeddings,
huggingface.py CHANGED
@@ -165,6 +165,10 @@ class ShramConfig(PretrainedConfig):
165
  use_cache: Whether to return past_key_values for KV caching.
166
  output_hidden_states: Whether to return hidden states after each layer.
167
  tie_word_embeddings: Whether input embedding and LM head share weights.
 
 
 
 
168
  """
169
 
170
  model_type = "shram"
@@ -197,6 +201,7 @@ class ShramConfig(PretrainedConfig):
197
  use_cache: bool = True,
198
  output_hidden_states: bool = False,
199
  tie_word_embeddings: bool = False,
 
200
  **kwargs
201
  ):
202
  if head_dim % 2 != 0:
@@ -253,6 +258,7 @@ class ShramConfig(PretrainedConfig):
253
  self.beta = beta
254
  self.attention_dropout = attention_dropout
255
  self.use_cache = use_cache
 
256
 
257
  super().__init__(
258
  tie_word_embeddings=tie_word_embeddings,
@@ -1736,17 +1742,19 @@ gated residual connections around both sublayers:
1736
 
1737
  normed_attn = RMSNorm(x)
1738
  attn_out, router_diagnostics = SHRAMHybridLayer(normed_attn, ...)
1739
- h = x + attn_residual_gate * attn_out
1740
 
1741
  normed_mlp = RMSNorm(h)
1742
  mlp_out = SwiGLUMLP(normed_mlp)
1743
- out = h + mlp_residual_gate * mlp_out
1744
 
1745
- Two independent residual gate vectors (shape: embedding_width, init: near-zero) gate
1746
- the attention and MLP sublayer contributions separately. At initialisation the layer is
1747
- a pure identity. The gates are independent trainable parameters so gradients from the
1748
- two sublayers never accumulate into a shared parameter, preventing norm explosion at
1749
- depth.
 
 
1750
 
1751
  Pre-norm keeps the residual stream unnormalised. Gradients flow more cleanly
1752
  through unnormalised residuals at depth, and each sublayer receives a stable,
@@ -1763,6 +1771,8 @@ subtraction, is faster than LayerNorm, and proved more stable at scale.
1763
 
1764
 
1765
 
 
 
1766
  # -----------
1767
  # Inlined from: shram.py
1768
  # -----------
@@ -3747,8 +3757,13 @@ class DecoderLayer(nn.Module):
3747
  self.mlp_norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps)
3748
  self.attention = SHRAMHybridLayer(config)
3749
  self.mlp = SwiGLUMLP(config)
3750
- self.attn_residual_gate = nn.Parameter(1e-6*torch.randn([config.embedding_width]))
3751
- self.mlp_residual_gate = nn.Parameter(1e-6*torch.randn([config.embedding_width]))
 
 
 
 
 
3752
  def num_mosrah_parameters(self) -> int:
3753
  """Return the total number of trainable MoSRAH parameters in this decoder layer."""
3754
  return self.attention.num_mosrah_parameters()
@@ -3782,8 +3797,8 @@ class DecoderLayer(nn.Module):
3782
  active_mask=active_mask,
3783
  cache=cache,
3784
  )
3785
- hidden_states = x + self.attn_residual_gate*attn_out
3786
- output = hidden_states + self.mlp_residual_gate*self.mlp(self.mlp_norm(hidden_states))
3787
  return output, router_diagnostics
3788
 
3789
 
 
165
  use_cache: Whether to return past_key_values for KV caching.
166
  output_hidden_states: Whether to return hidden states after each layer.
167
  tie_word_embeddings: Whether input embedding and LM head share weights.
168
+ use_residual_gate: When True, each DecoderLayer gates its residual contributions
169
+ with a learnable scalar parameter (init: zero). When False, uses a fixed
170
+ ``1/√num_decoder_layers`` scale instead, which preserves O(1) residual
171
+ variance at depth with no learnable gate. Default True.
172
  """
173
 
174
  model_type = "shram"
 
201
  use_cache: bool = True,
202
  output_hidden_states: bool = False,
203
  tie_word_embeddings: bool = False,
204
+ use_residual_gate: bool = True,
205
  **kwargs
206
  ):
207
  if head_dim % 2 != 0:
 
258
  self.beta = beta
259
  self.attention_dropout = attention_dropout
260
  self.use_cache = use_cache
261
+ self.use_residual_gate = use_residual_gate
262
 
263
  super().__init__(
264
  tie_word_embeddings=tie_word_embeddings,
 
1742
 
1743
  normed_attn = RMSNorm(x)
1744
  attn_out, router_diagnostics = SHRAMHybridLayer(normed_attn, ...)
1745
+ h = x + attn_residual_scale * attn_out
1746
 
1747
  normed_mlp = RMSNorm(h)
1748
  mlp_out = SwiGLUMLP(normed_mlp)
1749
+ out = h + mlp_residual_scale * mlp_out
1750
 
1751
+ ``attn_residual_scale`` and ``mlp_residual_scale`` are always present. Their nature
1752
+ depends on ``config.use_residual_gate``:
1753
+
1754
+ - ``True`` (default): learnable scalar ``nn.Parameter`` initialised to zero. The layer
1755
+ is a pure identity at initialisation and the scales open during training.
1756
+ - ``False``: fixed buffer ``1/√num_decoder_layers``. No learnable parameter; residual
1757
+ variance sums to O(1) across depth by construction.
1758
 
1759
  Pre-norm keeps the residual stream unnormalised. Gradients flow more cleanly
1760
  through unnormalised residuals at depth, and each sublayer receives a stable,
 
1771
 
1772
 
1773
 
1774
+
1775
+
1776
  # -----------
1777
  # Inlined from: shram.py
1778
  # -----------
 
3757
  self.mlp_norm = nn.RMSNorm(config.embedding_width, eps=config.rms_norm_eps)
3758
  self.attention = SHRAMHybridLayer(config)
3759
  self.mlp = SwiGLUMLP(config)
3760
+ scale = 1.0 / math.sqrt(config.num_decoder_layers)
3761
+ if config.use_residual_gate:
3762
+ self.attn_residual_scale = nn.Parameter(torch.zeros(1))
3763
+ self.mlp_residual_scale = nn.Parameter(torch.zeros(1))
3764
+ else:
3765
+ self.register_buffer("attn_residual_scale", torch.full((1,), scale))
3766
+ self.register_buffer("mlp_residual_scale", torch.full((1,), scale))
3767
  def num_mosrah_parameters(self) -> int:
3768
  """Return the total number of trainable MoSRAH parameters in this decoder layer."""
3769
  return self.attention.num_mosrah_parameters()
 
3797
  active_mask=active_mask,
3798
  cache=cache,
3799
  )
3800
+ hidden_states = x + self.attn_residual_scale * attn_out
3801
+ output = hidden_states + self.mlp_residual_scale * self.mlp(self.mlp_norm(hidden_states))
3802
  return output, router_diagnostics
3803
 
3804