KitsuVp commited on
Commit
44526de
·
verified ·
1 Parent(s): ac61dcd

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +327 -504
modeling_neollm.py CHANGED
@@ -1,18 +1,7 @@
1
  #!/usr/bin/env python3
2
  """
3
- NeoLLM Model with FANformer Integration in both Attention and FFN, Dropout Regularization,
4
- SeeDNorm (Self-Rescaled Dynamic Normalization), ResFormer Value Residual Learning,
5
- Learnable Multipliers for enhanced scale adaptation and information flow through deep layers,
6
- and StackMemory for hierarchical pattern modeling.
7
- Updated to include:
8
- - Fourier Analysis Network (FAN) layer for effective periodicity modeling in attention (relational space)
9
- - FAN layer in FFN for featural periodicity modeling (complementary coverage)
10
- - SeeDNorm: Dynamic normalization with input-dependent scaling for better adaptability
11
- - Dropout regularization at strategic locations
12
- - ResFormer: Feature residual connections from first layer (applied before projections)
13
- - Learnable Multipliers: Frees weight matrix scale from WD-noise equilibrium for data-adaptive scaling
14
- - StackMemory: Differentiable hidden state stack for modeling Chomsky hierarchy grammars
15
- - Full Attention only (linear attention removed)
16
  """
17
 
18
  import math
@@ -36,7 +25,6 @@ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_u
36
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
  from transformers.processing_utils import Unpack
38
  from transformers.utils import TransformersKwargs, logging
39
- from transformers.utils.generic import check_model_inputs
40
  from configuration_neollm import NeoLLMConfig
41
 
42
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
@@ -259,6 +247,7 @@ class SeeDNorm(nn.Module):
259
  Self-Rescaled Dynamic Normalization (SeeDNorm) with dual dropout regularization.
260
 
261
  SeeDNorm(x) = [σ(x·β^T)·α + γ] ⊙ x/RMS(x)
 
262
 
263
  Args:
264
  dim: Hidden dimension size
@@ -300,7 +289,7 @@ class SeeDNorm(nn.Module):
300
  Normalized and dynamically scaled tensor of same shape
