KitsuVp commited on
Commit
9a60a23
·
verified ·
1 Parent(s): d2e1abb

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +313 -95
modeling_neollm.py CHANGED
@@ -2,7 +2,9 @@
2
  """
3
  NeoLLM Model with FANformer Integration in both Attention and FFN, Dropout Regularization,
4
  SeeDNorm (Self-Rescaled Dynamic Normalization), ResFormer Value Residual Learning,
5
- and Learnable Multipliers for enhanced scale adaptation and information flow through deep layers.
 
 
6
  Updated to include:
7
  - Fourier Analysis Network (FAN) layer for effective periodicity modeling in attention (relational space)
8
  - FAN layer in FFN for featural periodicity modeling (complementary coverage)
@@ -10,16 +12,20 @@ Updated to include:
10
  - Dropout regularization at strategic locations
11
  - ResFormer: Feature residual connections from first layer (applied before projections)
12
  - Learnable Multipliers: Frees weight matrix scale from WD-noise equilibrium for data-adaptive scaling
 
13
  - Full Attention only (linear attention removed)
14
  """
15
 
16
  import math
17
- from typing import Any, Callable, Optional, Union, Tuple
18
 
19
  import torch
20
  import torch.nn.functional as F
21
  from torch import nn
22
  from cut_cross_entropy import linear_cross_entropy
 
 
 
23
 
24
  from transformers.activations import ACT2FN
25
  from transformers.generation import GenerationMixin
@@ -296,7 +302,7 @@ class SeeDNorm(nn.Module):
296
  Normalized and dynamically scaled tensor of same shape
297
  """
298
 
299
- x_for_dynamic = F.dropout(x, p=self.dropout_input)
300
  rescale_factor = torch.tanh(torch.sum(x_for_dynamic * self.beta,
301
  dim=-1, keepdim=True))
302
 
@@ -306,7 +312,7 @@ class SeeDNorm(nn.Module):
306
  # Apply RMS normalization on ORIGINAL input (not dropped version)
307
  x_normalized = self._rms_norm(x.float())
308
 
309
- x_normalized = F.dropout(x_normalized, p=self.dropout_hidden)
310
 
311
  # Apply dynamic scaling
312
  output = x_normalized * dynamic_scale.float()
@@ -317,6 +323,189 @@ class SeeDNorm(nn.Module):
317
  return (f"dim={self.dim}, eps={self.eps}, "
318
  f"dropout_input={self.dropout_input}, dropout_hidden={self.dropout_hidden}")
319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  class NeoLLMRotaryEmbedding(nn.Module):
321
  inv_freq: torch.Tensor # fix linting for `register_buffer`
322
 
@@ -424,7 +613,7 @@ class NeoLLMAttention(nn.Module):
424
  ResFormer feature residual connections, and Learnable Multipliers for enhanced
425
  information flow and scale adaptation.
426
 
427
- ResFormer enhancement: Applies learnable feature residual connections from the first layer
428
  BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n
429
 
430
  Learnable Multipliers placement (from "Learnable Multipliers" paper Appendix C):
@@ -486,33 +675,43 @@ class NeoLLMAttention(nn.Module):
486
  self.dropout = nn.Dropout(config.dropout_rate)
487
 
488
  # ResFormer: learnable feature residual parameters (initialized to 0.5)
489
- self.lambda_1 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_1 (first layer features)
490
- self.lambda_2 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_n (current layer features)
491
 
492
  def forward(
493
  self,
494
  hidden_states: torch.Tensor,
495
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
496
- attention_mask: Optional[torch.Tensor],
497
  first_layer_fan: Optional[torch.Tensor] = None,
498
  **kwargs: Unpack[FlashAttentionKwargs],
499
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
 
 
 
 
 
 
 
 
 
 
 
 
500
  input_shape = hidden_states.shape[:-1]
501
 
502
- # Apply FANformer transformation first
503
  hidden_states_fan = self.fan_layer(hidden_states)
504
 
505
  # ResFormer: Apply feature residual connection BEFORE projections
506
- # This ensures dimensional compatibility across all layer types
507
  if first_layer_fan is not None:
508
  hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
509
 
510
- # Store current FAN features for potential use as first_layer_fan in subsequent layers
511
  current_layer_fan = hidden_states_fan.clone()
512
 
513
  hidden_shape = (*input_shape, -1, self.head_dim)
514
 
515
- # Use FAN-transformed features (with residual applied) for projections
516
  # Q projection with learnable row multipliers
517
  query_states, gate = torch.chunk(
518
  self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
@@ -633,17 +832,19 @@ class NeoLLMMLP(nn.Module):
633
  hidden = self.dropout(hidden)
634
  return self.down_proj(hidden)
635
 
 
636
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
637
  """
638
- Decoder layer with standard residual connections.
639
 
640
- Arquitectura:
641
- 1. Pre-norm (SeeDNorm) → LNS scaling → Self-Attention con ResFormer y Learnable Multipliers
642
- 2. Standard Residual Connection (suma simple)
643
  3. GPAS activation scaling
