Update configuration_neollm.py
Browse files- configuration_neollm.py +309 -7
configuration_neollm.py
CHANGED
|
@@ -414,10 +414,59 @@ class NeoLLMConfig(PretrainedConfig):
|
|
| 414 |
Li, H., Zhao, T., Cai, D. & Sproat, R. (2026). *REPO: Language Models
|
| 415 |
with Context Re-Positioning.* arXiv:2512.14391.
|
| 416 |
|
|
|
|
|
|
|
|
|
|
| 417 |
Xiao, D., Meng, Q., Li, S. & Yuan, X. (2025). *MUDDFormer: Breaking
|
| 418 |
Residual Bottlenecks in Transformers via Multiway Dynamic Dense
|
| 419 |
Connections.* arXiv:2502.12170.
|
| 420 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
use_mudd (:obj:`bool`, *optional*, defaults to ``False``):
|
| 422 |
Enable **Multiway Dynamic Dense (MUDD) connections** (Xiao et al.,
|
| 423 |
2025). Replaces standard residual connections with learned,
|
|
@@ -478,6 +527,172 @@ class NeoLLMConfig(PretrainedConfig):
|
|
| 478 |
Enables independent rescaling per stream when
|
| 479 |
``mudd_dense_type="qkvr"``. Adds 2 Γ SeeDNorm parameters per
|
| 480 |
decoder layer. Ignored when ``mudd_dense_type="l"``.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
"""
|
| 482 |
|
| 483 |
model_type = "neollm"
|
|
@@ -519,7 +734,7 @@ class NeoLLMConfig(PretrainedConfig):
|
|
| 519 |
directional_routing_temp=3.0,
|
| 520 |
# ββ Attention Residuals (Kimi Team, 2026) βββββββββββββββββββββββββ
|
| 521 |
use_attn_res=False,
|
| 522 |
-
attn_res_num_blocks=
|
| 523 |
fan_ratio=0.125,
|
| 524 |
fan_ratio_ffn=0.0625,
|
| 525 |
dropout_rate=0.1,
|
|
@@ -557,13 +772,42 @@ class NeoLLMConfig(PretrainedConfig):
|
|
| 557 |
versatile_gumbel_temp_end=0.1,
|
| 558 |
versatile_gumbel_temp_decay=0.99984,
|
| 559 |
versatile_aux_loss_weight=1e-5,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
# ββ MUDD connections (Xiao et al., 2025) βββββββββββββββββββββββββ
|
| 561 |
use_mudd=False,
|
| 562 |
mudd_dense_type="qkvr",
|
| 563 |
mudd_dynamic_dense=True,
|
| 564 |
-
mudd_round64=
|
| 565 |
mudd_expand_last=True,
|
| 566 |
mudd_sepln=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
**kwargs,
|
| 568 |
):
|
| 569 |
# ββ Generator / tying consistency βββββββββββββββββββββββββββββββββ
|
|
@@ -612,12 +856,41 @@ class NeoLLMConfig(PretrainedConfig):
|
|
| 612 |
f"num_hidden_layers={num_hidden_layers}."
|
| 613 |
)
|
| 614 |
|
| 615 |
-
# ββ
|
| 616 |
-
|
|
|
|
|
|
|
|
|
|
| 617 |
raise ValueError(
|
| 618 |
-
"
|
| 619 |
-
"
|
| 620 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
)
|
| 622 |
if use_mudd and mudd_dense_type not in ("qkvr", "l"):
|
| 623 |
raise ValueError(
|
|
@@ -732,6 +1005,12 @@ class NeoLLMConfig(PretrainedConfig):
|
|
| 732 |
self.repo_start_layer = repo_start_layer
|
| 733 |
self.repo_d_p = repo_d_p
|
| 734 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 735 |
# ββ MUDD connections (Xiao et al., 2025) βββββββββββββββββββββββββ
|
| 736 |
self.use_mudd = use_mudd
|
| 737 |
self.mudd_dense_type = mudd_dense_type
|
|
@@ -740,6 +1019,29 @@ class NeoLLMConfig(PretrainedConfig):
|
|
| 740 |
self.mudd_expand_last = mudd_expand_last
|
| 741 |
self.mudd_sepln = mudd_sepln
|
| 742 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 743 |
# ββ VersatileFFN (Nie et al., 2026) βββββββββββββββββββββββββββββββ
|
| 744 |
self.use_versatile_ffn = use_versatile_ffn
|
| 745 |
self.versatile_total_experts = versatile_total_experts
|
|
|
|
| 414 |
Li, H., Zhao, T., Cai, D. & Sproat, R. (2026). *REPO: Language Models
|
| 415 |
with Context Re-Positioning.* arXiv:2512.14391.
|
| 416 |
|
| 417 |
+
Heddes, M. et al. (2025). *DeepCrossAttention: Supercharging
|
| 418 |
+
Transformer Residual Connections.* arXiv:2502.06785.
|
| 419 |
+
|
| 420 |
Xiao, D., Meng, Q., Li, S. & Yuan, X. (2025). *MUDDFormer: Breaking
|
| 421 |
Residual Bottlenecks in Transformers via Multiway Dynamic Dense
|
| 422 |
Connections.* arXiv:2502.12170.
|
| 423 |
|
| 424 |
+
use_dca (:obj:`bool`, *optional*, defaults to ``False``):
|
| 425 |
+
Enable **DeepCrossAttention (DCA)** (Heddes et al., 2025).
|
| 426 |
+
Replaces standard residual connections with three independent
|
| 427 |
+
GRN-v3 modules per decoder layer β one each for the Query,
|
| 428 |
+
Key, and Value streams β that dynamically aggregate outputs of
|
| 429 |
+
all preceding layers with both dimension-dependent (static) and
|
| 430 |
+
input-dependent (dynamic) learned weights.
|
| 431 |
+
|
| 432 |
+
**k-DCA efficiency** (``dca_k``): only the first and last ``k``
|
| 433 |
+
layer outputs are kept in the stack, bounding memory and
|
| 434 |
+
computation at ``O(2k)`` rather than ``O(L)``.
|
| 435 |
+
|
| 436 |
+
**Mutual exclusion**: ``use_dca``, ``use_attn_res``, and
|
| 437 |
+
``use_mudd`` are mutually exclusive. Set exactly one to True.
|
| 438 |
+
|
| 439 |
+
Reference: Heddes, M. et al. (2025). *DeepCrossAttention:
|
| 440 |
+
Supercharging Transformer Residual Connections.*
|
| 441 |
+
arXiv:2502.06785.
|
| 442 |
+
|
| 443 |
+
dca_k (:obj:`int`, *optional*, defaults to 2):
|
| 444 |
+
Number of first and last layer outputs retained in the
|
| 445 |
+
depth-wise stack for k-DCA. With ``k=2`` the stack contains
|
| 446 |
+
at most 4 tensors (first 2 + last 2) regardless of depth.
|
| 447 |
+
|
| 448 |
+
Paper Table 1 results (24-layer model on LM1B):
|
| 449 |
+
|
| 450 |
+
- ``k=1`` β 0.33Γ time to transformer PPL, PPL 14.48
|
| 451 |
+
- ``k=2`` β 0.33Γ time, PPL 14.41 β recommended default
|
| 452 |
+
- ``k=4`` β 0.37Γ time, PPL 14.50
|
| 453 |
+
- ``k=24`` (full) β 0.39Γ time, PPL 14.35
|
| 454 |
+
|
| 455 |
+
Smaller k gives faster training and lower inference latency
|
| 456 |
+
at a very small perplexity cost. k=2 is the best
|
| 457 |
+
efficiency-quality trade-off per the paper.
|
| 458 |
+
|
| 459 |
+
dca_use_final_grn (:obj:`bool`, *optional*, defaults to ``True``):
|
| 460 |
+
Apply a final GRN-v3 aggregation (``num_outputs=1``) over the
|
| 461 |
+
k-selected depth stack after all decoder layers, before the
|
| 462 |
+
output norm and lm_head. This matches the DCAGPT architecture
|
| 463 |
+
exactly. The final GRN collapses all depth history into the
|
| 464 |
+
final hidden representation using learned weights, rather than
|
| 465 |
+
using the raw last-layer output.
|
| 466 |
+
|
| 467 |
+
dca_grn_eps (:obj:`float`, *optional*, defaults to 1e-6):
|
| 468 |
+
Epsilon for the no-scale RMSNorm inside each GRN-v3 module.
|
| 469 |
+
|
| 470 |
use_mudd (:obj:`bool`, *optional*, defaults to ``False``):
|
| 471 |
Enable **Multiway Dynamic Dense (MUDD) connections** (Xiao et al.,
|
| 472 |
2025). Replaces standard residual connections with learned,
|
|
|
|
| 527 |
Enables independent rescaling per stream when
|
| 528 |
``mudd_dense_type="qkvr"``. Adds 2 Γ SeeDNorm parameters per
|
| 529 |
decoder layer. Ignored when ``mudd_dense_type="l"``.
|
| 530 |
+
|
| 531 |
+
use_stacktrans (:obj:`bool`, *optional*, defaults to ``False``):
|
| 532 |
+
Enable **StackTrans** (Zhang et al., NeurIPS 2025): inserts a
|
| 533 |
+
differentiable multi-head hidden-state stack between each pair of
|
| 534 |
+
Transformer layers, providing an explicit push/pop memory that
|
| 535 |
+
allows the model to learn Chomsky-hierarchy grammars (regular
|
| 536 |
+
expressions and deterministic context-free grammars) and improves
|
| 537 |
+
compositional generalisation and reasoning.
|
| 538 |
+
|
| 539 |
+
The stack is positioned at the **very beginning** of each decoder
|
| 540 |
+
layer forward pass, before the attention sublayer, so the
|
| 541 |
+
attention computation sees the stack-enriched hidden state.
|
| 542 |
+
|
| 543 |
+
**Mutual exclusion**: ``use_stacktrans`` cannot be active
|
| 544 |
+
simultaneously with ``use_attn_res``, ``use_mudd``, or
|
| 545 |
+
``use_dca`` because all four alter the information flow entering
|
| 546 |
+
the residual stream. Set exactly one to ``True``.
|
| 547 |
+
|
| 548 |
+
Reference: Zhang, K. et al. (2025). *Recursive Transformer:
|
| 549 |
+
Boosting Reasoning Ability with State Stack.* NeurIPS 2025.
|
| 550 |
+
|
| 551 |
+
stacktrans_num_heads (:obj:`int`, *optional*, defaults to 4):
|
| 552 |
+
Number of independent stack heads ``H``. Each head maintains its
|
| 553 |
+
own low-rank stack of dimension ``ds = stacktrans_stack_d_model //
|
| 554 |
+
H``. The paper ablation (Fig. 4a) shows that ``H = 4`` is the
|
| 555 |
+
optimal trade-off: performance plateaus past this value while
|
| 556 |
+
overhead grows.
|
| 557 |
+
|
| 558 |
+
stacktrans_stack_slots (:obj:`int`, *optional*, defaults to 24):
|
| 559 |
+
Maximum depth ``S`` of the stack (number of slots). Overflow
|
| 560 |
+
elements are truncated to zero (a form of forgetting). The paper
|
| 561 |
+
ablation (Fig. 4c) shows that ``S = 24`` is optimal; increasing
|
| 562 |
+
to 32 yields no measurable gain.
|
| 563 |
+
|
| 564 |
+
stacktrans_stack_d_model (:obj:`int`, *optional*, defaults to 64):
|
| 565 |
+
Total dimensionality of the low-rank stack space, equal to
|
| 566 |
+
``H Γ ds``. The hidden state is projected down from
|
| 567 |
+
``hidden_size`` to this dimension before stack operations, and
|
| 568 |
+
projected back up afterward. The paper uses ``H=4, ds=16``
|
| 569 |
+
(stack_d_model=64). From the ablation (Fig. 4b), ``ds`` in the
|
| 570 |
+
range 16β64 provides the best efficiencyβquality trade-off.
|
| 571 |
+
|
| 572 |
+
stacktrans_forward_bs (:obj:`int`, *optional*, defaults to 1):
|
| 573 |
+
Batch size of the internal ``k_cache`` and ``action_cache``
|
| 574 |
+
buffers used for autoregressive generation. Must be β₯ the
|
| 575 |
+
actual generation batch size. At training time these buffers are
|
| 576 |
+
never used (``enable_cache=False``). Increasing this above 1
|
| 577 |
+
is only needed for batched generation.
|
| 578 |
+
|
| 579 |
+
use_laurel (:obj:`bool`, *optional*, defaults to ``False``):
|
| 580 |
+
Enable the **LAuReL** framework (Menghani, Kumar & Kumar, ICML
|
| 581 |
+
2025): a learned generalisation of the canonical residual
|
| 582 |
+
connection that augments the residual stream with lightweight
|
| 583 |
+
learnable components, improving model quality with minimal
|
| 584 |
+
parameter overhead.
|
| 585 |
+
|
| 586 |
+
Standard Pre-LN residual connection:
|
| 587 |
+
|
| 588 |
+
.. math::
|
| 589 |
+
x_{i+1} = f(x_i) + x_i
|
| 590 |
+
|
| 591 |
+
LAuReL replaces this with:
|
| 592 |
+
|
| 593 |
+
.. math::
|
| 594 |
+
x_{i+1} = \\alpha \\cdot f(x_i) + g(x_i)
|
| 595 |
+
|
| 596 |
+
where :math:`\\alpha` is a learned scalar and :math:`g` is a
|
| 597 |
+
learned linear function. Two sub-variants are controlled
|
| 598 |
+
independently by ``use_laurel_rw`` and ``use_laurel_lr``.
|
| 599 |
+
|
| 600 |
+
In NeoLLM, LAuReL is applied to **both** residual connections
|
| 601 |
+
per decoder layer (attention and MLP), immediately before GPAS:
|
| 602 |
+
|
| 603 |
+
- Attention: ``GPAS(LAuReL(attn_out, residual_attn))``
|
| 604 |
+
- MLP: ``GPAS(LAuReL(delta_m, residual_mlp))``
|
| 605 |
+
|
| 606 |
+
GPAS then operates on the LAuReL-combined stream with its
|
| 607 |
+
stop-gradient scaling, so the two techniques remain
|
| 608 |
+
structurally orthogonal.
|
| 609 |
+
|
| 610 |
+
**Mutual exclusion**: ``use_laurel`` is incompatible with
|
| 611 |
+
``use_attn_res``, ``use_mudd``, and ``use_dca`` because those
|
| 612 |
+
three replace the residual streams (``residual_attn``,
|
| 613 |
+
``residual_mlp``) with custom-aggregated tensors whose
|
| 614 |
+
statistical properties differ from the standard accumulated
|
| 615 |
+
hidden state that LAuReL's initialisation guarantees assume.
|
| 616 |
+
``use_laurel`` is compatible with ``use_stacktrans``.
|
| 617 |
+
|
| 618 |
+
Reference: Menghani, G., Kumar, R. & Kumar, S. (ICML 2025).
|
| 619 |
+
*LAuReL: Learned Augmented Residual Layer.* arXiv:2411.07501.
|
| 620 |
+
|
| 621 |
+
use_laurel_rw (:obj:`bool`, *optional*, defaults to ``True``):
|
| 622 |
+
Enable the **Residual Weights** (RW) sub-variant of LAuReL.
|
| 623 |
+
Requires ``use_laurel=True``.
|
| 624 |
+
|
| 625 |
+
Assigns independently learned scalar weights to the nonlinear
|
| 626 |
+
component :math:`f(x_i)` and the residual :math:`x_i`:
|
| 627 |
+
|
| 628 |
+
.. math::
|
| 629 |
+
x_{i+1} = \\alpha \\cdot f(x_i) + \\beta \\cdot g_{\\text{LR}}(x_i)
|
| 630 |
+
|
| 631 |
+
where :math:`[\\alpha, \\beta] = \\operatorname{softmax}([a, b])`
|
| 632 |
+
with :math:`a, b \\in \\mathbb{R}` learnable scalars (2 parameters
|
| 633 |
+
per LAuReL instantiation). Softmax normalisation prevents
|
| 634 |
+
unbounded growth confirmed by the paper's ablation.
|
| 635 |
+
|
| 636 |
+
Initialized with :math:`a = b = 0`, giving
|
| 637 |
+
:math:`\\alpha = \\beta = 0.5` at step 0 β the model quickly
|
| 638 |
+
learns the optimal weighting. In earlier layers the nonlinear
|
| 639 |
+
component dominates; in deeper layers the residual gains
|
| 640 |
+
relative importance, adaptively mitigating vanishing gradients.
|
| 641 |
+
|
| 642 |
+
When ``use_laurel_lr=False``, the RW formula becomes:
|
| 643 |
+
|
| 644 |
+
.. math::
|
| 645 |
+
x_{i+1} = \\alpha \\cdot f(x_i) + \\beta \\cdot x_i
|
| 646 |
+
|
| 647 |
+
When both RW and LR are active (recommended combination):
|
| 648 |
+
|
| 649 |
+
.. math::
|
| 650 |
+
x_{i+1} = \\alpha \\cdot f(x_i)
|
| 651 |
+
+ \\beta \\cdot (B A x_i + x_i)
|
| 652 |
+
|
| 653 |
+
use_laurel_lr (:obj:`bool`, *optional*, defaults to ``True``):
|
| 654 |
+
Enable the **Low-Rank** (LR) sub-variant of LAuReL.
|
| 655 |
+
Requires ``use_laurel=True``.
|
| 656 |
+
|
| 657 |
+
Introduces a learnable low-rank linear transformation of the
|
| 658 |
+
residual :math:`x_i` that runs in parallel with the nonlinear
|
| 659 |
+
:math:`f(x_i)`:
|
| 660 |
+
|
| 661 |
+
.. math::
|
| 662 |
+
x_{i+1} = f(x_i) + W x_i + x_i,
|
| 663 |
+
\\quad W = A B + I,
|
| 664 |
+
\\quad A \\in \\mathbb{R}^{D \\times r},\\;
|
| 665 |
+
B^{\\top} \\in \\mathbb{R}^{D \\times r}
|
| 666 |
+
|
| 667 |
+
Equivalently written as a LoRA-style decomposition:
|
| 668 |
+
|
| 669 |
+
.. math::
|
| 670 |
+
x_{i+1} = f(x_i) + x_i + \\underbrace{A (B x_i)}_{\\text{low-rank term}}
|
| 671 |
+
|
| 672 |
+
where :math:`B \\in \\mathbb{R}^{r \\times D}` (down-projection,
|
| 673 |
+
column-orthogonal init) and :math:`A \\in \\mathbb{R}^{D \\times r}`
|
| 674 |
+
(up-projection, **zero init**). Zero init on :math:`A` ensures
|
| 675 |
+
the low-rank term is exactly zero at step 0, so the model starts
|
| 676 |
+
as a standard residual connection. Parameter count per LAuReL
|
| 677 |
+
layer: :math:`2rD`.
|
| 678 |
+
|
| 679 |
+
Adds :math:`2rD` parameters per residual connection (2 per decoder
|
| 680 |
+
layer since both attention and MLP residuals are augmented).
|
| 681 |
+
For hidden_size=512 and ``laurel_lr_rank=32``:
|
| 682 |
+
:math:`2 \\times 32 \\times 512 = 32{,}768` parameters per layer.
|
| 683 |
+
|
| 684 |
+
The paper recommends :math:`r \\in \\{32, 48, 64\\}` for LLMs.
|
| 685 |
+
For NeoLLM at 135M with 12 layers, ``r=32`` adds β786K parameters
|
| 686 |
+
(β0.6% of total), within the negligible overhead budget of the
|
| 687 |
+
paper (0.012%β0.1% for 1Bβ4B models).
|
| 688 |
+
|
| 689 |
+
laurel_lr_rank (:obj:`int`, *optional*, defaults to 32):
|
| 690 |
+
Rank :math:`r` of the low-rank matrices :math:`A` and :math:`B`
|
| 691 |
+
in the LAuReL-LR sub-variant. Controls the capacity vs. overhead
|
| 692 |
+
trade-off. The paper's ablation (Figure 3) shows performance
|
| 693 |
+
peaks at :math:`r \\in \\{16, 32\\}` for ResNet; for LLMs,
|
| 694 |
+
:math:`r \\in \\{32, 48, 64\\}` are recommended. Ignored when
|
| 695 |
+
``use_laurel_lr=False``.
|
| 696 |
"""
|
| 697 |
|
| 698 |
model_type = "neollm"
|
|
|
|
| 734 |
directional_routing_temp=3.0,
|
| 735 |
# ββ Attention Residuals (Kimi Team, 2026) βββββββββββββββββββββββββ
|
| 736 |
use_attn_res=False,
|
| 737 |
+
attn_res_num_blocks=2,
|
| 738 |
fan_ratio=0.125,
|
| 739 |
fan_ratio_ffn=0.0625,
|
| 740 |
dropout_rate=0.1,
|
|
|
|
| 772 |
versatile_gumbel_temp_end=0.1,
|
| 773 |
versatile_gumbel_temp_decay=0.99984,
|
| 774 |
versatile_aux_loss_weight=1e-5,
|
| 775 |
+
# ββ DCA (Heddes et al., 2025) βββββββββββββββββββββββββββββββββββββ
|
| 776 |
+
use_dca=False,
|
| 777 |
+
dca_k=1,
|
| 778 |
+
dca_use_final_grn=True,
|
| 779 |
+
dca_grn_eps=1e-6,
|
| 780 |
# ββ MUDD connections (Xiao et al., 2025) βββββββββββββββββββββββββ
|
| 781 |
use_mudd=False,
|
| 782 |
mudd_dense_type="qkvr",
|
| 783 |
mudd_dynamic_dense=True,
|
| 784 |
+
mudd_round64=True,
|
| 785 |
mudd_expand_last=True,
|
| 786 |
mudd_sepln=False,
|
| 787 |
+
# ββ StackTrans (Zhang et al., NeurIPS 2025) βββββββββββββββββββββββ
|
| 788 |
+
use_stacktrans=False,
|
| 789 |
+
stacktrans_num_heads=4,
|
| 790 |
+
stacktrans_stack_slots=24,
|
| 791 |
+
stacktrans_stack_d_model=64,
|
| 792 |
+
stacktrans_forward_bs=1,
|
| 793 |
+
# ββ LAuReL (Menghani, Kumar & Kumar, ICML 2025) βββββββββββββββββββ
|
| 794 |
+
use_laurel=False,
|
| 795 |
+
use_laurel_rw=True,
|
| 796 |
+
use_laurel_lr=True,
|
| 797 |
+
laurel_lr_rank=32,
|
| 798 |
+
# ββ GatedDeltaNet linear attention (Yang et al., 2024) βββββββββββ
|
| 799 |
+
# Replaces full attention every `linear_attention_every_n` layers
|
| 800 |
+
# (0-indexed: layers 2, 5, 8, ... for every_n=3).
|
| 801 |
+
# REPO applies to linear attention layers when both
|
| 802 |
+
# use_repo=True and use_repo_in_linear_attn=True.
|
| 803 |
+
use_linear_attention=False,
|
| 804 |
+
linear_attention_every_n=3,
|
| 805 |
+
use_repo_in_linear_attn=False,
|
| 806 |
+
linear_conv_kernel_dim=4,
|
| 807 |
+
linear_key_head_dim=32,
|
| 808 |
+
linear_value_head_dim=32,
|
| 809 |
+
linear_num_key_heads=8,
|
| 810 |
+
linear_num_value_heads=16,
|
| 811 |
**kwargs,
|
| 812 |
):
|
| 813 |
# ββ Generator / tying consistency βββββββββββββββββββββββββββββββββ
|
|
|
|
| 856 |
f"num_hidden_layers={num_hidden_layers}."
|
| 857 |
)
|
| 858 |
|
| 859 |
+
# ββ Residual-replacement mutex ββββββββββββββββββββββββββββββββββββ
|
| 860 |
+
_active = [n for n, f in [('use_dca', use_dca),
|
| 861 |
+
('use_mudd', use_mudd),
|
| 862 |
+
('use_attn_res', use_attn_res)] if f]
|
| 863 |
+
if len(_active) > 1:
|
| 864 |
raise ValueError(
|
| 865 |
+
f"use_dca, use_mudd, and use_attn_res are mutually exclusive. "
|
| 866 |
+
f"Got {_active} simultaneously. Set exactly one to True."
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
# ββ StackTrans / residual-replacement mutex βββββββββββββββββββββββ
|
| 870 |
+
if use_stacktrans and len(_active) > 0:
|
| 871 |
+
raise ValueError(
|
| 872 |
+
f"use_stacktrans is mutually exclusive with use_attn_res, "
|
| 873 |
+
f"use_mudd, and use_dca. Got use_stacktrans=True alongside "
|
| 874 |
+
f"{_active}. Set exactly one residual-replacement flag to True."
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
# ββ LAuReL / residual-replacement mutex βββββββββββββββββββββββββββ
|
| 878 |
+
# LAuReL's initialisation guarantees (BA=0 at step 0) assume that
|
| 879 |
+
# residual_attn and residual_mlp are standard accumulated hidden
|
| 880 |
+
# states. MUDD, DCA, and AttnRes replace these with custom-
|
| 881 |
+
# aggregated tensors, invalidating the assumption.
|
| 882 |
+
# LAuReL IS compatible with use_stacktrans (different position).
|
| 883 |
+
if use_laurel and len(_active) > 0:
|
| 884 |
+
raise ValueError(
|
| 885 |
+
f"use_laurel is mutually exclusive with use_attn_res, "
|
| 886 |
+
f"use_mudd, and use_dca (residual tensors are not standard "
|
| 887 |
+
f"accumulated hidden states when those flags are active). "
|
| 888 |
+
f"Got use_laurel=True alongside {_active}."
|
| 889 |
+
)
|
| 890 |
+
if use_laurel and not (use_laurel_rw or use_laurel_lr):
|
| 891 |
+
raise ValueError(
|
| 892 |
+
"use_laurel=True requires at least one of "
|
| 893 |
+
"use_laurel_rw=True or use_laurel_lr=True."
|
| 894 |
)
|
| 895 |
if use_mudd and mudd_dense_type not in ("qkvr", "l"):
|
| 896 |
raise ValueError(
|
|
|
|
| 1005 |
self.repo_start_layer = repo_start_layer
|
| 1006 |
self.repo_d_p = repo_d_p
|
| 1007 |
|
| 1008 |
+
# ββ DCA (Heddes et al., 2025) βββββββββββββββββββββββββββββββββββββ
|
| 1009 |
+
self.use_dca = use_dca
|
| 1010 |
+
self.dca_k = dca_k
|
| 1011 |
+
self.dca_use_final_grn = dca_use_final_grn
|
| 1012 |
+
self.dca_grn_eps = dca_grn_eps
|
| 1013 |
+
|
| 1014 |
# ββ MUDD connections (Xiao et al., 2025) βββββββββββββββββββββββββ
|
| 1015 |
self.use_mudd = use_mudd
|
| 1016 |
self.mudd_dense_type = mudd_dense_type
|
|
|
|
| 1019 |
self.mudd_expand_last = mudd_expand_last
|
| 1020 |
self.mudd_sepln = mudd_sepln
|
| 1021 |
|
| 1022 |
+
# ββ StackTrans (Zhang et al., NeurIPS 2025) βββββββββββββββββββββββ
|
| 1023 |
+
self.use_stacktrans = use_stacktrans
|
| 1024 |
+
self.stacktrans_num_heads = stacktrans_num_heads
|
| 1025 |
+
self.stacktrans_stack_slots = stacktrans_stack_slots
|
| 1026 |
+
self.stacktrans_stack_d_model = stacktrans_stack_d_model
|
| 1027 |
+
self.stacktrans_forward_bs = stacktrans_forward_bs
|
| 1028 |
+
|
| 1029 |
+
# ββ LAuReL (Menghani, Kumar & Kumar, ICML 2025) βββββββββββββββββββ
|
| 1030 |
+
self.use_laurel = use_laurel
|
| 1031 |
+
self.use_laurel_rw = use_laurel_rw
|
| 1032 |
+
self.use_laurel_lr = use_laurel_lr
|
| 1033 |
+
self.laurel_lr_rank = laurel_lr_rank
|
| 1034 |
+
|
| 1035 |
+
# ββ GatedDeltaNet linear attention ββββββββββββββββββββββββββββββββ
|
| 1036 |
+
self.use_linear_attention = use_linear_attention
|
| 1037 |
+
self.linear_attention_every_n = linear_attention_every_n
|
| 1038 |
+
self.use_repo_in_linear_attn = use_repo_in_linear_attn
|
| 1039 |
+
self.linear_conv_kernel_dim = linear_conv_kernel_dim
|
| 1040 |
+
self.linear_key_head_dim = linear_key_head_dim
|
| 1041 |
+
self.linear_value_head_dim = linear_value_head_dim
|
| 1042 |
+
self.linear_num_key_heads = linear_num_key_heads
|
| 1043 |
+
self.linear_num_value_heads = linear_num_value_heads
|
| 1044 |
+
|
| 1045 |
# ββ VersatileFFN (Nie et al., 2026) βββββββββββββββββββββββββββββββ
|
| 1046 |
self.use_versatile_ffn = use_versatile_ffn
|
| 1047 |
self.versatile_total_experts = versatile_total_experts
|