KitsuVp commited on
Commit
a4f976b
·
verified ·
1 Parent(s): 5f5bbd2

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +52 -61
modeling_neollm.py CHANGED
@@ -788,7 +788,6 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
788
  # scale adaptations from data without initial bias
789
  if hasattr(module, 'multiplier'):
790
  module.multiplier.data.fill_(1.0)
791
-
792
  class NeoLLMModel(NeoLLMPreTrainedModel):
793
  """
794
  NeoLLM base model with transformer decoder architecture.
@@ -842,8 +841,16 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
842
  attention_mask: Optional[torch.Tensor] = None,
843
  position_ids: Optional[torch.LongTensor] = None,
844
  inputs_embeds: Optional[torch.FloatTensor] = None,
 
 
845
  **kwargs: Unpack[TransformersKwargs],
846
  ) -> BaseModelOutputWithPast:
 
 
 
 
 
 
847
  if (input_ids is None) ^ (inputs_embeds is not None):
848
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
849
 
@@ -867,6 +874,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
867
  )
868
 
869
  hidden_states = inputs_embeds
 
870
 
871
  # create position embeddings to be shared across the decoder layers
872
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
@@ -875,6 +883,9 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
875
  self.first_layer_fan = None
876
 
877
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
 
 
 
878
  hidden_states = decoder_layer(
879
  hidden_states,
880
  position_embeddings=position_embeddings,
@@ -890,9 +901,16 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
890
  # Apply SeeDNorm for final normalization
891
  hidden_states = self.norm(hidden_states)
892
 
 
 
 
 
 
 
893
  return BaseModelOutputWithPast(
894
  last_hidden_state=hidden_states,
895
  past_key_values=None,
 
896
  )
897
 
898
 
@@ -953,75 +971,48 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
953
  attention_mask: Optional[torch.Tensor] = None,
954
  position_ids: Optional[torch.LongTensor] = None,
955
  inputs_embeds: Optional[torch.FloatTensor] = None,
956
- output_hidden_states: Optional[bool] = None,
957
- return_dict: Optional[bool] = None,
 
 
958
  **kwargs: Unpack[TransformersKwargs],
959
- ) -> BaseModelOutputWithPast:
960
-
961
- output_hidden_states = (
962
- output_hidden_states if output_hidden_states is not None
963
- else self.config.output_hidden_states
964
- )
965
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
966
-
967
- if (input_ids is None) ^ (inputs_embeds is not None):
968
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
969
-
970
- if inputs_embeds is None:
971
- inputs_embeds = self.embed_tokens(input_ids)
972
-
973
- if position_ids is None:
974
- position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
975
-
976
- causal_mask = create_causal_mask(
977
- config=self.config,
978
- input_embeds=inputs_embeds,
979
  attention_mask=attention_mask,
980
- cache_position=position_ids.squeeze(0),
981
- past_key_values=None,
982
  position_ids=position_ids,
 
 
 
 
983
  )
984
-
985
- hidden_states = inputs_embeds
986
 
987
- all_hidden_states = () if output_hidden_states else None
988
-
989
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
990
-
991
- self.first_layer_fan = None
992
-
993
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
994
-
995
- if output_hidden_states:
996
- all_hidden_states = all_hidden_states + (hidden_states,)
997
-
998
- hidden_states = decoder_layer(
999
- hidden_states,
1000
- position_embeddings=position_embeddings,
1001
- attention_mask=causal_mask,
1002
- first_layer_fan=self.first_layer_fan,
1003
- **kwargs,
1004
  )
1005
-
1006
- if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
1007
- self.first_layer_fan = decoder_layer.current_layer_fan
1008
-
1009
- hidden_states = self.norm(hidden_states)
1010
-
1011
- if output_hidden_states:
1012
- all_hidden_states = all_hidden_states + (hidden_states,)
1013
-
1014
- if not return_dict:
1015
- return tuple(v for v in [hidden_states, None, all_hidden_states] if v is not None)
1016
-
1017
- return BaseModelOutputWithPast(
1018
- last_hidden_state=hidden_states,
1019
  past_key_values=None,
1020
- hidden_states=all_hidden_states,
1021
- attentions=None,
1022
  )
1023
 
1024
-
1025
  # ==================== AUTOMODEL REGISTRATION ====================
1026
 
1027
  __all__ = [
 
788
  # scale adaptations from data without initial bias
789
  if hasattr(module, 'multiplier'):
790
  module.multiplier.data.fill_(1.0)
 
791
  class NeoLLMModel(NeoLLMPreTrainedModel):
792
  """
