KitsuVp commited on
Commit
fe8dd15
·
verified ·
1 Parent(s): 758569d

Update modeling_neollm.py

Browse files
Files changed (1) hide show
  1. modeling_neollm.py +158 -11
modeling_neollm.py CHANGED
@@ -148,6 +148,23 @@ class PolyNormAnalysis:
148
  output: Optional[torch.Tensor] = None # final PolyNorm output
149
 
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  @dataclass
152
  class AttentionAnalysis:
153
  """
@@ -218,6 +235,9 @@ class AttentionAnalysis:
218
  attn_output_pre_gate: Optional[torch.Tensor] = None # pre gate multiply [B,S,H,d]
219
  attn_output_final: Optional[torch.Tensor] = None # after o_proj [B,S,D]
220
 
 
 
 
221
 
222
  @dataclass
223
  class MLPAnalysis:
@@ -1617,15 +1637,110 @@ def affine_scaled_flash_attention_forward(
1617
  # ── Combine and apply dropout to the full affine output ───────────────
1618
  output = alpha_t * flash_out + beta_t * v_cumsum_t # [B, S, H_q, d_head]
1619
 
1620
- # Apply output dropout on the combined affine result.
1621
- # This regularises the full [α·flash + β·V_cumsum] output consistently.
1622
- if dropout > 0.0 and module.training:
1623
- output = nn.functional.dropout(output, p=dropout, training=True)
1624
-
1625
  # attn_weights is None — flash never exposes the softmax weight matrix.
1626
  return output, None
1627
 
1628
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1629
  class NeoLLMAttention(nn.Module):
1630
  """
1631
  Full attention with FANformer, SeeDNorm, ResFormer, Learnable Multipliers,
@@ -1644,8 +1759,16 @@ class NeoLLMAttention(nn.Module):
1644
  → MEAHeadSeeDNorm → XSA → Directional Routing → reshape
1645
  → o_proj · sigmoid(gate) → dropout
1646
 
 
 
 
 
 
 
 
1647
  References:
1648
  Directional Routing: Taylor (2026). arXiv:2603.14923.
 
1649
  """
1650
 
1651
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
@@ -1674,6 +1797,7 @@ class NeoLLMAttention(nn.Module):
1674
  self.lucid_attention_eps = float(
1675
  getattr(config, "lucid_attention_eps", config.rms_norm_eps)
1676
  )
 
1677
 
1678
  self.fan_layer = FANLayer(
1679
  hidden_size=config.hidden_size,
@@ -1699,10 +1823,23 @@ class NeoLLMAttention(nn.Module):
1699
  fan_output_dim, self.num_mea_component_heads * self.head_dim,
1700
  bias=config.attention_bias,
1701
  )
1702
- self.o_proj = LinearWithMultipliers(
1703
- config.num_attention_heads * self.head_dim, config.hidden_size,
1704
- bias=config.attention_bias, use_row_multiplier=True, use_column_multiplier=True,
1705
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
1706
 
1707
  self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
1708
  self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
@@ -2089,7 +2226,12 @@ class NeoLLMAttention(nn.Module):
2089
  attn_analysis.attn_output_pre_gate = attn_out_flat.detach()
2090
  gate_sig = torch.sigmoid(gate)
2091
  attn_analysis.gate_sigmoid = gate_sig.detach()
2092
- attn_out_gated = self.o_proj(attn_out_flat * gate_sig)
 
 
 
 
 
2093
  else:
2094
  attn_out_gated = self.o_proj(attn_out_flat * torch.sigmoid(gate))
2095
 
@@ -2644,7 +2786,10 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
2644
  return LayerAnalysis(
2645
  seednorm_pre_attn = SeeDNormAnalysis(),
2646
  seednorm_post_attn = SeeDNormAnalysis(),
2647
- attention = AttentionAnalysis(fan=FANAnalysis()),
 
 
 
2648
  mlp = MLPAnalysis(
2649
  fan = FANAnalysis(),
2650
  polynorm = PolyNormAnalysis(),
@@ -3098,6 +3243,7 @@ __all__ = [
3098
  "VectorMultiplier",
3099
  "LinearWithMultipliers",
3100
  "MEAHeadSeeDNorm",
 
3101
  # Analysis dataclasses — exported so external tools can type-hint against them
3102
  "AnalysisState",
3103
  "LayerAnalysis",
@@ -3107,6 +3253,7 @@ __all__ = [
3107
  "SeeDNormAnalysis",
3108
  "GPASAnalysis",
3109
  "PolyNormAnalysis",
 
3110
  "JTokMAnalysis",
3111
  "AttnResAnalysis",
3112
  "GeneratorAnalysis",
 
148
  output: Optional[torch.Tensor] = None # final PolyNorm output
149
 
150
 
151
+ @dataclass
152
+ class HadamardAnalysis:
153
+ """
154
+ Internals of a HadamardOProj forward pass.
155
+ Only populated when use_hadamard_o_proj=True.
156
+
157
+ Reference: Aggarwal & Kumar (2026). arXiv:2603.08343.
158
+
159
+ post_fwht: WHT output before α scaling [..., D] — useful to verify
160
+ that the transform is truly norm-preserving (κ=1 sanity check).
161
+ alpha_snapshot: detached copy of the learnable α vector [D] — tracks how
162
+ per-channel scaling evolves during training analysis.
163
+ """
164
+ post_fwht: Optional[torch.Tensor] = None # FWHT(x)/√d before α [B,S,D]
165
+ alpha_snapshot: Optional[torch.Tensor] = None # self.alpha [D] — learned scale
166
+
167
+
168
  @dataclass
169
  class AttentionAnalysis:
170
  """
 
