KitsuVp commited on
Commit
965d461
verified
1 Parent(s): e8269b7

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +97 -16
modeling_neollm.py CHANGED
@@ -1,8 +1,12 @@
1
  #!/usr/bin/env python3
2
  """
3
- NeoLLM Model with FANformer Integration, Dropout Regularization, and Selective Self-Attention (SSA)
4
- Updated to include Fourier Analysis Network (FAN) layer for effective periodicity modeling,
5
- dropout regularization at strategic locations
 
 
 
 
6
  """
7
 
8
  import math
@@ -28,7 +32,7 @@ from transformers.utils.import_utils import (
28
  is_causal_conv1d_available,
29
  is_flash_linear_attention_available,
30
  )
31
- from .configuration_neollm import NeoLLMConfig
32
 
33
 
34
  if is_causal_conv1d_available():
@@ -45,6 +49,8 @@ else:
45
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
46
 
47
  logger = logging.get_logger(__name__)
 
 
48
  class FANLayer(nn.Module):
49
  """
50
  Fourier Analysis Network (FAN) layer for effective periodicity modeling.
@@ -283,7 +289,13 @@ def eager_attention_forward(
283
 
284
 
285
  class NeoLLMAttention(nn.Module):
286
- """Multi-headed attention with FANformer integration and Selective Self-Attention for periodicity modeling"""
 
 
 
 
 
 
287
 
288
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
289
  super().__init__()
@@ -322,22 +334,35 @@ class NeoLLMAttention(nn.Module):
322
 
323
  # Dropout for attention output
324
  self.dropout = nn.Dropout(config.dropout_rate)
 
 
 
 
325
 
326
  def forward(
327
  self,
328
  hidden_states: torch.Tensor,
329
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
330
  attention_mask: Optional[torch.Tensor],
 
331
  **kwargs: Unpack[FlashAttentionKwargs],
332
- ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
333
  input_shape = hidden_states.shape[:-1]
334
 
335
  # Apply FANformer transformation first
336
  hidden_states_fan = self.fan_layer(hidden_states)
337
 
 
 
 
 
 
 
 
 
338
  hidden_shape = (*input_shape, -1, self.head_dim)
339
 
340
- # Use FAN-transformed features directly for projections
341
  query_states, gate = torch.chunk(
342
  self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
343
  )
@@ -369,8 +394,9 @@ class NeoLLMAttention(nn.Module):
369
  attn_output = attn_output * torch.sigmoid(gate)
370
 
371
  attn_output = self.o_proj(attn_output)
372
- attn_output = self.dropout(attn_output) # Apply dropout after output projection
373
- return attn_output, attn_weights
 
374
 
375
 
376
  def apply_mask_to_padding_states(hidden_states, attention_mask):
@@ -534,8 +560,15 @@ def torch_recurrent_gated_delta_rule(
534
  core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
535
  return core_attn_out, last_recurrent_state
536
 
 
537
  class NeoLLMGatedDeltaNet(nn.Module):
538
- """Linear attention with FANformer integration and Selective Self-Attention for periodicity modeling"""
 
 
 
 
 
 
539
 
540
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
541
  super().__init__()
@@ -610,6 +643,10 @@ class NeoLLMGatedDeltaNet(nn.Module):
610
  self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule
611
  self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule
612
 
 
 
 
 
613
  if not is_fast_path_available:
614
  logger.warning_once(
615
  "The fast path is not available because one of the required library is not installed. Falling back to "
@@ -649,7 +686,8 @@ class NeoLLMGatedDeltaNet(nn.Module):
649
  self,
650
  hidden_states: torch.Tensor,
651
  attention_mask: Optional[torch.Tensor] = None,
652
- ):
 
653
  hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
654
 
655
  # Set up dimensions for reshapes later
@@ -658,7 +696,15 @@ class NeoLLMGatedDeltaNet(nn.Module):
658
  # Apply FANformer transformation first
659
  hidden_states_fan = self.fan_layer(hidden_states)
660
 
661
- # Use FAN-transformed features directly for projections
 
 
 
 
 
 
 
 
662
  projected_states_qkvz = self.in_proj_qkvz(hidden_states_fan)
663
  projected_states_ba = self.in_proj_ba(hidden_states_fan)
664
  query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
@@ -722,7 +768,9 @@ class NeoLLMGatedDeltaNet(nn.Module):
722
 
723
  output = self.out_proj(core_attn_out)
724
  output = self.dropout(output) # Apply dropout after output projection
725
- return output
 
 
726
 
727
  class PolyNorm(torch.nn.Module):
728
  def __init__(self, eps=1e-6):
@@ -737,6 +785,7 @@ class PolyNorm(torch.nn.Module):
737
  def forward(self, x):
738
  return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias
739
 
 
740
  class NeoLLMMLP(nn.Module):
741
  def __init__(self, config):
742
  super().__init__()
@@ -760,6 +809,7 @@ class NeoLLMMLP(nn.Module):
760
  hidden = self.dropout(hidden)
761
  return self.down_proj(hidden)
762
 
 
763
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
764
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
765
  super().__init__()
@@ -786,12 +836,16 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
786
  # GPAS (Gradient-Preserving Activation Scaling) - applied after residual connections
787
  self.gpas_attn = GPAS(config.hidden_size)
788
  self.gpas_mlp = GPAS(config.hidden_size)
 
 
 
789
 
790
  def forward(
791
  self,
792
  hidden_states: torch.Tensor,
793
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
794
  attention_mask: Optional[torch.Tensor] = None,
 
795
  **kwargs: Unpack[FlashAttentionKwargs],
796
  ) -> torch.FloatTensor:
797
  residual = hidden_states
@@ -802,18 +856,20 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
802
  # Apply LNS scaling after normalization
803
  hidden_states = self.lns_attn(hidden_states)
804
 
805
- # Token Mixer
806
  if self.layer_type == "linear_attention":
807
- hidden_states = self.linear_attn(
808
  hidden_states=hidden_states,
809
  attention_mask=attention_mask,
 
810
  )
811
  elif self.layer_type == "full_attention":
812
  # Self Attention
813
- hidden_states, _ = self.self_attn(
814
  hidden_states=hidden_states,
815
  attention_mask=attention_mask,
816
  position_embeddings=position_embeddings,
 
817
  **kwargs,
818
  )
819
 
@@ -855,6 +911,17 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
855
  if isinstance(module, NeoLLMGatedDeltaNet):
856
  module.dt_bias.data.fill_(1.0)
857
  module.A_log.data.uniform_(0, 16).log_()
 
 
 
 
 
 
 
 
 
 
 
858
  elif isinstance(module, GPAS):
859
  # Initialize GPAS alpha to 0 as per paper
860
  module.alpha.data.fill_(0.0)
@@ -875,6 +942,10 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
875
  self.norm = NeoLLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
876
  self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
877
  self.gradient_checkpointing = False
 
 
 
 
878
  # Initialize weights and apply final processing
879
  self.post_init()
880
 
@@ -910,6 +981,9 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
910
  # create position embeddings to be shared across the decoder layers
911
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
912
 
 
 
 
913
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
914
  layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask
915
 
@@ -917,8 +991,13 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
917
  hidden_states,
918
  position_embeddings=position_embeddings,
919
  attention_mask=layer_mask,
 
920
  **kwargs,
921
  )
 
 
 
 
922
 
923
  hidden_states = self.norm(hidden_states)
924
 
@@ -937,6 +1016,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
937
  linear_attn_mask = None
938
  return linear_attn_mask
939
 
 
940
  @torch.compiler.disable
941
  def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None):
942
  """
