KitsuVp commited on
Commit
b88fefe
Β·
verified Β·
1 Parent(s): 761a91d

Upload modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +45 -1200
modeling_neollm.py CHANGED
@@ -80,7 +80,7 @@ from transformers.utils import TransformersKwargs, logging
80
  from configuration_neollm import NeoLLMConfig
81
 
82
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
83
-
84
  logger = logging.get_logger(__name__)
85
 
86
 
@@ -339,22 +339,6 @@ class JTokMAnalysis:
339
  lns_scale: Optional[float] = None # 1/√(2β„“) scaling factor
340
 
341
 
342
- @dataclass
343
- class DCAAnalysis:
344
- """
345
- GRN-v3 depth-wise aggregate weights from a DeepCrossAttention layer.
346
- Only populated when use_dca=True.
347
-
348
- grn_depth_weights: softmax-free aggregate scalars used to weight each
349
- source layer, shape [3, y, B, S] where 3 = Q/K/V streams,
350
- y = selected stack depth (at most 2*dca_k), B = batch, S = seq.
351
- These are the per-position, per-layer scalars *before* adding the
352
- static bias β€” useful to see which layers the dynamic component
353
- selectively suppresses (ReLU zeros out negative entries).
354
- """
355
- grn_depth_weights: Optional[torch.Tensor] = None # [3, y, B, S]
356
-
357
-
358
  @dataclass
359
  class AttnResAnalysis:
360
  """
@@ -366,78 +350,6 @@ class AttnResAnalysis:
366
  sources_count: Optional[int] = None # number of sources including partial
367
 
368
 
369
- @dataclass
370
- class StackMemoryAnalysis:
371
- """
372
- Internals of a StackMemory forward pass.
373
- Only populated when use_stacktrans=True AND model is in eval + analysis mode.
374
-
375
- Reference: Zhang, K. et al. (NeurIPS 2025). "Recursive Transformer:
376
- Boosting Reasoning Ability with State Stack."
377
-
378
- action_probs: softmax distribution [push, pop, no-op] per head and
379
- token position. Shape [B, S, H, 3]. Visualising this
380
- across layers reveals the push-heavy early layers and
381
- pop-heavy later layers described in the paper (Β§B.2).
382
- stack_in: stack state entering this layer (the output of the
383
- previous layer's StackMemory). Shape [B, H, slots, ds].
384
- None for layer 0 (starts as all-zeros).
385
- stack_out: updated stack state after processing this sequence.
386
- Shape [B, H, slots, ds]. This is new_stack[:, -1] β€”
387
- the stack at the final sequence position, passed to
388
- the next layer as stack_in.
389
- mask_out: validity mask for stack_out. Shape [B, H, slots].
390
- Values near 1 indicate active slots; near 0 = empty.
391
- gate_weights: softmax attention weights used for global reading.
392
- Shape [B, S, H, slots]. High weight on slot i at
393
- position t means the model retrieved from slot i there.
394
- memory_output: weighted stack readout before up_proj.
395
- Shape [B, S, stack_d_model].
396
- residual_scale: value of the learnable res_weight scalar at this step.
397
- """
398
- action_probs: Optional[torch.Tensor] = None # [B,S,H,3]
399
- stack_in: Optional[torch.Tensor] = None # [B,H,slots,ds] entering layer
400
- stack_out: Optional[torch.Tensor] = None # [B,H,slots,ds] leaving layer
401
- mask_out: Optional[torch.Tensor] = None # [B,H,slots]
402
- gate_weights: Optional[torch.Tensor] = None # [B,S,H,slots]
403
- memory_output: Optional[torch.Tensor] = None # [B,S,stack_d_model]
404
- residual_scale: Optional[float] = None # res_weight scalar
405
-
406
-
407
- @dataclass
408
- class LAuReLAnalysis:
409
- """
410
- Internals of one LAuReL residual connection forward pass.
411
- Only populated when use_laurel=True AND model is in eval + analysis mode.
412
- Instantiated twice per layer: once for the attention residual, once for MLP.
413
-
414
- Reference: Menghani, G., Kumar, R. & Kumar, S. (ICML 2025).
415
- *LAuReL: Learned Augmented Residual Layer.* arXiv:2411.07501.
416
-
417
- Math (combined RW+LR, both sub-variants active):
418
-
419
- x_{i+1} = Ξ± Β· f(x_i) + Ξ² Β· (AΒ·(BΒ·x_i) + x_i)
420
-
421
- where [Ξ±, Ξ²] = softmax([a, b]), a,b ∈ ℝ learnable (RW component),
422
- B ∈ ℝ^{rΓ—D} column-orthogonal init, A ∈ ℝ^{DΓ—r} zero init (LR component).
423
- At step 0: A=0 β†’ lr_term=0, so x_{i+1} = 0.5Β·f(x) + 0.5Β·x_i (RW only)
424
- or x_{i+1} = f(x_i) + x_i (LR only, standard residual).
425
-
426
- Fields:
427
- alpha_rw: softmax(a) β€” weight on f(x_i). [scalar float]
428
- None when use_laurel_rw=False.
429
- beta_rw: softmax(b) β€” weight on g(x_i). [scalar float]
430
- None when use_laurel_rw=False.
431
- lr_term: AΒ·(BΒ·x_res) β€” the low-rank residual augmentation.
432
- Shape [B, S, D]. Zero at init. None when use_laurel_lr=False.
433
- output: Final combined tensor before GPAS. Shape [B, S, D].
434
- """
435
- alpha_rw: Optional[float] = None # softmax weight on f(x)
436
- beta_rw: Optional[float] = None # softmax weight on g(x)
437
- lr_term: Optional[torch.Tensor] = None # A(Bx) low-rank augmentation [B,S,D]
438
- output: Optional[torch.Tensor] = None # combined pre-GPAS [B,S,D]
439
-
440
-
441
  @dataclass
442
  class LayerAnalysis:
443
  """
@@ -466,12 +378,8 @@ class LayerAnalysis:
466
  gpas_mlp: Optional[GPASAnalysis] = None # GPAS after MLP residual
467
 
468
  # Optional components (None when inactive)
469
- jtokm: Optional[JTokMAnalysis] = None # if use_jtokm
470
- attn_res: Optional[AttnResAnalysis] = None # if use_attn_res
471
- dca: Optional[DCAAnalysis] = None # if use_dca
472
- stack: Optional[StackMemoryAnalysis] = None # if use_stacktrans
473
- laurel_attn: Optional[LAuReLAnalysis] = None # if use_laurel (attention residual)
474
- laurel_mlp: Optional[LAuReLAnalysis] = None # if use_laurel (MLP residual)
475
 
476
 
477
  @dataclass
@@ -536,7 +444,6 @@ class AnalysisState:
536
  layers: Optional[List[LayerAnalysis]] = None
537
  jtokm_aux_stats: Optional[list] = None
538
  attn_res_sources_final: Optional[list] = None
539
- dca_all_tokens_final: Optional[list] = None
540
  logits: Optional[torch.Tensor] = None
541
 
542
  class ScalarMultiplier(nn.Module):
@@ -2456,8 +2363,6 @@ class NeoLLMAttention(nn.Module):
2456
  first_layer_fan: Optional[torch.Tensor] = None,
2457
  attn_analysis: Optional[AttentionAnalysis] = None,
2458
  repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
2459
- mudd_xk: Optional[torch.Tensor] = None,
2460
- mudd_xv: Optional[torch.Tensor] = None,
2461
  **kwargs: Unpack[FlashAttentionKwargs],
2462
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
2463
  input_shape = hidden_states.shape[:-1]
@@ -2468,14 +2373,6 @@ class NeoLLMAttention(nn.Module):
2468
  h_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * h_fan
2469
  current_layer_fan = h_fan.clone()
2470
 
2471
- # ── MUDD: separate K/V FAN paths ─────────────────────────────────
2472
- # When mudd_xk/mudd_xv are provided (MUDD qkvr mode), they have already
2473
- # been normalized by the decoder layer's K/V norm chain. Here they go
2474
- # through their own FAN transform before k_proj/v_proj, keeping the
2475
- # FANformer periodicity modeling orthogonally intact per stream.
2476
- h_fan_k = self.fan_layer(mudd_xk) if mudd_xk is not None else h_fan
2477
- h_fan_v = self.fan_layer(mudd_xv) if mudd_xv is not None else h_fan
2478
-
2479
  query_shape = (*input_shape, self.config.num_attention_heads, self.head_dim)
2480
  kv_shape = (*input_shape, self.num_mea_component_heads, self.head_dim)
2481
 
@@ -2490,8 +2387,8 @@ class NeoLLMAttention(nn.Module):
2490
  attn_analysis.gate_raw = gate.detach()
2491
 
2492
  q = self.q_norm(q_raw.view(query_shape)).transpose(1, 2)
2493
- k = self.k_norm(self.k_proj(h_fan_k).view(kv_shape)).transpose(1, 2)
2494
- v = self.v_proj(h_fan_v).view(kv_shape).transpose(1, 2)
2495
 
2496
  if attn_analysis is not None:
2497
  attn_analysis.q_post_norm = q.detach()
@@ -3168,651 +3065,6 @@ class NeoLLMMLP(nn.Module):
3168
  return result
3169
 
3170
 
3171
- class NeoLLMMUDDModule(nn.Module):
3172
- """
3173
- Multiway Dynamic Dense (MUDD) Depth-wise Aggregate module.
3174
-
3175
- Generates per-position, per-stream connection weights over all preceding
3176
- layer outputs (and the token embedding) and produces up to C=4 aggregated
3177
- streams (Q, K, V, R) for the next Transformer block.
3178
-
3179
- Architecture (Xiao et al., 2025, arXiv:2502.12170):
3180
- dw = GELU(RMSNorm(x) @ W1) @ W2 + a # [B, T, C*(lidx+2)]
3181
- dw = reshape to [C, B, T, (lidx+2)]
3182
- stream_c = Ξ£_j dw[c, :, :, j] * hiddens[j] for c in range(C)
3183
-
3184
- W1 ~ N(0, 1/D), W2 = 0, a = identity on last index β†’ reduces to standard
3185
- Transformer at init (dynamic part is zero, static bias selects Xi).
3186
-
3187
- Args:
3188
- hidden_size: model dimension D
3189
- lidx: layer index (0-based); history has lidx+2 entries
3190
- num_ways: C, number of output streams (4 for "qkvr", 1 for "l")
3191
- is_last: whether this is the last layer (controls expand_last)
3192
- expand_last: multiply hid_dim by 4 for the final layer's DA module
3193
- round64: round hid_dim up to the nearest multiple of 64
3194
- """
3195
-
3196
- def __init__(
3197
- self,
3198
- hidden_size: int,
3199
- lidx: int,
3200
- num_ways: int = 4,
3201
- is_last: bool = False,
3202
- expand_last: bool = False,
3203
- round64: bool = False,
3204
- ) -> None:
3205
- super().__init__()
3206
- self.lidx = lidx
3207
- self.num_ways = num_ways
3208
- l = lidx + 2 # history length: embedding + lidx layers
3209
- hid_dim = l * num_ways
3210
- out_dim = l * num_ways
3211
- if is_last and expand_last:
3212
- hid_dim *= 4
3213
- if round64:
3214
- hid_dim = (hid_dim // 64 + 1) * 64
3215
- # RMSNorm without learnable scale (paper uses RMSnormNoscale)
3216
- self.norm = nn.RMSNorm(hidden_size, elementwise_affine=False,
3217
- eps=1e-6)
3218
- self.w1 = nn.Linear(hidden_size, hid_dim, bias=False)
3219
- self.act = nn.GELU()
3220
- self.w2 = nn.Linear(hid_dim, out_dim, bias=False)
3221
- self._reset_mudd_parameters(hidden_size)
3222
-
3223
- def _reset_mudd_parameters(self, D: int) -> None:
3224
- # W1 ~ N(0, 1/D); W2 = 0 β†’ dynamic part starts at zero
3225
- nn.init.normal_(self.w1.weight, mean=0.0, std=1.0 / D)
3226
- nn.init.zeros_(self.w2.weight)
3227
-
3228
- def forward(
3229
- self,
3230
- x: torch.Tensor, # [B, T, D] β€” current layer output (Xi)
3231
- hiddens: list, # list of lidx+2 tensors [B, T, D]
3232
- static_bias: torch.Tensor, # [C, lidx+2] β€” learnable static prior
3233
- ) -> tuple:
3234
- """
3235
- Returns:
3236
- Tuple of num_ways tensors, each [B, T, D] β€” the aggregated streams.
3237
- """
3238
- B, T, D = x.shape
3239
- # Dynamic weight generation: [B, T, C*(lidx+2)]
3240
- dw = self.w2(self.act(self.w1(self.norm(x))))
3241
- # Add static bias (broadcast over B and T)
3242
- # static_bias: [C, L] β†’ [1, 1, C*L] via reshape
3243
- C, L = static_bias.shape
3244
- dw = dw + static_bias.reshape(1, 1, C * L).to(dw.dtype)
3245
- # Reshape to [C, B, T, L]
3246
- dw = dw.view(B, T, C, L).permute(2, 0, 1, 3) # [C, B, T, L]
3247
- # Stack history: [L, B, T, D]
3248
- stacked = torch.stack(hiddens, dim=0) # [L, B, T, D]
3249
- # Aggregate: Ξ£_j dw[c, :, :, j] * hiddens[j]
3250
- # einsum "cbtl, lbtd -> cbtd"
3251
- streams = torch.einsum('cbtl,lbtd->cbtd', dw, stacked) # [C, B, T, D]
3252
- return tuple(streams[c] for c in range(C))
3253
-
3254
-
3255
- def dca_select_layers(stacked: torch.Tensor, k: int) -> torch.Tensor:
3256
- """
3257
- k-DCA layer selection (Heddes et al., 2025, Β§3.1).
3258
-
3259
- Keeps only the first k and last k tensors from the depth stack,
3260
- capping memory at 2k layer representations regardless of depth.
3261
- When the stack has <= 2k entries all are kept (early layers).
3262
-
3263
- Args:
3264
- stacked: [y, B, S, D] β€” stack of all layer outputs so far.
3265
- k: number of first/last layers to retain.
3266
- Returns:
3267
- [min(y, 2k), B, S, D]
3268
- """
3269
- y = stacked.shape[0]
3270
- if y <= k * 2:
3271
- return stacked
3272
- return torch.cat([stacked[:k], stacked[-k:]], dim=0)
3273
-
3274
-
3275
- class NeoLLMGRN(nn.Module):
3276
- """
3277
- Generalized Residual Network v3 (GRN-v3) from DeepCrossAttention
3278
- (Heddes et al., 2025, arXiv:2502.06785, Β§3.1).
3279
-
3280
- Produces `num_outputs` aggregated streams from a depth-wise stack of
3281
- layer representations. Weights are simultaneously:
3282
-
3283
- - **Input-dependent** (dynamic): a two-layer mapping
3284
- ``wΜ„ = ReLU(RMSNorm(G) @ W)`` produces one scalar per
3285
- (output-stream, depth-position, batch-token). ``W`` is initialized
3286
- to zero so the dynamic contribution starts neutral.
3287
- - **Dimension-dependent** (static): a learnable bias ``b`` of shape
3288
- ``[num_outputs, num_stack_layers, hidden_size]`` initialized to ones
3289
- provides a per-dimension, per-layer prior. At initialization the
3290
- dynamic part is zero and the static bias sums to an equal-weight
3291
- average over all stack entries, reducing to a standard residual mean.
3292
-
3293
- The combined weight for output stream ``o``, stack position ``y``,
3294
- batch ``b``, token ``n``, feature ``d`` is::
3295
-
3296
- weight[o, y, b, n, d] = ReLU(dynamic[y, b, n, o]) + bias[o, y, d]
3297
-
3298
- Output ``o`` is then the weighted sum over depth::
3299
-
3300
- out[o, b, n, d] = Ξ£_y stack[y, b, n, d] * weight[o, y, b, n, d]
3301
-
3302
- Reference:
3303
- Heddes, M. et al. (2025). *DeepCrossAttention: Supercharging
3304
- Transformer Residual Connections.* arXiv:2502.06785.
3305
-
3306
- Args:
3307
- hidden_size: model dimension D.
3308
- num_stack_layers: number of depth entries this GRN will receive
3309
- (= min(layer_idx+1, 2*dca_k)).
3310
- num_outputs: number of output streams (3 for DCA Q/K/V,
3311
- 1 for the final aggregation GRN).
3312
- eps: epsilon for the internal RMSNorm.
3313
- """
3314
-
3315
- def __init__(
3316
- self,
3317
- hidden_size: int,
3318
- num_stack_layers: int,
3319
- num_outputs: int = 3,
3320
- eps: float = 1e-6,
3321
- ) -> None:
3322
- super().__init__()
3323
- self.num_outputs = num_outputs
3324
- self.num_stack_layers = num_stack_layers
3325
-
3326
- # Dynamic component: RMSNorm(no scale) β†’ Linear β†’ ReLU
3327
- # Linear maps D β†’ num_outputs; init zeros so dynamic part = 0 at step 0.
3328
- _linear = nn.Linear(hidden_size, num_outputs, bias=False)
3329
- nn.init.zeros_(_linear.weight)
3330
- self.norm_noscale = nn.RMSNorm(
3331
- hidden_size, eps=eps, elementwise_affine=False
3332
- )
3333
- self.to_dynamic = nn.Sequential(_linear, nn.ReLU())
3334
-
3335
- # Static bias: [num_outputs, num_stack_layers, hidden_size], init ones.
3336
- # At init: weight = 0 + bias = 1 per entry β†’ equal-weight average β†’ residual.
3337
- self.bias = nn.Parameter(
3338
- torch.ones(num_outputs, num_stack_layers, hidden_size)
3339
- )
3340
-
3341
- def forward(
3342
- self,
3343
- stack: torch.Tensor,
3344
- analysis: Optional["DCAAnalysis"] = None,
3345
- ) -> tuple:
3346
- """
3347
- Args:
3348
- stack: [y, B, S, D] β€” selected depth stack (y ≀ 2*dca_k).
3349
- analysis: optional DCAAnalysis to deposit grn_depth_weights.
3350
- Returns:
3351
- Tuple of num_outputs tensors each [B, S, D].
3352
- When num_outputs=1 returns a single [B, S, D] tensor directly.
3353
- """
3354
- y, B, S, D = stack.shape
3355
- assert y == self.num_stack_layers, (
3356
- f"NeoLLMGRN expected stack depth {self.num_stack_layers}, got {y}"
3357
- )
3358
-
3359
- # Dynamic aggregate: [y, B, S, D] β†’ norm β†’ [y, B, S, D]
3360
- # β†’ to_dynamic β†’ [y, B, S, num_outputs]
3361
- # β†’ permute β†’ [num_outputs, y, B, S]
3362
- normed = self.norm_noscale(stack) # [y, B, S, D]
3363
- dynamic = self.to_dynamic(normed) # [y, B, S, num_outputs]
3364
- dynamic = dynamic.permute(3, 0, 1, 2) # [o, y, B, S]
3365
-
3366
- if analysis is not None:
3367
- analysis.grn_depth_weights = dynamic.detach()
3368
-
3369
- # Combined weight: dynamic scalar + static bias per dimension
3370
- # dynamic: [o, y, B, S] β†’ [o, y, B, S, 1]
3371
- # bias: [o, y, D] β†’ [o, y, 1, 1, D]
3372
- weights = dynamic.unsqueeze(-1) + self.bias.unsqueeze(2).unsqueeze(3)
3373
- # weights: [o, y, B, S, D]
3374
-
3375
- # Weighted depth-sum: Ξ£_y stack[y] * weights[o, y]
3376
- # stack: [y, B, S, D] β†’ [1, y, B, S, D]
3377
- output = (stack.unsqueeze(0) * weights).sum(dim=1) # [o, B, S, D]
3378
-
3379
- if self.num_outputs == 1:
3380
- return output.squeeze(0) # [B, S, D]
3381
- return tuple(output[i] for i in range(self.num_outputs))
3382
-
3383
-
3384
- class StackMemory(nn.Module):
3385
- """
3386
- Differentiable multi-head hidden-state stack for NeoLLM.
3387
-
3388
- Implements the StackTrans module from Zhang et al. (NeurIPS 2025):
3389
- "Recursive Transformer: Boosting Reasoning Ability with State Stack."
3390
-
3391
- Architecture (one forward call, covering the full sequence in parallel):
3392
-
3393
- 1. down_proj : [B,S,D] β†’ [B,S,stack_d_model]
3394
- 2. action_head: β†’ [B,S,H,3] softmax (push / pop / no-op)
3395
- 3. k_values : reshape to [B,S,H,ds]
3396
- 4. _vectorized_update: applies soft push/pop/no-op to each
3397
- (batch, head) stack in parallel across the sequence dim.
3398
- This is the training-parallelism approximation from Β§3.3:
3399
- every token sees the *same* initial stack, breaking strict
3400
- temporal ordering within a sequence in exchange for full
3401
- data-parallelism. Cross-token memory is recovered during
3402
- autoregressive generation via the step() / enable_cache path.
3403
- 5. gate_proj : global read β€” softmax over all stack slots
3404
- (paper Β§3.1: "query-over-stack attention"), masked by the
3405
- validity mask. Returns weighted sum of the stack.
3406
- 6. up_proj : [B,S,stack_d_model] β†’ [B,S,D]
3407
- 7. residual : output = up_proj_out * res_weight + hidden_states
3408
-
3409
- Vertical passing (layer-to-layer):
3410
- Returns new_stack[:, -1] and new_mask[:, -1] β€” the stack state
3411
- at the last sequence position β€” which becomes the initial stack
3412
- for the next decoder layer. This propagates hierarchical context
3413
- depth-wise through the network.
3414
-
3415
- Temporal accumulation (generation):
3416
- During autoregressive decoding, enable_cache=True and step() is
3417
- used: k_cache and action_cache store previous-token values so the
3418
- update equation integrates the full generated history rather than
3419
- starting from zeros each step.
3420
-
3421
- Args:
3422
- config: NeoLLMConfig instance. Reads:
3423
- stacktrans_num_heads (H, number of stack heads)
3424
- stacktrans_stack_slots (S, stack depth)
3425
- stacktrans_stack_d_model (HΓ—ds, low-rank dimension)
3426
- stacktrans_forward_bs (batch size for cache buffers)
3427
- """
3428
-
3429
- def __init__(self, config: NeoLLMConfig):
3430
- super().__init__()
3431
- self.num_stack_heads = config.stacktrans_num_heads
3432
- self.stack_slots = config.stacktrans_stack_slots
3433
- self.stack_d_model = config.stacktrans_stack_d_model
3434
- self.head_dim = self.stack_d_model // self.num_stack_heads
3435
-
3436
- # Dimension reduction / expansion (standard nn.Linear, no multipliers β€”
3437
- # StackMemory is architecturally independent per the paper Β§A)
3438
- self.down_proj = nn.Linear(config.hidden_size, self.stack_d_model, bias=True)
3439
- self.up_proj = nn.Linear(self.stack_d_model, config.hidden_size, bias=True)
3440
-
3441
- # Action prediction: push / pop / no-op probabilities, one triple per head
3442
- self.action_head = nn.Linear(self.stack_d_model, 3 * self.num_stack_heads, bias=True)
3443
-
3444
- # Global read query: one scalar gate per stack slot per head
3445
- self.gate_proj = nn.Linear(self.head_dim, 1, bias=True)
3446
-
3447
- # Learnable residual gate (paper h'_t = g_hΒ·h_t + R_t, g_h scalar)
3448
- self.res_weight = nn.Parameter(torch.ones(1))
3449
-
3450
- # ── Autoregressive generation cache ──────────────────────────────
3451
- # k_cache and action_cache hold per-token values from previous steps
3452
- # so step() can reconstruct the full sequence history. Only used when
3453
- # enable_cache=True (set by NeoLLMModel.forward when use_cache=True).
3454
- _fbs = getattr(config, "stacktrans_forward_bs", 1)
3455
- _cs = getattr(config, "cache_size", 2048)
3456
- self.register_buffer(
3457
- "k_cache",
3458
- torch.zeros(_fbs, _cs, self.num_stack_heads, self.head_dim),
3459
- )
3460
- self.register_buffer(
3461
- "action_cache",
3462
- torch.zeros(_fbs, _cs, self.num_stack_heads, 3),
3463
- )
3464
- self.cache_position = 0
3465
- self.enable_cache = False
3466
-
3467
- # ── Cache helpers ─────────────────────────────────────────────────────
3468
-
3469
- def reset_cache(self) -> None:
3470
- self.cache_position = 0
3471
-
3472
- def _update_cache(
3473
- self,
3474
- k_values: torch.Tensor, # [B,S,H,ds] detached
3475
- actions: torch.Tensor, # [B,S,H,3] detached
3476
- ) -> None:
3477
- seq_len = k_values.shape[1]
3478
- if self.cache_position + seq_len <= self.k_cache.shape[1]:
3479
- self.k_cache [:, self.cache_position:self.cache_position + seq_len] = k_values
3480
- self.action_cache[:, self.cache_position:self.cache_position + seq_len] = actions
3481
- self.cache_position += seq_len
3482
- else:
3483
- self.reset_cache()
3484
-
3485
- # ── Core stack update ─────────────────────────────────────────────────
3486
-
3487
- def _vectorized_update(
3488
- self,
3489
- stack: torch.Tensor, # [B, H, slots, ds] (4-D) or [B,S,H,slots,ds] (5-D)
3490
- mask: torch.Tensor, # [B, H, slots] (3-D) or [B,S,H,slots] (4-D)
3491
- actions: torch.Tensor, # [B, S, H, 3]
3492
- k_values: torch.Tensor, # [B, S, H, ds]
3493
- ) -> Tuple[torch.Tensor, torch.Tensor]:
3494
- """
3495
- Vectorized soft push/pop/no-op stack update.
3496
-
3497
- Every token position receives the *same* initial stack (the one
3498
- passed in from the previous layer), and operations are applied in
3499
- parallel across S. This is the Β§3.3 training-parallelism
3500
- approximation: strict sequential dependency within a sequence is
3501
- broken intentionally to allow full batch processing.
3502
-
3503
- Returns:
3504
- new_stack [B, S, H, slots, ds]
3505
- new_mask [B, S, H, slots]
3506
- """
3507
- batch_size, seq_len = actions.shape[:2]
3508
-
3509
- # Broadcast 4-D initial state along the sequence dimension
3510
- if stack.dim() == 4:
3511
- stack = stack.unsqueeze(1).expand(-1, seq_len, -1, -1, -1)
3512
- mask = mask.unsqueeze(1).expand(-1, seq_len, -1, -1)
3513
-
3514
- # Push: new value at top, shift everything down (overflow discarded)
3515
- push_stack = torch.cat([k_values.unsqueeze(3), stack[:, :, :, :-1]], dim=3)
3516
- push_mask = torch.cat([torch.ones_like(mask[:, :, :, :1]),
3517
- mask[:, :, :, :-1]], dim=3)
3518
-
3519
- # Pop: shift everything up, zero at bottom
3520
- pop_stack = torch.cat([stack[:, :, :, 1:],
3521
- torch.zeros_like(stack[:, :, :, :1])], dim=3)
3522
- pop_mask = torch.cat([mask[:, :, :, 1:],
3523
- torch.zeros_like(mask[:, :, :, :1])], dim=3)
3524
-
3525
- # Soft combination weighted by action probabilities
3526
- # actions: [B,S,H,3] β†’ unsqueeze to [B,S,H,3,1,1] for stack broadcast
3527
- aw = actions.unsqueeze(-1).unsqueeze(-1) # [B,S,H,3,1,1]
3528
- stacks = torch.stack([push_stack, pop_stack, stack], dim=3) # [B,S,H,3,slots,ds]
3529
- masks = torch.stack([push_mask, pop_mask, mask], dim=3) # [B,S,H,3,slots]
3530
-
3531
- new_stack = (stacks * aw).sum(dim=3) # [B,S,H,slots,ds]
3532
- new_mask = (masks * aw.squeeze(-1)).sum(dim=3) # [B,S,H,slots]
3533
- return new_stack, new_mask
3534
-
3535
- # ── Training forward (full sequence) ─────────────────────────────────
3536
-
3537
- def forward(
3538
- self,
3539
- hidden_states: torch.Tensor,
3540
- stack: Optional[torch.Tensor] = None,
3541
- mask: Optional[torch.Tensor] = None,
3542
- analysis: Optional[StackMemoryAnalysis] = None,
3543
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
3544
- """
3545
- Full-sequence forward pass (training and prefill).
3546
-
3547
- Args:
3548
- hidden_states: [B, S, D]
3549
- stack: [B, H, slots, ds] β€” previous layer's stack state,
3550
- or None (initialised to zeros for layer 0).
3551
- mask: [B, H, slots] β€” validity mask for stack,
3552
- or None (initialised to zeros for layer 0).
3553
- analysis: StackMemoryAnalysis container; populated when
3554
- model is in eval + analysis mode.
3555
-
3556
- Returns:
3557
- (output, new_stack, new_mask)
3558
- output [B, S, D]
3559
- new_stack [B, H, slots, ds] β€” stack at final sequence position
3560
- new_mask [B, H, slots]
3561
- """
3562
- batch_size, seq_len, _ = hidden_states.shape
3563
- device = hidden_states.device
3564
-
3565
- # Capture incoming stack for analysis before it is updated
3566
- if analysis is not None:
3567
- analysis.stack_in = stack.detach() if stack is not None else None
3568
-
3569
- # Initialise empty stack / mask for layer 0
3570
- if stack is None:
3571
- stack = torch.zeros(
3572
- batch_size, self.num_stack_heads, self.stack_slots, self.head_dim,
3573
- device=device, dtype=hidden_states.dtype,
3574
- )
3575
- if mask is None:
3576
- mask = torch.zeros(
3577
- batch_size, self.num_stack_heads, self.stack_slots,
3578
- device=device, dtype=hidden_states.dtype,
3579
- )
3580
-
3581
- # 1. Project down
3582
- h_proj = self.down_proj(hidden_states) # [B,S,stack_d_model]
3583
-
3584
- # 2. Action probabilities
3585
- action_logits = self.action_head(h_proj) / math.sqrt(self.head_dim)
3586
- actions = F.softmax(
3587
- action_logits.view(batch_size, seq_len, self.num_stack_heads, 3), dim=-1
3588
- ) # [B,S,H,3]
3589
-
3590
- # 3. Values to push
3591
- k_values = h_proj.view(batch_size, seq_len, self.num_stack_heads, self.head_dim)
3592
-
3593
- # 4. Vectorized stack update
3594
- new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values)
3595
- # new_stack: [B,S,H,slots,ds], new_mask: [B,S,H,slots]
3596
-
3597
- # 5. Global read (query-over-stack attention, paper Β§3.1)
3598
- gate_scores = self.gate_proj(new_stack).squeeze(-1) # [B,S,H,slots]
3599
- gate_weights = F.softmax(gate_scores + (1 - new_mask) * -1e9, dim=-1)
3600
- memory_out = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3)
3601
- # memory_out: [B,S,H,ds] β†’ [B,S,stack_d_model]
3602
- memory_out = memory_out.view(batch_size, seq_len, self.stack_d_model)
3603
-
3604
- # 6. Project back up
3605
- memory_out_proj = self.up_proj(memory_out) # [B,S,D]
3606
-
3607
- # 7. Residual
3608
- output = memory_out_proj * self.res_weight + hidden_states
3609
-
3610
- # 8. Update generation cache (no-op during training)
3611
- if self.enable_cache:
3612
- self._update_cache(k_values.detach(), actions.detach())
3613
-
3614
- # Populate analysis fields
3615
- if analysis is not None:
3616
- analysis.action_probs = actions.detach()
3617
- analysis.stack_out = new_stack[:, -1].detach()
3618
- analysis.mask_out = new_mask[:, -1].detach()
3619
- analysis.gate_weights = gate_weights.detach()
3620
- analysis.memory_output = memory_out.detach()
3621
- analysis.residual_scale = self.res_weight.item()
3622
-
3623
- # Return output + last-position stack state for next layer
3624
- return output, new_stack[:, -1], new_mask[:, -1]
3625
-
3626
- # ── Autoregressive single-token forward ──────────────────────────────
3627
-
3628
- def step(
3629
- self,
3630
- hidden_state: torch.Tensor, # [B, D]
3631
- stack: torch.Tensor, # [B, H, slots, ds]
3632
- mask: torch.Tensor, # [B, H, slots]
3633
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
3634
- """
3635
- Single-token forward for autoregressive generation.
3636
-
3637
- When enable_cache=False (simple path used by NeoLLM generation):
3638
- Calls forward() with a length-1 sequence and unpacks the result.
3639
- The stack state passed in carries all history from previous tokens
3640
- (propagated by NeoLLMModel.forward across generation steps).
3641
-
3642
- When enable_cache=True (full-history reconstruction path):
3643
- Concatenates the current token with cached previous-token values
3644
- and replays the full vectorized update, extracting only the last
3645
- position. This gives a more accurate stack that sees full history
3646
- at the cost of O(T) computation per step.
3647
-
3648
- Returns:
3649
- (output, new_stack, new_mask)
3650
- output [B, D]
3651
- new_stack [B, H, slots, ds]
3652
- new_mask [B, H, slots]
3653
- """
3654
- if not self.enable_cache:
3655
- # Simple path: forward with seq_len=1, squeeze the sequence dim
3656
- out, new_stack, new_mask = self.forward(
3657
- hidden_state.unsqueeze(1), stack, mask
3658
- )
3659
- return out.squeeze(1), new_stack, new_mask
3660
-
3661
- batch_size = hidden_state.shape[0]
3662
-
3663
- # Compute features for the current token
3664
- h_proj = self.down_proj(hidden_state) # [B, stack_d_model]
3665
- a_logits = self.action_head(h_proj) / math.sqrt(self.head_dim)
3666
- cur_act = F.softmax(
3667
- a_logits.view(batch_size, 1, self.num_stack_heads, 3), dim=-1
3668
- ) # [B,1,H,3]
3669
- cur_k = h_proj.view(batch_size, 1, self.num_stack_heads, self.head_dim)
3670
-
3671
- # Prepend cached history (all previous tokens in this generation)
3672
- if self.cache_position > 0:
3673
- k_values = torch.cat([self.k_cache[:batch_size, :self.cache_position], cur_k], dim=1)
3674
- actions = torch.cat([self.action_cache[:batch_size, :self.cache_position], cur_act], dim=1)
3675
- else:
3676
- k_values = cur_k
3677
- actions = cur_act
3678
-
3679
- # Full vectorized update over history + current token; take last position
3680
- new_stack_seq, new_mask_seq = self._vectorized_update(stack, mask, actions, k_values)
3681
- new_stack = new_stack_seq[:, -1] # [B,H,slots,ds]
3682
- new_mask = new_mask_seq[:, -1] # [B,H,slots]
3683
-
3684
- # Global read on the new stack state
3685
- gate_scores = self.gate_proj(new_stack).squeeze(-1) # [B,H,slots]
3686
- gate_weights = F.softmax(gate_scores + (1 - new_mask) * -1e9, dim=-1)
3687
- memory_out = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=2)
3688
- memory_out = memory_out.view(batch_size, self.stack_d_model)
3689
-
3690
- memory_out_proj = self.up_proj(memory_out) # [B,D]
3691
- output = memory_out_proj * self.res_weight + hidden_state
3692
-
3693
- self._update_cache(cur_k, cur_act)
3694
- return output, new_stack, new_mask
3695
-
3696
-
3697
- @dataclass
3698
- class LAuReLLayer(nn.Module):
3699
- """
3700
- LAuReL: Learned Augmented Residual Layer.
3701
-
3702
- A lightweight replacement for the canonical residual connection
3703
- that learns to blend the nonlinear sub-layer output f(x) with a
3704
- richer linear function of the residual x, optionally augmented by a
3705
- low-rank transformation.
3706
-
3707
- Reference: Menghani, G., Kumar, R. & Kumar, S. (ICML 2025).
3708
- *LAuReL: Learned Augmented Residual Layer.* arXiv:2411.07501.
3709
-
3710
- ── Sub-variants ────────────────────────────────────────────────────
3711
- Controlled by config flags; any combination is valid:
3712
-
3713
- **RW only** (use_laurel_rw=True, use_laurel_lr=False):
3714
-
3715
- x_{i+1} = Ξ± Β· f(x_i) + Ξ² Β· x_i
3716
- [Ξ±, Ξ²] = softmax([a, b]), a,b ∈ ℝ (2 params)
3717
-
3718
- **LR only** (use_laurel_rw=False, use_laurel_lr=True):
3719
-
3720
- x_{i+1} = f(x_i) + AΒ·(BΒ·x_i) + x_i
3721
- B ∈ ℝ^{rΓ—D} column-orthogonal init (down-projection)
3722
- A ∈ ℝ^{DΓ—r} zero init (up-projection)
3723
- Params: 2Β·rΒ·D per layer.
3724
-
3725
- **RW + LR** (both True, paper recommendation):
3726
-
3727
- x_{i+1} = Ξ± Β· f(x_i) + Ξ² Β· (AΒ·(BΒ·x_i) + x_i)
3728
-
3729
- ── Initialisation ──────────────────────────────────────────────────
3730
- RW: raw logits [a, b] = [0, 0] β†’ Ξ±=Ξ²=0.5 at step 0.
3731
- LR: A (up) = zeros β†’ lr_term = 0 at step 0 β†’ pure residual at init.
3732
- This ensures the model starts as a standard residual and smoothly
3733
- diverges as the gates and low-rank matrices are trained.
3734
-
3735
- ── Integration in NeoLLM ───────────────────────────────────────────
3736
- Applied immediately before GPAS at both residual sums per layer:
3737
-
3738
- h_tilde = GPAS( LAuReL(attn_out, residual_attn) )
3739
- output = GPAS( LAuReL(delta_m, residual_mlp) )
3740
-
3741
- GPAS then applies its stop-gradient scaling on the combined stream,
3742
- preserving gradient magnitudes across the depth of the network.
3743
- The two techniques are structurally orthogonal: LAuReL controls the
3744
- *mixing ratio* of f(x) and x at each residual junction; GPAS
3745
- controls the *magnitude* of the combined stream with a learned gate
3746
- and a stop-gradient operator that prevents gradient vanishing.
3747
-
3748
- Args:
3749
- config: NeoLLMConfig. Reads use_laurel_rw, use_laurel_lr,
3750
- laurel_lr_rank, hidden_size.
3751
- """
3752
-
3753
- def __init__(self, config: NeoLLMConfig):
3754
- super().__init__()
3755
- self.use_rw = getattr(config, "use_laurel_rw", True)
3756
- self.use_lr = getattr(config, "use_laurel_lr", True)
3757
- D = config.hidden_size
3758
- r = getattr(config, "laurel_lr_rank", 32)
3759
-
3760
- if self.use_rw:
3761
- # Raw logits for softmax([Ξ±, Ξ²]).
3762
- # Stored as a single 2-vector so softmax is one op.
3763
- # Init to zero β†’ Ξ±=Ξ²=0.5 at step 0.
3764
- self.rw_logits = nn.Parameter(torch.zeros(2))
3765
-
3766
- if self.use_lr:
3767
- # down: B ∈ ℝ^{rΓ—D}, column-orthogonal init (paper Β§3.3 LLM recommendation)
3768
- # up: A ∈ ℝ^{DΓ—r}, zero init β†’ lr_term=0 at step 0 (LoRA-style)
3769
- self.lr_down = nn.Linear(D, r, bias=False)
3770
- self.lr_up = nn.Linear(r, D, bias=False)
3771
-
3772
- def forward(
3773
- self,
3774
- f_out: torch.Tensor, # output of f(x): attn or MLP [B,S,D]
3775
- x_res: torch.Tensor, # residual (skip connection) [B,S,D]
3776
- analysis: Optional[LAuReLAnalysis] = None,
3777
- ) -> torch.Tensor:
3778
- """
3779
- Args:
3780
- f_out: Output of f(x) β€” attention output or MLP delta.
3781
- x_res: Residual tensor β€” accumulated hidden state.
3782
- analysis: Optional analysis container; populated in eval+analysis mode.
3783
-
3784
- Returns:
3785
- Combined tensor [B, S, D] to be fed into GPAS.
3786
- """
3787
- # ── LR component: AΒ·(BΒ·x_res) ────────────────────────────────────
3788
- lr_term = None
3789
- if self.use_lr:
3790
- lr_term = self.lr_up(self.lr_down(x_res)) # [B,S,D]
3791
- g_res = lr_term + x_res # enriched residual
3792
- else:
3793
- g_res = x_res
3794
-
3795
- # ── RW component: Ξ±Β·f + Ξ²Β·g ──────────────────────────────────────
3796
- if self.use_rw:
3797
- weights = torch.softmax(self.rw_logits, dim=0) # [2]
3798
- alpha = weights[0]
3799
- beta = weights[1]
3800
- out = alpha * f_out + beta * g_res
3801
- else:
3802
- # LR only: standard sum with enriched residual
3803
- out = f_out + g_res
3804
-
3805
- if analysis is not None:
3806
- if self.use_rw:
3807
- analysis.alpha_rw = alpha.item()
3808
- analysis.beta_rw = beta.item()
3809
- if self.use_lr:
3810
- analysis.lr_term = lr_term.detach()
3811
- analysis.output = out.detach()
3812
-
3813
- return out
3814
-
3815
-
3816
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
3817
  """
3818
  Decoder layer with standard residual connections, optional JTok-M injection.
@@ -3868,78 +3120,10 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
3868
  self.attn_res_query_attn = nn.Parameter(torch.zeros(config.hidden_size))
3869
  self.attn_res_query_mlp = nn.Parameter(torch.zeros(config.hidden_size))
3870
  self.attn_res_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
3871
- _num_blocks = getattr(config, 'attn_res_num_blocks', 0)
3872
- self.attn_res_block_size = (
3873
- max(config.num_hidden_layers // _num_blocks, 1) if _num_blocks > 0 else 1
3874
- )
3875
  else:
3876
  self.attn_res_query_attn = None
3877
  self.attn_res_query_mlp = None
3878
  self.attn_res_norm = None
3879
- self.attn_res_block_size = None
3880
-
3881
- # ── MUDD: separate K/V LayerNorms for qkvr+sepln mode ──────────────
3882
- # Only instantiated when both mudd_dense_type='qkvr' AND mudd_sepln=True.
3883
- # The existing input_layernorm handles the Q stream (unchanged).
3884
- # Separate norms for K and V allow each stream to rescale independently.
3885
- _use_mudd = getattr(config, 'use_mudd', False)
3886
- _mudd_qkvr = getattr(config, 'mudd_dense_type', 'qkvr') == 'qkvr'
3887
- _mudd_sepln = getattr(config, 'mudd_sepln', False)
3888
- if _use_mudd and _mudd_qkvr and _mudd_sepln:
3889
- self.mudd_k_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
3890
- self.mudd_v_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
3891
- else:
3892
- self.mudd_k_norm = None
3893
- self.mudd_v_norm = None
3894
-
3895
- # ── DCA (Heddes et al., 2025, arXiv:2502.06785) ───────────────────
3896
- # GRN-v3 module that aggregates the k-selected depth stack into 3
3897
- # independent streams (Q, K, V). Each stream has its own dimension-
3898
- # and input-dependent weights, enabling richer cross-layer interactions.
3899
- # K and V get their own SeeDNorm + LNS norm chain (same scheme as
3900
- # MUDD sepln) since they now arrive from a different aggregation path.
3901
- # The residual connection uses the Q stream output (xq) as its base,
3902
- # matching the DCA paper's decoder block design (residual = q_input).
3903
- self.use_dca = getattr(config, 'use_dca', False)
3904
- if self.use_dca:
3905
- _dca_k = getattr(config, 'dca_k', 2)
3906
- _num_stack = min(layer_idx + 1, 2 * _dca_k)
3907
- self.dca_grn = NeoLLMGRN(
3908
- hidden_size = config.hidden_size,
3909
- num_stack_layers = _num_stack,
3910
- num_outputs = 3,
3911
- eps = config.rms_norm_eps,
3912
- )
3913
- self.dca_k_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
3914
- self.dca_v_norm = SeeDNorm(config.hidden_size, eps=config.rms_norm_eps)
3915
- else:
3916
- self.dca_grn = None
3917
- self.dca_k_norm = None
3918
- self.dca_v_norm = None
3919
-
3920
- # ── StackTrans (Zhang et al., NeurIPS 2025) ───────────────────────
3921
- # Differentiable multi-head hidden-state stack inserted at the very
3922
- # beginning of the layer forward, before the attention sublayer.
3923
- # Mutually exclusive with use_attn_res, use_mudd, use_dca.
3924
- self.use_stacktrans = getattr(config, 'use_stacktrans', False)
3925
- if self.use_stacktrans:
3926
- self.stack_memory = StackMemory(config)
3927
- else:
3928
- self.stack_memory = None
3929
-
3930
- # ── LAuReL (Menghani, Kumar & Kumar, ICML 2025) ───────────────────
3931
- # Learned augmented residual connection replacing f(x)+x at both
3932
- # the attention and MLP residual sums. Applied immediately before
3933
- # GPAS, so GPAS still controls magnitude via stop-gradient scaling.
3934
- # Two independent instances per layer (attention and MLP).
3935
- # Compatible with use_stacktrans. Incompatible with MUDD/DCA/AttnRes.
3936
- self.use_laurel = getattr(config, 'use_laurel', False)
3937
- if self.use_laurel:
3938
- self.laurel_attn = LAuReLLayer(config)
3939
- self.laurel_mlp = LAuReLLayer(config)
3940
- else:
3941
- self.laurel_attn = None
3942
- self.laurel_mlp = None
3943
 
3944
  def _attn_res(
3945
  self,
@@ -3989,10 +3173,6 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
3989
  B_vals: Optional[torch.Tensor] = None,
3990
  attn_res_sources: Optional[list] = None,
3991
  attn_res_partial: Optional[torch.Tensor] = None,
3992
- mudd_streams: Optional[tuple] = None,
3993
- dca_stack: Optional[torch.Tensor] = None,
3994
- stack_state: Optional[torch.Tensor] = None,
3995
- stack_mask: Optional[torch.Tensor] = None,
3996
  layer_analysis: Optional[LayerAnalysis] = None,
3997
  output_attentions: Optional[bool] = False,
3998
  repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
@@ -4002,63 +3182,6 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
4002
  if layer_analysis is not None:
4003
  layer_analysis.hidden_states_input = hidden_states.detach()
4004
 
4005
- # ── StackTrans: hidden-state stack (pre-attention, pre-norm) ─────
4006
- # Executed first so attention sees the stack-enriched representation.
4007
- # stack_state / stack_mask carry the stack from the previous layer;
4008
- # both are None for layer 0 (StackMemory initialises to zeros then).
4009
- # Mutually exclusive with MUDD / DCA / AttnRes β€” those branches are
4010
- # all skipped when use_stacktrans=True (enforced in NeoLLMConfig).
4011
- if self.use_stacktrans and self.stack_memory is not None:
4012
- st_analysis = layer_analysis.stack if layer_analysis is not None else None
4013
- hidden_states, stack_state, stack_mask = self.stack_memory(
4014
- hidden_states, stack_state, stack_mask, analysis=st_analysis
4015
- )
4016
-
4017
- # ── MUDD: unpack streams for Q/K/V/R (layer > 0 only) ────────────
4018
- # mudd_streams is a 4-tuple (xq, xk, xv, xr) when use_mudd=True and
4019
- # layer_idx > 0; None for layer 0 (standard residual there).
4020
- # xr replaces hidden_states as the residual throughout this layer.
4021
- # xq/xk/xv are the aggregated inputs for Q, K, V projections.
4022
- # When mudd_dense_type='l' (single stream), all four are equal.
4023
- # When mudd_sepln=True each stream has its own norm applied below.
4024
- mudd_xk = None
4025
- mudd_xv = None
4026
- if mudd_streams is not None:
4027
- xq_mudd, xk_mudd, xv_mudd, xr_mudd = mudd_streams
4028
- # Replace hidden_states with xr for residual connections
4029
- hidden_states = xr_mudd
4030
- # Norm K and V streams β€” use separate SeeDNorm if sepln, else
4031
- # they will share the main input_layernorm path via h_attn below
4032
- if self.mudd_k_norm is not None:
4033
- mudd_xk = self.lns_attn(self.mudd_k_norm(xk_mudd))
4034
- mudd_xv = self.lns_attn(self.mudd_v_norm(xv_mudd))
4035
- else:
4036
- # No sepln: K/V also go through the Q-path norm chain
4037
- mudd_xk = self.lns_attn(self.input_layernorm(xk_mudd))
4038
- mudd_xv = self.lns_attn(self.input_layernorm(xv_mudd))
4039
- # Override hidden_states for the Q path
4040
- hidden_states_for_attn = xq_mudd
4041
- else:
4042
- hidden_states_for_attn = hidden_states
4043
-
4044
- # ── DCA: GRN-v3 depth-wise aggregation ───────────────────────────
4045
- # When active, runs the per-layer GRN on the k-selected depth stack
4046
- # to produce three independent aggregated streams (Q, K, V).
4047
- # xq replaces hidden_states as both the Q projection input AND the
4048
- # post-attention residual (DCA paper: residual = q_input).
4049
- # xk and xv go through separate SeeDNorm+LNS chains and are injected
4050
- # into NeoLLMAttention via the existing mudd_xk/mudd_xv parameters.
4051
- dca_residual = None
4052
- dca_a = layer_analysis.dca if layer_analysis is not None else None
4053
- if self.use_dca and dca_stack is not None:
4054
- xq, xk, xv = self.dca_grn(dca_stack, analysis=dca_a)
4055
- dca_residual = xq
4056
- hidden_states_for_attn = xq
4057
- # K and V streams: SeeDNorm + LNS before k_proj / v_proj
4058
- # (reuses the mudd_xk/mudd_xv injection path in NeoLLMAttention)
4059
- mudd_xk = self.lns_attn(self.dca_k_norm(xk))
4060
- mudd_xv = self.lns_attn(self.dca_v_norm(xv))
4061
-
4062
  # ── Attention Residuals: compute pre-attention input ──────────────
4063
  # When active, the input to the attention sublayer is no longer the
4064
  # raw hidden_states (accumulated residual) but a softmax-weighted
@@ -4072,19 +3195,10 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
4072
  attn_res_sources, attn_res_partial, self.attn_res_query_attn,
4073
  ar_analysis, "attn",
4074
  )
4075
- # ── Block boundary fires HERE β€” after pre-attn, before attn sublayer ──
4076
- # Paper pseudocode (Fig. 2) timing: the completed partial of the previous
4077
- # block is pushed to sources AFTER the pre-attn AttnRes call, so the first
4078
- # layer of a new block still sees the old partial as an intra-block source
4079
- # (no duplicate) and the new intra-block accumulation starts from zeros.
4080
- if self.layer_idx > 0 and self.layer_idx % self.attn_res_block_size == 0:
4081
- attn_res_sources.append(attn_res_partial) # in-place; outer loop sees this
4082
- attn_res_partial = torch.zeros_like(attn_res_partial) # fresh delta start
4083
  residual_attn = attn_res_partial
4084
  else:
4085
- h_attn = hidden_states_for_attn # MUDD/DCA: xq stream or unchanged
4086
- # DCA: residual is xq (the GRN Q-stream output), not raw hidden_states
4087
- residual_attn = dca_residual if dca_residual is not None else hidden_states
4088
 
4089
  # ── Attention block ───────────────────────────────────────────────
4090
  sn_pre = layer_analysis.seednorm_pre_attn if layer_analysis is not None else None
@@ -4100,8 +3214,6 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
4100
  first_layer_fan=first_layer_fan,
4101
  attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
4102
  repo_rope_args=repo_rope_args,
4103
- mudd_xk=mudd_xk,
4104
- mudd_xv=mudd_xv,
4105
  **kwargs,
4106
  )
4107
 
@@ -4109,18 +3221,7 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
4109
  layer_analysis.attn_contribution = hidden_states.detach()
4110
 
4111
  gpas_attn_a = layer_analysis.gpas_attn if layer_analysis is not None else None
4112
-
4113
- # ── Attention residual sum ────────────────────────────────────────
4114
- # Standard: GPAS(residual_attn + hidden_states)
4115
- # LAuReL: GPAS(LAuReL(f_out=hidden_states, x_res=residual_attn))
4116
- # Both paths feed into GPAS which applies stop-gradient scaling.
4117
- if self.use_laurel and self.laurel_attn is not None:
4118
- la_attn_a = layer_analysis.laurel_attn if layer_analysis is not None else None
4119
- combined_attn = self.laurel_attn(hidden_states, residual_attn, analysis=la_attn_a)
4120
- else:
4121
- combined_attn = residual_attn + hidden_states
4122
-
4123
- h_tilde = self.gpas_attn(combined_attn, analysis=gpas_attn_a)
4124
 
4125
  if layer_analysis is not None:
4126
  layer_analysis.h_tilde = h_tilde.detach()
@@ -4156,11 +3257,8 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
4156
  if layer_analysis is not None:
4157
  layer_analysis.mlp_contribution = delta_m.detach()
4158
 
4159
- # ── MLP residual sum ──────────────────────────────────────────────
4160
- # LAuReL treats f(x) = delta_m [+ delta_r when JTok-M active] and
4161
- # x_res = residual_mlp. JTok-M delta_r is additive alongside delta_m,
4162
- # so the nonlinear component is delta_m + delta_r in that path.
4163
- gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
4164
  if self.use_jtokm and z_tilde is not None and B_vals is not None:
4165
  orig_shape = h_tilde.shape
4166
  h_flat = h_tilde.reshape(-1, self.hidden_size)
@@ -4171,21 +3269,11 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
4171
  delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat, analysis=jtokm_a)
4172
  delta_r = delta_r.reshape(orig_shape)
4173
 
4174
- f_mlp = delta_m + delta_r # combined nonlinear term
4175
- if self.use_laurel and self.laurel_mlp is not None:
4176
- la_mlp_a = layer_analysis.laurel_mlp if layer_analysis is not None else None
4177
- combined_mlp = self.laurel_mlp(f_mlp, residual_mlp, analysis=la_mlp_a)
4178
- else:
4179
- combined_mlp = residual_mlp + f_mlp
4180
- hidden_states = self.gpas_mlp(combined_mlp, analysis=gpas_mlp_a)
4181
  else:
4182
- aux_stats = None
4183
- if self.use_laurel and self.laurel_mlp is not None:
4184
- la_mlp_a = layer_analysis.laurel_mlp if layer_analysis is not None else None
4185
- combined_mlp = self.laurel_mlp(delta_m, residual_mlp, analysis=la_mlp_a)
4186
- else:
4187
- combined_mlp = residual_mlp + delta_m
4188
- hidden_states = self.gpas_mlp(combined_mlp, analysis=gpas_mlp_a)
4189
 
4190
  if layer_analysis is not None:
4191
  layer_analysis.hidden_states_output = hidden_states.detach()
@@ -4197,9 +3285,6 @@ class NeoLLMDecoderLayer(GradientCheckpointingLayer):
4197
  outputs += (aux_stats,)
4198
  if versatile_aux is not None:
4199
  outputs += (versatile_aux,)
4200
- # StackTrans: always append stack state (None, None when inactive)
4201
- # so NeoLLMModel.forward can extract them by position -2 and -1.
4202
- outputs += (stack_state, stack_mask)
4203
  return outputs
4204
 
4205
 
@@ -4583,45 +3668,6 @@ class NeoLLMPreTrainedModel(PreTrainedModel):
4583
  module.attn_res_query_attn.data.zero_()
4584
  module.attn_res_query_mlp.data.zero_()
4585
 
4586
- elif isinstance(module, StackMemory):
4587
- # Truncated-normal for all Linear weights (matches NeoLLM convention).
4588
- # Biases zeroed. res_weight starts at 1.0 so the stack readout
4589
- # contributes equally to the residual from step 0.
4590
- std = getattr(self.config, "initializer_range", 0.02)
4591
- cutoff = getattr(self.config, "init_cutoff_factor", 3.0) * std
4592
- for attr in ("down_proj", "up_proj", "action_head", "gate_proj"):
4593
- layer = getattr(module, attr, None)
4594
- if layer is not None and hasattr(layer, "weight"):
4595
- nn.init.trunc_normal_(
4596
- layer.weight, mean=0.0, std=std, a=-cutoff, b=cutoff
4597
- )
4598
- if layer.bias is not None:
4599
- nn.init.zeros_(layer.bias)
4600
- if hasattr(module, "res_weight"):
4601
- module.res_weight.data.fill_(1.0)
4602
-
4603
- elif isinstance(module, LAuReLLayer):
4604
- # RW: raw logits initialised to zero β†’ softmax([0,0]) = [0.5, 0.5].
4605
- # The model quickly learns the optimal Ξ±,Ξ² weighting.
4606
- # LR: lr_down (B, down-projection) β€” column orthogonal init,
4607
- # as recommended by the LAuReL paper Β§3.3 for LLMs.
4608
- # Column orthogonal preserves the L2 norm of the projected
4609
- # representation, ensuring stable gradient magnitudes
4610
- # through the low-rank bottleneck at init.
4611
- # lr_up (A, up-projection) β€” zero init β†’ lr_term = AΒ·Bx = 0
4612
- # at step 0, so the module starts as a standard residual.
4613
- # Gradient flows back through lr_down immediately via
4614
- # chain rule; A learns from step 1 onward.
4615
- if hasattr(module, "rw_logits"):
4616
- nn.init.zeros_(module.rw_logits)
4617
- if hasattr(module, "lr_down"):
4618
- # Column-orthogonal: each column of weight^T is orthonormal.
4619
- # nn.init.orthogonal_ produces a row-orthogonal matrix (rows
4620
- # are orthonormal). Transposing gives column-orthogonal.
4621
- nn.init.orthogonal_(module.lr_down.weight)
4622
- if hasattr(module, "lr_up"):
4623
- nn.init.zeros_(module.lr_up.weight)
4624
-
4625
  elif isinstance(module, SpellingBeeEmbedding):
4626
  # byte_emb initialised identically to token embeddings: std=1/√d.
4627
  # Ensures E[β€–e_byteβ€–Β²] β‰ˆ 1 at init, matching etok, so the
@@ -4691,82 +3737,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
4691
  self.gradient_checkpointing = False
4692
  self.first_layer_fan = None
4693
 
4694
- # ── StackTrans state flag ─────────────────────────────────────────
4695
- self.use_stacktrans = getattr(config, 'use_stacktrans', False)
4696
-
4697
- # ── Residual-replacement mutex ────────────────────────────────────
4698
- # AttnRes, MUDD, and DCA all replace the residual aggregation
4699
- # mechanism β€” at most one can be active at a time.
4700
- _use_mudd = getattr(config, 'use_mudd', False)
4701
- _use_attn_res = getattr(config, 'use_attn_res', False)
4702
- _use_dca = getattr(config, 'use_dca', False)
4703
- _active_count = sum([_use_mudd, _use_attn_res, _use_dca])
4704
- if _active_count > 1:
4705
- active = [n for n, f in [('use_mudd', _use_mudd),
4706
- ('use_attn_res', _use_attn_res),
4707
- ('use_dca', _use_dca)] if f]
4708
- raise ValueError(
4709
- f"use_mudd, use_attn_res, and use_dca are mutually exclusive β€” "
4710
- f"got {active} simultaneously active. Set exactly one to True."
4711
- )
4712
- if _use_mudd:
4713
- _mudd_dense_type = getattr(config, 'mudd_dense_type', 'qkvr')
4714
- _mudd_dynamic = getattr(config, 'mudd_dynamic_dense', True)
4715
- _mudd_round64 = getattr(config, 'mudd_round64', False)
4716
- _mudd_expand_last = getattr(config, 'mudd_expand_last', False)
4717
- _C = 4 if _mudd_dense_type == 'qkvr' else 1
4718
-
4719
- # Static bias: one [C, lidx+2] parameter per layer.
4720
- # Initialized with 1 at index [c, lidx+1] (identity on Xi) so that
4721
- # at init (W2=0) each DA output = Xi β€” reducing to standard Transformer.
4722
- _static_list = []
4723
- for lidx in range(config.num_hidden_layers):
4724
- # Last layer always uses C=1: its DA output is the final
4725
- # model representation fed to the norm and lm_head, collapsing
4726
- # all history into a single stream (paper code, both files).
4727
- _c = 1 if lidx == config.num_hidden_layers - 1 else _C
4728
- a = torch.zeros(_c, lidx + 2)
4729
- a[:, lidx + 1] = 1.0 # last entry = current layer = identity
4730
- _static_list.append(nn.Parameter(a))
4731
- self.mudd_static = nn.ParameterList(_static_list)
4732
-
4733
- # Dynamic DA modules (one per layer)
4734
- if _mudd_dynamic:
4735
- self.mudd_dynamic = nn.ModuleList([
4736
- NeoLLMMUDDModule(
4737
- hidden_size = config.hidden_size,
4738
- lidx = lidx,
4739
- # Last layer: C=1 β€” collapses to single final repr
4740
- num_ways = 1 if lidx == config.num_hidden_layers - 1 else _C,
4741
- is_last = (lidx == config.num_hidden_layers - 1),
4742
- expand_last = _mudd_expand_last,
4743
- round64 = _mudd_round64,
4744
- )
4745
- for lidx in range(config.num_hidden_layers)
4746
- ])
4747
- else:
4748
- self.mudd_dynamic = None
4749
- else:
4750
- self.mudd_static = None
4751
- self.mudd_dynamic = None
4752
-
4753
- # ── DCA final GRN (Heddes et al., 2025) ───────────────────────────
4754
- # Applied once after all decoder layers to aggregate the full depth
4755
- # stack into the final hidden representation before the output norm.
4756
- # num_stack_layers = min(2*k, L+1) β€” same cap as per-layer GRNs.
4757
- # num_outputs=1 collapses to a single [B, S, D] tensor.
4758
- if _use_dca and getattr(config, 'dca_use_final_grn', True):
4759
- _dca_k = getattr(config, 'dca_k', 2)
4760
- _dca_eps = getattr(config, 'dca_grn_eps', config.rms_norm_eps)
4761
- self.dca_final_grn = NeoLLMGRN(
4762
- hidden_size = config.hidden_size,
4763
- num_stack_layers = min(2 * _dca_k, config.num_hidden_layers + 1),
4764
- num_outputs = 1,
4765
- eps = _dca_eps,
4766
- )
4767
- else:
4768
- self.dca_final_grn = None
4769
-
4770
  self.post_init()
4771
 
4772
  def get_input_embeddings(self):
@@ -4794,9 +3764,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
4794
  getattr(cfg, "use_repo", False)
4795
  and layer_idx >= getattr(cfg, "repo_start_layer", cfg.num_hidden_layers // 3)
4796
  )
4797
- _versatile = getattr(cfg, "use_versatile_ffn", False)
4798
- _use_stacktrans = getattr(cfg, "use_stacktrans", False)
4799
- _use_laurel = getattr(cfg, "use_laurel", False)
4800
  return LayerAnalysis(
4801
  seednorm_pre_attn = SeeDNormAnalysis(),
4802
  seednorm_post_attn = SeeDNormAnalysis(),
@@ -4810,14 +3778,10 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
4810
  polynorm = PolyNormAnalysis() if not _versatile else None,
4811
  versatile = VersatileFFNAnalysis() if _versatile else None,
4812
  ),
4813
- gpas_attn = GPASAnalysis(),
4814
- gpas_mlp = GPASAnalysis(),
4815
- jtokm = JTokMAnalysis() if cfg.use_jtokm else None,
4816
- attn_res = AttnResAnalysis() if getattr(cfg, "use_attn_res", False) else None,
4817
- dca = DCAAnalysis() if getattr(cfg, "use_dca", False) else None,
4818
- stack = StackMemoryAnalysis() if _use_stacktrans else None,
4819
- laurel_attn = LAuReLAnalysis() if _use_laurel else None,
4820
- laurel_mlp = LAuReLAnalysis() if _use_laurel else None,
4821
  )
4822
 
4823
  def forward(
@@ -4933,57 +3897,13 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
4933
  if use_attn_res:
4934
  attn_res_sources = [hidden_states] # b_0 = token embedding
4935
  attn_res_partial = hidden_states # initial partial sum
4936
- # Block boundary handling now lives inside NeoLLMDecoderLayer.forward(),
4937
- # firing after the pre-attn AttnRes call (paper Fig. 2 timing).
4938
-
4939
- # ── MUDD state ────────────────────────────────────────────────────
4940
- # hiddens[0] = token embedding; hiddens[i] = output of layer i-1.
4941
- # After each layer, its output is appended so layer i receives a
4942
- # history of length i+2 (embedding + i preceding layer outputs).
4943
- # mudd_streams is None for layer 0 (standard residual path there)
4944
- # and a C-tuple of [B,T,D] tensors for layers 1…L.
4945
- use_mudd = getattr(self.config, 'use_mudd', False)
4946
- mudd_hiddens = None
4947
- mudd_streams = None
4948
- if use_mudd:
4949
- mudd_hiddens = [hidden_states] # b_0 = token embedding
4950
-
4951
- # ── DCA state ─────────────────────────────────────────────────────
4952
- # all_tokens[0] = token embedding; grows by one per decoder layer.
4953
- # Before each layer, the stack is built and k-DCA selection applied,
4954
- # capping memory at 2*dca_k stored tensors regardless of depth.
4955
- # dca_stack is always non-None (even layer 0 gets [embedding]).
4956
- use_dca = getattr(self.config, 'use_dca', False)
4957
- _dca_k = getattr(self.config, 'dca_k', 2)
4958
- dca_all_tokens = None
4959
- dca_stack = None
4960
- if use_dca:
4961
- dca_all_tokens = [hidden_states] # [embedding]
4962
-
4963
- # ── StackTrans state ──────────────────────────────────────────────
4964
- # stack_state / stack_mask start as None for the first layer;
4965
- # StackMemory initialises them to zeros internally on first call.
4966
- # After each layer, the returned (new_stack, new_mask) are passed
4967
- # to the next layer as its initial stack β€” this is "vertical" state
4968
- # propagation: information flows depth-wise through the stack.
4969
- #
4970
- # Temporal accumulation across generation steps is handled by the
4971
- # StackMemory internal k_cache / action_cache mechanism:
4972
- # - enable_cache is set True when use_cache=True (inference)
4973
- # - reset_cache() is called when past_key_values is None
4974
- # (new sequence, not a continuation step)
4975
- # This matches the OLMo reference implementation exactly.
4976
- use_stacktrans = self.use_stacktrans
4977
- stack_state = None
4978
- stack_mask = None
4979
- if use_stacktrans:
4980
- use_cache_flag = kwargs.get("use_cache", False)
4981
- past_kv_flag = kwargs.get("past_key_values", None)
4982
- for layer in self.layers:
4983
- if layer.stack_memory is not None:
4984
- layer.stack_memory.enable_cache = bool(use_cache_flag)
4985
- if past_kv_flag is None:
4986
- layer.stack_memory.reset_cache()
4987
 
4988
  # Pre-allocate per-layer analysis list when analysis is active
4989
  if analysis_state is not None:
@@ -4993,13 +3913,17 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
4993
  if output_hidden_states:
4994
  all_hidden_states = all_hidden_states + (hidden_states,)
4995
 
4996
- # ── DCA: build k-selected stack for this layer ───────────────
4997
- # Stack has layer_idx+1 entries before selection; after k-DCA
4998
- # selection it has at most 2*dca_k entries (first k + last k).
4999
- if use_dca:
5000
- dca_stack = dca_select_layers(
5001
- torch.stack(dca_all_tokens, dim=0), k=_dca_k
5002
- )
 
 
 
 
5003
 
5004
  # Build per-layer analysis container (only in eval + analysis mode)
5005
  layer_analysis = None
@@ -5017,10 +3941,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
5017
  B_vals=B_vals,
5018
  attn_res_sources=attn_res_sources,
5019
  attn_res_partial=attn_res_partial if use_attn_res else None,
5020
- mudd_streams=mudd_streams,
5021
- dca_stack=dca_stack,
5022
- stack_state=stack_state,
5023
- stack_mask=stack_mask,
5024
  layer_analysis=layer_analysis,
5025
  output_attentions=output_attentions,
5026
  repo_rope_args=repo_rope_args,
@@ -5028,76 +3948,23 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
5028
  )
5029
  hidden_states = layer_outputs[0]
5030
 
5031
- # ── StackTrans: extract updated stack state for next layer ─────
5032
- # layer_outputs always ends with (stack_state, stack_mask) β€”
5033
- # both are None when use_stacktrans=False (zero cost).
5034
- stack_state = layer_outputs[-2]
5035
- stack_mask = layer_outputs[-1]
5036
-
5037
  # Update AttnRes partial sum β€” the new partial is the layer output
5038
  if use_attn_res:
5039
  attn_res_partial = hidden_states
5040
 
5041
- # Append layer output to DCA history for next layer's stack
5042
- if use_dca:
5043
- dca_all_tokens.append(hidden_states)
5044
-
5045
- # ── MUDD: append current output and compute DA for next layer ──
5046
- # mudd_hiddens grows by 1 each iteration; at layer i it has i+2
5047
- # entries (embedding + i outputs). The DA for layer i+1 takes this
5048
- # full history and produces C streams via dynamic + static weights.
5049
- # mudd_streams is passed to layer i+1 as its input streams.
5050
- if use_mudd:
5051
- mudd_hiddens.append(hidden_states)
5052
- # Compute DA module output using the just-appended history
5053
- # (mudd_hiddens now has layer_idx+2 entries)
5054
- is_last_layer = (layer_idx == self.config.num_hidden_layers - 1)
5055
- mudd_da_module = self.mudd_dynamic[layer_idx] if self.mudd_dynamic is not None else None
5056
- if mudd_da_module is not None:
5057
- raw_streams = mudd_da_module(
5058
- hidden_states,
5059
- mudd_hiddens,
5060
- self.mudd_static[layer_idx],
5061
- )
5062
- else:
5063
- # Static-only: apply weighted sum with learnable bias
5064
- # stack history [L, B, T, D], weight by mudd_static
5065
- stacked = torch.stack(mudd_hiddens, dim=0) # [L, B, T, D]
5066
- a = self.mudd_static[layer_idx].to(hidden_states.dtype) # [C, L]
5067
- raw_streams = tuple(
5068
- torch.einsum('cl,lbtd->btd', a[c:c+1], stacked).squeeze(0)
5069
- for c in range(a.shape[0])
5070
- )
5071
- if is_last_layer:
5072
- # Last layer DA always produces C=1 β†’ single final repr.
5073
- # This is the MUDD-aggregated combination of all layer
5074
- # histories weighted by the last layer's output as query.
5075
- # Replace hidden_states so the final norm and lm_head
5076
- # receive this aggregated representation, not the raw
5077
- # last-layer output (paper forward loop: x = x[0] after loop).
5078
- hidden_states = raw_streams[0]
5079
- mudd_streams = None # no next layer
5080
- elif len(raw_streams) == 1:
5081
- # dense_type='l': broadcast to 4-tuple
5082
- mudd_streams = (raw_streams[0],) * 4
5083
- else:
5084
- # 'qkvr': 4 streams β†’ (xq, xk, xv, xr)
5085
- mudd_streams = raw_streams
5086
-
5087
  if output_attentions:
5088
  all_attentions = all_attentions + (layer_outputs[1],)
5089
 
5090
- # Collect JTok-M / VersatileFFN aux stats.
5091
- # layer_outputs always ends with (stack_state, stack_mask) β€”
5092
- # slice [1:-2] to skip hidden_states[0] and the two stack slots.
5093
- inner_outputs = layer_outputs[1:-2]
5094
-
5095
- if self.config.use_jtokm and len(inner_outputs) > (1 if output_attentions else 0):
5096
- all_aux_stats.append(inner_outputs[-1])
5097
 
 
 
5098
  if getattr(self.config, "use_versatile_ffn", False):
5099
- for item in inner_outputs:
5100
  if isinstance(item, tuple) and len(item) == 3:
 
5101
  all_aux_stats.append(("versatile", item))
5102
  break
5103
 
@@ -5105,16 +3972,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
5105
  and hasattr(decoder_layer, "current_layer_fan")):
5106
  self.first_layer_fan = decoder_layer.current_layer_fan
5107
 
5108
- # ── DCA final GRN ──────────────────────────────────────────────────
5109
- # Aggregates the full depth history (k-selected) into the final
5110
- # hidden representation, matching the DCAGPT forward loop which
5111
- # applies final_grn(stack(all_tokens)) before norm β†’ lm_head.
5112
- if use_dca and self.dca_final_grn is not None:
5113
- final_stack = dca_select_layers(
5114
- torch.stack(dca_all_tokens, dim=0), k=_dca_k
5115
- )
5116
- hidden_states = self.dca_final_grn(final_stack)
5117
-
5118
  hidden_states = self.norm(hidden_states)
5119
 
5120
  if output_hidden_states:
@@ -5127,9 +3984,6 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
5127
  analysis_state.attn_res_sources_final = (
5128
  attn_res_sources if use_attn_res else None
5129
  )
5130
- analysis_state.dca_all_tokens_final = (
5131
- dca_all_tokens if use_dca else None
5132
- )
5133
 
5134
  if not return_dict:
5135
  return tuple(
@@ -5270,7 +4124,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
5270
  layers = None, # filled by NeoLLMModel.forward
5271
  jtokm_aux_stats = [] if cfg.use_jtokm else None,
5272
  attn_res_sources_final = [] if getattr(cfg, "use_attn_res", False) else None,
5273
- dca_all_tokens_final = [] if getattr(cfg, "use_dca", False) else None,
5274
  )
5275
 
5276
  # ── Standard model API ────────────────────────────────────────────────
@@ -5408,11 +4261,6 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
5408
  # ==================== AUTOMODEL REGISTRATION ====================
5409
 
5410
  __all__ = [
5411
- "StackMemory",
5412
- "LAuReLLayer",
5413
- "NeoLLMMUDDModule",
5414
- "NeoLLMGRN",
5415
- "dca_select_layers",
5416
  "NeoLLMForCausalLM",
5417
  "NeoLLMModel",
5418
  "NeoLLMPreTrainedModel",
@@ -5430,7 +4278,7 @@ __all__ = [
5430
  "REPOModule",
5431
  "VersatileFFN",
5432
  "compute_versatile_aux_loss",
5433
- # Analysis dataclasses
5434
  "AnalysisState",
5435
  "LayerAnalysis",
5436
  "AttentionAnalysis",
@@ -5444,9 +4292,6 @@ __all__ = [
5444
  "VersatileFFNAnalysis",
5445
  "JTokMAnalysis",
5446
  "AttnResAnalysis",
5447
- "DCAAnalysis",
5448
- "StackMemoryAnalysis",
5449
- "LAuReLAnalysis",
5450
  "GeneratorAnalysis",
5451
  ]
5452
 
 
80
  from configuration_neollm import NeoLLMConfig
81
 
82
  from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
83
+ torch._dynamo.config.capture_scalar_outputs = True
84
  logger = logging.get_logger(__name__)
85
 
86
 
 
339
  lns_scale: Optional[float] = None # 1/√(2β„“) scaling factor
340
 
341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  @dataclass
343
  class AttnResAnalysis:
344
  """
 
350
  sources_count: Optional[int] = None # number of sources including partial
351
 
352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  @dataclass
354
  class LayerAnalysis:
355
  """
 
378
  gpas_mlp: Optional[GPASAnalysis] = None # GPAS after MLP residual
379
 
380
  # Optional components (None when inactive)
381
+ jtokm: Optional[JTokMAnalysis] = None # if use_jtokm
382
+ attn_res: Optional[AttnResAnalysis] = None # if use_attn_res
 
 
 
 
383
 
384
 
385
  @dataclass
 
444
  layers: Optional[List[LayerAnalysis]] = None
445
  jtokm_aux_stats: Optional[list] = None
446
  attn_res_sources_final: Optional[list] = None
 
447
  logits: Optional[torch.Tensor] = None
448
 
449
  class ScalarMultiplier(nn.Module):
 
2363
  first_layer_fan: Optional[torch.Tensor] = None,
2364
  attn_analysis: Optional[AttentionAnalysis] = None,
2365
  repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
 
 
2366
  **kwargs: Unpack[FlashAttentionKwargs],
2367
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
2368
  input_shape = hidden_states.shape[:-1]
 
2373
  h_fan = self.lambda_1 * first_layer_fan + self.lambda_2 * h_fan
2374
  current_layer_fan = h_fan.clone()
2375
 
 
 
 
 
 
 
 
 
2376
  query_shape = (*input_shape, self.config.num_attention_heads, self.head_dim)
2377
  kv_shape = (*input_shape, self.num_mea_component_heads, self.head_dim)
2378
 
 
2387
  attn_analysis.gate_raw = gate.detach()
2388
 
2389
  q = self.q_norm(q_raw.view(query_shape)).transpose(1, 2)
2390
+ k = self.k_norm(self.k_proj(h_fan).view(kv_shape)).transpose(1, 2)
2391
+ v = self.v_proj(h_fan).view(kv_shape).transpose(1, 2)
2392
 
2393
  if attn_analysis is not None:
2394
  attn_analysis.q_post_norm = q.detach()
 
3065
  return result
3066
 
3067
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3068
  class NeoLLMDecoderLayer(GradientCheckpointingLayer):
3069
  """
3070
  Decoder layer with standard residual connections, optional JTok-M injection.
 
3120
  self.attn_res_query_attn = nn.Parameter(torch.zeros(config.hidden_size))
3121
  self.attn_res_query_mlp = nn.Parameter(torch.zeros(config.hidden_size))
3122
  self.attn_res_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
 
 
3123
  else:
3124
  self.attn_res_query_attn = None
3125
  self.attn_res_query_mlp = None
3126
  self.attn_res_norm = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3127
 
3128
  def _attn_res(
3129
  self,
 
3173
  B_vals: Optional[torch.Tensor] = None,
3174
  attn_res_sources: Optional[list] = None,
3175
  attn_res_partial: Optional[torch.Tensor] = None,
 
 
 
 
3176
  layer_analysis: Optional[LayerAnalysis] = None,
3177
  output_attentions: Optional[bool] = False,
3178
  repo_rope_args: Optional[Tuple[torch.Tensor, float]] = None,
 
3182
  if layer_analysis is not None:
3183
  layer_analysis.hidden_states_input = hidden_states.detach()
3184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3185
  # ── Attention Residuals: compute pre-attention input ──────────────
3186
  # When active, the input to the attention sublayer is no longer the
3187
  # raw hidden_states (accumulated residual) but a softmax-weighted
 
3195
  attn_res_sources, attn_res_partial, self.attn_res_query_attn,
3196
  ar_analysis, "attn",
3197
  )
 
 
 
 
 
 
 
 
3198
  residual_attn = attn_res_partial
3199
  else:
3200
+ h_attn = hidden_states
3201
+ residual_attn = hidden_states
 
3202
 
3203
  # ── Attention block ───────────────────────────────────────────────
3204
  sn_pre = layer_analysis.seednorm_pre_attn if layer_analysis is not None else None
 
3214
  first_layer_fan=first_layer_fan,
3215
  attn_analysis=layer_analysis.attention if layer_analysis is not None else None,
3216
  repo_rope_args=repo_rope_args,
 
 
3217
  **kwargs,
3218
  )