235
  attn_output_pre_gate: Optional[torch.Tensor] = None # pre gate multiply [B,S,H,d]
236
  attn_output_final: Optional[torch.Tensor] = None # after o_proj [B,S,D]
237
 
238
+ # ── HadamardOProj internals (conditional on use_hadamard_o_proj) ──
239
+ hadamard: Optional["HadamardAnalysis"] = None # None when dense o_proj active
240
+
241
 
242
  @dataclass
243
  class MLPAnalysis:
 
1637
  # ── Combine and apply dropout to the full affine output ───────────────
1638
  output = alpha_t * flash_out + beta_t * v_cumsum_t # [B, S, H_q, d_head]
1639
 
 
 
 
 
 
1640
  # attn_weights is None — flash never exposes the softmax weight matrix.
1641
  return output, None
1642
 
1643
 
1644
+ class HadamardOProj(nn.Module):
1645
+ """
1646
+ Parameter-free Walsh–Hadamard output projection with learnable affine rescaling.
1647
+
1648
+ Replaces the dense W_O ∈ R^{d×d} in multi-head attention with a fixed
1649
+ orthogonal Walsh–Hadamard Transform followed by a per-channel learnable
1650
+ affine: output = α ⊙ FWHT(x) + β
1651
+
1652
+ Motivation (Aggarwal & Kumar, 2026, arXiv:2603.08343):
1653
+ The standard dense o_proj develops extreme condition numbers during
1654
+ training (κ up to 10^5 observed in practice) because the optimiser has
1655
+ no incentive to keep singular values balanced — some directions are
1656
+ amplified while others collapse toward zero. This makes the layer
1657
+ hostile to FP8 quantisation, which uses a single per-tensor scale and
1658
+ therefore loses the low-magnitude directions entirely.
1659
+
1660
+ The Walsh–Hadamard Transform is a fixed orthogonal matrix whose
1661
+ singular values are all identically 1, making κ = 1 by construction.
1662
+ It cannot develop condition-number pathology because it has no
1663
+ parameters. The learnable α/β restore per-channel expressivity at
1664
+ a cost of 2·d parameters instead of d².
1665
+
1666
+ Properties:
1667
+ - Condition number: κ = 1 (exact, permanent, by construction)
1668
+ - Parameters: 2·d vs d² for dense (~25% attention params saved)
1669
+ - Forward FLOPs: O(d log d) vs O(d²) for dense
1670
+ - Norm preservation: FWHT is isometric — ‖FWHT(x)‖₂ = ‖x‖₂
1671
+ - FP8 friendliness: single per-tensor scale covers all directions equally
1672
+ - Requires: d must be a power of 2
1673
+
1674
+ The FWHT is implemented as an in-place iterative butterfly (Cooley-Tukey
1675
+ pattern over additions/subtractions) followed by 1/√d normalisation to
1676
+ produce an orthonormal transform (H^T H = I). No external dependency.
1677
+
1678
+ Reference:
1679
+ Aggarwal, S. & Kumar, L. (2026). "Rethinking Attention Output
1680
+ Projection: Structured Hadamard Transforms for Efficient Transformers."
1681
+ arXiv:2603.08343.
1682
+ """
1683
+
1684
+ def __init__(self, dim: int, bias: bool = True):
1685
+ super().__init__()
1686
+ assert dim > 0 and (dim & (dim - 1)) == 0, (
1687
+ f"HadamardOProj requires dim to be a power of 2, got {dim}"
1688
+ )
1689
+ self.dim = dim
1690
+ self.norm = dim ** -0.5 # 1/√d — makes H^T H = I
1691
+
1692
+ # Learnable affine rescaling: α ⊙ FWHT(x) + β
1693
+ # Initialised to α=1, β=0 so the layer starts as a pure WHT,
1694
+ # identical to an orthonormal projection with unit gain.
1695
+ self.alpha = nn.Parameter(torch.ones(dim))
1696
+ self.beta = nn.Parameter(torch.zeros(dim)) if bias else None
1697
+
1698
+ def _fwht(self, x: torch.Tensor) -> torch.Tensor:
1699
+ """
1700
+ Iterative in-place Fast Walsh–Hadamard Transform over the last dim.
1701
+
1702
+ Butterfly pattern: log₂(d) stages, each pairing elements at stride h.
1703
+ Cost: d·log₂(d) additions/subtractions, zero multiplications.
1704
+ Compatible with torch.compile — all shapes are static, no Python loops
1705
+ visible to the tracer once d is fixed.
1706
+ """
1707
+ h = 1
1708
+ while h < self.dim:
1709
+ # Reshape to expose pairs at current stride
1710
+ x = x.reshape(*x.shape[:-1], -1, 2 * h)
1711
+ a, b = x[..., :h], x[..., h:]
1712
+ # Butterfly: (a+b, a-b) — only additions and subtractions
1713
+ x = torch.cat([a + b, a - b], dim=-1)
1714
+ x = x.reshape(*x.shape[:-2], self.dim)
1715
+ h *= 2
1716
+ return x
1717
+
1718
+ def forward(
1719
+ self,
1720
+ x: torch.Tensor,
1721
+ analysis: Optional["HadamardAnalysis"] = None,
1722
+ ) -> torch.Tensor:
1723
+ """
1724
+ Args:
1725
+ x: [..., dim] — concatenated multi-head attention outputs
1726
+ analysis: HadamardAnalysis container populated when analysis mode
1727
+ is active (eval + model.enable_analysis()). None otherwise.
1728
+
1729
+ Returns:
1730
+ α ⊙ (FWHT(x) / √dim) + β of shape [..., dim]
1731
+ """
1732
+ out = self._fwht(x) * self.norm # normalise: H^T H = I
1733
+
1734
+ if analysis is not None:
1735
+ analysis.post_fwht = out.detach()
1736
+ analysis.alpha_snapshot = self.alpha.detach()
1737
+
1738
+ out = out * self.alpha # per-channel learnable scale
1739
+ if self.beta is not None:
1740
+ out = out + self.beta # per-channel learnable bias
1741
+ return out
1742
+
1743
+
1744
  class NeoLLMAttention(nn.Module):
