Update modeling_neollm.py
Browse files- modeling_neollm.py +158 -11
modeling_neollm.py
CHANGED
|
@@ -148,6 +148,23 @@ class PolyNormAnalysis:
|
|
| 148 |
output: Optional[torch.Tensor] = None # final PolyNorm output
|
| 149 |
|
| 150 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
@dataclass
|
| 152 |
class AttentionAnalysis:
|
| 153 |
"""
|
|
@@ -218,6 +235,9 @@ class AttentionAnalysis:
|
|
| 218 |
attn_output_pre_gate: Optional[torch.Tensor] = None # pre gate multiply [B,S,H,d]
|
| 219 |
attn_output_final: Optional[torch.Tensor] = None # after o_proj [B,S,D]
|
| 220 |
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
@dataclass
|
| 223 |
class MLPAnalysis:
|
|
@@ -1617,15 +1637,110 @@ def affine_scaled_flash_attention_forward(
|
|
| 1617 |
# ── Combine and apply dropout to the full affine output ───────────────
|
| 1618 |
output = alpha_t * flash_out + beta_t * v_cumsum_t # [B, S, H_q, d_head]
|
| 1619 |
|
| 1620 |
-
# Apply output dropout on the combined affine result.
|
| 1621 |
-
# This regularises the full [α·flash + β·V_cumsum] output consistently.
|
| 1622 |
-
if dropout > 0.0 and module.training:
|
| 1623 |
-
output = nn.functional.dropout(output, p=dropout, training=True)
|
| 1624 |
-
|
| 1625 |
# attn_weights is None — flash never exposes the softmax weight matrix.
|
| 1626 |
return output, None
|
| 1627 |
|
| 1628 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1629 |
class NeoLLMAttention(nn.Module):
|
| 1630 |
"""
|
| 1631 |
Full attention with FANformer, SeeDNorm, ResFormer, Learnable Multipliers,
|
|
@@ -1644,8 +1759,16 @@ class NeoLLMAttention(nn.Module):
|
|
| 1644 |
→ MEAHeadSeeDNorm → XSA → Directional Routing → reshape
|
| 1645 |
→ o_proj · sigmoid(gate) → dropout
|
| 1646 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1647 |
References:
|
| 1648 |
Directional Routing: Taylor (2026). arXiv:2603.14923.
|
|
|
|
| 1649 |
"""
|
| 1650 |
|
| 1651 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
|
@@ -1674,6 +1797,7 @@ class NeoLLMAttention(nn.Module):
|
|
| 1674 |
self.lucid_attention_eps = float(
|
| 1675 |
getattr(config, "lucid_attention_eps", config.rms_norm_eps)
|
| 1676 |
)
|
|
|
|
| 1677 |
|
| 1678 |
self.fan_layer = FANLayer(
|
| 1679 |
hidden_size=config.hidden_size,
|
|
@@ -1699,10 +1823,23 @@ class NeoLLMAttention(nn.Module):
|
|
| 1699 |
fan_output_dim, self.num_mea_component_heads * self.head_dim,
|
| 1700 |
bias=config.attention_bias,
|
| 1701 |
)
|
| 1702 |
-
|
| 1703 |
-
|
| 1704 |
-
|
| 1705 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1706 |
|
| 1707 |
self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 1708 |
self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
@@ -2089,7 +2226,12 @@ class NeoLLMAttention(nn.Module):
|
|
| 2089 |
attn_analysis.attn_output_pre_gate = attn_out_flat.detach()
|
| 2090 |
gate_sig = torch.sigmoid(gate)
|
| 2091 |
attn_analysis.gate_sigmoid = gate_sig.detach()
|
| 2092 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2093 |
else:
|
| 2094 |
attn_out_gated = self.o_proj(attn_out_flat * torch.sigmoid(gate))
|
| 2095 |
|
|
@@ -2644,7 +2786,10 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 2644 |
return LayerAnalysis(
|
| 2645 |
seednorm_pre_attn = SeeDNormAnalysis(),
|
| 2646 |
seednorm_post_attn = SeeDNormAnalysis(),
|
| 2647 |
-
attention = AttentionAnalysis(
|
|
|
|
|
|
|
|
|
|
| 2648 |
mlp = MLPAnalysis(
|
| 2649 |
fan = FANAnalysis(),
|
| 2650 |
polynorm = PolyNormAnalysis(),
|
|
@@ -3098,6 +3243,7 @@ __all__ = [
|
|
| 3098 |
"VectorMultiplier",
|
| 3099 |
"LinearWithMultipliers",
|
| 3100 |
"MEAHeadSeeDNorm",
|
|
|
|
| 3101 |
# Analysis dataclasses — exported so external tools can type-hint against them
|
| 3102 |
"AnalysisState",
|
| 3103 |
"LayerAnalysis",
|
|
@@ -3107,6 +3253,7 @@ __all__ = [
|
|
| 3107 |
"SeeDNormAnalysis",
|
| 3108 |
"GPASAnalysis",
|
| 3109 |
"PolyNormAnalysis",
|
|
|
|
| 3110 |
"JTokMAnalysis",
|
| 3111 |
"AttnResAnalysis",
|
| 3112 |
"GeneratorAnalysis",
|
|
|
|
| 148 |
output: Optional[torch.Tensor] = None # final PolyNorm output
|
| 149 |
|
| 150 |
|
| 151 |
+
@dataclass
|
| 152 |
+
class HadamardAnalysis:
|
| 153 |
+
"""
|
| 154 |
+
Internals of a HadamardOProj forward pass.
|
| 155 |
+
Only populated when use_hadamard_o_proj=True.
|
| 156 |
+
|
| 157 |
+
Reference: Aggarwal & Kumar (2026). arXiv:2603.08343.
|
| 158 |
+
|
| 159 |
+
post_fwht: WHT output before α scaling [..., D] — useful to verify
|
| 160 |
+
that the transform is truly norm-preserving (κ=1 sanity check).
|
| 161 |
+
alpha_snapshot: detached copy of the learnable α vector [D] — tracks how
|
| 162 |
+
per-channel scaling evolves during training analysis.
|
| 163 |
+
"""
|
| 164 |
+
post_fwht: Optional[torch.Tensor] = None # FWHT(x)/√d before α [B,S,D]
|
| 165 |
+
alpha_snapshot: Optional[torch.Tensor] = None # self.alpha [D] — learned scale
|
| 166 |
+
|
| 167 |
+
|
| 168 |
@dataclass
|
| 169 |
class AttentionAnalysis:
|
| 170 |
"""
|
|
|
|
| 235 |
attn_output_pre_gate: Optional[torch.Tensor] = None # pre gate multiply [B,S,H,d]
|
| 236 |
attn_output_final: Optional[torch.Tensor] = None # after o_proj [B,S,D]
|
| 237 |
|
| 238 |
+
# ── HadamardOProj internals (conditional on use_hadamard_o_proj) ──
|
| 239 |
+
hadamard: Optional["HadamardAnalysis"] = None # None when dense o_proj active
|
| 240 |
+
|
| 241 |
|
| 242 |
@dataclass
|
| 243 |
class MLPAnalysis:
|
|
|
|
| 1637 |
# ── Combine and apply dropout to the full affine output ───────────────
|
| 1638 |
output = alpha_t * flash_out + beta_t * v_cumsum_t # [B, S, H_q, d_head]
|
| 1639 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1640 |
# attn_weights is None — flash never exposes the softmax weight matrix.
|
| 1641 |
return output, None
|
| 1642 |
|
| 1643 |
|
| 1644 |
+
class HadamardOProj(nn.Module):
|
| 1645 |
+
"""
|
| 1646 |
+
Parameter-free Walsh–Hadamard output projection with learnable affine rescaling.
|
| 1647 |
+
|
| 1648 |
+
Replaces the dense W_O ∈ R^{d×d} in multi-head attention with a fixed
|
| 1649 |
+
orthogonal Walsh–Hadamard Transform followed by a per-channel learnable
|
| 1650 |
+
affine: output = α ⊙ FWHT(x) + β
|
| 1651 |
+
|
| 1652 |
+
Motivation (Aggarwal & Kumar, 2026, arXiv:2603.08343):
|
| 1653 |
+
The standard dense o_proj develops extreme condition numbers during
|
| 1654 |
+
training (κ up to 10^5 observed in practice) because the optimiser has
|
| 1655 |
+
no incentive to keep singular values balanced — some directions are
|
| 1656 |
+
amplified while others collapse toward zero. This makes the layer
|
| 1657 |
+
hostile to FP8 quantisation, which uses a single per-tensor scale and
|
| 1658 |
+
therefore loses the low-magnitude directions entirely.
|
| 1659 |
+
|
| 1660 |
+
The Walsh–Hadamard Transform is a fixed orthogonal matrix whose
|
| 1661 |
+
singular values are all identically 1, making κ = 1 by construction.
|
| 1662 |
+
It cannot develop condition-number pathology because it has no
|
| 1663 |
+
parameters. The learnable α/β restore per-channel expressivity at
|
| 1664 |
+
a cost of 2·d parameters instead of d².
|
| 1665 |
+
|
| 1666 |
+
Properties:
|
| 1667 |
+
- Condition number: κ = 1 (exact, permanent, by construction)
|
| 1668 |
+
- Parameters: 2·d vs d² for dense (~25% attention params saved)
|
| 1669 |
+
- Forward FLOPs: O(d log d) vs O(d²) for dense
|
| 1670 |
+
- Norm preservation: FWHT is isometric — ‖FWHT(x)‖₂ = ‖x‖₂
|
| 1671 |
+
- FP8 friendliness: single per-tensor scale covers all directions equally
|
| 1672 |
+
- Requires: d must be a power of 2
|
| 1673 |
+
|
| 1674 |
+
The FWHT is implemented as an in-place iterative butterfly (Cooley-Tukey
|
| 1675 |
+
pattern over additions/subtractions) followed by 1/√d normalisation to
|
| 1676 |
+
produce an orthonormal transform (H^T H = I). No external dependency.
|
| 1677 |
+
|
| 1678 |
+
Reference:
|
| 1679 |
+
Aggarwal, S. & Kumar, L. (2026). "Rethinking Attention Output
|
| 1680 |
+
Projection: Structured Hadamard Transforms for Efficient Transformers."
|
| 1681 |
+
arXiv:2603.08343.
|
| 1682 |
+
"""
|
| 1683 |
+
|
| 1684 |
+
def __init__(self, dim: int, bias: bool = True):
|
| 1685 |
+
super().__init__()
|
| 1686 |
+
assert dim > 0 and (dim & (dim - 1)) == 0, (
|
| 1687 |
+
f"HadamardOProj requires dim to be a power of 2, got {dim}"
|
| 1688 |
+
)
|
| 1689 |
+
self.dim = dim
|
| 1690 |
+
self.norm = dim ** -0.5 # 1/√d — makes H^T H = I
|
| 1691 |
+
|
| 1692 |
+
# Learnable affine rescaling: α ⊙ FWHT(x) + β
|
| 1693 |
+
# Initialised to α=1, β=0 so the layer starts as a pure WHT,
|
| 1694 |
+
# identical to an orthonormal projection with unit gain.
|
| 1695 |
+
self.alpha = nn.Parameter(torch.ones(dim))
|
| 1696 |
+
self.beta = nn.Parameter(torch.zeros(dim)) if bias else None
|
| 1697 |
+
|
| 1698 |
+
def _fwht(self, x: torch.Tensor) -> torch.Tensor:
|
| 1699 |
+
"""
|
| 1700 |
+
Iterative in-place Fast Walsh–Hadamard Transform over the last dim.
|
| 1701 |
+
|
| 1702 |
+
Butterfly pattern: log₂(d) stages, each pairing elements at stride h.
|
| 1703 |
+
Cost: d·log₂(d) additions/subtractions, zero multiplications.
|
| 1704 |
+
Compatible with torch.compile — all shapes are static, no Python loops
|
| 1705 |
+
visible to the tracer once d is fixed.
|
| 1706 |
+
"""
|
| 1707 |
+
h = 1
|
| 1708 |
+
while h < self.dim:
|
| 1709 |
+
# Reshape to expose pairs at current stride
|
| 1710 |
+
x = x.reshape(*x.shape[:-1], -1, 2 * h)
|
| 1711 |
+
a, b = x[..., :h], x[..., h:]
|
| 1712 |
+
# Butterfly: (a+b, a-b) — only additions and subtractions
|
| 1713 |
+
x = torch.cat([a + b, a - b], dim=-1)
|
| 1714 |
+
x = x.reshape(*x.shape[:-2], self.dim)
|
| 1715 |
+
h *= 2
|
| 1716 |
+
return x
|
| 1717 |
+
|
| 1718 |
+
def forward(
|
| 1719 |
+
self,
|
| 1720 |
+
x: torch.Tensor,
|
| 1721 |
+
analysis: Optional["HadamardAnalysis"] = None,
|
| 1722 |
+
) -> torch.Tensor:
|
| 1723 |
+
"""
|
| 1724 |
+
Args:
|
| 1725 |
+
x: [..., dim] — concatenated multi-head attention outputs
|
| 1726 |
+
analysis: HadamardAnalysis container populated when analysis mode
|
| 1727 |
+
is active (eval + model.enable_analysis()). None otherwise.
|
| 1728 |
+
|
| 1729 |
+
Returns:
|
| 1730 |
+
α ⊙ (FWHT(x) / √dim) + β of shape [..., dim]
|
| 1731 |
+
"""
|
| 1732 |
+
out = self._fwht(x) * self.norm # normalise: H^T H = I
|
| 1733 |
+
|
| 1734 |
+
if analysis is not None:
|
| 1735 |
+
analysis.post_fwht = out.detach()
|
| 1736 |
+
analysis.alpha_snapshot = self.alpha.detach()
|
| 1737 |
+
|
| 1738 |
+
out = out * self.alpha # per-channel learnable scale
|
| 1739 |
+
if self.beta is not None:
|
| 1740 |
+
out = out + self.beta # per-channel learnable bias
|
| 1741 |
+
return out
|
| 1742 |
+
|
| 1743 |
+
|
| 1744 |
class NeoLLMAttention(nn.Module):
|
| 1745 |
"""
|
| 1746 |
Full attention with FANformer, SeeDNorm, ResFormer, Learnable Multipliers,
|
|
|
|
| 1759 |
→ MEAHeadSeeDNorm → XSA → Directional Routing → reshape
|
| 1760 |
→ o_proj · sigmoid(gate) → dropout
|
| 1761 |
|
| 1762 |
+
o_proj variants (controlled by config.use_hadamard_o_proj):
|
| 1763 |
+
False (default): dense LinearWithMultipliers — full expressivity,
|
| 1764 |
+
develops high κ during training (FP8 risk).
|
| 1765 |
+
True: HadamardOProj — fixed WHT + learnable α/β,
|
| 1766 |
+
κ = 1 by construction, 25% fewer attention params,
|
| 1767 |
+
FP8-friendly (Aggarwal & Kumar, 2026, arXiv:2603.08343).
|
| 1768 |
+
|
| 1769 |
References:
|
| 1770 |
Directional Routing: Taylor (2026). arXiv:2603.14923.
|
| 1771 |
+
Hadamard o_proj: Aggarwal & Kumar (2026). arXiv:2603.08343.
|
| 1772 |
"""
|
| 1773 |
|
| 1774 |
def __init__(self, config: NeoLLMConfig, layer_idx: int):
|
|
|
|
| 1797 |
self.lucid_attention_eps = float(
|
| 1798 |
getattr(config, "lucid_attention_eps", config.rms_norm_eps)
|
| 1799 |
)
|
| 1800 |
+
self.use_hadamard_o_proj = getattr(config, "use_hadamard_o_proj", False)
|
| 1801 |
|
| 1802 |
self.fan_layer = FANLayer(
|
| 1803 |
hidden_size=config.hidden_size,
|
|
|
|
| 1823 |
fan_output_dim, self.num_mea_component_heads * self.head_dim,
|
| 1824 |
bias=config.attention_bias,
|
| 1825 |
)
|
| 1826 |
+
# ── Output projection (Aggarwal & Kumar, 2026, arXiv:2603.08343) ────
|
| 1827 |
+
# use_hadamard_o_proj=False (default): dense LinearWithMultipliers.
|
| 1828 |
+
# use_hadamard_o_proj=True: HadamardOProj — fixed WHT + learnable α/β.
|
| 1829 |
+
# κ = 1 by construction, 25% fewer attention params, FP8-friendly.
|
| 1830 |
+
# Requires hidden_size to be a power of 2 (512 ✓, 1024 ✓, 768 ✗).
|
| 1831 |
+
_o_in = config.num_attention_heads * self.head_dim
|
| 1832 |
+
if self.use_hadamard_o_proj:
|
| 1833 |
+
assert _o_in == config.hidden_size, (
|
| 1834 |
+
f"HadamardOProj requires in_dim == out_dim, "
|
| 1835 |
+
f"got {_o_in} vs {config.hidden_size}"
|
| 1836 |
+
)
|
| 1837 |
+
self.o_proj = HadamardOProj(config.hidden_size, bias=config.attention_bias)
|
| 1838 |
+
else:
|
| 1839 |
+
self.o_proj = LinearWithMultipliers(
|
| 1840 |
+
_o_in, config.hidden_size,
|
| 1841 |
+
bias=config.attention_bias, use_row_multiplier=True, use_column_multiplier=True,
|
| 1842 |
+
)
|
| 1843 |
|
| 1844 |
self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
|
| 1845 |
self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
|
|
| 2226 |
attn_analysis.attn_output_pre_gate = attn_out_flat.detach()
|
| 2227 |
gate_sig = torch.sigmoid(gate)
|
| 2228 |
attn_analysis.gate_sigmoid = gate_sig.detach()
|
| 2229 |
+
gated = attn_out_flat * gate_sig
|
| 2230 |
+
if self.use_hadamard_o_proj:
|
| 2231 |
+
# Pass HadamardAnalysis sub-object so post_fwht and alpha are captured
|
| 2232 |
+
attn_out_gated = self.o_proj(gated, analysis=attn_analysis.hadamard)
|
| 2233 |
+
else:
|
| 2234 |
+
attn_out_gated = self.o_proj(gated)
|
| 2235 |
else:
|
| 2236 |
attn_out_gated = self.o_proj(attn_out_flat * torch.sigmoid(gate))
|
| 2237 |
|
|
|
|
| 2786 |
return LayerAnalysis(
|
| 2787 |
seednorm_pre_attn = SeeDNormAnalysis(),
|
| 2788 |
seednorm_post_attn = SeeDNormAnalysis(),
|
| 2789 |
+
attention = AttentionAnalysis(
|
| 2790 |
+
fan = FANAnalysis(),
|
| 2791 |
+
hadamard = HadamardAnalysis() if getattr(cfg, "use_hadamard_o_proj", False) else None,
|
| 2792 |
+
),
|
| 2793 |
mlp = MLPAnalysis(
|
| 2794 |
fan = FANAnalysis(),
|
| 2795 |
polynorm = PolyNormAnalysis(),
|
|
|
|
| 3243 |
"VectorMultiplier",
|
| 3244 |
"LinearWithMultipliers",
|
| 3245 |
"MEAHeadSeeDNorm",
|
| 3246 |
+
"HadamardOProj",
|
| 3247 |
# Analysis dataclasses — exported so external tools can type-hint against them
|
| 3248 |
"AnalysisState",
|
| 3249 |
"LayerAnalysis",
|
|
|
|
| 3253 |
"SeeDNormAnalysis",
|
| 3254 |
"GPASAnalysis",
|
| 3255 |
"PolyNormAnalysis",
|
| 3256 |
+
"HadamardAnalysis",
|
| 3257 |
"JTokMAnalysis",
|
| 3258 |
"AttnResAnalysis",
|
| 3259 |
"GeneratorAnalysis",
|