KitsuVp commited on
Commit
8cc7157
·
verified ·
1 Parent(s): c8bde25

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_neollm.py +87 -0
  2. modeling_neollm.py +1034 -0
configuration_neollm.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ==================== configuration_neollm.py ====================
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+ from transformers.modeling_rope_utils import rope_config_validation
5
+ from transformers.utils import logging
6
+
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+
11
+ class NeoLLMConfig(PretrainedConfig):
12
+ r"""
13
+ This is the configuration class to store the configuration of a [`NeoLLMModel`]. It is used to instantiate a
14
+ NeoLLM model according to the specified arguments, defining the model architecture.
15
+
16
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
17
+ """
18
+
19
+ model_type = "neollm"
20
+ keys_to_ignore_at_inference = []
21
+
22
+ def __init__(
23
+ self,
24
+ vocab_size=151665,
25
+ hidden_size=512,
26
+ intermediate_size=1024,
27
+ num_hidden_layers=12,
28
+ num_attention_heads=8,
29
+ num_key_value_heads=2,
30
+ hidden_act="xielu",
31
+ max_position_embeddings=32768,
32
+ initializer_range=0.02,
33
+ rms_norm_eps=1e-6,
34
+ tie_word_embeddings=True,
35
+ rope_theta=10000.0,
36
+ rope_scaling=None,
37
+ partial_rotary_factor=0.25,
38
+ attention_bias=False,
39
+ attention_dropout=0.1,
40
+ head_dim=64,
41
+ linear_conv_kernel_dim=4,
42
+ linear_key_head_dim=64,
43
+ linear_value_head_dim=64,
44
+ linear_num_key_heads=8,
45
+ linear_num_value_heads=8,
46
+ layer_types=None,
47
+ fan_ratio=0.125,
48
+ dropout_rate=0.1,
49
+ **kwargs,
50
+ ):
51
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
52
+ self.vocab_size = vocab_size
53
+ self.max_position_embeddings = max_position_embeddings
54
+ self.hidden_size = hidden_size
55
+ self.intermediate_size = intermediate_size
56
+ self.num_hidden_layers = num_hidden_layers
57
+ self.num_attention_heads = num_attention_heads
58
+ self.num_key_value_heads = num_key_value_heads
59
+ self.hidden_act = hidden_act
60
+ self.initializer_range = initializer_range
61
+ self.rms_norm_eps = rms_norm_eps
62
+ self.rope_theta = rope_theta
63
+ self.rope_scaling = rope_scaling
64
+ self.partial_rotary_factor = partial_rotary_factor
65
+ self.attention_bias = attention_bias
66
+ self.attention_dropout = attention_dropout
67
+ self.head_dim = head_dim
68
+ rope_config_validation(self)
69
+
70
+ self.layer_types = layer_types
71
+ if self.layer_types is None:
72
+ interval_pattern = kwargs.get("full_attention_interval", 4)
73
+ self.layer_types = [
74
+ "linear_attention" if bool((i + 1) % interval_pattern) else "full_attention"
75
+ for i in range(self.num_hidden_layers)
76
+ ]
77
+
78
+ # linear attention part
79
+ self.linear_conv_kernel_dim = linear_conv_kernel_dim
80
+ self.linear_key_head_dim = linear_key_head_dim
81
+ self.linear_value_head_dim = linear_value_head_dim
82
+ self.linear_num_key_heads = linear_num_key_heads
83
+ self.linear_num_value_heads = linear_num_value_heads
84
+ self.fan_ratio = fan_ratio
85
+ self.dropout_rate = dropout_rate
86
+
87
+ __all__ = ["NeoLLMConfig"]
modeling_neollm.py ADDED
@@ -0,0 +1,1034 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ NeoLLM Model with FANformer Integration and Dropout Regularization
4
+ Updated to include Fourier Analysis Network (FAN) layer for effective periodicity modeling
5
+ and dropout regularization at strategic locations
6
+ """
7
+
8
+ import math
9
+ from typing import Any, Callable, Optional, Union
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import nn
14
+ from cut_cross_entropy import linear_cross_entropy
15
+
16
+ from transformers.activations import ACT2FN
17
+ from transformers.generation import GenerationMixin
18
+ from transformers.masking_utils import create_causal_mask
19
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
20
+ from transformers.modeling_layers import GradientCheckpointingLayer
21
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
22
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
23
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
24
+ from transformers.processing_utils import Unpack
25
+ from transformers.utils import TransformersKwargs, logging
26
+ from transformers.utils.generic import check_model_inputs
27
+ from transformers.utils.import_utils import (
28
+ is_causal_conv1d_available,
29
+ is_flash_linear_attention_available,
30
+ )
31
+ from configuration_neollm import NeoLLMConfig
32
+
33
+
34
+ if is_causal_conv1d_available():
35
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
36
+ else:
37
+ causal_conv1d_update, causal_conv1d_fn = None, None
38
+
39
+ if is_flash_linear_attention_available():
40
+ from fla.modules import FusedRMSNormGated
41
+ from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
42
+ else:
43
+ chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None
44
+ FusedRMSNormGated = None
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+
49
+ class FANLayer(nn.Module):
50
+ """
51
+ Fourier Analysis Network (FAN) layer for effective periodicity modeling.
52
+
53
+ From "FANformer: Improving Large Language Models Through Effective Periodicity Modeling":
54
+ FANLayer'(X) = [cos(WpX)||sin(WpX)||(Wp¯X + Bp¯)]
55
+
56
+ This is the modified version (FANLayer') without activation function that gave
57
+ the best results in the paper.
58
+ """
59
+
60
+ def __init__(self, hidden_size: int, fan_ratio: float = 0.25):
61
+ super().__init__()
62
+ self.hidden_size = hidden_size
63
+ self.fan_ratio = fan_ratio
64
+
65
+ # Calculate dimensions for periodic and non-periodic components
66
+ self.periodic_dim = int(hidden_size * fan_ratio)
67
+ self.non_periodic_dim = hidden_size - self.periodic_dim
68
+
69
+ # Projection matrices
70
+ self.Wp = nn.Linear(hidden_size, self.periodic_dim, bias=False)
71
+ self.Wp_bar = nn.Linear(hidden_size, self.non_periodic_dim, bias=True)
72
+
73
+ # Initialize parameters
74
+ self._init_weights()
75
+
76
+ def _init_weights(self):
77
+ """Initialize weights following the paper's recommendations."""
78
+ # Initialize Wp for periodic components
79
+ nn.init.normal_(self.Wp.weight, mean=0.0, std=0.02)
80
+
81
+ # Initialize Wp_bar for non-periodic components
82
+ nn.init.normal_(self.Wp_bar.weight, mean=0.0, std=0.02)
83
+ if self.Wp_bar.bias is not None:
84
+ nn.init.zeros_(self.Wp_bar.bias)
85
+
86
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
87
+ """
88
+ Apply Fourier transformation to input.
89
+
90
+ Args:
91
+ x: Input tensor of shape (batch, seq_len, hidden_size)
92
+
93
+ Returns:
94
+ Transformed tensor with Fourier components concatenated
95
+ """
96
+ # Get periodic components
97
+ x_periodic = self.Wp(x) # (batch, seq_len, periodic_dim)
98
+ cos_component = torch.cos(x_periodic)
99
+ sin_component = torch.sin(x_periodic)
100
+
101
+ # Get non-periodic component (linear transformation)
102
+ x_non_periodic = self.Wp_bar(x) # (batch, seq_len, non_periodic_dim)
103
+
104
+ # Concatenate all components: [cos(WpX) || sin(WpX) || (Wp¯X + Bp¯)]
105
+ x_fan = torch.cat([cos_component, sin_component, x_non_periodic], dim=-1)
106
+
107
+ return x_fan
108
+
109
+
110
+ class LNS(nn.Module):
111
+ """
112
+ LayerNorm Scaling (LNS) - applies scaling factor 1/√ℓ as described in the paper.
113
+
114
+ From "The Curse of Depth in Large Language Models":
115
+ h^(ℓ) = LayerNorm(h^(ℓ)) × (1/√ℓ)
116
+
117
+ This prevents exponential variance growth in deeper layers.
118
+ """
119
+ def __init__(self, layer_idx: int):
120
+ super().__init__()
121
+ # Layer 1 gets index 1, layer 2 gets index 2, etc.
122
+ # Avoid division by zero for layer 0
123
+ self.layer_idx = max(layer_idx + 1, 1) # +1 because layer_idx starts from 0
124
+ self.scale = 1.0 / math.sqrt(self.layer_idx)
125
+
126
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
127
+ return x * self.scale
128
+
129
+
130
+ class GPAS(nn.Module):
131
+ """
132
+ Gradient-Preserving Activation Scaling (GPAS)
133
+ Scales activations without penalizing gradients using stop-gradient.
134
+ Applied in Pre-Norm style: after sub-layer output but before residual sum.
135
+ """
136
+ def __init__(self, d_model: int):
137
+ super().__init__()
138
+
139
+ self.d_model = d_model
140
+ self.alpha = nn.Parameter(torch.zeros(1))
141
+
142
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
143
+ x_detached = x.detach()
144
+ scaled_component = F.silu(self.alpha) * x_detached
145
+ x_scaled = x - scaled_component
146
+
147
+ return x_scaled
148
+
149
+
150
+ class NeoLLMRMSNormGated(nn.Module):
151
+ def __init__(self, hidden_size, eps=1e-6, **kwargs):
152
+ super().__init__()
153
+ self.weight = nn.Parameter(torch.ones(hidden_size))
154
+ self.variance_epsilon = eps
155
+
156
+ def forward(self, hidden_states, gate=None):
157
+ input_dtype = hidden_states.dtype
158
+ hidden_states = hidden_states.to(torch.float32)
159
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
160
+ # Norm before gate
161
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
162
+ hidden_states = self.weight * hidden_states.to(input_dtype)
163
+ hidden_states = hidden_states * F.silu(gate.to(torch.float32))
164
+
165
+ return hidden_states.to(input_dtype)
166
+
167
+
168
+ class NeoLLMRotaryEmbedding(nn.Module):
169
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
170
+
171
+ def __init__(self, config: NeoLLMConfig, device=None):
172
+ super().__init__()
173
+ # BC: "rope_type" was originally "type"
174
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
175
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
176
+ else:
177
+ self.rope_type = "default"
178
+ self.max_seq_len_cached = config.max_position_embeddings
179
+ self.original_max_seq_len = config.max_position_embeddings
180
+
181
+ self.config = config
182
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
183
+
184
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
185
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
186
+ self.original_inv_freq = self.inv_freq
187
+
188
+ @torch.no_grad()
189
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
190
+ def forward(self, x, position_ids):
191
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
192
+ position_ids_expanded = position_ids[:, None, :].float()
193
+
194
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
195
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
196
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
197
+ emb = torch.cat((freqs, freqs), dim=-1)
198
+ cos = emb.cos() * self.attention_scaling
199
+ sin = emb.sin() * self.attention_scaling
200
+
201
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
202
+
203
+
204
+ class NeoLLMRMSNorm(nn.Module):
205
+ def __init__(self, dim: int, eps: float = 1e-6):
206
+ super().__init__()
207
+ self.eps = eps
208
+ self.weight = nn.Parameter(torch.zeros(dim))
209
+
210
+ def _norm(self, x):
211
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
212
+
213
+ def forward(self, x):
214
+ output = self._norm(x.float())
215
+ # Llama does x.to(float16) * w whilst NeoLLM is (x * w).to(float16)
216
+ output = output * (1.0 + self.weight.float())
217
+ return output.type_as(x)
218
+
219
+ def extra_repr(self):
220
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
221
+
222
+
223
+ def rotate_half(x):
224
+ """Rotates half the hidden dims of the input."""
225
+ x1 = x[..., : x.shape[-1] // 2]
226
+ x2 = x[..., x.shape[-1] // 2 :]
227
+ return torch.cat((-x2, x1), dim=-1)
228
+
229
+
230
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
231
+ """Applies Rotary Position Embedding to the query and key tensors."""
232
+ cos = cos.unsqueeze(unsqueeze_dim)
233
+ sin = sin.unsqueeze(unsqueeze_dim)
234
+
235
+ # Keep half or full tensor for later concatenation
236
+ rotary_dim = cos.shape[-1]
237
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
238
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
239
+
240
+ # Apply rotary embeddings on the first half or full tensor
241
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
242
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
243
+
244
+ # Concatenate back to full shape
245
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
246
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
247
+ return q_embed, k_embed
248
+
249
+
250
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
251
+ """
252
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
253
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
254
+ """
255
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
256
+ if n_rep == 1:
257
+ return hidden_states
258
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
259
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
260
+
261
+
262
+ def eager_attention_forward(
263
+ module: nn.Module,
264
+ query: torch.Tensor,
265
+ key: torch.Tensor,
266
+ value: torch.Tensor,
267
+ attention_mask: Optional[torch.Tensor],
268
+ scaling: float,
269
+ dropout: float = 0.0,
270
+ **kwargs: Unpack[TransformersKwargs],
271
+ ):
272
+ key_states = repeat_kv(key, module.num_key_value_groups)
273
+ value_states = repeat_kv(value, module.num_key_value_groups)
274
+
275
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
276
+ if attention_mask is not None:
277
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
278
+ attn_weights = attn_weights + causal_mask
279
+
280
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
281
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
282
+ attn_output = torch.matmul(attn_weights, value_states)
283
+ attn_output = attn_output.transpose(1, 2).contiguous()
284
+
285
+ return attn_output, attn_weights
286
+
287
+
288
+ class NeoLLMAttention(nn.Module):
289
+ """Multi-headed attention with FANformer integration for periodicity modeling"""
290
+
291
+ def __init__(self, config: NeoLLMConfig, layer_idx: int):
292
+ super().__init__()
293
+ self.config = config
294
+ self.layer_idx = layer_idx
295
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
296
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
297
+ self.scaling = self.head_dim**-0.5
298
+ self.attention_dropout = config.attention_dropout
299
+ self.is_causal = True
300
+
301
+ # FANformer integration: FAN layer before QKV projections
302
+ self.fan_layer = FANLayer(
303
+ hidden_size=config.hidden_size,
304
+ fan_ratio=getattr(config, 'fan_ratio', 0.25)
305
+ )
306
+
307
+ # Calculate the output dimension after FAN transformation
308
+ fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio', 0.25))
309
+
310
+ # QKV projections operate on FAN-transformed features
311
+ self.q_proj = nn.Linear(
312
+ fan_output_dim, config.num_attention_heads * self.head_dim * 2, bias=config.attention_bias
313
+ )
314
+ self.k_proj = nn.Linear(
315
+ fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
316
+ )
317
+ self.v_proj = nn.Linear(
318
+ fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
319
+ )
320
+ self.o_proj = nn.Linear(
321
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
322
+ )
323
+ self.q_norm = NeoLLMRMSNorm(self.head_dim, eps=config.rms_norm_eps)
324
+ self.k_norm = NeoLLMRMSNorm(self.head_dim, eps=config.rms_norm_eps)
325
+
326
+ # Dropout for attention output
327
+ self.dropout = nn.Dropout(config.dropout_rate)
328
+
329
+ def forward(
330
+ self,
331
+ hidden_states: torch.Tensor,
332
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
333
+ attention_mask: Optional[torch.Tensor],
334
+ **kwargs: Unpack[FlashAttentionKwargs],
335
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
336
+ input_shape = hidden_states.shape[:-1]
337
+
338
+ # Apply FANformer transformation first
339
+ hidden_states_fan = self.fan_layer(hidden_states)
340
+ hidden_shape = (*input_shape, -1, self.head_dim)
341
+
342
+ query_states, gate = torch.chunk(
343
+ self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
344
+ )
345
+ gate = gate.reshape(*input_shape, -1)
346
+
347
+ query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
348
+ key_states = self.k_norm(self.k_proj(hidden_states_fan).view(hidden_shape)).transpose(1, 2)
349
+ value_states = self.v_proj(hidden_states_fan).view(hidden_shape).transpose(1, 2)
350
+
351
+ cos, sin = position_embeddings
352
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
353
+
354
+ attention_interface: Callable = eager_attention_forward
355
+ if self.config._attn_implementation != "eager":
356
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
357
+
358
+ attn_output, attn_weights = attention_interface(
359
+ self,
360
+ query_states,
361
+ key_states,
362
+ value_states,
363
+ attention_mask,
364
+ dropout=0.0 if not self.training else self.attention_dropout,
365
+ scaling=self.scaling,
366
+ **kwargs,
367
+ )
368
+
369
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
370
+ attn_output = attn_output * torch.sigmoid(gate)
371
+
372
+ attn_output = self.o_proj(attn_output)
373
+ attn_output = self.dropout(attn_output) # Apply dropout after output projection
374
+ return attn_output, attn_weights
375
+
376
+
377
+ def apply_mask_to_padding_states(hidden_states, attention_mask):
378
+ """
379
+ Tunes out the hidden states for padding tokens
380
+ """
381
+ if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
382
+ dtype = hidden_states.dtype
383
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
384
+
385
+ return hidden_states
386
+
387
+
388
+ is_fast_path_available = all(
389
+ (causal_conv1d_fn, causal_conv1d_update, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule)
390
+ )
391
+
392
+
393
+ def torch_causal_conv1d_update(
394
+ hidden_states,
395
+ conv_state,
396
+ weight,
397
+ bias=None,
398
+ activation=None,
399
+ ):
400
+ _, hidden_size, seq_len = hidden_states.shape
401
+ state_len = conv_state.shape[-1]
402
+
403
+ hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
404
+ conv_state.copy_(hidden_states_new[:, :, -state_len:])
405
+ out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
406
+ out = F.silu(out[:, :, -seq_len:])
407
+ out = out.to(hidden_states.dtype)
408
+ return out
409
+
410
+
411
+ def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
412
+ """This function is intended to align with the l2norm implementation in the FLA library."""
413
+ inv_norm = 1 / torch.sqrt((x * x).sum(dim=dim, keepdim=True) + eps)
414
+ return x * inv_norm
415
+
416
+
417
+ def torch_chunk_gated_delta_rule(
418
+ query,
419
+ key,
420
+ value,
421
+ g,
422
+ beta,
423
+ chunk_size=64,
424
+ initial_state=None,
425
+ output_final_state=False,
426
+ use_qk_l2norm_in_kernel=False,
427
+ ):
428
+ initial_dtype = query.dtype
429
+ if use_qk_l2norm_in_kernel:
430
+ query = l2norm(query, dim=-1, eps=1e-6)
431
+ key = l2norm(key, dim=-1, eps=1e-6)
432
+ query, key, value, beta, g = [
433
+ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
434
+ ]
435
+
436
+ batch_size, sequence_length, num_heads, k_head_dim = key.shape
437
+ v_head_dim = value.shape[-1]
438
+ pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
439
+ query = F.pad(query, (0, 0, 0, pad_size))
440
+ key = F.pad(key, (0, 0, 0, pad_size))
441
+ value = F.pad(value, (0, 0, 0, pad_size))
442
+ beta = F.pad(beta, (0, pad_size))
443
+ g = F.pad(g, (0, pad_size))
444
+ tot_heads = num_heads + pad_size
445
+ scale = 1 / (query.shape[-1] ** 0.5)
446
+ query = query * scale
447
+
448
+ v_beta = value * beta.unsqueeze(-1)
449
+ k_beta = key * beta.unsqueeze(-1)
450
+ # reshape to chunks
451
+ query, key, value, k_beta, v_beta = [
452
+ x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
453
+ ]
454
+ g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
455
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
456
+
457
+ # chunk decay
458
+ g = g.cumsum(dim=-1)
459
+ decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
460
+ attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
461
+ for i in range(1, chunk_size):
462
+ row = attn[..., i, :i].clone()
463
+ sub = attn[..., :i, :i].clone()
464
+ attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
465
+ attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
466
+ value = attn @ v_beta
467
+ k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
468
+ last_recurrent_state = (
469
+ torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
470
+ if initial_state is None
471
+ else initial_state.to(value)
472
+ )
473
+ core_attn_out = torch.zeros_like(value)
474
+ mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
475
+
476
+ # for each chunk
477
+ for i in range(0, tot_heads // chunk_size):
478
+ q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
479
+ attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
480
+ v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
481
+ v_new = v_i - v_prime
482
+ attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
483
+ core_attn_out[:, :, i] = attn_inter + attn @ v_new
484
+ last_recurrent_state = (
485
+ last_recurrent_state * g[:, :, i, -1, None, None].exp()
486
+ + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
487
+ )
488
+
489
+ if not output_final_state:
490
+ last_recurrent_state = None
491
+ core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
492
+ core_attn_out = core_attn_out[:, :, :num_heads]
493
+ core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
494
+ return core_attn_out, last_recurrent_state
495
+
496
+
497
+ def torch_recurrent_gated_delta_rule(
498
+ query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
499
+ ):
500
+ initial_dtype = query.dtype
501
+ if use_qk_l2norm_in_kernel:
502
+ query = l2norm(query, dim=-1, eps=1e-6)
503
+ key = l2norm(key, dim=-1, eps=1e-6)
504
+ query, key, value, beta, g = [
505
+ x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
506
+ ]
507
+
508
+ batch_size, sequence_length, num_heads, k_head_dim = key.shape
509
+ v_head_dim = value.shape[-1]
510
+ scale = 1 / (query.shape[-1] ** 0.5)
511
+ query = query * scale
512
+
513
+ core_attn_out = torch.zeros(batch_size, sequence_length, num_heads, v_head_dim).to(value)
514
+ last_recurrent_state = (
515
+ torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
516
+ if initial_state is None
517
+ else initial_state.to(value)
518
+ )
519
+
520
+ for i in range(num_heads):
521
+ q_t = query[:, :, i]
522
+ k_t = key[:, :, i]
523
+ v_t = value[:, :, i]
524
+ g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
525
+ beta_t = beta[:, :, i].unsqueeze(-1)
526
+
527
+ last_recurrent_state = last_recurrent_state * g_t
528
+ kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
529
+ delta = (v_t - kv_mem) * beta_t
530
+ last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
531
+ core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
532
+
533
+ if not output_final_state:
534
+ last_recurrent_state = None
535
+ core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
536
+ return core_attn_out, last_recurrent_state
537
+
538
+ class NeoLLMGatedDeltaNet(nn.Module):
539
+ """Linear attention with FANformer integration for periodicity modeling"""
540
+
541
+ def __init__(self, config: NeoLLMConfig, layer_idx: int):
542
+ super().__init__()
543
+ self.hidden_size = config.hidden_size
544
+ self.num_v_heads = config.linear_num_value_heads
545
+ self.num_k_heads = config.linear_num_key_heads
546
+ self.head_k_dim = config.linear_key_head_dim
547
+ self.head_v_dim = config.linear_value_head_dim
548
+ self.key_dim = self.head_k_dim * self.num_k_heads
549
+ self.value_dim = self.head_v_dim * self.num_v_heads
550
+
551
+ self.conv_kernel_size = config.linear_conv_kernel_dim
552
+ self.layer_idx = layer_idx
553
+ self.activation = config.hidden_act
554
+ self.act = ACT2FN[config.hidden_act]
555
+ self.layer_norm_epsilon = config.rms_norm_eps
556
+
557
+ # FANformer integration: FAN layer before projections
558
+ self.fan_layer = FANLayer(
559
+ hidden_size=config.hidden_size,
560
+ fan_ratio=getattr(config, 'fan_ratio', 0.25)
561
+ )
562
+
563
+ # Calculate the output dimension after FAN transformation
564
+ fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio', 0.25))
565
+
566
+ # QKV - operates on FAN-transformed features
567
+ self.conv_dim = self.key_dim * 2 + self.value_dim
568
+ self.conv1d = nn.Conv1d(
569
+ in_channels=self.conv_dim,
570
+ out_channels=self.conv_dim,
571
+ bias=False,
572
+ kernel_size=self.conv_kernel_size,
573
+ groups=self.conv_dim,
574
+ padding=self.conv_kernel_size - 1,
575
+ )
576
+
577
+ # projection of the FAN-transformed hidden states
578
+ projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
579
+ projection_size_ba = self.num_v_heads * 2
580
+ self.in_proj_qkvz = nn.Linear(fan_output_dim, projection_size_qkvz, bias=False)
581
+ self.in_proj_ba = nn.Linear(fan_output_dim, projection_size_ba, bias=False)
582
+
583
+ # time step projection (discretization)
584
+ self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
585
+
586
+ A = torch.empty(self.num_v_heads).uniform_(0, 16)
587
+ self.A_log = nn.Parameter(torch.log(A))
588
+
589
+ # FLA compatibility: use "silu" for FusedRMSNormGated, original activation elsewhere
590
+ fla_compatible_activation = "silu" if self.activation not in ['swish', 'silu', 'sigmoid'] else self.activation
591
+
592
+ self.norm = (
593
+ NeoLLMRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon)
594
+ if FusedRMSNormGated is None
595
+ else FusedRMSNormGated(
596
+ self.head_v_dim,
597
+ eps=self.layer_norm_epsilon,
598
+ activation=fla_compatible_activation, # Use FLA-compatible activation
599
+ device=torch.cuda.current_device(),
600
+ dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
601
+ )
602
+ )
603
+
604
+ self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
605
+
606
+ # Dropout for attention output
607
+ self.dropout = nn.Dropout(config.dropout_rate)
608
+
609
+ self.causal_conv1d_fn = causal_conv1d_fn
610
+ self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update
611
+ self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule
612
+ self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule
613
+
614
+ if not is_fast_path_available:
615
+ logger.warning_once(
616
+ "The fast path is not available because one of the required library is not installed. Falling back to "
617
+ "torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and"
618
+ " https://github.com/Dao-AILab/causal-conv1d"
619
+ )
620
+
621
+ def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):
622
+ """
623
+ Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`.
624
+ """
625
+ new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
626
+ self.num_k_heads,
627
+ 2 * self.head_k_dim + 2 * self.head_v_dim * self.num_v_heads // self.num_k_heads,
628
+ )
629
+ new_tensor_shape_ba = mixed_ba.size()[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads)
630
+
631
+ mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
632
+ mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
633
+ split_arg_list_qkvz = [
634
+ self.head_k_dim,
635
+ self.head_k_dim,
636
+ (self.num_v_heads // self.num_k_heads * self.head_v_dim),
637
+ (self.num_v_heads // self.num_k_heads * self.head_v_dim),
638
+ ]
639
+ split_arg_list_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads]
640
+ query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3)
641
+ b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3)
642
+ # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
643
+ value = value.reshape(value.size(0), value.size(1), -1, self.head_v_dim)
644
+ z = z.reshape(z.size(0), z.size(1), -1, self.head_v_dim)
645
+ b = b.reshape(b.size(0), b.size(1), self.num_v_heads)
646
+ a = a.reshape(a.size(0), a.size(1), self.num_v_heads)
647
+ return query, key, value, z, b, a
648
+
649
+ def forward(
650
+ self,
651
+ hidden_states: torch.Tensor,
652
+ attention_mask: Optional[torch.Tensor] = None,
653
+ ):
654
+ hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
655
+
656
+ # Set up dimensions for reshapes later
657
+ batch_size, seq_len, _ = hidden_states.shape
658
+
659
+ # Apply FANformer transformation first
660
+ hidden_states_fan = self.fan_layer(hidden_states)
661
+
662
+ projected_states_qkvz = self.in_proj_qkvz(hidden_states_fan)
663
+ projected_states_ba = self.in_proj_ba(hidden_states_fan)
664
+ query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
665
+ query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value))
666
+
667
+ mixed_qkv = torch.cat((query, key, value), dim=-1)
668
+ mixed_qkv = mixed_qkv.transpose(1, 2)
669
+
670
+ # Simple convolution without cache
671
+ if self.causal_conv1d_fn is not None:
672
+ mixed_qkv = self.causal_conv1d_fn(
673
+ x=mixed_qkv,
674
+ weight=self.conv1d.weight.squeeze(1),
675
+ bias=self.conv1d.bias,
676
+ activation="silu", # Keep original activation for conv1d
677
+ seq_idx=None,
678
+ )
679
+ else:
680
+ mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
681
+
682
+ mixed_qkv = mixed_qkv.transpose(1, 2)
683
+ query, key, value = torch.split(
684
+ mixed_qkv,
685
+ [
686
+ self.key_dim,
687
+ self.key_dim,
688
+ self.value_dim,
689
+ ],
690
+ dim=-1,
691
+ )
692
+ query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim)
693
+ key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim)
694
+ value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim)
695
+
696
+ beta = b.sigmoid()
697
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
698
+ g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
699
+ if self.num_v_heads // self.num_k_heads > 1:
700
+ query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
701
+ key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
702
+
703
+ # Use chunk-based implementation without cache
704
+ core_attn_out, _ = self.chunk_gated_delta_rule(
705
+ query,
706
+ key,
707
+ value,
708
+ g=g,
709
+ beta=beta,
710
+ initial_state=None,
711
+ output_final_state=False,
712
+ use_qk_l2norm_in_kernel=True,
713
+ )
714
+
715
+ z_shape_og = z.shape
716
+ # reshape input data into 2D tensor
717
+ core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
718
+ z = z.reshape(-1, z.shape[-1])
719
+ core_attn_out = self.norm(core_attn_out, z)
720
+ core_attn_out = core_attn_out.reshape(z_shape_og)
721
+ core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1)
722
+
723
+ output = self.out_proj(core_attn_out)
724
+ output = self.dropout(output) # Apply dropout after output projection
725
+ return output
726
+
727
+ class PolyNorm(torch.nn.Module):
728
+ def __init__(self, eps=1e-6):
729
+ super(PolyNorm, self).__init__()
730
+ self.weight = torch.nn.Parameter(torch.ones(3) / 3)
731
+ self.bias = torch.nn.Parameter(torch.zeros(1))
732
+ self.eps = eps
733
+
734
+ def _norm(self, x):
735
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
736
+
737
+ def forward(self, x):
738
+ return self.weight[0] * self._norm(x**3) + self.weight[1] * self._norm(x**2) + self.weight[2] * self._norm(x) + self.bias
739
+
740
+ class NeoLLMMLP(nn.Module):
741
+ def __init__(self, config):
742
+ super().__init__()
743
+ self.config = config
744
+ self.hidden_size = config.hidden_size
745
+ self.intermediate_size = config.intermediate_size
746
+ self.linear1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
747
+ self.linear2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
748
+ self.act_fn = PolyNorm()
749
+
750
+ # Dropout for MLP hidden layer
751
+ self.dropout = nn.Dropout(config.dropout_rate)
752
+
753
+ def forward(self, x):
754
+ hidden = self.act_fn(self.linear1(x))
755
+ hidden = self.dropout(hidden) # Apply dropout after activation
756
+ return self.linear2(hidden)
757
+
758
+
759
+ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
760
+ def __init__(self, config: NeoLLMConfig, layer_idx: int):
761
+ super().__init__()
762
+ self.hidden_size = config.hidden_size
763
+ self.layer_idx = layer_idx
764
+
765
+ # token mixer
766
+ self.layer_type = config.layer_types[layer_idx]
767
+ if self.layer_type == "linear_attention":
768
+ self.linear_attn = NeoLLMGatedDeltaNet(config, layer_idx)
769
+ elif self.layer_type == "full_attention":
770
+ self.self_attn = NeoLLMAttention(config, layer_idx)
771
+
772
+ # Always use regular MLP (no MoE)
773
+ self.mlp = NeoLLMMLP(config)
774
+
775
+ self.input_layernorm = NeoLLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
776
+ self.post_attention_layernorm = NeoLLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
777
+
778
+ # LNS (LayerNorm Scaling) - applies 1/√ℓ scaling
779
+ self.lns_attn = LNS(layer_idx)
780
+ self.lns_mlp = LNS(layer_idx)
781
+
782
+ # GPAS (Gradient-Preserving Activation Scaling) - applied after residual connections
783
+ self.gpas_attn = GPAS(config.hidden_size)
784
+ self.gpas_mlp = GPAS(config.hidden_size)
785
+
786
+ def forward(
787
+ self,
788
+ hidden_states: torch.Tensor,
789
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
790
+ attention_mask: Optional[torch.Tensor] = None,
791
+ **kwargs: Unpack[FlashAttentionKwargs],
792
+ ) -> torch.FloatTensor:
793
+ residual = hidden_states
794
+
795
+ # Apply layer normalization
796
+ hidden_states = self.input_layernorm(hidden_states)
797
+
798
+ # Apply LNS scaling after normalization
799
+ hidden_states = self.lns_attn(hidden_states)
800
+
801
+ # Token Mixer
802
+ if self.layer_type == "linear_attention":
803
+ hidden_states = self.linear_attn(
804
+ hidden_states=hidden_states,
805
+ attention_mask=attention_mask,
806
+ )
807
+ elif self.layer_type == "full_attention":
808
+ # Self Attention
809
+ hidden_states, _ = self.self_attn(
810
+ hidden_states=hidden_states,
811
+ attention_mask=attention_mask,
812
+ position_embeddings=position_embeddings,
813
+ **kwargs,
814
+ )
815
+
816
+ # Residual connection
817
+ hidden_states = residual + hidden_states
818
+
819
+ # Apply GPAS after attention residual connection
820
+ hidden_states = self.gpas_attn(hidden_states)
821
+
822
+ # Fully Connected
823
+ residual = hidden_states
824
+ hidden_states = self.post_attention_layernorm(hidden_states)
825
+
826
+ # Apply LNS scaling after normalization
827
+ hidden_states = self.lns_mlp(hidden_states)
828
+
829
+ hidden_states = self.mlp(hidden_states)
830
+
831
+ # Residual connection
832
+ hidden_states = residual + hidden_states
833
+
834
+ # Apply GPAS after MLP residual connection
835
+ hidden_states = self.gpas_mlp(hidden_states)
836
+
837
+ return hidden_states
838
+
839
+
840
+ class NeoLLMPreTrainedModel(PreTrainedModel):
841
+ config: NeoLLMConfig
842
+ base_model_prefix = "model"
843
+ supports_gradient_checkpointing = True
844
+ _no_split_modules = ["NeoLLMDecoderLayer"]
845
+ _supports_flash_attn_2 = True
846
+ _supports_sdpa = True
847
+ _is_stateful = True
848
+
849
+ def _init_weights(self, module):
850
+ super()._init_weights(module)
851
+ if isinstance(module, NeoLLMGatedDeltaNet):
852
+ module.dt_bias.data.fill_(1.0)
853
+ module.A_log.data.uniform_(0, 16).log_()
854
+ elif isinstance(module, GPAS):
855
+ # Initialize GPAS alpha to 0 as per paper
856
+ module.alpha.data.fill_(0.0)
857
+ elif isinstance(module, FANLayer):
858
+ # FANLayer initialization is handled within the class
859
+ pass
860
+
861
+
862
+ class NeoLLMModel(NeoLLMPreTrainedModel):
863
+ def __init__(self, config: NeoLLMConfig):
864
+ super().__init__(config)
865
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
866
+ self.layers = nn.ModuleList(
867
+ [NeoLLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
868
+ )
869
+ self.norm = NeoLLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
870
+ self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
871
+ self.gradient_checkpointing = False
872
+ # Initialize weights and apply final processing
873
+ self.post_init()
874
+
875
+ def forward(
876
+ self,
877
+ input_ids: Optional[torch.LongTensor] = None,
878
+ attention_mask: Optional[torch.Tensor] = None,
879
+ position_ids: Optional[torch.LongTensor] = None,
880
+ inputs_embeds: Optional[torch.FloatTensor] = None,
881
+ **kwargs: Unpack[TransformersKwargs],
882
+ ) -> BaseModelOutputWithPast:
883
+ if (input_ids is None) ^ (inputs_embeds is not None):
884
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
885
+
886
+ if inputs_embeds is None:
887
+ inputs_embeds = self.embed_tokens(input_ids)
888
+
889
+ if position_ids is None:
890
+ position_ids = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
891
+
892
+ causal_mask = create_causal_mask(
893
+ config=self.config,
894
+ input_embeds=inputs_embeds,
895
+ attention_mask=attention_mask,
896
+ cache_position=position_ids.squeeze(0),
897
+ past_key_values=None,
898
+ position_ids=position_ids,
899
+ )
900
+ linear_attn_mask = self._update_linear_attn_mask(attention_mask, position_ids.squeeze(0))
901
+
902
+ hidden_states = inputs_embeds
903
+
904
+ # create position embeddings to be shared across the decoder layers
905
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
906
+
907
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
908
+ layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask
909
+
910
+ hidden_states = decoder_layer(
911
+ hidden_states,
912
+ position_embeddings=position_embeddings,
913
+ attention_mask=layer_mask,
914
+ **kwargs,
915
+ )
916
+
917
+ hidden_states = self.norm(hidden_states)
918
+
919
+ return BaseModelOutputWithPast(
920
+ last_hidden_state=hidden_states,
921
+ past_key_values=None,
922
+ )
923
+
924
+ def _update_linear_attn_mask(self, attention_mask, cache_position):
925
+ """
926
+ NOTE: Left-padding is used for linear attention mask.
927
+ No need for zeroing states when attending to all inputs
928
+ """
929
+ linear_attn_mask = attention_mask
930
+ if attention_mask is not None and torch.all(attention_mask == 1):
931
+ linear_attn_mask = None
932
+ return linear_attn_mask
933
+
934
+ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
935
+ _tied_weights_keys = ["lm_head.weight"]
936
+
937
+ def __init__(self, config):
938
+ super().__init__(config)
939
+ self.model = NeoLLMModel(config)
940
+ self.vocab_size = config.vocab_size
941
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
942
+
943
+ # Initialize weights and apply final processing
944
+ self.post_init()
945
+
946
+ @torch.compiler.disable
947
+ def _compute_cce_loss(self, hidden_states, labels):
948
+ """
949
+ CCE loss computation excluded from compilation.
950
+ Preprocesses labels to eliminate torch.compile warnings.
951
+ """
952
+ # Ensure labels are on the correct device
953
+ processed_labels = labels.to(hidden_states.device)
954
+
955
+ # Handle pad tokens: convert pad_token_id to -100 for proper masking
956
+ if self.config.pad_token_id is not None:
957
+ processed_labels = torch.where(
958
+ processed_labels == self.config.pad_token_id,
959
+ torch.tensor(-100, dtype=processed_labels.dtype, device=processed_labels.device),
960
+ processed_labels
961
+ )
962
+
963
+ return linear_cross_entropy(
964
+ hidden_states,
965
+ self.lm_head.weight,
966
+ processed_labels, # Use preprocessed labels
967
+ bias=getattr(self.lm_head, 'bias', None),
968
+ shift=1,
969
+ impl="cce",
970
+ reduction="mean"
971
+ )
972
+
973
+ def forward(
974
+ self,
975
+ input_ids: Optional[torch.LongTensor] = None,
976
+ attention_mask: Optional[torch.Tensor] = None,
977
+ position_ids: Optional[torch.LongTensor] = None,
978
+ inputs_embeds: Optional[torch.FloatTensor] = None,
979
+ labels: Optional[torch.LongTensor] = None,
980
+ logits_to_keep: Union[int, torch.Tensor] = 0,
981
+ **kwargs: Unpack[TransformersKwargs],
982
+ ) -> CausalLMOutputWithPast:
983
+ r"""
984
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
985
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
986
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
987
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
988
+ """
989
+
990
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
991
+ outputs: BaseModelOutputWithPast = self.model(
992
+ input_ids=input_ids,
993
+ attention_mask=attention_mask,
994
+ position_ids=position_ids,
995
+ inputs_embeds=inputs_embeds,
996
+ **kwargs,
997
+ )
998
+
999
+ hidden_states = outputs.last_hidden_state
1000
+
1001
+ # CCE Loss computation for training
1002
+ if labels is not None:
1003
+ loss = self._compute_cce_loss(hidden_states, labels)
1004
+ logits = None # CCE doesn't return logits to save memory
1005
+ else:
1006
+ # Inference mode - compute logits normally
1007
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1008
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1009
+ loss = None
1010
+
1011
+ return CausalLMOutputWithPast(
1012
+ loss=loss,
1013
+ logits=logits,
1014
+ past_key_values=None,
1015
+ hidden_states=outputs.hidden_states,
1016
+ attentions=outputs.attentions,
1017
+ )
1018
+
1019
+
1020
+ __all__ = [
1021
+ "NeoLLMForCausalLM",
1022
+ "NeoLLMModel",
1023
+ "NeoLLMPreTrainedModel",
1024
+ "NeoLLMConfig",
1025
+ "FANLayer",
1026
+ ]
1027
+
1028
+ # ==================== AUTOMODEL REGISTRATION ====================
1029
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
1030
+
1031
+ # Register the configuration and model for AutoClass support
1032
+ AutoConfig.register("neollm", NeoLLMConfig)
1033
+ AutoModel.register(NeoLLMConfig, NeoLLMModel)
1034
+ AutoModelForCausalLM.register(NeoLLMConfig, NeoLLMForCausalLM)