KitsuVp commited on
Commit
eb219fb
Β·
verified Β·
1 Parent(s): 0b14653

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +1647 -56
modeling_neollm.py CHANGED
@@ -77,10 +77,34 @@ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_u
77
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
78
  from transformers.processing_utils import Unpack
79
  from transformers.utils import TransformersKwargs, logging
80
- from .configuration_neollm import NeoLLMConfig
81
 
82
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  logger = logging.get_logger(__name__)
85
 
86
 
@@ -339,6 +363,22 @@ class JTokMAnalysis:
339
  lns_scale: Optional[float] = None # 1/√(2β„“) scaling factor
340
 
341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  @dataclass
343
  class AttnResAnalysis:
344
  """
@@ -350,6 +390,44 @@ class AttnResAnalysis:
350
  sources_count: Optional[int] = None # number of sources including partial
351
 
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  @dataclass
354
  class LayerAnalysis:
355
  """
@@ -378,8 +456,12 @@ class LayerAnalysis:
378
  gpas_mlp: Optional[GPASAnalysis] = None # GPAS after MLP residual
379
 
380
  # Optional components (None when inactive)
381
- jtokm: Optional[JTokMAnalysis] = None # if use_jtokm
382
- attn_res: Optional[AttnResAnalysis] = None # if use_attn_res
 
 
 
 
383
 
384
 
385
  @dataclass
@@ -444,6 +526,7 @@ class AnalysisState:
444
  layers: Optional[List[LayerAnalysis]] = None
445
  jtokm_aux_stats: Optional[list] = None
446
  attn_res_sources_final: Optional[list] = None
 
447
  logits: Optional[torch.Tensor] = None
448
 
449
  class ScalarMultiplier(nn.Module):
@@ -2363,6 +2446,8 @@ class NeoLLMAttention(nn.Module):
2363
  first_layer_fan: Optional[torch.Tensor] = None,
2364
  attn_analysis: Optional[AttentionAnalysis] = None,
2365
  repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
 
 
2366
  **kwargs: Unpack[FlashAttentionKwargs],
2367
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
2368
  input_shape = hidden_states.shape[:-1]
@@ -2373,6 +2458,14 @@ class NeoLLMAttention(nn.Module):
2373
  h_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * h_fan
2374
  current_layer_fan = h_fan.clone()
2375
 
 
 
 
 
 
 
 
 
2376
  query_shape = (*input_shape, self.config.num_attention_heads, self.head_dim)
2377
  kv_shape = (*input_shape, self.num_mea_component_heads, self.head_dim)
2378
 
@@ -2387,8 +2480,8 @@ class NeoLLMAttention(nn.Module):
2387
  attn_analysis.gate_raw = gate.detach()
2388
 
2389
  q = self.q_norm(q_raw.view(query_shape)).transpose(1, 2)
2390
- k = self.k_norm(self.k_proj(h_fan).view(kv_shape)).transpose(1, 2)
2391
- v = self.v_proj(h_fan).view(kv_shape).transpose(1, 2)
2392
 
2393
  if attn_analysis is not None:
2394
  attn_analysis.q_post_norm = q.detach()
@@ -3065,6 +3158,1036 @@ class NeoLLMMLP(nn.Module):
3065
  return result
3066
 
3067
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3068
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
3069
  """
3070
  Decoder layer with standard residual connections, optional JTok-M injection.
@@ -3087,7 +4210,23 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
3087
  self.layer_idx = layer_idx
3088
  self.use_jtokm = config.use_jtokm
3089
 
3090
- self.self_attn = NeoLLMAttention(config, layer_idx)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3091
  self.mlp = (
3092
  VersatileFFN(config)
3093
  if getattr(config, "use_versatile_ffn", False)
@@ -3120,10 +4259,78 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
3120
  self.attn_res_query_attn = nn.Parameter(torch.zeros(config.hidden_size))
3121
  self.attn_res_query_mlp = nn.Parameter(torch.zeros(config.hidden_size))
3122
  self.attn_res_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
 
 
3123
  else:
3124
  self.attn_res_query_attn = None
3125
  self.attn_res_query_mlp = None
3126
  self.attn_res_norm = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3127
 
3128
  def _attn_res(
3129
  self,
@@ -3173,6 +4380,10 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
3173
  B_vals: Optional[torch.Tensor] = None,
3174
  attn_res_sources: Optional[list] = None,
3175
  attn_res_partial: Optional[torch.Tensor] = None,
 
 
 
 
3176
  layer_analysis: Optional[LayerAnalysis] = None,
3177
  output_attentions: Optional[bool] = False,
3178
  repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
@@ -3182,6 +4393,63 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
3182
  if layer_analysis is not None:
3183
  layer_analysis.hidden_states_input = hidden_states.detach()
3184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3185
  # ── Attention Residuals: compute pre-attention input ──────────────
3186
  # When active, the input to the attention sublayer is no longer the
3187
  # raw hidden_states (accumulated residual) but a softmax-weighted
@@ -3195,10 +4463,19 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
3195
  attn_res_sources, attn_res_partial, self.attn_res_query_attn,
3196
  ar_analysis, "attn",
3197
  )
 
 
 
 
 
 
 
 
3198
  residual_attn = attn_res_partial
3199
  else:
3200
- h_attn = hidden_states
3201
- residual_attn = hidden_states
 
3202
 
3203
  # ── Attention block ───────────────────────────────────────────────
3204
  sn_pre = layer_analysis.seednorm_pre_attn if layer_analysis is not None else None
@@ -3207,21 +4484,49 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
3207
  if layer_analysis is not None:
3208
  layer_analysis.lns_attn_output = h_lns.detach()
3209
 
3210
- hidden_states, attn_weights, self.current_layer_fan = self.self_attn(
3211
- hidden_states=h_lns,
3212
- attention_mask=attention_mask,
3213
- position_embeddings=position_embeddings,
3214
- first_layer_fan=first_layer_fan,
3215
- attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
3216
- repo_rope_args=repo_rope_args,
3217
- **kwargs,
3218
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3219
 
3220
  if layer_analysis is not None:
3221
  layer_analysis.attn_contribution = hidden_states.detach()
3222
 
3223
  gpas_attn_a = layer_analysis.gpas_attn if layer_analysis is not None else None
3224
- h_tilde = self.gpas_attn(residual_attn + hidden_states, analysis=gpas_attn_a)
 
 
 
 
 
 
 
 
 
 
 
3225
 
3226
  if layer_analysis is not None:
3227
  layer_analysis.h_tilde = h_tilde.detach()
@@ -3257,8 +4562,11 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
3257
  if layer_analysis is not None:
3258
  layer_analysis.mlp_contribution = delta_m.detach()
3259
 
3260
- # ── JTok-M injection (additive alongside MLP residual) ────────────
3261
- aux_stats = None
 
 
 
3262
  if self.use_jtokm and z_tilde is not None and B_vals is not None:
3263
  orig_shape = h_tilde.shape
3264
  h_flat = h_tilde.reshape(-1, self.hidden_size)
@@ -3269,11 +4577,21 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
3269
  delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat, analysis=jtokm_a)
3270
  delta_r = delta_r.reshape(orig_shape)
3271
 
3272
- gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
3273
- hidden_states = self.gpas_mlp(residual_mlp + delta_m + delta_r, analysis=gpas_mlp_a)
 
 
 
 
 
3274
  else:
3275
- gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
3276
- hidden_states = self.gpas_mlp(residual_mlp + delta_m, analysis=gpas_mlp_a)
 
 
 
 
 
3277
 
3278
  if layer_analysis is not None:
3279
  layer_analysis.hidden_states_output = hidden_states.detach()
@@ -3285,6 +4603,9 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
3285
  outputs += (aux_stats,)
3286
  if versatile_aux is not None:
3287
  outputs += (versatile_aux,)
 
 
 
3288
  return outputs
3289
 
3290
 
@@ -3612,6 +4933,9 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
3612
  if hasattr(module, "alpha_ma"):
3613
  module.alpha_ma.zero_()
3614
 
 
 
 
3615
  elif isinstance(module, GPAS):
3616
  module.alpha.data.fill_(0.0)
3617
 
@@ -3668,6 +4992,45 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
3668
  module.attn_res_query_attn.data.zero_()
3669
  module.attn_res_query_mlp.data.zero_()
3670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3671
  elif isinstance(module, SpellingBeeEmbedding):
3672
  # byte_emb initialised identically to token embeddings: std=1/√d.
3673
  # Ensures E[β€–e_byteβ€–Β²] β‰ˆ 1 at init, matching etok, so the
@@ -3737,8 +5100,99 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
3737
  self.gradient_checkpointing = False
3738
  self.first_layer_fan = None
3739
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3740
  self.post_init()
3741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3742
  def get_input_embeddings(self):
3743
  if self.config.use_token_generator:
3744
  return self.token_generator
@@ -3764,7 +5218,9 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
3764
  getattr(cfg, "use_repo", False)
3765
  and layer_idx >= getattr(cfg, "repo_start_layer", cfg.num_hidden_layers // 3)
3766
  )
3767
- _versatile = getattr(cfg, "use_versatile_ffn", False)
 
 
3768
  return LayerAnalysis(
3769
  seednorm_pre_attn = SeeDNormAnalysis(),
3770
  seednorm_post_attn = SeeDNormAnalysis(),
@@ -3778,10 +5234,14 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
3778
  polynorm = PolyNormAnalysis() if not _versatile else None,
3779
  versatile = VersatileFFNAnalysis() if _versatile else None,
3780
  ),
3781
- gpas_attn = GPASAnalysis(),
3782
- gpas_mlp = GPASAnalysis(),
3783
- jtokm = JTokMAnalysis() if cfg.use_jtokm else None,
3784
- attn_res = AttnResAnalysis() if getattr(cfg, "use_attn_res", False) else None,
 
 
 
 
3785
  )
3786
 
3787
  def forward(
@@ -3883,6 +5343,10 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
3883
  if getattr(self.config, "use_repo", False) else None
3884
  )
3885
 
 
 
 
 
3886
  # ── Attention Residuals state ──────────────────────────────────────
3887
  # Full AttnRes (attn_res_num_blocks=0): sources grows by one entry per
3888
  # decoder layer β€” all previous outputs are kept, max N=num_layers+1.
@@ -3897,13 +5361,57 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
3897
  if use_attn_res:
3898
  attn_res_sources = [hidden_states] # b_0 = token embedding
3899
  attn_res_partial = hidden_states # initial partial sum
3900
-
3901
- num_blocks = getattr(self.config, 'attn_res_num_blocks', 0)
3902
- block_size = (
3903
- max(self.config.num_hidden_layers // num_blocks, 1)
3904
- if num_blocks > 0
3905
- else 1 # Full AttnRes: every layer is its own "block"
3906
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3907
 
3908
  # Pre-allocate per-layer analysis list when analysis is active
3909
  if analysis_state is not None:
@@ -3913,17 +5421,13 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
3913
  if output_hidden_states:
3914
  all_hidden_states = all_hidden_states + (hidden_states,)
3915
 
3916
- # ── Block AttnRes: boundary handling ──────────────────────────
3917
- # At each block boundary (excluding layer 0): append the current
3918
- # partial sum to sources as a completed block summary, then reset
3919
- # partial to None so the new block builds from scratch β€” matching
3920
- # the paper's pseudocode exactly.
3921
- # For Full AttnRes (block_size=1): every layer is a boundary, so
3922
- # partial is appended and reset after every layer. The partial is
3923
- # re-seeded from the previous hidden_states below.
3924
- if use_attn_res and layer_idx > 0 and layer_idx % block_size == 0:
3925
- attn_res_sources = attn_res_sources + [attn_res_partial]
3926
- attn_res_partial = hidden_states # start new block from current output
3927
 
3928
  # Build per-layer analysis container (only in eval + analysis mode)
3929
  layer_analysis = None
@@ -3932,15 +5436,26 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
3932
  layer_analysis.layer_idx = layer_idx
3933
  analysis_state.layers.append(layer_analysis)
3934
 
 
 
 
 
 
 
 
3935
  layer_outputs = decoder_layer(
3936
  hidden_states,
3937
  position_embeddings=position_embeddings,
3938
- attention_mask=causal_mask,
3939
  first_layer_fan=self.first_layer_fan,
3940
  z_tilde=z_tilde,
3941
  B_vals=B_vals,
3942
  attn_res_sources=attn_res_sources,
3943
  attn_res_partial=attn_res_partial if use_attn_res else None,
 
 
 
 
3944
  layer_analysis=layer_analysis,
3945
  output_attentions=output_attentions,
3946
  repo_rope_args=repo_rope_args,
@@ -3948,23 +5463,76 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
3948
  )
3949
  hidden_states = layer_outputs[0]
3950
 
 
 
 
 
 
 
3951
  # Update AttnRes partial sum β€” the new partial is the layer output
3952
  if use_attn_res:
3953
  attn_res_partial = hidden_states
3954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3955
  if output_attentions:
3956
  all_attentions = all_attentions + (layer_outputs[1],)
3957
 
3958
- # Collect JTok-M aux stats (last element if present)
3959
- if self.config.use_jtokm and len(layer_outputs) > (2 if output_attentions else 1):
3960
- all_aux_stats.append(layer_outputs[-1])
 
 
 
 
3961
 
3962
- # Collect VersatileFFN aux stats (second-to-last if jtokm also present,
3963
- # or last if jtokm is absent). Only non-None during training.
3964
  if getattr(self.config, "use_versatile_ffn", False):
3965
- for item in layer_outputs[1:]:
3966
  if isinstance(item, tuple) and len(item) == 3:
3967
- # (p_sum, f_sum, N_tokens) signature
3968
  all_aux_stats.append(("versatile", item))
3969
  break
3970
 
@@ -3972,6 +5540,16 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
3972
  and hasattr(decoder_layer, "current_layer_fan")):
3973
  self.first_layer_fan = decoder_layer.current_layer_fan
3974
 
 
 
 
 
 
 
 
 
 
 
3975
  hidden_states = self.norm(hidden_states)
3976
 
3977
  if output_hidden_states:
@@ -3984,6 +5562,9 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
3984
  analysis_state.attn_res_sources_final = (
3985
  attn_res_sources if use_attn_res else None
3986
  )
 
 
 
3987
 
3988
  if not return_dict:
3989
  return tuple(
@@ -4124,6 +5705,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
4124
  layers = None, # filled by NeoLLMModel.forward
4125
  jtokm_aux_stats = [] if cfg.use_jtokm else None,
4126
  attn_res_sources_final = [] if getattr(cfg, "use_attn_res", False) else None,
 
4127
  )
4128
 
4129
  # ── Standard model API ────────────────────────────────────────────────
@@ -4261,6 +5843,12 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
4261
  # ==================== AUTOMODEL REGISTRATION ====================
4262
 
4263
  __all__ = [
 
 
 
 
 
 
4264
  "NeoLLMForCausalLM",
4265
  "NeoLLMModel",
4266
  "NeoLLMPreTrainedModel",
@@ -4278,7 +5866,7 @@ __all__ = [
4278
  "REPOModule",
4279
  "VersatileFFN",
4280
  "compute_versatile_aux_loss",
4281
- # Analysis dataclasses β€” exported so external tools can type-hint against them
4282
  "AnalysisState",
4283
  "LayerAnalysis",
4284
  "AttentionAnalysis",
@@ -4292,6 +5880,9 @@ __all__ = [
4292
  "VersatileFFNAnalysis",
4293
  "JTokMAnalysis",
4294
  "AttnResAnalysis",
 
 
 
4295
  "GeneratorAnalysis",
4296
  ]
4297
 
 
77
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
78
  from transformers.processing_utils import Unpack
79
  from transformers.utils import TransformersKwargs, logging
80
+ from configuration_neollm import NeoLLMConfig
81
 
82
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
83
 
84
+ # ── Optional fast-path dependencies (GatedDeltaNet linear attention) ─────────
85
+ try:
86
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update as _causal_conv1d_update
87
+ _causal_conv1d_available = True
88
+ except ImportError:
89
+ causal_conv1d_fn = None
90
+ _causal_conv1d_update = None
91
+ _causal_conv1d_available = False
92
+
93
+ try:
94
+ from fla.modules import FusedRMSNormGated
95
+ from fla.ops.gated_delta_rule import (
96
+ chunk_gated_delta_rule,
97
+ fused_recurrent_gated_delta_rule,
98
+ )
99
+ _fla_available = True
100
+ except ImportError:
101
+ FusedRMSNormGated = None
102
+ chunk_gated_delta_rule = None
103
+ fused_recurrent_gated_delta_rule = None
104
+ _fla_available = False
105
+
106
+ is_linear_attn_fast_path = _causal_conv1d_available and _fla_available
107
+
108
  logger = logging.get_logger(__name__)
109
 
110
 
 
363
  lns_scale: Optional[float] = None # 1/√(2β„“) scaling factor
364
 
365
 
366
+ @dataclass
367
+ class DCAAnalysis:
368
+ """
369
+ GRN-v3 depth-wise aggregate weights from a DeepCrossAttention layer.
370
+ Only populated when use_dca=True.
371
+
372
+ grn_depth_weights: softmax-free aggregate scalars used to weight each
373
+ source layer, shape [3, y, B, S] where 3 = Q/K/V streams,
374
+ y = selected stack depth (at most 2*dca_k), B = batch, S = seq.
375
+ These are the per-position, per-layer scalars *before* adding the
376
+ static bias β€” useful to see which layers the dynamic component
377
+ selectively suppresses (ReLU zeros out negative entries).
378
+ """
379
+ grn_depth_weights: Optional[torch.Tensor] = None # [3, y, B, S]
380
+
381
+
382
  @dataclass
383
  class AttnResAnalysis:
384
  """
 