@@ -1019,6 +1099,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1019
  attentions=outputs.attentions,
1020
  )
1021
 
 
1022
  # ==================== AUTOMODEL REGISTRATION ====================
1023
 
1024
  __all__ = [
 
1
  #!/usr/bin/env python3
2
  """
3
+ NeoLLM Model with FANformer Integration, Dropout Regularization, Selective Self-Attention (SSA),
4
+ and ResFormer Value Residual Learning for enhanced information flow through deep layers.
5
+
6
+ Updated to include:
7
+ - Fourier Analysis Network (FAN) layer for effective periodicity modeling
8
+ - Dropout regularization at strategic locations
9
+ - ResFormer: Feature residual connections from first layer (applied before projections)
10
  """
11
 
12
  import math
 
32
  is_causal_conv1d_available,
33
  is_flash_linear_attention_available,
34
  )
35
+ from configuration_neollm import NeoLLMConfig
36
 
37
 
38
  if is_causal_conv1d_available():
 
49
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
50
 
51
  logger = logging.get_logger(__name__)
52
+
53
+
54
  class FANLayer(nn.Module):
55
  """
56
  Fourier Analysis Network (FAN) layer for effective periodicity modeling.
 
289
 
290
 
291
  class NeoLLMAttention(nn.Module):
292
+ """
293
+ Multi-headed attention with FANformer integration, Selective Self-Attention for periodicity modeling,
294
+ and ResFormer feature residual connections for enhanced information flow.
295
+
296
+ ResFormer enhancement: Applies learnable feature residual connections from the first layer
297
+ BEFORE QKV projections: H'_fan_n = 位_1 * H_fan_1 + 位_2 * H_fan_n
298
+ """
299
 
300
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
301
  super().__init__()
 
334
 
335
  # Dropout for attention output
336
  self.dropout = nn.Dropout(config.dropout_rate)
337
+
338
+ # ResFormer: learnable feature residual parameters (initialized to 0.5)
339
+ self.lambda_1 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_1 (first layer features)
340
+ self.lambda_2 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_n (current layer features)
341
 
342
  def forward(
343
  self,
344
  hidden_states: torch.Tensor,
345
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
346
  attention_mask: Optional[torch.Tensor],
347
+ first_layer_fan: Optional[torch.Tensor] = None,
348
  **kwargs: Unpack[FlashAttentionKwargs],
349
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
350
  input_shape = hidden_states.shape[:-1]
351
 
352
  # Apply FANformer transformation first
353
  hidden_states_fan = self.fan_layer(hidden_states)
354
 
355
+ # ResFormer: Apply feature residual connection BEFORE projections
356
+ # This ensures dimensional compatibility across all layer types
357
+ if first_layer_fan is not None:
358
+ hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
359
+
360
+ # Store current FAN features for potential use as first_layer_fan in subsequent layers
361
+ current_layer_fan = hidden_states_fan.clone()
362
+
363
  hidden_shape = (*input_shape, -1, self.head_dim)
364
 
365
+ # Use FAN-transformed features (with residual applied) for projections
366
  query_states, gate = torch.chunk(
367
  self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
368
  )
 
394
  attn_output = attn_output * torch.sigmoid(gate)
395
 
396
  attn_output = self.o_proj(attn_output)
397
+ attn_output = self.dropout(attn_output)
398
+
399
+ return attn_output, attn_weights, current_layer_fan
400
 
401
 
402
  def apply_mask_to_padding_states(hidden_states, attention_mask):
 
560
  core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
561
  return core_attn_out, last_recurrent_state
562
 
563
+
564
  class NeoLLMGatedDeltaNet(nn.Module):
565
+ """
566
+ Linear attention with FANformer integration, Selective Self-Attention for periodicity modeling,
567
+ and ResFormer feature residual connections for enhanced information flow.
568
+
569
+ ResFormer enhancement: Applies learnable feature residual connections from the first layer
570
+ BEFORE QKV projections: H'_fan_n = 位_1 * H_fan_1 + 位_2 * H_fan_n
571
+ """
572
 
573
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
574
  super().__init__()
 
643
  self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule
644
  self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule
645
 
646
+ # ResFormer: learnable feature residual parameters (initialized to 0.5)
647
+ self.lambda_1 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_1 (first layer features)
648
+ self.lambda_2 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_n (current layer features)
649
+
650
  if not is_fast_path_available:
651
  logger.warning_once(
652
  "The fast path is not available because one of the required library is not installed. Falling back to "
 
686
  self,
687
  hidden_states: torch.Tensor,
688
  attention_mask: Optional[torch.Tensor] = None,
689
+ first_layer_fan: Optional[torch.Tensor] = None,
690
+ ) -> tuple[torch.Tensor, torch.Tensor]:
691
  hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
692
 
693
  # Set up dimensions for reshapes later
 
696
  # Apply FANformer transformation first
697
  hidden_states_fan = self.fan_layer(hidden_states)
698
 
699
+ # ResFormer: Apply feature residual connection BEFORE projections
700
+ # This ensures dimensional compatibility across all layer types
701
+ if first_layer_fan is not None:
702
+ hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
703
+
704
+ # Store current FAN features for potential use as first_layer_fan in subsequent layers
705
+ current_layer_fan = hidden_states_fan.clone()
706
+
707
+ # Use FAN-transformed features (with residual applied) for projections
708
  projected_states_qkvz = self.in_proj_qkvz(hidden_states_fan)
709
  projected_states_ba = self.in_proj_ba(hidden_states_fan)
710
  query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
 
768
 
769
  output = self.out_proj(core_attn_out)
770
  output = self.dropout(output) # Apply dropout after output projection
771
+
772
+ return output, current_layer_fan
773
+
774
 
775
  class PolyNorm(torch.nn.Module):
776
  def __init__(self, eps=1e-6):
 
785
  def forward(self, x):
786
  return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias
787
 
788
+
789
  class NeoLLMMLP(nn.Module):
790
  def __init__(self, config):
791
  super().__init__()
 
809
  hidden = self.dropout(hidden)
810
  return self.down_proj(hidden)
811
 
812
+
813
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
814
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
815
  super().__init__()
 
836
  # GPAS (Gradient-Preserving Activation Scaling) - applied after residual connections
837
  self.gpas_attn = GPAS(config.hidden_size)
838
  self.gpas_mlp = GPAS(config.hidden_size)
839
+
840
+ # ResFormer: storage for current layer's FAN features
841
+ self.current_layer_fan = None
842
 
843
  def forward(
844
  self,
845
  hidden_states: torch.Tensor,
846
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
847
  attention_mask: Optional[torch.Tensor] = None,
848
+ first_layer_fan: Optional[torch.Tensor] = None,
849
  **kwargs: Unpack[FlashAttentionKwargs],
850
  ) -> torch.FloatTensor:
851
  residual = hidden_states
 
856
  # Apply LNS scaling after normalization
857
  hidden_states = self.lns_attn(hidden_states)
858
 
859
+ # Token Mixer with ResFormer feature residual connections
860
  if self.layer_type == "linear_attention":
861
+ hidden_states, self.current_layer_fan = self.linear_attn(
862
  hidden_states=hidden_states,
863
  attention_mask=attention_mask,
864
+ first_layer_fan=first_layer_fan,
865
  )
866
  elif self.layer_type == "full_attention":
867
  # Self Attention
868
+ hidden_states, _, self.current_layer_fan = self.self_attn(
869
  hidden_states=hidden_states,
870
  attention_mask=attention_mask,
871
  position_embeddings=position_embeddings,
872
+ first_layer_fan=first_layer_fan,
873
  **kwargs,
874
  )
875
 
 
911
  if isinstance(module, NeoLLMGatedDeltaNet):
912
  module.dt_bias.data.fill_(1.0)
913
  module.A_log.data.uniform_(0, 16).log_()
914
+ # ResFormer: initialize lambda parameters for linear attention
915
+ if hasattr(module, 'lambda_1'):
916
+ module.lambda_1.data.fill_(0.5)
917
+ if hasattr(module, 'lambda_2'):
918
+ module.lambda_2.data.fill_(0.5)
919
+ elif isinstance(module, NeoLLMAttention):
920
+ # ResFormer: initialize lambda parameters for full attention
921
+ if hasattr(module, 'lambda_1'):
922
+ module.lambda_1.data.fill_(0.5)
923
+ if hasattr(module, 'lambda_2'):
924
+ module.lambda_2.data.fill_(0.5)
925
  elif isinstance(module, GPAS):
926
  # Initialize GPAS alpha to 0 as per paper
927
  module.alpha.data.fill_(0.0)
 
942
  self.norm = NeoLLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
943
  self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
944
  self.gradient_checkpointing = False
945
+
946
+ # ResFormer: storage for first layer's FAN features (H_fan_1)
947
+ self.first_layer_fan = None
948
+
949
  # Initialize weights and apply final processing
950
  self.post_init()
951
 
 
981
  # create position embeddings to be shared across the decoder layers
982
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
983
 
984
+ # ResFormer: reset first_layer_fan at the start of each forward pass
985
+ self.first_layer_fan = None
986
+
987
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
988
  layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask
989
 
 
991
  hidden_states,
992
  position_embeddings=position_embeddings,
993
  attention_mask=layer_mask,
994
+ first_layer_fan=self.first_layer_fan, # Pass H_fan_1 to all layers
995
  **kwargs,
996
  )
997
+
998
+ # ResFormer: capture H_fan_1 from the first layer
999
+ if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
1000
+ self.first_layer_fan = decoder_layer.current_layer_fan
1001
 
1002
  hidden_states = self.norm(hidden_states)
1003
 
 
1016
  linear_attn_mask = None
1017
  return linear_attn_mask
1018
 
1019
+
1020
  @torch.compiler.disable
1021
  def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None):
1022
  """
 
1099
  attentions=outputs.attentions,
1100
  )
1101
 
1102
+
1103
  # ==================== AUTOMODEL REGISTRATION ====================
1104
 
1105
  __all__ = [