KitsuVp commited on
Commit
299d86a
·
verified ·
1 Parent(s): 38a192b

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +94 -33
modeling_neollm.py CHANGED
@@ -1,13 +1,13 @@
1
- # ==================== modeling_neollm.py ====================
2
  #!/usr/bin/env python3
3
  """
4
  NeoLLM Model with FANformer Integration in both Attention and FFN, Dropout Regularization,
5
- and ResFormer Value Residual Learning for enhanced information
6
- flow through deep layers.
7
 
8
  Updated to include:
9
  - Fourier Analysis Network (FAN) layer for effective periodicity modeling in attention (relational space)
10
- - NEW: FAN layer in FFN for featural periodicity modeling (complementary coverage)
 
11
  - Dropout regularization at strategic locations
12
  - ResFormer: Feature residual connections from first layer (applied before projections)
13
  """
@@ -35,7 +35,7 @@ from transformers.utils.import_utils import (
35
  is_causal_conv1d_available,
36
  is_flash_linear_attention_available,
37
  )
38
- from .configuration_neollm import NeoLLMConfig
39
 
40
 
41
  if is_causal_conv1d_available():
@@ -153,7 +153,74 @@ class GPAS(nn.Module):
153
  return x_scaled
154
 
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  class NeoLLMRMSNormGated(nn.Module):
 
 
 
157
  def __init__(self, hidden_size, eps=1e-6, **kwargs):
158
  super().__init__()
159
  self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -207,25 +274,6 @@ class NeoLLMRotaryEmbedding(nn.Module):
207
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
208
 
209
 
210
- class NeoLLMRMSNorm(nn.Module):
211
- def __init__(self, dim: int, eps: float = 1e-6):
212
- super().__init__()
213
- self.eps = eps
214
- self.weight = nn.Parameter(torch.zeros(dim))
215
-
216
- def _norm(self, x):
217
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
218
-
219
- def forward(self, x):
220
- output = self._norm(x.float())
221
- # Llama does x.to(float16) * w whilst NeoLLM is (x * w).to(float16)
222
- output = output * (1.0 + self.weight.float())
223
- return output.type_as(x)
224
-
225
- def extra_repr(self):
226
- return f"{tuple(self.weight.shape)}, eps={self.eps}"
227
-
228
-
229
  def rotate_half(x):
230
  """Rotates half the hidden dims of the input."""
