KitsuVp commited on
Commit
f38e7cd
·
verified ·
1 Parent(s): a76060b

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +370 -743
modeling_neollm.py CHANGED
@@ -1,8 +1,8 @@
1
  #!/usr/bin/env python3
2
  """
3
  NeoLLM Model with FANformer Integration in both Attention and FFN, Dropout Regularization,
4
- SeeDNorm (Self-Rescaled Dynamic Normalization), and ResFormer Value Residual Learning
5
- for enhanced information flow through deep layers.
6
 
7
  Updated to include:
8
  - Fourier Analysis Network (FAN) layer for effective periodicity modeling in attention (relational space)
@@ -10,7 +10,8 @@ Updated to include:
10
  - SeeDNorm: Dynamic normalization with input-dependent scaling for better adaptability
11
  - Dropout regularization at strategic locations
12
  - ResFormer: Feature residual connections from first layer (applied before projections)
13
- - PoPE (Polar Coordinate Position Embedding): Decouples 'what' and 'where' for superior length extrapolation
 
14
  """
15
 
16
  import math
@@ -27,33 +28,130 @@ from transformers.masking_utils import create_causal_mask
27
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
28
  from transformers.modeling_layers import GradientCheckpointingLayer
29
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 
30
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
31
  from transformers.processing_utils import Unpack
32
  from transformers.utils import TransformersKwargs, logging
33
  from transformers.utils.generic import check_model_inputs
34
- from transformers.utils.import_utils import (
35
- is_causal_conv1d_available,
36
- is_flash_linear_attention_available,
37
- )
38
- from .configuration_neollm import NeoLLMConfig
39
-
40
-
41
- if is_causal_conv1d_available():
42
- from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
43
- else:
44
- causal_conv1d_update, causal_conv1d_fn = None, None
45
-
46
- if is_flash_linear_attention_available():
47
- from fla.modules import FusedRMSNormGated
48
- from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
49
- else:
50
- chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None
51
- FusedRMSNormGated = None
52
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
53
 
54
  logger = logging.get_logger(__name__)
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  class FANLayer(nn.Module):
58
  """
59
  Fourier Analysis Network (FAN) layer for effective periodicity modeling.
@@ -217,200 +315,67 @@ class SeeDNorm(nn.Module):
217
  return f"dim={self.dim}, eps={self.eps}"
218
 
219
 
220
- class NeoLLMRMSNormGated(nn.Module):
221
- """
222
- Gated RMSNorm variant used in specific contexts.
223
- """
224
- def __init__(self, hidden_size, eps=1e-6, **kwargs):
225
- super().__init__()
226
- self.weight = nn.Parameter(torch.ones(hidden_size))
227
- self.variance_epsilon = eps
228
 
229
- def forward(self, hidden_states, gate=None):
230
- input_dtype = hidden_states.dtype
231
- hidden_states = hidden_states.to(torch.float32)
232
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
233
- # Norm before gate
234
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
235
- hidden_states = self.weight * hidden_states.to(input_dtype)
236
- hidden_states = hidden_states * F.silu(gate.to(torch.float32))
 
237
 
238
- return hidden_states.to(input_dtype)
 
239
 
240
- class PolarPositionalEmbedding(nn.Module):
241
- """
242
- Polar Coordinate Position Embedding (PoPE) - FlashAttention2-compatible implementation
243
-
244
- From "Decoupling the 'What' and 'Where' with Polar Coordinate Positional Embedding":
245
-
246
- THEORETICAL FORMULATION (from paper):
247
- - Magnitudes: μ_q̃tc = softplus(qtc), μ_k̃sc = softplus(ksc) (content only)
248
- - Phases: φ_q̃tc = t*θc, φ_k̃sc = s*θc (position only)
249
- - Attention score: a^PoPE_ts = Re[q̃^H @ k̃] = Σ (x_q * x_k + y_q * y_k)
250
-
251
- Where x = μ*cos(φ), y = μ*sin(φ) are Cartesian coordinates.
252
-
253
- PRACTICAL IMPLEMENTATION (this code):
254
- To enable FlashAttention2 compatibility without custom kernels, we use the
255
- mathematically equivalent formulation:
256
-
257
- Q' = [x_q; y_q] ∈ ℝ^(2d) (concatenation of real and imaginary parts)
258
- K' = [x_k; y_k] ∈ ℝ^(2d)
259
-
260
- This doubles head_dim (d → 2d) but allows:
261
- - Standard FlashAttention2 kernel usage
262
- - Q'·K' = Σ(x_q*x_k + y_q*y_k) = a^PoPE_ts (mathematically equivalent)
263
- - ~2× overhead in attention computation (acceptable tradeoff vs custom kernels)
264
-
265
- Benefits retained:
266
- - Superior length extrapolation without fine-tuning
267
- - Decoupled 'what' and 'where' information
268
- - Better performance on content/position independent matching tasks
269
-
270
- Args:
271
- dim: Original dimension per attention head (will be doubled to 2d internally)
272
- max_position_embeddings: Maximum sequence length
273
- base: Base wavelength (theta) for frequency components
274
- device: Device to place tensors on
275
- """
276
-
277
- def __init__(
278
- self,
279
- dim: int,
280
- max_position_embeddings: int = 2048,
281
- base: float = 10000.0,
282
- device=None
283
- ):
284
- super().__init__()
285
- self.dim = dim # Original head_dim (d)
286
- self.max_position_embeddings = max_position_embeddings
287
- self.base = base
288
-
289
- # Compute frequency components: θc = base^(-(c-1)/d) for c = 1, ..., d
290
- # PoPE uses d frequencies (not d/2 like RoPE)
291
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 1, dtype=torch.float32) / self.dim))
292
  self.register_buffer("inv_freq", inv_freq, persistent=False)
293
-
294
- def forward(
295
- self,
296
- q: torch.Tensor,
297
- k: torch.Tensor,
298
- position_ids: torch.LongTensor,
299
- ) -> tuple[torch.Tensor, torch.Tensor]:
300
- """
301
- Apply PoPE transformation with concatenation for FlashAttention2 compatibility.
302
-
303
- Args:
304
- q: Query tensor of shape (batch, num_heads, seq_len, head_dim)
305
- k: Key tensor of shape (batch, num_kv_heads, seq_len, head_dim)
306
- position_ids: Position indices of shape (batch, seq_len)
307
-
308
- Returns:
309
- Tuple of (Q', K') with doubled head_dim:
310
- - Q': shape (batch, num_heads, seq_len, 2*head_dim) = [x_q; y_q]
311
- - K': shape (batch, num_kv_heads, seq_len, 2*head_dim) = [x_k; y_k]
312
- """
313
- # Step 1: Apply softplus to get magnitudes (Equation 3 from paper)
314
- # μ_q̃tc = softplus(qtc), μ_k̃sc = softplus(ksc)
315
- mu_q = F.softplus(q)
316
- mu_k = F.softplus(k)
317
-
318
- # Step 2: Compute phase angles (Equation 4 from paper)
319
- # φ_q̃tc = t*θc, φ_k̃sc = s*θc
320
- # freqs shape: (batch, 1, seq_len, head_dim)
321
- inv_freq_expanded = self.inv_freq[None, None, None, :].to(q.device)
322
- position_ids_expanded = position_ids[:, None, :, None].float()
323
- freqs = position_ids_expanded * inv_freq_expanded
324
-
325
- # Step 3: Convert to Cartesian coordinates (Equations 7-8 from paper)
326
- # x = μ * cos(φ), y = μ * sin(φ)
327
- # Note: Compute trigonometric functions in float32 for precision, then convert
328
- # to input dtype (fp8/fp16/bf16) to maintain efficiency in subsequent operations
329
- cos_freqs = torch.cos(freqs).to(q.dtype)
330
- sin_freqs = torch.sin(freqs).to(q.dtype)
331
-
332
- q_real = mu_q * cos_freqs # x_q component
333
- q_imag = mu_q * sin_freqs # y_q component
334
- k_real = mu_k * cos_freqs # x_k component
335
- k_imag = mu_k * sin_freqs # y_k component
336
-
337
- # Step 4: Concatenate [real; imag] to create 2d dimensional vectors
338
- # This enables Q'·K' = Σ(x_q*x_k + y_q*y_k) via standard dot product
339
- q_pope = torch.cat([q_real, q_imag], dim=-1) # (batch, num_heads, seq_len, 2*head_dim)
340
- k_pope = torch.cat([k_real, k_imag], dim=-1) # (batch, num_kv_heads, seq_len, 2*head_dim)
341
-
342
- return q_pope, k_pope
343
-
344
- def apply_pope_embedding(
345
- q_pope: torch.Tensor,
346
- k_pope: torch.Tensor,
347
- delta_bias: Optional[torch.Tensor] = None,
348
- num_key_value_groups: int = 1
349
- ) -> tuple[torch.Tensor, torch.Tensor]:
350
- """
351
- Apply learnable phase bias δc to PoPE embeddings (Equation 6 from paper).
352
-
353
- With phase bias: a^PoPE_ts = Σ μ_q μ_k cos((s-t)θc + δc)
354
-
355
- This is implemented by rotating k by exp(i*δ) in the concatenated representation.
356
-
357
- Args:
358
- q_pope: Query with PoPE applied, shape (batch, num_heads, seq_len, 2*head_dim)
359
- Format: [x_q; y_q] where first head_dim is real, second head_dim is imaginary
360
- k_pope: Key with PoPE applied, shape (batch, num_kv_heads, seq_len, 2*head_dim)
361
- Format: [x_k; y_k]
362
- delta_bias: Learnable phase bias per head/dim, shape (num_attention_heads, head_dim)
363
- Bounded to [-2π, 0] as per paper. Applied only to keys.
364
- num_key_value_groups: Number of query groups per key/value head for GQA
365
-
366
- Returns:
367
- Tuple of (q_out, k_out) with delta_bias applied:
368
- - q_out: Query unchanged (phase bias only affects keys)
369
- - k_out: Key rotated by delta_bias
370
- Both maintain shape with 2*head_dim
371
- """
372
- # Query passes through unchanged (phase bias only affects keys)
373
- q_out = q_pope
374
-
375
- # Apply learnable phase bias to key if provided
376
- if delta_bias is not None:
377
- # Get head_dim (original dimension, half of current last dim)
378
- head_dim = k_pope.shape[-1] // 2
379
-
380
- # Split k into real and imaginary components
381
- k_real, k_imag = k_pope[..., :head_dim], k_pope[..., head_dim:]
382
-
383
- # Clamp delta_bias to [-2π, 0] as specified in paper Section 3
384
- delta_clamped = torch.clamp(delta_bias, min=-2*math.pi, max=0)
385
-
386
- # Adapt delta_bias for GQA: (num_attention_heads, head_dim) -> (num_kv_heads, head_dim)
387
- # Group the attention heads' biases by averaging/selecting
388
- if num_key_value_groups > 1:
389
- # Reshape: (num_attention_heads, head_dim) -> (num_kv_heads, num_key_value_groups, head_dim)
390
- num_kv_heads = delta_clamped.shape[0] // num_key_value_groups
391
- delta_clamped = delta_clamped.view(num_kv_heads, num_key_value_groups, head_dim)
392
- # Average across the groups to get one bias per kv_head
393
- delta_clamped = delta_clamped.mean(dim=1) # (num_kv_heads, head_dim)
394
-
395
- # Reshape for broadcasting: (num_kv_heads, head_dim) -> (1, num_kv_heads, 1, head_dim)
396
- delta_clamped = delta_clamped.unsqueeze(0).unsqueeze(2)
397
-
398
- # Compute rotation components: exp(i*δ) = cos(δ) + i*sin(δ)
399
- cos_delta = torch.cos(delta_clamped)
400
- sin_delta = torch.sin(delta_clamped)
401
-
402
- # Apply complex multiplication: k * exp(i*δ)
403
- # Real part: k_real*cos(δ) - k_imag*sin(δ)
404
- # Imag part: k_real*sin(δ) + k_imag*cos(δ)
405
- k_real_rotated = k_real * cos_delta - k_imag * sin_delta
406
- k_imag_rotated = k_real * sin_delta + k_imag * cos_delta
407
-
408
- # Recombine into concatenated form [real; imag]
409
- k_out = torch.cat([k_real_rotated, k_imag_rotated], dim=-1)
410
- else:
411
- k_out = k_pope
412
-
413
- return q_out, k_out
414
 
415
 
416
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -435,18 +400,10 @@ def eager_attention_forward(
435
  dropout: float = 0.0,
436
  **kwargs: Unpack[TransformersKwargs],
437
  ):
438
- """
439
- Standard eager attention implementation for PoPE.
440
-
441
- Note: query and key have 2*head_dim due to PoPE concatenation [real; imag].
442
- Value is padded to match this dimension for kernel compatibility.
443
- """
444
  key_states = repeat_kv(key, module.num_key_value_groups)
445
  value_states = repeat_kv(value, module.num_key_value_groups)
446
 
447
- # Standard attention computation
448
  attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
449
-
450
  if attention_mask is not None:
451
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
452
  attn_weights = attn_weights + causal_mask
@@ -462,14 +419,16 @@ def eager_attention_forward(
462
  class NeoLLMAttention(nn.Module):
463
  """