644
- 4. Pre-norm (SeeDNorm) → LNS scaling → MLP con FANformer y Learnable Multipliers
645
- 5. Standard Residual Connection (suma simple)
646
  6. GPAS activation scaling
 
647
  """
648
 
649
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
@@ -657,7 +858,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
657
  # MLP with FANformer integration and learnable multipliers
658
  self.mlp = NeoLLMMLP(config)
659
 
660
- # SeeDNorm for input and post-attention normalization (replaces RMSNorm)
661
  self.input_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
662
  self.post_attention_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
663
 
@@ -665,10 +866,15 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
665
  self.lns_attn = LNS(layer_idx)
666
  self.lns_mlp = LNS(layer_idx)
667
 
668
- # GPAS (Gradient-Preserving Activation Scaling) - applied after residual connections
669
  self.gpas_attn = GPAS(config.hidden_size)
670
  self.gpas_mlp = GPAS(config.hidden_size)
671
 
 
 
 
 
 
672
  # ResFormer: storage for current layer's FAN features
673
  self.current_layer_fan = None
674
 
@@ -678,11 +884,28 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
678
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
679
  attention_mask: Optional[torch.Tensor] = None,
680
  first_layer_fan: Optional[torch.Tensor] = None,
 
 
681
  output_attentions: Optional[bool] = False,
682
  **kwargs: Unpack[FlashAttentionKwargs],
683
- ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684
  # ============================================================
685
- # Attention Block with standard residual connection
686
  # ============================================================
687
  residual = hidden_states
688
 
@@ -692,24 +915,23 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
692
  # Apply LNS scaling after normalization
693
  hidden_states = self.lns_attn(hidden_states)
694
 
695
- # Self Attention with ResFormer feature residual connections and learnable multipliers
696
- # We capture attn_weights here instead of ignoring them
697
- hidden_states, attn_weights, self.current_layer_fan = self.self_attn(
698
  hidden_states=hidden_states,
699
- attention_mask=attention_mask,
700
  position_embeddings=position_embeddings,
 
701
  first_layer_fan=first_layer_fan,
702
  **kwargs,
703
  )
704
 
705
- # Standard residual connection
706
- hidden_states = residual + hidden_states
707
 
708
- # Apply GPAS after attention residual connection
709
  hidden_states = self.gpas_attn(hidden_states)
710
 
711
  # ============================================================
712
- # MLP Block with standard residual connection
713
  # ============================================================
714
  residual = hidden_states
715
  hidden_states = self.post_attention_layernorm(hidden_states)
@@ -717,20 +939,27 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
717
  # Apply LNS scaling after normalization
718
  hidden_states = self.lns_mlp(hidden_states)
719
 
720
- # MLP now includes FAN transformation and learnable multipliers internally
721
- hidden_states = self.mlp(hidden_states)
722
 
723
- # Standard residual connection
724
- hidden_states = residual + hidden_states
725
 
726
- # Apply GPAS after MLP residual connection
727
  hidden_states = self.gpas_mlp(hidden_states)
728
 
729
- outputs = (hidden_states,)
730
- if output_attentions:
731
- outputs += (attn_weights,)
 
 
 
 
732
 
733
- return outputs
 
 
 
734
 
735
 
736
  class NeoLLMPreTrainedModel(PreTrainedModel):
@@ -743,6 +972,7 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
743
  - FANLayer (Fourier Analysis Network)
744
  - SeeDNorm (Self-Rescaled Dynamic Normalization)
745
  - Learnable Multipliers (ScalarMultiplier, VectorMultiplier)
 
746
  """
747
  config: NeoLLMConfig
748
  base_model_prefix = "model"
@@ -755,76 +985,58 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
755
  def _init_weights(self, module):
756
  """
757
  Initialize weights for all custom modules in NeoLLM.
758
-
759
- Strategy:
760
- - Standard layers (Linear, Embedding): handled by parent class
761
- - Custom modules: specialized initialization per component
762
- - Learnable Multipliers: initialized to 1.0 for identity transformation
763
  """
764
  super()._init_weights(module)
765
 
766
  if isinstance(module, NeoLLMAttention):
767
- # ResFormer: initialize lambda parameters for full attention
768
- # Lambda values control the interpolation between first layer and current layer features
769
- # Starting at 0.5 provides balanced contribution from both sources
770
  if hasattr(module, 'lambda_1'):
771
  module.lambda_1.data.fill_(0.5)
772
  if hasattr(module, 'lambda_2'):
773
  module.lambda_2.data.fill_(0.5)
774
 
775
  elif isinstance(module, GPAS):
776
- # Initialize GPAS alpha to 0 as per paper
777
- # This starts with no activation scaling, allowing the model to learn gradually
778
  module.alpha.data.fill_(0.0)
779
 