3219
 
 
3221
  layer_analysis.attn_contribution = hidden_states.detach()
3222
 
3223
  gpas_attn_a = layer_analysis.gpas_attn if layer_analysis is not None else None
3224
+ h_tilde = self.gpas_attn(residual_attn + hidden_states, analysis=gpas_attn_a)
 
 
 
 
 
 
 
 
 
 
 
3225
 
3226
  if layer_analysis is not None:
3227
  layer_analysis.h_tilde = h_tilde.detach()
 
3257
  if layer_analysis is not None:
3258
  layer_analysis.mlp_contribution = delta_m.detach()
3259
 
3260
+ # ── JTok-M injection (additive alongside MLP residual) ────────────
3261
+ aux_stats = None
 
 
 
3262
  if self.use_jtokm and z_tilde is not None and B_vals is not None:
3263
  orig_shape = h_tilde.shape
3264
  h_flat = h_tilde.reshape(-1, self.hidden_size)
 
3269
  delta_r, aux_stats = self.jtokm(h_flat, z_flat, B_flat, analysis=jtokm_a)
3270
  delta_r = delta_r.reshape(orig_shape)
3271
 
3272
+ gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
3273
+ hidden_states = self.gpas_mlp(residual_mlp + delta_m + delta_r, analysis=gpas_mlp_a)
 
 
 
 
 
