Update configuration_neollm.py
Browse files- configuration_neollm.py +27 -0
configuration_neollm.py
CHANGED
|
@@ -247,6 +247,28 @@ class NeoLLMConfig(PretrainedConfig):
|
|
| 247 |
Coefficient ``λ`` for the load-balancing auxiliary loss.
|
| 248 |
jtokm_norm_eps (:obj:`float`, *optional*, defaults to 1e-6):
|
| 249 |
Epsilon for L2 normalisation of modulation vectors.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
Constraints:
|
| 252 |
- ``use_jtokm=True`` requires ``use_token_generator=True``.
|
|
@@ -353,6 +375,8 @@ class NeoLLMConfig(PretrainedConfig):
|
|
| 353 |
jtokm_num_modes=4,
|
| 354 |
jtokm_aux_loss_weight=1e-4,
|
| 355 |
jtokm_norm_eps=1e-6,
|
|
|
|
|
|
|
| 356 |
**kwargs,
|
| 357 |
):
|
| 358 |
# ── Generator / tying consistency ─────────────────────────────────
|
|
@@ -463,6 +487,9 @@ class NeoLLMConfig(PretrainedConfig):
|
|
| 463 |
self.jtokm_aux_loss_weight = jtokm_aux_loss_weight
|
| 464 |
self.jtokm_norm_eps = jtokm_norm_eps
|
| 465 |
|
|
|
|
|
|
|
|
|
|
| 466 |
self.auto_map = {
|
| 467 |
"AutoConfig": "configuration_neollm.NeoLLMConfig",
|
| 468 |
"AutoModel": "modeling_neollm.NeoLLMModel",
|
|
|
|
| 247 |
Coefficient ``λ`` for the load-balancing auxiliary loss.
|
| 248 |
jtokm_norm_eps (:obj:`float`, *optional*, defaults to 1e-6):
|
| 249 |
Epsilon for L2 normalisation of modulation vectors.
|
| 250 |
+
use_hadamard_o_proj (:obj:`bool`, *optional*, defaults to ``False``):
|
| 251 |
+
Replace the dense ``W_O ∈ R^{d×d}`` output projection in every
|
| 252 |
+
multi-head attention block with a fixed Walsh–Hadamard Transform
|
| 253 |
+
followed by a learnable per-channel affine rescaling
|
| 254 |
+
``α ⊙ FWHT(x)/√d + β``.
|
| 255 |
+
|
| 256 |
+
The WHT is a parameter-free orthogonal matrix whose singular values
|
| 257 |
+
are all identically 1, so the effective condition number is
|
| 258 |
+
``κ = 1`` by construction and cannot grow during training. This
|
| 259 |
+
directly addresses the high-κ pathology (κ up to 10^5) observed in
|
| 260 |
+
the dense ``o_proj`` matrices, which causes FP8 per-tensor
|
| 261 |
+
quantisation to lose low-magnitude directions entirely.
|
| 262 |
+
|
| 263 |
+
Parameter reduction: replaces ``d²`` weights with ``2d``
|
| 264 |
+
(``α`` and ``β``), saving ≈25% of attention parameters per block.
|
| 265 |
+
Requires ``hidden_size`` to be a power of 2 (512 ✓, 1024 ✓,
|
| 266 |
+
768 ✗).
|
| 267 |
+
|
| 268 |
+
Reference: Aggarwal & Kumar (2026). *Rethinking Attention Output
|
| 269 |
+
Projection: Structured Hadamard Transforms for Efficient
|
| 270 |
+
Transformers.* arXiv:2603.08343.
|
| 271 |
+
|
| 272 |
|
| 273 |
Constraints:
|
| 274 |
- ``use_jtokm=True`` requires ``use_token_generator=True``.
|
|
|
|
| 375 |
jtokm_num_modes=4,
|
| 376 |
jtokm_aux_loss_weight=1e-4,
|
| 377 |
jtokm_norm_eps=1e-6,
|
| 378 |
+
# ── Hadamard output projection (Aggarwal & Kumar, 2026) ───────────
|
| 379 |
+
use_hadamard_o_proj=True,
|
| 380 |
**kwargs,
|
| 381 |
):
|
| 382 |
# ── Generator / tying consistency ─────────────────────────────────
|
|
|
|
| 487 |
self.jtokm_aux_loss_weight = jtokm_aux_loss_weight
|
| 488 |
self.jtokm_norm_eps = jtokm_norm_eps
|
| 489 |
|
| 490 |
+
# ── Hadamard output projection (Aggarwal & Kumar, 2026) ───────────
|
| 491 |
+
self.use_hadamard_o_proj = use_hadamard_o_proj
|
| 492 |
+
|
| 493 |
self.auto_map = {
|
| 494 |
"AutoConfig": "configuration_neollm.NeoLLMConfig",
|
| 495 |
"AutoModel": "modeling_neollm.NeoLLMModel",
|