KitsuVp commited on
Commit
8d7ff55
·
verified ·
1 Parent(s): 2d2e45d

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +60 -53
modeling_neollm.py CHANGED
@@ -12,7 +12,6 @@ import torch
12
  import torch.nn.functional as F
13
  from torch import nn
14
  from cut_cross_entropy import linear_cross_entropy
15
- from .configuration_neollm import NeoLLMConfig
16
 
17
  from transformers.activations import ACT2FN
18
  from transformers.generation import GenerationMixin
@@ -29,8 +28,8 @@ from transformers.utils.import_utils import (
29
  is_causal_conv1d_available,
30
  is_flash_linear_attention_available,
31
  )
 
32
 
33
- from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
34
 
35
  if is_causal_conv1d_available():
36
  from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
@@ -43,6 +42,7 @@ if is_flash_linear_attention_available():
43
  else:
44
  chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None
45
  FusedRMSNormGated = None
 
46
 
47
  logger = logging.get_logger(__name__)
48
 
@@ -737,25 +737,28 @@ class PolyNorm(torch.nn.Module):
737
 
738
  def forward(self, x):
739
  return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias
740
-
741
  class NeoLLMMLP(nn.Module):
742
  def __init__(self, config):
743
  super().__init__()
744
  self.config = config
745
  self.hidden_size = config.hidden_size
746
  self.intermediate_size = config.intermediate_size
747
- self.linear1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
748
- self.linear2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
 
 
 
749
  self.act_fn = PolyNorm()
750
 
751
  # Dropout for MLP hidden layer
752
  self.dropout = nn.Dropout(config.dropout_rate)
753
 
754
  def forward(self, x):
755
- hidden = self.act_fn(self.linear1(x))
756
- hidden = self.dropout(hidden) # Apply dropout after activation
757
- return self.linear2(hidden)
758
-
 
759
 
760
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
761
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
@@ -931,46 +934,44 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
931
  if attention_mask is not None and torch.all(attention_mask == 1):
932
  linear_attn_mask = None
933
  return linear_attn_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
934
 
935
  class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
936
  _tied_weights_keys = ["lm_head.weight"]
937
-
938
  def __init__(self, config):
939
  super().__init__(config)
940
  self.model = NeoLLMModel(config)
941
  self.vocab_size = config.vocab_size
942
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
943
-
944
- # Initialize weights and apply final processing
945
  self.post_init()
946
 
947
- @torch.compiler.disable
948
- def _compute_cce_loss(self, hidden_states, labels):
949
- """
950
- CCE loss computation excluded from compilation.
951
- Preprocesses labels to eliminate torch.compile warnings.
952
- """
953
- # Ensure labels are on the correct device
954
- processed_labels = labels.to(hidden_states.device)
955
-
956
- # Handle pad tokens: convert pad_token_id to -100 for proper masking
957
- if self.config.pad_token_id is not None:
958
- processed_labels = torch.where(
959
- processed_labels == self.config.pad_token_id,
960
- torch.tensor(-100, dtype=processed_labels.dtype, device=processed_labels.device),
961
- processed_labels
962
- )
963
-
964
- return linear_cross_entropy(
965
- hidden_states,
966
- self.lm_head.weight,
967
- processed_labels, # Use preprocessed labels
968
- bias=getattr(self.lm_head, 'bias', None),
969
- shift=1,
970
- impl="cce",
971
- reduction="mean"
972
- )
973
-
974
  def forward(
975
  self,
976
  input_ids: Optional[torch.LongTensor] = None,
@@ -981,14 +982,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
981
  logits_to_keep: Union[int, torch.Tensor] = 0,
982
  **kwargs: Unpack[TransformersKwargs],
983
  ) -> CausalLMOutputWithPast:
984
- r"""
985
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
986
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
987
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
988
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
989
- """
990
-
991
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
992
  outputs: BaseModelOutputWithPast = self.model(
993
  input_ids=input_ids,
994
  attention_mask=attention_mask,
@@ -996,19 +989,25 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
996
  inputs_embeds=inputs_embeds,
997
  **kwargs,
998
  )
999
-
1000
  hidden_states = outputs.last_hidden_state
1001
-
1002
  # CCE Loss computation for training
1003
  if labels is not None:
1004
- loss = self._compute_cce_loss(hidden_states, labels)
1005
- logits = None # CCE doesn't return logits to save memory
 
 
 
 
 
 
1006
  else:
1007
  # Inference mode - compute logits normally
1008
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1009
  logits = self.lm_head(hidden_states[:, slice_indices, :])
1010
  loss = None
1011
-
1012
  return CausalLMOutputWithPast(
1013
  loss=loss,
1014
  logits=logits,
@@ -1016,9 +1015,17 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1016
  hidden_states=outputs.hidden_states,
1017
  attentions=outputs.attentions,
1018
  )
1019
-
1020
  # ==================== AUTOMODEL REGISTRATION ====================
1021
 
 
 
 
 
 
 
 
 
 
1022
  # Register the configuration and model for AutoClass support