390
  sources_count: Optional[int] = None # number of sources including partial
391
 
392
 
393
+ @dataclass
394
+ class StackMemoryAnalysis:
395
+ """
396
+ Internals of a StackMemory forward pass.
397
+ Only populated when use_stacktrans=True AND model is in eval + analysis mode.
398
+
399
+ Reference: Zhang, K. et al. (NeurIPS 2025). "Recursive Transformer:
400
+ Boosting Reasoning Ability with State Stack."
401
+
402
+ action_probs: softmax distribution [push, pop, no-op] per head and
403
+ token position. Shape [B, S, H, 3]. Visualising this
404
+ across layers reveals the push-heavy early layers and
405
+ pop-heavy later layers described in the paper (Β§B.2).
406
+ stack_in: stack state entering this layer (the output of the
407
+ previous layer's StackMemory). Shape [B, H, slots, ds].
408
+ None for layer 0 (starts as all-zeros).
409
+ stack_out: updated stack state after processing this sequence.
410
+ Shape [B, H, slots, ds]. This is new_stack[:, -1] β€”
411
+ the stack at the final sequence position, passed to
412
+ the next layer as stack_in.
413
+ mask_out: validity mask for stack_out. Shape [B, H, slots].
414
+ Values near 1 indicate active slots; near 0 = empty.
415
+ gate_weights: softmax attention weights used for global reading.
416
+ Shape [B, S, H, slots]. High weight on slot i at
417
+ position t means the model retrieved from slot i there.
418
+ memory_output: weighted stack readout before up_proj.
419
+ Shape [B, S, stack_d_model].
420
+ residual_scale: value of the learnable res_weight scalar at this step.
421
+ """
422
+ action_probs: Optional[torch.Tensor] = None # [B,S,H,3]
423
+ stack_in: Optional[torch.Tensor] = None # [B,H,slots,ds] entering layer
424
+ stack_out: Optional[torch.Tensor] = None # [B,H,slots,ds] leaving layer
425
+ mask_out: Optional[torch.Tensor] = None # [B,H,slots]
426
+ gate_weights: Optional[torch.Tensor] = None # [B,S,H,slots]
427
+ memory_output: Optional[torch.Tensor] = None # [B,S,stack_d_model]
428
+ residual_scale: Optional[float] = None # res_weight scalar
429
+
430
+
431
  @dataclass
432
  class LayerAnalysis:
433
  """
 
456
  gpas_mlp: Optional[GPASAnalysis] = None # GPAS after MLP residual
457
 
458
  # Optional components (None when inactive)
459
+ jtokm: Optional[JTokMAnalysis] = None # if use_jtokm
460
+ attn_res: Optional[AttnResAnalysis] = None # if use_attn_res
461
+ dca: Optional[DCAAnalysis] = None # if use_dca
462
+ stack: Optional[StackMemoryAnalysis] = None # if use_stacktrans
463
+ laurel_attn: Optional["LAuReLAnalysis"] = None # if use_laurel (attention residual)
464
+ laurel_mlp: Optional["LAuReLAnalysis"] = None # if use_laurel (MLP residual)
465
 
466
 
467
  @dataclass
 
526
  layers: Optional[List[LayerAnalysis]] = None
527
  jtokm_aux_stats: Optional[list] = None
528
  attn_res_sources_final: Optional[list] = None
529
+ dca_all_tokens_final: Optional[list] = None
530
  logits: Optional[torch.Tensor] = None
531
 
532
  class ScalarMultiplier(nn.Module):
 
2446
  first_layer_fan: Optional[torch.Tensor] = None,
2447
  attn_analysis: Optional[AttentionAnalysis] = None,
2448
  repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
2449
+ mudd_xk: Optional[torch.Tensor] = None,
2450
+ mudd_xv: Optional[torch.Tensor] = None,
2451
  **kwargs: Unpack[FlashAttentionKwargs],
2452
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
2453
  input_shape = hidden_states.shape[:-1]
 
2458
  h_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * h_fan
2459
  current_layer_fan = h_fan.clone()
2460
 
2461
+ # ── MUDD: separate K/V FAN paths ─────────────────────────────────
2462
+ # When mudd_xk/mudd_xv are provided (MUDD qkvr mode), they have already
2463
+ # been normalized by the decoder layer's K/V norm chain. Here they go
2464
+ # through their own FAN transform before k_proj/v_proj, keeping the
2465
+ # FANformer periodicity modeling orthogonally intact per stream.
2466
+ h_fan_k = self.fan_layer(mudd_xk) if mudd_xk is not None else h_fan
2467
+ h_fan_v = self.fan_layer(mudd_xv) if mudd_xv is not None else h_fan
2468
+
2469
  query_shape = (*input_shape, self.config.num_attention_heads, self.head_dim)
2470
  kv_shape = (*input_shape, self.num_mea_component_heads, self.head_dim)
2471
 
 
2480
  attn_analysis.gate_raw = gate.detach()
2481
 
2482
  q = self.q_norm(q_raw.view(query_shape)).transpose(1, 2)
2483
+ k = self.k_norm(self.k_proj(h_fan_k).view(kv_shape)).transpose(1, 2)
2484
+ v = self.v_proj(h_fan_v).view(kv_shape).transpose(1, 2)
2485
 
2486
  if attn_analysis is not None:
2487
  attn_analysis.q_post_norm = q.detach()
 
3158
  return result
3159
 
3160
 
3161
+ class NeoLLMMUDDModule(nn.Module):
3162
+ """
3163
+ Multiway Dynamic Dense (MUDD) Depth-wise Aggregate module.
3164
+
3165
+ Generates per-position, per-stream connection weights over all preceding
3166
+ layer outputs (and the token embedding) and produces up to C=4 aggregated
3167
+ streams (Q, K, V, R) for the next Transformer block.
3168
+
3169
+ Architecture (Xiao et al., 2025, arXiv:2502.12170):
3170
+ dw = GELU(RMSNorm(x) @ W1) @ W2 + a # [B, T, C*(lidx+2)]
3171
+ dw = reshape to [C, B, T, (lidx+2)]
3172
+ stream_c = Ξ£_j dw[c, :, :, j] * hiddens[j] for c in range(C)
3173
+
3174
+ W1 ~ N(0, 1/D), W2 = 0, a = identity on last index β†’ reduces to standard
3175
+ Transformer at init (dynamic part is zero, static bias selects Xi).
3176
+
3177
+ Args:
3178
+ hidden_size: model dimension D
3179
+ lidx: layer index (0-based); history has lidx+2 entries
3180
+ num_ways: C, number of output streams (4 for "qkvr", 1 for "l")
3181
+ is_last: whether this is the last layer (controls expand_last)
3182
+ expand_last: multiply hid_dim by 4 for the final layer's DA module
3183
+ round64: round hid_dim up to the nearest multiple of 64
3184
+ """
3185
+
3186
+ def __init__(
3187
+ self,
3188
+ hidden_size: int,
3189
+ lidx: int,
3190
+ num_ways: int = 4,
3191
+ is_last: bool = False,
3192
+ expand_last: bool = False,
3193
+ round64: bool = False,
3194
+ ) -> None:
3195
+ super().__init__()
3196
+ self.lidx = lidx
3197
+ self.num_ways = num_ways
3198
+ l = lidx + 2 # history length: embedding + lidx layers
3199
+ hid_dim = l * num_ways
3200
+ out_dim = l * num_ways
3201
+ if is_last and expand_last:
3202
+ hid_dim *= 4
3203
+ if round64:
3204
+ hid_dim = (hid_dim // 64 + 1) * 64
3205
+ # RMSNorm without learnable scale (paper uses RMSnormNoscale)
3206
+ self.norm = nn.RMSNorm(hidden_size, elementwise_affine=False,
3207
+ eps=1e-6)
3208
+ self.w1 = nn.Linear(hidden_size, hid_dim, bias=False)
3209
+ self.act = nn.GELU()
3210
+ self.w2 = nn.Linear(hid_dim, out_dim, bias=False)
3211
+ self._reset_mudd_parameters(hidden_size)
3212
+
3213
+ def _reset_mudd_parameters(self, D: int) -> None:
3214
+ # W1 ~ N(0, 1/D); W2 = 0 β†’ dynamic part starts at zero
3215
+ nn.init.normal_(self.w1.weight, mean=0.0, std=1.0 / D)
3216
+ nn.init.zeros_(self.w2.weight)
3217
+
3218
+ def forward(
3219
+ self,
3220
+ x: torch.Tensor, # [B, T, D] β€” current layer output (Xi)
3221
+ hiddens: list, # list of lidx+2 tensors [B, T, D]
3222
+ static_bias: torch.Tensor, # [C, lidx+2] β€” learnable static prior
3223
+ ) -> tuple:
3224
+ """
3225
+ Returns:
3226
+ Tuple of num_ways tensors, each [B, T, D] β€” the aggregated streams.
3227
+ """
3228
+ B, T, D = x.shape
3229
+ # Dynamic weight generation: [B, T, C*(lidx+2)]
3230
+ dw = self.w2(self.act(self.w1(self.norm(x))))
3231
+ # Add static bias (broadcast over B and T)
3232
+ # static_bias: [C, L] β†’ [1, 1, C*L] via reshape
3233
+ C, L = static_bias.shape
3234
+ dw = dw + static_bias.reshape(1, 1, C * L).to(dw.dtype)
3235
+ # Reshape to [C, B, T, L]
3236
+ dw = dw.view(B, T, C, L).permute(2, 0, 1, 3) # [C, B, T, L]
3237
+ # Stack history: [L, B, T, D]
3238
+ stacked = torch.stack(hiddens, dim=0) # [L, B, T, D]
3239
+ # Aggregate: Ξ£_j dw[c, :, :, j] * hiddens[j]
3240
+ # einsum "cbtl, lbtd -> cbtd"
3241
+ streams = torch.einsum('cbtl,lbtd->cbtd', dw, stacked) # [C, B, T, D]
3242
+ return tuple(streams[c] for c in range(C))
3243
+
3244
+
3245
+ def dca_select_layers(stacked: torch.Tensor, k: int) -> torch.Tensor:
3246
+ """
3247
+ k-DCA layer selection (Heddes et al., 2025, Β§3.1).
3248
+
3249
+ Keeps only the first k and last k tensors from the depth stack,
3250
+ capping memory at 2k layer representations regardless of depth.
3251
+ When the stack has <= 2k entries all are kept (early layers).
3252
+
3253
+ Args:
3254
+ stacked: [y, B, S, D] β€” stack of all layer outputs so far.
3255
+ k: number of first/last layers to retain.
3256
+ Returns:
3257
+ [min(y, 2k), B, S, D]
3258
+ """
3259
+ y = stacked.shape[0]
3260
+ if y <= k * 2:
3261
+ return stacked
3262
+ return torch.cat([stacked[:k], stacked[-k:]], dim=0)
3263
+
3264
+
3265
+ class NeoLLMGRN(nn.Module):
3266
+ """
3267
+ Generalized Residual Network v3 (GRN-v3) from DeepCrossAttention
3268
+ (Heddes et al., 2025, arXiv:2502.06785, Β§3.1).
3269
+
3270
+ Produces `num_outputs` aggregated streams from a depth-wise stack of
3271
+ layer representations. Weights are simultaneously:
3272
+
3273
+ - **Input-dependent** (dynamic): a two-layer mapping
3274
+ ``wΜ„ = ReLU(RMSNorm(G) @ W)`` produces one scalar per
3275
+ (output-stream, depth-position, batch-token). ``W`` is initialized
3276
+ to zero so the dynamic contribution starts neutral.
3277
+ - **Dimension-dependent** (static): a learnable bias ``b`` of shape
3278
+ ``[num_outputs, num_stack_layers, hidden_size]`` initialized to ones
3279
+ provides a per-dimension, per-layer prior. At initialization the
3280
+ dynamic part is zero and the static bias sums to an equal-weight
3281
+ average over all stack entries, reducing to a standard residual mean.
3282
+
3283
+ The combined weight for output stream ``o``, stack position ``y``,
3284
+ batch ``b``, token ``n``, feature ``d`` is::
3285
+
3286
+ weight[o, y, b, n, d] = ReLU(dynamic[y, b, n, o]) + bias[o, y, d]
3287
+
3288
+ Output ``o`` is then the weighted sum over depth::
3289
+
3290
+ out[o, b, n, d] = Ξ£_y stack[y, b, n, d] * weight[o, y, b, n, d]
3291
+
3292
+ Reference:
3293
+ Heddes, M. et al. (2025). *DeepCrossAttention: Supercharging
3294
+ Transformer Residual Connections.* arXiv:2502.06785.
3295
+
3296
+ Args:
3297
+ hidden_size: model dimension D.
3298
+ num_stack_layers: number of depth entries this GRN will receive
3299
+ (= min(layer_idx+1, 2*dca_k)).
3300
+ num_outputs: number of output streams (3 for DCA Q/K/V,
3301
+ 1 for the final aggregation GRN).
3302
+ eps: epsilon for the internal RMSNorm.
3303
+ """
3304
+
3305
+ def __init__(
3306
+ self,
3307
+ hidden_size: int,
3308
+ num_stack_layers: int,
3309
+ num_outputs: int = 3,
3310
+ eps: float = 1e-6,
3311
+ ) -> None:
3312
+ super().__init__()
3313
+ self.num_outputs = num_outputs
3314
+ self.num_stack_layers = num_stack_layers
3315
+
3316
+ # Dynamic component: RMSNorm(no scale) β†’ Linear β†’ ReLU
3317
+ # Linear maps D β†’ num_outputs; init zeros so dynamic part = 0 at step 0.
3318
+ _linear = nn.Linear(hidden_size, num_outputs, bias=False)
3319
+ nn.init.zeros_(_linear.weight)
3320
+ self.norm_noscale = nn.RMSNorm(
3321
+ hidden_size, eps=eps, elementwise_affine=False
3322
+ )
3323
+ self.to_dynamic = nn.Sequential(_linear, nn.ReLU())
3324
+
3325
+ # Static bias: [num_outputs, num_stack_layers, hidden_size], init ones.
3326
+ # At init: weight = 0 + bias = 1 per entry β†’ equal-weight average β†’ residual.
3327
+ self.bias = nn.Parameter(
3328
+ torch.ones(num_outputs, num_stack_layers, hidden_size)
3329
+ )
3330
+
3331
+ def forward(
3332
+ self,
3333
+ stack: torch.Tensor,
3334
+ analysis: Optional["DCAAnalysis"] = None,
3335
+ ) -> tuple:
3336
+ """
3337
+ Args:
3338
+ stack: [y, B, S, D] β€” selected depth stack (y ≀ 2*dca_k).
3339
+ analysis: optional DCAAnalysis to deposit grn_depth_weights.
3340
+ Returns:
3341
+ Tuple of num_outputs tensors each [B, S, D].
3342
+ When num_outputs=1 returns a single [B, S, D] tensor directly.
3343
+ """
3344
+ y, B, S, D = stack.shape
3345
+ assert y == self.num_stack_layers, (
3346
+ f"NeoLLMGRN expected stack depth {self.num_stack_layers}, got {y}"
3347
+ )
3348
+
3349
+ # Dynamic aggregate: [y, B, S, D] β†’ norm β†’ [y, B, S, D]
3350
+ # β†’ to_dynamic β†’ [y, B, S, num_outputs]
3351
+ # β†’ permute β†’ [num_outputs, y, B, S]
3352
+ normed = self.norm_noscale(stack) # [y, B, S, D]
3353
+ dynamic = self.to_dynamic(normed) # [y, B, S, num_outputs]
3354
+ dynamic = dynamic.permute(3, 0, 1, 2) # [o, y, B, S]
3355
+
3356
+ if analysis is not None:
3357
+ analysis.grn_depth_weights = dynamic.detach()
3358
+
3359
+ # Combined weight: dynamic scalar + static bias per dimension
3360
+ # dynamic: [o, y, B, S] β†’ [o, y, B, S, 1]
3361
+ # bias: [o, y, D] β†’ [o, y, 1, 1, D]
3362
+ weights = dynamic.unsqueeze(-1) + self.bias.unsqueeze(2).unsqueeze(3)
3363
+ # weights: [o, y, B, S, D]
3364
+
3365
+ # Weighted depth-sum: Ξ£_y stack[y] * weights[o, y]
3366
+ # stack: [y, B, S, D] β†’ [1, y, B, S, D]
3367
+ output = (stack.unsqueeze(0) * weights).sum(dim=1) # [o, B, S, D]
3368
+
3369
+ if self.num_outputs == 1:
3370
+ return output.squeeze(0) # [B, S, D]
3371
+ return tuple(output[i] for i in range(self.num_outputs))
3372
+
3373
+
3374
+ class StackMemory(nn.Module):
3375
+ """
3376
+ Differentiable multi-head hidden-state stack for NeoLLM.
3377
+
3378
+ Implements the StackTrans module from Zhang et al. (NeurIPS 2025):
3379
+ "Recursive Transformer: Boosting Reasoning Ability with State Stack."
3380
+
3381
+ Architecture (one forward call, covering the full sequence in parallel):
3382
+
3383
+ 1. down_proj : [B,S,D] β†’ [B,S,stack_d_model]
3384
+ 2. action_head: β†’ [B,S,H,3] softmax (push / pop / no-op)
3385
+ 3. k_values : reshape to [B,S,H,ds]
3386
+ 4. _vectorized_update: applies soft push/pop/no-op to each
3387
+ (batch, head) stack in parallel across the sequence dim.
3388
+ This is the training-parallelism approximation from Β§3.3:
3389
+ every token sees the *same* initial stack, breaking strict
3390
+ temporal ordering within a sequence in exchange for full
3391
+ data-parallelism. Cross-token memory is recovered during
3392
+ autoregressive generation via the step() / enable_cache path.
3393
+ 5. gate_proj : global read β€” softmax over all stack slots
3394
+ (paper Β§3.1: "query-over-stack attention"), masked by the
3395
+ validity mask. Returns weighted sum of the stack.
3396
+ 6. up_proj : [B,S,stack_d_model] β†’ [B,S,D]
3397
+ 7. residual : output = up_proj_out * res_weight + hidden_states
3398
+
3399
+ Vertical passing (layer-to-layer):
3400
+ Returns new_stack[:, -1] and new_mask[:, -1] β€” the stack state
3401
+ at the last sequence position β€” which becomes the initial stack
3402
+ for the next decoder layer. This propagates hierarchical context
3403
+ depth-wise through the network.
3404
+
3405
+ Temporal accumulation (generation):
3406
+ During autoregressive decoding, enable_cache=True and step() is
3407
+ used: k_cache and action_cache store previous-token values so the
3408
+ update equation integrates the full generated history rather than
3409
+ starting from zeros each step.
3410
+
3411
+ Args:
3412
+ config: NeoLLMConfig instance. Reads:
3413
+ stacktrans_num_heads (H, number of stack heads)
3414
+ stacktrans_stack_slots (S, stack depth)
3415
+ stacktrans_stack_d_model (HΓ—ds, low-rank dimension)
3416
+ stacktrans_forward_bs (batch size for cache buffers)
3417
+ """
3418
+
3419
+ def __init__(self, config: NeoLLMConfig):
3420
+ super().__init__()
3421
+ self.num_stack_heads = config.stacktrans_num_heads
3422
+ self.stack_slots = config.stacktrans_stack_slots
3423
+ self.stack_d_model = config.stacktrans_stack_d_model
3424
+ self.head_dim = self.stack_d_model // self.num_stack_heads
3425
+
3426
+ # Dimension reduction / expansion (standard nn.Linear, no multipliers β€”
3427
+ # StackMemory is architecturally independent per the paper Β§A)
3428
+ self.down_proj = nn.Linear(config.hidden_size, self.stack_d_model, bias=True)
3429
+ self.up_proj = nn.Linear(self.stack_d_model, config.hidden_size, bias=True)
3430
+
3431
+ # Action prediction: push / pop / no-op probabilities, one triple per head
3432
+ self.action_head = nn.Linear(self.stack_d_model, 3 * self.num_stack_heads, bias=True)
3433
+
3434
+ # Global read query: one scalar gate per stack slot per head
3435
+ self.gate_proj = nn.Linear(self.head_dim, 1, bias=True)
3436
+
3437
+ # Learnable residual gate (paper h'_t = g_hΒ·h_t + R_t, g_h scalar)
3438
+ self.res_weight = nn.Parameter(torch.ones(1))
3439
+
3440
+ # ── Autoregressive generation cache ──────────────────────────────
3441
+ # k_cache and action_cache hold per-token values from previous steps
3442
+ # so step() can reconstruct the full sequence history. Only used when
3443
+ # enable_cache=True (set by NeoLLMModel.forward when use_cache=True).
3444
+ _fbs = getattr(config, "stacktrans_forward_bs", 1)
3445
+ _cs = getattr(config, "cache_size", 2048)
3446
+ self.register_buffer(
3447
+ "k_cache",
3448
+ torch.zeros(_fbs, _cs, self.num_stack_heads, self.head_dim),
3449
+ )
3450
+ self.register_buffer(
3451
+ "action_cache",
3452
+ torch.zeros(_fbs, _cs, self.num_stack_heads, 3),
3453
+ )
3454
+ self.cache_position = 0
3455
+ self.enable_cache = False
3456
+
3457
+ # ── Cache helpers ─────────────────────────────────────────────────────
3458
+
3459
+ def reset_cache(self) -> None:
3460
+ self.cache_position = 0
3461
+
3462
+ def _update_cache(
3463
+ self,
3464
+ k_values: torch.Tensor, # [B,S,H,ds] detached
3465
+ actions: torch.Tensor, # [B,S,H,3] detached
3466
+ ) -> None:
3467
+ seq_len = k_values.shape[1]
3468
+ if self.cache_position + seq_len <= self.k_cache.shape[1]:
3469
+ self.k_cache [:, self.cache_position:self.cache_position + seq_len] = k_values
3470
+ self.action_cache[:, self.cache_position:self.cache_position + seq_len] = actions
3471
+ self.cache_position += seq_len
3472
+ else:
3473
+ self.reset_cache()
3474
+
3475
+ # ── Core stack update ─────────────────────────────────────────────────
3476
+
3477
+ def _vectorized_update(
3478
+ self,
3479
+ stack: torch.Tensor, # [B, H, slots, ds] (4-D) or [B,S,H,slots,ds] (5-D)
3480
+ mask: torch.Tensor, # [B, H, slots] (3-D) or [B,S,H,slots] (4-D)
3481
+ actions: torch.Tensor, # [B, S, H, 3]
3482
+ k_values: torch.Tensor, # [B, S, H, ds]
3483
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
3484
+ """
3485
+ Vectorized soft push/pop/no-op stack update.
3486
+
3487
+ Every token position receives the *same* initial stack (the one
3488
+ passed in from the previous layer), and operations are applied in
3489
+ parallel across S. This is the Β§3.3 training-parallelism
3490
+ approximation: strict sequential dependency within a sequence is
3491
+ broken intentionally to allow full batch processing.
3492
+
3493
+ Returns:
3494
+ new_stack [B, S, H, slots, ds]
3495
+ new_mask [B, S, H, slots]
3496
+ """
3497
+ batch_size, seq_len = actions.shape[:2]
3498
+
3499
+ # Broadcast 4-D initial state along the sequence dimension
3500
+ if stack.dim() == 4:
3501
+ stack = stack.unsqueeze(1).expand(-1, seq_len, -1, -1, -1)
3502
+ mask = mask.unsqueeze(1).expand(-1, seq_len, -1, -1)
3503
+
3504
+ # Push: new value at top, shift everything down (overflow discarded)
3505
+ push_stack = torch.cat([k_values.unsqueeze(3), stack[:, :, :, :-1]], dim=3)
3506
+ push_mask = torch.cat([torch.ones_like(mask[:, :, :, :1]),
3507
+ mask[:, :, :, :-1]], dim=3)
3508
+
3509
+ # Pop: shift everything up, zero at bottom
3510
+ pop_stack = torch.cat([stack[:, :, :, 1:],
3511
+ torch.zeros_like(stack[:, :, :, :1])], dim=3)
3512
+ pop_mask = torch.cat([mask[:, :, :, 1:],
3513
+ torch.zeros_like(mask[:, :, :, :1])], dim=3)
3514
+
3515
+ # Soft combination weighted by action probabilities
3516
+ # actions: [B,S,H,3] β†’ unsqueeze to [B,S,H,3,1,1] for stack broadcast
3517
+ aw = actions.unsqueeze(-1).unsqueeze(-1) # [B,S,H,3,1,1]
3518
+ stacks = torch.stack([push_stack, pop_stack, stack], dim=3) # [B,S,H,3,slots,ds]
3519
+ masks = torch.stack([push_mask, pop_mask, mask], dim=3) # [B,S,H,3,slots]
3520
+
3521
+ new_stack = (stacks * aw).sum(dim=3) # [B,S,H,slots,ds]
3522
+ new_mask = (masks * aw.squeeze(-1)).sum(dim=3) # [B,S,H,slots]
3523
+ return new_stack, new_mask
3524
+
3525
+ # ── Training forward (full sequence) ─────────────────────────────────
3526
+
3527
+ def forward(
3528
+ self,
3529
+ hidden_states: torch.Tensor,
3530
+ stack: Optional[torch.Tensor] = None,
3531
+ mask: Optional[torch.Tensor] = None,
3532
+ analysis: Optional[StackMemoryAnalysis] = None,
3533
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
3534
+ """
3535
+ Full-sequence forward pass (training and prefill).
3536
+
3537
+ Args:
3538
+ hidden_states: [B, S, D]
3539
+ stack: [B, H, slots, ds] β€” previous layer's stack state,
3540
+ or None (initialised to zeros for layer 0).
3541
+ mask: [B, H, slots] β€” validity mask for stack,
3542
+ or None (initialised to zeros for layer 0).
3543
+ analysis: StackMemoryAnalysis container; populated when
3544
+ model is in eval + analysis mode.
3545
+
3546
+ Returns:
3547
+ (output, new_stack, new_mask)
3548
+ output [B, S, D]
3549
+ new_stack [B, H, slots, ds] β€” stack at final sequence position
3550
+ new_mask [B, H, slots]
3551
+ """
3552
+ batch_size, seq_len, _ = hidden_states.shape
3553
+ device = hidden_states.device
3554
+
3555
+ # Capture incoming stack for analysis before it is updated
3556
+ if analysis is not None:
3557
+ analysis.stack_in = stack.detach() if stack is not None else None
3558
+
3559
+ # Initialise empty stack / mask for layer 0
3560
+ if stack is None:
3561
+ stack = torch.zeros(
3562
+ batch_size, self.num_stack_heads, self.stack_slots, self.head_dim,
3563
+ device=device, dtype=hidden_states.dtype,
3564
+ )
3565
+ if mask is None:
3566
+ mask = torch.zeros(
3567
+ batch_size, self.num_stack_heads, self.stack_slots,
3568
+ device=device, dtype=hidden_states.dtype,
3569
+ )
3570
+
3571
+ # 1. Project down
3572
+ h_proj = self.down_proj(hidden_states) # [B,S,stack_d_model]
3573
+
3574
+ # 2. Action probabilities
3575
+ action_logits = self.action_head(h_proj) / math.sqrt(self.head_dim)
3576
+ actions = F.softmax(
3577
+ action_logits.view(batch_size, seq_len, self.num_stack_heads, 3), dim=-1
3578
+ ) # [B,S,H,3]
3579
+
3580
+ # 3. Values to push
3581
+ k_values = h_proj.view(batch_size, seq_len, self.num_stack_heads, self.head_dim)
3582
+
3583
+ # 4. Vectorized stack update
3584
+ new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values)
3585
+ # new_stack: [B,S,H,slots,ds], new_mask: [B,S,H,slots]
3586
+
3587
+ # 5. Global read (query-over-stack attention, paper Β§3.1)
3588
+ gate_scores = self.gate_proj(new_stack).squeeze(-1) # [B,S,H,slots]
3589
+ gate_weights = F.softmax(gate_scores + (1 - new_mask) * -1e9, dim=-1)
3590
+ memory_out = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3)
3591
+ # memory_out: [B,S,H,ds] β†’ [B,S,stack_d_model]
3592
+ memory_out = memory_out.view(batch_size, seq_len, self.stack_d_model)
3593
+
3594
+ # 6. Project back up
3595
+ memory_out_proj = self.up_proj(memory_out) # [B,S,D]
3596
+
3597
+ # 7. Residual
3598
+ output = memory_out_proj * self.res_weight + hidden_states
3599
+
3600
+ # 8. Update generation cache (no-op during training)
3601
+ if self.enable_cache:
3602
+ self._update_cache(k_values.detach(), actions.detach())
3603
+
3604
+ # Populate analysis fields
3605
+ if analysis is not None:
3606
+ analysis.action_probs = actions.detach()
3607
+ analysis.stack_out = new_stack[:, -1].detach()
3608
+ analysis.mask_out = new_mask[:, -1].detach()
3609
+ analysis.gate_weights = gate_weights.detach()
3610
+ analysis.memory_output = memory_out.detach()
3611
+ analysis.residual_scale = self.res_weight.item()
3612
+
3613
+ # Return output + last-position stack state for next layer
3614
+ return output, new_stack[:, -1], new_mask[:, -1]
3615
+
3616
+ # ── Autoregressive single-token forward ──────────────────────────────
3617
+
3618
+ def step(
3619
+ self,
3620
+ hidden_state: torch.Tensor, # [B, D]
3621
+ stack: torch.Tensor, # [B, H, slots, ds]
3622
+ mask: torch.Tensor, # [B, H, slots]
3623
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
3624
+ """
3625
+ Single-token forward for autoregressive generation.
3626
+
3627
+ When enable_cache=False (simple path used by NeoLLM generation):
3628
+ Calls forward() with a length-1 sequence and unpacks the result.
3629
+ The stack state passed in carries all history from previous tokens
3630
+ (propagated by NeoLLMModel.forward across generation steps).
3631
+
3632
+ When enable_cache=True (full-history reconstruction path):
3633
+ Concatenates the current token with cached previous-token values
3634
+ and replays the full vectorized update, extracting only the last
3635
+ position. This gives a more accurate stack that sees full history
3636
+ at the cost of O(T) computation per step.
3637
+
3638
+ Returns:
3639
+ (output, new_stack, new_mask)
3640
+ output [B, D]
3641
+ new_stack [B, H, slots, ds]
3642
+ new_mask [B, H, slots]
3643
+ """
3644
+ if not self.enable_cache:
3645
+ # Simple path: forward with seq_len=1, squeeze the sequence dim
3646
+ out, new_stack, new_mask = self.forward(
3647
+ hidden_state.unsqueeze(1), stack, mask
3648
+ )
3649
+ return out.squeeze(1), new_stack, new_mask
3650
+
3651
+ batch_size = hidden_state.shape[0]
3652
+
3653
+ # Compute features for the current token
3654
+ h_proj = self.down_proj(hidden_state) # [B, stack_d_model]
3655
+ a_logits = self.action_head(h_proj) / math.sqrt(self.head_dim)
3656
+ cur_act = F.softmax(
3657
+ a_logits.view(batch_size, 1, self.num_stack_heads, 3), dim=-1
3658
+ ) # [B,1,H,3]
3659
+ cur_k = h_proj.view(batch_size, 1, self.num_stack_heads, self.head_dim)
3660
+
3661
+ # Prepend cached history (all previous tokens in this generation)
3662
+ if self.cache_position > 0:
3663
+ k_values = torch.cat([self.k_cache[:batch_size, :self.cache_position], cur_k], dim=1)
3664
+ actions = torch.cat([self.action_cache[:batch_size, :self.cache_position], cur_act], dim=1)
3665
+ else:
3666
+ k_values = cur_k
3667
+ actions = cur_act
3668
+
3669
+ # Full vectorized update over history + current token; take last position
3670
+ new_stack_seq, new_mask_seq = self._vectorized_update(stack, mask, actions, k_values)
3671
+ new_stack = new_stack_seq[:, -1] # [B,H,slots,ds]
3672
+ new_mask = new_mask_seq[:, -1] # [B,H,slots]
3673
+
3674
+ # Global read on the new stack state
3675
+ gate_scores = self.gate_proj(new_stack).squeeze(-1) # [B,H,slots]
3676
+ gate_weights = F.softmax(gate_scores + (1 - new_mask) * -1e9, dim=-1)
3677
+ memory_out = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=2)
3678
+ memory_out = memory_out.view(batch_size, self.stack_d_model)
3679
+
3680
+ memory_out_proj = self.up_proj(memory_out) # [B,D]
3681
+ output = memory_out_proj * self.res_weight + hidden_state
3682
+
3683
+ self._update_cache(cur_k, cur_act)
3684
+ return output, new_stack, new_mask
3685
+
3686
+
3687
+ @dataclass
3688
+ class LAuReLAnalysis:
3689
+ """
3690
+ Internals of one LAuReL residual connection forward pass.
3691
+ Only populated when use_laurel=True AND model is in eval + analysis mode.
3692
+ Instantiated twice per layer: once for the attention residual, once for MLP.
3693
+
3694
+ Reference: Menghani, G., Kumar, R. & Kumar, S. (ICML 2025).
3695
+ *LAuReL: Learned Augmented Residual Layer.* arXiv:2411.07501.
3696
+
3697
+ Math (combined RW+LR, both sub-variants active):
3698
+
3699
+ x_{i+1} = Ξ± Β· f(x_i) + Ξ² Β· (AΒ·(BΒ·x_i) + x_i)
3700
+
3701
+ where [Ξ±, Ξ²] = softmax([a, b]), a,b ∈ ℝ learnable (RW component),
3702
+ B ∈ ℝ^{rΓ—D} column-orthogonal init, A ∈ ℝ^{DΓ—r} zero init (LR component).
3703
+ At step 0: A=0 β†’ lr_term=0, so x_{i+1} = 0.5Β·f(x) + 0.5Β·x_i (RW only)
3704
+ or x_{i+1} = f(x_i) + x_i (LR only, standard residual).
3705
+
3706
+ Fields:
3707
+ alpha_rw: softmax(a) β€” weight on f(x_i). [scalar float]
3708
+ None when use_laurel_rw=False.
3709
+ beta_rw: softmax(b) β€” weight on g(x_i). [scalar float]
3710
+ None when use_laurel_rw=False.
3711
+ lr_term: AΒ·(BΒ·x_res) β€” the low-rank residual augmentation.
3712
+ Shape [B, S, D]. Zero at init. None when use_laurel_lr=False.
3713
+ output: Final combined tensor before GPAS. Shape [B, S, D].
3714
+ """
3715
+ alpha_rw: Optional[float] = None # softmax weight on f(x)
3716
+ beta_rw: Optional[float] = None # softmax weight on g(x)
3717
+ lr_term: Optional[torch.Tensor] = None # A(Bx) low-rank augmentation [B,S,D]
3718
+ output: Optional[torch.Tensor] = None # combined pre-GPAS [B,S,D]
3719
+
3720
+
3721
+ class LAuReLLayer(nn.Module):
3722
+ """
3723
+ LAuReL: Learned Augmented Residual Layer.
3724
+
3725
+ A lightweight replacement for the canonical residual connection
3726
+ that learns to blend the nonlinear sub-layer output f(x) with a
3727
+ richer linear function of the residual x, optionally augmented by a
3728
+ low-rank transformation.
3729
+
3730
+ Reference: Menghani, G., Kumar, R. & Kumar, S. (ICML 2025).
3731
+ *LAuReL: Learned Augmented Residual Layer.* arXiv:2411.07501.
3732
+
3733
+ ── Sub-variants ────────────────────────────────────────────────────
3734
+ Controlled by config flags; any combination is valid:
3735
+
3736
+ **RW only** (use_laurel_rw=True, use_laurel_lr=False):
3737
+
3738
+ x_{i+1} = Ξ± Β· f(x_i) + Ξ² Β· x_i
3739
+ [Ξ±, Ξ²] = softmax([a, b]), a,b ∈ ℝ (2 params)
3740
+
3741
+ **LR only** (use_laurel_rw=False, use_laurel_lr=True):
3742
+
3743
+ x_{i+1} = f(x_i) + AΒ·(BΒ·x_i) + x_i
3744
+ B ∈ ℝ^{rΓ—D} column-orthogonal init (down-projection)
3745
+ A ∈ ℝ^{DΓ—r} zero init (up-projection)
3746
+ Params: 2Β·rΒ·D per layer.
3747
+
3748
+ **RW + LR** (both True, paper recommendation):
3749
+
3750
+ x_{i+1} = Ξ± Β· f(x_i) + Ξ² Β· (AΒ·(BΒ·x_i) + x_i)
3751
+
3752
+ ── Initialisation ──────────────────────────────────────────────────
3753
+ RW: raw logits [a, b] = [0, 0] β†’ Ξ±=Ξ²=0.5 at step 0.
3754
+ LR: A (up) = zeros β†’ lr_term = 0 at step 0 β†’ pure residual at init.
3755
+ This ensures the model starts as a standard residual and smoothly
3756
+ diverges as the gates and low-rank matrices are trained.
3757
+
3758
+ ── Integration in NeoLLM ───────────────────────────────────────────
3759
+ Applied immediately before GPAS at both residual sums per layer:
3760
+
3761
+ h_tilde = GPAS( LAuReL(attn_out, residual_attn) )
3762
+ output = GPAS( LAuReL(delta_m, residual_mlp) )
3763
+
3764
+ GPAS then applies its stop-gradient scaling on the combined stream,
3765
+ preserving gradient magnitudes across the depth of the network.
3766
+ The two techniques are structurally orthogonal: LAuReL controls the
3767
+ *mixing ratio* of f(x) and x at each residual junction; GPAS
3768
+ controls the *magnitude* of the combined stream with a learned gate
3769
+ and a stop-gradient operator that prevents gradient vanishing.
3770
+
3771
+ Args:
3772
+ config: NeoLLMConfig. Reads use_laurel_rw, use_laurel_lr,
3773
+ laurel_lr_rank, hidden_size.
3774
+ """
3775
+
3776
+ def __init__(self, config: NeoLLMConfig):
3777
+ super().__init__()
3778
+ self.use_rw = getattr(config, "use_laurel_rw", True)
3779
+ self.use_lr = getattr(config, "use_laurel_lr", True)
3780
+ D = config.hidden_size
3781
+ r = getattr(config, "laurel_lr_rank", 32)
3782
+
3783
+ if self.use_rw:
3784
+ # Raw logits for softmax([Ξ±, Ξ²]).
3785
+ # Stored as a single 2-vector so softmax is one op.
3786
+ # Init to zero β†’ Ξ±=Ξ²=0.5 at step 0.
3787
+ self.rw_logits = nn.Parameter(torch.zeros(2))
3788
+
3789
+ if self.use_lr:
3790
+ # down: B ∈ ℝ^{rΓ—D}, column-orthogonal init (paper Β§3.3 LLM recommendation)
3791
+ # up: A ∈ ℝ^{DΓ—r}, zero init β†’ lr_term=0 at step 0 (LoRA-style)
3792
+ self.lr_down = nn.Linear(D, r, bias=False)
3793
+ self.lr_up = nn.Linear(r, D, bias=False)
3794
+
3795
+ def forward(
3796
+ self,
3797
+ f_out: torch.Tensor, # output of f(x): attn or MLP [B,S,D]
3798
+ x_res: torch.Tensor, # residual (skip connection) [B,S,D]
3799
+ analysis: Optional[LAuReLAnalysis] = None,
3800
+ ) -> torch.Tensor:
3801
+ """
3802
+ Args:
3803
+ f_out: Output of f(x) β€” attention output or MLP delta.
3804
+ x_res: Residual tensor β€” accumulated hidden state.
3805
+ analysis: Optional analysis container; populated in eval+analysis mode.
3806
+
3807
+ Returns:
3808
+ Combined tensor [B, S, D] to be fed into GPAS.
3809
+ """
3810
+ # ── LR component: AΒ·(BΒ·x_res) ────────────────────────────────────
3811
+ lr_term = None
3812
+ if self.use_lr:
3813
+ lr_term = self.lr_up(self.lr_down(x_res)) # [B,S,D]
3814
+ g_res = lr_term + x_res # enriched residual
3815
+ else:
3816
+ g_res = x_res
3817
+
3818
+ # ── RW component: Ξ±Β·f + Ξ²Β·g ──────────────────────────────────────
3819
+ if self.use_rw:
3820
+ weights = torch.softmax(self.rw_logits, dim=0) # [2]
3821
+ alpha = weights[0]
3822
+ beta = weights[1]
3823
+ out = alpha * f_out + beta * g_res
3824
+ else:
3825
+ # LR only: standard sum with enriched residual
3826
+ out = f_out + g_res
3827
+
3828
+ if analysis is not None:
3829
+ if self.use_rw:
3830
+ analysis.alpha_rw = alpha.item()
3831
+ analysis.beta_rw = beta.item()
3832
+ if self.use_lr:
3833
+ analysis.lr_term = lr_term.detach()
3834
+ analysis.output = out.detach()
3835
+
3836
+ return out
3837
+
3838
+
3839
+ # ==================== GATED DELTA NET (LINEAR ATTENTION) ====================
3840
+ # Active when use_linear_attention=True. Replaces NeoLLMAttention every
3841
+ # `linear_attention_every_n` layers (pattern 0-indexed: layers 2, 5, 8 …).
3842
+ #
3843
+ # References:
3844
+ # Yang et al. (2024). "Gated Delta Networks." arXiv:2412.06464.
3845
+ # Li et al. (2026). "REPO." arXiv:2512.14391.
3846
+
3847
+
3848
+ def _apply_mask_to_padding_states(
3849
+ hidden_states: torch.Tensor,
3850
+ attention_mask: Optional[torch.Tensor],
3851
+ ) -> torch.Tensor:
3852
+ if (
3853
+ attention_mask is not None
3854
+ and attention_mask.shape[1] > 1
3855
+ and attention_mask.shape[0] > 1
3856
+ ):
3857
+ hidden_states = (
3858
+ hidden_states * attention_mask[:, :, None]
3859
+ ).to(hidden_states.dtype)
3860
+ return hidden_states
3861
+
3862
+
3863
+ def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
3864
+ return x / torch.sqrt((x * x).sum(dim=dim, keepdim=True) + eps)
3865
+
3866
+
3867
+ def _torch_causal_conv1d_update(
3868
+ hidden_states, conv_state, weight, bias=None, activation=None
3869
+ ):
3870
+ _, hidden_size, seq_len = hidden_states.shape
3871
+ state_len = conv_state.shape[-1]
3872
+ combined = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
3873
+ conv_state.copy_(combined[:, :, -state_len:])
3874
+ out = F.conv1d(combined, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
3875
+ return F.silu(out[:, :, -seq_len:]).to(hidden_states.dtype)
3876
+
3877
+
3878
+ def _torch_chunk_gated_delta_rule(
3879
+ query, key, value, g, beta,
3880
+ chunk_size=64, initial_state=None, output_final_state=False,
3881
+ use_qk_l2norm_in_kernel=False,
3882
+ ):
3883
+ initial_dtype = query.dtype
3884
+ if use_qk_l2norm_in_kernel:
3885
+ query, key = _l2norm(query), _l2norm(key)
3886
+ query, key, value, beta, g = [
3887
+ x.transpose(1, 2).contiguous().to(torch.float32)
3888
+ for x in (query, key, value, beta, g)
3889
+ ]
3890
+ bs, seq, nh, kdim = key.shape
3891
+ vdim = value.shape[-1]
3892
+ pad = (chunk_size - nh % chunk_size) % chunk_size
3893
+ for t in (query, key, value):
3894
+ t = F.pad(t, (0, 0, 0, pad))
3895
+ query = F.pad(query, (0, 0, 0, pad))
3896
+ key = F.pad(key, (0, 0, 0, pad))
3897
+ value = F.pad(value, (0, 0, 0, pad))
3898
+ beta = F.pad(beta, (0, pad))
3899
+ g = F.pad(g, (0, pad))
3900
+ tot = nh + pad
3901
+ scale = query.shape[-1] ** -0.5
3902
+ query = query * scale
3903
+ vb = value * beta.unsqueeze(-1)
3904
+ kb = key * beta.unsqueeze(-1)
3905
+ query, key, value, kb, vb = [
3906
+ x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
3907
+ for x in (query, key, value, kb, vb)
3908
+ ]
3909
+ g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
3910
+ triu = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), 0)
3911
+ g = g.cumsum(-1)
3912
+ dm = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp()).tril()
3913
+ attn = -((kb @ key.transpose(-1, -2)) * dm).masked_fill(triu, 0)
3914
+ for i in range(1, chunk_size):
3915
+ r = attn[..., i, :i].clone(); s = attn[..., :i, :i].clone()
3916
+ attn[..., i, :i] = r + (r.unsqueeze(-1) * s).sum(-2)
3917
+ eye = torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
3918
+ attn = attn + eye
3919
+ value = attn @ vb
3920
+ kcd = attn @ (kb * g.exp().unsqueeze(-1))
3921
+ st = torch.zeros(bs, seq, kdim, vdim, dtype=value.dtype, device=value.device) if initial_state is None else initial_state.to(value)
3922
+ out = torch.zeros_like(value)
3923
+ triu2 = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), 1)
3924
+ for i in range(tot // chunk_size):
3925
+ qi, ki, vi = query[:,:,i], key[:,:,i], value[:,:,i]
3926
+ a = (qi @ ki.transpose(-1,-2) * dm[:,:,i]).masked_fill_(triu2, 0)
3927
+ vp = kcd[:,:,i] @ st
3928
+ vn = vi - vp
3929
+ out[:,:,i] = (qi * g[:,:,i,:,None].exp()) @ st + a @ vn
3930
+ st = st * g[:,:,i,-1,None,None].exp() + (ki * (g[:,:,i,-1,None]-g[:,:,i]).exp()[...,None]).transpose(-1,-2) @ vn
3931
+ if not output_final_state: st = None
3932
+ out = out.reshape(out.shape[0], out.shape[1], -1, out.shape[-1])[:,:,:nh]
3933
+ return out.transpose(1,2).contiguous().to(initial_dtype), st
3934
+
3935
+
3936
+ def _torch_recurrent_gated_delta_rule(
3937
+ query, key, value, g, beta, initial_state, output_final_state,
3938
+ use_qk_l2norm_in_kernel=False,
3939
+ ):
3940
+ initial_dtype = query.dtype
3941
+ if use_qk_l2norm_in_kernel:
3942
+ query, key = _l2norm(query), _l2norm(key)
3943
+ query, key, value, beta, g = [
3944
+ x.transpose(1,2).contiguous().to(torch.float32)
3945
+ for x in (query, key, value, beta, g)
3946
+ ]
3947
+ bs, seq, nh, kdim = key.shape
3948
+ vdim = value.shape[-1]
3949
+ query = query * (query.shape[-1] ** -0.5)
3950
+ out = torch.zeros(bs, seq, nh, vdim, dtype=value.dtype, device=value.device)
3951
+ st = torch.zeros(bs, seq, kdim, vdim, dtype=value.dtype, device=value.device) if initial_state is None else initial_state.to(value)
3952
+ for i in range(nh):
3953
+ qt, kt, vt = query[:,:,i], key[:,:,i], value[:,:,i]
3954
+ gt, bt = g[:,:,i].exp().unsqueeze(-1).unsqueeze(-1), beta[:,:,i].unsqueeze(-1)
3955
+ st = st * gt
3956
+ delta = (vt - (st * kt.unsqueeze(-1)).sum(-2)) * bt
3957
+ st = st + kt.unsqueeze(-1) * delta.unsqueeze(-2)
3958
+ out[:,:,i] = (st * qt.unsqueeze(-1)).sum(-2)
3959
+ if not output_final_state: st = None
3960
+ return out.transpose(1,2).contiguous().to(initial_dtype), st
3961
+
3962
+
3963
+ class _NeoLLMRMSNormGated(nn.Module):
3964
+ """Gated RMSNorm fallback when FLA unavailable."""
3965
+ def __init__(self, hidden_size, eps=1e-6, **kwargs):
3966
+ super().__init__()
3967
+ self.weight = nn.Parameter(torch.ones(hidden_size))
3968
+ self.eps = eps
3969
+ def forward(self, x, gate):
3970
+ dtype = x.dtype
3971
+ x = x.float()
3972
+ x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
3973
+ return (self.weight * x.to(dtype) * F.silu(gate.float())).to(dtype)
3974
+
3975
+
3976
+ class NeoLLMGatedDeltaNet(nn.Module):
3977
+ """
3978
+ GatedDeltaNet linear attention with FANformer integration.
3979
+
3980
+ Replaces NeoLLMAttention on every ``linear_attention_every_n``-th layer
3981
+ (0-indexed: layers 2, 5, 8 … for every_n=3).
3982
+
3983
+ REPO (use_repo=True AND use_repo_in_linear_attn=True):
3984
+ Applies continuous per-head positions to Q and K via _apply_repo_rope,
3985
+ matching the full-attention REPO path identically.
3986
+
3987
+ Without REPO the gated delta rule operates without explicit positional
3988
+ encoding (its recurrent state is implicitly position-aware).
3989
+
3990
+ References:
3991
+ Yang et al. (2024). arXiv:2412.06464.
3992
+ Li et al. (2026). arXiv:2512.14391.
3993
+ """
3994
+
3995
+ def __init__(self, config: NeoLLMConfig, layer_idx: int):
3996
+ super().__init__()
3997
+ self.hidden_size = config.hidden_size
3998
+ self.num_v_heads = config.linear_num_value_heads
3999
+ self.num_k_heads = config.linear_num_key_heads
4000
+ self.head_k_dim = config.linear_key_head_dim
4001
+ self.head_v_dim = config.linear_value_head_dim
4002
+ self.key_dim = self.head_k_dim * self.num_k_heads
4003
+ self.value_dim = self.head_v_dim * self.num_v_heads
4004
+ self.conv_kernel_size = config.linear_conv_kernel_dim
4005
+ self.layer_idx = layer_idx
4006
+
4007
+ # ── FANformer (same ratio as full-attention layers) ────────────────
4008
+ self.fan_layer = FANLayer(
4009
+ hidden_size=config.hidden_size,
4010
+ fan_ratio=getattr(config, "fan_ratio", 0.125),
4011
+ )
4012
+ _fan_dim = config.hidden_size + int(
4013
+ config.hidden_size * getattr(config, "fan_ratio", 0.125)
4014
+ )
4015
+
4016
+ # ── Causal conv1d on concatenated Q/K/V ──────────────────────────
4017
+ self.conv_dim = self.key_dim * 2 + self.value_dim
4018
+ self.conv1d = nn.Conv1d(
4019
+ self.conv_dim, self.conv_dim, bias=False,
4020
+ kernel_size=self.conv_kernel_size,
4021
+ groups=self.conv_dim,
4022
+ padding=self.conv_kernel_size - 1,
4023
+ )
4024
+
4025
+ # ── QKVz + ba projections (all from FAN-transformed features) ─────
4026
+ _ratio = self.num_v_heads // self.num_k_heads
4027
+ self.in_proj_qkvz = nn.Linear(
4028
+ _fan_dim, self.key_dim * 2 + self.value_dim * 2, bias=False
4029
+ )
4030
+ self.in_proj_ba = nn.Linear(
4031
+ _fan_dim, self.num_v_heads * 2, bias=False
4032
+ )
4033
+
4034
+ # ── Delta-rule gating parameters ──────────────────────────────────
4035
+ self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
4036
+ A = torch.empty(self.num_v_heads).uniform_(0, 16)
4037
+ self.A_log = nn.Parameter(torch.log(A))
4038
+
4039
+ # ── Output normalisation ───────────────────────────────────────────
4040
+ _NormCls = FusedRMSNormGated if FusedRMSNormGated is not None else _NeoLLMRMSNormGated
4041
+ _norm_kw = (
4042
+ dict(activation="silu",
4043
+ device=torch.cuda.current_device(),
4044
+ dtype=getattr(config, "dtype", None) or torch.get_default_dtype())
4045
+ if FusedRMSNormGated is not None else {}
4046
+ )
4047
+ self.norm = _NormCls(self.head_v_dim, eps=config.rms_norm_eps, **_norm_kw)
4048
+ self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
4049
+ self.dropout = nn.Dropout(config.dropout_rate)
4050
+
4051
+ # ── Kernel dispatch (fast β†’ fallback) ─────────────────────────────
4052
+ self._conv1d_fn = causal_conv1d_fn # None if not installed
4053
+ self._chunk_fn = (chunk_gated_delta_rule
4054
+ if chunk_gated_delta_rule is not None
4055
+ else _torch_chunk_gated_delta_rule)
4056
+ self._recur_fn = (fused_recurrent_gated_delta_rule
4057
+ if fused_recurrent_gated_delta_rule is not None
4058
+ else _torch_recurrent_gated_delta_rule)
4059
+
4060
+ if not is_linear_attn_fast_path:
4061
+ logger.warning_once(
4062
+ "NeoLLMGatedDeltaNet: causal_conv1d / flash-linear-attention "
4063
+ "not installed β€” using pure-PyTorch fallbacks. "
4064
+ "Install both packages for full performance."
4065
+ )
4066
+
4067
+ # ── REPO: continuous per-head positions on Q and K ─────────────────
4068
+ # Controlled by use_repo AND use_repo_in_linear_attn flags.
4069
+ # Only active for layers at or above repo_start_layer.
4070
+ self.use_repo = (
4071
+ getattr(config, "use_repo", False)
4072
+ and getattr(config, "use_repo_in_linear_attn", False)
4073
+ and layer_idx >= getattr(config, "repo_start_layer",
4074
+ config.num_hidden_layers // 3)
4075
+ )
4076
+ if self.use_repo:
4077
+ _d_p = getattr(config, "repo_d_p", config.hidden_size // 8)
4078
+ self.repo_module = REPOModule(
4079
+ hidden_size=config.hidden_size,
4080
+ d_p=_d_p,
4081
+ num_heads=self.num_v_heads,
4082
+ )
4083
+ else:
4084
+ self.repo_module = None
4085
+
4086
+ def _fix_qkvz(
4087
+ self,
4088
+ mixed_qkvz: torch.Tensor,
4089
+ mixed_ba: torch.Tensor,
4090
+ ) -> Tuple[torch.Tensor, ...]:
4091
+ """Split fused projection into (q, k, v, z, b, a)."""
4092
+ ratio = self.num_v_heads // self.num_k_heads
4093
+ mixed_qkvz = mixed_qkvz.view(
4094
+ *mixed_qkvz.shape[:-1],
4095
+ self.num_k_heads,
4096
+ 2 * self.head_k_dim + 2 * ratio * self.head_v_dim,
4097
+ )
4098
+ mixed_ba = mixed_ba.view(
4099
+ *mixed_ba.shape[:-1],
4100
+ self.num_k_heads,
4101
+ 2 * ratio,
4102
+ )
4103
+ q, k, v, z = torch.split(
4104
+ mixed_qkvz,
4105
+ [self.head_k_dim, self.head_k_dim,
4106
+ ratio * self.head_v_dim, ratio * self.head_v_dim],
4107
+ dim=3,
4108
+ )
4109
+ b, a = torch.split(mixed_ba, ratio, dim=3)
4110
+ v = v.reshape(v.shape[0], v.shape[1], -1, self.head_v_dim)
4111
+ z = z.reshape(z.shape[0], z.shape[1], -1, self.head_v_dim)
4112
+ b = b.reshape(b.shape[0], b.shape[1], self.num_v_heads)
4113
+ a = a.reshape(a.shape[0], a.shape[1], self.num_v_heads)
4114
+ return q, k, v, z, b, a
4115
+
4116
+ def forward(
4117
+ self,
4118
+ hidden_states: torch.Tensor,
4119
+ attention_mask: Optional[torch.Tensor] = None,
4120
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
4121
+ repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
4122
+ ) -> torch.Tensor:
4123
+ hidden_states = _apply_mask_to_padding_states(hidden_states, attention_mask)
4124
+ B, S, _ = hidden_states.shape
4125
+
4126
+ # ── FANformer ─────────────────────────────────────────────────────
4127
+ h_fan = self.fan_layer(hidden_states)
4128
+
4129
+ # ── QKVz and ba projections ───────────────────────────────────────
4130
+ q, k, v, z, b, a = self._fix_qkvz(
4131
+ self.in_proj_qkvz(h_fan), self.in_proj_ba(h_fan)
4132
+ )
4133
+
4134
+ # ── Causal conv1d on flattened Q/K/V ─────────────────────────────
4135
+ qkv = torch.cat(
4136
+ [q.reshape(B, S, -1), k.reshape(B, S, -1), v.reshape(B, S, -1)], dim=-1
4137
+ ).transpose(1, 2) # [B, conv_dim, S]
4138
+
4139
+ if self._conv1d_fn is not None:
4140
+ qkv = self._conv1d_fn(
4141
+ x=qkv, weight=self.conv1d.weight.squeeze(1),
4142
+ bias=self.conv1d.bias, activation="silu", seq_idx=None,
4143
+ )
4144
+ else:
4145
+ qkv = F.silu(self.conv1d(qkv)[:, :, :S])
4146
+ qkv = qkv.transpose(1, 2) # [B, S, conv_dim]
4147
+
4148
+ q_f, k_f, v_f = torch.split(
4149
+ qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1
4150
+ )
4151
+ q = q_f.reshape(B, S, -1, self.head_k_dim)
4152
+ k = k_f.reshape(B, S, -1, self.head_k_dim)
4153
+ v = v_f.reshape(B, S, -1, self.head_v_dim)
4154
+
4155
+ # ── REPO: continuous per-head positions ───────────────────────────
4156
+ # Transpose to [B, H, S, dk] for _apply_repo_rope, then back.
4157
+ if self.use_repo and self.repo_module is not None and repo_rope_args is not None:
4158
+ inv_freq, attn_scaling = repo_rope_args
4159
+ z_pos = self.repo_module(hidden_states) # [B, H, S]
4160
+ q_t, k_t = q.transpose(1, 2), k.transpose(1, 2)
4161
+ q_t, k_t = _apply_repo_rope(q_t, k_t, z_pos, inv_freq, attn_scaling)
4162
+ q, k = q_t.transpose(1, 2), k_t.transpose(1, 2)
4163
+
4164
+ # ── GQA-like head expansion ────────────────────────────────────────
4165
+ ratio = self.num_v_heads // self.num_k_heads
4166
+ if ratio > 1:
4167
+ q = q.repeat_interleave(ratio, dim=2)
4168
+ k = k.repeat_interleave(ratio, dim=2)
4169
+
4170
+ beta = b.sigmoid()
4171
+ g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
4172
+
4173
+ # ── Chunk gated delta rule (fused or fallback) ────────────────────
4174
+ core_out, _ = self._chunk_fn(
4175
+ q, k, v, g=g, beta=beta,
4176
+ initial_state=None, output_final_state=False,
4177
+ use_qk_l2norm_in_kernel=True,
4178
+ )
4179
+
4180
+ # ── Gated RMSNorm + output projection ─────────────────────────────
4181
+ z_shape = z.shape
4182
+ core_out = core_out.reshape(-1, core_out.shape[-1])
4183
+ core_out = self.norm(core_out, z.reshape(-1, z.shape[-1]))
4184
+ core_out = core_out.reshape(z_shape).reshape(B, S, -1)
4185
+ return self.dropout(self.out_proj(core_out))
4186
+
4187
+
4188
+ # ==================== DECODER LAYER ==========================================
4189
+
4190
+
4191
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
4192
  """
