KitsuVp commited on
Commit
761a91d
Β·
verified Β·
1 Parent(s): 2574d55

Update configuration_neollm.py

Browse files
Files changed (1) hide show
  1. configuration_neollm.py +3 -374
configuration_neollm.py CHANGED
@@ -413,286 +413,6 @@ class NeoLLMConfig(PretrainedConfig):
413
 
414
  Li, H., Zhao, T., Cai, D. & Sproat, R. (2026). *REPO: Language Models
415
  with Context Re-Positioning.* arXiv:2512.14391.
416
-
417
- Heddes, M. et al. (2025). *DeepCrossAttention: Supercharging
418
- Transformer Residual Connections.* arXiv:2502.06785.
419
-
420
- Xiao, D., Meng, Q., Li, S. & Yuan, X. (2025). *MUDDFormer: Breaking
421
- Residual Bottlenecks in Transformers via Multiway Dynamic Dense
422
- Connections.* arXiv:2502.12170.
423
-
424
- use_dca (:obj:`bool`, *optional*, defaults to ``False``):
425
- Enable **DeepCrossAttention (DCA)** (Heddes et al., 2025).
426
- Replaces standard residual connections with three independent
427
- GRN-v3 modules per decoder layer β€” one each for the Query,
428
- Key, and Value streams β€” that dynamically aggregate outputs of
429
- all preceding layers with both dimension-dependent (static) and
430
- input-dependent (dynamic) learned weights.
431
-
432
- **k-DCA efficiency** (``dca_k``): only the first and last ``k``
433
- layer outputs are kept in the stack, bounding memory and
434
- computation at ``O(2k)`` rather than ``O(L)``.
435
-
436
- **Mutual exclusion**: ``use_dca``, ``use_attn_res``, and
437
- ``use_mudd`` are mutually exclusive. Set exactly one to True.
438
-
439
- Reference: Heddes, M. et al. (2025). *DeepCrossAttention:
440
- Supercharging Transformer Residual Connections.*
441
- arXiv:2502.06785.
442
-
443
- dca_k (:obj:`int`, *optional*, defaults to 2):
444
- Number of first and last layer outputs retained in the
445
- depth-wise stack for k-DCA. With ``k=2`` the stack contains
446
- at most 4 tensors (first 2 + last 2) regardless of depth.
447
-
448
- Paper Table 1 results (24-layer model on LM1B):
449
-
450
- - ``k=1`` β€” 0.33Γ— time to transformer PPL, PPL 14.48
451
- - ``k=2`` β€” 0.33Γ— time, PPL 14.41 ← recommended default
452
- - ``k=4`` β€” 0.37Γ— time, PPL 14.50
453
- - ``k=24`` (full) β€” 0.39Γ— time, PPL 14.35
454
-
455
- Smaller k gives faster training and lower inference latency
456
- at a very small perplexity cost. k=2 is the best
457
- efficiency-quality trade-off per the paper.
458
-
459
- dca_use_final_grn (:obj:`bool`, *optional*, defaults to ``True``):
460
- Apply a final GRN-v3 aggregation (``num_outputs=1``) over the
461
- k-selected depth stack after all decoder layers, before the
462
- output norm and lm_head. This matches the DCAGPT architecture
463
- exactly. The final GRN collapses all depth history into the
464
- final hidden representation using learned weights, rather than
465
- using the raw last-layer output.
466
-
467
- dca_grn_eps (:obj:`float`, *optional*, defaults to 1e-6):
468
- Epsilon for the no-scale RMSNorm inside each GRN-v3 module.
469
-
470
- use_mudd (:obj:`bool`, *optional*, defaults to ``False``):
471
- Enable **Multiway Dynamic Dense (MUDD) connections** (Xiao et al.,
472
- 2025). Replaces standard residual connections with learned,
473
- input-dependent depth-wise aggregation over all preceding layer
474
- outputs, producing up to four decoupled input streams (Q, K, V, R)
475
- for each Transformer block.
476
-
477
- **Mutually exclusive with** ``use_attn_res``. Both mechanisms
478
- replace residual aggregation and cannot be active simultaneously.
479
-
480
- Reference: Xiao, D. et al. (2025). *MUDDFormer: Breaking Residual
481
- Bottlenecks in Transformers via Multiway Dynamic Dense Connections.*
482
- arXiv:2502.12170.
483
-
484
- mudd_dense_type (:obj:`str`, *optional*, defaults to ``"qkvr"``):
485
- Stream configuration for the DA modules. Two options:
486
-
487
- - ``"qkvr"``: four independent aggregated streams, one each for the
488
- Query, Key, Value and Residual inputs of every Transformer block.
489
- This is the full MUDDFormer configuration and the main
490
- contribution of the paper. Cross-layer communication bandwidth is
491
- expanded 4Γ— relative to single-stream approaches.
492
- - ``"l"``: a single aggregated stream applied only to the residual
493
- path (equivalent to DDFormer / DenseFormer-dynamic).
494
-
495
- Ablation (Table 5 of the paper): removing any single stream hurts
496
- performance; the value stream benefits most.
497
-
498
- mudd_dynamic_dense (:obj:`bool`, *optional*, defaults to ``True``):
499
- Whether to generate connection weights dynamically from the current
500
- hidden state (``True``, MUDDFormer) or use only learned static
501
- scalar weights (``False``, equivalent to DenseFormer).
502
-
503
- Dynamic weights are computed position-wise via a two-layer MLP:
504
-
505
- .. math::
506
- A_i(X_i) = \text{GELU}(\text{RMSNorm}(X_i)\,W_1)\,W_2 + a_i
507
-
508
- where :math:`a_i` is a learnable static prior (initialized as
509
- identity on the current layer). Setting this to ``False`` disables
510
- :math:`W_1` and :math:`W_2`, retaining only the static bias.
511
-
512
- mudd_round64 (:obj:`bool`, *optional*, defaults to ``True``):
513
- Round the inner hidden dimension of each DA module up to the
514
- nearest multiple of 64 for hardware-aligned tensor operations.
515
- Recommended for training on CUDA devices. Slightly increases
516
- parameter count but improves throughput.
517
-
518
- mudd_expand_last (:obj:`bool`, *optional*, defaults to ``True``):
519
- Multiply the DA module hidden dimension by 4 for the final
520
- Transformer layer. The last layer's aggregation benefits from
521
- higher capacity because it summarizes the entire depth of the
522
- network before the output projection.
523
-
524
- mudd_sepln (:obj:`bool`, *optional*, defaults to ``False``):
525
- Use separate SeeDNorm pre-normalization layers for the K and V
526
- input streams (Q already uses the existing ``input_layernorm``).
527
- Enables independent rescaling per stream when
528
- ``mudd_dense_type="qkvr"``. Adds 2 Γ— SeeDNorm parameters per
529
- decoder layer. Ignored when ``mudd_dense_type="l"``.
530
-
531
- use_stacktrans (:obj:`bool`, *optional*, defaults to ``False``):
532
- Enable **StackTrans** (Zhang et al., NeurIPS 2025): inserts a
533
- differentiable multi-head hidden-state stack between each pair of
534
- Transformer layers, providing an explicit push/pop memory that
535
- allows the model to learn Chomsky-hierarchy grammars (regular
536
- expressions and deterministic context-free grammars) and improves
537
- compositional generalisation and reasoning.
538
-
539
- The stack is positioned at the **very beginning** of each decoder
540
- layer forward pass, before the attention sublayer, so the
541
- attention computation sees the stack-enriched hidden state.
542
-
543
- **Mutual exclusion**: ``use_stacktrans`` cannot be active
544
- simultaneously with ``use_attn_res``, ``use_mudd``, or
545
- ``use_dca`` because all four alter the information flow entering
546
- the residual stream. Set exactly one to ``True``.
547
-
548
- Reference: Zhang, K. et al. (2025). *Recursive Transformer:
549
- Boosting Reasoning Ability with State Stack.* NeurIPS 2025.
550
-
551
- stacktrans_num_heads (:obj:`int`, *optional*, defaults to 4):
552
- Number of independent stack heads ``H``. Each head maintains its
553
- own low-rank stack of dimension ``ds = stacktrans_stack_d_model //
554
- H``. The paper ablation (Fig. 4a) shows that ``H = 4`` is the
555
- optimal trade-off: performance plateaus past this value while
556
- overhead grows.
557
-
558
- stacktrans_stack_slots (:obj:`int`, *optional*, defaults to 24):
559
- Maximum depth ``S`` of the stack (number of slots). Overflow
560
- elements are truncated to zero (a form of forgetting). The paper
561
- ablation (Fig. 4c) shows that ``S = 24`` is optimal; increasing
562
- to 32 yields no measurable gain.
563
-
564
- stacktrans_stack_d_model (:obj:`int`, *optional*, defaults to 64):
565
- Total dimensionality of the low-rank stack space, equal to
566
- ``H Γ— ds``. The hidden state is projected down from
567
- ``hidden_size`` to this dimension before stack operations, and
568
- projected back up afterward. The paper uses ``H=4, ds=16``
569
- (stack_d_model=64). From the ablation (Fig. 4b), ``ds`` in the
570
- range 16–64 provides the best efficiency–quality trade-off.
571
-
572
- stacktrans_forward_bs (:obj:`int`, *optional*, defaults to 1):
573
- Batch size of the internal ``k_cache`` and ``action_cache``
574
- buffers used for autoregressive generation. Must be β‰₯ the
575
- actual generation batch size. At training time these buffers are
576
- never used (``enable_cache=False``). Increasing this above 1
577
- is only needed for batched generation.
578
-
579
- use_laurel (:obj:`bool`, *optional*, defaults to ``False``):
580
- Enable the **LAuReL** framework (Menghani, Kumar & Kumar, ICML
581
- 2025): a learned generalisation of the canonical residual
582
- connection that augments the residual stream with lightweight
583
- learnable components, improving model quality with minimal
584
- parameter overhead.
585
-
586
- Standard Pre-LN residual connection:
587
-
588
- .. math::
589
- x_{i+1} = f(x_i) + x_i
590
-
591
- LAuReL replaces this with:
592
-
593
- .. math::
594
- x_{i+1} = \\alpha \\cdot f(x_i) + g(x_i)
595
-
596
- where :math:`\\alpha` is a learned scalar and :math:`g` is a
597
- learned linear function. Two sub-variants are controlled
598
- independently by ``use_laurel_rw`` and ``use_laurel_lr``.
599
-
600
- In NeoLLM, LAuReL is applied to **both** residual connections
601
- per decoder layer (attention and MLP), immediately before GPAS:
602
-
603
- - Attention: ``GPAS(LAuReL(attn_out, residual_attn))``
604
- - MLP: ``GPAS(LAuReL(delta_m, residual_mlp))``
605
-
606
- GPAS then operates on the LAuReL-combined stream with its
607
- stop-gradient scaling, so the two techniques remain
608
- structurally orthogonal.
609
-
610
- **Mutual exclusion**: ``use_laurel`` is incompatible with
611
- ``use_attn_res``, ``use_mudd``, and ``use_dca`` because those
612
- three replace the residual streams (``residual_attn``,
613
- ``residual_mlp``) with custom-aggregated tensors whose
614
- statistical properties differ from the standard accumulated
615
- hidden state that LAuReL's initialisation guarantees assume.
616
- ``use_laurel`` is compatible with ``use_stacktrans``.
617
-
618
- Reference: Menghani, G., Kumar, R. & Kumar, S. (ICML 2025).
619
- *LAuReL: Learned Augmented Residual Layer.* arXiv:2411.07501.
620
-
621
- use_laurel_rw (:obj:`bool`, *optional*, defaults to ``True``):
622
- Enable the **Residual Weights** (RW) sub-variant of LAuReL.
623
- Requires ``use_laurel=True``.
624
-
625
- Assigns independently learned scalar weights to the nonlinear
626
- component :math:`f(x_i)` and the residual :math:`x_i`:
627
-
628
- .. math::
629
- x_{i+1} = \\alpha \\cdot f(x_i) + \\beta \\cdot g_{\\text{LR}}(x_i)
630
-
631
- where :math:`[\\alpha, \\beta] = \\operatorname{softmax}([a, b])`
632
- with :math:`a, b \\in \\mathbb{R}` learnable scalars (2 parameters
633
- per LAuReL instantiation). Softmax normalisation prevents
634
- unbounded growth confirmed by the paper's ablation.
635
-
636
- Initialized with :math:`a = b = 0`, giving
637
- :math:`\\alpha = \\beta = 0.5` at step 0 β€” the model quickly
638
- learns the optimal weighting. In earlier layers the nonlinear
639
- component dominates; in deeper layers the residual gains
640
- relative importance, adaptively mitigating vanishing gradients.
641
-
642
- When ``use_laurel_lr=False``, the RW formula becomes:
643
-
644
- .. math::
645
- x_{i+1} = \\alpha \\cdot f(x_i) + \\beta \\cdot x_i
646
-
647
- When both RW and LR are active (recommended combination):
648
-
649
- .. math::
650
- x_{i+1} = \\alpha \\cdot f(x_i)
651
- + \\beta \\cdot (B A x_i + x_i)
652
-
653
- use_laurel_lr (:obj:`bool`, *optional*, defaults to ``True``):
654
- Enable the **Low-Rank** (LR) sub-variant of LAuReL.
655
- Requires ``use_laurel=True``.
656
-
657
- Introduces a learnable low-rank linear transformation of the
658
- residual :math:`x_i` that runs in parallel with the nonlinear
659
- :math:`f(x_i)`:
660
-
661
- .. math::
662
- x_{i+1} = f(x_i) + W x_i + x_i,
663
- \\quad W = A B + I,
664
- \\quad A \\in \\mathbb{R}^{D \\times r},\\;
665
- B^{\\top} \\in \\mathbb{R}^{D \\times r}
666
-
667
- Equivalently written as a LoRA-style decomposition:
668
-
669
- .. math::
670
- x_{i+1} = f(x_i) + x_i + \\underbrace{A (B x_i)}_{\\text{low-rank term}}
671
-
672
- where :math:`B \\in \\mathbb{R}^{r \\times D}` (down-projection,
673
- column-orthogonal init) and :math:`A \\in \\mathbb{R}^{D \\times r}`
674
- (up-projection, **zero init**). Zero init on :math:`A` ensures
675
- the low-rank term is exactly zero at step 0, so the model starts
676
- as a standard residual connection. Parameter count per LAuReL
677
- layer: :math:`2rD`.
678
-
679
- Adds :math:`2rD` parameters per residual connection (2 per decoder
680
- layer since both attention and MLP residuals are augmented).
681
- For hidden_size=512 and ``laurel_lr_rank=32``:
682
- :math:`2 \\times 32 \\times 512 = 32{,}768` parameters per layer.
683
-
684
- The paper recommends :math:`r \\in \\{32, 48, 64\\}` for LLMs.
685
- For NeoLLM at 135M with 12 layers, ``r=32`` adds β‰ˆ786K parameters
686
- (β‰ˆ0.6% of total), within the negligible overhead budget of the
687
- paper (0.012%–0.1% for 1B–4B models).
688
-
689
- laurel_lr_rank (:obj:`int`, *optional*, defaults to 32):
690
- Rank :math:`r` of the low-rank matrices :math:`A` and :math:`B`
691
- in the LAuReL-LR sub-variant. Controls the capacity vs. overhead
692
- trade-off. The paper's ablation (Figure 3) shows performance
693
- peaks at :math:`r \\in \\{16, 32\\}` for ResNet; for LLMs,
694
- :math:`r \\in \\{32, 48, 64\\}` are recommended. Ignored when
695
- ``use_laurel_lr=False``.
696
  """
697
 
698
  model_type = "neollm"
@@ -719,7 +439,7 @@ class NeoLLMConfig(PretrainedConfig):
719
  head_dim=64,
720
  use_momentum_attention=True,
721
  momentum_gamma=0.10,
722
- use_mea_attention=False,
723
  mea_component_key_value_heads=None,
724
  mea_groupnorm_eps=1e-6,
725
  use_lucid_attention=True,
@@ -734,7 +454,7 @@ class NeoLLMConfig(PretrainedConfig):
734
  directional_routing_temp=3.0,
735
  # ── Attention Residuals (Kimi Team, 2026) ─────────────────────────
736
  use_attn_res=False,
737
- attn_res_num_blocks=2,
738
  fan_ratio=0.125,
739
  fan_ratio_ffn=0.0625,
740
  dropout_rate=0.1,
@@ -765,36 +485,13 @@ class NeoLLMConfig(PretrainedConfig):
765
  repo_d_p=None,
766
  # ── VersatileFFN (Nie et al., 2026) ───────────────────────────────
767
  use_versatile_ffn=False,
768
- versatile_total_experts=8,
769
  versatile_active_experts=2,
770
  versatile_max_depth=2,
771
  versatile_gumbel_temp_start=5.0,
772
  versatile_gumbel_temp_end=0.1,
773
  versatile_gumbel_temp_decay=0.99984,
774
  versatile_aux_loss_weight=1e-5,
775
- # ── DCA (Heddes et al., 2025) ─────────────────────────────────────
776
- use_dca=False,
777
- dca_k=1,
778
- dca_use_final_grn=False,
779
- dca_grn_eps=1e-6,
780
- # ── MUDD connections (Xiao et al., 2025) ─────────────────────────
781
- use_mudd=False,
782
- mudd_dense_type="qkvr",
783
- mudd_dynamic_dense=False,
784
- mudd_round64=True,
785
- mudd_expand_last=True,
786
- mudd_sepln=False,
787
- # ── StackTrans (Zhang et al., NeurIPS 2025) ───────────────────────
788
- use_stacktrans=False,
789
- stacktrans_num_heads=4,
790
- stacktrans_stack_slots=16,
791
- stacktrans_stack_d_model=32,
792
- stacktrans_forward_bs=1,
793
- # ── LAuReL (Menghani, Kumar & Kumar, ICML 2025) ───────────────────
794
- use_laurel=False,
795
- use_laurel_rw=False,
796
- use_laurel_lr=True,
797
- laurel_lr_rank=32,
798
  **kwargs,
799
  ):
800
  # ── Generator / tying consistency ─────────────────────────────────
@@ -843,47 +540,6 @@ class NeoLLMConfig(PretrainedConfig):
843
  f"num_hidden_layers={num_hidden_layers}."
844
  )
845
 
846
- # ── Residual-replacement mutex ────────────────────────────────────
847
- _active = [n for n, f in [('use_dca', use_dca),
848
- ('use_mudd', use_mudd),
849
- ('use_attn_res', use_attn_res)] if f]
850
- if len(_active) > 1:
851
- raise ValueError(
852
- f"use_dca, use_mudd, and use_attn_res are mutually exclusive. "
853
- f"Got {_active} simultaneously. Set exactly one to True."
854
- )
855
-
856
- # ── StackTrans / residual-replacement mutex ───────────────────────
857
- if use_stacktrans and len(_active) > 0:
858
- raise ValueError(
859
- f"use_stacktrans is mutually exclusive with use_attn_res, "
860
- f"use_mudd, and use_dca. Got use_stacktrans=True alongside "
861
- f"{_active}. Set exactly one residual-replacement flag to True."
862
- )
863
-
864
- # ── LAuReL / residual-replacement mutex ───────────────────────────
865
- # LAuReL's initialisation guarantees (BA=0 at step 0) assume that
866
- # residual_attn and residual_mlp are standard accumulated hidden
867
- # states. MUDD, DCA, and AttnRes replace these with custom-
868
- # aggregated tensors, invalidating the assumption.
869
- # LAuReL IS compatible with use_stacktrans (different position).
870
- if use_laurel and len(_active) > 0:
871
- raise ValueError(
872
- f"use_laurel is mutually exclusive with use_attn_res, "
873
- f"use_mudd, and use_dca (residual tensors are not standard "
874
- f"accumulated hidden states when those flags are active). "
875
- f"Got use_laurel=True alongside {_active}."
876
- )
877
- if use_laurel and not (use_laurel_rw or use_laurel_lr):
878
- raise ValueError(
879
- "use_laurel=True requires at least one of "
880
- "use_laurel_rw=True or use_laurel_lr=True."
881
- )
882
- if use_mudd and mudd_dense_type not in ("qkvr", "l"):
883
- raise ValueError(
884
- f"`mudd_dense_type` must be 'qkvr' or 'l', got '{mudd_dense_type}'."
885
- )
886
-
887
  # ── VersatileFFN: validate expert configuration ────────────────────
888
  if use_versatile_ffn:
889
  if not (1 <= versatile_active_experts < versatile_total_experts):
@@ -992,33 +648,6 @@ class NeoLLMConfig(PretrainedConfig):
992
  self.repo_start_layer = repo_start_layer
993
  self.repo_d_p = repo_d_p
994
 
995
- # ── DCA (Heddes et al., 2025) ─────────────────────────────────────
996
- self.use_dca = use_dca
997
- self.dca_k = dca_k
998
- self.dca_use_final_grn = dca_use_final_grn
999
- self.dca_grn_eps = dca_grn_eps
1000
-
1001
- # ── MUDD connections (Xiao et al., 2025) ─────────────────────────
1002
- self.use_mudd = use_mudd
1003
- self.mudd_dense_type = mudd_dense_type
1004
- self.mudd_dynamic_dense = mudd_dynamic_dense
1005
- self.mudd_round64 = mudd_round64
1006
- self.mudd_expand_last = mudd_expand_last
1007
- self.mudd_sepln = mudd_sepln
1008
-
1009
- # ── StackTrans (Zhang et al., NeurIPS 2025) ───────────────────────
1010
- self.use_stacktrans = use_stacktrans
1011
- self.stacktrans_num_heads = stacktrans_num_heads
1012
- self.stacktrans_stack_slots = stacktrans_stack_slots
1013
- self.stacktrans_stack_d_model = stacktrans_stack_d_model
1014
- self.stacktrans_forward_bs = stacktrans_forward_bs
1015
-
1016
- # ── LAuReL (Menghani, Kumar & Kumar, ICML 2025) ───────────────────
1017
- self.use_laurel = use_laurel
1018
- self.use_laurel_rw = use_laurel_rw
1019
- self.use_laurel_lr = use_laurel_lr
1020
- self.laurel_lr_rank = laurel_lr_rank
1021
-
1022
  # ── VersatileFFN (Nie et al., 2026) ───────────────────────────────
1023
  self.use_versatile_ffn = use_versatile_ffn
1024
  self.versatile_total_experts = versatile_total_experts
 
413
 
414
  Li, H., Zhao, T., Cai, D. & Sproat, R. (2026). *REPO: Language Models
415
  with Context Re-Positioning.* arXiv:2512.14391.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  """