3274
  else:
3275
+ gpas_mlp_a = layer_analysis.gpas_mlp if layer_analysis is not None else None
3276
+ hidden_states = self.gpas_mlp(residual_mlp + delta_m, analysis=gpas_mlp_a)
 
 
 
 
 
3277
 
3278
  if layer_analysis is not None:
3279
  layer_analysis.hidden_states_output = hidden_states.detach()
 
3285
  outputs += (aux_stats,)
3286
  if versatile_aux is not None:
3287
  outputs += (versatile_aux,)
 
 
 
3288
  return outputs
3289
 
3290
 
 
3668
  module.attn_res_query_attn.data.zero_()
3669
  module.attn_res_query_mlp.data.zero_()
3670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3671
  elif isinstance(module, SpellingBeeEmbedding):
3672
  # byte_emb initialised identically to token embeddings: std=1/√d.
3673
  # Ensures E[β€–e_byteβ€–Β²] β‰ˆ 1 at init, matching etok, so the
 
3737
  self.gradient_checkpointing = False
3738
  self.first_layer_fan = None
3739
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3740
  self.post_init()
3741
 
3742
  def get_input_embeddings(self):
 
3764
  getattr(cfg, "use_repo", False)
3765
  and layer_idx >= getattr(cfg, "repo_start_layer", cfg.num_hidden_layers // 3)
3766
  )
