Update modeling_neollm.py
Browse files- modeling_neollm.py +52 -61
modeling_neollm.py
CHANGED
|
@@ -788,7 +788,6 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
|
|
| 788 |
# scale adaptations from data without initial bias
|
| 789 |
if hasattr(module, 'multiplier'):
|
| 790 |
module.multiplier.data.fill_(1.0)
|
| 791 |
-
|
| 792 |
class NeoLLMModel(NeoLLMPreTrainedModel):
|
| 793 |
"""
|
| 794 |
NeoLLM base model with transformer decoder architecture.
|
|
@@ -842,8 +841,16 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 842 |
attention_mask: Optional[torch.Tensor] = None,
|
| 843 |
position_ids: Optional[torch.LongTensor] = None,
|
| 844 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
|
|
|
| 845 |
**kwargs: Unpack[TransformersKwargs],
|
| 846 |
) -> BaseModelOutputWithPast:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 847 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 848 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 849 |
|
|
@@ -867,6 +874,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 867 |
)
|
| 868 |
|
| 869 |
hidden_states = inputs_embeds
|
|
|
|
| 870 |
|
| 871 |
# create position embeddings to be shared across the decoder layers
|
| 872 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
@@ -875,6 +883,9 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 875 |
self.first_layer_fan = None
|
| 876 |
|
| 877 |
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
|
|
|
|
|
|
|
|
|
| 878 |
hidden_states = decoder_layer(
|
| 879 |
hidden_states,
|
| 880 |
position_embeddings=position_embeddings,
|
|
@@ -890,9 +901,16 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 890 |
# Apply SeeDNorm for final normalization
|
| 891 |
hidden_states = self.norm(hidden_states)
|
| 892 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 893 |
return BaseModelOutputWithPast(
|
| 894 |
last_hidden_state=hidden_states,
|
| 895 |
past_key_values=None,
|
|
|
|
| 896 |
)
|
| 897 |
|
| 898 |
|
|
@@ -953,75 +971,48 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 953 |
attention_mask: Optional[torch.Tensor] = None,
|
| 954 |
position_ids: Optional[torch.LongTensor] = None,
|
| 955 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 956 |
-
|
| 957 |
-
|
|
|
|
|
|
|
| 958 |
**kwargs: Unpack[TransformersKwargs],
|
| 959 |
-
) ->
|
| 960 |
-
|
| 961 |
-
|
| 962 |
-
output_hidden_states if output_hidden_states is not None
|
| 963 |
-
else self.config.output_hidden_states
|
| 964 |
-
)
|
| 965 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 966 |
-
|
| 967 |
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 968 |
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 969 |
-
|
| 970 |
-
if inputs_embeds is None:
|
| 971 |
-
inputs_embeds = self.embed_tokens(input_ids)
|
| 972 |
-
|
| 973 |
-
if position_ids is None:
|
| 974 |
-
position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
|
| 975 |
-
|
| 976 |
-
causal_mask = create_causal_mask(
|
| 977 |
-
config=self.config,
|
| 978 |
-
input_embeds=inputs_embeds,
|
| 979 |
attention_mask=attention_mask,
|
| 980 |
-
cache_position=position_ids.squeeze(0),
|
| 981 |
-
past_key_values=None,
|
| 982 |
position_ids=position_ids,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 983 |
)
|
| 984 |
-
|
| 985 |
-
hidden_states = inputs_embeds
|
| 986 |
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
|
| 995 |
-
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
hidden_states = decoder_layer(
|
| 999 |
-
hidden_states,
|
| 1000 |
-
position_embeddings=position_embeddings,
|
| 1001 |
-
attention_mask=causal_mask,
|
| 1002 |
-
first_layer_fan=self.first_layer_fan,
|
| 1003 |
-
**kwargs,
|
| 1004 |
)
|
| 1005 |
-
|
| 1006 |
-
|
| 1007 |
-
|
| 1008 |
-
|
| 1009 |
-
|
| 1010 |
-
|
| 1011 |
-
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
|
| 1015 |
-
return tuple(v for v in [hidden_states, None, all_hidden_states] if v is not None)
|
| 1016 |
-
|
| 1017 |
-
return BaseModelOutputWithPast(
|
| 1018 |
-
last_hidden_state=hidden_states,
|
| 1019 |
past_key_values=None,
|
| 1020 |
-
hidden_states=
|
| 1021 |
-
attentions=
|
| 1022 |
)
|
| 1023 |
|
| 1024 |
-
|
| 1025 |
# ==================== AUTOMODEL REGISTRATION ====================
|
| 1026 |
|
| 1027 |
__all__ = [
|
|
|
|
| 788 |
# scale adaptations from data without initial bias
|
| 789 |
if hasattr(module, 'multiplier'):
|
| 790 |
module.multiplier.data.fill_(1.0)
|
|
|
|
| 791 |
class NeoLLMModel(NeoLLMPreTrainedModel):
|
| 792 |
"""
|
| 793 |
NeoLLM base model with transformer decoder architecture.
|
|
|
|
| 841 |
attention_mask: Optional[torch.Tensor] = None,
|
| 842 |
position_ids: Optional[torch.LongTensor] = None,
|
| 843 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 844 |
+
output_hidden_states: Optional[bool] = None,
|
| 845 |
+
return_dict: Optional[bool] = None,
|
| 846 |
**kwargs: Unpack[TransformersKwargs],
|
| 847 |
) -> BaseModelOutputWithPast:
|
| 848 |
+
output_hidden_states = (
|
| 849 |
+
output_hidden_states if output_hidden_states is not None
|
| 850 |
+
else self.config.output_hidden_states
|
| 851 |
+
)
|
| 852 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 853 |
+
|
| 854 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 855 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 856 |
|
|
|
|
| 874 |
)
|
| 875 |
|
| 876 |
hidden_states = inputs_embeds
|
| 877 |
+
all_hidden_states = () if output_hidden_states else None
|
| 878 |
|
| 879 |
# create position embeddings to be shared across the decoder layers
|
| 880 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
|
|
| 883 |
self.first_layer_fan = None
|
| 884 |
|
| 885 |
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 886 |
+
if output_hidden_states:
|
| 887 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 888 |
+
|
| 889 |
hidden_states = decoder_layer(
|
| 890 |
hidden_states,
|
| 891 |
position_embeddings=position_embeddings,
|
|
|
|
| 901 |
# Apply SeeDNorm for final normalization
|
| 902 |
hidden_states = self.norm(hidden_states)
|
| 903 |
|
| 904 |
+
if output_hidden_states:
|
| 905 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 906 |
+
|
| 907 |
+
if not return_dict:
|
| 908 |
+
return tuple(v for v in [hidden_states, None, all_hidden_states] if v is not None)
|
| 909 |
+
|
| 910 |
return BaseModelOutputWithPast(
|
| 911 |
last_hidden_state=hidden_states,
|
| 912 |
past_key_values=None,
|
| 913 |
+
hidden_states=all_hidden_states,
|
| 914 |
)
|
| 915 |
|
| 916 |
|
|
|
|
| 971 |
attention_mask: Optional[torch.Tensor] = None,
|
| 972 |
position_ids: Optional[torch.LongTensor] = None,
|
| 973 |
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 974 |
+
labels: Optional[torch.LongTensor] = None,
|
| 975 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 976 |
+
output_hidden_states: Optional[bool] = None,
|
| 977 |
+
return_dict: Optional[bool] = None,
|
| 978 |
**kwargs: Unpack[TransformersKwargs],
|
| 979 |
+
) -> CausalLMOutputWithPast:
|
| 980 |
+
outputs: BaseModelOutputWithPast = self.model(
|
| 981 |
+
input_ids=input_ids,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 982 |
attention_mask=attention_mask,
|
|
|
|
|
|
|
| 983 |
position_ids=position_ids,
|
| 984 |
+
inputs_embeds=inputs_embeds,
|
| 985 |
+
output_hidden_states=output_hidden_states,
|
| 986 |
+
return_dict=return_dict,
|
| 987 |
+
**kwargs,
|
| 988 |
)
|
|
|
|
|
|
|
| 989 |
|
| 990 |
+
hidden_states = outputs.last_hidden_state
|
| 991 |
+
|
| 992 |
+
# CCE Loss computation for training
|
| 993 |
+
if labels is not None:
|
| 994 |
+
loss = compute_cce_loss(
|
| 995 |
+
hidden_states,
|
| 996 |
+
labels,
|
| 997 |
+
self.lm_head.weight,
|
| 998 |
+
getattr(self.lm_head, 'bias', None),
|
| 999 |
+
self.config.pad_token_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1000 |
)
|
| 1001 |
+
logits = None
|
| 1002 |
+
else:
|
| 1003 |
+
# Inference mode - compute logits normally
|
| 1004 |
+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1005 |
+
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 1006 |
+
loss = None
|
| 1007 |
+
|
| 1008 |
+
return CausalLMOutputWithPast(
|
| 1009 |
+
loss=loss,
|
| 1010 |
+
logits=logits,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1011 |
past_key_values=None,
|
| 1012 |
+
hidden_states=outputs.hidden_states,
|
| 1013 |
+
attentions=outputs.attentions,
|
| 1014 |
)
|
| 1015 |
|
|
|
|
| 1016 |
# ==================== AUTOMODEL REGISTRATION ====================
|
| 1017 |
|
| 1018 |
__all__ = [
|