KitsuVp commited on
Commit
b964bba
Β·
verified Β·
1 Parent(s): eb219fb

Update configuration_neollm.py

Browse files
Files changed (1) hide show
  1. configuration_neollm.py +309 -7
configuration_neollm.py CHANGED
@@ -414,10 +414,59 @@ class NeoLLMConfig(PretrainedConfig):
414
  Li, H., Zhao, T., Cai, D. & Sproat, R. (2026). *REPO: Language Models
415
  with Context Re-Positioning.* arXiv:2512.14391.
416
 
 
 
 
417
  Xiao, D., Meng, Q., Li, S. & Yuan, X. (2025). *MUDDFormer: Breaking
418
  Residual Bottlenecks in Transformers via Multiway Dynamic Dense
419
  Connections.* arXiv:2502.12170.
420
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  use_mudd (:obj:`bool`, *optional*, defaults to ``False``):
422
  Enable **Multiway Dynamic Dense (MUDD) connections** (Xiao et al.,
423
  2025). Replaces standard residual connections with learned,
@@ -478,6 +527,172 @@ class NeoLLMConfig(PretrainedConfig):
478
  Enables independent rescaling per stream when
479
  ``mudd_dense_type="qkvr"``. Adds 2 Γ— SeeDNorm parameters per
480
  decoder layer. Ignored when ``mudd_dense_type="l"``.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  """
482
 
483
  model_type = "neollm"
@@ -519,7 +734,7 @@ class NeoLLMConfig(PretrainedConfig):
519
  directional_routing_temp=3.0,
520
  # ── Attention Residuals (Kimi Team, 2026) ─────────────────────────
521
  use_attn_res=False,
522
- attn_res_num_blocks=4,
523
  fan_ratio=0.125,
524
  fan_ratio_ffn=0.0625,
525
  dropout_rate=0.1,
@@ -557,13 +772,42 @@ class NeoLLMConfig(PretrainedConfig):
557
  versatile_gumbel_temp_end=0.1,
558
  versatile_gumbel_temp_decay=0.99984,
559
  versatile_aux_loss_weight=1e-5,
 
 
 
 
 
560
  # ── MUDD connections (Xiao et al., 2025) ─────────────────────────
561
  use_mudd=False,
562
  mudd_dense_type="qkvr",
563
  mudd_dynamic_dense=True,
564
- mudd_round64=False,
565
  mudd_expand_last=True,
566
  mudd_sepln=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
567
  **kwargs,
568
  ):
569
  # ── Generator / tying consistency ─────────────────────────────────
@@ -612,12 +856,41 @@ class NeoLLMConfig(PretrainedConfig):
612
  f"num_hidden_layers={num_hidden_layers}."
613
  )
614
 
615
- # ── MUDD: validate and resolve ────────────────────────────────────
616
- if use_mudd and use_attn_res:
 
 
 
617
  raise ValueError(
618
- "`use_mudd=True` and `use_attn_res=True` are mutually exclusive. "
619
- "Both mechanisms replace residual aggregation across depth and "
620
- "cannot be active simultaneously. Set exactly one to True."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  )
622
  if use_mudd and mudd_dense_type not in ("qkvr", "l"):
623
  raise ValueError(
@@ -732,6 +1005,12 @@ class NeoLLMConfig(PretrainedConfig):
732
  self.repo_start_layer = repo_start_layer
733
  self.repo_d_p = repo_d_p
734
 
 
 
 
 
 
 
735
  # ── MUDD connections (Xiao et al., 2025) ─────────────────────────
736
  self.use_mudd = use_mudd
737
  self.mudd_dense_type = mudd_dense_type
@@ -740,6 +1019,29 @@ class NeoLLMConfig(PretrainedConfig):
740
  self.mudd_expand_last = mudd_expand_last
741
  self.mudd_sepln = mudd_sepln
742
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
743
  # ── VersatileFFN (Nie et al., 2026) ───────────────────────────────
744
  self.use_versatile_ffn = use_versatile_ffn
745
  self.versatile_total_experts = versatile_total_experts
 
414
  Li, H., Zhao, T., Cai, D. & Sproat, R. (2026). *REPO: Language Models
415
  with Context Re-Positioning.* arXiv:2512.14391.
416
 
417
+ Heddes, M. et al. (2025). *DeepCrossAttention: Supercharging
418
+ Transformer Residual Connections.* arXiv:2502.06785.
419
+
420
  Xiao, D., Meng, Q., Li, S. & Yuan, X. (2025). *MUDDFormer: Breaking
421
  Residual Bottlenecks in Transformers via Multiway Dynamic Dense
422
  Connections.* arXiv:2502.12170.
423
 
424
+ use_dca (:obj:`bool`, *optional*, defaults to ``False``):
425
+ Enable **DeepCrossAttention (DCA)** (Heddes et al., 2025).
426
+ Replaces standard residual connections with three independent
427
+ GRN-v3 modules per decoder layer β€” one each for the Query,
428
+ Key, and Value streams β€” that dynamically aggregate outputs of
429
+ all preceding layers with both dimension-dependent (static) and
430
+ input-dependent (dynamic) learned weights.
431
+
432
+ **k-DCA efficiency** (``dca_k``): only the first and last ``k``
433
+ layer outputs are kept in the stack, bounding memory and
434
+ computation at ``O(2k)`` rather than ``O(L)``.
435
+
436
+ **Mutual exclusion**: ``use_dca``, ``use_attn_res``, and
437
+ ``use_mudd`` are mutually exclusive. Set exactly one to True.
438
+
439
+ Reference: Heddes, M. et al. (2025). *DeepCrossAttention:
440
+ Supercharging Transformer Residual Connections.*
441
+ arXiv:2502.06785.
442
+
443
+ dca_k (:obj:`int`, *optional*, defaults to 2):
444
+ Number of first and last layer outputs retained in the
445
+ depth-wise stack for k-DCA. With ``k=2`` the stack contains
446
+ at most 4 tensors (first 2 + last 2) regardless of depth.
447
+
448
+ Paper Table 1 results (24-layer model on LM1B):
449
+
450
+ - ``k=1`` β€” 0.33Γ— time to transformer PPL, PPL 14.48
451
+ - ``k=2`` β€” 0.33Γ— time, PPL 14.41 ← recommended default
452
+ - ``k=4`` β€” 0.37Γ— time, PPL 14.50
453
+ - ``k=24`` (full) β€” 0.39Γ— time, PPL 14.35
454
+
455
+ Smaller k gives faster training and lower inference latency
456
+ at a very small perplexity cost. k=2 is the best
457
+ efficiency-quality trade-off per the paper.
458
+
459
+ dca_use_final_grn (:obj:`bool`, *optional*, defaults to ``True``):
460
+ Apply a final GRN-v3 aggregation (``num_outputs=1``) over the
461
+ k-selected depth stack after all decoder layers, before the
462
+ output norm and lm_head. This matches the DCAGPT architecture
463
+ exactly. The final GRN collapses all depth history into the
464
+ final hidden representation using learned weights, rather than
465
+ using the raw last-layer output.
466
+
467
+ dca_grn_eps (:obj:`float`, *optional*, defaults to 1e-6):
468
+ Epsilon for the no-scale RMSNorm inside each GRN-v3 module.
469
+
470
  use_mudd (:obj:`bool`, *optional*, defaults to ``False``):
471
  Enable **Multiway Dynamic Dense (MUDD) connections** (Xiao et al.,
472
  2025). Replaces standard residual connections with learned,
 
527
  Enables independent rescaling per stream when
528
  ``mudd_dense_type="qkvr"``. Adds 2 Γ— SeeDNorm parameters per
529
  decoder layer. Ignored when ``mudd_dense_type="l"``.
530
+
531
+ use_stacktrans (:obj:`bool`, *optional*, defaults to ``False``):
532
+ Enable **StackTrans** (Zhang et al., NeurIPS 2025): inserts a
533
+ differentiable multi-head hidden-state stack between each pair of
534
+ Transformer layers, providing an explicit push/pop memory that
535
+ allows the model to learn Chomsky-hierarchy grammars (regular
536
+ expressions and deterministic context-free grammars) and improves
537
+ compositional generalisation and reasoning.
538
+
539
+ The stack is positioned at the **very beginning** of each decoder
540
+ layer forward pass, before the attention sublayer, so the
541
+ attention computation sees the stack-enriched hidden state.
542
+
543
+ **Mutual exclusion**: ``use_stacktrans`` cannot be active
544
+ simultaneously with ``use_attn_res``, ``use_mudd``, or
545
+ ``use_dca`` because all four alter the information flow entering
546
+ the residual stream. Set exactly one to ``True``.
547
+
548
+ Reference: Zhang, K. et al. (2025). *Recursive Transformer:
549
+ Boosting Reasoning Ability with State Stack.* NeurIPS 2025.
550
+
551
+ stacktrans_num_heads (:obj:`int`, *optional*, defaults to 4):
552
+ Number of independent stack heads ``H``. Each head maintains its
553
+ own low-rank stack of dimension ``ds = stacktrans_stack_d_model //
554
+ H``. The paper ablation (Fig. 4a) shows that ``H = 4`` is the
555
+ optimal trade-off: performance plateaus past this value while
556
+ overhead grows.
557
+
558
+ stacktrans_stack_slots (:obj:`int`, *optional*, defaults to 24):
559
+ Maximum depth ``S`` of the stack (number of slots). Overflow
560
+ elements are truncated to zero (a form of forgetting). The paper
561
+ ablation (Fig. 4c) shows that ``S = 24`` is optimal; increasing
562
+ to 32 yields no measurable gain.
563
+
564
+ stacktrans_stack_d_model (:obj:`int`, *optional*, defaults to 64):
565
+ Total dimensionality of the low-rank stack space, equal to
566
+ ``H Γ— ds``. The hidden state is projected down from
567
+ ``hidden_size`` to this dimension before stack operations, and
568
+ projected back up afterward. The paper uses ``H=4, ds=16``
569
+ (stack_d_model=64). From the ablation (Fig. 4b), ``ds`` in the
570
+ range 16–64 provides the best efficiency–quality trade-off.
571
+
572
+ stacktrans_forward_bs (:obj:`int`, *optional*, defaults to 1):
573
+ Batch size of the internal ``k_cache`` and ``action_cache``
574
+ buffers used for autoregressive generation. Must be β‰₯ the
575
+ actual generation batch size. At training time these buffers are
576
+ never used (``enable_cache=False``). Increasing this above 1
577
+ is only needed for batched generation.
578
+
579
+ use_laurel (:obj:`bool`, *optional*, defaults to ``False``):
580
+ Enable the **LAuReL** framework (Menghani, Kumar & Kumar, ICML
581
+ 2025): a learned generalisation of the canonical residual
582
+ connection that augments the residual stream with lightweight
583
+ learnable components, improving model quality with minimal
584
+ parameter overhead.
585
+
586
+ Standard Pre-LN residual connection:
587
+
588
+ .. math::
589
+ x_{i+1} = f(x_i) + x_i
590
+
591
+ LAuReL replaces this with:
592
+
593
+ .. math::
594
+ x_{i+1} = \\alpha \\cdot f(x_i) + g(x_i)
595
+
596
+ where :math:`\\alpha` is a learned scalar and :math:`g` is a
597
+ learned linear function. Two sub-variants are controlled
598
+ independently by ``use_laurel_rw`` and ``use_laurel_lr``.
599
+
600
+ In NeoLLM, LAuReL is applied to **both** residual connections
601
+ per decoder layer (attention and MLP), immediately before GPAS:
602
+
603
+ - Attention: ``GPAS(LAuReL(attn_out, residual_attn))``
604
+ - MLP: ``GPAS(LAuReL(delta_m, residual_mlp))``
605
+
606
+ GPAS then operates on the LAuReL-combined stream with its
607
+ stop-gradient scaling, so the two techniques remain
608
+ structurally orthogonal.
609
+
610
+ **Mutual exclusion**: ``use_laurel`` is incompatible with
611
+ ``use_attn_res``, ``use_mudd``, and ``use_dca`` because those
612
+ three replace the residual streams (``residual_attn``,
613
+ ``residual_mlp``) with custom-aggregated tensors whose
614
+ statistical properties differ from the standard accumulated
615
+ hidden state that LAuReL's initialisation guarantees assume.
616
+ ``use_laurel`` is compatible with ``use_stacktrans``.
617
+
618
+ Reference: Menghani, G., Kumar, R. & Kumar, S. (ICML 2025).
619
+ *LAuReL: Learned Augmented Residual Layer.* arXiv:2411.07501.
620
+
621
+ use_laurel_rw (:obj:`bool`, *optional*, defaults to ``True``):
622
+ Enable the **Residual Weights** (RW) sub-variant of LAuReL.
623
+ Requires ``use_laurel=True``.
624
+
625
+ Assigns independently learned scalar weights to the nonlinear
626
+ component :math:`f(x_i)` and the residual :math:`x_i`:
627
+
628
+ .. math::
629
+ x_{i+1} = \\alpha \\cdot f(x_i) + \\beta \\cdot g_{\\text{LR}}(x_i)
630
+
631
+ where :math:`[\\alpha, \\beta] = \\operatorname{softmax}([a, b])`
632
+ with :math:`a, b \\in \\mathbb{R}` learnable scalars (2 parameters
633
+ per LAuReL instantiation). Softmax normalisation prevents
634
+ unbounded growth confirmed by the paper's ablation.
635
+
636
+ Initialized with :math:`a = b = 0`, giving
637
+ :math:`\\alpha = \\beta = 0.5` at step 0 β€” the model quickly
638
+ learns the optimal weighting. In earlier layers the nonlinear
639
+ component dominates; in deeper layers the residual gains
640
+ relative importance, adaptively mitigating vanishing gradients.
641
+
642
+ When ``use_laurel_lr=False``, the RW formula becomes:
643
+
644
+ .. math::
645
+ x_{i+1} = \\alpha \\cdot f(x_i) + \\beta \\cdot x_i
646
+
647
+ When both RW and LR are active (recommended combination):
648
+
649
+ .. math::
650
+ x_{i+1} = \\alpha \\cdot f(x_i)
651
+ + \\beta \\cdot (B A x_i + x_i)
652
+
653
+ use_laurel_lr (:obj:`bool`, *optional*, defaults to ``True``):
654
+ Enable the **Low-Rank** (LR) sub-variant of LAuReL.
655
+ Requires ``use_laurel=True``.
656
+
657
+ Introduces a learnable low-rank linear transformation of the
658
+ residual :math:`x_i` that runs in parallel with the nonlinear
659
+ :math:`f(x_i)`:
660
+
661
+ .. math::
662
+ x_{i+1} = f(x_i) + W x_i + x_i,
663
+ \\quad W = A B + I,
664
+ \\quad A \\in \\mathbb{R}^{D \\times r},\\;
665
+ B^{\\top} \\in \\mathbb{R}^{D \\times r}
666
+
667
+ Equivalently written as a LoRA-style decomposition:
668
+
669
+ .. math::
670
+ x_{i+1} = f(x_i) + x_i + \\underbrace{A (B x_i)}_{\\text{low-rank term}}
671
+
672
+ where :math:`B \\in \\mathbb{R}^{r \\times D}` (down-projection,
673
+ column-orthogonal init) and :math:`A \\in \\mathbb{R}^{D \\times r}`
674
+ (up-projection, **zero init**). Zero init on :math:`A` ensures
675
+ the low-rank term is exactly zero at step 0, so the model starts
676
+ as a standard residual connection. Parameter count per LAuReL
677
+ layer: :math:`2rD`.
678
+
679
+ Adds :math:`2rD` parameters per residual connection (2 per decoder
680
+ layer since both attention and MLP residuals are augmented).
681
+ For hidden_size=512 and ``laurel_lr_rank=32``:
682
+ :math:`2 \\times 32 \\times 512 = 32{,}768` parameters per layer.
683
+
684
+ The paper recommends :math:`r \\in \\{32, 48, 64\\}` for LLMs.
685
+ For NeoLLM at 135M with 12 layers, ``r=32`` adds β‰ˆ786K parameters
686
+ (β‰ˆ0.6% of total), within the negligible overhead budget of the
687
+ paper (0.012%–0.1% for 1B–4B models).
688
+
689
+ laurel_lr_rank (:obj:`int`, *optional*, defaults to 32):
690
+ Rank :math:`r` of the low-rank matrices :math:`A` and :math:`B`
691
+ in the LAuReL-LR sub-variant. Controls the capacity vs. overhead
692
+ trade-off. The paper's ablation (Figure 3) shows performance
693
+ peaks at :math:`r \\in \\{16, 32\\}` for ResNet; for LLMs,
694
+ :math:`r \\in \\{32, 48, 64\\}` are recommended. Ignored when
695
+ ``use_laurel_lr=False``.
696
  """