417
 
418
  model_type = "neollm"
 
439
  head_dim=64,
440
  use_momentum_attention=True,
441
  momentum_gamma=0.10,
442
+ use_mea_attention=True,
443
  mea_component_key_value_heads=None,
444
  mea_groupnorm_eps=1e-6,
445
  use_lucid_attention=True,
 
454
  directional_routing_temp=3.0,
455
  # ── Attention Residuals (Kimi Team, 2026) ─────────────────────────
456
  use_attn_res=False,
457
+ attn_res_num_blocks=4,
458
  fan_ratio=0.125,
459
  fan_ratio_ffn=0.0625,
460
  dropout_rate=0.1,
 
485
  repo_d_p=None,
486
  # ── VersatileFFN (Nie et al., 2026) ───────────────────────────────
487
  use_versatile_ffn=False,
488
+ versatile_total_experts=4,
489
  versatile_active_experts=2,
490
  versatile_max_depth=2,
491
  versatile_gumbel_temp_start=5.0,
492
  versatile_gumbel_temp_end=0.1,
493
  versatile_gumbel_temp_decay=0.99984,
494
  versatile_aux_loss_weight=1e-5,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  **kwargs,
496
  ):
497
  # ── Generator / tying consistency ─────────────────────────────────
 
540
  f"num_hidden_layers={num_hidden_layers}."
541
  )
542
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  # ── VersatileFFN: validate expert configuration ────────────────────
544
  if use_versatile_ffn:
545
  if not (1 <= versatile_active_experts < versatile_total_experts):
 
648
  self.repo_start_layer = repo_start_layer
649
  self.repo_d_p = repo_d_p
650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651
  # ── VersatileFFN (Nie et al., 2026) ───────────────────────────────
652
  self.use_versatile_ffn = use_versatile_ffn
653
  self.versatile_total_experts = versatile_total_experts