KitsuVp commited on
Commit
7f7b8f1
·
verified ·
1 Parent(s): f9b2ab8

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +16 -97
modeling_neollm.py CHANGED
@@ -1,12 +1,8 @@
1
  #!/usr/bin/env python3
2
  """
3
- NeoLLM Model with FANformer Integration, Dropout Regularization, Selective Self-Attention (SSA),
4
- and ResFormer Value Residual Learning for enhanced information flow through deep layers.
5
-
6
- Updated to include:
7
- - Fourier Analysis Network (FAN) layer for effective periodicity modeling
8
- - Dropout regularization at strategic locations
9
- - ResFormer: Feature residual connections from first layer (applied before projections)
10
  """
11
 
12
  import math
@@ -32,7 +28,7 @@ from transformers.utils.import_utils import (
32
  is_causal_conv1d_available,
33
  is_flash_linear_attention_available,
34
  )
35
- from .configuration_neollm import NeoLLMConfig
36
 
37
 
38
  if is_causal_conv1d_available():
@@ -49,8 +45,6 @@ else:
49
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
50
 
51
  logger = logging.get_logger(__name__)
52
-
53
-
54
  class FANLayer(nn.Module):
55
  """
56
  Fourier Analysis Network (FAN) layer for effective periodicity modeling.
@@ -289,13 +283,7 @@ def eager_attention_forward(
289
 
290
 
291
  class NeoLLMAttention(nn.Module):
292
- """
293
- Multi-headed attention with FANformer integration, Selective Self-Attention for periodicity modeling,
294
- and ResFormer feature residual connections for enhanced information flow.
295
-
296
- ResFormer enhancement: Applies learnable feature residual connections from the first layer
297
- BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n
298
- """
299
 
300
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
301
  super().__init__()
@@ -334,35 +322,22 @@ class NeoLLMAttention(nn.Module):
334
 
335
  # Dropout for attention output
336
  self.dropout = nn.Dropout(config.dropout_rate)
337
-
338
- # ResFormer: learnable feature residual parameters (initialized to 0.5)
339
- self.lambda_1 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_1 (first layer features)
340
- self.lambda_2 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_n (current layer features)
341
 
342
  def forward(
343
  self,
344
  hidden_states: torch.Tensor,
345
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
346
  attention_mask: Optional[torch.Tensor],
347
- first_layer_fan: Optional[torch.Tensor] = None,
348
  **kwargs: Unpack[FlashAttentionKwargs],
349
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
350
  input_shape = hidden_states.shape[:-1]
351
 
352
  # Apply FANformer transformation first
353
  hidden_states_fan = self.fan_layer(hidden_states)
354
 
355
- # ResFormer: Apply feature residual connection BEFORE projections
356
- # This ensures dimensional compatibility across all layer types
357
- if first_layer_fan is not None:
358
- hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
359
-
360
- # Store current FAN features for potential use as first_layer_fan in subsequent layers
361
- current_layer_fan = hidden_states_fan.clone()
362
-
363
  hidden_shape = (*input_shape, -1, self.head_dim)
364
 
365
- # Use FAN-transformed features (with residual applied) for projections
366
  query_states, gate = torch.chunk(
367
  self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
368
  )
@@ -394,9 +369,8 @@ class NeoLLMAttention(nn.Module):
394
  attn_output = attn_output * torch.sigmoid(gate)
395
 
396
  attn_output = self.o_proj(attn_output)
397
- attn_output = self.dropout(attn_output)
398
-
399
- return attn_output, attn_weights, current_layer_fan
400
 
401
 
402
  def apply_mask_to_padding_states(hidden_states, attention_mask):
@@ -560,15 +534,8 @@ def torch_recurrent_gated_delta_rule(
560
  core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
561
  return core_attn_out, last_recurrent_state
562
 
563
-
564
  class NeoLLMGatedDeltaNet(nn.Module):
565
- """
566
- Linear attention with FANformer integration, Selective Self-Attention for periodicity modeling,
567
- and ResFormer feature residual connections for enhanced information flow.
568
-
569
- ResFormer enhancement: Applies learnable feature residual connections from the first layer
570
- BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n
571
- """
572
 
573
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
574
  super().__init__()
@@ -643,10 +610,6 @@ class NeoLLMGatedDeltaNet(nn.Module):
643
  self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule
644
  self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule
645
 
646
- # ResFormer: learnable feature residual parameters (initialized to 0.5)
647
- self.lambda_1 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_1 (first layer features)
648
- self.lambda_2 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_n (current layer features)
649
-
650
  if not is_fast_path_available:
651
  logger.warning_once(
652
  "The fast path is not available because one of the required library is not installed. Falling back to "
@@ -686,8 +649,7 @@ class NeoLLMGatedDeltaNet(nn.Module):
686
  self,
687
  hidden_states: torch.Tensor,
688
  attention_mask: Optional[torch.Tensor] = None,
689
- first_layer_fan: Optional[torch.Tensor] = None,
690
- ) -> tuple[torch.Tensor, torch.Tensor]:
691
  hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
692
 
693
  # Set up dimensions for reshapes later
@@ -696,15 +658,7 @@ class NeoLLMGatedDeltaNet(nn.Module):
696
  # Apply FANformer transformation first
697
  hidden_states_fan = self.fan_layer(hidden_states)
698
 
699
- # ResFormer: Apply feature residual connection BEFORE projections
700
- # This ensures dimensional compatibility across all layer types
701
- if first_layer_fan is not None:
702
- hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
703
-
704
- # Store current FAN features for potential use as first_layer_fan in subsequent layers
705
- current_layer_fan = hidden_states_fan.clone()
706
-
707
- # Use FAN-transformed features (with residual applied) for projections
708
  projected_states_qkvz = self.in_proj_qkvz(hidden_states_fan)
709
  projected_states_ba = self.in_proj_ba(hidden_states_fan)
710
  query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
@@ -768,9 +722,7 @@ class NeoLLMGatedDeltaNet(nn.Module):
768
 
769
  output = self.out_proj(core_attn_out)
770
  output = self.dropout(output) # Apply dropout after output projection
771
-
772
- return output, current_layer_fan
773
-
774
 
775
  class PolyNorm(torch.nn.Module):
776
  def __init__(self, eps=1e-6):
@@ -785,7 +737,6 @@ class PolyNorm(torch.nn.Module):
785
  def forward(self, x):
786
  return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias
787
 
788
-
789
  class NeoLLMMLP(nn.Module):
790
  def __init__(self, config):
791
  super().__init__()
@@ -809,7 +760,6 @@ class NeoLLMMLP(nn.Module):
809
  hidden = self.dropout(hidden)
810
  return self.down_proj(hidden)
811
 
812
-
813
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
814
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
815
  super().__init__()
@@ -836,16 +786,12 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
836
  # GPAS (Gradient-Preserving Activation Scaling) - applied after residual connections
837
  self.gpas_attn = GPAS(config.hidden_size)
838
  self.gpas_mlp = GPAS(config.hidden_size)
839
-
840
- # ResFormer: storage for current layer's FAN features
841
- self.current_layer_fan = None
842
 
843
  def forward(
844
  self,
845
  hidden_states: torch.Tensor,
846
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
847
  attention_mask: Optional[torch.Tensor] = None,
848
- first_layer_fan: Optional[torch.Tensor] = None,
849
  **kwargs: Unpack[FlashAttentionKwargs],
850
  ) -> torch.FloatTensor:
851
  residual = hidden_states
@@ -856,20 +802,18 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
856
  # Apply LNS scaling after normalization
857
  hidden_states = self.lns_attn(hidden_states)
858
 
859
- # Token Mixer with ResFormer feature residual connections
860
  if self.layer_type == "linear_attention":
861
- hidden_states, self.current_layer_fan = self.linear_attn(
862
  hidden_states=hidden_states,
863
  attention_mask=attention_mask,
864
- first_layer_fan=first_layer_fan,
865
  )
866
  elif self.layer_type == "full_attention":
867
  # Self Attention
868
- hidden_states, _, self.current_layer_fan = self.self_attn(
869
  hidden_states=hidden_states,
870
  attention_mask=attention_mask,
871
  position_embeddings=position_embeddings,
872
- first_layer_fan=first_layer_fan,
873
  **kwargs,
874
  )
875
 
@@ -911,17 +855,6 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
911
  if isinstance(module, NeoLLMGatedDeltaNet):
912
  module.dt_bias.data.fill_(1.0)
913
  module.A_log.data.uniform_(0, 16).log_()
914
- # ResFormer: initialize lambda parameters for linear attention
915
- if hasattr(module, 'lambda_1'):
916
- module.lambda_1.data.fill_(0.5)
917
- if hasattr(module, 'lambda_2'):
918
- module.lambda_2.data.fill_(0.5)
919
- elif isinstance(module, NeoLLMAttention):
920
- # ResFormer: initialize lambda parameters for full attention
921
- if hasattr(module, 'lambda_1'):
922
- module.lambda_1.data.fill_(0.5)
923
- if hasattr(module, 'lambda_2'):
924
- module.lambda_2.data.fill_(0.5)
925
  elif isinstance(module, GPAS):
926
  # Initialize GPAS alpha to 0 as per paper
927
  module.alpha.data.fill_(0.0)
@@ -942,10 +875,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
942
  self.norm = NeoLLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
943
  self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
944
  self.gradient_checkpointing = False
945
-
946
- # ResFormer: storage for first layer's FAN features (H_fan_1)
947
- self.first_layer_fan = None
948
-
949
  # Initialize weights and apply final processing
950
  self.post_init()
951
 
@@ -981,9 +910,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
981
  # create position embeddings to be shared across the decoder layers
982
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
983
 
984
- # ResFormer: reset first_layer_fan at the start of each forward pass
985
- self.first_layer_fan = None
986
-
987
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
988
  layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask
989
 
@@ -991,13 +917,8 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
991
  hidden_states,
992
  position_embeddings=position_embeddings,
993
  attention_mask=layer_mask,
994
- first_layer_fan=self.first_layer_fan, # Pass H_fan_1 to all layers
995
  **kwargs,
996
  )
997
-
998
- # ResFormer: capture H_fan_1 from the first layer
999
- if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
1000
- self.first_layer_fan = decoder_layer.current_layer_fan
1001
 
1002
  hidden_states = self.norm(hidden_states)
1003
 
@@ -1016,7 +937,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1016
  linear_attn_mask = None
1017
  return linear_attn_mask
1018
 
1019
-
1020
  @torch.compiler.disable
1021
  def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None):
1022
  """
