File size: 11,799 Bytes
c475135 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 | """Training-safe filter cache for HyenaOperator.
**What this validates:**
When `HYDRA_HYENA_TRAIN_CACHE=1`, the filter MLP must:
1. Run EXACTLY ONCE per optimizer step, not once per micro-batch.
2. Produce gradients on its params that match the uncached path to within
bf16 tolerance (we use fp32 CPU tensors here, so atol should be tight).
3. Not trip `RuntimeError: Trying to backward through the graph a second time`
under the grad-accum pattern.
**Design under test:**
`HyenaFilter.get_or_build_train_cache(L, fft_size)` returns a LEAF tensor
`k_leaf` whose grad accumulates across micro-batches. After all micro-batch
backwards, `flush_pending_filter_grads()` does one
`torch.autograd.backward(_k_graph, _k_leaf.grad)` to populate the filter
MLP params' `.grad`. Then `invalidate_cache()` resets state for the next
step.
Run:
cd /home/mikeb/work/feather
.venv/bin/pytest tests/test_hyena_train_cache.py -v
"""
from __future__ import annotations
import sys
from pathlib import Path
import pytest
import torch
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from hydra.hyena_block import HyenaBlock # noqa: E402
from subsystems import hyena_pure # noqa: E402
def _reset_rfft_counter():
hyena_pure._fftconv_filter_rfft_count = 0
def _rfft_count() -> int:
return hyena_pure._fftconv_filter_rfft_count
def test_train_cache_runs_filter_mlp_once_per_step(monkeypatch):
"""With HYDRA_HYENA_TRAIN_CACHE=1, the IMPLICIT FILTER MLP runs exactly
once across N accum micro-batches, not once per micro-batch.
We can't distinguish MLP forwards via the rfft counter alone (rfft also
fires for `k_f` per micro-batch for graph-safety reasons, see
`HyenaFilter.get_or_build_train_cache` docstring). We instead patch the
`implicit_filter` Sequential's forward with a counting proxy and verify
it ran once.
"""
monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
torch.manual_seed(0)
D, T = 32, 16
block = HyenaBlock(d_model=D, seq_len=T)
block.train()
assert block.operator._use_train_cache is True
# Count MLP forwards.
orig_forward = block.operator.filter_fn.implicit_filter.forward
n_calls = {"count": 0}
def counting_forward(*args, **kwargs):
n_calls["count"] += 1
return orig_forward(*args, **kwargs)
block.operator.filter_fn.implicit_filter.forward = counting_forward
accum = 3
for _ in range(accum):
x = torch.randn(1, T, D)
y = block(x)
loss = y.pow(2).mean() / accum
loss.backward()
# EXACTLY 1 MLP forward total, not 3.
assert n_calls["count"] == 1, (
f"expected exactly 1 filter MLP forward under train-cache across "
f"{accum} micro-batches, got {n_calls['count']}"
)
def test_train_cache_no_backward_twice_error(monkeypatch):
"""Three micro-batches with train-cache on must NOT raise
'Trying to backward through the graph a second time'.
This is the core correctness guarantee. Without the fix, this test
reliably reproduces the runtime error.
"""
monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
torch.manual_seed(1)
D, T = 32, 16
block = HyenaBlock(d_model=D, seq_len=T)
block.train()
accum = 4
# This must not raise.
for _ in range(accum):
x = torch.randn(1, T, D)
y = block(x)
loss = y.pow(2).mean() / accum
loss.backward()
# After all micro-batches, k_leaf.grad must be non-None (grad accumulated).
k_leaf = block.operator.filter_fn._k_leaf
assert k_leaf is not None, "train-cache failed to populate _k_leaf"
assert k_leaf.grad is not None, "no accumulated gradient on _k_leaf"
assert torch.isfinite(k_leaf.grad).all(), "k_leaf.grad has NaN/Inf"
def test_train_cache_flush_populates_filter_params(monkeypatch):
"""After flush, the filter MLP params must have non-zero, finite grads."""
monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
torch.manual_seed(2)
D, T = 32, 16
block = HyenaBlock(d_model=D, seq_len=T)
block.train()
# Zero-init params' grads.
for p in block.parameters():
p.grad = None
# Run 3 accum micro-batches.
for _ in range(3):
x = torch.randn(1, T, D)
y = block(x)
loss = y.pow(2).mean() / 3
loss.backward()
# Before flush, filter MLP params should have NO grad (the backward chain
# was cut at k_leaf). Only params downstream of k_leaf (short_filter,
# in_proj, out_proj) should have grads.
# NOTE: the filter's `bias` is actually used AFTER the leaf stash (see
# HyenaOperator.forward: bias comes from filter_fn.bias directly, not from
# the cached k_leaf) so `bias.grad` WILL be populated by the direct path.
for name, p in block.operator.filter_fn.implicit_filter.named_parameters():
if p.requires_grad:
assert p.grad is None or p.grad.abs().max() == 0, (
f"implicit_filter.{name} has grad before flush — the leaf "
f"cache didn't actually cut the graph"
)
# Flush: this invokes torch.autograd.backward(_k_graph, _k_leaf.grad).
block.operator.flush_pending_filter_grads()
# Now implicit_filter params must have real grads.
for name, p in block.operator.filter_fn.implicit_filter.named_parameters():
if p.requires_grad:
assert p.grad is not None, f"implicit_filter.{name} has no grad after flush"
assert torch.isfinite(p.grad).all(), f"implicit_filter.{name} grad NaN/Inf"
# With 3 random micro-batches and dL/dy = 2*y/(B*T*D*3), the
# propagated grad MUST be non-zero for every param that's
# reachable from the filter output.
assert p.grad.abs().max() > 0, (
f"implicit_filter.{name}.grad is all zero — flush didn't "
f"push the k_leaf.grad back"
)
def test_train_cache_gradient_matches_uncached(monkeypatch):
"""Parameter gradients under train-cache must numerically match
the uncached path within tolerance.
We construct two identical blocks, run the same 3 micro-batches on each,
flush train-cache, then compare .grad on every param.
"""
torch.manual_seed(3)
D, T = 32, 16
# Block A: no cache (baseline).
block_a = HyenaBlock(d_model=D, seq_len=T)
block_a.train()
# Block B: train-cache on, same weights.
# Note: monkeypatch.setenv only affects env reads AT CONSTRUCTION; the
# block reads the flag in __init__. So we set before constructing B.
monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
block_b = HyenaBlock(d_model=D, seq_len=T)
block_b.load_state_dict(block_a.state_dict())
block_b.train()
# Verify the flag actually took effect.
assert block_b.operator._use_train_cache is True
assert block_a.operator._use_train_cache is False
# Same 3 micro-batches.
xs = [torch.randn(1, T, D) for _ in range(3)]
for block, label in ((block_a, "a"), (block_b, "b")):
for p in block.parameters():
p.grad = None
for x in xs:
y = block(x)
loss = y.pow(2).mean() / len(xs)
loss.backward()
# Flush train-cache (block_b only).
block_b.operator.flush_pending_filter_grads()
# Compare grads.
state_a = dict(block_a.named_parameters())
state_b = dict(block_b.named_parameters())
max_abs_diff = 0.0
max_diff_name = ""
for name, p_a in state_a.items():
p_b = state_b[name]
if p_a.grad is None:
assert p_b.grad is None or p_b.grad.abs().max() == 0, (
f"{name}: A has no grad, B has nonzero grad"
)
continue
assert p_b.grad is not None, f"{name}: A has grad, B has none"
diff = (p_a.grad - p_b.grad).abs().max().item()
if diff > max_abs_diff:
max_abs_diff = diff
max_diff_name = name
# Tight tolerance: the two paths do the SAME math in fp32 CPU, just the
# cached path defers the backward. Expected diff ≈ 0 modulo FP noise.
assert max_abs_diff < 1e-4, (
f"grad mismatch between cached and uncached paths: "
f"max |Δgrad| = {max_abs_diff:.3e} on {max_diff_name!r}"
)
def test_train_cache_invalidate_resets_state(monkeypatch):
"""After invalidate_cache(), the next step rebuilds k_graph fresh.
Simulates the post-optimizer.step() lifecycle.
"""
monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
torch.manual_seed(4)
D, T = 32, 16
block = HyenaBlock(d_model=D, seq_len=T)
block.train()
# Step 1: 2 micro-batches, flush, invalidate.
for _ in range(2):
y = block(torch.randn(1, T, D))
(y.pow(2).mean() / 2).backward()
assert block.operator.filter_fn._k_graph is not None
block.operator.flush_pending_filter_grads()
block.operator.invalidate_filter_cache()
assert block.operator.filter_fn._k_graph is None
assert block.operator.filter_fn._k_leaf is None
# Zero filter params' grads (simulating optimizer.zero_grad())
for p in block.parameters():
p.grad = None
# Step 2: must work the same (not use stale state).
for _ in range(2):
y = block(torch.randn(1, T, D))
(y.pow(2).mean() / 2).backward()
assert block.operator.filter_fn._k_graph is not None, (
"second step failed to rebuild k_graph"
)
block.operator.flush_pending_filter_grads()
# All filter MLP params got grad again.
for name, p in block.operator.filter_fn.implicit_filter.named_parameters():
if p.requires_grad:
assert p.grad is not None, f"step 2: {name} has no grad"
def test_train_cache_disabled_by_default(monkeypatch):
"""Unset env var → train-cache OFF → filter runs per micro-batch as before."""
monkeypatch.delenv("HYDRA_HYENA_TRAIN_CACHE", raising=False)
torch.manual_seed(5)
D, T = 32, 16
block = HyenaBlock(d_model=D, seq_len=T)
assert block.operator._use_train_cache is False
def test_train_cache_forward_output_matches_uncached(monkeypatch):
"""Cached vs uncached forward outputs must match numerically."""
torch.manual_seed(6)
D, T = 32, 16
# Uncached.
block_a = HyenaBlock(d_model=D, seq_len=T)
block_a.eval()
# Cached copy.
monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
block_b = HyenaBlock(d_model=D, seq_len=T)
block_b.load_state_dict(block_a.state_dict())
block_b.train() # train-cache only activates under grad_enabled
x = torch.randn(1, T, D)
y_a = block_a(x) # uncached path (no grad → eval mode anyway)
y_b = block_b(x) # cached path
max_diff = (y_a - y_b).abs().max().item()
assert max_diff < 1e-5, f"forward drift under train-cache: {max_diff:.3e}"
def test_flush_is_no_op_on_second_call(monkeypatch):
"""Idempotent flush: second call in the same step must not crash."""
monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
torch.manual_seed(7)
D, T = 32, 16
block = HyenaBlock(d_model=D, seq_len=T)
block.train()
y = block(torch.randn(1, T, D))
y.pow(2).mean().backward()
# First flush: real work.
block.operator.flush_pending_filter_grads()
# Second flush: must silently no-op.
block.operator.flush_pending_filter_grads()
def test_flush_is_no_op_when_no_forward(monkeypatch):
"""If no Hyena forward ran this step, flush is a safe no-op."""
monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
D, T = 32, 16
block = HyenaBlock(d_model=D, seq_len=T)
block.train()
# No forward called. Flush should just return.
block.operator.flush_pending_filter_grads()
if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-v"]))
|