780
- elif isinstance(module, FANLayer):
781
- # FANLayer initialization is handled within the class __init__
782
- # Uses normal initialization with std=0.02 for weights
783
- pass
784
-
785
- elif isinstance(module, SeeDNorm):
786
- # SeeDNorm initialization (parameters already initialized correctly in __init__):
787
- # gamma (γ) initialized to 1 (static scaling component, like RMSNorm)
788
- # beta (β) initialized to 0 (self-rescaling starts disabled)
789
- # alpha (α) initialized to 1 (dynamic modulation at full strength)
790
- pass
791
-
792
  elif isinstance(module, (ScalarMultiplier, VectorMultiplier)):
793
- # Learnable Multipliers: initialize to 1.0 for identity transformation
794
- # This allows the model to start from the standard behavior and learn
795
- # scale adaptations from data without initial bias
796
  if hasattr(module, 'multiplier'):
797
  module.multiplier.data.fill_(1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
798
 
799
  class NeoLLMModel(NeoLLMPreTrainedModel):
800
  """
801
  NeoLLM base model with transformer decoder architecture.
802
 
 
 
 
803
  Note on embeddings and weight tying: This model uses weight tying between
804
  embed_tokens and lm_head (shared weights). Following "Learnable Multipliers"
805
  paper analysis, we do NOT add multipliers to embeddings because:
806
 
807
- 1. Weight tying creates conflicting gradient paths: multipliers would scale
808
- gradients from embedding lookup but not from lm_head projection, causing
809
- the multiplier to receive incomplete optimization signals.
810
-
811
- 2. The paper explicitly warns against multipliers in lm_head (creates shortcuts
812
- for learning marginal token distribution), and with weight tying this
813
- restriction propagates to embeddings.
814
-
815
- 3. Compensating mechanisms provide scale adaptation immediately after embedding:
816
- - First layer attention has multipliers in Q/O projections
817
- - FANformer transforms the representation space
818
- - SeeDNorm provides input-dependent dynamic scaling
819
- - ResFormer propagates first-layer features with learnable scaling
820
  """
821
 
822
  def __init__(self, config: NeoLLMConfig):
823
  super().__init__(config)
824
 
825
  # Standard embedding without learnable multipliers
826
- # Due to weight tying with lm_head, multipliers would create
827
- # conflicting optimization dynamics (see class docstring)
828
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
829
 
830
  # Each layer creates its own components (no shared parameters)
@@ -837,7 +1049,10 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
837
  self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
838
  self.gradient_checkpointing = False
839
 
840
- # ResFormer: storage for first layer's FAN features (H_fan_1)
 
 
 
841
  self.first_layer_fan = None
842
 
843
  # Initialize weights and apply final processing
@@ -868,10 +1083,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
868
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
869
 
870
  if inputs_embeds is None:
871
- # Standard embedding lookup without multipliers
872
- # Scale adaptation occurs in subsequent layers via:
873
- # (1) First layer attention multipliers, (2) FANformer transformation,
874
- # (3) SeeDNorm dynamic scaling, (4) ResFormer feature propagation
875
  inputs_embeds = self.embed_tokens(input_ids)
876
 
877
  if position_ids is None:
@@ -890,13 +1101,15 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
890
  all_hidden_states = () if output_hidden_states else None
891
  all_attentions = () if output_attentions else None
892
 
893
- # create position embeddings to be shared across the decoder layers
894
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
895
 
896
- # ResFormer: reset first_layer_fan at the start of each forward pass
897
  self.first_layer_fan = None
898
-
899
- for decoder_layer in self.layers[: self.config.num_hidden_layers]:
 
 
900
  if output_hidden_states:
901
  all_hidden_states = all_hidden_states + (hidden_states,)
902
 
@@ -904,7 +1117,9 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
904
  hidden_states,
905
  position_embeddings=position_embeddings,
906
  attention_mask=causal_mask,
907
- first_layer_fan=self.first_layer_fan, # Pass H_fan_1 to all layers
 
 
908
  output_attentions=output_attentions,
909
  **kwargs,
910
  )
@@ -914,6 +1129,10 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
914
  if output_attentions:
915
  all_attentions = all_attentions + (layer_outputs[1],)
916
 
 
 
 
 
917
  # ResFormer: capture H_fan_1 from the first layer
918
  if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
919
  self.first_layer_fan = decoder_layer.current_layer_fan
@@ -967,11 +1186,10 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
967
  """
968
  Causal Language Model with NeoLLM architecture.
969
 
 
 
970
  Note on LM head: Following "Learnable Multipliers" paper recommendations,
971
- the output projection (lm_head) does NOT include learnable multipliers because:
972
- 1. The preceding RMSNorm (self.model.norm) already acts as column multipliers
973
- 2. Adding row multipliers to lm_head can create shortcuts where the model
974
- learns marginal token distribution without updating internal features
975
  """
976
  _tied_weights_keys = ["lm_head.weight"]
977
 
