Update modeling_neollm.py
Browse files- modeling_neollm.py +60 -53
modeling_neollm.py
CHANGED
|
@@ -12,7 +12,6 @@ import torch
|
|
| 12 |
import torch.nn.functional as F
|
| 13 |
from torch import nn
|
| 14 |
from cut_cross_entropy import linear_cross_entropy
|
| 15 |
-
from .configuration_neollm import NeoLLMConfig
|
| 16 |
|
| 17 |
from transformers.activations import ACT2FN
|
| 18 |
from transformers.generation import GenerationMixin
|
|
@@ -29,8 +28,8 @@ from transformers.utils.import_utils import (
|
|
| 29 |
is_causal_conv1d_available,
|
| 30 |
is_flash_linear_attention_available,
|
| 31 |
)
|
|
|
|
| 32 |
|
| 33 |
-
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 34 |
|
| 35 |
if is_causal_conv1d_available():
|
| 36 |
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
|
@@ -43,6 +42,7 @@ if is_flash_linear_attention_available():
|
|
| 43 |
else:
|
| 44 |
chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None
|
| 45 |
FusedRMSNormGated = None
|
|
|
|
| 46 |
|
| 47 |
logger = logging.get_logger(__name__)
|
| 48 |
|
|
@@ -737,25 +737,28 @@ class PolyNorm(torch.nn.Module):
|
|
| 737 |
|
| 738 |
def forward(self, x):
|
| 739 |
return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias
|
| 740 |
-
|
| 741 |
class NeoLLMMLP(nn.Module):
|
| 742 |
def __init__(self, config):
|
| 743 |
super().__init__()
|
| 744 |
self.config = config
|
| 745 |
self.hidden_size = config.hidden_size
|
| 746 |
self.intermediate_size = config.intermediate_size
|
| 747 |
-
|
| 748 |
-
|
|
|
|
|
|
|
|
|
|
| 749 |
self.act_fn = PolyNorm()
|
| 750 |
|
| 751 |
# Dropout for MLP hidden layer
|
| 752 |
self.dropout = nn.Dropout(config.dropout_rate)
|
| 753 |
|
| 754 |
def forward(self, x):
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
|
|
|
| 759 |
|
| 760 |
class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
| 761 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
|
@@ -931,46 +934,44 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 931 |
if attention_mask is not None and torch.all(attention_mask == 1):
|
| 932 |
linear_attn_mask = None
|
| 933 |
return linear_attn_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 934 |
|
| 935 |
class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
| 936 |
_tied_weights_keys = ["lm_head.weight"]
|
| 937 |
-
|
| 938 |
def __init__(self, config):
|
| 939 |
super().__init__(config)
|
| 940 |
self.model = NeoLLMModel(config)
|
| 941 |
self.vocab_size = config.vocab_size
|
| 942 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 943 |
-
|
| 944 |
-
# Initialize weights and apply final processing
|
| 945 |
self.post_init()
|
| 946 |
|
| 947 |
-
@torch.compiler.disable
|
| 948 |
-
def _compute_cce_loss(self, hidden_states, labels):
|
| 949 |
-
"""
|
| 950 |
-
CCE loss computation excluded from compilation.
|
| 951 |
-
Preprocesses labels to eliminate torch.compile warnings.
|
| 952 |
-
"""
|
| 953 |
-
# Ensure labels are on the correct device
|
| 954 |
-
processed_labels = labels.to(hidden_states.device)
|
| 955 |
-
|
| 956 |
-
# Handle pad tokens: convert pad_token_id to -100 for proper masking
|
| 957 |
-
if self.config.pad_token_id is not None:
|
| 958 |
-
processed_labels = torch.where(
|
| 959 |
-
processed_labels == self.config.pad_token_id,
|
| 960 |
-
torch.tensor(-100, dtype=processed_labels.dtype, device=processed_labels.device),
|
| 961 |
-
processed_labels
|
| 962 |
-
)
|
| 963 |
-
|
| 964 |
-
return linear_cross_entropy(
|
| 965 |
-
hidden_states,
|
| 966 |
-
self.lm_head.weight,
|
| 967 |
-
processed_labels, # Use preprocessed labels
|
| 968 |
-
bias=getattr(self.lm_head, 'bias', None),
|
| 969 |
-
shift=1,
|
| 970 |
-
impl="cce",
|
| 971 |
-
reduction="mean"
|
| 972 |
-
)
|
| 973 |
-
|
| 974 |
def forward(
|
| 975 |
self,
|
| 976 |
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -981,14 +982,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 981 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 982 |
**kwargs: Unpack[TransformersKwargs],
|
| 983 |
) -> CausalLMOutputWithPast:
|
| 984 |
-
r"""
|
| 985 |
-
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 986 |
-
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 987 |
-
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 988 |
-
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 989 |
-
"""
|
| 990 |
-
|
| 991 |
-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 992 |
outputs: BaseModelOutputWithPast = self.model(
|
| 993 |
input_ids=input_ids,
|
| 994 |
attention_mask=attention_mask,
|
|
@@ -996,19 +989,25 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 996 |
inputs_embeds=inputs_embeds,
|
| 997 |
**kwargs,
|
| 998 |
)
|
| 999 |
-
|
| 1000 |
hidden_states = outputs.last_hidden_state
|
| 1001 |
-
|
| 1002 |
# CCE Loss computation for training
|
| 1003 |
if labels is not None:
|
| 1004 |
-
loss =
|
| 1005 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1006 |
else:
|
| 1007 |
# Inference mode - compute logits normally
|
| 1008 |
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1009 |
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 1010 |
loss = None
|
| 1011 |
-
|
| 1012 |
return CausalLMOutputWithPast(
|
| 1013 |
loss=loss,
|
| 1014 |
logits=logits,
|
|
@@ -1016,9 +1015,17 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 1016 |
hidden_states=outputs.hidden_states,
|
| 1017 |
attentions=outputs.attentions,
|
| 1018 |
)
|
| 1019 |
-
|
| 1020 |
# ==================== AUTOMODEL REGISTRATION ====================
|
| 1021 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1022 |
# Register the configuration and model for AutoClass support
|
| 1023 |
AutoConfig.register("neollm", NeoLLMConfig)
|
| 1024 |
AutoModel.register(NeoLLMConfig, NeoLLMModel)
|
|
|
|
| 12 |
import torch.nn.functional as F
|
| 13 |
from torch import nn
|
| 14 |
from cut_cross_entropy import linear_cross_entropy
|
|
|
|
| 15 |
|
| 16 |
from transformers.activations import ACT2FN
|
| 17 |
from transformers.generation import GenerationMixin
|
|
|
|
| 28 |
is_causal_conv1d_available,
|
| 29 |
is_flash_linear_attention_available,
|
| 30 |
)
|
| 31 |
+
from .configuration_neollm import NeoLLMConfig
|
| 32 |
|
|
|
|
| 33 |
|
| 34 |
if is_causal_conv1d_available():
|
| 35 |
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
|
|
|
| 42 |
else:
|
| 43 |
chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None
|
| 44 |
FusedRMSNormGated = None
|
| 45 |
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 46 |
|
| 47 |
logger = logging.get_logger(__name__)
|
| 48 |
|
|
|
|
| 737 |
|
| 738 |
def forward(self, x):
|
| 739 |
return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias
|
|
|
|
| 740 |
class NeoLLMMLP(nn.Module):
|
| 741 |
def __init__(self, config):
|
| 742 |
super().__init__()
|
| 743 |
self.config = config
|
| 744 |
self.hidden_size = config.hidden_size
|
| 745 |
self.intermediate_size = config.intermediate_size
|
| 746 |
+
|
| 747 |
+
# SwiGLU/Gated architecture like Motif - sin bias como el original
|
| 748 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 749 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 750 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 751 |
self.act_fn = PolyNorm()
|
| 752 |
|
| 753 |
# Dropout for MLP hidden layer
|
| 754 |
self.dropout = nn.Dropout(config.dropout_rate)
|
| 755 |
|
| 756 |
def forward(self, x):
|
| 757 |
+
gate_output = self.act_fn(self.gate_proj(x))
|
| 758 |
+
up_output = self.up_proj(x)
|
| 759 |
+
hidden = gate_output * up_output
|
| 760 |
+
hidden = self.dropout(hidden)
|
| 761 |
+
return self.down_proj(hidden)
|
| 762 |
|
| 763 |
class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
| 764 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
|
|
|
| 934 |
if attention_mask is not None and torch.all(attention_mask == 1):
|
| 935 |
linear_attn_mask = None
|
| 936 |
return linear_attn_mask
|
| 937 |
+
@torch.compiler.disable
|
| 938 |
+
def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None):
|
| 939 |
+
"""
|
| 940 |
+
CCE loss computation excluded from compilation.
|
| 941 |
+
Preprocesses labels to eliminate torch.compile warnings.
|
| 942 |
+
"""
|
| 943 |
+
# Ensure labels are on the correct device
|
| 944 |
+
processed_labels = labels.to(hidden_states.device)
|
| 945 |
+
|
| 946 |
+
# Handle pad tokens: convert pad_token_id to -100 for proper masking
|
| 947 |
+
if pad_token_id is not None:
|
| 948 |
+
processed_labels = torch.where(
|
| 949 |
+
processed_labels == pad_token_id,
|
| 950 |
+
torch.tensor(-100, dtype=processed_labels.dtype, device=processed_labels.device),
|
| 951 |
+
processed_labels
|
| 952 |
+
)
|
| 953 |
+
|
| 954 |
+
return linear_cross_entropy(
|
| 955 |
+
hidden_states,
|
| 956 |
+
lm_head_weight,
|
| 957 |
+
processed_labels,
|
| 958 |
+
bias=lm_head_bias,
|
| 959 |
+
shift=1,
|
| 960 |
+
impl="cce",
|
| 961 |
+
reduction="mean"
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
|
| 965 |
class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
| 966 |
_tied_weights_keys = ["lm_head.weight"]
|
| 967 |
+
|
| 968 |
def __init__(self, config):
|
| 969 |
super().__init__(config)
|
| 970 |
self.model = NeoLLMModel(config)
|
| 971 |
self.vocab_size = config.vocab_size
|
| 972 |
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
|
|
| 973 |
self.post_init()
|
| 974 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 975 |
def forward(
|
| 976 |
self,
|
| 977 |
input_ids: Optional[torch.LongTensor] = None,
|
|
|
|
| 982 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 983 |
**kwargs: Unpack[TransformersKwargs],
|
| 984 |
) -> CausalLMOutputWithPast:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 985 |
outputs: BaseModelOutputWithPast = self.model(
|
| 986 |
input_ids=input_ids,
|
| 987 |
attention_mask=attention_mask,
|
|
|
|
| 989 |
inputs_embeds=inputs_embeds,
|
| 990 |
**kwargs,
|
| 991 |
)
|
| 992 |
+
|
| 993 |
hidden_states = outputs.last_hidden_state
|
| 994 |
+
|
| 995 |
# CCE Loss computation for training
|
| 996 |
if labels is not None:
|
| 997 |
+
loss = compute_cce_loss(
|
| 998 |
+
hidden_states,
|
| 999 |
+
labels,
|
| 1000 |
+
self.lm_head.weight,
|
| 1001 |
+
getattr(self.lm_head, 'bias', None),
|
| 1002 |
+
self.config.pad_token_id
|
| 1003 |
+
)
|
| 1004 |
+
logits = None
|
| 1005 |
else:
|
| 1006 |
# Inference mode - compute logits normally
|
| 1007 |
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 1008 |
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
| 1009 |
loss = None
|
| 1010 |
+
|
| 1011 |
return CausalLMOutputWithPast(
|
| 1012 |
loss=loss,
|
| 1013 |
logits=logits,
|
|
|
|
| 1015 |
hidden_states=outputs.hidden_states,
|
| 1016 |
attentions=outputs.attentions,
|
| 1017 |
)
|
|
|
|
| 1018 |
# ==================== AUTOMODEL REGISTRATION ====================
|
| 1019 |
|
| 1020 |
+
__all__ = [
|
| 1021 |
+
"NeoLLMForCausalLM",
|
| 1022 |
+
"NeoLLMModel",
|
| 1023 |
+
"NeoLLMPreTrainedModel",
|
| 1024 |
+
"NeoLLMConfig",
|
| 1025 |
+
"FANLayer",
|
| 1026 |
+
]
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
# Register the configuration and model for AutoClass support
|
| 1030 |
AutoConfig.register("neollm", NeoLLMConfig)
|
| 1031 |
AutoModel.register(NeoLLMConfig, NeoLLMModel)
|