3767
+ _versatile = getattr(cfg, "use_versatile_ffn", False)
 
 
3768
  return LayerAnalysis(
3769
  seednorm_pre_attn = SeeDNormAnalysis(),
3770
  seednorm_post_attn = SeeDNormAnalysis(),
 
3778
  polynorm = PolyNormAnalysis() if not _versatile else None,
3779
  versatile = VersatileFFNAnalysis() if _versatile else None,
3780
  ),
3781
+ gpas_attn = GPASAnalysis(),
3782
+ gpas_mlp = GPASAnalysis(),
3783
+ jtokm = JTokMAnalysis() if cfg.use_jtokm else None,
3784
+ attn_res = AttnResAnalysis() if getattr(cfg, "use_attn_res", False) else None,
 
 
 
 
3785
  )
3786
 
3787
  def forward(
 
3897
  if use_attn_res:
3898
  attn_res_sources = [hidden_states] # b_0 = token embedding
3899
  attn_res_partial = hidden_states # initial partial sum
3900
+
3901
+ num_blocks = getattr(self.config, 'attn_res_num_blocks', 0)
3902
+ block_size = (
3903
+ max(self.config.num_hidden_layers // num_blocks, 1)
3904
+ if num_blocks > 0
3905
+ else 1 # Full AttnRes: every layer is its own "block"
3906
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3907
 
3908
  # Pre-allocate per-layer analysis list when analysis is active
3909
  if analysis_state is not None:
 
3913
  if output_hidden_states:
3914
  all_hidden_states = all_hidden_states + (hidden_states,)
3915
 
3916
+ # ── Block AttnRes: boundary handling ──────────────────────────
3917
+ # At each block boundary (excluding layer 0): append the current
3918
+ # partial sum to sources as a completed block summary, then reset
3919
+ # partial to None so the new block builds from scratch β€” matching
3920
+ # the paper's pseudocode exactly.
3921
+ # For Full AttnRes (block_size=1): every layer is a boundary, so
3922
+ # partial is appended and reset after every layer. The partial is
3923
+ # re-seeded from the previous hidden_states below.
3924
+ if use_attn_res and layer_idx > 0 and layer_idx % block_size == 0:
3925
+ attn_res_sources = attn_res_sources + [attn_res_partial]
3926
+ attn_res_partial = hidden_states # start new block from current output
3927
 
3928
  # Build per-layer analysis container (only in eval + analysis mode)
3929
  layer_analysis = None
 
3941
  B_vals=B_vals,
3942
  attn_res_sources=attn_res_sources,
3943
  attn_res_partial=attn_res_partial if use_attn_res else None,
 
 
 
 
3944
  layer_analysis=layer_analysis,
3945
  output_attentions=output_attentions,
3946
  repo_rope_args=repo_rope_args,
 
3948
  )
3949
  hidden_states = layer_outputs[0]
3950
 
 
 
 
 
 
 
3951
  # Update AttnRes partial sum β€” the new partial is the layer output
3952
  if use_attn_res:
3953
  attn_res_partial = hidden_states
3954
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3955
  if output_attentions:
3956
  all_attentions = all_attentions + (layer_outputs[1],)
3957
 
3958
+ # Collect JTok-M aux stats (last element if present)
3959
+ if self.config.use_jtokm and len(layer_outputs) > (2 if output_attentions else 1):
3960
+ all_aux_stats.append(layer_outputs[-1])
 
 
 
 
3961
 
3962
+ # Collect VersatileFFN aux stats (second-to-last if jtokm also present,
3963
+ # or last if jtokm is absent). Only non-None during training.
3964
  if getattr(self.config, "use_versatile_ffn", False):
3965
+ for item in layer_outputs[1:]:
3966
  if isinstance(item, tuple) and len(item) == 3:
3967
+ # (p_sum, f_sum, N_tokens) signature
3968
  all_aux_stats.append(("versatile", item))
3969
  break
3970
 
 
3972
  and hasattr(decoder_layer, "current_layer_fan")):
3973
  self.first_layer_fan = decoder_layer.current_layer_fan
3974
 
 
 
 
 
 
 
 
 
 
 
3975
  hidden_states = self.norm(hidden_states)
3976
 
3977
  if output_hidden_states:
 
3984
  analysis_state.attn_res_sources_final = (
3985
  attn_res_sources if use_attn_res else None
3986
  )
 
 
 
3987
 
3988
  if not return_dict:
3989
  return tuple(
 
4124
  layers = None, # filled by NeoLLMModel.forward
4125
  jtokm_aux_stats = [] if cfg.use_jtokm else None,
4126
  attn_res_sources_final = [] if getattr(cfg, "use_attn_res", False) else None,
 
4127
  )
4128
 
4129
  # ── Standard model API ────────────────────────────────────────────────
 
4261
  # ==================== AUTOMODEL REGISTRATION ====================
4262
 
4263
  __all__ = [
 
 
 
 
 
4264
  "NeoLLMForCausalLM",
4265
  "NeoLLMModel",
4266
  "NeoLLMPreTrainedModel",
 
4278
  "REPOModule",
4279
  "VersatileFFN",
4280
  "compute_versatile_aux_loss",
4281
+ # Analysis dataclasses β€” exported so external tools can type-hint against them
4282
  "AnalysisState",
4283
  "LayerAnalysis",
4284
  "AttentionAnalysis",
 
4292
  "VersatileFFNAnalysis",
4293
  "JTokMAnalysis",
4294
  "AttnResAnalysis",
 
 
 
4295
  "GeneratorAnalysis",
4296
  ]
4297