@@ -981,7 +1199,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
981
  self.vocab_size = config.vocab_size
982
 
983
  # LM head without learnable multipliers (standard linear layer)
984
- # Preceding norm layer provides sufficient scale adaptation
985
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
986
 
987
  self.post_init()
@@ -1046,6 +1263,7 @@ __all__ = [
1046
  "ScalarMultiplier",
1047
  "VectorMultiplier",
1048
  "LinearWithMultipliers",
 
1049
  ]
1050
 
1051
  # Register the configuration and model for AutoClass support
 
2
  """
3
  NeoLLM Model with FANformer Integration in both Attention and FFN, Dropout Regularization,
4
  SeeDNorm (Self-Rescaled Dynamic Normalization), ResFormer Value Residual Learning,
5
+ Learnable Multipliers for enhanced scale adaptation and information flow through deep layers,
6
+ and StackMemory for hierarchical pattern modeling.
7
+
8
  Updated to include:
9
  - Fourier Analysis Network (FAN) layer for effective periodicity modeling in attention (relational space)
10
  - FAN layer in FFN for featural periodicity modeling (complementary coverage)
 
12
  - Dropout regularization at strategic locations
13
  - ResFormer: Feature residual connections from first layer (applied before projections)
14
  - Learnable Multipliers: Frees weight matrix scale from WD-noise equilibrium for data-adaptive scaling
15
+ - StackMemory: Differentiable hidden state stack for modeling Chomsky hierarchy grammars
16
  - Full Attention only (linear attention removed)
17
  """
18
 
19
  import math
20
+ from typing import Any, Callable, Optional, Union, Tuple, List
21
 
22
  import torch
23
  import torch.nn.functional as F
24
  from torch import nn
25
  from cut_cross_entropy import linear_cross_entropy
26
+ import torch.nn.functional as F
27
+ from torch.utils.checkpoint import checkpoint
28
+ from typing import Optional, Tuple
29
 
30
  from transformers.activations import ACT2FN
31
  from transformers.generation import GenerationMixin
 
302
  Normalized and dynamically scaled tensor of same shape