301
  """
302
 
303
- x_for_dynamic = F.dropout(x, p=self.dropout_input, training=self.training)
304
  rescale_factor = torch.tanh(torch.sum(x_for_dynamic * self.beta,
305
  dim=-1, keepdim=True))
306
 
@@ -310,7 +299,7 @@ class SeeDNorm(nn.Module):
310
  # Apply RMS normalization on ORIGINAL input (not dropped version)
311
  x_normalized = self._rms_norm(x.float())
312
 
313
- x_normalized = F.dropout(x_normalized, p=self.dropout_hidden, training=self.training)
314
 
315
  # Apply dynamic scaling
316
  output = x_normalized * dynamic_scale.float()
@@ -320,263 +309,6 @@ class SeeDNorm(nn.Module):
320
  def extra_repr(self) -> str:
321
  return (f"dim={self.dim}, eps={self.eps}, "
322
  f"dropout_input={self.dropout_input}, dropout_hidden={self.dropout_hidden}")
323
-
324
-
325
- # ==================== STACK MEMORY MODULE ====================
326
- class StackMemory(nn.Module):
327
- """
328
- From "Improving Formal Reasoning of Transformer with State Stack":
329
- Implements a multi-head differentiable stack with soft push, pop, and no-op operations.
330
- Each head maintains its own stack and mask, which are updated based on learned action
331
- probabilities. Global reading is performed via query-over-stack attention.
332
-
333
- This module is inserted between Transformer layers to augment information flow with
334
- stack-like memory operations, enabling the model to better capture hierarchical and
335
- recursive patterns characteristic of regular expressions and context-free grammars.
336
-
337
- Note: StackMemory uses standard nn.Linear to maintain architectural
338
- independence and avoid introducing additional complexity in the memory operations.
339
-
340
- Args:
341
- config: Model configuration containing stack-related hyperparameters
342
- """
343
-
344
- def __init__(self, config: NeoLLMConfig):
345
- super().__init__()
346
- self.config = config
347
- self.num_stack_heads = getattr(config, 'num_stack_heads', 4)
348
- self.stack_slots = getattr(config, 'stack_slots', 24)
349
- self.stack_d_model = getattr(config, 'stack_d_model', 128)
350
-
351
- self.head_dim = self.stack_d_model // self.num_stack_heads
352
-
353
- # Dimension reduction projections for efficiency
354
- # Uses standard nn.Linear
355
- self.down_proj = nn.Linear(config.hidden_size, self.stack_d_model, bias=True)
356
- self.up_proj = nn.Linear(self.stack_d_model, config.hidden_size, bias=True)
357
-
358
- # Action prediction: generates push/pop/no-op probabilities for each head
359
- self.action_head = nn.Linear(self.stack_d_model, 3 * self.num_stack_heads, bias=True)
360
-
361
- # Query projection for global reading (one per head)
362
- self.gate_proj = nn.Linear(self.head_dim, 1, bias=True)
363
-
364
- # Residual weight for gating stack contribution
365
- self.res_weight = nn.Parameter(torch.ones(1))
366
-
367
- # Cache for autoregressive generation (matches OLMo reference)
368
- self.cache_size = getattr(config, "cache_size", 2048)
369
- # Initialization fix: Register buffers for cache
370
- # Default to batch_size=1 if forward_bs is not in config (standard inference)
371
- forward_bs = getattr(config, 'forward_bs', 1)
372
- self.register_buffer("k_cache", torch.zeros(forward_bs, self.cache_size, self.num_stack_heads, self.head_dim))
373
- self.register_buffer("action_cache", torch.zeros(forward_bs, self.cache_size, self.num_stack_heads, 3))
374
-
375
- self.cache_position = 0
376
- self.enable_cache = False
377
-
378
- def reset_cache(self):
379
- self.cache_position = 0
380
-
381
- def _vectorized_update(
382
- self,
383
- stack: torch.Tensor,
384
- mask: torch.Tensor,
385
- actions: torch.Tensor,
386
- k_values: torch.Tensor
387
- ) -> Tuple[torch.Tensor, torch.Tensor]:
388
- """
389
- Vectorized stack update mechanism applying soft push/pop/no-op operations.
390
-
391
- Implements the differentiable stack operations from the paper:
392
- - Push: shifts all elements down and places k_values at top
393
- - Pop: shifts all elements up and removes top
394
- - No-op: maintains current stack state
395
-
396
- Args:
397
- stack: Current stack state [batch, seq, num_heads, stack_slots, head_dim]
398
- mask: Current stack mask [batch, seq, num_heads, stack_slots]
399
- actions: Action probabilities [batch, seq, num_heads, 3] (push/pop/no-op)
400
- k_values: New values to push [batch, seq, num_heads, head_dim]
401
-
402
- Returns:
403
- Tuple of (updated_stack, updated_mask)
404
- """
405
- batch_size, seq_len = actions.shape[:2]
406
-
407
- # Expand stack and mask along sequence dimension for parallel processing
408
- # Only expand if checking against initial state dimensions (4D)
409
- if stack.dim() == 4:
410
- stack = stack.unsqueeze(1).expand(-1, seq_len, -1, -1, -1)
411
- mask = mask.unsqueeze(1).expand(-1, seq_len, -1, -1)
412
-
413
- # Generate pushed stack: new value at top, shift others down
414
- push_stack = torch.cat([
415
- k_values.unsqueeze(3), # New value at position 0
416
- stack[:, :, :, :-1] # Shift existing elements down
417
- ], dim=3)
418
- push_mask = torch.cat([
419
- torch.ones_like(mask[:, :, :, :1]),
420
- mask[:, :, :, :-1]
421
- ], dim=3)
422
-
423
- # Generate popped stack: shift all up, zero at bottom
424
- pop_stack = torch.cat([
425
- stack[:, :, :, 1:],
426
- torch.zeros_like(stack[:, :, :, :1])
427
- ], dim=3)
428
- pop_mask = torch.cat([
429
- mask[:, :, :, 1:],
430
- torch.zeros_like(mask[:, :, :, :1])
431
- ], dim=3)
432
-
433
- # Combine operations weighted by action probabilities
434
- action_weights = actions.unsqueeze(-1).unsqueeze(-1) # [batch, seq, heads, 3, 1, 1]
435
- stacks = torch.stack([push_stack, pop_stack, stack], dim=3) # [batch, seq, heads, 3, slots, dim]
436
- masks = torch.stack([push_mask, pop_mask, mask], dim=3) # [batch, seq, heads, 3, slots]
437
-
438
- # Weighted combination of all operations
439
- new_stack = (stacks * action_weights).sum(dim=3)
440
- new_mask = (masks * action_weights.squeeze(-1)).sum(dim=3)
441
-
442
- return new_stack, new_mask
443
-
444
- def forward(
445
- self,
446
- hidden_states: torch.Tensor,
447
- stack: Optional[torch.Tensor] = None,
448
- mask: Optional[torch.Tensor] = None
449
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
450
- """
451
- Apply differentiable stack operations to hidden states.
452
-
453
- Args:
454
- hidden_states: Input hidden states [batch, seq, hidden_size]
455
- stack: Previous stack state [batch, num_heads, stack_slots, head_dim] or None
456
- mask: Previous stack mask [batch, num_heads, stack_slots] or None
457
-
458
- Returns:
459
- Tuple of (output_hidden_states, updated_stack, updated_mask)
460
- """
461
- batch_size, seq_len, _ = hidden_states.shape
462
- device = hidden_states.device
463
-
464
- # Initialize stack and mask if not provided
465
- if stack is None:
466
- stack = torch.zeros(
467
- batch_size, self.num_stack_heads, self.stack_slots, self.head_dim,
468
- device=device, dtype=hidden_states.dtype
469
- )
470
- if mask is None:
471
- mask = torch.zeros(
472
- batch_size, self.num_stack_heads, self.stack_slots,
473
- device=device, dtype=hidden_states.dtype
474
- )
475
-
476
- # Project to lower dimension for efficiency
477
- new_hidden_states = self.down_proj(hidden_states)
478
-
479
- # Generate action probabilities: [batch, seq, num_heads, 3]
480
- action_logits = self.action_head(new_hidden_states) / math.sqrt(self.head_dim)
481
- actions = F.softmax(
482
- action_logits.view(batch_size, seq_len, self.num_stack_heads, 3),
483
- dim=-1
484
- )
485
-
486
- # Prepare values to push (split into heads)
487
- k_values = new_hidden_states.view(batch_size, seq_len, self.num_stack_heads, self.head_dim)
488
-
489
- # Update stack and mask using vectorized operations
490
- new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values)
491
-
492
- # Global reading via query-over-stack attention
493
- gate_scores = self.gate_proj(new_stack).squeeze(-1) # [batch, seq, heads, slots]
494
-
495
- gate_weights = F.softmax(gate_scores + (1 - new_mask) * -1e9, dim=-1)
496
-
497
- # Weighted sum over stack slots
498
- memory_output = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3)
499
- memory_output = memory_output.view(batch_size, seq_len, -1)
500
-
501
- memory_output = self.up_proj(memory_output)
502
-
503
- # Residual Connection
504
- output = memory_output * self.res_weight + hidden_states
505
-
506
- # Update Cache Logic
507
- if self.enable_cache:
508
- self._update_cache(k_values.detach(), actions.detach())
509
-
510
- return output, new_stack[:, -1], new_mask[:, -1]
511
-
512
- def _update_cache(self, k_values: torch.Tensor, actions: torch.Tensor):
513
- seq_len = k_values.shape[1]
514
- if self.cache_position + seq_len <= self.cache_size:
515
- # Assumes standard batch processing for inference (usually batch_size=1)
516
- self.k_cache[:, self.cache_position:self.cache_position+seq_len] = k_values
517
- self.action_cache[:, self.cache_position:self.cache_position+seq_len] = actions
518
- self.cache_position += seq_len
519
- else:
520
- self.reset_cache()
521
-
522
- def step(self, hidden_state: torch.Tensor, stack: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
523
- if not self.enable_cache:
524
- return self.forward(hidden_state.unsqueeze(1), stack, mask)
525
-
526
- batch_size = hidden_state.shape[0]
527
-
528
- # Compute features for current token
529
- new_hidden_states = self.down_proj(hidden_state)
530
-
531
- action_logits = self.action_head(new_hidden_states) / math.sqrt(self.head_dim)
532
- current_actions = F.softmax(
533
- action_logits.view(batch_size, 1, self.num_stack_heads, 3),
534
- dim=-1
535
- )
536
-
537
- current_k = new_hidden_states.view(batch_size, 1, self.num_stack_heads, self.head_dim)
538
-
539
- # Reconstruct History
540
- if self.cache_position > 0:
541
- cached_k = self.k_cache[:, :self.cache_position]
542
- cached_actions = self.action_cache[:, :self.cache_position]
543
-
544
- k_values = torch.cat([cached_k, current_k], dim=1)
545
- actions = torch.cat([cached_actions, current_actions], dim=1)
546
- else:
547
- k_values = current_k
548
- actions = current_actions
549
-
550
- # Dimension Fix: Pass sequences directly without unsqueeze(0)
551
- # k_values is [batch, seq_len_total, heads, dim]
552
- # actions is [batch, seq_len_total, heads, 3]
553
-
554
- new_stack_seq, new_mask_seq = self._vectorized_update(
555
- stack, # Initial stack [batch, heads, slots, dim]
556
- mask,
557
- actions,
558
- k_values
559
- )
560
-
561
- # Extract last step
562
- current_stack = new_stack_seq[:, -1]
563
- current_mask = new_mask_seq[:, -1]
564
-
565
- gate_scores = self.gate_proj(current_stack).squeeze(-1)
566
- gate_weights = F.softmax(gate_scores + (1 - current_mask) * -1e9, dim=-1)
567
-
568
- memory_output = (current_stack * gate_weights.unsqueeze(-1)).sum(dim=2)
569
- memory_output = memory_output.view(batch_size, -1)
570
-
571
- memory_output_proj = self.up_proj(memory_output)
572
-
573
- self._update_cache(current_k, current_actions)
574
-
575
- return (
576
- memory_output_proj * self.res_weight + hidden_state,
577
- current_stack,
578
- current_mask
579
- )
580
  # ==================== ROTARY EMBEDDING ====================
581
  class NeoLLMRotaryEmbedding(nn.Module):
582
  inv_freq: torch.Tensor # fix linting for `register_buffer`
@@ -662,9 +394,6 @@ class NeoLLMRotaryEmbedding(nn.Module):
662
  sin = emb.sin() * self.attention_scaling
663
 
664
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
665
-
666
-
667
-
668
  def rotate_half(x):
669
  """Rotates half the hidden dims of the input."""
670
  x1 = x[..., : x.shape[-1] // 2]
@@ -677,16 +406,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
677
  cos = cos.unsqueeze(unsqueeze_dim)
678
  sin = sin.unsqueeze(unsqueeze_dim)
679
 
680
- # Keep half or full tensor for later concatenation
681
  rotary_dim = cos.shape[-1]
682
  q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
683
  k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
684
 
685
- # Apply rotary embeddings on the first half or full tensor
686
  q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
687
  k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
688
 
689
- # Concatenate back to full shape
690
  q_embed = torch.cat([q_embed, q_pass], dim=-1)
691
  k_embed = torch.cat([k_embed, k_pass], dim=-1)
692
  return q_embed, k_embed
@@ -704,6 +430,98 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
704
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
705
 
706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707
  def eager_attention_forward(
708
  module: nn.Module,
709
  query: torch.Tensor,
@@ -732,17 +550,9 @@ def eager_attention_forward(
732
 
733
  class NeoLLMAttention(nn.Module):
734
  """
