KitsuVp commited on
Commit
758569d
·
verified ·
1 Parent(s): 033bc2c

Update configuration_neollm.py

Browse files
Files changed (1) hide show
  1. configuration_neollm.py +27 -0
configuration_neollm.py CHANGED
@@ -247,6 +247,28 @@ class NeoLLMConfig(PretrainedConfig):
247
  Coefficient ``λ`` for the load-balancing auxiliary loss.
248
  jtokm_norm_eps (:obj:`float`, *optional*, defaults to 1e-6):
249
  Epsilon for L2 normalisation of modulation vectors.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  Constraints:
252
  - ``use_jtokm=True`` requires ``use_token_generator=True``.
@@ -353,6 +375,8 @@ class NeoLLMConfig(PretrainedConfig):
353
  jtokm_num_modes=4,
354
  jtokm_aux_loss_weight=1e-4,
355
  jtokm_norm_eps=1e-6,
 
 
356
  **kwargs,
357
  ):
358
  # ── Generator / tying consistency ─────────────────────────────────
@@ -463,6 +487,9 @@ class NeoLLMConfig(PretrainedConfig):
463
  self.jtokm_aux_loss_weight = jtokm_aux_loss_weight
464
  self.jtokm_norm_eps = jtokm_norm_eps
465
 
 
 
 
466
  self.auto_map = {
467
  "AutoConfig": "configuration_neollm.NeoLLMConfig",
468
  "AutoModel": "modeling_neollm.NeoLLMModel",
 
247
  Coefficient ``λ`` for the load-balancing auxiliary loss.
248
  jtokm_norm_eps (:obj:`float`, *optional*, defaults to 1e-6):
249
  Epsilon for L2 normalisation of modulation vectors.
250
+ use_hadamard_o_proj (:obj:`bool`, *optional*, defaults to ``False``):
251
+ Replace the dense ``W_O ∈ R^{d×d}`` output projection in every
252
+ multi-head attention block with a fixed Walsh–Hadamard Transform
253
+ followed by a learnable per-channel affine rescaling
254
+ ``α ⊙ FWHT(x)/√d + β``.
255
+
256
+ The WHT is a parameter-free orthogonal matrix whose singular values
257
+ are all identically 1, so the effective condition number is
258
+ ``κ = 1`` by construction and cannot grow during training. This
259
+ directly addresses the high-κ pathology (κ up to 10^5) observed in
260
+ the dense ``o_proj`` matrices, which causes FP8 per-tensor
261
+ quantisation to lose low-magnitude directions entirely.
262
+
263
+ Parameter reduction: replaces ``d²`` weights with ``2d``
264
+ (``α`` and ``β``), saving ≈25% of attention parameters per block.
265
+ Requires ``hidden_size`` to be a power of 2 (512 ✓, 1024 ✓,
266
+ 768 ✗).
267
+
268
+ Reference: Aggarwal & Kumar (2026). *Rethinking Attention Output
269
+ Projection: Structured Hadamard Transforms for Efficient
270
+ Transformers.* arXiv:2603.08343.
271
+
272
 
273
  Constraints:
274
  - ``use_jtokm=True`` requires ``use_token_generator=True``.
 
375
  jtokm_num_modes=4,
376
  jtokm_aux_loss_weight=1e-4,
377
  jtokm_norm_eps=1e-6,
378
+ # ── Hadamard output projection (Aggarwal & Kumar, 2026) ───────────
379
+ use_hadamard_o_proj=True,
380
  **kwargs,
381
  ):
382
  # ── Generator / tying consistency ─────────────────────────────────
 
487
  self.jtokm_aux_loss_weight = jtokm_aux_loss_weight
488
  self.jtokm_norm_eps = jtokm_norm_eps
489
 
490
+ # ── Hadamard output projection (Aggarwal & Kumar, 2026) ───────────
491
+ self.use_hadamard_o_proj = use_hadamard_o_proj
492
+
493
  self.auto_map = {
494
  "AutoConfig": "configuration_neollm.NeoLLMConfig",
495
  "AutoModel": "modeling_neollm.NeoLLMModel",