303
  """
304
 
305
+ x_for_dynamic = F.dropout(x, p=self.dropout_input, training=self.training)
306
  rescale_factor = torch.tanh(torch.sum(x_for_dynamic * self.beta,
307
  dim=-1, keepdim=True))
308
 
 
312
  # Apply RMS normalization on ORIGINAL input (not dropped version)
313
  x_normalized = self._rms_norm(x.float())
314
 
315
+ x_normalized = F.dropout(x_normalized, p=self.dropout_hidden, training=self.training)
316
 
317
  # Apply dynamic scaling
318
  output = x_normalized * dynamic_scale.float()
 
323
  return (f"dim={self.dim}, eps={self.eps}, "
324
  f"dropout_input={self.dropout_input}, dropout_hidden={self.dropout_hidden}")
325
 
326
+
327
+ # ==================== STACK MEMORY MODULE ====================
328
+
329
+ class StackMemory(nn.Module):
330
+ """
331
+ Differentiable Hidden State Stack for modeling Chomsky hierarchy grammars.
332
+
333
+ From "Improving Formal Reasoning of Transformer with State Stack":
334
+ Implements a multi-head differentiable stack with soft push, pop, and no-op operations.
335
+ Each head maintains its own stack and mask, which are updated based on learned action
336
+ probabilities. Global reading is performed via query-over-stack attention.
337
+
338
+ This module is inserted between Transformer layers to augment information flow with
339
+ stack-like memory operations, enabling the model to better capture hierarchical and
340
+ recursive patterns characteristic of regular expressions and context-free grammars.
341
+
342
+ Note: StackMemory uses standard nn.Linear to maintain architectural
343
+ independence and avoid introducing additional complexity in the memory operations.
344
+
345
+ Args:
346
+ config: Model configuration containing stack-related hyperparameters
347
+ """
348
+
349
+ def __init__(self, config: NeoLLMConfig):
350
+ super().__init__()
351
+ self.config = config
352
+ self.num_stack_heads = getattr(config, 'num_stack_heads', 4)
353
+ self.stack_slots = getattr(config, 'stack_slots', 24)
354
+ self.stack_d_model = getattr(config, 'stack_d_model', 128)
355
+
356
+ self.head_dim = self.stack_d_model // self.num_stack_heads
357
+
358
+ # Dimension reduction projections for efficiency
359
+ # Uses standard nn.Linear
360
+ self.down_proj = nn.Linear(config.hidden_size, self.stack_d_model, bias=False)
361
+ self.up_proj = nn.Linear(self.stack_d_model, config.hidden_size, bias=False)
362
+
363
+ # Action prediction: generates push/pop/no-op probabilities for each head
364
+ self.action_head = nn.Linear(self.stack_d_model, 3 * self.num_stack_heads, bias=True)
365
+
366
+ # Query projection for global reading (one per head)
367
+ self.gate_proj = nn.Linear(self.head_dim, 1, bias=False)
368
+
369
+ # Residual weight for gating stack contribution
370
+ self.res_weight = nn.Parameter(torch.ones(1))
371
+
372
+ def _vectorized_update(
373
+ self,
374
+ stack: torch.Tensor,
375
+ mask: torch.Tensor,
376
+ actions: torch.Tensor,
377
+ k_values: torch.Tensor
378
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
379
+ """
380
+ Vectorized stack update mechanism applying soft push/pop/no-op operations.
381
+
382
+ Implements the differentiable stack operations from the paper:
383
+ - Push: shifts all elements down and places k_values at top
384
+ - Pop: shifts all elements up and removes top
385
+ - No-op: maintains current stack state
386
+
387
+ Args:
388
+ stack: Current stack state [batch, seq, num_heads, stack_slots, head_dim]
389
+ mask: Current stack mask [batch, seq, num_heads, stack_slots]
390
+ actions: Action probabilities [batch, seq, num_heads, 3] (push/pop/no-op)
391
+ k_values: New values to push [batch, seq, num_heads, head_dim]
392
+
393
+ Returns:
394
+ Tuple of (updated_stack, updated_mask)
395
+ """
396
+ batch_size, seq_len = actions.shape[:2]
397
+
398
+ # Expand stack and mask along sequence dimension for parallel processing
399
+ stack = stack.unsqueeze(1).expand(-1, seq_len, -1, -1, -1)
400
+ mask = mask.unsqueeze(1).expand(-1, seq_len, -1, -1)
401
+
402
+ # Generate pushed stack: new value at top, shift others down
403
+ push_stack = torch.cat([
404
+ k_values.unsqueeze(3), # New value at position 0
405
+ stack[:, :, :, :-1] # Shift existing elements down
406
+ ], dim=3)
407
+ push_mask = torch.cat([
408
+ torch.ones_like(mask[:, :, :, :1]),
409
+ mask[:, :, :, :-1]
410
+ ], dim=3)
411
+
412
+ # Generate popped stack: shift all up, zero at bottom
413
+ pop_stack = torch.cat([
414
+ stack[:, :, :, 1:],
415
+ torch.zeros_like(stack[:, :, :, :1])
416
+ ], dim=3)
417
+ pop_mask = torch.cat([
418
+ mask[:, :, :, 1:],
419
+ torch.zeros_like(mask[:, :, :, :1])
420
+ ], dim=3)
421
+
422
+ # Combine operations weighted by action probabilities
423
+ action_weights = actions.unsqueeze(-1).unsqueeze(-1) # [batch, seq, heads, 3, 1, 1]
424
+ stacks = torch.stack([push_stack, pop_stack, stack], dim=3) # [batch, seq, heads, 3, slots, dim]
425
+ masks = torch.stack([push_mask, pop_mask, mask], dim=3) # [batch, seq, heads, 3, slots]
426
+
427
+ # Weighted combination of all operations
428
+ new_stack = (stacks * action_weights).sum(dim=3)
429
+ new_mask = (masks * action_weights.squeeze(-1)).sum(dim=3)
430
+
431
+ return new_stack, new_mask
432
+
433
+ def forward(
434
+ self,
435
+ hidden_states: torch.Tensor,
436
+ stack: Optional[torch.Tensor] = None,
437
+ mask: Optional[torch.Tensor] = None
438
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
439
+ """
440
+ Apply differentiable stack operations to hidden states.
441
+
442
+ Args:
443
+ hidden_states: Input hidden states [batch, seq, hidden_size]
444
+ stack: Previous stack state [batch, num_heads, stack_slots, head_dim] or None
445
+ mask: Previous stack mask [batch, num_heads, stack_slots] or None
446
+
447
+ Returns:
448
+ Tuple of (output_hidden_states, updated_stack, updated_mask)
449
+ """
450
+ batch_size, seq_len, _ = hidden_states.shape
451
+ device = hidden_states.device
452
+
453
+ # Initialize stack and mask if not provided
454
+ if stack is None:
455
+ stack = torch.zeros(
456
+ batch_size, self.num_stack_heads, self.stack_slots, self.head_dim,
457
+ device=device, dtype=hidden_states.dtype
458
+ )
459
+ if mask is None:
460
+ mask = torch.zeros(
461
+ batch_size, self.num_stack_heads, self.stack_slots,
462
+ device=device, dtype=hidden_states.dtype
463
+ )
464
+
465
+ # Project to lower dimension for efficiency
466
+ new_hidden_states = self.down_proj(hidden_states)
467
+
468
+ # Generate action probabilities: [batch, seq, num_heads, 3]
469
+ action_logits = self.action_head(new_hidden_states) / math.sqrt(self.head_dim)
470
+ actions = F.softmax(
471
+ action_logits.view(batch_size, seq_len, self.num_stack_heads, 3),
472
+ dim=-1
473
+ )
474
+
475
+ # Prepare values to push (split into heads)
476
+ k_values = new_hidden_states.view(batch_size, seq_len, self.num_stack_heads, self.head_dim)
477
+
478
+ # Update stack and mask using vectorized operations
479
+ new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values)
480
+
481
+ # Global reading via query-over-stack attention
482
+ # Apply mask before attention computation
483
+ masked_stack = new_stack * new_mask.unsqueeze(-1)
484
+
485
+ # Compute attention scores for each head
486
+ gate_scores = self.gate_proj(masked_stack).squeeze(-1) # [batch, seq, heads, slots]
487
+
488
+ # Mask out invalid positions (add large negative value)
489
+ gate_scores = gate_scores + (1 - new_mask) * -1e9
490
+
491
+ # Softmax to get attention weights
492
+ gate_weights = F.softmax(gate_scores, dim=-1)
493
+
494
+ # Weighted sum over stack slots
495
+ memory_output = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3)
496
+ memory_output = memory_output.view(batch_size, seq_len, -1)
497
+
498
+ # Project back to original dimension
499
+ memory_output = self.up_proj(memory_output)
500
+
501
+ # Gated residual connection
502
+ output = memory_output * self.res_weight + hidden_states
503
+
504
+ # Return output and updated stack state (use last timestep's state)
505
+ return output, new_stack[:, -1], new_mask[:, -1]
506
+
507
+ # ==================== ROTARY EMBEDDING ====================
508
+
509
  class NeoLLMRotaryEmbedding(nn.Module):
510
  inv_freq: torch.Tensor # fix linting for `register_buffer`
511
 
 
613
  ResFormer feature residual connections, and Learnable Multipliers for enhanced
614
  information flow and scale adaptation.
615
 
616
+ ResFormer enhancement: Applies learnable feature residual connections from first layer
617
  BEFORE QKV projections: H'_fan_n = λ_1 * H_fan_1 + λ_2 * H_fan_n
618
 
619
  Learnable Multipliers placement (from "Learnable Multipliers" paper Appendix C):
 
675
  self.dropout = nn.Dropout(config.dropout_rate)
676
 
677
  # ResFormer: learnable feature residual parameters (initialized to 0.5)
678
+ self.lambda_1 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_1
679
+ self.lambda_2 = nn.Parameter(torch.tensor(0.5)) # Weight for H_fan_n
680
 
681
  def forward(
682
  self,
683
  hidden_states: torch.Tensor,
684
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
685
+ attention_mask: Optional[torch.Tensor] = None,
686
  first_layer_fan: Optional[torch.Tensor] = None,
687
  **kwargs: Unpack[FlashAttentionKwargs],
688
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
689
+ """
690
+ Forward pass with ResFormer feature residual connections.
691
+
692
+ Args:
693
+ hidden_states: Current layer input [batch, seq, hidden_size]
694
+ position_embeddings: Tuple of (cos, sin) for RoPE
695
+ attention_mask: Causal attention mask
696
+ first_layer_fan: First layer FAN features (for ResFormer)
697
+
698
+ Returns:
699
+ Tuple of (attn_output, attn_weights, current_layer_fan)
700
+ """
701
  input_shape = hidden_states.shape[:-1]
702
 
703
+ # Apply FANformer transformation
704
  hidden_states_fan = self.fan_layer(hidden_states)
705
 
706
  # ResFormer: Apply feature residual connection BEFORE projections
 
707
  if first_layer_fan is not None:
708
  hidden_states_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * hidden_states_fan
709
 
710
+ # Store current FAN features for ResFormer
711
  current_layer_fan = hidden_states_fan.clone()
712
 
713
  hidden_shape = (*input_shape, -1, self.head_dim)
714
 
 
715
  # Q projection with learnable row multipliers
716
  query_states, gate = torch.chunk(
717
  self.q_proj(hidden_states_fan).view(*input_shape, -1, self.head_dim * 2), 2, dim=-1
 
832
  hidden = self.dropout(hidden)
833
  return self.down_proj(hidden)
834
 
835
+
836
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
837
  """