@@ -1099,7 +1019,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1099
  attentions=outputs.attentions,
1100
  )
1101
 
1102
-
1103
  # ==================== AUTOMODEL REGISTRATION ====================
1104
 
1105
  __all__ = [
 
1
  #!/usr/bin/env python3
2
  """
3
+ NeoLLM Model with FANformer Integration, Dropout Regularization, and Selective Self-Attention (SSA)
4
+ Updated to include Fourier Analysis Network (FAN) layer for effective periodicity modeling,
5
+ dropout regularization at strategic locations
 
 
 
 
6
  """
7
 
8
  import math
 
28
  is_causal_conv1d_available,
29
  is_flash_linear_attention_available,
30
  )
31
+ from configuration_neollm import NeoLLMConfig
32
 
33
 
34
  if is_causal_conv1d_available():
 
45
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
46
 
47
  logger = logging.get_logger(__name__)
 
 
48
  class FANLayer(nn.Module):
49
  """
50
  Fourier Analysis Network (FAN) layer for effective periodicity modeling.
 
283
 
284
 
285
  class NeoLLMAttention(nn.Module):
286
+ """Multi-headed attention with FANformer integration and Selective Self-Attention for periodicity modeling"""
 
 
 
 
 
 
287
 
288
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
289
  super().__init__()
 
322
 
323
  # Dropout for attention output
324
  self.dropout = nn.Dropout(config.dropout_rate)
 
 
 
 
325
 
326
  def forward(
327
  self,
328
  hidden_states: torch.Tensor,
329
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
330
  attention_mask: Optional[torch.Tensor],
 
331
  **kwargs: Unpack[FlashAttentionKwargs],
332
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
333
  input_shape = hidden_states.shape[:-1]
334
 
335
  # Apply FANformer transformation first
336
  hidden_states_fan = self.fan_layer(hidden_states)
337
 
 
 
 
 
 
 
 
 
338
  hidden_shape = (*input_shape, -1, self.head_dim)
339
 
340
+ # Use FAN-transformed features directly for projections
341
  query_states, gate = torch.chunk(
342
  self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
343
  )
 
369
  attn_output = attn_output * torch.sigmoid(gate)
370
 
371
  attn_output = self.o_proj(attn_output)
372
+ attn_output = self.dropout(attn_output) # Apply dropout after output projection
373
+ return attn_output, attn_weights
 
374
 
375
 
376
  def apply_mask_to_padding_states(hidden_states, attention_mask):
 
534
  core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
535
  return core_attn_out, last_recurrent_state
536
 
 
537
  class NeoLLMGatedDeltaNet(nn.Module):
538
+ """Linear attention with FANformer integration and Selective Self-Attention for periodicity modeling"""
 
 
 
 
 
 
539
 
540
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
541
  super().__init__()
 
610
  self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule
611
  self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule
612
 
 
 
 
 
613
  if not is_fast_path_available:
614
  logger.warning_once(
615
  "The fast path is not available because one of the required library is not installed. Falling back to "
 
649
  self,
650
  hidden_states: torch.Tensor,
651
  attention_mask: Optional[torch.Tensor] = None,
652
+ ):
 
653
  hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
654
 
655
  # Set up dimensions for reshapes later
 
658
  # Apply FANformer transformation first
659
  hidden_states_fan = self.fan_layer(hidden_states)
660
 
661
+ # Use FAN-transformed features directly for projections
 
 
 
 
 
 
 
 
662
  projected_states_qkvz = self.in_proj_qkvz(hidden_states_fan)
663
  projected_states_ba = self.in_proj_ba(hidden_states_fan)
664
  query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
 
722
 
723
  output = self.out_proj(core_attn_out)
724
  output = self.dropout(output) # Apply dropout after output projection
725
+ return output
 
 
726
 
727
  class PolyNorm(torch.nn.Module):
728
  def __init__(self, eps=1e-6):
 
737
  def forward(self, x):
738
  return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias
739
 
 
740
  class NeoLLMMLP(nn.Module):
741
  def __init__(self, config):
742
  super().__init__()
 
760
  hidden = self.dropout(hidden)
761
  return self.down_proj(hidden)
762
 
 
763
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
764
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
765
  super().__init__()
 
786
  # GPAS (Gradient-Preserving Activation Scaling) - applied after residual connections
787
  self.gpas_attn = GPAS(config.hidden_size)
788
  self.gpas_mlp = GPAS(config.hidden_size)
 
 
 
789
 
790
  def forward(
791
  self,
792
  hidden_states: torch.Tensor,
793
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
794
  attention_mask: Optional[torch.Tensor] = None,
 
795
  **kwargs: Unpack[FlashAttentionKwargs],
796
  ) -> torch.FloatTensor:
797
  residual = hidden_states
 
802
  # Apply LNS scaling after normalization
803
  hidden_states = self.lns_attn(hidden_states)
804
 
805
+ # Token Mixer
806
  if self.layer_type == "linear_attention":
807
+ hidden_states = self.linear_attn(
808
  hidden_states=hidden_states,
809
  attention_mask=attention_mask,
 
810
  )
811
  elif self.layer_type == "full_attention":
812
  # Self Attention
813
+ hidden_states, _ = self.self_attn(
814
  hidden_states=hidden_states,
815
  attention_mask=attention_mask,
816
  position_embeddings=position_embeddings,
 
817
  **kwargs,
818
  )
819
 
 
855
  if isinstance(module, NeoLLMGatedDeltaNet):
856
  module.dt_bias.data.fill_(1.0)
857
  module.A_log.data.uniform_(0, 16).log_()
 
 
 
 
 
 
 
 
 
 
 
858
  elif isinstance(module, GPAS):
859
  # Initialize GPAS alpha to 0 as per paper
860
  module.alpha.data.fill_(0.0)
 
875
  self.norm = NeoLLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
876
  self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
877
  self.gradient_checkpointing = False
 
 
 
 
878
  # Initialize weights and apply final processing
879
  self.post_init()
880
 
 
910
  # create position embeddings to be shared across the decoder layers
911
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
912
 
 
 
 
913
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
914
  layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask
915
 
 
917
  hidden_states,
918
  position_embeddings=position_embeddings,
919
  attention_mask=layer_mask,
 
920
  **kwargs,
921
  )
 
 
 
 
922
 
923
  hidden_states = self.norm(hidden_states)
924
 
 
937
  linear_attn_mask = None
938
  return linear_attn_mask
939
 
 
940
  @torch.compiler.disable
941
  def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None):
942
  """
 
1019
  attentions=outputs.attentions,
1020
  )
1021
 
 
1022
  # ==================== AUTOMODEL REGISTRATION ====================
1023
 
1024
  __all__ = [