231
  x1 = x[..., : x.shape[-1] // 2]
@@ -293,7 +341,7 @@ def eager_attention_forward(
293
 
294
  class NeoLLMAttention(nn.Module):
295
  """
296
- Multi-headed attention with FANformer integration, Selective Self-Attention for periodicity modeling,
297
  and ResFormer feature residual connections for enhanced information flow.
298
 
299
  ResFormer enhancement: Applies learnable feature residual connections from the first layer
@@ -332,8 +380,10 @@ class NeoLLMAttention(nn.Module):
332
  self.o_proj = nn.Linear(
333
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
334
  )
335
- self.q_norm = NeoLLMRMSNorm(self.head_dim, eps=config.rms_norm_eps)
336
- self.k_norm = NeoLLMRMSNorm(self.head_dim, eps=config.rms_norm_eps)
 
 
337
 
338
  # Dropout for attention output
339
  self.dropout = nn.Dropout(config.dropout_rate)
@@ -371,6 +421,7 @@ class NeoLLMAttention(nn.Module):
371
  )
372
  gate = gate.reshape(*input_shape, -1)
373
 
 
374
  query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
375
  key_states = self.k_norm(self.k_proj(hidden_states_fan).view(hidden_shape)).transpose(1, 2)
376
  value_states = self.v_proj(hidden_states_fan).view(hidden_shape).transpose(1, 2)
@@ -566,7 +617,7 @@ def torch_recurrent_gated_delta_rule(
566
 
567
  class NeoLLMGatedDeltaNet(nn.Module):
568
  """
569
- Linear attention with FANformer integration, Selective Self-Attention for periodicity modeling,
570
  and ResFormer feature residual connections for enhanced information flow.
571
 
572
  ResFormer enhancement: Applies learnable feature residual connections from the first layer
@@ -630,7 +681,7 @@ class NeoLLMGatedDeltaNet(nn.Module):
630
  else FusedRMSNormGated(
631
  self.head_v_dim,
632
  eps=self.layer_norm_epsilon,
633
- activation=fla_compatible_activation, # Use FLA-compatible activation
634
  device=torch.cuda.current_device(),
635
  dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
636
  )
@@ -849,8 +900,9 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
849
  # MLP with FANformer integration
850
  self.mlp = NeoLLMMLP(config)
851
 
852
- self.input_layernorm = NeoLLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
853
- self.post_attention_layernorm = NeoLLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
854
 
855
  # LNS (LayerNorm Scaling) - applies 1/√ℓ scaling
856
  self.lns_attn = LNS(layer_idx)
@@ -873,7 +925,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
873
  ) -> torch.FloatTensor:
874
  residual = hidden_states
875
 
876
- # Apply layer normalization
877
  hidden_states = self.input_layernorm(hidden_states)
878
 
879
  # Apply LNS scaling after normalization
@@ -952,6 +1004,12 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
952
  elif isinstance(module, FANLayer):
953
  # FANLayer initialization is handled within the class
954
  pass
 
 
 
 
 
 
955
 
956
 
957
  class NeoLLMModel(NeoLLMPreTrainedModel):
@@ -963,7 +1021,8 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
963
  self.layers = nn.ModuleList(
964
  [NeoLLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
965
  )
966
- self.norm = NeoLLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
967
  self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
968
  self.gradient_checkpointing = False
969
 
@@ -1023,6 +1082,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1023
  if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
1024
  self.first_layer_fan = decoder_layer.current_layer_fan
1025
 
 
1026
  hidden_states = self.norm(hidden_states)
1027
 
1028
  return BaseModelOutputWithPast(
@@ -1132,6 +1192,7 @@ __all__ = [
1132
  "NeoLLMPreTrainedModel",
1133
  "NeoLLMConfig",
1134
  "FANLayer",
 
1135
  ]
1136
 
1137
  # Register the configuration and model for AutoClass support
 
 
1
  #!/usr/bin/env python3
2
  """
3
  NeoLLM Model with FANformer Integration in both Attention and FFN, Dropout Regularization,
4
+ SeeDNorm (Self-Rescaled Dynamic Normalization), and ResFormer Value Residual Learning
5
+ for enhanced information flow through deep layers.
6
 
7
  Updated to include:
8
  - Fourier Analysis Network (FAN) layer for effective periodicity modeling in attention (relational space)
9
+ - FAN layer in FFN for featural periodicity modeling (complementary coverage)
10
+ - SeeDNorm: Dynamic normalization with input-dependent scaling for better adaptability
11
  - Dropout regularization at strategic locations
12
  - ResFormer: Feature residual connections from first layer (applied before projections)
13
  """
 
35
  is_causal_conv1d_available,
36
  is_flash_linear_attention_available,
37
  )
38
+ from configuration_neollm import NeoLLMConfig
39
 
40
 
41
  if is_causal_conv1d_available():
 
153
  return x_scaled
154
 
155
 
156
+ class SeeDNorm(nn.Module):
157
+ """
158
+ Self-Rescaled Dynamic Normalization (SeeDNorm)
159
+
160
+ From "SeeDNorm: Self-Rescaled Dynamic Normalization":
161
+ SeeDNorm(x) = [σ(x·β^T)·α + γ] ⊙ x/RMS(x)
162
+
163
+ Dynamically adjusts the scaling coefficient based on the current input,
164
+ preserving input norm information and enabling data-dependent normalization.
165
+
166
+ Key features:
167
+ - γ: Static scaling factor (like RMSNorm), initialized to 1
168
+ - β: Self-rescaling parameter, initialized to 0
169
+ - α: Dynamic modulation parameter, initialized to 1
170
+ - σ: tanh activation to constrain dynamic scaling range [-1, 1]
171
+
172
+ Args:
173
+ dim: Hidden dimension size
174
+ eps: Small constant for numerical stability
175
+ """
176
+
177
+ def __init__(self, dim: int, eps: float = 1e-6):
178
+ super().__init__()
179
+ self.dim = dim
180
+ self.eps = eps
181
+
182
+ # Learnable parameters
183
+ self.gamma = nn.Parameter(torch.ones(dim)) # γ: static scaling (RMSNorm-like)
184
+ self.beta = nn.Parameter(torch.zeros(dim)) # β: self-rescaling parameter
185
+ self.alpha = nn.Parameter(torch.ones(dim)) # α: dynamic modulation parameter
186
+
187
+ def _rms_norm(self, x: torch.Tensor) -> torch.Tensor:
188
+ """Compute RMS normalization: x / RMS(x)"""
189
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
190
+
191
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
192
+ """
193
+ Apply Self-Rescaled Dynamic Normalization.
194
+
195
+ Args:
196
+ x: Input tensor of shape (..., dim)
197
+
198
+ Returns:
199
+ Normalized and dynamically scaled tensor of same shape
200
+ """
201
+ # Compute input-dependent rescaling: σ(x·β^T)
202
+ # x·β^T produces scalar per token via dot product
203
+ rescale_factor = torch.tanh(torch.sum(x * self.beta, dim=-1, keepdim=True))
204
+
205
+ # Dynamic scaling coefficient: σ(x·β^T)·α + γ
206
+ dynamic_scale = rescale_factor * self.alpha + self.gamma
207
+
208
+ # Apply RMS normalization
209
+ x_normalized = self._rms_norm(x.float())
210
+
211
+ # Apply dynamic scaling
212
+ output = x_normalized * dynamic_scale.float()
213
+
214
+ return output.type_as(x)
215
+
216
+ def extra_repr(self) -> str:
217
+ return f"dim={self.dim}, eps={self.eps}"
218
+
219
+
220
  class NeoLLMRMSNormGated(nn.Module):
221
+ """
222
+ Gated RMSNorm variant used in specific contexts.
223
+ """
224
  def __init__(self, hidden_size, eps=1e-6, **kwargs):
225
  super().__init__()
226
  self.weight = nn.Parameter(torch.ones(hidden_size))
 
274
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
275
 
276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  def rotate_half(x):
278
  """Rotates half the hidden dims of the input."""
279
  x1 = x[..., : x.shape[-1] // 2]
 
341
 
342
  class NeoLLMAttention(nn.Module):
343
  """
344
+ Multi-headed attention with FANformer integration, SeeDNorm for Q/K normalization,
345
  and ResFormer feature residual connections for enhanced information flow.
346
 
347
  ResFormer enhancement: Applies learnable feature residual connections from the first layer
 
380
  self.o_proj = nn.Linear(
381
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
382
  )
383
+
384
+ # SeeDNorm for Q/K normalization (replaces RMSNorm)
385
+ self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
386
+ self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
387
 
388
  # Dropout for attention output
389
  self.dropout = nn.Dropout(config.dropout_rate)
 
421
  )
422
  gate = gate.reshape(*input_shape, -1)
423
 
424
+ # Apply SeeDNorm to Q and K
425
  query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
426
  key_states = self.k_norm(self.k_proj(hidden_states_fan).view(hidden_shape)).transpose(1, 2)
427
  value_states = self.v_proj(hidden_states_fan).view(hidden_shape).transpose(1, 2)
 
617
 
618
  class NeoLLMGatedDeltaNet(nn.Module):
619
  """
620
+ Linear attention with FANformer integration, SeeDNorm for normalization,
621
  and ResFormer feature residual connections for enhanced information flow.
622
 
623
  ResFormer enhancement: Applies learnable feature residual connections from the first layer
 
681
  else FusedRMSNormGated(
682
  self.head_v_dim,
683
  eps=self.layer_norm_epsilon,
684
+ activation=fla_compatible_activation,
685
  device=torch.cuda.current_device(),
686
  dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
687
  )
 
900
  # MLP with FANformer integration
901
  self.mlp = NeoLLMMLP(config)
902
 
903
+ # SeeDNorm for input and post-attention normalization (replaces RMSNorm)
904
+ self.input_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
905
+ self.post_attention_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
906
 
907
  # LNS (LayerNorm Scaling) - applies 1/√ℓ scaling
908
  self.lns_attn = LNS(layer_idx)
 
925
  ) -> torch.FloatTensor:
926
  residual = hidden_states
927
 
928
+ # Apply SeeDNorm normalization
929
  hidden_states = self.input_layernorm(hidden_states)
930
 
931
  # Apply LNS scaling after normalization
 
1004
  elif isinstance(module, FANLayer):
1005
  # FANLayer initialization is handled within the class
1006
  pass
1007
+ elif isinstance(module, SeeDNorm):
1008
+ # SeeDNorm initialization:
1009
+ # gamma (γ) initialized to 1 (default in Parameter definition)
1010
+ # beta (β) initialized to 0 (default in Parameter definition)
1011
+ # alpha (α) initialized to 1 (default in Parameter definition)
1012
+ pass
1013
 
1014
 
1015
  class NeoLLMModel(NeoLLMPreTrainedModel):
 
1021
  self.layers = nn.ModuleList(
1022
  [NeoLLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1023
  )
1024
+ # SeeDNorm for final output normalization (replaces RMSNorm)
1025
+ self.norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
1026
  self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
1027
  self.gradient_checkpointing = False
1028
 
 
1082
  if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
1083
  self.first_layer_fan = decoder_layer.current_layer_fan
1084
 
1085
+ # Apply SeeDNorm for final normalization
1086
  hidden_states = self.norm(hidden_states)
1087
 
1088
  return BaseModelOutputWithPast(
 
1192
  "NeoLLMPreTrainedModel",
1193
  "NeoLLMConfig",
1194
  "FANLayer",
1195
+ "SeeDNorm",
1196
  ]
1197
 
1198
  # Register the configuration and model for AutoClass support