838
+ Decoder layer with standard residual connections and optional StackMemory.
839
 
840
+ Architecture:
841
+ 1. Pre-norm (SeeDNorm) → LNS scaling → Self-Attention with ResFormer and Learnable Multipliers
842
+ 2. Standard Residual Connection
843
  3. GPAS activation scaling
844
+ 4. Pre-norm (SeeDNorm) → LNS scaling → MLP with FANformer and Learnable Multipliers
845
+ 5. Standard Residual Connection
846
  6. GPAS activation scaling
847
+ 7. Optional: StackMemory module
848
  """
849
 
850
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
 
858
  # MLP with FANformer integration and learnable multipliers
859
  self.mlp = NeoLLMMLP(config)
860
 
861
+ # SeeDNorm for input and post-attention normalization
862
  self.input_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
863
  self.post_attention_layernorm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
864
 
 
866
  self.lns_attn = LNS(layer_idx)
867
  self.lns_mlp = LNS(layer_idx)
868
 
869
+ # GPAS (Gradient-Preserving Activation Scaling)
870
  self.gpas_attn = GPAS(config.hidden_size)
871
  self.gpas_mlp = GPAS(config.hidden_size)
872
 
873
+ # StackMemory: Differentiable hidden state stack
874
+ self.use_stack = getattr(config, 'use_stack', False)
875
+ if self.use_stack:
876
+ self.stack_memory = StackMemory(config)
877
+
878
  # ResFormer: storage for current layer's FAN features
879
  self.current_layer_fan = None
880
 
 
884
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
885
  attention_mask: Optional[torch.Tensor] = None,
886
  first_layer_fan: Optional[torch.Tensor] = None,
887
+ stack_state: Optional[torch.Tensor] = None,
888
+ stack_mask: Optional[torch.Tensor] = None,
889
  output_attentions: Optional[bool] = False,
890
  **kwargs: Unpack[FlashAttentionKwargs],
891
+ ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
892
+ """
893
+ Forward pass with ResFormer and optional StackMemory.
894
+
895
+ Args:
896
+ hidden_states: Current layer input [batch, seq, hidden_size]
897
+ position_embeddings: Tuple of (cos, sin) for RoPE
898
+ attention_mask: Causal attention mask
899
+ first_layer_fan: First layer FAN features (for ResFormer)
900
+ stack_state: StackMemory state (optional)
901
+ stack_mask: StackMemory mask (optional)
902
+ output_attentions: Whether to return attention weights
903
+
904
+ Returns:
905
+ Tuple of (hidden_states, attn_weights, stack_state, stack_mask)
906
+ """
907
  # ============================================================
908
+ # Attention Block with Standard Residual Connection
909
  # ============================================================
910
  residual = hidden_states
911
 
 
915
  # Apply LNS scaling after normalization
916
  hidden_states = self.lns_attn(hidden_states)
917
 
918
+ # Self Attention with ResFormer
919
+ attn_output, attn_weights, self.current_layer_fan = self.self_attn(
 
920
  hidden_states=hidden_states,
 
921
  position_embeddings=position_embeddings,
922
+ attention_mask=attention_mask,
923
  first_layer_fan=first_layer_fan,
924
  **kwargs,
925
  )
926
 
927
+ # Standard Residual Connection
928
+ hidden_states = residual + attn_output
929
 
930
+ # Apply GPAS after residual connection
931
  hidden_states = self.gpas_attn(hidden_states)
932
 
933
  # ============================================================
934
+ # MLP Block with Standard Residual Connection
935
  # ============================================================
936
  residual = hidden_states
937
  hidden_states = self.post_attention_layernorm(hidden_states)
 
939
  # Apply LNS scaling after normalization
940
  hidden_states = self.lns_mlp(hidden_states)
941
 
942
+ # MLP with FANformer
943
+ mlp_output = self.mlp(hidden_states)
944
 
945
+ # Standard Residual Connection
946
+ hidden_states = residual + mlp_output
947
 
948
+ # Apply GPAS after residual connection
949
  hidden_states = self.gpas_mlp(hidden_states)
950
 
951
+ # ============================================================
952
+ # Stack Memory Module
953
+ # ============================================================
954
+ if self.use_stack:
955
+ hidden_states, stack_state, stack_mask = self.stack_memory(
956
+ hidden_states, stack_state, stack_mask
957
+ )
958
 
959
+ if self.use_stack:
960
+ return (hidden_states, attn_weights, stack_state, stack_mask)
961
+ else:
962
+ return (hidden_states, attn_weights, None, None)
963
 
964
 
965
  class NeoLLMPreTrainedModel(PreTrainedModel):
 
972
  - FANLayer (Fourier Analysis Network)
973
  - SeeDNorm (Self-Rescaled Dynamic Normalization)
974
  - Learnable Multipliers (ScalarMultiplier, VectorMultiplier)
975
+ - StackMemory (Differentiable Hidden State Stack)
976
  """
