KitsuVp commited on
Commit
13c7457
·
verified ·
1 Parent(s): a4f976b

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +25 -6
modeling_neollm.py CHANGED
@@ -631,7 +631,6 @@ class NeoLLMMLP(nn.Module):
631
  hidden = self.dropout(hidden)
632
  return self.down_proj(hidden)
633
 
634
-
635
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
636
  """
637
  Decoder layer with standard residual connections.
@@ -677,8 +676,9 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
677
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
678
  attention_mask: Optional[torch.Tensor] = None,
679
  first_layer_fan: Optional[torch.Tensor] = None,
 
680
  **kwargs: Unpack[FlashAttentionKwargs],
681
- ) -> torch.FloatTensor:
682
  # ============================================================
683
  # Attention Block with standard residual connection
684
  # ============================================================
@@ -691,7 +691,8 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
691
  hidden_states = self.lns_attn(hidden_states)
692
 
693
  # Self Attention with ResFormer feature residual connections and learnable multipliers
694
- hidden_states, _, self.current_layer_fan = self.self_attn(
 
695
  hidden_states=hidden_states,
696
  attention_mask=attention_mask,
697
  position_embeddings=position_embeddings,
@@ -723,7 +724,11 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
723
  # Apply GPAS after MLP residual connection
724
  hidden_states = self.gpas_mlp(hidden_states)
725
 
726
- return hidden_states
 
 
 
 
727
 
728
 
729
  class NeoLLMPreTrainedModel(PreTrainedModel):
@@ -788,6 +793,7 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
788
  # scale adaptations from data without initial bias
789
  if hasattr(module, 'multiplier'):
790
  module.multiplier.data.fill_(1.0)
 
791
  class NeoLLMModel(NeoLLMPreTrainedModel):
792
  """
793
  NeoLLM base model with transformer decoder architecture.
@@ -842,6 +848,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
842
  position_ids: Optional[torch.LongTensor] = None,
843
  inputs_embeds: Optional[torch.FloatTensor] = None,
844
  output_hidden_states: Optional[bool] = None,
 
845
  return_dict: Optional[bool] = None,
846
  **kwargs: Unpack[TransformersKwargs],
847
  ) -> BaseModelOutputWithPast:
@@ -849,6 +856,10 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
849
  output_hidden_states if output_hidden_states is not None
850
  else self.config.output_hidden_states
851
  )
 
 
 
 
852
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
853
 
854
  if (input_ids is None) ^ (inputs_embeds is not None):
@@ -875,6 +886,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
875
 
876
  hidden_states = inputs_embeds
877
  all_hidden_states = () if output_hidden_states else None
 
878
 
879
  # create position embeddings to be shared across the decoder layers
880
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
@@ -886,14 +898,20 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
886
  if output_hidden_states:
887
  all_hidden_states = all_hidden_states + (hidden_states,)
888
 
889
- hidden_states = decoder_layer(
890
  hidden_states,
891
  position_embeddings=position_embeddings,
892
  attention_mask=causal_mask,
893
  first_layer_fan=self.first_layer_fan, # Pass H_fan_1 to all layers
 
894
  **kwargs,
895
  )
896
 
 
 
 
 
 
897
  # ResFormer: capture H_fan_1 from the first layer
898
  if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
899
  self.first_layer_fan = decoder_layer.current_layer_fan
@@ -905,12 +923,13 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
905
  all_hidden_states = all_hidden_states + (hidden_states,)
906
 
907
  if not return_dict:
908
- return tuple(v for v in [hidden_states, None, all_hidden_states] if v is not None)
909
 
910
  return BaseModelOutputWithPast(
911
  last_hidden_state=hidden_states,
912
  past_key_values=None,
913
  hidden_states=all_hidden_states,
 
914
  )
915
 
916
 
 
631
  hidden = self.dropout(hidden)
632
  return self.down_proj(hidden)
633
 
 
634
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
635
  """
636
  Decoder layer with standard residual connections.
 
676
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
677
  attention_mask: Optional[torch.Tensor] = None,
678
  first_layer_fan: Optional[torch.Tensor] = None,
679
+ output_attentions: Optional[bool] = False,
680
  **kwargs: Unpack[FlashAttentionKwargs],
681
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
682
  # ============================================================
683
  # Attention Block with standard residual connection
684
  # ============================================================
 
691
  hidden_states = self.lns_attn(hidden_states)
692
 
693
  # Self Attention with ResFormer feature residual connections and learnable multipliers
694
+ # We capture attn_weights here instead of ignoring them
695
+ hidden_states, attn_weights, self.current_layer_fan = self.self_attn(
696
  hidden_states=hidden_states,
697
  attention_mask=attention_mask,
698
  position_embeddings=position_embeddings,
 
724
  # Apply GPAS after MLP residual connection
725
  hidden_states = self.gpas_mlp(hidden_states)
726
 
727
+ outputs = (hidden_states,)
728
+ if output_attentions:
729
+ outputs += (attn_weights,)
730
+
731
+ return outputs
732
 
733
 
734
  class NeoLLMPreTrainedModel(PreTrainedModel):
 
793
  # scale adaptations from data without initial bias
794
  if hasattr(module, 'multiplier'):
795
  module.multiplier.data.fill_(1.0)
796
+
797
  class NeoLLMModel(NeoLLMPreTrainedModel):
798
  """
799
  NeoLLM base model with transformer decoder architecture.
 
848
  position_ids: Optional[torch.LongTensor] = None,
849
  inputs_embeds: Optional[torch.FloatTensor] = None,
850
  output_hidden_states: Optional[bool] = None,
851
+ output_attentions: Optional[bool] = None,
852
  return_dict: Optional[bool] = None,
853
  **kwargs: Unpack[TransformersKwargs],
854
  ) -> BaseModelOutputWithPast:
 
856
  output_hidden_states if output_hidden_states is not None
857
  else self.config.output_hidden_states
858
  )
859
+ output_attentions = (
860
+ output_attentions if output_attentions is not None
861
+ else self.config.output_attentions
862
+ )
863
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
864
 
865
  if (input_ids is None) ^ (inputs_embeds is not None):
 
886
 
887
  hidden_states = inputs_embeds
888
  all_hidden_states = () if output_hidden_states else None
889
+ all_attentions = () if output_attentions else None
890
 
891
  # create position embeddings to be shared across the decoder layers
892
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
 
898
  if output_hidden_states:
899
  all_hidden_states = all_hidden_states + (hidden_states,)
900
 
901
+ layer_outputs = decoder_layer(
902
  hidden_states,
903
  position_embeddings=position_embeddings,
904
  attention_mask=causal_mask,
905
  first_layer_fan=self.first_layer_fan, # Pass H_fan_1 to all layers
906
+ output_attentions=output_attentions,
907
  **kwargs,
908
  )
909
 
910
+ hidden_states = layer_outputs[0]
911
+
912
+ if output_attentions:
913
+ all_attentions = all_attentions + (layer_outputs[1],)
914
+
915
  # ResFormer: capture H_fan_1 from the first layer
916
  if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
917
  self.first_layer_fan = decoder_layer.current_layer_fan
 
923
  all_hidden_states = all_hidden_states + (hidden_states,)
924
 
925
  if not return_dict:
926
+ return tuple(v for v in [hidden_states, None, all_hidden_states, all_attentions] if v is not None)
927
 
928
  return BaseModelOutputWithPast(
929
  last_hidden_state=hidden_states,
930
  past_key_values=None,
931
  hidden_states=all_hidden_states,
932
+ attentions=all_attentions,
933
  )
934
 
935