464
  Multi-headed attention with FANformer integration, SeeDNorm for Q/K normalization,
465
- PoPE for positional encoding, and ResFormer feature residual connections.
 
466
 
467
  ResFormer enhancement: Applies learnable feature residual connections from the first layer
468
  BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n
469
 
470
- PoPE enhancement: Decouples 'what' and 'where' via polar coordinates for superior
471
- length extrapolation and content/position independent matching. Uses concatenated
472
- [real; imag] representation for FlashAttention2 compatibility ( head_dim overhead).
 
473
  """
474
 
475
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
@@ -477,11 +436,7 @@ class NeoLLMAttention(nn.Module):
477
  self.config = config
478
  self.layer_idx = layer_idx
479
  self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
480
- self.num_attention_heads = config.num_attention_heads
481
- self.num_key_value_heads = config.num_key_value_heads
482
- self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
483
-
484
- # PoPE uses original head_dim for scaling (not 2*head_dim)
485
  self.scaling = self.head_dim**-0.5
486
  self.attention_dropout = config.attention_dropout
487
  self.is_causal = True
@@ -495,36 +450,36 @@ class NeoLLMAttention(nn.Module):
495
  # Calculate the output dimension after FAN transformation
496
  fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio', 0.125))
497
 
498
- # QKV projections operate on FAN-transformed features
499
- self.q_proj = nn.Linear(
500
- fan_output_dim, self.num_attention_heads * self.head_dim * 2, bias=config.attention_bias
 
 
 
 
501
  )
 
 
502
  self.k_proj = nn.Linear(
503
- fan_output_dim, self.num_key_value_heads * self.head_dim, bias=config.attention_bias
504
  )
505
  self.v_proj = nn.Linear(
506
- fan_output_dim, self.num_key_value_heads * self.head_dim, bias=config.attention_bias
507
  )
508
- self.o_proj = nn.Linear(
509
- self.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
 
 
 
 
 
 
510
  )
511
 
512
  # SeeDNorm for Q/K normalization (replaces RMSNorm)
513
  self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
514
  self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
515
 
516
- # PoPE: Learnable phase bias δc for each head and dimension
517
- # Initialized based on pope_bias_init config: 'zero' or 'uniform'
518
- pope_bias_init = getattr(config, 'pope_bias_init', 'zero')
519
- if pope_bias_init == 'uniform':
520
- # Uniform initialization in [-2π, 0]
521
- delta_init = torch.empty(self.num_attention_heads, self.head_dim).uniform_(-2 * math.pi, 0)
522
- else:
523
- # Zero initialization (better for length extrapolation)
524
- delta_init = torch.zeros(self.num_attention_heads, self.head_dim)
525
-
526
- self.delta_bias = nn.Parameter(delta_init)
527
-
528
  # Dropout for attention output
529
  self.dropout = nn.Dropout(config.dropout_rate)
530
 
@@ -541,61 +496,35 @@ class NeoLLMAttention(nn.Module):
541
  **kwargs: Unpack[FlashAttentionKwargs],
542
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
543
  input_shape = hidden_states.shape[:-1]
544
- batch_size, seq_len = input_shape
545
 
546
  # Apply FANformer transformation first
547
  hidden_states_fan = self.fan_layer(hidden_states)
548
 
549
  # ResFormer: Apply feature residual connection BEFORE projections
 
550
  if first_layer_fan is not None:
551
  hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
552
 
553
  # Store current FAN features for potential use as first_layer_fan in subsequent layers
554
  current_layer_fan = hidden_states_fan.clone()
555
 
556
- # Project to Q, K, V
 
 
 
557
  query_states, gate = torch.chunk(
558
- self.q_proj(hidden_states_fan).view(batch_size, seq_len, self.num_attention_heads, self.head_dim * 2),
559
- 2, dim=-1
560
- )
561
- gate = gate.reshape(batch_size, seq_len, -1)
562
-
563
- key_states = self.k_proj(hidden_states_fan).view(
564
- batch_size, seq_len, self.num_key_value_heads, self.head_dim
565
- )
566
- value_states = self.v_proj(hidden_states_fan).view(
567
- batch_size, seq_len, self.num_key_value_heads, self.head_dim
568
  )
569
-
570
- # Apply SeeDNorm to Q and K before PoPE
571
- query_states = self.q_norm(query_states)
572
- key_states = self.k_norm(key_states)
573
-
574
- # Transpose to (batch, num_heads, seq_len, head_dim)
575
- query_states = query_states.transpose(1, 2)
576
- key_states = key_states.transpose(1, 2)
577
- value_states = value_states.transpose(1, 2)
578
-
579
- # Apply PoPE: position_embeddings is (pope_emb, position_ids)
580
- pope_emb, position_ids = position_embeddings
581
-
582
- # Get PoPE embeddings with concatenated [real; imag] representation
583
- # Returns Q', K' with shape (..., 2*head_dim)
584
- query_states, key_states = pope_emb(query_states, key_states, position_ids)
585
-
586
- # Apply learnable phase bias δc
587
- # Apply learnable phase bias δc
588
- query_states, key_states = apply_pope_embedding(
589
- query_states,
590
- key_states,
591
- self.delta_bias,
592
- num_key_value_groups=self.num_key_value_groups # AGREGAR ESTE PARÁMETRO
593
- )
594
- # Pad value to 2*head_dim for dimension compatibility
595
- # Only first head_dim components are used in output
596
- value_states = F.pad(value_states, (0, self.head_dim), value=0.0)
597
-
598
- # Call attention with doubled head_dim
599
  attention_interface: Callable = eager_attention_forward
600
  if self.config._attn_implementation != "eager":
601
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
@@ -611,391 +540,16 @@ class NeoLLMAttention(nn.Module):
611
  **kwargs,
612
  )
613
 
614
- # Extract only the first head_dim components (discard padding)
615
- attn_output = attn_output[..., :self.head_dim]
616
-
617
- attn_output = attn_output.reshape(batch_size, seq_len, -1).contiguous()
618
  attn_output = attn_output * torch.sigmoid(gate)
619
 
 
620
  attn_output = self.o_proj(attn_output)
621
  attn_output = self.dropout(attn_output)
622
 
623
  return attn_output, attn_weights, current_layer_fan
624
 
625
 
626
- def apply_mask_to_padding_states(hidden_states, attention_mask):
627
- """
628
- Tunes out the hidden states for padding tokens
629
- """
630
- if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
631
- dtype = hidden_states.dtype
632
- hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
633
-
634
- return hidden_states
635
-
636
-
637
- is_fast_path_available = all(
638
- (causal_conv1d_fn, causal_conv1d_update, chunk_gated_delta_rule, fused_recurrent_gated_delta_rule)
639
- )
640
-
641
-
642
- def torch_causal_conv1d_update(
643
- hidden_states,
644
- conv_state,
645
- weight,
646
- bias=None,
647
- activation=None,
648
- ):
649
- _, hidden_size, seq_len = hidden_states.shape
650
- state_len = conv_state.shape[-1]
651
-
652
- hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
653
- conv_state.copy_(hidden_states_new[:, :, -state_len:])
654
- out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
655
- out = F.silu(out[:, :, -seq_len:])
656
- out = out.to(hidden_states.dtype)
657
- return out
658
-
659
-
660
- def l2norm(x: torch.FloatTensor, dim: int = -1, eps: float = 1e-6):
661
- """This function is intended to align with the l2norm implementation in the FLA library."""
662
- inv_norm = 1 / torch.sqrt((x * x).sum(dim=dim, keepdim=True) + eps)
663
- return x * inv_norm
664
-
665
-
666
- def torch_chunk_gated_delta_rule(
667
- query,
668
- key,
669
- value,
670
- g,
671
- beta,
672
- chunk_size=64,
673
- initial_state=None,
674
- output_final_state=False,
675
- use_qk_l2norm_in_kernel=False,
676
- ):
677
- initial_dtype = query.dtype
678
- if use_qk_l2norm_in_kernel:
679
- query = l2norm(query, dim=-1, eps=1e-6)
680
- key = l2norm(key, dim=-1, eps=1e-6)
681
- query, key, value, beta, g = [
682
- x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
683
- ]
684
-
685
- batch_size, sequence_length, num_heads, k_head_dim = key.shape
686
- v_head_dim = value.shape[-1]
687
- pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
688
- query = F.pad(query, (0, 0, 0, pad_size))
689
- key = F.pad(key, (0, 0, 0, pad_size))
690
- value = F.pad(value, (0, 0, 0, pad_size))
691
- beta = F.pad(beta, (0, pad_size))
692
- g = F.pad(g, (0, pad_size))
693
- tot_heads = num_heads + pad_size
694
- scale = 1 / (query.shape[-1] ** 0.5)
695
- query = query * scale
696
-
697
- v_beta = value * beta.unsqueeze(-1)
698
- k_beta = key * beta.unsqueeze(-1)
699
- # reshape to chunks
700
- query, key, value, k_beta, v_beta = [
701
- x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
702
- ]
703
- g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
704
- mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
705
-
706
- # chunk decay
707
- g = g.cumsum(dim=-1)
708
- decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
709
- attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
710
- for i in range(1, chunk_size):
711
- row = attn[..., i, :i].clone()
712
- sub = attn[..., :i, :i].clone()
713
- attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
714
- attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
715
- value = attn @ v_beta
716
- k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
717
- last_recurrent_state = (
718
- torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
719
- if initial_state is None
720
- else initial_state.to(value)
721
- )
722
- core_attn_out = torch.zeros_like(value)
723
- mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
724
-
725
- # for each chunk
726
- for i in range(0, tot_heads // chunk_size):
727
- q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
728
- attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
729
- v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
730
- v_new = v_i - v_prime
731
- attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
732
- core_attn_out[:, :, i] = attn_inter + attn @ v_new
733
- last_recurrent_state = (
734
- last_recurrent_state * g[:, :, i, -1, None, None].exp()
735
- + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
736
- )
737
-
738
- if not output_final_state:
739
- last_recurrent_state = None
740
- core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
741
- core_attn_out = core_attn_out[:, :, :num_heads]
742
- core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
743
- return core_attn_out, last_recurrent_state
744
-
745
-
746
- def torch_recurrent_gated_delta_rule(
747
- query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
748
- ):
749
- initial_dtype = query.dtype
750
- if use_qk_l2norm_in_kernel:
751
- query = l2norm(query, dim=-1, eps=1e-6)
752
- key = l2norm(key, dim=-1, eps=1e-6)
753
- query, key, value, beta, g = [
754
- x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
755
- ]
756
-
757
- batch_size, sequence_length, num_heads, k_head_dim = key.shape
758
- v_head_dim = value.shape[-1]
759
- scale = 1 / (query.shape[-1] ** 0.5)
760
- query = query * scale
761
-
762
- core_attn_out = torch.zeros(batch_size, sequence_length, num_heads, v_head_dim).to(value)
763
- last_recurrent_state = (
764
- torch.zeros(batch_size, sequence_length, k_head_dim, v_head_dim).to(value)
765
- if initial_state is None
766
- else initial_state.to(value)
767
- )
768
-
769
- for i in range(num_heads):
770
- q_t = query[:, :, i]
771
- k_t = key[:, :, i]
772
- v_t = value[:, :, i]
773
- g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
774
- beta_t = beta[:, :, i].unsqueeze(-1)
775
-
776
- last_recurrent_state = last_recurrent_state * g_t
777
- kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
778
- delta = (v_t - kv_mem) * beta_t
779
- last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
780
- core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
781
-
782
- if not output_final_state:
783
- last_recurrent_state = None
784
- core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
785
- return core_attn_out, last_recurrent_state
786
-
787
-
788
- class NeoLLMGatedDeltaNet(nn.Module):
789
- """
790
- Linear attention with FANformer integration, SeeDNorm for normalization,
791
- and ResFormer feature residual connections for enhanced information flow.
792
-
793
- ResFormer enhancement: Applies learnable feature residual connections from the first layer
794
- BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n
795
- """
796
-
797
- def __init__(self, config: NeoLLMConfig, layer_idx: int):
798
- super().__init__()
799
- self.hidden_size = config.hidden_size
800
- self.num_v_heads = config.linear_num_value_heads
801
- self.num_k_heads = config.linear_num_key_heads
802
- self.head_k_dim = config.linear_key_head_dim
803
- self.head_v_dim = config.linear_value_head_dim
804
- self.key_dim = self.head_k_dim * self.num_k_heads
805
- self.value_dim = self.head_v_dim * self.num_v_heads
806
-
807
- self.conv_kernel_size = config.linear_conv_kernel_dim
808
- self.layer_idx = layer_idx
809
- self.activation = config.hidden_act
810
- self.act = ACT2FN[config.hidden_act]
811
- self.layer_norm_epsilon = config.rms_norm_eps
812
-
813
- # FANformer integration: FAN layer before projections
814
- self.fan_layer = FANLayer(
815
- hidden_size=config.hidden_size,
816
- fan_ratio=getattr(config, 'fan_ratio', 0.125)
817
- )
818
-
819
- # Calculate the output dimension after FAN transformation
820
- fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio', 0.125))
821
-
822
- # QKV - operates on FAN-transformed features
823
- self.conv_dim = self.key_dim * 2 + self.value_dim
824
- self.conv1d = nn.Conv1d(
825
- in_channels=self.conv_dim,
826
- out_channels=self.conv_dim,
827
- bias=False,
828
- kernel_size=self.conv_kernel_size,
829
- groups=self.conv_dim,
830
- padding=self.conv_kernel_size - 1,
831
- )
832
-
833
- # projection of the FAN-transformed hidden states
834
- projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2
835
- projection_size_ba = self.num_v_heads * 2
836
- self.in_proj_qkvz = nn.Linear(fan_output_dim, projection_size_qkvz, bias=False)
837
- self.in_proj_ba = nn.Linear(fan_output_dim, projection_size_ba, bias=False)
838
-
839
- # time step projection (discretization)
840
- self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
841
-
842
- A = torch.empty(self.num_v_heads).uniform_(0, 16)
843
- self.A_log = nn.Parameter(torch.log(A))
844
-
845
- # FLA compatibility: use "silu" for FusedRMSNormGated, original activation elsewhere
846
- fla_compatible_activation = "silu" if self.activation not in ['swish', 'silu', 'sigmoid'] else self.activation
847
-
848
- self.norm = (
849
- NeoLLMRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon)
850
- if FusedRMSNormGated is None
851
- else FusedRMSNormGated(
852
- self.head_v_dim,
853
- eps=self.layer_norm_epsilon,
854
- activation=fla_compatible_activation,
855
- device=torch.cuda.current_device(),
856
- dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
857
- )
858
- )
859
-
860
- self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
861
-
862
- # Dropout for attention output
863
- self.dropout = nn.Dropout(config.dropout_rate)
864
-
865
- self.causal_conv1d_fn = causal_conv1d_fn
866
- self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update
867
- self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule
868
- self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule
869
-
870
- # ResFormer: learnable feature residual parameters (initialized to 0.5)
871
- self.lambda_1 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_1 (first layer features)
872
- self.lambda_2 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_n (current layer features)
873
-
874
- if not is_fast_path_available:
875
- logger.warning_once(
876
- "The fast path is not available because one of the required library is not installed. Falling back to "
877
- "torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and"
878
- " https://github.com/Dao-AILab/causal-conv1d"
879
- )
880
-
881
- def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba):
882
- """
883
- Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`.
884
- """
885
- new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
886
- self.num_k_heads,
887
- 2 * self.head_k_dim + 2 * self.head_v_dim * self.num_v_heads // self.num_k_heads,
888
- )
889
- new_tensor_shape_ba = mixed_ba.size()[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads)
890
-
891
- mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
892
- mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
893
- split_arg_list_qkvz = [
894
- self.head_k_dim,
895
- self.head_k_dim,
896
- (self.num_v_heads // self.num_k_heads * self.head_v_dim),
897
- (self.num_v_heads // self.num_k_heads * self.head_v_dim),
898
- ]
899
- split_arg_list_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads]
900
- query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=3)
901
- b, a = torch.split(mixed_ba, split_arg_list_ba, dim=3)
902
- # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
903
- value = value.reshape(value.size(0), value.size(1), -1, self.head_v_dim)
904
- z = z.reshape(z.size(0), z.size(1), -1, self.head_v_dim)
905
- b = b.reshape(b.size(0), b.size(1), self.num_v_heads)
906
- a = a.reshape(a.size(0), a.size(1), self.num_v_heads)
907
- return query, key, value, z, b, a
908
-
909
- def forward(
910
- self,
911
- hidden_states: torch.Tensor,
912
- attention_mask: Optional[torch.Tensor] = None,
913
- first_layer_fan: Optional[torch.Tensor] = None,
914
- ) -> tuple[torch.Tensor, torch.Tensor]:
915
- hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
916
-
917
- # Set up dimensions for reshapes later
918
- batch_size, seq_len, _ = hidden_states.shape
919
-
920
- # Apply FANformer transformation first
921
- hidden_states_fan = self.fan_layer(hidden_states)
922
-
923
- # ResFormer: Apply feature residual connection BEFORE projections
924
- # This ensures dimensional compatibility across all layer types
925
- if first_layer_fan is not None:
926
- hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
927
-
928
- # Store current FAN features for potential use as first_layer_fan in subsequent layers
929
- current_layer_fan = hidden_states_fan.clone()
930
-
931
- # Use FAN-transformed features (with residual applied) for projections
932
- projected_states_qkvz = self.in_proj_qkvz(hidden_states_fan)
933
- projected_states_ba = self.in_proj_ba(hidden_states_fan)
934
- query, key, value, z, b, a = self.fix_query_key_value_ordering(projected_states_qkvz, projected_states_ba)
935
- query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value))
936
-
937
- mixed_qkv = torch.cat((query, key, value), dim=-1)
938
- mixed_qkv = mixed_qkv.transpose(1, 2)
939
-
940
- # Simple convolution without cache
941
- if self.causal_conv1d_fn is not None:
942
- mixed_qkv = self.causal_conv1d_fn(
943
- x=mixed_qkv,
944
- weight=self.conv1d.weight.squeeze(1),
945
- bias=self.conv1d.bias,
946
- activation="silu", # Keep original activation for conv1d
947
- seq_idx=None,
948
- )
949
- else:
950
- mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
951
-
952
- mixed_qkv = mixed_qkv.transpose(1, 2)
953
- query, key, value = torch.split(
954
- mixed_qkv,
955
- [
956
- self.key_dim,
957
- self.key_dim,
958
- self.value_dim,
959
- ],
960
- dim=-1,
961
- )
962
- query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim)
963
- key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim)
964
- value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim)
965
-
966
- beta = b.sigmoid()
967
- # If the model is loaded in fp16, without the .float() here, A might be -inf
968
- g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
969
- if self.num_v_heads // self.num_k_heads > 1:
970
- query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
971
- key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
972
-
973
- # Use chunk-based implementation without cache
974
- core_attn_out, _ = self.chunk_gated_delta_rule(
975
- query,
976
- key,
977
- value,
978
- g=g,
979
- beta=beta,
980
- initial_state=None,
981
- output_final_state=False,
982
- use_qk_l2norm_in_kernel=True,
983
- )
984
-
985
- z_shape_og = z.shape
986
- # reshape input data into 2D tensor
987
- core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
988
- z = z.reshape(-1, z.shape[-1])
989
- core_attn_out = self.norm(core_attn_out, z)
990
- core_attn_out = core_attn_out.reshape(z_shape_og)
991
- core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1)
992
-
993
- output = self.out_proj(core_attn_out)
994
- output = self.dropout(output) # Apply dropout after output projection
995
-
996
- return output, current_layer_fan
997
-
998
-
999
  class PolyNorm(torch.nn.Module):
1000
  def __init__(self, eps=1e-6):
1001
  super(PolyNorm, self).__init__()
@@ -1012,11 +566,17 @@ class PolyNorm(torch.nn.Module):
1012
 
1013
  class NeoLLMMLP(nn.Module):
1014
  """
1015
- MLP with FANformer integration for featural periodicity modeling.
 
1016
 
1017
  This captures periodicities in the feature space (semantic/embedding dimensions)
1018
  complementary to the relational periodicities captured by attention mechanisms.
1019
  Works in conjunction with ResFormer for comprehensive information flow.
 
 
 
 
 
1020
  """
1021
  def __init__(self, config):
1022
  super().__init__()
@@ -1024,7 +584,7 @@ class NeoLLMMLP(nn.Module):
1024
  self.hidden_size = config.hidden_size
1025
  self.intermediate_size = config.intermediate_size
1026
 
1027
- # NEW: FANformer integration for featural space periodicity
1028
  self.fan_layer = FANLayer(
1029
  hidden_size=config.hidden_size,
1030
  fan_ratio=getattr(config, 'fan_ratio_ffn', 0.0625) # Half of attention's fan_ratio
@@ -1033,17 +593,35 @@ class NeoLLMMLP(nn.Module):
1033
  # Calculate the output dimension after FAN transformation
1034
  fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio_ffn', 0.0625))
1035
 
1036
- # SwiGLU/Gated architecture - now operates on FAN-transformed features
1037
- self.gate_proj = nn.Linear(fan_output_dim, self.intermediate_size, bias=False)
 
 
 
 
 
 
 
 
 
1038
  self.up_proj = nn.Linear(fan_output_dim, self.intermediate_size, bias=False)
1039
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
 
 
 
 
 
 
 
 
 
1040
  self.act_fn = PolyNorm()
1041
 
1042
  # Dropout for MLP hidden layer
1043
  self.dropout = nn.Dropout(config.dropout_rate)
1044
 
1045
  def forward(self, x):
1046
- # NEW: Apply FAN transformation before projections
1047
  x_fan = self.fan_layer(x)
1048
 
1049
  # Use FAN-transformed features for gate and up projections
@@ -1055,19 +633,27 @@ class NeoLLMMLP(nn.Module):
1055
 
1056
 
1057
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
 
 
 
 
 
 
 
 
 
 
 
 
1058
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
1059
  super().__init__()
1060
  self.hidden_size = config.hidden_size
1061
  self.layer_idx = layer_idx
1062
 
1063
- # token mixer
1064
- self.layer_type = config.layer_types[layer_idx]
1065
- if self.layer_type == "linear_attention":
1066
- self.linear_attn = NeoLLMGatedDeltaNet(config, layer_idx)
1067
- elif self.layer_type == "full_attention":
1068
- self.self_attn = NeoLLMAttention(config, layer_idx)
1069
 
1070
- # MLP with FANformer integration
1071
  self.mlp = NeoLLMMLP(config)
1072
 
1073
  # SeeDNorm for input and post-attention normalization (replaces RMSNorm)
@@ -1093,6 +679,9 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
1093
  first_layer_fan: Optional[torch.Tensor] = None,
1094
  **kwargs: Unpack[FlashAttentionKwargs],
1095
  ) -> torch.FloatTensor:
 
 
 
1096
  residual = hidden_states
1097
 
1098
  # Apply SeeDNorm normalization
@@ -1101,22 +690,14 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
1101
  # Apply LNS scaling after normalization
1102
  hidden_states = self.lns_attn(hidden_states)
1103
 
1104
- # Token Mixer with ResFormer feature residual connections
1105
- if self.layer_type == "linear_attention":
1106
- hidden_states, self.current_layer_fan = self.linear_attn(
1107
- hidden_states=hidden_states,
1108
- attention_mask=attention_mask,
1109
- first_layer_fan=first_layer_fan,
1110
- )
1111
- elif self.layer_type == "full_attention":
1112
- # Self Attention
1113
- hidden_states, _, self.current_layer_fan = self.self_attn(
1114
- hidden_states=hidden_states,
1115
- attention_mask=attention_mask,
1116
- position_embeddings=position_embeddings,
1117
- first_layer_fan=first_layer_fan,
1118
- **kwargs,
1119
- )
1120
 
1121
  # Standard residual connection
1122
  hidden_states = residual + hidden_states
@@ -1124,14 +705,16 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
1124
  # Apply GPAS after attention residual connection
1125
  hidden_states = self.gpas_attn(hidden_states)
1126
 
1127
- # Fully Connected with FANformer
 
 
1128
  residual = hidden_states
1129
  hidden_states = self.post_attention_layernorm(hidden_states)
1130
 
1131
  # Apply LNS scaling after normalization
1132
  hidden_states = self.lns_mlp(hidden_states)
1133
 
1134
- # MLP now includes FAN transformation internally
1135
  hidden_states = self.mlp(hidden_states)
1136
 
1137
  # Standard residual connection
@@ -1144,6 +727,16 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
1144
 
1145
 
1146
  class NeoLLMPreTrainedModel(PreTrainedModel):
 
 
 
 
 
 
 
 
 
 
1147
  config: NeoLLMConfig
1148
  base_model_prefix = "model"
1149
  supports_gradient_checkpointing = True
@@ -1153,59 +746,88 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
1153
  _is_stateful = True
1154
 
1155
  def _init_weights(self, module):
 
 
 
 
 
 
 
 
1156
  super()._init_weights(module)
1157
- if isinstance(module, NeoLLMGatedDeltaNet):
1158
- module.dt_bias.data.fill_(1.0)
1159
- module.A_log.data.uniform_(0, 16).log_()
1160
- # ResFormer: initialize lambda parameters for linear attention
1161
- if hasattr(module, 'lambda_1'):
1162
- module.lambda_1.data.fill_(0.5)
1163
- if hasattr(module, 'lambda_2'):
1164
- module.lambda_2.data.fill_(0.5)
1165
- elif isinstance(module, NeoLLMAttention):
1166
  # ResFormer: initialize lambda parameters for full attention
 
 
1167
  if hasattr(module, 'lambda_1'):
1168
  module.lambda_1.data.fill_(0.5)
1169
  if hasattr(module, 'lambda_2'):
1170
  module.lambda_2.data.fill_(0.5)
1171
- # PoPE delta_bias already initialized in __init__
1172
  elif isinstance(module, GPAS):
1173
  # Initialize GPAS alpha to 0 as per paper
 
1174
  module.alpha.data.fill_(0.0)
 
1175
  elif isinstance(module, FANLayer):
1176
- # FANLayer initialization is handled within the class
 
1177
  pass
 
1178
  elif isinstance(module, SeeDNorm):
1179
- # SeeDNorm initialization:
1180
- # gamma (γ) initialized to 1 (default in Parameter definition)
1181
- # beta (β) initialized to 0 (default in Parameter definition)
1182
- # alpha (α) initialized to 1 (default in Parameter definition)
1183
  pass
1184
- elif isinstance(module, PolarPositionalEmbedding):
1185
- # PoPE frequency initialization handled in __init__
1186
- pass
1187
-
 
 
 
1188
 
1189
  class NeoLLMModel(NeoLLMPreTrainedModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1190
  def __init__(self, config: NeoLLMConfig):
1191
  super().__init__(config)
 
 
 
 
1192
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
1193
 
1194
  # Each layer creates its own components (no shared parameters)
1195
  self.layers = nn.ModuleList(
1196
  [NeoLLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1197
  )
 
1198
  # SeeDNorm for final output normalization (replaces RMSNorm)
1199
  self.norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
1200
-
1201
- # PoPE positional embedding (replaces RoPE)
1202
- head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
1203
- self.pope_emb = PolarPositionalEmbedding(
1204
- dim=head_dim,
1205
- max_position_embeddings=config.max_position_embeddings,
1206
- base=getattr(config, 'rope_theta', 10000.0), # Use rope_theta for backward compatibility
1207
- )
1208
-
1209
  self.gradient_checkpointing = False
1210
 
1211
  # ResFormer: storage for first layer's FAN features (H_fan_1)
@@ -1226,6 +848,10 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1226
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1227
 
1228
  if inputs_embeds is None:
 
 
 
 
1229
  inputs_embeds = self.embed_tokens(input_ids)
1230
 
1231
  if position_ids is None:
@@ -1239,24 +865,20 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1239
  past_key_values=None,
1240
  position_ids=position_ids,
1241
  )
1242
- linear_attn_mask = self._update_linear_attn_mask(attention_mask, position_ids.squeeze(0))
1243
 
1244
  hidden_states = inputs_embeds
1245
 
1246
- # Create position embeddings for PoPE
1247
- # position_embeddings is a tuple of (pope_emb, position_ids)
1248
- position_embeddings = (self.pope_emb, position_ids)
1249
 
1250
  # ResFormer: reset first_layer_fan at the start of each forward pass
1251
  self.first_layer_fan = None
1252
 
1253
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
1254
- layer_mask = linear_attn_mask if decoder_layer.layer_type == "linear_attention" else causal_mask
1255
-
1256
  hidden_states = decoder_layer(
1257
  hidden_states,
1258
  position_embeddings=position_embeddings,
1259
- attention_mask=layer_mask,
1260
  first_layer_fan=self.first_layer_fan, # Pass H_fan_1 to all layers
1261
  **kwargs,
1262
  )
@@ -1273,16 +895,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1273
  past_key_values=None,
1274
  )
1275
 
1276
- def _update_linear_attn_mask(self, attention_mask, cache_position):
1277
- """
1278
- NOTE: Left-padding is used for linear attention mask.
1279
- No need for zeroing states when attending to all inputs
1280
- """
1281
- linear_attn_mask = attention_mask
1282
- if attention_mask is not None and torch.all(attention_mask == 1):
1283
- linear_attn_mask = None
1284
- return linear_attn_mask
1285
-
1286
 
1287
  @torch.compiler.disable
1288
  def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None):
@@ -1313,13 +925,26 @@ def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, p
1313
 
1314
 
1315
  class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
 
 
 
 
 
 
 
 
 
1316
  _tied_weights_keys = ["lm_head.weight"]
1317
 
1318
  def __init__(self, config):
1319
  super().__init__(config)
1320
  self.model = NeoLLMModel(config)
1321
  self.vocab_size = config.vocab_size
 
 
 
1322
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
1323
  self.post_init()
1324
 
1325
  def forward(
@@ -1376,7 +1001,9 @@ __all__ = [
1376
  "NeoLLMConfig",
1377
  "FANLayer",
1378
  "SeeDNorm",
1379
- "PolarPositionalEmbedding",
 
 
1380
  ]
1381
 
1382
  # Register the configuration and model for AutoClass support
 
1
  #!/usr/bin/env python3
2
  """
3
  NeoLLM Model with FANformer Integration in both Attention and FFN, Dropout Regularization,
4
+ SeeDNorm (Self-Rescaled Dynamic Normalization), ResFormer Value Residual Learning,
5
+ and Learnable Multipliers for enhanced scale adaptation and information flow through deep layers.
6
 
7
  Updated to include:
8
  - Fourier Analysis Network (FAN) layer for effective periodicity modeling in attention (relational space)
 
10
  - SeeDNorm: Dynamic normalization with input-dependent scaling for better adaptability
11
  - Dropout regularization at strategic locations
12
  - ResFormer: Feature residual connections from first layer (applied before projections)
13
+ - Learnable Multipliers: Frees weight matrix scale from WD-noise equilibrium for data-adaptive scaling
14
+ - Full Attention only (linear attention removed)
15
  """
16
 
17
  import math
 
28
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
29
  from transformers.modeling_layers import GradientCheckpointingLayer
30
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
31
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
32
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
33
  from transformers.processing_utils import Unpack
34
  from transformers.utils import TransformersKwargs, logging
35
  from transformers.utils.generic import check_model_inputs
36
+ from configuration_neollm import NeoLLMConfig
37
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
39
 
40
  logger = logging.get_logger(__name__)
41
 
42
 
43
+ # ==================== LEARNABLE MULTIPLIERS ====================
44
+
45
+ class ScalarMultiplier(nn.Module):
46
+ """
47
+ Scalar Learnable Multiplier: W̃ = s·W
48
+
49
+ From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers":
50
+ Allows the effective matrix norm ||W̃|| = s·||W|| to adapt to data, escaping the
51
+ WD-noise equilibrium that constrains ||W|| ∝ √(η/λ).
52
+
53
+ Args:
54
+ initial_value: Initial multiplier value (default: 1.0 for identity)
55
+ """
56
+ def __init__(self, initial_value: float = 1.0):
57
+ super().__init__()
58
+ self.multiplier = nn.Parameter(torch.tensor(initial_value))
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ return self.multiplier * x
62
+
63
+
64
+ class VectorMultiplier(nn.Module):
65
+ """
66
+ Vector Learnable Multipliers: W̃ = diag(r)·W·diag(c)
67
+
68
+ From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers":
69
+ Frees not only the overall matrix norm but also individual row/column norms from
70
+ the WD-noise equilibrium, enabling richer feature scale diversity.
71
+
72
+ Args:
73
+ dim: Dimension size for the multiplier vector
74
+ multiplier_type: Either "row" or "column"
75
+ initial_value: Initial multiplier value (default: 1.0)
76
+ """
77
+ def __init__(self, dim: int, multiplier_type: str = "row", initial_value: float = 1.0):
78
+ super().__init__()
79
+ self.multiplier_type = multiplier_type
80
+ self.multiplier = nn.Parameter(torch.ones(dim) * initial_value)
81
+
82
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
83
+ """
84
+ Apply row or column multiplier.
85
+
86
+ For row multipliers: x shape is (batch, seq, out_features) or (batch, heads, seq, head_dim)
87
+ For column multipliers: applied before matrix multiplication
88
+ """
89
+ if self.multiplier_type == "row":
90
+ # Broadcast along the last dimension (output features)
91
+ return x * self.multiplier
92
+ else: # column
93
+ # For column multipliers, typically applied before linear layer
94
+ return x * self.multiplier
95
+
96
+
97
+ class LinearWithMultipliers(nn.Module):
98
+ """
99
+ Linear layer with optional row and/or column learnable multipliers.
100
+
101
+ Implements: y = (r ⊙ (W @ (c ⊙ x))) + b
102
+ where r and c are learnable multipliers, W is the base weight matrix.
103
+
104
+ From "Learnable Multipliers: Freeing the Scale of Language Model Matrix Layers":
105
+ The base matrix W remains subject to WD-noise equilibrium with ||W|| ∝ √(η/λ),
106
+ while multipliers r,c learn freely to adapt the effective scale to data.
107
+
108
+ Args:
109
+ in_features: Input feature dimension
110
+ out_features: Output feature dimension
111
+ bias: Whether to include bias term
112
+ use_row_multiplier: Enable row (output) multipliers
113
+ use_column_multiplier: Enable column (input) multipliers
114
+ """
115
+ def __init__(
116
+ self,
117
+ in_features: int,
118
+ out_features: int,
119
+ bias: bool = True,
120
+ use_row_multiplier: bool = False,
121
+ use_column_multiplier: bool = False
122
+ ):
123
+ super().__init__()
124
+
125
+ # Base weight matrix (subject to WD)
126
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
127
+
128
+ # Learnable multipliers (NOT subject to WD)
129
+ self.use_row_multiplier = use_row_multiplier
130
+ self.use_column_multiplier = use_column_multiplier
131
+
132
+ if use_row_multiplier:
133
+ self.row_multiplier = VectorMultiplier(out_features, multiplier_type="row")
134
+
135
+ if use_column_multiplier:
136
+ self.column_multiplier = VectorMultiplier(in_features, multiplier_type="column")
137
+
138
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
139
+ # Apply column multiplier before linear transformation
140
+ if self.use_column_multiplier:
141
+ x = self.column_multiplier(x)
142
+
143
+ # Linear transformation with base weights
144
+ x = self.linear(x)
145
+
146
+ # Apply row multiplier after linear transformation
147
+ if self.use_row_multiplier:
148
+ x = self.row_multiplier(x)
149
+
150
+ return x
151
+
152
+
153
+ # ==================== ORIGINAL COMPONENTS ====================
154
+
155
  class FANLayer(nn.Module):
156
  """
157
  Fourier Analysis Network (FAN) layer for effective periodicity modeling.
 
315
  return f"dim={self.dim}, eps={self.eps}"
316
 
317
 
318
+ class NeoLLMRotaryEmbedding(nn.Module):
319
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
 
 
 
 
 
 
320
 
321
+ def __init__(self, config: NeoLLMConfig, device=None):
322
+ super().__init__()
323
+ # BC: "rope_type" was originally "type"
324
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
325
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
326
+ else:
327
+ self.rope_type = "default"
328
+ self.max_seq_len_cached = config.max_position_embeddings
329
+ self.original_max_seq_len = config.max_position_embeddings
330
 
331
+ self.config = config
332
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
333
 
334
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  self.register_buffer("inv_freq", inv_freq, persistent=False)
336
+ self.original_inv_freq = self.inv_freq
337
+
338
+ @torch.no_grad()
339
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
340
+ def forward(self, x, position_ids):
341
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
342
+ position_ids_expanded = position_ids[:, None, :].float()
343
+
344
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
345
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
346
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
347
+ emb = torch.cat((freqs, freqs), dim=-1)
348
+ cos = emb.cos() * self.attention_scaling
349
+ sin = emb.sin() * self.attention_scaling
350
+
351
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
352
+
353
+
354
+ def rotate_half(x):
355
+ """Rotates half the hidden dims of the input."""
356
+ x1 = x[..., : x.shape[-1] // 2]
357
+ x2 = x[..., x.shape[-1] // 2 :]
358
+ return torch.cat((-x2, x1), dim=-1)
359
+
360
+
361
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
362
+ """Applies Rotary Position Embedding to the query and key tensors."""
363
+ cos = cos.unsqueeze(unsqueeze_dim)
364
+ sin = sin.unsqueeze(unsqueeze_dim)
365
+
366
+ # Keep half or full tensor for later concatenation
367
+ rotary_dim = cos.shape[-1]
368
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
369
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
370
+
371
+ # Apply rotary embeddings on the first half or full tensor
372
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
373
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
374
+
375
+ # Concatenate back to full shape
376
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
377
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
378
+ return q_embed, k_embed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
 
381
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 
400
  dropout: float = 0.0,
401
  **kwargs: Unpack[TransformersKwargs],
402
  ):
 
 
 
 
 
 
403
  key_states = repeat_kv(key, module.num_key_value_groups)
404
  value_states = repeat_kv(value, module.num_key_value_groups)
405
 
 
406
  attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
 
407
  if attention_mask is not None:
408
  causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
409
  attn_weights = attn_weights + causal_mask
 
419
  class NeoLLMAttention(nn.Module):
420
  """
421
  Multi-headed attention with FANformer integration, SeeDNorm for Q/K normalization,
422
+ ResFormer feature residual connections, and Learnable Multipliers for enhanced
423
+ information flow and scale adaptation.
424
 
425
  ResFormer enhancement: Applies learnable feature residual connections from the first layer
426
  BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n
427
 
428
+ Learnable Multipliers placement (from "Learnable Multipliers" paper Appendix C):
429
+ - Q projection: row multipliers only (enables per-head attention scaling in GQA)
430
+ - K, V projections: no multipliers (avoids redundancy with Q multipliers)
431
+ - Output projection: row + column multipliers (maximally expressive without symmetries)
432
  """
433
 
434
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
 
436
  self.config = config
437
  self.layer_idx = layer_idx
438
  self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
439
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
 
 
 
 
440
  self.scaling = self.head_dim**-0.5
441
  self.attention_dropout = config.attention_dropout
442
  self.is_causal = True
 
450
  # Calculate the output dimension after FAN transformation
451
  fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio', 0.125))
452
 
453
+ # Q projection with row multipliers (per-head scaling capability)
454
+ self.q_proj = LinearWithMultipliers(
455
+ fan_output_dim,
456
+ config.num_attention_heads * self.head_dim * 2,
457
+ bias=config.attention_bias,
458
+ use_row_multiplier=True,
459
+ use_column_multiplier=False
460
  )
461
+
462
+ # K, V projections without multipliers (avoids Q-K symmetry)
463
  self.k_proj = nn.Linear(
464
+ fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
465
  )
466
  self.v_proj = nn.Linear(
467
+ fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
468
  )
469
+
470
+ # Output projection with row + column multipliers (maximally expressive)
471
+ self.o_proj = LinearWithMultipliers(
472
+ config.num_attention_heads * self.head_dim,
473
+ config.hidden_size,
474
+ bias=config.attention_bias,
475
+ use_row_multiplier=True,
476
+ use_column_multiplier=True
477
  )
478
 
479
  # SeeDNorm for Q/K normalization (replaces RMSNorm)
480
  self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
481
  self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
482
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  # Dropout for attention output
484
  self.dropout = nn.Dropout(config.dropout_rate)
485
 
 
496
  **kwargs: Unpack[FlashAttentionKwargs],
497
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
498
  input_shape = hidden_states.shape[:-1]
 
499
 
500
  # Apply FANformer transformation first
501
  hidden_states_fan = self.fan_layer(hidden_states)
502
 
503
  # ResFormer: Apply feature residual connection BEFORE projections
504
+ # This ensures dimensional compatibility across all layer types
505
  if first_layer_fan is not None:
506
  hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
507
 
508
  # Store current FAN features for potential use as first_layer_fan in subsequent layers
509
  current_layer_fan = hidden_states_fan.clone()
510
 
511
+ hidden_shape = (*input_shape, -1, self.head_dim)
512
+
513
+ # Use FAN-transformed features (with residual applied) for projections
514
+ # Q projection with learnable row multipliers
515
  query_states, gate = torch.chunk(
516
+ self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
 
 
 
 
 
 
 
 
 
517
  )
518
+ gate = gate.reshape(*input_shape, -1)
519
+
520
+ # Apply SeeDNorm to Q and K
521
+ query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
522
+ key_states = self.k_norm(self.k_proj(hidden_states_fan).view(hidden_shape)).transpose(1, 2)
523
+ value_states = self.v_proj(hidden_states_fan).view(hidden_shape).transpose(1, 2)
524
+
525
+ cos, sin = position_embeddings
526
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
527
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
  attention_interface: Callable = eager_attention_forward
529
  if self.config._attn_implementation != "eager":
530
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
 
540
  **kwargs,
541
  )
542
 
543
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
 
 
 
544
  attn_output = attn_output * torch.sigmoid(gate)
545
 
546
+ # Output projection with learnable row + column multipliers
547
  attn_output = self.o_proj(attn_output)
548
  attn_output = self.dropout(attn_output)
549
 
550
  return attn_output, attn_weights, current_layer_fan
551
 
552
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  class PolyNorm(torch.nn.Module):
554
  def __init__(self, eps=1e-6):
555
  super(PolyNorm, self).__init__()
 
566
 
567
  class NeoLLMMLP(nn.Module):
568
  """
569
+ MLP with FANformer integration for featural periodicity modeling and
570
+ Learnable Multipliers for adaptive scale control.
571
 
572
  This captures periodicities in the feature space (semantic/embedding dimensions)
573
  complementary to the relational periodicities captured by attention mechanisms.
574
  Works in conjunction with ResFormer for comprehensive information flow.
575
+
576
+ Learnable Multipliers placement (from "Learnable Multipliers" paper Appendix C):
577
+ - gate_proj: row multipliers only (controls gating mechanism scale)
578
+ - up_proj: no multipliers (avoids redundancy with down_proj)
579
+ - down_proj: row + column multipliers (maximally expressive output scaling)
580
  """
581
  def __init__(self, config):
582
  super().__init__()
 
584
  self.hidden_size = config.hidden_size
585
  self.intermediate_size = config.intermediate_size
586
 
587
+ # FANformer integration for featural space periodicity
588
  self.fan_layer = FANLayer(
589
  hidden_size=config.hidden_size,
590
  fan_ratio=getattr(config, 'fan_ratio_ffn', 0.0625) # Half of attention's fan_ratio
 
593
  # Calculate the output dimension after FAN transformation
594
  fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio_ffn', 0.0625))
595
 
596
+ # SwiGLU/Gated architecture with learnable multipliers
597
+ # gate_proj: row multipliers for gating scale control
598
+ self.gate_proj = LinearWithMultipliers(
599
+ fan_output_dim,
600
+ self.intermediate_size,
601
+ bias=False,
602
+ use_row_multiplier=True,
603
+ use_column_multiplier=False
604
+ )
605
+
606
+ # up_proj: no multipliers (avoids redundancy)
607
  self.up_proj = nn.Linear(fan_output_dim, self.intermediate_size, bias=False)
608
+
609
+ # down_proj: row + column multipliers (maximally expressive)
610
+ self.down_proj = LinearWithMultipliers(
611
+ self.intermediate_size,
612
+ self.hidden_size,
613
+ bias=False,
614
+ use_row_multiplier=True,
615
+ use_column_multiplier=True
616
+ )
617
+
618
  self.act_fn = PolyNorm()
619
 
620
  # Dropout for MLP hidden layer
621
  self.dropout = nn.Dropout(config.dropout_rate)
622
 
623
  def forward(self, x):
624
+ # Apply FAN transformation before projections
625
  x_fan = self.fan_layer(x)
626
 
627
  # Use FAN-transformed features for gate and up projections
 
633
 
634
 
635
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
636
+ """
637
+ Decoder layer with standard residual connections.
638
+
639
+ Arquitectura:
640
+ 1. Pre-norm (SeeDNorm) → LNS scaling → Self-Attention con ResFormer y Learnable Multipliers
641
+ 2. Standard Residual Connection (suma simple)
642
+ 3. GPAS activation scaling
643
+ 4. Pre-norm (SeeDNorm) → LNS scaling → MLP con FANformer y Learnable Multipliers
644
+ 5. Standard Residual Connection (suma simple)
645
+ 6. GPAS activation scaling
646
+ """
647
+
648
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
649
  super().__init__()
650
  self.hidden_size = config.hidden_size
651
  self.layer_idx = layer_idx
652
 
653
+ # Full attention with learnable multipliers
654
+ self.self_attn = NeoLLMAttention(config, layer_idx)
 
 
 
 
655
 
656
+ # MLP with FANformer integration and learnable multipliers
657
  self.mlp = NeoLLMMLP(config)
658
 
659
  # SeeDNorm for input and post-attention normalization (replaces RMSNorm)
 
679
  first_layer_fan: Optional[torch.Tensor] = None,
680
  **kwargs: Unpack[FlashAttentionKwargs],
681
  ) -> torch.FloatTensor:
682
+ # ============================================================
683
+ # Attention Block with standard residual connection
684
+ # ============================================================
685
  residual = hidden_states
686
 
687
  # Apply SeeDNorm normalization
 
690
  # Apply LNS scaling after normalization
691
  hidden_states = self.lns_attn(hidden_states)
692
 
693
+ # Self Attention with ResFormer feature residual connections and learnable multipliers
694
+ hidden_states, _, self.current_layer_fan = self.self_attn(
695
+ hidden_states=hidden_states,
696
+ attention_mask=attention_mask,
697
+ position_embeddings=position_embeddings,
698
+ first_layer_fan=first_layer_fan,
699
+ **kwargs,
700
+ )
 
 
 
 
 
 
 
 
701
 
702
  # Standard residual connection
703
  hidden_states = residual + hidden_states
 
705
  # Apply GPAS after attention residual connection
706
  hidden_states = self.gpas_attn(hidden_states)
707
 
708
+ # ============================================================
709
+ # MLP Block with standard residual connection
710
+ # ============================================================
711
  residual = hidden_states
712
  hidden_states = self.post_attention_layernorm(hidden_states)
713
 
714
  # Apply LNS scaling after normalization
715
  hidden_states = self.lns_mlp(hidden_states)
716
 
717
+ # MLP now includes FAN transformation and learnable multipliers internally
718
  hidden_states = self.mlp(hidden_states)
719
 
720
  # Standard residual connection
 
727
 
728
 
729
  class NeoLLMPreTrainedModel(PreTrainedModel):
730
+ """
731
+ Base class for NeoLLM models with custom weight initialization.
732
+
733
+ Handles initialization for:
734
+ - NeoLLMAttention (ResFormer lambda parameters)
735
+ - GPAS (Gradient-Preserving Activation Scaling)
736
+ - FANLayer (Fourier Analysis Network)
737
+ - SeeDNorm (Self-Rescaled Dynamic Normalization)
738
+ - Learnable Multipliers (ScalarMultiplier, VectorMultiplier)
739
+ """
740
  config: NeoLLMConfig
741
  base_model_prefix = "model"
742
  supports_gradient_checkpointing = True
 
746
  _is_stateful = True
747
 
748
  def _init_weights(self, module):
749
+ """
750
+ Initialize weights for all custom modules in NeoLLM.
751
+
752
+ Strategy:
753
+ - Standard layers (Linear, Embedding): handled by parent class
754
+ - Custom modules: specialized initialization per component
755
+ - Learnable Multipliers: initialized to 1.0 for identity transformation
756
+ """
757
  super()._init_weights(module)
758
+
759
+ if isinstance(module, NeoLLMAttention):
 
 
 
 
 
 
 
760
  # ResFormer: initialize lambda parameters for full attention
761
+ # Lambda values control the interpolation between first layer and current layer features
762
+ # Starting at 0.5 provides balanced contribution from both sources
763
  if hasattr(module, 'lambda_1'):
764
  module.lambda_1.data.fill_(0.5)
765
  if hasattr(module, 'lambda_2'):
766
  module.lambda_2.data.fill_(0.5)
767
+
768
  elif isinstance(module, GPAS):
769
  # Initialize GPAS alpha to 0 as per paper
770
+ # This starts with no activation scaling, allowing the model to learn gradually
771
  module.alpha.data.fill_(0.0)
772
+
773
  elif isinstance(module, FANLayer):
774
+ # FANLayer initialization is handled within the class __init__
775
+ # Uses normal initialization with std=0.02 for weights
776
  pass
777
+
778
  elif isinstance(module, SeeDNorm):
779
+ # SeeDNorm initialization (parameters already initialized correctly in __init__):
780
+ # gamma (γ) initialized to 1 (static scaling component, like RMSNorm)
781
+ # beta (β) initialized to 0 (self-rescaling starts disabled)
782
+ # alpha (α) initialized to 1 (dynamic modulation at full strength)
783
  pass
784
+
785
+ elif isinstance(module, (ScalarMultiplier, VectorMultiplier)):
786
+ # Learnable Multipliers: initialize to 1.0 for identity transformation
787
+ # This allows the model to start from the standard behavior and learn
788
+ # scale adaptations from data without initial bias
789
+ if hasattr(module, 'multiplier'):
790
+ module.multiplier.data.fill_(1.0)
791
 
792
  class NeoLLMModel(NeoLLMPreTrainedModel):
793
+ """
794
+ NeoLLM base model with transformer decoder architecture.
795
+
796
+ Note on embeddings and weight tying: This model uses weight tying between
797
+ embed_tokens and lm_head (shared weights). Following "Learnable Multipliers"
798
+ paper analysis, we do NOT add multipliers to embeddings because:
799
+
800
+ 1. Weight tying creates conflicting gradient paths: multipliers would scale
801
+ gradients from embedding lookup but not from lm_head projection, causing
802
+ the multiplier to receive incomplete optimization signals.
803
+
804
+ 2. The paper explicitly warns against multipliers in lm_head (creates shortcuts
805
+ for learning marginal token distribution), and with weight tying this
806
+ restriction propagates to embeddings.
807
+
808
+ 3. Compensating mechanisms provide scale adaptation immediately after embedding:
809
+ - First layer attention has multipliers in Q/O projections
810
+ - FANformer transforms the representation space
811
+ - SeeDNorm provides input-dependent dynamic scaling
812
+ - ResFormer propagates first-layer features with learnable scaling
813
+ """
814
+
815
  def __init__(self, config: NeoLLMConfig):
816
  super().__init__(config)
817
+
818
+ # Standard embedding without learnable multipliers
819
+ # Due to weight tying with lm_head, multipliers would create
820
+ # conflicting optimization dynamics (see class docstring)
821
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
822
 
823
  # Each layer creates its own components (no shared parameters)
824
  self.layers = nn.ModuleList(
825
  [NeoLLMDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
826
  )
827
+
828
  # SeeDNorm for final output normalization (replaces RMSNorm)
829
  self.norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
830
+ self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
 
 
 
 
 
 
 
 
831
  self.gradient_checkpointing = False
832
 
833
  # ResFormer: storage for first layer's FAN features (H_fan_1)
 
848
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
849
 
850
  if inputs_embeds is None:
851
+ # Standard embedding lookup without multipliers
852
+ # Scale adaptation occurs in subsequent layers via:
853
+ # (1) First layer attention multipliers, (2) FANformer transformation,
854
+ # (3) SeeDNorm dynamic scaling, (4) ResFormer feature propagation
855
  inputs_embeds = self.embed_tokens(input_ids)
856
 
857
  if position_ids is None:
 
865
  past_key_values=None,
866
  position_ids=position_ids,
867
  )
 
868
 
869
  hidden_states = inputs_embeds
870
 
871
+ # create position embeddings to be shared across the decoder layers
872
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
 
873
 
874
  # ResFormer: reset first_layer_fan at the start of each forward pass
875
  self.first_layer_fan = None
876
 
877
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
 
 
878
  hidden_states = decoder_layer(
879
  hidden_states,
880
  position_embeddings=position_embeddings,
881
+ attention_mask=causal_mask,
882
  first_layer_fan=self.first_layer_fan, # Pass H_fan_1 to all layers
883
  **kwargs,
884
  )
 
895
  past_key_values=None,
896
  )
897
 
 
 
 
 
 
 
 
 
 
 
898
 
899
  @torch.compiler.disable
900
  def compute_cce_loss(hidden_states, labels, lm_head_weight, lm_head_bias=None, pad_token_id=None):
 
925
 
926
 
927
  class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
928
+ """
929
+ Causal Language Model with NeoLLM architecture.
930
+
931
+ Note on LM head: Following "Learnable Multipliers" paper recommendations,
932
+ the output projection (lm_head) does NOT include learnable multipliers because:
933
+ 1. The preceding RMSNorm (self.model.norm) already acts as column multipliers
934
+ 2. Adding row multipliers to lm_head can create shortcuts where the model
935
+ learns marginal token distribution without updating internal features
936
+ """
937
  _tied_weights_keys = ["lm_head.weight"]
938
 
939
  def __init__(self, config):
940
  super().__init__(config)
941
  self.model = NeoLLMModel(config)
942
  self.vocab_size = config.vocab_size
943
+
944
+ # LM head without learnable multipliers (standard linear layer)
945
+ # Preceding norm layer provides sufficient scale adaptation
946
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
947
+
948
  self.post_init()
949
 
950
  def forward(
 
1001
  "NeoLLMConfig",
1002
  "FANLayer",
1003
  "SeeDNorm",
1004
+ "ScalarMultiplier",
1005
+ "VectorMultiplier",
1006
+ "LinearWithMultipliers",
1007
  ]
1008
 
1009
  # Register the configuration and model for AutoClass support