977
  config: NeoLLMConfig
978
  base_model_prefix = "model"
 
985
  def _init_weights(self, module):
986
  """
987
  Initialize weights for all custom modules in NeoLLM.
 
 
 
 
 
988
  """
989
  super()._init_weights(module)
990
 
991
  if isinstance(module, NeoLLMAttention):
 
 
 
992
  if hasattr(module, 'lambda_1'):
993
  module.lambda_1.data.fill_(0.5)
994
  if hasattr(module, 'lambda_2'):
995
  module.lambda_2.data.fill_(0.5)
996
 
997
  elif isinstance(module, GPAS):
 
 
998
  module.alpha.data.fill_(0.0)
999
 
 
 
 
 
 
 
 
 
 
 
 
 
1000
  elif isinstance(module, (ScalarMultiplier, VectorMultiplier)):
 
 
 
1001
  if hasattr(module, 'multiplier'):
1002
  module.multiplier.data.fill_(1.0)
1003
+
1004
+ elif isinstance(module, StackMemory):
1005
+ std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02
1006
+ if hasattr(module, 'down_proj'):
1007
+ module.down_proj.weight.data.normal_(mean=0.0, std=std)
1008
+ if hasattr(module, 'up_proj'):
1009
+ module.up_proj.weight.data.normal_(mean=0.0, std=std)
1010
+ if hasattr(module, 'action_head'):
1011
+ module.action_head.weight.data.normal_(mean=0.0, std=std)
1012
+ if module.action_head.bias is not None:
1013
+ module.action_head.bias.data.zero_()
1014
+ if hasattr(module, 'gate_proj'):
1015
+ module.gate_proj.weight.data.normal_(mean=0.0, std=std)
1016
+ if hasattr(module, 'res_weight'):
1017
+ module.res_weight.data.fill_(1.0)
1018
+
1019
 