697
 
698
  model_type = "neollm"
 
734
  directional_routing_temp=3.0,
735
  # ── Attention Residuals (Kimi Team, 2026) ─────────────────────────
736
  use_attn_res=False,
737
+ attn_res_num_blocks=2,
738
  fan_ratio=0.125,
739
  fan_ratio_ffn=0.0625,
740
  dropout_rate=0.1,
 
772
  versatile_gumbel_temp_end=0.1,
773
  versatile_gumbel_temp_decay=0.99984,
774
  versatile_aux_loss_weight=1e-5,
775
+ # ── DCA (Heddes et al., 2025) ─────────────────────────────────────
776
+ use_dca=False,
777
+ dca_k=1,
778
+ dca_use_final_grn=True,
779
+ dca_grn_eps=1e-6,
780
  # ── MUDD connections (Xiao et al., 2025) ─────────────────────────
781
  use_mudd=False,
782
  mudd_dense_type="qkvr",
783
  mudd_dynamic_dense=True,
784
+ mudd_round64=True,
785
  mudd_expand_last=True,
786
  mudd_sepln=False,
787
+ # ── StackTrans (Zhang et al., NeurIPS 2025) ───────────────────────
788
+ use_stacktrans=False,
789
+ stacktrans_num_heads=4,
790
+ stacktrans_stack_slots=24,
791
+ stacktrans_stack_d_model=64,
792
+ stacktrans_forward_bs=1,
793
+ # ── LAuReL (Menghani, Kumar & Kumar, ICML 2025) ───────────────────
794
+ use_laurel=False,
795
+ use_laurel_rw=True,
796
+ use_laurel_lr=True,
797
+ laurel_lr_rank=32,
798
+ # ── GatedDeltaNet linear attention (Yang et al., 2024) ───────────
799
+ # Replaces full attention every `linear_attention_every_n` layers
800
+ # (0-indexed: layers 2, 5, 8, ... for every_n=3).
801
+ # REPO applies to linear attention layers when both
802
+ # use_repo=True and use_repo_in_linear_attn=True.
803
+ use_linear_attention=False,
804
+ linear_attention_every_n=3,
805
+ use_repo_in_linear_attn=False,
806
+ linear_conv_kernel_dim=4,
807
+ linear_key_head_dim=32,
808
+ linear_value_head_dim=32,
809
+ linear_num_key_heads=8,
810
+ linear_num_value_heads=16,
811
  **kwargs,