735
- Multi-headed attention with FANformer integration, SeeDNorm for Q/K normalization,
736
- ResFormer feature residual connections, and Learnable Multipliers for enhanced
737
- information flow and scale adaptation.
738
-
739
- ResFormer enhancement: Applies learnable feature residual connections from first layer
740
- BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n
741
-
742
- Learnable Multipliers placement (from "Learnable Multipliers" paper Appendix C):
743
- - Q projection: row multipliers only (enables per-head attention scaling in GQA)
744
- - K, V projections: no multipliers (avoids redundancy with Q multipliers)
745
- - Output projection: row + column multipliers (maximally expressive without symmetries)
746
  """
747
 
748
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
@@ -752,54 +562,141 @@ class NeoLLMAttention(nn.Module):
752
  self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
753
  self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
754
  self.scaling = self.head_dim**-0.5
 
755
  self.attention_dropout = config.attention_dropout
756
  self.is_causal = True
757
-
758
- # FANformer integration: FAN layer before QKV projections
 
 
 
 
 
 
 
 
 
759
  self.fan_layer = FANLayer(
760
- hidden_size=config.hidden_size,
761
- fan_ratio=getattr(config, 'fan_ratio', 0.125)
762
  )
763
-
764
- # Calculate the output dimension after FAN transformation
765
- fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, 'fan_ratio', 0.125))
766
-
767
- # Q projection with row multipliers (per-head scaling capability)
768
  self.q_proj = LinearWithMultipliers(
769
- fan_output_dim,
770
- config.num_attention_heads * self.head_dim * 2,
771
  bias=config.attention_bias,
772
  use_row_multiplier=True,
773
- use_column_multiplier=False
 
 
 
774
  )
775
-
776
- # K, V projections without multipliers (avoids Q-K symmetry)
777
  self.k_proj = nn.Linear(
778
- fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
779
  )
780
  self.v_proj = nn.Linear(
781
- fan_output_dim, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
782
  )
783
-
784
- # Output projection with row + column multipliers (maximally expressive)
785
  self.o_proj = LinearWithMultipliers(
786
  config.num_attention_heads * self.head_dim,
787
  config.hidden_size,
788
  bias=config.attention_bias,
789
  use_row_multiplier=True,
790
- use_column_multiplier=True
791
  )
792
-
793
- # SeeDNorm for Q/K normalization (replaces RMSNorm)
794
  self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
795
  self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
796
-
797
- # Dropout for attention output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
798
  self.dropout = nn.Dropout(config.dropout_rate)
799
-
800
- # ResFormer: learnable feature residual parameters (initialized to 0.5)
801
- self.lambda_1 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_1
802
- self.lambda_2 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_n
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
803
 
804
  def forward(
805
  self,
@@ -809,45 +706,31 @@ class NeoLLMAttention(nn.Module):
809
  first_layer_fan: Optional[torch.Tensor] = None,
810
  **kwargs: Unpack[FlashAttentionKwargs],
811
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
812
- """
813
- Forward pass with ResFormer feature residual connections.
814
-
815
- Args:
816
- hidden_states: Current layer input [batch, seq, hidden_size]
817
- position_embeddings: Tuple of (cos, sin) for RoPE
818
- attention_mask: Causal attention mask
819
- first_layer_fan: First layer FAN features (for ResFormer)
820
-
821
- Returns:
822
- Tuple of (attn_output, attn_weights, current_layer_fan)
823
- """
824
  input_shape = hidden_states.shape[:-1]
825
-
826
- # Apply FANformer transformation
827
  hidden_states_fan = self.fan_layer(hidden_states)
828
-
829
- # ResFormer: Apply feature residual connection BEFORE projections
830
  if first_layer_fan is not None:
831
  hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
832
-
833
- # Store current FAN features for ResFormer
834
  current_layer_fan = hidden_states_fan.clone()
835
-
836
- hidden_shape = (*input_shape, -1, self.head_dim)
837
 
838
- # Q projection with learnable row multipliers
839
  query_states, gate = torch.chunk(
840
- self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
841
  )
842
  gate = gate.reshape(*input_shape, -1)
843
 
844
- # Apply SeeDNorm to Q and K
845
- query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
846
- key_states = self.k_norm(self.k_proj(hidden_states_fan).view(hidden_shape)).transpose(1, 2)
847
- value_states = self.v_proj(hidden_states_fan).view(hidden_shape).transpose(1, 2)
848
 
849
  cos, sin = position_embeddings
850
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 
 
 
851
 
852
  attention_interface: Callable = eager_attention_forward
853
  if self.config._attn_implementation != "eager":
@@ -864,15 +747,14 @@ class NeoLLMAttention(nn.Module):
864
  **kwargs,
865
  )
866
 
 
 
867
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
868
  attn_output = attn_output * torch.sigmoid(gate)
869
-
870
- # Output projection with learnable row + column multipliers
871
  attn_output = self.o_proj(attn_output)
872
  attn_output = self.dropout(attn_output)
873
-
874
- return attn_output, attn_weights, current_layer_fan
875
 
 
876
 
877
  class PolyNorm(torch.nn.Module):
878
  def __init__(self, eps=1e-6):
@@ -957,16 +839,15 @@ class NeoLLMMLP(nn.Module):
957
 
958
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
959
  """
960
- Decoder layer with standard residual connections and optional StackMemory.
961
 
962
- Architecture (Updated Flow):
963
- 1. Optional: StackMemory module (Pre-processing context injection)
964
- 2. Pre-norm (SeeDNorm) LNS scaling → Self-Attention with ResFormer and Learnable Multipliers
965
- 3. Standard Residual Connection
966
- 4. GPAS activation scaling
967
- 5. Pre-norm (SeeDNorm) LNS scaling → MLP with FANformer and Learnable Multipliers
968
- 6. Standard Residual Connection
969
- 7. GPAS activation scaling
970
  """
971
 
972
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
@@ -980,7 +861,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
980
  # MLP with FANformer integration and learnable multipliers
981
  self.mlp = NeoLLMMLP(config)
982
 
983
- # SeeDNorm for input and post-attention normalization
984
  self.input_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
985
  self.post_attention_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
986
 
@@ -988,15 +869,10 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
988
  self.lns_attn = LNS(layer_idx)
989
  self.lns_mlp = LNS(layer_idx)
990
 
991
- # GPAS (Gradient-Preserving Activation Scaling)
992
  self.gpas_attn = GPAS(config.hidden_size)
993
  self.gpas_mlp = GPAS(config.hidden_size)
994
 
995
- # StackMemory: Differentiable hidden state stack
996
- self.use_stack = getattr(config, 'use_stack', False)
997
- if self.use_stack:
998
- self.stack_memory = StackMemory(config)
999
-
1000
  # ResFormer: storage for current layer's FAN features
1001
  self.current_layer_fan = None
1002
 
@@ -1006,39 +882,11 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
1006
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
1007
  attention_mask: Optional[torch.Tensor] = None,
1008
  first_layer_fan: Optional[torch.Tensor] = None,
1009
- stack_state: Optional[torch.Tensor] = None,
1010
- stack_mask: Optional[torch.Tensor] = None,
1011
  output_attentions: Optional[bool] = False,
1012
  **kwargs: Unpack[FlashAttentionKwargs],
1013
- ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
1014
- """
1015
- Forward pass with ResFormer and optional StackMemory.
1016
-
1017
- Args:
1018
- hidden_states: Current layer input [batch, seq, hidden_size]
1019
- position_embeddings: Tuple of (cos, sin) for RoPE
1020
- attention_mask: Causal attention mask
1021
- first_layer_fan: First layer FAN features (for ResFormer)
1022
- stack_state: StackMemory state (optional)
1023
- stack_mask: StackMemory mask (optional)
1024
- output_attentions: Whether to return attention weights
1025
-
1026
- Returns:
1027
- Tuple of (hidden_states, attn_weights, stack_state, stack_mask)
1028
- """
1029
-
1030
  # ============================================================
1031
- # 1. Stack Memory Module (MOVED TO START)
1032
- # ============================================================
1033
- # We process memory first so the Attention layer can "see" the
1034
- # retrieved context. This eliminates the 1-layer lag.
1035
- if self.use_stack:
1036
- hidden_states, stack_state, stack_mask = self.stack_memory(
1037
- hidden_states, stack_state, stack_mask
1038
- )
1039
-
1040
- # ============================================================
1041
- # 2. Attention Block with Standard Residual Connection
1042
  # ============================================================
1043
  residual = hidden_states
1044
 
@@ -1048,23 +896,24 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
1048
  # Apply LNS scaling after normalization
1049
  hidden_states = self.lns_attn(hidden_states)
1050
 
1051
- # Self Attention with ResFormer
1052
- attn_output, attn_weights, self.current_layer_fan = self.self_attn(
 
1053
  hidden_states=hidden_states,
1054
- position_embeddings=position_embeddings,
1055
  attention_mask=attention_mask,
 
1056
  first_layer_fan=first_layer_fan,
1057
  **kwargs,
1058
  )
1059
 
1060
- # Standard Residual Connection
1061
- hidden_states = residual + attn_output
1062
 
1063
- # Apply GPAS after residual connection
1064
  hidden_states = self.gpas_attn(hidden_states)
1065
 
1066
  # ============================================================
1067
- # 3. MLP Block with Standard Residual Connection
1068
  # ============================================================
1069
  residual = hidden_states
1070
  hidden_states = self.post_attention_layernorm(hidden_states)
@@ -1072,20 +921,20 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
1072
  # Apply LNS scaling after normalization
1073
  hidden_states = self.lns_mlp(hidden_states)
1074
 
1075
- # MLP with FANformer
1076
- mlp_output = self.mlp(hidden_states)
1077
 
1078
- # Standard Residual Connection
1079
- hidden_states = residual + mlp_output
1080
 
1081
- # Apply GPAS after residual connection
1082
  hidden_states = self.gpas_mlp(hidden_states)
1083
 
1084
- # Return tuple matching the expected signature
1085
- if self.use_stack:
1086
- return (hidden_states, attn_weights, stack_state, stack_mask)
1087
- else:
1088
- return (hidden_states, attn_weights, None, None)
1089
 
1090
 
1091
  class NeoLLMPreTrainedModel(PreTrainedModel):
@@ -1098,7 +947,6 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
1098
  - FANLayer (Fourier Analysis Network)
1099
  - SeeDNorm (Self-Rescaled Dynamic Normalization)
1100
  - Learnable Multipliers (ScalarMultiplier, VectorMultiplier)
1101
- - StackMemory (Differentiable Hidden State Stack)
1102
  """