1020
  class NeoLLMModel(NeoLLMPreTrainedModel):
1021
  """
1022
  NeoLLM base model with transformer decoder architecture.
1023
 
1024
+ Uses ResFormer for first-layer feature propagation with standard residual connections
1025
+ and optional StackMemory for hierarchical pattern modeling.
1026
+
1027
  Note on embeddings and weight tying: This model uses weight tying between
1028
  embed_tokens and lm_head (shared weights). Following "Learnable Multipliers"
1029
  paper analysis, we do NOT add multipliers to embeddings because:
1030
 
1031
+ 1. Weight tying creates conflicting gradient paths
1032
+ 2. The paper explicitly warns against multipliers in lm_head
1033
+ 3. Compensating mechanisms provide scale adaptation immediately after embedding
 
 
 
 
 
 
 
 
 
 
1034
  """
1035
 
1036
  def __init__(self, config: NeoLLMConfig):
1037
  super().__init__(config)
1038
 
1039
  # Standard embedding without learnable multipliers
 
 
1040
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
1041
 
1042
  # Each layer creates its own components (no shared parameters)
 
1049
  self.rotary_emb = NeoLLMRotaryEmbedding(config=config)
1050
  self.gradient_checkpointing = False
1051
 
1052
+ # Configuration
1053
+ self.use_stack = getattr(config, 'use_stack', False)
1054
+
1055
+ # ResFormer: storage for first layer's FAN features
1056
  self.first_layer_fan = None
1057
 
1058
  # Initialize weights and apply final processing
 
1083
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1084
 
1085
  if inputs_embeds is None:
 
 
 
 
1086
  inputs_embeds = self.embed_tokens(input_ids)
1087
 
1088
  if position_ids is None:
 
1101
  all_hidden_states = () if output_hidden_states else None
1102
  all_attentions = () if output_attentions else None
1103
 
1104
+ # Create position embeddings to be shared across the decoder layers
1105
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
1106
 
1107
+ # ResFormer with first-layer feature propagation
1108
  self.first_layer_fan = None
1109
+ stack_state = None
1110
+ stack_mask = None
1111
+
1112
+ for decoder_layer in self.layers:
1113
  if output_hidden_states:
1114
  all_hidden_states = all_hidden_states + (hidden_states,)
1115
 
 
1117
  hidden_states,
1118
  position_embeddings=position_embeddings,
1119
  attention_mask=causal_mask,
1120
+ first_layer_fan=self.first_layer_fan,
1121
+ stack_state=stack_state,
1122
+ stack_mask=stack_mask,
1123
  output_attentions=output_attentions,
1124
  **kwargs,
1125
  )
 
1129
  if output_attentions:
1130
  all_attentions = all_attentions + (layer_outputs[1],)
1131
 
1132
+ if self.use_stack:
1133
+ stack_state = layer_outputs[2]
1134
+ stack_mask = layer_outputs[3]
1135
+
1136
  # ResFormer: capture H_fan_1 from the first layer
1137
  if self.first_layer_fan is None and hasattr(decoder_layer, 'current_layer_fan'):
1138
  self.first_layer_fan = decoder_layer.current_layer_fan
 
1186
  """
1187
  Causal Language Model with NeoLLM architecture.
1188
 
1189
+ Supports ResFormer with standard residuals and optional StackMemory.
1190
+
1191
  Note on LM head: Following "Learnable Multipliers" paper recommendations,
1192
+ the output projection (lm_head) does NOT include learnable multipliers.
 
 
 
1193
  """
1194
  _tied_weights_keys = ["lm_head.weight"]
1195
 
 
1199
  self.vocab_size = config.vocab_size
1200
 
1201
  # LM head without learnable multipliers (standard linear layer)
 
1202
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1203
 
1204
  self.post_init()
 
1263
  "ScalarMultiplier",
1264
  "VectorMultiplier",
1265
  "LinearWithMultipliers",
1266
+ "StackMemory",
1267
  ]
1268
 
1269
  # Register the configuration and model for AutoClass support