812
  ):
813
  # ── Generator / tying consistency ─────────────────────────────────
 
856
  f"num_hidden_layers={num_hidden_layers}."
857
  )
858
 
859
+ # ── Residual-replacement mutex ────────────────────────────────────
860
+ _active = [n for n, f in [('use_dca', use_dca),
861
+ ('use_mudd', use_mudd),
862
+ ('use_attn_res', use_attn_res)] if f]
863
+ if len(_active) > 1:
864
  raise ValueError(
865
+ f"use_dca, use_mudd, and use_attn_res are mutually exclusive. "
866
+ f"Got {_active} simultaneously. Set exactly one to True."
867
+ )
868
+
869
+ # ── StackTrans / residual-replacement mutex ───────────────────────
870
+ if use_stacktrans and len(_active) > 0:
871
+ raise ValueError(
872
+ f"use_stacktrans is mutually exclusive with use_attn_res, "
873
+ f"use_mudd, and use_dca. Got use_stacktrans=True alongside "
874
+ f"{_active}. Set exactly one residual-replacement flag to True."
875
+ )
876
+
877
+ # ── LAuReL / residual-replacement mutex ───────────────────────────
878
+ # LAuReL's initialisation guarantees (BA=0 at step 0) assume that
879
+ # residual_attn and residual_mlp are standard accumulated hidden
880
+ # states. MUDD, DCA, and AttnRes replace these with custom-
881
+ # aggregated tensors, invalidating the assumption.
882
+ # LAuReL IS compatible with use_stacktrans (different position).
883
+ if use_laurel and len(_active) > 0:
884
+ raise ValueError(
885
+ f"use_laurel is mutually exclusive with use_attn_res, "
886
+ f"use_mudd, and use_dca (residual tensors are not standard "
887
+ f"accumulated hidden states when those flags are active). "
888
+ f"Got use_laurel=True alongside {_active}."
889
+ )
890
+ if use_laurel and not (use_laurel_rw or use_laurel_lr):
891
+ raise ValueError(
892
+ "use_laurel=True requires at least one of "
893
+ "use_laurel_rw=True or use_laurel_lr=True."
894
  )
