Update modeling_neollm.py
Browse files- modeling_neollm.py +25 -6
modeling_neollm.py
CHANGED
|
@@ -631,7 +631,6 @@ class NeoLLMMLP(nn.Module):
|
|
| 631 |
hidden = self.dropout(hidden)
|
| 632 |
return self.down_proj(hidden)
|
| 633 |
|
| 634 |
-
|
| 635 |
class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
| 636 |
"""
|
| 637 |
Decoder layer with standard residual connections.
|
|
@@ -677,8 +676,9 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 677 |
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 678 |
attention_mask: Optional[torch.Tensor] = None,
|
| 679 |
first_layer_fan: Optional[torch.Tensor] = None,
|
|
|
|
| 680 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 681 |
-
) -> torch.FloatTensor:
|
| 682 |
# ============================================================
|
| 683 |
# Attention Block with standard residual connection
|
| 684 |
# ============================================================
|
|
@@ -691,7 +691,8 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 691 |
hidden_states = self.lns_attn(hidden_states)
|
| 692 |
|
| 693 |
# Self Attention with ResFormer feature residual connections and learnable multipliers
|
| 694 |
-
|
|
|
|
| 695 |
hidden_states=hidden_states,
|
| 696 |
attention_mask=attention_mask,
|
| 697 |
position_embeddings=position_embeddings,
|
|
@@ -723,7 +724,11 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 723 |
# Apply GPAS after MLP residual connection
|
| 724 |
hidden_states = self.gpas_mlp(hidden_states)
|
| 725 |
|
| 726 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 727 |
|
| 728 |
|
| 729 |
class NeoLLMPreTrainedModel(PreTrainedModel):
|
|
@@ -788,6 +793,7 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
|
|
| 788 |
# scale adaptations from data without initial bias
|
| 789 |
if hasattr(module, 'multiplier'):
|
| 790 |
module.multiplier.data.fill_(1.0)
|
|
|
|
| 791 |
class NeoLLMModel(NeoLLMPreTrainedModel):
|
| 792 |
"""
|
| 793 |
NeoLLM base model with transformer decoder architecture.
|
|
@@ -842,6 +848,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 842 |
position_ids: Optional[torch.LongTensor] = None,
|
| 843 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 844 |
output_hidden_states: Optional[bool] = None,
|
|
|
|
| 845 |
return_dict: Optional[bool] = None,
|
| 846 |
**kwargs: Unpack[TransformersKwargs],
|
| 847 |
) -> BaseModelOutputWithPast:
|
|
@@ -849,6 +856,10 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 849 |
output_hidden_states if output_hidden_states is not None
|
| 850 |
else self.config.output_hidden_states
|
| 851 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 852 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 853 |
|
| 854 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
@@ -875,6 +886,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 875 |
|
| 876 |
hidden_states = inputs_embeds
|
| 877 |
all_hidden_states = () if output_hidden_states else None
|
|
|
|
| 878 |
|
| 879 |
# create position embeddings to be shared across the decoder layers
|
| 880 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
@@ -886,14 +898,20 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 886 |
if output_hidden_states:
|
| 887 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 888 |
|
| 889 |
-
|
| 890 |
hidden_states,
|
| 891 |
position_embeddings=position_embeddings,
|
| 892 |
attention_mask=causal_mask,
|
| 893 |
first_layer_fan=self.first_layer_fan, # Pass H_fan_1 to all layers
|
|
|
|
| 894 |
**kwargs,
|
| 895 |
)
|
| 896 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 897 |
# ResFormer: capture H_fan_1 from the first layer
|
| 898 |
if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
|
| 899 |
self.first_layer_fan = decoder_layer.current_layer_fan
|
|
@@ -905,12 +923,13 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 905 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 906 |
|
| 907 |
if not return_dict:
|
| 908 |
-
return tuple(v for v in [hidden_states, None, all_hidden_states] if v is not None)
|
| 909 |
|
| 910 |
return BaseModelOutputWithPast(
|
| 911 |
last_hidden_state=hidden_states,
|
| 912 |
past_key_values=None,
|
| 913 |
hidden_states=all_hidden_states,
|
|
|
|
| 914 |
)
|
| 915 |
|
| 916 |
|
|
|
|
| 631 |
hidden = self.dropout(hidden)
|
| 632 |
return self.down_proj(hidden)
|
| 633 |
|
|
|
|
| 634 |
class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
| 635 |
"""
|
| 636 |
Decoder layer with standard residual connections.
|
|
|
|
| 676 |
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 677 |
attention_mask: Optional[torch.Tensor] = None,
|
| 678 |
first_layer_fan: Optional[torch.Tensor] = None,
|
| 679 |
+
output_attentions: Optional[bool] = False,
|
| 680 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 681 |
+
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
|
| 682 |
# ============================================================
|
| 683 |
# Attention Block with standard residual connection
|
| 684 |
# ============================================================
|
|
|
|
| 691 |
hidden_states = self.lns_attn(hidden_states)
|
| 692 |
|
| 693 |
# Self Attention with ResFormer feature residual connections and learnable multipliers
|
| 694 |
+
# We capture attn_weights here instead of ignoring them
|
| 695 |
+
hidden_states, attn_weights, self.current_layer_fan = self.self_attn(
|
| 696 |
hidden_states=hidden_states,
|
| 697 |
attention_mask=attention_mask,
|
| 698 |
position_embeddings=position_embeddings,
|
|
|
|
| 724 |
# Apply GPAS after MLP residual connection
|
| 725 |
hidden_states = self.gpas_mlp(hidden_states)
|
| 726 |
|
| 727 |
+
outputs = (hidden_states,)
|
| 728 |
+
if output_attentions:
|
| 729 |
+
outputs += (attn_weights,)
|
| 730 |
+
|
| 731 |
+
return outputs
|
| 732 |
|
| 733 |
|
| 734 |
class NeoLLMPreTrainedModel(PreTrainedModel):
|
|
|
|
| 793 |
# scale adaptations from data without initial bias
|
| 794 |
if hasattr(module, 'multiplier'):
|
| 795 |
module.multiplier.data.fill_(1.0)
|
| 796 |
+
|
| 797 |
class NeoLLMModel(NeoLLMPreTrainedModel):
|
| 798 |
"""
|
| 799 |
NeoLLM base model with transformer decoder architecture.
|
|
|
|
| 848 |
position_ids: Optional[torch.LongTensor] = None,
|
| 849 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 850 |
output_hidden_states: Optional[bool] = None,
|
| 851 |
+
output_attentions: Optional[bool] = None,
|
| 852 |
return_dict: Optional[bool] = None,
|
| 853 |
**kwargs: Unpack[TransformersKwargs],
|
| 854 |
) -> BaseModelOutputWithPast:
|
|
|
|
| 856 |
output_hidden_states if output_hidden_states is not None
|
| 857 |
else self.config.output_hidden_states
|
| 858 |
)
|
| 859 |
+
output_attentions = (
|
| 860 |
+
output_attentions if output_attentions is not None
|
| 861 |
+
else self.config.output_attentions
|
| 862 |
+
)
|
| 863 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 864 |
|
| 865 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
|
|
| 886 |
|
| 887 |
hidden_states = inputs_embeds
|
| 888 |
all_hidden_states = () if output_hidden_states else None
|
| 889 |
+
all_attentions = () if output_attentions else None
|
| 890 |
|
| 891 |
# create position embeddings to be shared across the decoder layers
|
| 892 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
|
|
| 898 |
if output_hidden_states:
|
| 899 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 900 |
|
| 901 |
+
layer_outputs = decoder_layer(
|
| 902 |
hidden_states,
|
| 903 |
position_embeddings=position_embeddings,
|
| 904 |
attention_mask=causal_mask,
|
| 905 |
first_layer_fan=self.first_layer_fan, # Pass H_fan_1 to all layers
|
| 906 |
+
output_attentions=output_attentions,
|
| 907 |
**kwargs,
|
| 908 |
)
|
| 909 |
|
| 910 |
+
hidden_states = layer_outputs[0]
|
| 911 |
+
|
| 912 |
+
if output_attentions:
|
| 913 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 914 |
+
|
| 915 |
# ResFormer: capture H_fan_1 from the first layer
|
| 916 |
if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
|
| 917 |
self.first_layer_fan = decoder_layer.current_layer_fan
|
|
|
|
| 923 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 924 |
|
| 925 |
if not return_dict:
|
| 926 |
+
return tuple(v for v in [hidden_states, None, all_hidden_states, all_attentions] if v is not None)
|
| 927 |
|
| 928 |
return BaseModelOutputWithPast(
|
| 929 |
last_hidden_state=hidden_states,
|
| 930 |
past_key_values=None,
|
| 931 |
hidden_states=all_hidden_states,
|
| 932 |
+
attentions=all_attentions,
|
| 933 |
)
|
| 934 |
|
| 935 |
|