4193
  Decoder layer with standard residual connections, optional JTok-M injection.
 
4210
  self.layer_idx = layer_idx
4211
  self.use_jtokm = config.use_jtokm
4212
 
4213
+ # ── Token-mixer selection ─────────────────────────────────────────
4214
+ # use_linear_attention=True: replace full attention every
4215
+ # `linear_attention_every_n` layers (0-indexed pattern:
4216
+ # e.g. every_n=3 β†’ layers 2, 5, 8, 11 …).
4217
+ # All other layers keep NeoLLMAttention unchanged.
4218
+ _every_n = getattr(config, "linear_attention_every_n", 3)
4219
+ self.is_linear_attn = (
4220
+ getattr(config, "use_linear_attention", False)
4221
+ and (layer_idx + 1) % _every_n == 0
4222
+ )
4223
+ if self.is_linear_attn:
4224
+ self.linear_attn = NeoLLMGatedDeltaNet(config, layer_idx)
4225
+ self.self_attn = None
4226
+ else:
4227
+ self.self_attn = NeoLLMAttention(config, layer_idx)
4228
+ self.linear_attn = None
4229
+
4230
  self.mlp = (
4231
  VersatileFFN(config)
4232
  if getattr(config, "use_versatile_ffn", False)
 
4259
  self.attn_res_query_attn = nn.Parameter(torch.zeros(config.hidden_size))
4260
  self.attn_res_query_mlp = nn.Parameter(torch.zeros(config.hidden_size))
4261
  self.attn_res_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
4262
+ _num_blocks = getattr(config, 'attn_res_num_blocks', 0)
4263
+ self.attn_res_block_size = (
4264
+ max(config.num_hidden_layers // _num_blocks, 1) if _num_blocks > 0 else 1
4265
+ )
4266
  else:
4267
  self.attn_res_query_attn = None
4268
  self.attn_res_query_mlp = None
4269
  self.attn_res_norm = None
4270
+ self.attn_res_block_size = None
4271
+
4272
+ # ── MUDD: separate K/V LayerNorms for qkvr+sepln mode ──────────────
4273
+ # Only instantiated when both mudd_dense_type='qkvr' AND mudd_sepln=True.
4274
+ # The existing input_layernorm handles the Q stream (unchanged).
4275
+ # Separate norms for K and V allow each stream to rescale independently.
4276
+ _use_mudd = getattr(config, 'use_mudd', False)
4277
+ _mudd_qkvr = getattr(config, 'mudd_dense_type', 'qkvr') == 'qkvr'
4278
+ _mudd_sepln = getattr(config, 'mudd_sepln', False)
4279
+ if _use_mudd and _mudd_qkvr and _mudd_sepln:
4280
+ self.mudd_k_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
4281
+ self.mudd_v_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
4282
+ else:
4283
+ self.mudd_k_norm = None
4284
+ self.mudd_v_norm = None
4285
+
4286
+ # ── DCA (Heddes et al., 2025, arXiv:2502.06785) ───────────────────
4287
+ # GRN-v3 module that aggregates the k-selected depth stack into 3
4288
+ # independent streams (Q, K, V). Each stream has its own dimension-
4289
+ # and input-dependent weights, enabling richer cross-layer interactions.
4290
+ # K and V get their own SeeDNorm + LNS norm chain (same scheme as
4291
+ # MUDD sepln) since they now arrive from a different aggregation path.
4292
+ # The residual connection uses the Q stream output (xq) as its base,
4293
+ # matching the DCA paper's decoder block design (residual = q_input).
4294
+ self.use_dca = getattr(config, 'use_dca', False)
4295
+ if self.use_dca:
4296
+ _dca_k = getattr(config, 'dca_k', 2)
4297
+ _num_stack = min(layer_idx + 1, 2 * _dca_k)
4298
+ self.dca_grn = NeoLLMGRN(
4299
+ hidden_size = config.hidden_size,
4300
+ num_stack_layers = _num_stack,
4301
+ num_outputs = 3,
4302
+ eps = config.rms_norm_eps,
4303
+ )
4304
+ self.dca_k_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
4305
+ self.dca_v_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
4306
+ else:
4307
+ self.dca_grn = None
4308
+ self.dca_k_norm = None
4309
+ self.dca_v_norm = None
4310
+
4311
+ # ── StackTrans (Zhang et al., NeurIPS 2025) ───────────────────────
4312
+ # Differentiable multi-head hidden-state stack inserted at the very
4313
+ # beginning of the layer forward, before the attention sublayer.
4314
+ # Mutually exclusive with use_attn_res, use_mudd, use_dca.
4315
+ self.use_stacktrans = getattr(config, 'use_stacktrans', False)
4316
+ if self.use_stacktrans:
4317
+ self.stack_memory = StackMemory(config)
4318
+ else:
4319
+ self.stack_memory = None
4320
+
4321
+ # ── LAuReL (Menghani, Kumar & Kumar, ICML 2025) ───────────────────
4322
+ # Learned augmented residual connection replacing f(x)+x at both
4323
+ # the attention and MLP residual sums. Applied immediately before
4324
+ # GPAS, so GPAS still controls magnitude via stop-gradient scaling.
4325
+ # Two independent instances per layer (attention and MLP).
4326
+ # Compatible with use_stacktrans. Incompatible with MUDD/DCA/AttnRes.
4327
+ self.use_laurel = getattr(config, 'use_laurel', False)
4328
+ if self.use_laurel:
4329
+ self.laurel_attn = LAuReLLayer(config)
4330
+ self.laurel_mlp = LAuReLLayer(config)
4331
+ else:
4332
+ self.laurel_attn = None
4333
+ self.laurel_mlp = None
4334
 
4335
  def _attn_res(
4336
  self,
 
4380
  B_vals: Optional[torch.Tensor] = None,
4381
  attn_res_sources: Optional[list] = None,
4382
  attn_res_partial: Optional[torch.Tensor] = None,
4383
+ mudd_streams: Optional[tuple] = None,
4384
+ dca_stack: Optional[torch.Tensor] = None,
4385
+ stack_state: Optional[torch.Tensor] = None,
4386
+ stack_mask: Optional[torch.Tensor] = None,
4387
  layer_analysis: Optional[LayerAnalysis] = None,
4388
  output_attentions: Optional[bool] = False,
4389
  repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
 
4393
  if layer_analysis is not None:
4394
  layer_analysis.hidden_states_input = hidden_states.detach()
4395
 
4396
+ # ── StackTrans: hidden-state stack (pre-attention, pre-norm) ─────
4397
+ # Executed first so attention sees the stack-enriched representation.
4398
+ # stack_state / stack_mask carry the stack from the previous layer;
4399
+ # both are None for layer 0 (StackMemory initialises to zeros then).
4400
+ # Mutually exclusive with MUDD / DCA / AttnRes β€” those branches are
4401
+ # all skipped when use_stacktrans=True (enforced in NeoLLMConfig).
4402
+ if self.use_stacktrans and self.stack_memory is not None:
4403
+ st_analysis = layer_analysis.stack if layer_analysis is not None else None
4404
+ hidden_states, stack_state, stack_mask = self.stack_memory(
4405
+ hidden_states, stack_state, stack_mask, analysis=st_analysis
4406
+ )
4407
+
4408
+ # ── MUDD: unpack streams for Q/K/V/R (layer > 0 only) ────────────
4409
+ # mudd_streams is a 4-tuple (xq, xk, xv, xr) when use_mudd=True and
4410
+ # layer_idx > 0; None for layer 0 (standard residual there).
4411
+ # xr replaces hidden_states as the residual throughout this layer.
4412
+ # xq/xk/xv are the aggregated inputs for Q, K, V projections.
4413
+ # When mudd_dense_type='l' (single stream), all four are equal.
4414
+ # When mudd_sepln=True each stream has its own norm applied below.
4415
+ mudd_xk = None
4416
+ mudd_xv = None
4417
+ if mudd_streams is not None:
4418
+ xq_mudd, xk_mudd, xv_mudd, xr_mudd = mudd_streams
4419
+ # Replace hidden_states with xr for residual connections
4420
+ hidden_states = xr_mudd
4421
+ # Norm K and V streams β€” use separate SeeDNorm if sepln, else
4422
+ # they will share the main input_layernorm path via h_attn below
4423
+ if self.mudd_k_norm is not None:
4424
+ mudd_xk = self.lns_attn(self.mudd_k_norm(xk_mudd))
4425
+ mudd_xv = self.lns_attn(self.mudd_v_norm(xv_mudd))
4426
+ else:
4427
+ # No sepln: K/V also go through the Q-path norm chain
4428
+ mudd_xk = self.lns_attn(self.input_layernorm(xk_mudd))
4429
+ mudd_xv = self.lns_attn(self.input_layernorm(xv_mudd))
4430
+ # Override hidden_states for the Q path
4431
+ hidden_states_for_attn = xq_mudd
4432
+ else:
4433
+ hidden_states_for_attn = hidden_states
4434
+
4435
+ # ── DCA: GRN-v3 depth-wise aggregation ───────────────────────────
4436
+ # When active, runs the per-layer GRN on the k-selected depth stack
4437
+ # to produce three independent aggregated streams (Q, K, V).
4438
+ # xq replaces hidden_states as both the Q projection input AND the
4439
+ # post-attention residual (DCA paper: residual = q_input).
4440
+ # xk and xv go through separate SeeDNorm+LNS chains and are injected
4441
+ # into NeoLLMAttention via the existing mudd_xk/mudd_xv parameters.
4442
+ dca_residual = None
4443
+ dca_a = layer_analysis.dca if layer_analysis is not None else None
4444
+ if self.use_dca and dca_stack is not None:
4445
+ xq, xk, xv = self.dca_grn(dca_stack, analysis=dca_a)
4446
+ dca_residual = xq
4447
+ hidden_states_for_attn = xq
4448
+ # K and V streams: SeeDNorm + LNS before k_proj / v_proj
4449
+ # (reuses the mudd_xk/mudd_xv injection path in NeoLLMAttention)
4450
+ mudd_xk = self.lns_attn(self.dca_k_norm(xk))
4451
+ mudd_xv = self.lns_attn(self.dca_v_norm(xv))
4452
+
4453
  # ── Attention Residuals: compute pre-attention input ──────────────
4454
  # When active, the input to the attention sublayer is no longer the
4455
  # raw hidden_states (accumulated residual) but a softmax-weighted
 
4463
  attn_res_sources, attn_res_partial, self.attn_res_query_attn,
4464
  ar_analysis, "attn",
4465
  )
4466
+ # ── Block boundary fires HERE β€” after pre-attn, before attn sublayer ──
4467
+ # Paper pseudocode (Fig. 2) timing: the completed partial of the previous
4468
+ # block is pushed to sources AFTER the pre-attn AttnRes call, so the first
4469
+ # layer of a new block still sees the old partial as an intra-block source
4470
+ # (no duplicate) and the new intra-block accumulation starts from zeros.
4471
+ if self.layer_idx > 0 and self.layer_idx % self.attn_res_block_size == 0:
4472
+ attn_res_sources.append(attn_res_partial) # in-place; outer loop sees this
4473
+ attn_res_partial = torch.zeros_like(attn_res_partial) # fresh delta start
4474
  residual_attn = attn_res_partial
4475
  else:
4476
+ h_attn = hidden_states_for_attn # MUDD/DCA: xq stream or unchanged
4477
+ # DCA: residual is xq (the GRN Q-stream output), not raw hidden_states
4478
+ residual_attn = dca_residual if dca_residual is not None else hidden_states
4479
 
4480
  # ── Attention block ───────────────────────────────────────────────
4481
  sn_pre = layer_analysis.seednorm_pre_attn if layer_analysis is not None else None
 
4484
  if layer_analysis is not None:
4485
  layer_analysis.lns_attn_output = h_lns.detach()
4486
 
4487
+ if self.is_linear_attn:
4488
+ # ── GatedDeltaNet linear attention path ───────────────────────
4489
+ # Does not use: first_layer_fan, mudd_xk/xv, attn_analysis.
4490
+ # attention_mask here is already the linear_attn_mask (no causal
4491
+ # bias, just padding) β€” NeoLLMModel.forward selects it per layer.
4492
+ hidden_states = self.linear_attn(
4493
+ hidden_states=h_lns,
4494
+ attention_mask=attention_mask,
4495
+ position_embeddings=position_embeddings,
4496
+ repo_rope_args=repo_rope_args,
4497
+ )
4498
+ attn_weights = None
4499
+ self.current_layer_fan = None
4500
+ else:
4501
+ # ── Standard full attention path ──────────────────────────────
4502
+ hidden_states, attn_weights, self.current_layer_fan = self.self_attn(
4503
+ hidden_states=h_lns,
4504
+ attention_mask=attention_mask,
4505
+ position_embeddings=position_embeddings,
4506
+ first_layer_fan=first_layer_fan,
4507
+ attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
4508
+ repo_rope_args=repo_rope_args,
4509
+ mudd_xk=mudd_xk,
4510
+ mudd_xv=mudd_xv,
4511
+ **kwargs,
4512
+ )
4513
 
4514
  if layer_analysis is not None:
4515
  layer_analysis.attn_contribution = hidden_states.detach()
4516
 
4517
  gpas_attn_a = layer_analysis.gpas_attn if layer_analysis is not None else None
4518
+
4519
+ # ── Attention residual sum ────────────────────────────────────────
4520
+ # Standard: GPAS(residual_attn + hidden_states)
4521
+ # LAuReL: GPAS(LAuReL(f_out=hidden_states, x_res=residual_attn))
4522
+ # Both paths feed into GPAS which applies stop-gradient scaling.
4523
+ if self.use_laurel and self.laurel_attn is not None:
4524
+ la_attn_a = layer_analysis.laurel_attn if layer_analysis is not None else None
4525
+ combined_attn = self.laurel_attn(hidden_states, residual_attn, analysis=la_attn_a)
4526
+ else:
4527
+ combined_attn = residual_attn + hidden_states
4528
+
4529
+ h_tilde = self.gpas_attn(combined_attn, analysis=gpas_attn_a)
4530
 
4531
  if layer_analysis is not None:
4532
  layer_analysis.h_tilde = h_tilde.detach()
 
4562
  if layer_analysis is not None:
4563
  layer_analysis.mlp_contribution = delta_m.detach()
4564
 
4565
+ # ── MLP residual sum ──────────────────────────────────────────────
4566
+ # LAuReL treats f(x) = delta_m [+ delta_r when JTok-M active] and
4567
+ # x_res = residual_mlp. JTok-M delta_r is additive alongside delta_m,
4568
+ # so the nonlinear component is delta_m + delta_r in that path.
4569
+ gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
4570
  if self.use_jtokm and z_tilde is not None and B_vals is not None:
4571
  orig_shape = h_tilde.shape
4572
  h_flat = h_tilde.reshape(-1, self.hidden_size)
 
4577
  delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat, analysis=jtokm_a)
4578
  delta_r = delta_r.reshape(orig_shape)
4579
 
4580
+ f_mlp = delta_m + delta_r # combined nonlinear term
4581
+ if self.use_laurel and self.laurel_mlp is not None:
4582
+ la_mlp_a = layer_analysis.laurel_mlp if layer_analysis is not None else None
4583
+ combined_mlp = self.laurel_mlp(f_mlp, residual_mlp, analysis=la_mlp_a)
4584
+ else:
4585
+ combined_mlp = residual_mlp + f_mlp
4586
+ hidden_states = self.gpas_mlp(combined_mlp, analysis=gpas_mlp_a)
4587
  else:
4588
+ aux_stats = None
4589
+ if self.use_laurel and self.laurel_mlp is not None:
4590
+ la_mlp_a = layer_analysis.laurel_mlp if layer_analysis is not None else None
4591
+ combined_mlp = self.laurel_mlp(delta_m, residual_mlp, analysis=la_mlp_a)
4592
+ else:
4593
+ combined_mlp = residual_mlp + delta_m
4594
+ hidden_states = self.gpas_mlp(combined_mlp, analysis=gpas_mlp_a)
4595
 
4596
  if layer_analysis is not None:
4597
  layer_analysis.hidden_states_output = hidden_states.detach()
 
4603
  outputs += (aux_stats,)
4604
  if versatile_aux is not None:
4605
  outputs += (versatile_aux,)
4606
+ # StackTrans: always append stack state (None, None when inactive)
4607
+ # so NeoLLMModel.forward can extract them by position -2 and -1.
4608
+ outputs += (stack_state, stack_mask)
4609
  return outputs
4610
 
4611
 
 
4933
  if hasattr(module, "alpha_ma"):
4934
  module.alpha_ma.zero_()
4935
 
4936
+ elif isinstance(module, NeoLLMGatedDeltaNet):
4937
+ module.dt_bias.data.fill_(1.0)
4938
+ module.A_log.data.uniform_(0, 16).log_()
4939
  elif isinstance(module, GPAS):
4940
  module.alpha.data.fill_(0.0)
4941
 
 
4992
  module.attn_res_query_attn.data.zero_()
4993
  module.attn_res_query_mlp.data.zero_()
4994
 
4995
+ elif isinstance(module, StackMemory):
4996
+ # Truncated-normal for all Linear weights (matches NeoLLM convention).
4997
+ # Biases zeroed. res_weight starts at 1.0 so the stack readout
4998
+ # contributes equally to the residual from step 0.
4999
+ std = getattr(self.config, "initializer_range", 0.02)
5000
+ cutoff = getattr(self.config, "init_cutoff_factor", 3.0) * std
5001
+ for attr in ("down_proj", "up_proj", "action_head", "gate_proj"):
5002
+ layer = getattr(module, attr, None)
5003
+ if layer is not None and hasattr(layer, "weight"):
5004
+ nn.init.trunc_normal_(
5005
+ layer.weight, mean=0.0, std=std, a=-cutoff, b=cutoff
5006
+ )
5007
+ if layer.bias is not None:
5008
+ nn.init.zeros_(layer.bias)
5009
+ if hasattr(module, "res_weight"):
5010
+ module.res_weight.data.fill_(1.0)
5011
+
5012
+ elif isinstance(module, LAuReLLayer):
5013
+ # RW: raw logits initialised to zero β†’ softmax([0,0]) = [0.5, 0.5].
5014
+ # The model quickly learns the optimal Ξ±,Ξ² weighting.
5015
+ # LR: lr_down (B, down-projection) β€” column orthogonal init,
5016
+ # as recommended by the LAuReL paper Β§3.3 for LLMs.
5017
+ # Column orthogonal preserves the L2 norm of the projected
5018
+ # representation, ensuring stable gradient magnitudes
5019
+ # through the low-rank bottleneck at init.
5020
+ # lr_up (A, up-projection) β€” zero init β†’ lr_term = AΒ·Bx = 0
5021
+ # at step 0, so the module starts as a standard residual.
5022
+ # Gradient flows back through lr_down immediately via
5023
+ # chain rule; A learns from step 1 onward.
5024
+ if hasattr(module, "rw_logits"):
5025
+ nn.init.zeros_(module.rw_logits)
5026
+ if hasattr(module, "lr_down"):
5027
+ # Column-orthogonal: each column of weight^T is orthonormal.
5028
+ # nn.init.orthogonal_ produces a row-orthogonal matrix (rows
5029
+ # are orthonormal). Transposing gives column-orthogonal.
5030
+ nn.init.orthogonal_(module.lr_down.weight)
5031
+ if hasattr(module, "lr_up"):
5032
+ nn.init.zeros_(module.lr_up.weight)
5033
+
5034
  elif isinstance(module, SpellingBeeEmbedding):
5035
  # byte_emb initialised identically to token embeddings: std=1/√d.
5036
  # Ensures E[β€–e_byteβ€–Β²] β‰ˆ 1 at init, matching etok, so the
 
5100
  self.gradient_checkpointing = False
5101
  self.first_layer_fan = None
5102
 
5103
+ # ── StackTrans state flag ────��────────────────────────────────────
5104
+ self.use_stacktrans = getattr(config, 'use_stacktrans', False)
5105
+
5106
+ # ── Residual-replacement mutex ────────────────────────────────────
5107
+ # AttnRes, MUDD, and DCA all replace the residual aggregation
5108
+ # mechanism β€” at most one can be active at a time.
5109
+ _use_mudd = getattr(config, 'use_mudd', False)
5110
+ _use_attn_res = getattr(config, 'use_attn_res', False)
5111
+ _use_dca = getattr(config, 'use_dca', False)
5112
+ _active_count = sum([_use_mudd, _use_attn_res, _use_dca])
5113
+ if _active_count > 1:
5114
+ active = [n for n, f in [('use_mudd', _use_mudd),
5115
+ ('use_attn_res', _use_attn_res),
5116
+ ('use_dca', _use_dca)] if f]
5117
+ raise ValueError(
5118
+ f"use_mudd, use_attn_res, and use_dca are mutually exclusive β€” "
5119
+ f"got {active} simultaneously active. Set exactly one to True."
5120
+ )
5121
+ if _use_mudd:
5122
+ _mudd_dense_type = getattr(config, 'mudd_dense_type', 'qkvr')
5123
+ _mudd_dynamic = getattr(config, 'mudd_dynamic_dense', True)
5124
+ _mudd_round64 = getattr(config, 'mudd_round64', False)
5125
+ _mudd_expand_last = getattr(config, 'mudd_expand_last', False)
5126
+ _C = 4 if _mudd_dense_type == 'qkvr' else 1
5127
+
5128
+ # Static bias: one [C, lidx+2] parameter per layer.
5129
+ # Initialized with 1 at index [c, lidx+1] (identity on Xi) so that
5130
+ # at init (W2=0) each DA output = Xi β€” reducing to standard Transformer.
5131
+ _static_list = []
5132
+ for lidx in range(config.num_hidden_layers):
5133
+ # Last layer always uses C=1: its DA output is the final
5134
+ # model representation fed to the norm and lm_head, collapsing
5135
+ # all history into a single stream (paper code, both files).
5136
+ _c = 1 if lidx == config.num_hidden_layers - 1 else _C
5137
+ a = torch.zeros(_c, lidx + 2)
5138
+ a[:, lidx + 1] = 1.0 # last entry = current layer = identity
5139
+ _static_list.append(nn.Parameter(a))
5140
+ self.mudd_static = nn.ParameterList(_static_list)
5141
+
5142
+ # Dynamic DA modules (one per layer)
5143
+ if _mudd_dynamic:
5144
+ self.mudd_dynamic = nn.ModuleList([
5145
+ NeoLLMMUDDModule(
5146
+ hidden_size = config.hidden_size,
5147
+ lidx = lidx,
5148
+ # Last layer: C=1 β€” collapses to single final repr
5149
+ num_ways = 1 if lidx == config.num_hidden_layers - 1 else _C,
5150
+ is_last = (lidx == config.num_hidden_layers - 1),
5151
+ expand_last = _mudd_expand_last,
5152
+ round64 = _mudd_round64,
5153
+ )
5154
+ for lidx in range(config.num_hidden_layers)
5155
+ ])
5156
+ else:
5157
+ self.mudd_dynamic = None
5158
+ else:
5159
+ self.mudd_static = None
5160
+ self.mudd_dynamic = None
5161
+
5162
+ # ── DCA final GRN (Heddes et al., 2025) ───────────────────────────
5163
+ # Applied once after all decoder layers to aggregate the full depth
5164
+ # stack into the final hidden representation before the output norm.
5165
+ # num_stack_layers = min(2*k, L+1) β€” same cap as per-layer GRNs.
5166
+ # num_outputs=1 collapses to a single [B, S, D] tensor.
5167
+ if _use_dca and getattr(config, 'dca_use_final_grn', True):
5168
+ _dca_k = getattr(config, 'dca_k', 2)
5169
+ _dca_eps = getattr(config, 'dca_grn_eps', config.rms_norm_eps)
5170
+ self.dca_final_grn = NeoLLMGRN(
5171
+ hidden_size = config.hidden_size,
5172
+ num_stack_layers = min(2 * _dca_k, config.num_hidden_layers + 1),
5173
+ num_outputs = 1,
5174
+ eps = _dca_eps,
5175
+ )
5176
+ else:
5177
+ self.dca_final_grn = None
5178
+
5179
  self.post_init()
5180
 
5181
+ def _update_linear_attn_mask(
5182
+ self,
5183
+ attention_mask: Optional[torch.Tensor],
5184
+ ) -> Optional[torch.Tensor]:
5185
+ """
5186
+ Return mask for GatedDeltaNet layers (no causal bias, padding only).
5187
+ Returns None when all tokens are valid (GatedDeltaNet handles via
5188
+ _apply_mask_to_padding_states internally).
5189
+ """
5190
+ if attention_mask is None:
5191
+ return None
5192
+ if torch.all(attention_mask == 1):
5193
+ return None
5194
+ return attention_mask
5195
+
5196
  def get_input_embeddings(self):
5197
  if self.config.use_token_generator:
5198
  return self.token_generator
 
5218
  getattr(cfg, "use_repo", False)
5219
  and layer_idx >= getattr(cfg, "repo_start_layer", cfg.num_hidden_layers // 3)
5220
  )
5221
+ _versatile = getattr(cfg, "use_versatile_ffn", False)
5222
+ _use_stacktrans = getattr(cfg, "use_stacktrans", False)
5223
+ _use_laurel = getattr(cfg, "use_laurel", False)
5224
  return LayerAnalysis(
5225
  seednorm_pre_attn = SeeDNormAnalysis(),
5226
  seednorm_post_attn = SeeDNormAnalysis(),
 
5234
  polynorm = PolyNormAnalysis() if not _versatile else None,
5235
  versatile = VersatileFFNAnalysis() if _versatile else None,
5236
  ),
5237
+ gpas_attn = GPASAnalysis(),
5238
+ gpas_mlp = GPASAnalysis(),
5239
+ jtokm = JTokMAnalysis() if cfg.use_jtokm else None,
5240
+ attn_res = AttnResAnalysis() if getattr(cfg, "use_attn_res", False) else None,
5241
+ dca = DCAAnalysis() if getattr(cfg, "use_dca", False) else None,
5242
+ stack = StackMemoryAnalysis() if _use_stacktrans else None,
5243
+ laurel_attn = LAuReLAnalysis() if _use_laurel else None,
5244
+ laurel_mlp = LAuReLAnalysis() if _use_laurel else None,
5245
  )
