KitsuVp commited on
Commit
a0f735e
·
verified ·
1 Parent(s): cb51961

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +61 -31
modeling_neollm.py CHANGED
@@ -33,7 +33,7 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
33
  from transformers.processing_utils import Unpack
34
  from transformers.utils import TransformersKwargs, logging
35
  from transformers.utils.generic import check_model_inputs
36
- from .configuration_neollm import NeoLLMConfig
37
 
38
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
39
 
@@ -953,42 +953,72 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
953
  attention_mask: Optional[torch.Tensor] = None,
954
  position_ids: Optional[torch.LongTensor] = None,
955
  inputs_embeds: Optional[torch.FloatTensor] = None,
956
- labels: Optional[torch.LongTensor] = None,
957
- logits_to_keep: Union[int, torch.Tensor] = 0,
958
  **kwargs: Unpack[TransformersKwargs],
959
- ) -> CausalLMOutputWithPast:
960
- outputs: BaseModelOutputWithPast = self.model(
961
- input_ids=input_ids,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
962
  attention_mask=attention_mask,
 
 
963
  position_ids=position_ids,
964
- inputs_embeds=inputs_embeds,
965
- **kwargs,
966
  )
 
 
967
 
968
- hidden_states = outputs.last_hidden_state
969
-
970
- # CCE Loss computation for training
971
- if labels is not None:
972
- loss = compute_cce_loss(
973
- hidden_states,
974
- labels,
975
- self.lm_head.weight,
976
- getattr(self.lm_head, 'bias', None),
977
- self.config.pad_token_id
 
 
 
 
 
 
 
978
  )
979
- logits = None
980
- else:
981
- # Inference mode - compute logits normally
982
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
983
- logits = self.lm_head(hidden_states[:, slice_indices, :])
984
- loss = None
985
-
986
- return CausalLMOutputWithPast(
987
- loss=loss,
988
- logits=logits,
 
 
 
 
989
  past_key_values=None,
990
- hidden_states=outputs.hidden_states,
991
- attentions=outputs.attentions,
992
  )
993
 
994
 
@@ -1009,4 +1039,4 @@ __all__ = [
1009
  # Register the configuration and model for AutoClass support
1010
  AutoConfig.register("neollm", NeoLLMConfig)
1011
  AutoModel.register(NeoLLMConfig, NeoLLMModel)
1012
- AutoModelForCausalLM.register(NeoLLMConfig, NeoLLMForCausalLM)
 
33
  from transformers.processing_utils import Unpack
34
  from transformers.utils import TransformersKwargs, logging
35
  from transformers.utils.generic import check_model_inputs
36
+ from configuration_neollm import NeoLLMConfig
37
 
38
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
39
 
 
953
  attention_mask: Optional[torch.Tensor] = None,
954
  position_ids: Optional[torch.LongTensor] = None,
955
  inputs_embeds: Optional[torch.FloatTensor] = None,
956
+ output_hidden_states: Optional[bool] = None,
957
+ return_dict: Optional[bool] = None,
958
  **kwargs: Unpack[TransformersKwargs],
959
+ ) -> BaseModelOutputWithPast:
960
+
961
+ output_hidden_states = (
962
+ output_hidden_states if output_hidden_states is not None
963
+ else self.config.output_hidden_states
964
+ )
965
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
966
+
967
+ if (input_ids is None) ^ (inputs_embeds is not None):
968
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
969
+
970
+ if inputs_embeds is None:
971
+ inputs_embeds = self.embed_tokens(input_ids)
972
+
973
+ if position_ids is None:
974
+ position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
975
+
976
+ causal_mask = create_causal_mask(
977
+ config=self.config,
978
+ input_embeds=inputs_embeds,
979
  attention_mask=attention_mask,
980
+ cache_position=position_ids.squeeze(0),
981
+ past_key_values=None,
982
  position_ids=position_ids,
 
 
983
  )
984
+
985
+ hidden_states = inputs_embeds
986
 
987
+ all_hidden_states = () if output_hidden_states else None
988
+
989
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
990
+
991
+ self.first_layer_fan = None
992
+
993
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
994
+
995
+ if output_hidden_states:
996
+ all_hidden_states = all_hidden_states + (hidden_states,)
997
+
998
+ hidden_states = decoder_layer(
999
+ hidden_states,
1000
+ position_embeddings=position_embeddings,
1001
+ attention_mask=causal_mask,
1002
+ first_layer_fan=self.first_layer_fan,
1003
+ **kwargs,
1004
  )
1005
+
1006
+ if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
1007
+ self.first_layer_fan = decoder_layer.current_layer_fan
1008
+
1009
+ hidden_states = self.norm(hidden_states)
1010
+
1011
+ if output_hidden_states:
1012
+ all_hidden_states = all_hidden_states + (hidden_states,)
1013
+
1014
+ if not return_dict:
1015
+ return tuple(v for v in [hidden_states, None, all_hidden_states] if v is not None)
1016
+
1017
+ return BaseModelOutputWithPast(
1018
+ last_hidden_state=hidden_states,
1019
  past_key_values=None,
1020
+ hidden_states=all_hidden_states,
1021
+ attentions=None,
1022
  )
1023
 
1024
 
 
1039
  # Register the configuration and model for AutoClass support
1040
  AutoConfig.register("neollm", NeoLLMConfig)
1041
  AutoModel.register(NeoLLMConfig, NeoLLMModel)
1042
+ AutoModelForCausalLM.register(NeoLLMConfig, NeoLLMForCausalLM)