Update modeling_neollm.py
Browse files- modeling_neollm.py +16 -97
modeling_neollm.py
CHANGED
|
@@ -1,12 +1,8 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
NeoLLM Model with FANformer Integration, Dropout Regularization, Selective Self-Attention (SSA)
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
Updated to include:
|
| 7 |
-
- Fourier Analysis Network (FAN) layer for effective periodicity modeling
|
| 8 |
-
- Dropout regularization at strategic locations
|
| 9 |
-
- ResFormer: Feature residual connections from first layer (applied before projections)
|
| 10 |
"""
|
| 11 |
|
| 12 |
import math
|
|
@@ -32,7 +28,7 @@ from transformers.utils.import_utils import (
|
|
| 32 |
is_causal_conv1d_available,
|
| 33 |
is_flash_linear_attention_available,
|
| 34 |
)
|
| 35 |
-
from
|
| 36 |
|
| 37 |
|
| 38 |
if is_causal_conv1d_available():
|
|
@@ -49,8 +45,6 @@ else:
|
|
| 49 |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 50 |
|
| 51 |
logger = logging.get_logger(__name__)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
class FANLayer(nn.Module):
|
| 55 |
"""
|
| 56 |
Fourier Analysis Network (FAN) layer for effective periodicity modeling.
|
|
@@ -289,13 +283,7 @@ def eager_attention_forward(
|
|
| 289 |
|
| 290 |
|
| 291 |
class NeoLLMAttention(nn.Module):
|
| 292 |
-
"""
|
| 293 |
-
Multi-headed attention with FANformer integration, Selective Self-Attention for periodicity modeling,
|
| 294 |
-
and ResFormer feature residual connections for enhanced information flow.
|
| 295 |
-
|
| 296 |
-
ResFormer enhancement: Applies learnable feature residual connections from the first layer
|
| 297 |
-
BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n
|
| 298 |
-
"""
|
| 299 |
|
| 300 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
| 301 |
super().__init__()
|
|
@@ -334,35 +322,22 @@ class NeoLLMAttention(nn.Module):
|
|
| 334 |
|
| 335 |
# Dropout for attention output
|
| 336 |
self.dropout = nn.Dropout(config.dropout_rate)
|
| 337 |
-
|
| 338 |
-
# ResFormer: learnable feature residual parameters (initialized to 0.5)
|
| 339 |
-
self.lambda_1 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_1 (first layer features)
|
| 340 |
-
self.lambda_2 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_n (current layer features)
|
| 341 |
|
| 342 |
def forward(
|
| 343 |
self,
|
| 344 |
hidden_states: torch.Tensor,
|
| 345 |
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 346 |
attention_mask: Optional[torch.Tensor],
|
| 347 |
-
first_layer_fan: Optional[torch.Tensor] = None,
|
| 348 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 349 |
-
) -> tuple[torch.Tensor, Optional[torch.Tensor]
|
| 350 |
input_shape = hidden_states.shape[:-1]
|
| 351 |
|
| 352 |
# Apply FANformer transformation first
|
| 353 |
hidden_states_fan = self.fan_layer(hidden_states)
|
| 354 |
|
| 355 |
-
# ResFormer: Apply feature residual connection BEFORE projections
|
| 356 |
-
# This ensures dimensional compatibility across all layer types
|
| 357 |
-
if first_layer_fan is not None:
|
| 358 |
-
hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
|
| 359 |
-
|
| 360 |
-
# Store current FAN features for potential use as first_layer_fan in subsequent layers
|
| 361 |
-
current_layer_fan = hidden_states_fan.clone()
|
| 362 |
-
|
| 363 |
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 364 |
|
| 365 |
-
# Use FAN-transformed features
|
| 366 |
query_states, gate = torch.chunk(
|
| 367 |
self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
|
| 368 |
)
|
|
@@ -394,9 +369,8 @@ class NeoLLMAttention(nn.Module):
|
|
| 394 |
attn_output = attn_output * torch.sigmoid(gate)
|
| 395 |
|
| 396 |
attn_output = self.o_proj(attn_output)
|
| 397 |
-
attn_output = self.dropout(attn_output)
|
| 398 |
-
|
| 399 |
-
return attn_output, attn_weights, current_layer_fan
|
| 400 |
|
| 401 |
|
| 402 |
def apply_mask_to_padding_states(hidden_states, attention_mask):
|
|
@@ -560,15 +534,8 @@ def torch_recurrent_gated_delta_rule(
|
|
| 560 |
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
| 561 |
return core_attn_out, last_recurrent_state
|
| 562 |
|
| 563 |
-
|
| 564 |
class NeoLLMGatedDeltaNet(nn.Module):
|
| 565 |
-
"""
|
| 566 |
-
Linear attention with FANformer integration, Selective Self-Attention for periodicity modeling,
|
| 567 |
-
and ResFormer feature residual connections for enhanced information flow.
|
| 568 |
-
|
| 569 |
-
ResFormer enhancement: Applies learnable feature residual connections from the first layer
|
| 570 |
-
BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n
|
| 571 |
-
"""
|
| 572 |
|
| 573 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
| 574 |
super().__init__()
|
|
@@ -643,10 +610,6 @@ class NeoLLMGatedDeltaNet(nn.Module):
|
|
| 643 |
self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule
|
| 644 |
self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule
|
| 645 |
|
| 646 |
-
# ResFormer: learnable feature residual parameters (initialized to 0.5)
|
| 647 |
-
self.lambda_1 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_1 (first layer features)
|
| 648 |
-
self.lambda_2 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_n (current layer features)
|
| 649 |
-
|
| 650 |
if not is_fast_path_available:
|
| 651 |
logger.warning_once(
|
| 652 |
"The fast path is not available because one of the required library is not installed. Falling back to "
|
|
@@ -686,8 +649,7 @@ class NeoLLMGatedDeltaNet(nn.Module):
|
|
| 686 |
self,
|
| 687 |
hidden_states: torch.Tensor,
|
| 688 |
attention_mask: Optional[torch.Tensor] = None,
|
| 689 |
-
|
| 690 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 691 |
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 692 |
|
| 693 |
# Set up dimensions for reshapes later
|
|
@@ -696,15 +658,7 @@ class NeoLLMGatedDeltaNet(nn.Module):
|
|
| 696 |
# Apply FANformer transformation first
|
| 697 |
hidden_states_fan = self.fan_layer(hidden_states)
|
| 698 |
|
| 699 |
-
#
|
| 700 |
-
# This ensures dimensional compatibility across all layer types
|
| 701 |
-
if first_layer_fan is not None:
|
| 702 |
-
hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
|
| 703 |
-
|
| 704 |
-
# Store current FAN features for potential use as first_layer_fan in subsequent layers
|
| 705 |
-
current_layer_fan = hidden_states_fan.clone()
|
| 706 |
-
|
| 707 |
-
# Use FAN-transformed features (with residual applied) for projections
|
| 708 |
projected_states_qkvz = self.in_proj_qkvz(hidden_states_fan)
|
| 709 |
projected_states_ba = self.in_proj_ba(hidden_states_fan)
|
| 710 |
query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
|
|
@@ -768,9 +722,7 @@ class NeoLLMGatedDeltaNet(nn.Module):
|
|
| 768 |
|
| 769 |
output = self.out_proj(core_attn_out)
|
| 770 |
output = self.dropout(output) # Apply dropout after output projection
|
| 771 |
-
|
| 772 |
-
return output, current_layer_fan
|
| 773 |
-
|
| 774 |
|
| 775 |
class PolyNorm(torch.nn.Module):
|
| 776 |
def __init__(self, eps=1e-6):
|
|
@@ -785,7 +737,6 @@ class PolyNorm(torch.nn.Module):
|
|
| 785 |
def forward(self, x):
|
| 786 |
return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias
|
| 787 |
|
| 788 |
-
|
| 789 |
class NeoLLMMLP(nn.Module):
|
| 790 |
def __init__(self, config):
|
| 791 |
super().__init__()
|
|
@@ -809,7 +760,6 @@ class NeoLLMMLP(nn.Module):
|
|
| 809 |
hidden = self.dropout(hidden)
|
| 810 |
return self.down_proj(hidden)
|
| 811 |
|
| 812 |
-
|
| 813 |
class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
| 814 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
| 815 |
super().__init__()
|
|
@@ -836,16 +786,12 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 836 |
# GPAS (Gradient-Preserving Activation Scaling) - applied after residual connections
|
| 837 |
self.gpas_attn = GPAS(config.hidden_size)
|
| 838 |
self.gpas_mlp = GPAS(config.hidden_size)
|
| 839 |
-
|
| 840 |
-
# ResFormer: storage for current layer's FAN features
|
| 841 |
-
self.current_layer_fan = None
|
| 842 |
|
| 843 |
def forward(
|
| 844 |
self,
|
| 845 |
hidden_states: torch.Tensor,
|
| 846 |
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 847 |
attention_mask: Optional[torch.Tensor] = None,
|
| 848 |
-
first_layer_fan: Optional[torch.Tensor] = None,
|
| 849 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 850 |
) -> torch.FloatTensor:
|
| 851 |
residual = hidden_states
|
|
@@ -856,20 +802,18 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
|
| 856 |
# Apply LNS scaling after normalization
|
| 857 |
hidden_states = self.lns_attn(hidden_states)
|
| 858 |
|
| 859 |
-
# Token Mixer
|
| 860 |
if self.layer_type == "linear_attention":
|
| 861 |
-
hidden_states
|
| 862 |
hidden_states=hidden_states,
|
| 863 |
attention_mask=attention_mask,
|
| 864 |
-
first_layer_fan=first_layer_fan,
|
| 865 |
)
|
| 866 |
elif self.layer_type == "full_attention":
|
| 867 |
# Self Attention
|
| 868 |
-
hidden_states, _
|
| 869 |
hidden_states=hidden_states,
|
| 870 |
attention_mask=attention_mask,
|
| 871 |
position_embeddings=position_embeddings,
|
| 872 |
-
first_layer_fan=first_layer_fan,
|
| 873 |
**kwargs,
|
| 874 |
)
|
| 875 |
|
|
@@ -911,17 +855,6 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
|
|
| 911 |
if isinstance(module, NeoLLMGatedDeltaNet):
|
| 912 |
module.dt_bias.data.fill_(1.0)
|
| 913 |
module.A_log.data.uniform_(0, 16).log_()
|
| 914 |
-
# ResFormer: initialize lambda parameters for linear attention
|
| 915 |
-
if hasattr(module, 'lambda_1'):
|
| 916 |
-
module.lambda_1.data.fill_(0.5)
|
| 917 |
-
if hasattr(module, 'lambda_2'):
|
| 918 |
-
module.lambda_2.data.fill_(0.5)
|
| 919 |
-
elif isinstance(module, NeoLLMAttention):
|
| 920 |
-
# ResFormer: initialize lambda parameters for full attention
|
| 921 |
-
if hasattr(module, 'lambda_1'):
|
| 922 |
-
module.lambda_1.data.fill_(0.5)
|
| 923 |
-
if hasattr(module, 'lambda_2'):
|
| 924 |
-
module.lambda_2.data.fill_(0.5)
|
| 925 |
elif isinstance(module, GPAS):
|
| 926 |
# Initialize GPAS alpha to 0 as per paper
|
| 927 |
module.alpha.data.fill_(0.0)
|
|
@@ -942,10 +875,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 942 |
self.norm = NeoLLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 943 |
self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
|
| 944 |
self.gradient_checkpointing = False
|
| 945 |
-
|
| 946 |
-
# ResFormer: storage for first layer's FAN features (H_fan_1)
|
| 947 |
-
self.first_layer_fan = None
|
| 948 |
-
|
| 949 |
# Initialize weights and apply final processing
|
| 950 |
self.post_init()
|
| 951 |
|
|
@@ -981,9 +910,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 981 |
# create position embeddings to be shared across the decoder layers
|
| 982 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 983 |
|
| 984 |
-
# ResFormer: reset first_layer_fan at the start of each forward pass
|
| 985 |
-
self.first_layer_fan = None
|
| 986 |
-
|
| 987 |
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 988 |
layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask
|
| 989 |
|
|
@@ -991,13 +917,8 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 991 |
hidden_states,
|
| 992 |
position_embeddings=position_embeddings,
|
| 993 |
attention_mask=layer_mask,
|
| 994 |
-
first_layer_fan=self.first_layer_fan, # Pass H_fan_1 to all layers
|
| 995 |
**kwargs,
|
| 996 |
)
|
| 997 |
-
|
| 998 |
-
# ResFormer: capture H_fan_1 from the first layer
|
| 999 |
-
if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
|
| 1000 |
-
self.first_layer_fan = decoder_layer.current_layer_fan
|
| 1001 |
|
| 1002 |
hidden_states = self.norm(hidden_states)
|
| 1003 |
|
|
@@ -1016,7 +937,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 1016 |
linear_attn_mask = None
|
| 1017 |
return linear_attn_mask
|
| 1018 |
|
| 1019 |
-
|
| 1020 |
@torch.compiler.disable
|
| 1021 |
def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None):
|
| 1022 |
"""
|
|
@@ -1099,7 +1019,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 1099 |
attentions=outputs.attentions,
|
| 1100 |
)
|
| 1101 |
|
| 1102 |
-
|
| 1103 |
# ==================== AUTOMODEL REGISTRATION ====================
|
| 1104 |
|
| 1105 |
__all__ = [
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
NeoLLM Model with FANformer Integration, Dropout Regularization, and Selective Self-Attention (SSA)
|
| 4 |
+
Updated to include Fourier Analysis Network (FAN) layer for effective periodicity modeling,
|
| 5 |
+
dropout regularization at strategic locations
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
|
| 8 |
import math
|
|
|
|
| 28 |
is_causal_conv1d_available,
|
| 29 |
is_flash_linear_attention_available,
|
| 30 |
)
|
| 31 |
+
from configuration_neollm import NeoLLMConfig
|
| 32 |
|
| 33 |
|
| 34 |
if is_causal_conv1d_available():
|
|
|
|
| 45 |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 46 |
|
| 47 |
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
| 48 |
class FANLayer(nn.Module):
|
| 49 |
"""
|
| 50 |
Fourier Analysis Network (FAN) layer for effective periodicity modeling.
|
|
|
|
| 283 |
|
| 284 |
|
| 285 |
class NeoLLMAttention(nn.Module):
|
| 286 |
+
"""Multi-headed attention with FANformer integration and Selective Self-Attention for periodicity modeling"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
| 289 |
super().__init__()
|
|
|
|
| 322 |
|
| 323 |
# Dropout for attention output
|
| 324 |
self.dropout = nn.Dropout(config.dropout_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
def forward(
|
| 327 |
self,
|
| 328 |
hidden_states: torch.Tensor,
|
| 329 |
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 330 |
attention_mask: Optional[torch.Tensor],
|
|
|
|
| 331 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 332 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 333 |
input_shape = hidden_states.shape[:-1]
|
| 334 |
|
| 335 |
# Apply FANformer transformation first
|
| 336 |
hidden_states_fan = self.fan_layer(hidden_states)
|
| 337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 339 |
|
| 340 |
+
# Use FAN-transformed features directly for projections
|
| 341 |
query_states, gate = torch.chunk(
|
| 342 |
self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
|
| 343 |
)
|
|
|
|
| 369 |
attn_output = attn_output * torch.sigmoid(gate)
|
| 370 |
|
| 371 |
attn_output = self.o_proj(attn_output)
|
| 372 |
+
attn_output = self.dropout(attn_output) # Apply dropout after output projection
|
| 373 |
+
return attn_output, attn_weights
|
|
|
|
| 374 |
|
| 375 |
|
| 376 |
def apply_mask_to_padding_states(hidden_states, attention_mask):
|
|
|
|
| 534 |
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
|
| 535 |
return core_attn_out, last_recurrent_state
|
| 536 |
|
|
|
|
| 537 |
class NeoLLMGatedDeltaNet(nn.Module):
|
| 538 |
+
"""Linear attention with FANformer integration and Selective Self-Attention for periodicity modeling"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
|
| 540 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
| 541 |
super().__init__()
|
|
|
|
| 610 |
self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule
|
| 611 |
self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule
|
| 612 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
if not is_fast_path_available:
|
| 614 |
logger.warning_once(
|
| 615 |
"The fast path is not available because one of the required library is not installed. Falling back to "
|
|
|
|
| 649 |
self,
|
| 650 |
hidden_states: torch.Tensor,
|
| 651 |
attention_mask: Optional[torch.Tensor] = None,
|
| 652 |
+
):
|
|
|
|
| 653 |
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
| 654 |
|
| 655 |
# Set up dimensions for reshapes later
|
|
|
|
| 658 |
# Apply FANformer transformation first
|
| 659 |
hidden_states_fan = self.fan_layer(hidden_states)
|
| 660 |
|
| 661 |
+
# Use FAN-transformed features directly for projections
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 662 |
projected_states_qkvz = self.in_proj_qkvz(hidden_states_fan)
|
| 663 |
projected_states_ba = self.in_proj_ba(hidden_states_fan)
|
| 664 |
query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
|
|
|
|
| 722 |
|
| 723 |
output = self.out_proj(core_attn_out)
|
| 724 |
output = self.dropout(output) # Apply dropout after output projection
|
| 725 |
+
return output
|
|
|
|
|
|
|
| 726 |
|
| 727 |
class PolyNorm(torch.nn.Module):
|
| 728 |
def __init__(self, eps=1e-6):
|
|
|
|
| 737 |
def forward(self, x):
|
| 738 |
return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias
|
| 739 |
|
|
|
|
| 740 |
class NeoLLMMLP(nn.Module):
|
| 741 |
def __init__(self, config):
|
| 742 |
super().__init__()
|
|
|
|
| 760 |
hidden = self.dropout(hidden)
|
| 761 |
return self.down_proj(hidden)
|
| 762 |
|
|
|
|
| 763 |
class NeoLLMDecoderLayer(GradientCheckpointingLayer):
|
| 764 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
| 765 |
super().__init__()
|
|
|
|
| 786 |
# GPAS (Gradient-Preserving Activation Scaling) - applied after residual connections
|
| 787 |
self.gpas_attn = GPAS(config.hidden_size)
|
| 788 |
self.gpas_mlp = GPAS(config.hidden_size)
|
|
|
|
|
|
|
|
|
|
| 789 |
|
| 790 |
def forward(
|
| 791 |
self,
|
| 792 |
hidden_states: torch.Tensor,
|
| 793 |
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 794 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
| 795 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 796 |
) -> torch.FloatTensor:
|
| 797 |
residual = hidden_states
|
|
|
|
| 802 |
# Apply LNS scaling after normalization
|
| 803 |
hidden_states = self.lns_attn(hidden_states)
|
| 804 |
|
| 805 |
+
# Token Mixer
|
| 806 |
if self.layer_type == "linear_attention":
|
| 807 |
+
hidden_states = self.linear_attn(
|
| 808 |
hidden_states=hidden_states,
|
| 809 |
attention_mask=attention_mask,
|
|
|
|
| 810 |
)
|
| 811 |
elif self.layer_type == "full_attention":
|
| 812 |
# Self Attention
|
| 813 |
+
hidden_states, _ = self.self_attn(
|
| 814 |
hidden_states=hidden_states,
|
| 815 |
attention_mask=attention_mask,
|
| 816 |
position_embeddings=position_embeddings,
|
|
|
|
| 817 |
**kwargs,
|
| 818 |
)
|
| 819 |
|
|
|
|
| 855 |
if isinstance(module, NeoLLMGatedDeltaNet):
|
| 856 |
module.dt_bias.data.fill_(1.0)
|
| 857 |
module.A_log.data.uniform_(0, 16).log_()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 858 |
elif isinstance(module, GPAS):
|
| 859 |
# Initialize GPAS alpha to 0 as per paper
|
| 860 |
module.alpha.data.fill_(0.0)
|
|
|
|
| 875 |
self.norm = NeoLLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 876 |
self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
|
| 877 |
self.gradient_checkpointing = False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 878 |
# Initialize weights and apply final processing
|
| 879 |
self.post_init()
|
| 880 |
|
|
|
|
| 910 |
# create position embeddings to be shared across the decoder layers
|
| 911 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 912 |
|
|
|
|
|
|
|
|
|
|
| 913 |
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 914 |
layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask
|
| 915 |
|
|
|
|
| 917 |
hidden_states,
|
| 918 |
position_embeddings=position_embeddings,
|
| 919 |
attention_mask=layer_mask,
|
|
|
|
| 920 |
**kwargs,
|
| 921 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 922 |
|
| 923 |
hidden_states = self.norm(hidden_states)
|
| 924 |
|
|
|
|
| 937 |
linear_attn_mask = None
|
| 938 |
return linear_attn_mask
|
| 939 |
|
|
|
|
| 940 |
@torch.compiler.disable
|
| 941 |
def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None):
|
| 942 |
"""
|
|
|
|
| 1019 |
attentions=outputs.attentions,
|
| 1020 |
)
|
| 1021 |
|
|
|
|
| 1022 |
# ==================== AUTOMODEL REGISTRATION ====================
|
| 1023 |
|
| 1024 |
__all__ = [
|