5246
 
5247
  def forward(
 
5343
  if getattr(self.config, "use_repo", False) else None
5344
  )
5345
 
5346
+ # ── Linear attention mask ──────────────────────────────────────────
5347
+ # Computed once; each layer picks the appropriate mask below.
5348
+ linear_attn_mask = self._update_linear_attn_mask(attention_mask)
5349
+
5350
  # ── Attention Residuals state ──────────────────────────────────────
5351
  # Full AttnRes (attn_res_num_blocks=0): sources grows by one entry per
5352
  # decoder layer β€” all previous outputs are kept, max N=num_layers+1.
 
5361
  if use_attn_res:
5362
  attn_res_sources = [hidden_states] # b_0 = token embedding
5363
  attn_res_partial = hidden_states # initial partial sum
5364
+ # Block boundary handling now lives inside NeoLLMDecoderLayer.forward(),
5365
+ # firing after the pre-attn AttnRes call (paper Fig. 2 timing).
5366
+
5367
+ # ── MUDD state ────────────────────────────────────────────────────
5368
+ # hiddens[0] = token embedding; hiddens[i] = output of layer i-1.
5369
+ # After each layer, its output is appended so layer i receives a
5370
+ # history of length i+2 (embedding + i preceding layer outputs).
5371
+ # mudd_streams is None for layer 0 (standard residual path there)
5372
+ # and a C-tuple of [B,T,D] tensors for layers 1…L.
5373
+ use_mudd = getattr(self.config, 'use_mudd', False)
5374
+ mudd_hiddens = None
5375
+ mudd_streams = None
5376
+ if use_mudd:
5377
+ mudd_hiddens = [hidden_states] # b_0 = token embedding
5378
+
5379
+ # ── DCA state ─────────────────────────────────────────────────────
5380
+ # all_tokens[0] = token embedding; grows by one per decoder layer.
5381
+ # Before each layer, the stack is built and k-DCA selection applied,
5382
+ # capping memory at 2*dca_k stored tensors regardless of depth.
5383
+ # dca_stack is always non-None (even layer 0 gets [embedding]).
5384
+ use_dca = getattr(self.config, 'use_dca', False)
5385
+ _dca_k = getattr(self.config, 'dca_k', 2)
5386
+ dca_all_tokens = None
5387
+ dca_stack = None
5388
+ if use_dca:
5389
+ dca_all_tokens = [hidden_states] # [embedding]
5390
+
5391
+ # ── StackTrans state ──────────────────────────────────────────────
5392
+ # stack_state / stack_mask start as None for the first layer;
5393
+ # StackMemory initialises them to zeros internally on first call.
5394
+ # After each layer, the returned (new_stack, new_mask) are passed
5395
+ # to the next layer as its initial stack β€” this is "vertical" state
5396
+ # propagation: information flows depth-wise through the stack.
5397
+ #
5398
+ # Temporal accumulation across generation steps is handled by the
5399
+ # StackMemory internal k_cache / action_cache mechanism:
5400
+ # - enable_cache is set True when use_cache=True (inference)
5401
+ # - reset_cache() is called when past_key_values is None
5402
+ # (new sequence, not a continuation step)
5403
+ # This matches the OLMo reference implementation exactly.
5404
+ use_stacktrans = self.use_stacktrans
5405
+ stack_state = None
5406
+ stack_mask = None
5407
+ if use_stacktrans:
5408
+ use_cache_flag = kwargs.get("use_cache", False)
5409
+ past_kv_flag = kwargs.get("past_key_values", None)
5410
+ for layer in self.layers:
5411
+ if layer.stack_memory is not None:
5412
+ layer.stack_memory.enable_cache = bool(use_cache_flag)
5413
+ if past_kv_flag is None:
5414
+ layer.stack_memory.reset_cache()
5415
 