793
  NeoLLM base model with transformer decoder architecture.
 
841
  attention_mask: Optional[torch.Tensor] = None,
842
  position_ids: Optional[torch.LongTensor] = None,
843
  inputs_embeds: Optional[torch.FloatTensor] = None,
844
+ output_hidden_states: Optional[bool] = None,
845
+ return_dict: Optional[bool] = None,
846
  **kwargs: Unpack[TransformersKwargs],
847
  ) -> BaseModelOutputWithPast:
848
+ output_hidden_states = (
849
+ output_hidden_states if output_hidden_states is not None
850
+ else self.config.output_hidden_states
851
+ )
852
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
853
+
854
  if (input_ids is None) ^ (inputs_embeds is not None):
855
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
856
 
 
874
  )
875
 
876
  hidden_states = inputs_embeds
877
+ all_hidden_states = () if output_hidden_states else None
878
 
879
  # create position embeddings to be shared across the decoder layers
880
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
 
883
  self.first_layer_fan = None
884
 
885
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
886
+ if output_hidden_states:
887
+ all_hidden_states = all_hidden_states + (hidden_states,)
888
+
889
  hidden_states = decoder_layer(
890
  hidden_states,
891
  position_embeddings=position_embeddings,
 
901
  # Apply SeeDNorm for final normalization
902
  hidden_states = self.norm(hidden_states)
903
 
904
+ if output_hidden_states:
905
+ all_hidden_states = all_hidden_states + (hidden_states,)
906
+
907
+ if not return_dict:
908
+ return tuple(v for v in [hidden_states, None, all_hidden_states] if v is not None)
909
+
910
  return BaseModelOutputWithPast(
911
  last_hidden_state=hidden_states,
912
  past_key_values=None,
913
+ hidden_states=all_hidden_states,
914
  )
915
 
916
 
 
971
  attention_mask: Optional[torch.Tensor] = None,
972
  position_ids: Optional[torch.LongTensor] = None,
973
  inputs_embeds: Optional[torch.FloatTensor] = None,
974
+ labels: Optional[torch.LongTensor] = None,
975
+ logits_to_keep: Union[int, torch.Tensor] = 0,
976
+ output_hidden_states: Optional[bool] = None,
977
+ return_dict: Optional[bool] = None,
978
  **kwargs: Unpack[TransformersKwargs],
979
+ ) -> CausalLMOutputWithPast:
980
+ outputs: BaseModelOutputWithPast = self.model(
981
+ input_ids=input_ids,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
982
  attention_mask=attention_mask,
 
 
983
  position_ids=position_ids,
984
+ inputs_embeds=inputs_embeds,
985
+ output_hidden_states=output_hidden_states,
986
+ return_dict=return_dict,
987
+ **kwargs,
988
  )
 
 
989
 
990
+ hidden_states = outputs.last_hidden_state
991
+
992
+ # CCE Loss computation for training
993
+ if labels is not None:
994
+ loss = compute_cce_loss(
995
+ hidden_states,
996
+ labels,
997
+ self.lm_head.weight,
998
+ getattr(self.lm_head, 'bias', None),
999
+ self.config.pad_token_id
 
 
 
 
 
 
 
1000
  )
1001
+ logits = None
1002
+ else:
1003
+ # Inference mode - compute logits normally
1004
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1005
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1006
+ loss = None
1007
+
1008
+ return CausalLMOutputWithPast(
1009
+ loss=loss,
1010
+ logits=logits,
 
 
 
 
1011
  past_key_values=None,
1012
+ hidden_states=outputs.hidden_states,
1013
+ attentions=outputs.attentions,
1014
  )
1015
 
 
1016
  # ==================== AUTOMODEL REGISTRATION ====================
1017
 
1018
  __all__ = [