1103
  config: NeoLLMConfig
1104
  base_model_prefix = "model"
@@ -1111,58 +959,90 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
1111
  def _init_weights(self, module):
1112
  """
1113
  Initialize weights for all custom modules in NeoLLM.
 
 
 
 
 
1114
  """
1115
  super()._init_weights(module)
1116
 
1117
  if isinstance(module, NeoLLMAttention):
 
 
 
1118
  if hasattr(module, 'lambda_1'):
1119
  module.lambda_1.data.fill_(0.5)
1120
  if hasattr(module, 'lambda_2'):
1121
  module.lambda_2.data.fill_(0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1122
 
1123
  elif isinstance(module, GPAS):
 
 
1124
  module.alpha.data.fill_(0.0)
1125
 
 
 
 
 
 
 
 
 
 
 
 
 
1126
  elif isinstance(module, (ScalarMultiplier, VectorMultiplier)):
 
 
 
1127
  if hasattr(module, 'multiplier'):
1128
  module.multiplier.data.fill_(1.0)
1129
-
1130
- elif isinstance(module, StackMemory):
1131
- std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02
1132
- if hasattr(module, 'down_proj'):
1133
- module.down_proj.weight.data.normal_(mean=0.0, std=std)
1134
- if hasattr(module, 'up_proj'):
1135
- module.up_proj.weight.data.normal_(mean=0.0, std=std)
1136
- if hasattr(module, 'action_head'):
1137
- module.action_head.weight.data.normal_(mean=0.0, std=std)
1138
- if module.action_head.bias is not None:
1139
- module.action_head.bias.data.zero_()
1140
- if hasattr(module, 'gate_proj'):
1141
- module.gate_proj.weight.data.normal_(mean=0.0, std=std)
1142
- if hasattr(module, 'res_weight'):
1143
- module.res_weight.data.fill_(1.0)
1144
-
1145
 
1146
  class NeoLLMModel(NeoLLMPreTrainedModel):
1147
  """
1148
  NeoLLM base model with transformer decoder architecture.
1149
 
1150
- Uses ResFormer for first-layer feature propagation with standard residual connections
1151
- and optional StackMemory for hierarchical pattern modeling.
1152
-
1153
  Note on embeddings and weight tying: This model uses weight tying between
1154
  embed_tokens and lm_head (shared weights). Following "Learnable Multipliers"
1155
  paper analysis, we do NOT add multipliers to embeddings because:
1156
 
1157
- 1. Weight tying creates conflicting gradient paths
1158
- 2. The paper explicitly warns against multipliers in lm_head
1159
- 3. Compensating mechanisms provide scale adaptation immediately after embedding
 
 
 
 
 
 
 
 
 
 
1160
  """
1161
 
1162
  def __init__(self, config: NeoLLMConfig):
1163
  super().__init__(config)
1164
 
1165
  # Standard embedding without learnable multipliers
 
 
1166
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
1167
 
1168
  # Each layer creates its own components (no shared parameters)
@@ -1175,10 +1055,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1175
  self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
1176
  self.gradient_checkpointing = False
1177
 
1178
- # Configuration
1179
- self.use_stack = getattr(config, 'use_stack', False)
1180
-
1181
- # ResFormer: storage for first layer's FAN features
1182
  self.first_layer_fan = None
1183
 
1184
  # Initialize weights and apply final processing
@@ -1193,8 +1070,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1193
  output_hidden_states: Optional[bool] = None,
1194
  output_attentions: Optional[bool] = None,
1195
  return_dict: Optional[bool] = None,
1196
- past_key_values: Optional[List[torch.FloatTensor]] = None,
1197
- use_cache: Optional[bool] = None,
1198
  **kwargs: Unpack[TransformersKwargs],
1199
  ) -> BaseModelOutputWithPast:
1200
  output_hidden_states = (
@@ -1211,6 +1086,10 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1211
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1212
 
1213
  if inputs_embeds is None:
 
 
 
 
1214
  inputs_embeds = self.embed_tokens(input_ids)
1215
 
1216
  if position_ids is None:
@@ -1226,29 +1105,16 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1226
  )
1227
 
1228
  hidden_states = inputs_embeds
1229
- next_decoder_cache = None
1230
  all_hidden_states = () if output_hidden_states else None
1231
  all_attentions = () if output_attentions else None
1232
 
1233
- # Create position embeddings to be shared across the decoder layers
1234
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
1235
 
1236
- # ResFormer with first-layer feature propagation
1237
  self.first_layer_fan = None
1238
-
1239
- # Initialize Stack states (always None at start of forward, rebuilt via cache step or vertical flow)
1240
- stack_state = None
1241
- stack_mask = None
1242
-
1243
- # Propagate use_cache and reset if starting a new sequence
1244
- if self.use_stack:
1245
- for layer in self.layers:
1246
- if hasattr(layer, 'stack_memory'):
1247
- layer.stack_memory.enable_cache = use_cache if use_cache is not None else False
1248
- if past_key_values is None:
1249
- layer.stack_memory.reset_cache()
1250
-
1251
- for decoder_layer in self.layers:
1252
  if output_hidden_states:
1253
  all_hidden_states = all_hidden_states + (hidden_states,)
1254
 
@@ -1256,9 +1122,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1256
  hidden_states,
1257
  position_embeddings=position_embeddings,
1258
  attention_mask=causal_mask,
1259
- first_layer_fan=self.first_layer_fan,
1260
- stack_state=stack_state,
1261
- stack_mask=stack_mask,
1262
  output_attentions=output_attentions,
1263
  **kwargs,
1264
  )
@@ -1268,15 +1132,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1268
  if output_attentions:
1269
  all_attentions = all_attentions + (layer_outputs[1],)
1270
 
1271
- if self.use_stack:
1272
- # Vertical memory logic:
1273
- # The layer returns updated stack for the next layer to use (Vertical passing)
1274
- # But we do NOT persist it temporally here. The Module's internal cache handles temporal.
1275
- stack_state = layer_outputs[2]
1276
- stack_mask = layer_outputs[3]
1277
-
1278
  # ResFormer: capture H_fan_1 from the first layer
1279
- # Dynamically capture for the current pass
1280
  if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
1281
  self.first_layer_fan = decoder_layer.current_layer_fan
1282
 
@@ -1287,11 +1143,11 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
1287
  all_hidden_states = all_hidden_states + (hidden_states,)
1288
 
1289
  if not return_dict:
1290
- return tuple(v for v in [hidden_states, next_decoder_cache, all_hidden_states, all_attentions] if v is not None)
1291
 
1292
  return BaseModelOutputWithPast(
1293
  last_hidden_state=hidden_states,
1294
- past_key_values=next_decoder_cache,
1295
  hidden_states=all_hidden_states,
1296
  attentions=all_attentions,
1297
  )
@@ -1346,37 +1202,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1346
 
1347
  self.post_init()
1348
 
1349
- def prepare_inputs_for_generation(
1350
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1351
- ):
1352
- if past_key_values:
1353
- past_length = past_key_values[0][0].shape[2]
1354
-
1355
- # If past_length > input_ids length, we are likely generating token by token
1356
- if input_ids.shape[1] > past_length:
1357
- remove_prefix_length = past_length
1358
- else:
1359
- # Default standard HF behavior
1360
- remove_prefix_length = input_ids.shape[1] - 1
1361
-
1362
- input_ids = input_ids[:, remove_prefix_length:]
1363
-
1364
- position_ids = kwargs.get("position_ids", None)
1365
- if attention_mask is not None and position_ids is None:
1366
- # create position_ids on the fly for batch generation
1367
- position_ids = attention_mask.long().cumsum(-1) - 1
1368
- position_ids.masked_fill_(attention_mask == 0, 1)
1369
- if past_key_values:
1370
- position_ids = position_ids[:, -input_ids.shape[1] :]
1371
-
1372
- return {
1373
- "input_ids": input_ids,
1374
- "past_key_values": past_key_values,
1375
- "use_cache": kwargs.get("use_cache"),
1376
- "position_ids": position_ids,
1377
- "attention_mask": attention_mask,
1378
- "inputs_embeds": inputs_embeds,
1379
- }
1380
 
1381
  def forward(
1382
  self,
@@ -1388,7 +1213,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1388
  logits_to_keep: Union[int, torch.Tensor] = 0,
1389
  output_hidden_states: Optional[bool] = None,
1390
  return_dict: Optional[bool] = None,
1391
-
1392
  **kwargs: Unpack[TransformersKwargs],
1393
  ) -> CausalLMOutputWithPast:
1394
  outputs: BaseModelOutputWithPast = self.model(
@@ -1398,7 +1222,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1398
  inputs_embeds=inputs_embeds,
1399
  output_hidden_states=output_hidden_states,
1400
  return_dict=return_dict,
1401
-
1402
  **kwargs,
1403
  )
1404
 
@@ -1423,7 +1246,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
1423
  return CausalLMOutputWithPast(
1424
  loss=loss,
1425
  logits=logits,
1426
- past_key_values=outputs.past_key_values,
1427
  hidden_states=outputs.hidden_states,
1428
  attentions=outputs.attentions,
1429
  )
@@ -1440,7 +1263,7 @@ __all__ = [
1440
  "ScalarMultiplier",
1441
  "VectorMultiplier",
1442
  "LinearWithMultipliers",
1443
- "StackMemory",
1444
  ]
1445
 
1446
  # Register the configuration and model for AutoClass support
 
1
  #!/usr/bin/env python3
2
  """
3
+ NeoLLM model with FANformer, SeeDNorm, ResFormer, Learnable Multipliers,
4
+ and full attention augmented with optional Momentum, MEA, and LUCID operators.
 
 
 
 
 
 
 
 
 
 
 
5
  """
6
 
7
  import math
 
25
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
26
  from transformers.processing_utils import Unpack
27
  from transformers.utils import TransformersKwargs, logging
 
28
  from configuration_neollm import NeoLLMConfig
29
 
30
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 
247
  Self-Rescaled Dynamic Normalization (SeeDNorm) with dual dropout regularization.
248
 
249
  SeeDNorm(x) = [σ(x·β^T)·α + γ] ⊙ x/RMS(x)
250
+
251
 
252
  Args:
253
  dim: Hidden dimension size
 
289
  Normalized and dynamically scaled tensor of same shape
290
  """
291
 
292
+ x_for_dynamic = F.dropout(x, p=self.dropout_input)
293
  rescale_factor = torch.tanh(torch.sum(x_for_dynamic * self.beta,
294
  dim=-1, keepdim=True))
295
 
 
299
  # Apply RMS normalization on ORIGINAL input (not dropped version)
300
  x_normalized = self._rms_norm(x.float())
301
 
302
+ x_normalized = F.dropout(x_normalized, p=self.dropout_hidden)
303
 
304
  # Apply dynamic scaling
305
  output = x_normalized * dynamic_scale.float()
 
309
  def extra_repr(self) -> str:
310
  return (f"dim={self.dim}, eps={self.eps}, "
311
  f"dropout_input={self.dropout_input}, dropout_hidden={self.dropout_hidden}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  # ==================== ROTARY EMBEDDING ====================
313
  class NeoLLMRotaryEmbedding(nn.Module):
314
  inv_freq: torch.Tensor # fix linting for `register_buffer`
 
394
  sin = emb.sin() * self.attention_scaling
395
 
396
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
 
 
 
397
  def rotate_half(x):
398
  """Rotates half the hidden dims of the input."""
399
  x1 = x[..., : x.shape[-1] // 2]
 
406
  cos = cos.unsqueeze(unsqueeze_dim)
407
  sin = sin.unsqueeze(unsqueeze_dim)
408
 
 
409
  rotary_dim = cos.shape[-1]
410
  q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
411
  k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
412
 
 
413
  q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
414
  k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
415
 
 
416
  q_embed = torch.cat([q_embed, q_pass], dim=-1)
417
  k_embed = torch.cat([k_embed, k_pass], dim=-1)
418
  return q_embed, k_embed
 
430
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
431
 
432
 
433
+ def causal_first_difference(x: torch.Tensor) -> torch.Tensor:
434
+ """Causal first difference along sequence length without Python loops."""
435
+ previous = F.pad(x[..., :-1, :], (0, 0, 1, 0))
436
+ return x - previous
437
+
438
+
439
+ def rms_key_unit_norm(x: torch.Tensor, eps: float) -> torch.Tensor:
440
+ """RMS-style key normalization used by the LUCID preconditioner."""
441
+ scale = math.sqrt(x.shape[-1])
442
+ return F.normalize(x.float(), p=2, dim=-1, eps=eps) * scale
443
+
444
+
445
+ def infer_key_validity(attention_mask: Optional[torch.Tensor], seq_len: int, num_heads: int) -> Optional[torch.Tensor]:
446
+ """Infer valid key positions from a square additive attention mask when available."""
447
+ if attention_mask is None or attention_mask.ndim != 4:
448
+ return None
449
+ if attention_mask.shape[-2] != seq_len or attention_mask.shape[-1] != seq_len:
450
+ return None
451
+
452
+ diag = attention_mask.diagonal(dim1=-2, dim2=-1)
453
+ valid = torch.isfinite(diag) & (diag == 0)
454
+
455
+ if valid.shape[1] == 1 and num_heads != 1:
456
+ valid = valid.expand(-1, num_heads, -1)
457
+ elif valid.shape[1] != num_heads:
458
+ valid = valid[:, :1, :].expand(-1, num_heads, -1)
459
+
460
+ return valid
461
+
462
+
463
+ def head_linear_compose(hidden_states: torch.Tensor, mixing_matrix: torch.Tensor) -> torch.Tensor:
464
+ """Head-level linear composition over head axis without Python loops."""
465
+ return torch.einsum("bhtd,hk->bktd", hidden_states, mixing_matrix.to(device=hidden_states.device, dtype=hidden_states.dtype))
466
+
467
+
468
+ def build_mea_reconstruction_matrix(num_component_heads: int, num_output_heads: int) -> torch.Tensor:
469
+ """Build an identity-preserving MEA reconstruction initializer from component heads to output heads."""
470
+ matrix = torch.zeros(num_component_heads, num_output_heads, dtype=torch.float32)
471
+ if num_component_heads <= 0 or num_output_heads <= 0:
472
+ raise ValueError("MEA head counts must be positive")
473
+
474
+ output_indices = torch.arange(num_output_heads, dtype=torch.long)
475
+ component_indices = torch.div(output_indices * num_component_heads, num_output_heads, rounding_mode="floor")
476
+ matrix[component_indices, output_indices] = 1.0
477
+ return matrix
478
+
479
+
480
+ class MEAHeadSeeDNorm(nn.Module):
481
+ """
482
+ MEA head-level normalization using SeeDNorm grouped by KV structure (GQA-aware).
483
+
484
+ In GQA, query heads that share the same K and V are structurally correlated —
485
+ they received identical values and only differ in their Q projection. Normalizing
486
+ them independently (as the original MEA paper assumes for MHA) ignores this
487
+ correlation. Instead, we normalize per KV group: all query heads sharing the
488
+ same KV head are flattened together and normalized as a single unit.
489
+
490
+ With num_attention_heads=8 and num_key_value_heads=2 (num_kv_groups=4):
491
+ - 2 independent SeeDNorm groups
492
+ - each group covers 4 query heads × head_dim = 256 dims
493
+ - SeeDNorm's dynamic scale operates over the group's full 256-dim space
494
+
495
+ This allows SeeDNorm's dynamic scale to detect and compensate for
496
+ LUCID decorrelation magnitude within each KV-coherent group of heads,
497
+ while respecting the GQA structural dependency between heads.
498
+ """
499
+
500
+ def __init__(self, num_heads: int, head_dim: int, num_kv_groups: int, eps: float = 1e-6):
501
+ super().__init__()
502
+ self.num_heads = num_heads
503
+ self.head_dim = head_dim
504
+ self.num_kv_groups = num_kv_groups
505
+ self.num_kv_heads = num_heads // num_kv_groups # number of KV groups = num_key_value_heads
506
+ self.group_dim = num_kv_groups * head_dim # dims per KV group
507
+ # One SeeDNorm instance shared across all KV groups, operating over group_dim
508
+ self.norm = SeeDNorm(self.group_dim, eps=eps)
509
+
510
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
511
+ batch, seq_len, num_heads, head_dim = hidden_states.shape
512
+ if num_heads != self.num_heads or head_dim != self.head_dim:
513
+ raise ValueError(
514
+ f"MEAHeadSeeDNorm expected ({self.num_heads}, {self.head_dim}) heads, "
515
+ f"received ({num_heads}, {head_dim})"
516
+ )
517
+ # Reshape into KV groups: (batch, seq, num_kv_heads, num_kv_groups * head_dim)
518
+ # heads within each KV group are contiguous after attention_interface transpose
519
+ grouped = hidden_states.reshape(batch, seq_len, self.num_kv_heads, self.group_dim)
520
+ # SeeDNorm operates over last dim → independently per KV group
521
+ normed = self.norm(grouped)
522
+ return normed.reshape(batch, seq_len, num_heads, head_dim)
523
+
524
+
525
  def eager_attention_forward(
526
  module: nn.Module,
527
  query: torch.Tensor,
 
550
 
551
  class NeoLLMAttention(nn.Module):
552
  """
