KitsuVp commited on
Commit
61006d3
Β·
verified Β·
1 Parent(s): a9df2b8

Update configuration_neollm.py

Browse files
Files changed (1) hide show
  1. configuration_neollm.py +104 -0
configuration_neollm.py CHANGED
@@ -349,6 +349,75 @@ class NeoLLMConfig(PretrainedConfig):
349
  is less rich than the full hidden representation. Ignored
350
  when ``use_repo=False``.
351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  Constraints:
353
  - ``use_jtokm=True`` requires ``use_token_generator=True``.
354
  - ``1 ≀ jtokm_top_k < jtokm_num_experts`` when ``use_jtokm=True``.
@@ -358,6 +427,9 @@ class NeoLLMConfig(PretrainedConfig):
358
  - ``repo_start_layer`` must satisfy
359
  ``0 <= repo_start_layer < num_hidden_layers`` when
360
  ``use_repo=True``.
 
 
 
361
 
362
  Examples::
363
 
@@ -413,6 +485,9 @@ class NeoLLMConfig(PretrainedConfig):
413
 
414
  Li, H., Zhao, T., Cai, D. & Sproat, R. (2026). *REPO: Language Models
415
  with Context Re-Positioning.* arXiv:2512.14391.
 
 
 
416
  """
417
 
418
  model_type = "neollm"
@@ -492,6 +567,11 @@ class NeoLLMConfig(PretrainedConfig):
492
  versatile_gumbel_temp_end=0.1,
493
  versatile_gumbel_temp_decay=0.99984,
494
  versatile_aux_loss_weight=1e-5,
 
 
 
 
 
495
  **kwargs,
496
  ):
497
  # ── Generator / tying consistency ─────────────────────────────────
@@ -554,6 +634,24 @@ class NeoLLMConfig(PretrainedConfig):
554
  f"`versatile_total_experts` ({versatile_total_experts})."
555
  )
556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
558
 
559
  # ── Core Transformer ──────────────────────────────────────────────
@@ -658,6 +756,12 @@ class NeoLLMConfig(PretrainedConfig):
658
  self.versatile_gumbel_temp_decay = versatile_gumbel_temp_decay
659
  self.versatile_aux_loss_weight = versatile_aux_loss_weight
660
 
 
 
 
 
 
 
661
  self.auto_map = {
662
  "AutoConfig": "configuration_neollm.NeoLLMConfig",
663
  "AutoModel": "modeling_neollm.NeoLLMModel",
 
349
  is less rich than the full hidden representation. Ignored
350
  when ``use_repo=False``.
351
 
352
+ use_laurel (:obj:`bool`, *optional*, defaults to ``False``):
353
+ Enable the Learned Augmented Residual Layer (LAUREL) framework
354
+ (Menghani, Kumar & Kumar, ICML 2025). LAUREL generalises the
355
+ canonical residual connection:
356
+
357
+ .. math::
358
+ x_{i+1} = \\alpha \\cdot f(x_i) + g(x_i)
359
+
360
+ where :math:`g` is a learned linear function operating on the
361
+ residual stream. Applied independently to both the attention
362
+ and MLP sublayers of every decoder layer.
363
+
364
+ At least one of ``use_laurel_rw`` or ``use_laurel_lr`` must be
365
+ ``True`` when this flag is active; both may be active
366
+ simultaneously, producing the combined **LAUREL-RW+LR** variant
367
+ (paper eq. 5).
368
+
369
+ Incompatible with ``use_attn_res=True`` β€” both methods modify
370
+ the residual stream and their interaction is undefined.
371
+
372
+ Reference: Menghani, G., Kumar, R. & Kumar, S. (2025).
373
+ *LAUREL: Learned Augmented Residual Layer.* ICML 2025.
374
+
375
+ use_laurel_rw (:obj:`bool`, *optional*, defaults to ``True``):
376
+ Enable the **LAUREL-RW** (Residual Weights) variant. Assigns
377
+ independent learned scalars :math:`\\alpha, \\beta` to the
378
+ sublayer output and residual respectively:
379
+
380
+ .. math::
381
+ x_{i+1} = \\alpha_s \\cdot f(x_i) + \\beta_s \\cdot x_i
382
+
383
+ :math:`\\alpha_s, \\beta_s = \\text{softmax}([\\tilde{\\alpha},
384
+ \\tilde{\\beta}])` so that they are non-negative and sum to 1,
385
+ preventing unbounded growth (paper Β§2.1). Adds **2 parameters
386
+ per sublayer** (4 per decoder layer).
387
+
388
+ When combined with ``use_laurel_lr=True`` (LAUREL-RW+LR,
389
+ paper eq. 5):
390
+
391
+ .. math::
392
+ x_{i+1} = \\alpha_s \\cdot f(x_i)
393
+ + \\beta_s \\cdot (B A x_i + x_i)
394
+
395
+ Ignored when ``use_laurel=False``.
396
+
397
+ use_laurel_lr (:obj:`bool`, *optional*, defaults to ``False``):
398
+ Enable the **LAUREL-LR** (Low-Rank) variant. Augments the
399
+ residual with a rank-``laurel_lr_rank`` correction:
400
+
401
+ .. math::
402
+ x_{i+1} = f(x_i) + B A x_i + x_i
403
+
404
+ where :math:`A \\in \\mathbb{R}^{D \\times r}` and
405
+ :math:`B \\in \\mathbb{R}^{r \\times D}` are learnable matrices
406
+ (paper eq. 3). :math:`A` is initialised with column-orthogonal
407
+ values :math:`A_{i,j} = 1/\\sqrt{rD}` if :math:`i \\bmod r = j`
408
+ else 0; :math:`B` is initialised to zero β€” matching the LoRA
409
+ convention and ensuring the residual starts as identity
410
+ (paper Β§3.3). Adds **2Β·rΒ·D parameters per sublayer**
411
+ (4Β·rΒ·D per decoder layer).
412
+
413
+ Ignored when ``use_laurel=False``.
414
+
415
+ laurel_lr_rank (:obj:`int`, *optional*, defaults to ``32``):
416
+ Rank ``r`` of the low-rank matrices in LAUREL-LR. The paper
417
+ recommends :math:`r \\in \\{32, 48, 64\\}` for LLMs
418
+ (paper Β§3.3). Ignored when ``use_laurel=False`` or
419
+ ``use_laurel_lr=False``.
420
+
421
  Constraints:
422
  - ``use_jtokm=True`` requires ``use_token_generator=True``.
423
  - ``1 ≀ jtokm_top_k < jtokm_num_experts`` when ``use_jtokm=True``.
 
427
  - ``repo_start_layer`` must satisfy
428
  ``0 <= repo_start_layer < num_hidden_layers`` when
429
  ``use_repo=True``.
430
+ - ``use_laurel=True`` is incompatible with ``use_attn_res=True``.
431
+ - When ``use_laurel=True``, at least one of ``use_laurel_rw`` or
432
+ ``use_laurel_lr`` must be ``True``.
433
 
434
  Examples::
435
 
 
485
 
486
  Li, H., Zhao, T., Cai, D. & Sproat, R. (2026). *REPO: Language Models
487
  with Context Re-Positioning.* arXiv:2512.14391.
488
+
489
+ Menghani, G., Kumar, R. & Kumar, S. (2025). *LAUREL: Learned Augmented
490
+ Residual Layer.* ICML 2025. arXiv:2411.07501.
491
  """