1745
  """
1746
  Full attention with FANformer, SeeDNorm, ResFormer, Learnable Multipliers,
 
1759
  → MEAHeadSeeDNorm → XSA → Directional Routing → reshape
1760
  → o_proj · sigmoid(gate) → dropout
1761
 
1762
+ o_proj variants (controlled by config.use_hadamard_o_proj):
1763
+ False (default): dense LinearWithMultipliers — full expressivity,
1764
+ develops high κ during training (FP8 risk).
1765
+ True: HadamardOProj — fixed WHT + learnable α/β,
1766
+ κ = 1 by construction, 25% fewer attention params,
1767
+ FP8-friendly (Aggarwal & Kumar, 2026, arXiv:2603.08343).
1768
+
1769
  References:
1770
  Directional Routing: Taylor (2026). arXiv:2603.14923.
1771
+ Hadamard o_proj: Aggarwal & Kumar (2026). arXiv:2603.08343.
1772
  """
1773
 
1774
  def __init__(self, config: NeoLLMConfig, layer_idx: int):
 
1797
  self.lucid_attention_eps = float(
1798
  getattr(config, "lucid_attention_eps", config.rms_norm_eps)
1799
  )
1800
+ self.use_hadamard_o_proj = getattr(config, "use_hadamard_o_proj", False)
1801
 