553
+ Full attention with FANformer, SeeDNorm, ResFormer, Learnable Multipliers,
554
+ optional post-RoPE Momentum attention, full MEA head-level composition over
555
+ K/V, and optional LUCID value preconditioning.
 
 
 
 
 
 
 
 
556
  """
557
 
558
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
 
562
  self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
563
  self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
564
  self.scaling = self.head_dim**-0.5
565
+ self.sqrt_head_dim = math.sqrt(self.head_dim)
566
  self.attention_dropout = config.attention_dropout
567
  self.is_causal = True
568
+
569
+ self.use_momentum_attention = getattr(config, "use_momentum_attention", False)
570
+ self.momentum_gamma = float(getattr(config, "momentum_gamma", 0.0))
571
+ self.use_mea_attention = getattr(config, "use_mea_attention", False)
572
+ self.mea_component_key_value_heads = int(
573
+ getattr(config, "mea_component_key_value_heads", config.num_key_value_heads)
574
+ )
575
+ self.mea_groupnorm_eps = float(getattr(config, "mea_groupnorm_eps", config.rms_norm_eps))
576
+ self.use_lucid_attention = getattr(config, "use_lucid_attention", False)
577
+ self.lucid_attention_eps = float(getattr(config, "lucid_attention_eps", config.rms_norm_eps))
578
+
579
  self.fan_layer = FANLayer(
580
+ hidden_size=config.hidden_size,
581
+ fan_ratio=getattr(config, "fan_ratio", 0.125),
582
  )
583
+
584
+ fan_output_dim = config.hidden_size + int(config.hidden_size * getattr(config, "fan_ratio", 0.125))
585
+
 
 
586
  self.q_proj = LinearWithMultipliers(
587
+ fan_output_dim,
588
+ config.num_attention_heads * self.head_dim * 2,
589
  bias=config.attention_bias,
590
  use_row_multiplier=True,
591
+ use_column_multiplier=False,
592
+ )
593
+ self.num_mea_component_heads = (
594
+ self.mea_component_key_value_heads if self.use_mea_attention else config.num_key_value_heads
595
  )
 
 
596
  self.k_proj = nn.Linear(
597
+ fan_output_dim, self.num_mea_component_heads * self.head_dim, bias=config.attention_bias
598
  )
599
  self.v_proj = nn.Linear(
600
+ fan_output_dim, self.num_mea_component_heads * self.head_dim, bias=config.attention_bias
601
  )
 
 
602
  self.o_proj = LinearWithMultipliers(
603
  config.num_attention_heads * self.head_dim,
604
  config.hidden_size,
605
  bias=config.attention_bias,
606
  use_row_multiplier=True,
607
+ use_column_multiplier=True,
608
  )
609
+
 
610
  self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
611
  self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
612
+
613
+ if self.use_mea_attention:
614
+ self.mea_key_mix = nn.Parameter(
615
+ build_mea_reconstruction_matrix(self.num_mea_component_heads, config.num_key_value_heads)
616
+ )
617
+ self.mea_value_mix = nn.Parameter(
618
+ build_mea_reconstruction_matrix(self.num_mea_component_heads, config.num_key_value_heads)
619
+ )
620
+ self.mea_output_norm = MEAHeadSeeDNorm(
621
+ num_heads=config.num_attention_heads,
622
+ head_dim=self.head_dim,
623
+ num_kv_groups=self.num_key_value_groups,
624
+ eps=self.mea_groupnorm_eps,
625
+ )
626
+ else:
627
+ self.mea_key_mix = None
628
+ self.mea_value_mix = None
629
+ self.mea_output_norm = None
630
+
631
  self.dropout = nn.Dropout(config.dropout_rate)
632
+ self.lambda_1 = nn.Parameter(torch.tensor(0.5))
633
+ self.lambda_2 = nn.Parameter(torch.tensor(0.5))
634
+
635
+ def _apply_momentum_attention(
636
+ self,
637
+ query_states: torch.Tensor,
638
+ key_states: torch.Tensor,
639
+ ) -> tuple[torch.Tensor, torch.Tensor]:
640
+ """Apply post-RoPE momentum shear to Q and K only."""
641
+ if not self.use_momentum_attention or self.momentum_gamma == 0.0:
642
+ return query_states, key_states
643
+
644
+ query_states = query_states + self.momentum_gamma * causal_first_difference(query_states)
645
+ key_states = key_states + self.momentum_gamma * causal_first_difference(key_states)
646
+ return query_states, key_states
647
+
648
+ def _apply_mea_head_mixing(
649
+ self,
650
+ key_states: torch.Tensor,
651
+ value_states: torch.Tensor,
652
+ ) -> tuple[torch.Tensor, torch.Tensor]:
653
+ """Apply explicit KV head interaction before repeat_kv and attention."""
654
+ if not self.use_mea_attention:
655
+ return key_states, value_states
656
+
657
+ mixed_keys = head_linear_compose(key_states, self.mea_key_mix).contiguous()
658
+ mixed_values = head_linear_compose(value_states, self.mea_value_mix).contiguous()
659
+ return mixed_keys, mixed_values
660
+
661
+ def _apply_lucid_preconditioner(
662
+ self,
663
+ key_states: torch.Tensor,
664
+ value_states: torch.Tensor,
665
+ attention_mask: Optional[torch.Tensor],
666
+ ) -> torch.Tensor:
667
+ """Compute LUCID preconditioned values via a batched lower-triangular solve."""
668
+ if not self.use_lucid_attention:
669
+ return value_states
670
+
671
+ key_rn = rms_key_unit_norm(key_states, eps=self.lucid_attention_eps)
672
+ precondition_logits = torch.matmul(key_rn, key_rn.transpose(-1, -2)) * self.scaling - self.sqrt_head_dim
673
+ preconditioner = torch.tril(torch.exp(precondition_logits))
674
+
675
+ key_validity = infer_key_validity(attention_mask, key_states.shape[-2], key_states.shape[1])
676
+ if key_validity is not None:
677
+ pair_valid = key_validity.unsqueeze(-1) & key_validity.unsqueeze(-2)
678
+ preconditioner = preconditioner * pair_valid.to(preconditioner.dtype)
679
+
680
+ eye = torch.eye(
681
+ preconditioner.shape[-1],
682
+ device=preconditioner.device,
683
+ dtype=preconditioner.dtype,
684
+ ).view(1, 1, preconditioner.shape[-1], preconditioner.shape[-1])
685
+ preconditioner = preconditioner * (1.0 - eye) + eye
686
+
687
+ lucid_values = torch.linalg.solve_triangular(
688
+ preconditioner,
689
+ value_states.float(),
690
+ upper=False,
691
+ unitriangular=True,
692
+ )
693
+ return lucid_values.to(value_states.dtype).contiguous()
694
+
695
+ def _apply_mea_output_norm(self, attn_output: torch.Tensor) -> torch.Tensor:
696
+ """Apply MEA GQA-grouped SeeDNorm on the per-head attention output."""
697
+ if not self.use_mea_attention:
698
+ return attn_output
699
+ return self.mea_output_norm(attn_output)
700
 
701
  def forward(
702
  self,
 
706
  first_layer_fan: Optional[torch.Tensor] = None,
707
  **kwargs: Unpack[FlashAttentionKwargs],
708
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
709
+ """Forward pass for the full attention block."""
 
 
 
 
 
 
 
 
 
 
 
710
  input_shape = hidden_states.shape[:-1]
711
+
 
712
  hidden_states_fan = self.fan_layer(hidden_states)
 
 
713
  if first_layer_fan is not None:
714
  hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
715
+
 
716
  current_layer_fan = hidden_states_fan.clone()
717
+ query_shape = (*input_shape, self.config.num_attention_heads, self.head_dim)
718
+ key_value_shape = (*input_shape, self.num_mea_component_heads, self.head_dim)
719
 
 
720
  query_states, gate = torch.chunk(
721
+ self.q_proj(hidden_states_fan).view(*input_shape, self.config.num_attention_heads, self.head_dim * 2), 2, dim=-1
722
  )
723
  gate = gate.reshape(*input_shape, -1)
724
 
725
+ query_states = self.q_norm(query_states.view(query_shape)).transpose(1, 2)
726
+ key_states = self.k_norm(self.k_proj(hidden_states_fan).view(key_value_shape)).transpose(1, 2)
727
+ value_states = self.v_proj(hidden_states_fan).view(key_value_shape).transpose(1, 2)
 
728
 
729
  cos, sin = position_embeddings
730
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
731
+ query_states, key_states = self._apply_momentum_attention(query_states, key_states)
732
+ key_states, value_states = self._apply_mea_head_mixing(key_states, value_states)
733
+ value_states = self._apply_lucid_preconditioner(key_states, value_states, attention_mask)
734
 
735
  attention_interface: Callable = eager_attention_forward
736
  if self.config._attn_implementation != "eager":
 
747
  **kwargs,
748
  )
749
 
750
+ attn_output = attn_output.reshape(*input_shape, -1, self.head_dim)
751
+ attn_output = self._apply_mea_output_norm(attn_output)
752
  attn_output = attn_output.reshape(*input_shape, -1).contiguous()
753
  attn_output = attn_output * torch.sigmoid(gate)
 
 
754
  attn_output = self.o_proj(attn_output)
755
  attn_output = self.dropout(attn_output)
 
 
756
 
757
+ return attn_output, attn_weights, current_layer_fan
758
 
759
  class PolyNorm(torch.nn.Module):
760
  def __init__(self, eps=1e-6):
 
839
 
840
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
841
  """