1023
  AutoConfig.register("neollm", NeoLLMConfig)
1024
  AutoModel.register(NeoLLMConfig, NeoLLMModel)
 
12
  import torch.nn.functional as F
13
  from torch import nn
14
  from cut_cross_entropy import linear_cross_entropy
 
15
 
16
  from transformers.activations import ACT2FN
17
  from transformers.generation import GenerationMixin
 
28
  is_causal_conv1d_available,
29
  is_flash_linear_attention_available,
30
  )
31
+ from .configuration_neollm import NeoLLMConfig
32
 
 
33
 
34
  if is_causal_conv1d_available():
35
  from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
 
42
  else:
43
  chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None
44
  FusedRMSNormGated = None
45
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
46
 
47
  logger = logging.get_logger(__name__)
48
 
 
737
 
738
  def forward(self, x):
739
  return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias
 
740
  class NeoLLMMLP(nn.Module):
741
  def __init__(self, config):
742
  super().__init__()
743
  self.config = config
744
  self.hidden_size = config.hidden_size
745
  self.intermediate_size = config.intermediate_size
746
+
747
+ # SwiGLU/Gated architecture like Motif - sin bias como el original
748
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
749
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
750
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
751
  self.act_fn = PolyNorm()
752
 
753
  # Dropout for MLP hidden layer
754
  self.dropout = nn.Dropout(config.dropout_rate)
755
 
756
  def forward(self, x):
757
+ gate_output = self.act_fn(self.gate_proj(x))
758
+ up_output = self.up_proj(x)
759
+ hidden = gate_output * up_output
760
+ hidden = self.dropout(hidden)
761
+ return self.down_proj(hidden)
762
 
763
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
764
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
 
934
  if attention_mask is not None and torch.all(attention_mask == 1):
935
  linear_attn_mask = None
936
  return linear_attn_mask
937
+ @torch.compiler.disable
938
+ def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None):
939
+ """
940
+ CCE loss computation excluded from compilation.
941
+ Preprocesses labels to eliminate torch.compile warnings.
942
+ """
943
+ # Ensure labels are on the correct device
944
+ processed_labels = labels.to(hidden_states.device)
945
+
946
+ # Handle pad tokens: convert pad_token_id to -100 for proper masking
947
+ if pad_token_id is not None:
948
+ processed_labels = torch.where(
949
+ processed_labels == pad_token_id,
950
+ torch.tensor(-100, dtype=processed_labels.dtype, device=processed_labels.device),
951
+ processed_labels
952
+ )
953
+
954
+ return linear_cross_entropy(
955
+ hidden_states,
956
+ lm_head_weight,
957
+ processed_labels,
958
+ bias=lm_head_bias,
959
+ shift=1,
960
+ impl="cce",
961
+ reduction="mean"
962
+ )
963
+
964
 
965
  class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
966
  _tied_weights_keys = ["lm_head.weight"]
967
+
968
  def __init__(self, config):
969
  super().__init__(config)
970
  self.model = NeoLLMModel(config)
971
  self.vocab_size = config.vocab_size
972
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
973
  self.post_init()
974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
975
  def forward(
976
  self,
977
  input_ids: Optional[torch.LongTensor] = None,
 
982
  logits_to_keep: Union[int, torch.Tensor] = 0,
983
  **kwargs: Unpack[TransformersKwargs],
984
  ) -> CausalLMOutputWithPast:
 
 
 
 
 
 
 
 
985
  outputs: BaseModelOutputWithPast = self.model(
986
  input_ids=input_ids,
987
  attention_mask=attention_mask,
 
989
  inputs_embeds=inputs_embeds,
990
  **kwargs,
991
  )
992
+
993
  hidden_states = outputs.last_hidden_state
994
+
995
  # CCE Loss computation for training
996
  if labels is not None:
997
+ loss = compute_cce_loss(
998
+ hidden_states,
999
+ labels,
1000
+ self.lm_head.weight,
1001
+ getattr(self.lm_head, 'bias', None),
1002
+ self.config.pad_token_id
1003
+ )
1004
+ logits = None
1005
  else:
1006
  # Inference mode - compute logits normally
1007
  slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1008
  logits = self.lm_head(hidden_states[:, slice_indices, :])
1009
  loss = None
1010
+
1011
  return CausalLMOutputWithPast(
1012
  loss=loss,
1013
  logits=logits,
 
1015
  hidden_states=outputs.hidden_states,
1016
  attentions=outputs.attentions,
1017
  )
 
1018
  # ==================== AUTOMODEL REGISTRATION ====================
1019
 
1020
+ __all__ = [
1021
+ "NeoLLMForCausalLM",
1022
+ "NeoLLMModel",
1023
+ "NeoLLMPreTrainedModel",
1024
+ "NeoLLMConfig",
1025
+ "FANLayer",
1026
+ ]
1027
+
1028
+
1029
  # Register the configuration and model for AutoClass support
1030
  AutoConfig.register("neollm", NeoLLMConfig)
1031
  AutoModel.register(NeoLLMConfig, NeoLLMModel)