1802
  self.fan_layer = FANLayer(
1803
  hidden_size=config.hidden_size,
 
1823
  fan_output_dim, self.num_mea_component_heads * self.head_dim,
1824
  bias=config.attention_bias,
1825
  )
1826
+ # ── Output projection (Aggarwal & Kumar, 2026, arXiv:2603.08343) ────
1827
+ # use_hadamard_o_proj=False (default): dense LinearWithMultipliers.
1828
+ # use_hadamard_o_proj=True: HadamardOProj — fixed WHT + learnable α/β.
1829
+ # κ = 1 by construction, 25% fewer attention params, FP8-friendly.
1830
+ # Requires hidden_size to be a power of 2 (512 ✓, 1024 ✓, 768 ✗).
1831
+ _o_in = config.num_attention_heads * self.head_dim
1832
+ if self.use_hadamard_o_proj:
1833
+ assert _o_in == config.hidden_size, (
1834
+ f"HadamardOProj requires in_dim == out_dim, "
1835
+ f"got {_o_in} vs {config.hidden_size}"
1836
+ )
1837
+ self.o_proj = HadamardOProj(config.hidden_size, bias=config.attention_bias)
1838
+ else:
1839
+ self.o_proj = LinearWithMultipliers(
1840
+ _o_in, config.hidden_size,
1841
+ bias=config.attention_bias, use_row_multiplier=True, use_column_multiplier=True,
1842
+ )
1843
 
1844
  self.q_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
1845
  self.k_norm = SeeDNorm(self.head_dim, eps=config.rms_norm_eps)
 
2226
  attn_analysis.attn_output_pre_gate = attn_out_flat.detach()
2227
  gate_sig = torch.sigmoid(gate)
2228
  attn_analysis.gate_sigmoid = gate_sig.detach()
2229
+ gated = attn_out_flat * gate_sig
2230
+ if self.use_hadamard_o_proj:
2231
+ # Pass HadamardAnalysis sub-object so post_fwht and alpha are captured
2232
+ attn_out_gated = self.o_proj(gated, analysis=attn_analysis.hadamard)
2233
+ else:
2234
+ attn_out_gated = self.o_proj(gated)
2235
  else:
2236
  attn_out_gated = self.o_proj(attn_out_flat * torch.sigmoid(gate))
2237
 
 
2786
  return LayerAnalysis(
2787
  seednorm_pre_attn = SeeDNormAnalysis(),
2788
  seednorm_post_attn = SeeDNormAnalysis(),
2789
+ attention = AttentionAnalysis(
2790
+ fan = FANAnalysis(),
2791
+ hadamard = HadamardAnalysis() if getattr(cfg, "use_hadamard_o_proj", False) else None,
2792
+ ),
2793
  mlp = MLPAnalysis(
2794
  fan = FANAnalysis(),
2795
  polynorm = PolyNormAnalysis(),
 
3243
  "VectorMultiplier",
3244
  "LinearWithMultipliers",
3245
  "MEAHeadSeeDNorm",
3246
+ "HadamardOProj",
3247
  # Analysis dataclasses — exported so external tools can type-hint against them
3248
  "AnalysisState",
3249
  "LayerAnalysis",
 
3253
  "SeeDNormAnalysis",
3254
  "GPASAnalysis",
3255
  "PolyNormAnalysis",
3256
+ "HadamardAnalysis",
3257
  "JTokMAnalysis",
3258
  "AttnResAnalysis",
3259
  "GeneratorAnalysis",