842
+ Decoder layer with standard residual connections.
843
 
844
+ Arquitectura:
845
+ 1. Pre-norm (SeeDNorm) LNS scaling → Self-Attention con ResFormer y Learnable Multipliers
846
+ 2. Standard Residual Connection (suma simple)
847
+ 3. GPAS activation scaling
848
+ 4. Pre-norm (SeeDNorm) → LNS scaling → MLP con FANformer y Learnable Multipliers
849
+ 5. Standard Residual Connection (suma simple)
850
+ 6. GPAS activation scaling
 
851
  """
852
 
853
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
 
861
  # MLP with FANformer integration and learnable multipliers
862
  self.mlp = NeoLLMMLP(config)
863
 
864
+ # SeeDNorm for input and post-attention normalization (replaces RMSNorm)
865
  self.input_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
866
  self.post_attention_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
867
 
 
869
  self.lns_attn = LNS(layer_idx)
870
  self.lns_mlp = LNS(layer_idx)
871
 
872
+ # GPAS (Gradient-Preserving Activation Scaling) - applied after residual connections
873
  self.gpas_attn = GPAS(config.hidden_size)
874
  self.gpas_mlp = GPAS(config.hidden_size)
875
 
 
 
 
 
 
876
  # ResFormer: storage for current layer's FAN features
877
  self.current_layer_fan = None
878
 
 
882
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
883
  attention_mask: Optional[torch.Tensor] = None,
884
  first_layer_fan: Optional[torch.Tensor] = None,
 
 
885
  output_attentions: Optional[bool] = False,
886
  **kwargs: Unpack[FlashAttentionKwargs],
887
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
888
  # ============================================================
889
+ # Attention Block with standard residual connection
 
 
 
 
 
 
 
 
 
 
890
  # ============================================================
891
  residual = hidden_states
892
 
 
896
  # Apply LNS scaling after normalization
897
  hidden_states = self.lns_attn(hidden_states)
898
 
899
+ # Self Attention with ResFormer feature residual connections and learnable multipliers
900
+ # We capture attn_weights here instead of ignoring them
901
+ hidden_states, attn_weights, self.current_layer_fan = self.self_attn(
902
  hidden_states=hidden_states,
 
903
  attention_mask=attention_mask,
904
+ position_embeddings=position_embeddings,
905
  first_layer_fan=first_layer_fan,
906
  **kwargs,
907
  )
908
 
909
+ # Standard residual connection
910
+ hidden_states = residual + hidden_states
911
 
912
+ # Apply GPAS after attention residual connection
913
  hidden_states = self.gpas_attn(hidden_states)
914
 
915
  # ============================================================
916
+ # MLP Block with standard residual connection
917
  # ============================================================
918
  residual = hidden_states
919
  hidden_states = self.post_attention_layernorm(hidden_states)
 
921
  # Apply LNS scaling after normalization
922
  hidden_states = self.lns_mlp(hidden_states)
923
 
924
+ # MLP now includes FAN transformation and learnable multipliers internally
925
+ hidden_states = self.mlp(hidden_states)
926
 
927
+ # Standard residual connection
928
+ hidden_states = residual + hidden_states
929
 
930
+ # Apply GPAS after MLP residual connection
931
  hidden_states = self.gpas_mlp(hidden_states)
932
 
933
+ outputs = (hidden_states,)
934
+ if output_attentions:
935
+ outputs += (attn_weights,)
936
+
937
+ return outputs
938
 
939
 
940
  class NeoLLMPreTrainedModel(PreTrainedModel):
 
947
  - FANLayer (Fourier Analysis Network)
948
  - SeeDNorm (Self-Rescaled Dynamic Normalization)
949
  - Learnable Multipliers (ScalarMultiplier, VectorMultiplier)
 
950
  """
951
  config: NeoLLMConfig
952
  base_model_prefix = "model"
 
959
  def _init_weights(self, module):
960
  """
961
  Initialize weights for all custom modules in NeoLLM.
962
+
963
+ Strategy:
964
+ - Standard layers (Linear, Embedding): handled by parent class
965
+ - Custom modules: specialized initialization per component
966
+ - Learnable Multipliers: initialized to 1.0 for identity transformation
967
  """