492
 
493
  model_type = "neollm"
 
567
  versatile_gumbel_temp_end=0.1,
568
  versatile_gumbel_temp_decay=0.99984,
569
  versatile_aux_loss_weight=1e-5,
570
+ # ── LAuReL: Learned Augmented Residual Layer (Menghani et al., 2025) ─
571
+ use_laurel=True,
572
+ use_laurel_rw=True,
573
+ use_laurel_lr=True,
574
+ laurel_lr_rank=32,
575
  **kwargs,
576
  ):
577
  # ── Generator / tying consistency ─────────────────────────────────
 
634
  f"`versatile_total_experts` ({versatile_total_experts})."
635
  )
636
 
637
+ # ── LAuReL: mutual exclusion and sub-flag consistency ─────────────
638
+ # use_laurel and use_attn_res both modify the residual stream and are
639
+ # structurally incompatible: AttnRes replaces the accumulation entirely
640
+ # with learned depth-wise attention, while LAuReL scales/augments the
641
+ # additive residual in-place.
642
+ if use_laurel and use_attn_res:
643
+ raise ValueError(
644
+ "`use_laurel=True` is incompatible with `use_attn_res=True`. "
645
+ "Both methods modify the residual stream: AttnRes replaces it "
646
+ "with depth-wise softmax attention, while LAuReL applies learned "
647
+ "scalar/low-rank augmentation in-place. Enable at most one."
648
+ )
649
+ if use_laurel and not use_laurel_rw and not use_laurel_lr:
650
+ raise ValueError(
651
+ "`use_laurel=True` requires at least one sub-variant to be active. "
652
+ "Set `use_laurel_rw=True` and/or `use_laurel_lr=True`."
653
+ )
654
+
655
  super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
656
 
657
  # ── Core Transformer ──────────────────────────────────────────────
 
756
  self.versatile_gumbel_temp_decay = versatile_gumbel_temp_decay
757
  self.versatile_aux_loss_weight = versatile_aux_loss_weight
758
 
759
+ # ── LAuReL: Learned Augmented Residual Layer (Menghani et al., 2025) ─
760
+ self.use_laurel = use_laurel
761
+ self.use_laurel_rw = use_laurel_rw
762
+ self.use_laurel_lr = use_laurel_lr
763
+ self.laurel_lr_rank = laurel_lr_rank
764
+
765
  self.auto_map = {
766
  "AutoConfig": "configuration_neollm.NeoLLMConfig",
767
  "AutoModel": "modeling_neollm.NeoLLMModel",