895
  if use_mudd and mudd_dense_type not in ("qkvr", "l"):
896
  raise ValueError(
 
1005
  self.repo_start_layer = repo_start_layer
1006
  self.repo_d_p = repo_d_p
1007
 
1008
+ # ── DCA (Heddes et al., 2025) ─────────────────────────────────────
1009
+ self.use_dca = use_dca
1010
+ self.dca_k = dca_k
1011
+ self.dca_use_final_grn = dca_use_final_grn
1012
+ self.dca_grn_eps = dca_grn_eps
1013
+
1014
  # ── MUDD connections (Xiao et al., 2025) ─────────────────────────
1015
  self.use_mudd = use_mudd
1016
  self.mudd_dense_type = mudd_dense_type
 
1019
  self.mudd_expand_last = mudd_expand_last
1020
  self.mudd_sepln = mudd_sepln
1021
 
1022
+ # ── StackTrans (Zhang et al., NeurIPS 2025) ───────────────────────
1023
+ self.use_stacktrans = use_stacktrans
1024
+ self.stacktrans_num_heads = stacktrans_num_heads
1025
+ self.stacktrans_stack_slots = stacktrans_stack_slots
1026
+ self.stacktrans_stack_d_model = stacktrans_stack_d_model
1027
+ self.stacktrans_forward_bs = stacktrans_forward_bs
1028
+
1029
+ # ── LAuReL (Menghani, Kumar & Kumar, ICML 2025) ───────────────────
1030
+ self.use_laurel = use_laurel
1031
+ self.use_laurel_rw = use_laurel_rw
1032
+ self.use_laurel_lr = use_laurel_lr
1033
+ self.laurel_lr_rank = laurel_lr_rank
1034
+
1035
+ # ── GatedDeltaNet linear attention ────────────────────────────────
1036
+ self.use_linear_attention = use_linear_attention
1037
+ self.linear_attention_every_n = linear_attention_every_n
1038
+ self.use_repo_in_linear_attn = use_repo_in_linear_attn
1039
+ self.linear_conv_kernel_dim = linear_conv_kernel_dim
1040
+ self.linear_key_head_dim = linear_key_head_dim
1041
+ self.linear_value_head_dim = linear_value_head_dim
1042
+ self.linear_num_key_heads = linear_num_key_heads
1043
+ self.linear_num_value_heads = linear_num_value_heads
1044
+
1045
  # ── VersatileFFN (Nie et al., 2026) ───────────────────────────────
1046
  self.use_versatile_ffn = use_versatile_ffn
1047
  self.versatile_total_experts = versatile_total_experts