5416
  # Pre-allocate per-layer analysis list when analysis is active
5417
  if analysis_state is not None:
 
5421
  if output_hidden_states:
5422
  all_hidden_states = all_hidden_states + (hidden_states,)
5423
 
5424
+ # ── DCA: build k-selected stack for this layer ───────────────
5425
+ # Stack has layer_idx+1 entries before selection; after k-DCA
5426
+ # selection it has at most 2*dca_k entries (first k + last k).
5427
+ if use_dca:
5428
+ dca_stack = dca_select_layers(
5429
+ torch.stack(dca_all_tokens, dim=0), k=_dca_k
5430
+ )
 
 
 
 
5431
 
5432
  # Build per-layer analysis container (only in eval + analysis mode)
5433
  layer_analysis = None
 
5436
  layer_analysis.layer_idx = layer_idx
5437
  analysis_state.layers.append(layer_analysis)
5438
 
5439
+ # Select the appropriate mask: causal for full attention,
5440
+ # padding-only for GatedDeltaNet linear attention.
5441
+ _layer_mask = (
5442
+ linear_attn_mask
5443
+ if getattr(decoder_layer, "is_linear_attn", False)
5444
+ else causal_mask
5445
+ )
5446
  layer_outputs = decoder_layer(
5447
  hidden_states,
5448
  position_embeddings=position_embeddings,
5449
+ attention_mask=_layer_mask,
5450
  first_layer_fan=self.first_layer_fan,
5451
  z_tilde=z_tilde,
5452
  B_vals=B_vals,
5453
  attn_res_sources=attn_res_sources,
5454
  attn_res_partial=attn_res_partial if use_attn_res else None,
5455
+ mudd_streams=mudd_streams,
5456
+ dca_stack=dca_stack,
5457
+ stack_state=stack_state,
5458
+ stack_mask=stack_mask,
5459
  layer_analysis=layer_analysis,
5460
  output_attentions=output_attentions,
5461
  repo_rope_args=repo_rope_args,
 
5463
  )