968
  super()._init_weights(module)
969
 
970
  if isinstance(module, NeoLLMAttention):
971
+ # ResFormer: initialize lambda parameters for full attention
972
+ # Lambda values control the interpolation between first layer and current layer features
973
+ # Starting at 0.5 provides balanced contribution from both sources
974
  if hasattr(module, 'lambda_1'):
975
  module.lambda_1.data.fill_(0.5)
976
  if hasattr(module, 'lambda_2'):
977
  module.lambda_2.data.fill_(0.5)
978
+ if hasattr(module, 'mea_key_mix') and module.mea_key_mix is not None:
979
+ module.mea_key_mix.data.copy_(
980
+ build_mea_reconstruction_matrix(
981
+ module.mea_key_mix.shape[0],
982
+ module.mea_key_mix.shape[1],
983
+ ).to(device=module.mea_key_mix.device, dtype=module.mea_key_mix.dtype)
984
+ )
985
+ if hasattr(module, 'mea_value_mix') and module.mea_value_mix is not None:
986
+ module.mea_value_mix.data.copy_(
987
+ build_mea_reconstruction_matrix(
988
+ module.mea_value_mix.shape[0],
989
+ module.mea_value_mix.shape[1],
990
+ ).to(device=module.mea_value_mix.device, dtype=module.mea_value_mix.dtype)
991
+ )
992
 
993
  elif isinstance(module, GPAS):
994
+ # Initialize GPAS alpha to 0 as per paper
995
+ # This starts with no activation scaling, allowing the model to learn gradually
996
  module.alpha.data.fill_(0.0)
997
 
998
+ elif isinstance(module, FANLayer):
999
+ # FANLayer initialization is handled within the class __init__
1000
+ # Uses normal initialization with std=0.02 for weights
1001
+ pass
1002
+
1003
+ elif isinstance(module, SeeDNorm):
1004
+ # SeeDNorm initialization (parameters already initialized correctly in __init__):
1005
+ # gamma (γ) initialized to 1 (static scaling component, like RMSNorm)
1006
+ # beta (β) initialized to 0 (self-rescaling starts disabled)
1007
+ # alpha (α) initialized to 1 (dynamic modulation at full strength)
1008
+ pass
1009
+
1010
  elif isinstance(module, (ScalarMultiplier, VectorMultiplier)):
1011
+ # Learnable Multipliers: initialize to 1.0 for identity transformation
1012
+ # This allows the model to start from the standard behavior and learn
1013
+ # scale adaptations from data without initial bias
1014
  if hasattr(module, 'multiplier'):
1015
  module.multiplier.data.fill_(1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1016
 
1017
  class NeoLLMModel(NeoLLMPreTrainedModel):
1018
  """
1019
  NeoLLM base model with transformer decoder architecture.
1020
 
 
 
 
1021
  Note on embeddings and weight tying: This model uses weight tying between
1022
  embed_tokens and lm_head (shared weights). Following "Learnable Multipliers"
1023
  paper analysis, we do NOT add multipliers to embeddings because:
1024
 
1025
+ 1. Weight tying creates conflicting gradient paths: multipliers would scale
1026
+ gradients from embedding lookup but not from lm_head projection, causing
1027
+ the multiplier to receive incomplete optimization signals.
1028
+
1029
+ 2. The paper explicitly warns against multipliers in lm_head (creates shortcuts
1030
+ for learning marginal token distribution), and with weight tying this
1031
+ restriction propagates to embeddings.
1032
+
1033
+ 3. Compensating mechanisms provide scale adaptation immediately after embedding:
1034
+ - First layer attention has multipliers in Q/O projections
1035
+ - FANformer transforms the representation space
1036
+ - SeeDNorm provides input-dependent dynamic scaling
1037
+ - ResFormer propagates first-layer features with learnable scaling
1038
  """
1039
 
1040
  def __init__(self, config: NeoLLMConfig):
1041
  super().__init__(config)
1042
 
1043
  # Standard embedding without learnable multipliers
1044
+ # Due to weight tying with lm_head, multipliers would create
1045
+ # conflicting optimization dynamics (see class docstring)
1046
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
1047
 
1048
  # Each layer creates its own components (no shared parameters)
 
1055
  self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
1056
  self.gradient_checkpointing = False
1057
 
1058
+ # ResFormer: storage for first layer's FAN features (H_fan_1)
 
 
 
1059
  self.first_layer_fan = None
1060
 
1061
  # Initialize weights and apply final processing
 
1070
  output_hidden_states: Optional[bool] = None,
1071
  output_attentions: Optional[bool] = None,
1072
  return_dict: Optional[bool] = None,
 
 
1073
  **kwargs: Unpack[TransformersKwargs],
1074
  ) -> BaseModelOutputWithPast:
1075
  output_hidden_states = (
 
1086
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1087
 
1088
  if inputs_embeds is None:
1089
+ # Standard embedding lookup without multipliers
1090
+ # Scale adaptation occurs in subsequent layers via:
1091
+ # (1) First layer attention multipliers, (2) FANformer transformation,
1092
+ # (3) SeeDNorm dynamic scaling, (4) ResFormer feature propagation
1093
  inputs_embeds = self.embed_tokens(input_ids)
1094
 
1095
  if position_ids is None:
 
1105
  )
1106
 
1107
  hidden_states = inputs_embeds
 
1108
  all_hidden_states = () if output_hidden_states else None
1109
  all_attentions = () if output_attentions else None
1110
 
1111
+ # create position embeddings to be shared across the decoder layers
1112
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
1113
 
1114
+ # ResFormer: reset first_layer_fan at the start of each forward pass
1115
  self.first_layer_fan = None
1116
+
1117
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
 
 
 
 
 
 
 
 
 
 
 
 
1118
  if output_hidden_states:
1119
  all_hidden_states = all_hidden_states + (hidden_states,)
1120
 
 
1122
  hidden_states,
1123
  position_embeddings=position_embeddings,
1124
  attention_mask=causal_mask,
1125
+ first_layer_fan=self.first_layer_fan, # Pass H_fan_1 to all layers
 
 
1126
  output_attentions=output_attentions,
1127
  **kwargs,
1128
  )
 
1132
  if output_attentions:
1133
  all_attentions = all_attentions + (layer_outputs[1],)
1134
 
 
 
 
 
 
 
 
1135
  # ResFormer: capture H_fan_1 from the first layer
 
1136
  if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
1137
  self.first_layer_fan = decoder_layer.current_layer_fan
1138
 
 
1143
  all_hidden_states = all_hidden_states + (hidden_states,)
1144
 
1145
  if not return_dict:
1146
+ return tuple(v for v in [hidden_states, None, all_hidden_states, all_attentions] if v is not None)
1147
 
1148
  return BaseModelOutputWithPast(
1149
  last_hidden_state=hidden_states,
1150
+ past_key_values=None,
1151
  hidden_states=all_hidden_states,
1152
  attentions=all_attentions,
1153
  )
 
1202
 
1203
  self.post_init()
1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1205
 
1206
  def forward(
1207
  self,
 
1213
  logits_to_keep: Union[int, torch.Tensor] = 0,
1214
  output_hidden_states: Optional[bool] = None,
1215
  return_dict: Optional[bool] = None,
 
1216
  **kwargs: Unpack[TransformersKwargs],
1217
  ) -> CausalLMOutputWithPast:
1218
  outputs: BaseModelOutputWithPast = self.model(
 
1222
  inputs_embeds=inputs_embeds,
1223
  output_hidden_states=output_hidden_states,
1224
  return_dict=return_dict,
 
1225
  **kwargs,
1226
  )
1227
 
 
1246
  return CausalLMOutputWithPast(
1247
  loss=loss,
1248
  logits=logits,
1249
+ past_key_values=None,
1250
  hidden_states=outputs.hidden_states,
1251
  attentions=outputs.attentions,
1252
  )
 
1263
  "ScalarMultiplier",
1264
  "VectorMultiplier",
1265
  "LinearWithMultipliers",
1266
+ "MEAHeadRMSNorm",
1267
  ]
1268
 
1269
  # Register the configuration and model for AutoClass support