5464
  hidden_states = layer_outputs[0]
5465
 
5466
+ # ── StackTrans: extract updated stack state for next layer ─────
5467
+ # layer_outputs always ends with (stack_state, stack_mask) β€”
5468
+ # both are None when use_stacktrans=False (zero cost).
5469
+ stack_state = layer_outputs[-2]
5470
+ stack_mask = layer_outputs[-1]
5471
+
5472
  # Update AttnRes partial sum β€” the new partial is the layer output
5473
  if use_attn_res:
5474
  attn_res_partial = hidden_states
5475
 
5476
+ # Append layer output to DCA history for next layer's stack
5477
+ if use_dca:
5478
+ dca_all_tokens.append(hidden_states)
5479
+
5480
+ # ── MUDD: append current output and compute DA for next layer ──
5481
+ # mudd_hiddens grows by 1 each iteration; at layer i it has i+2
5482
+ # entries (embedding + i outputs). The DA for layer i+1 takes this
5483
+ # full history and produces C streams via dynamic + static weights.
5484
+ # mudd_streams is passed to layer i+1 as its input streams.
5485
+ if use_mudd:
5486
+ mudd_hiddens.append(hidden_states)
5487
+ # Compute DA module output using the just-appended history
5488
+ # (mudd_hiddens now has layer_idx+2 entries)
5489
+ is_last_layer = (layer_idx == self.config.num_hidden_layers - 1)
5490
+ mudd_da_module = self.mudd_dynamic[layer_idx] if self.mudd_dynamic is not None else None
5491
+ if mudd_da_module is not None:
5492
+ raw_streams = mudd_da_module(
5493
+ hidden_states,
5494
+ mudd_hiddens,
5495
+ self.mudd_static[layer_idx],
5496
+ )
5497
+ else:
5498
+ # Static-only: apply weighted sum with learnable bias
5499
+ # stack history [L, B, T, D], weight by mudd_static
5500
+ stacked = torch.stack(mudd_hiddens, dim=0) # [L, B, T, D]
5501
+ a = self.mudd_static[layer_idx].to(hidden_states.dtype) # [C, L]
5502
+ raw_streams = tuple(
5503
+ torch.einsum('cl,lbtd->btd', a[c:c+1], stacked).squeeze(0)
5504
+ for c in range(a.shape[0])
5505
+ )
5506
+ if is_last_layer:
5507
+ # Last layer DA always produces C=1 β†’ single final repr.
5508
+ # This is the MUDD-aggregated combination of all layer
5509
+ # histories weighted by the last layer's output as query.
5510
+ # Replace hidden_states so the final norm and lm_head
5511
+ # receive this aggregated representation, not the raw
5512
+ # last-layer output (paper forward loop: x = x[0] after loop).
5513
+ hidden_states = raw_streams[0]
5514
+ mudd_streams = None # no next layer
5515
+ elif len(raw_streams) == 1:
5516
+ # dense_type='l': broadcast to 4-tuple
5517
+ mudd_streams = (raw_streams[0],) * 4
5518
+ else:
5519
+ # 'qkvr': 4 streams β†’ (xq, xk, xv, xr)
5520
+ mudd_streams = raw_streams
5521
+
5522
  if output_attentions:
5523
  all_attentions = all_attentions + (layer_outputs[1],)
5524
 
5525
+ # Collect JTok-M / VersatileFFN aux stats.
5526
+ # layer_outputs always ends with (stack_state, stack_mask) β€”
5527
+ # slice [1:-2] to skip hidden_states[0] and the two stack slots.
5528
+ inner_outputs = layer_outputs[1:-2]
5529
+
5530
+ if self.config.use_jtokm and len(inner_outputs) > (1 if output_attentions else 0):
5531
+ all_aux_stats.append(inner_outputs[-1])
5532
 
 
 
5533
  if getattr(self.config, "use_versatile_ffn", False):
5534
+ for item in inner_outputs:
5535
  if isinstance(item, tuple) and len(item) == 3:
 
5536
  all_aux_stats.append(("versatile", item))
5537
  break
5538
 
 
5540
  and hasattr(decoder_layer, "current_layer_fan")):
5541
  self.first_layer_fan = decoder_layer.current_layer_fan
5542
 
5543
+ # ── DCA final GRN ──────────────────────────────────────────────────
5544
+ # Aggregates the full depth history (k-selected) into the final
5545
+ # hidden representation, matching the DCAGPT forward loop which
5546
+ # applies final_grn(stack(all_tokens)) before norm β†’ lm_head.
5547
+ if use_dca and self.dca_final_grn is not None:
5548
+ final_stack = dca_select_layers(
5549
+ torch.stack(dca_all_tokens, dim=0), k=_dca_k
5550
+ )
5551
+ hidden_states = self.dca_final_grn(final_stack)
5552
+
5553
  hidden_states = self.norm(hidden_states)
5554
 
5555
  if output_hidden_states:
 
5562
  analysis_state.attn_res_sources_final = (
5563
  attn_res_sources if use_attn_res else None
5564
  )
5565
+ analysis_state.dca_all_tokens_final = (
5566
+ dca_all_tokens if use_dca else None
5567
+ )
5568
 
5569
  if not return_dict:
5570
  return tuple(
 
5705
  layers = None, # filled by NeoLLMModel.forward
5706
  jtokm_aux_stats = [] if cfg.use_jtokm else None,
5707
  attn_res_sources_final = [] if getattr(cfg, "use_attn_res", False) else None,
5708
+ dca_all_tokens_final = [] if getattr(cfg, "use_dca", False) else None,
5709
  )
5710
 
5711
  # ── Standard model API ────────────────────────────────────────────────
 
5843
  # ==================== AUTOMODEL REGISTRATION ====================
5844
 
5845
  __all__ = [
5846
+ "NeoLLMGatedDeltaNet",
5847
+ "StackMemory",
5848
+ "LAuReLLayer",
5849
+ "NeoLLMMUDDModule",
5850
+ "NeoLLMGRN",
5851
+ "dca_select_layers",
5852
  "NeoLLMForCausalLM",
5853
  "NeoLLMModel",
5854
  "NeoLLMPreTrainedModel",
 
5866
  "REPOModule",
5867
  "VersatileFFN",
5868
  "compute_versatile_aux_loss",
5869
+ # Analysis dataclasses
5870
  "AnalysisState",
5871
  "LayerAnalysis",
5872
  "AttentionAnalysis",
 
5880
  "VersatileFFNAnalysis",
5881
  "JTokMAnalysis",
5882
  "AttnResAnalysis",
5883
+ "DCAAnalysis",
5884
+ "StackMemoryAnalysis",
5885
+ "LAuReLAnalysis",
5886
  "GeneratorAnalysis",
5887
  ]
5888