"""Training-safe filter cache for HyenaOperator. **What this validates:** When `HYDRA_HYENA_TRAIN_CACHE=1`, the filter MLP must: 1. Run EXACTLY ONCE per optimizer step, not once per micro-batch. 2. Produce gradients on its params that match the uncached path to within bf16 tolerance (we use fp32 CPU tensors here, so atol should be tight). 3. Not trip `RuntimeError: Trying to backward through the graph a second time` under the grad-accum pattern. **Design under test:** `HyenaFilter.get_or_build_train_cache(L, fft_size)` returns a LEAF tensor `k_leaf` whose grad accumulates across micro-batches. After all micro-batch backwards, `flush_pending_filter_grads()` does one `torch.autograd.backward(_k_graph, _k_leaf.grad)` to populate the filter MLP params' `.grad`. Then `invalidate_cache()` resets state for the next step. Run: cd /home/mikeb/work/feather .venv/bin/pytest tests/test_hyena_train_cache.py -v """ from __future__ import annotations import sys from pathlib import Path import pytest import torch sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from hydra.hyena_block import HyenaBlock # noqa: E402 from subsystems import hyena_pure # noqa: E402 def _reset_rfft_counter(): hyena_pure._fftconv_filter_rfft_count = 0 def _rfft_count() -> int: return hyena_pure._fftconv_filter_rfft_count def test_train_cache_runs_filter_mlp_once_per_step(monkeypatch): """With HYDRA_HYENA_TRAIN_CACHE=1, the IMPLICIT FILTER MLP runs exactly once across N accum micro-batches, not once per micro-batch. We can't distinguish MLP forwards via the rfft counter alone (rfft also fires for `k_f` per micro-batch for graph-safety reasons, see `HyenaFilter.get_or_build_train_cache` docstring). We instead patch the `implicit_filter` Sequential's forward with a counting proxy and verify it ran once. """ monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") torch.manual_seed(0) D, T = 32, 16 block = HyenaBlock(d_model=D, seq_len=T) block.train() assert block.operator._use_train_cache is True # Count MLP forwards. orig_forward = block.operator.filter_fn.implicit_filter.forward n_calls = {"count": 0} def counting_forward(*args, **kwargs): n_calls["count"] += 1 return orig_forward(*args, **kwargs) block.operator.filter_fn.implicit_filter.forward = counting_forward accum = 3 for _ in range(accum): x = torch.randn(1, T, D) y = block(x) loss = y.pow(2).mean() / accum loss.backward() # EXACTLY 1 MLP forward total, not 3. assert n_calls["count"] == 1, ( f"expected exactly 1 filter MLP forward under train-cache across " f"{accum} micro-batches, got {n_calls['count']}" ) def test_train_cache_no_backward_twice_error(monkeypatch): """Three micro-batches with train-cache on must NOT raise 'Trying to backward through the graph a second time'. This is the core correctness guarantee. Without the fix, this test reliably reproduces the runtime error. """ monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") torch.manual_seed(1) D, T = 32, 16 block = HyenaBlock(d_model=D, seq_len=T) block.train() accum = 4 # This must not raise. for _ in range(accum): x = torch.randn(1, T, D) y = block(x) loss = y.pow(2).mean() / accum loss.backward() # After all micro-batches, k_leaf.grad must be non-None (grad accumulated). k_leaf = block.operator.filter_fn._k_leaf assert k_leaf is not None, "train-cache failed to populate _k_leaf" assert k_leaf.grad is not None, "no accumulated gradient on _k_leaf" assert torch.isfinite(k_leaf.grad).all(), "k_leaf.grad has NaN/Inf" def test_train_cache_flush_populates_filter_params(monkeypatch): """After flush, the filter MLP params must have non-zero, finite grads.""" monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") torch.manual_seed(2) D, T = 32, 16 block = HyenaBlock(d_model=D, seq_len=T) block.train() # Zero-init params' grads. for p in block.parameters(): p.grad = None # Run 3 accum micro-batches. for _ in range(3): x = torch.randn(1, T, D) y = block(x) loss = y.pow(2).mean() / 3 loss.backward() # Before flush, filter MLP params should have NO grad (the backward chain # was cut at k_leaf). Only params downstream of k_leaf (short_filter, # in_proj, out_proj) should have grads. # NOTE: the filter's `bias` is actually used AFTER the leaf stash (see # HyenaOperator.forward: bias comes from filter_fn.bias directly, not from # the cached k_leaf) so `bias.grad` WILL be populated by the direct path. for name, p in block.operator.filter_fn.implicit_filter.named_parameters(): if p.requires_grad: assert p.grad is None or p.grad.abs().max() == 0, ( f"implicit_filter.{name} has grad before flush — the leaf " f"cache didn't actually cut the graph" ) # Flush: this invokes torch.autograd.backward(_k_graph, _k_leaf.grad). block.operator.flush_pending_filter_grads() # Now implicit_filter params must have real grads. for name, p in block.operator.filter_fn.implicit_filter.named_parameters(): if p.requires_grad: assert p.grad is not None, f"implicit_filter.{name} has no grad after flush" assert torch.isfinite(p.grad).all(), f"implicit_filter.{name} grad NaN/Inf" # With 3 random micro-batches and dL/dy = 2*y/(B*T*D*3), the # propagated grad MUST be non-zero for every param that's # reachable from the filter output. assert p.grad.abs().max() > 0, ( f"implicit_filter.{name}.grad is all zero — flush didn't " f"push the k_leaf.grad back" ) def test_train_cache_gradient_matches_uncached(monkeypatch): """Parameter gradients under train-cache must numerically match the uncached path within tolerance. We construct two identical blocks, run the same 3 micro-batches on each, flush train-cache, then compare .grad on every param. """ torch.manual_seed(3) D, T = 32, 16 # Block A: no cache (baseline). block_a = HyenaBlock(d_model=D, seq_len=T) block_a.train() # Block B: train-cache on, same weights. # Note: monkeypatch.setenv only affects env reads AT CONSTRUCTION; the # block reads the flag in __init__. So we set before constructing B. monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") block_b = HyenaBlock(d_model=D, seq_len=T) block_b.load_state_dict(block_a.state_dict()) block_b.train() # Verify the flag actually took effect. assert block_b.operator._use_train_cache is True assert block_a.operator._use_train_cache is False # Same 3 micro-batches. xs = [torch.randn(1, T, D) for _ in range(3)] for block, label in ((block_a, "a"), (block_b, "b")): for p in block.parameters(): p.grad = None for x in xs: y = block(x) loss = y.pow(2).mean() / len(xs) loss.backward() # Flush train-cache (block_b only). block_b.operator.flush_pending_filter_grads() # Compare grads. state_a = dict(block_a.named_parameters()) state_b = dict(block_b.named_parameters()) max_abs_diff = 0.0 max_diff_name = "" for name, p_a in state_a.items(): p_b = state_b[name] if p_a.grad is None: assert p_b.grad is None or p_b.grad.abs().max() == 0, ( f"{name}: A has no grad, B has nonzero grad" ) continue assert p_b.grad is not None, f"{name}: A has grad, B has none" diff = (p_a.grad - p_b.grad).abs().max().item() if diff > max_abs_diff: max_abs_diff = diff max_diff_name = name # Tight tolerance: the two paths do the SAME math in fp32 CPU, just the # cached path defers the backward. Expected diff ≈ 0 modulo FP noise. assert max_abs_diff < 1e-4, ( f"grad mismatch between cached and uncached paths: " f"max |Δgrad| = {max_abs_diff:.3e} on {max_diff_name!r}" ) def test_train_cache_invalidate_resets_state(monkeypatch): """After invalidate_cache(), the next step rebuilds k_graph fresh. Simulates the post-optimizer.step() lifecycle. """ monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") torch.manual_seed(4) D, T = 32, 16 block = HyenaBlock(d_model=D, seq_len=T) block.train() # Step 1: 2 micro-batches, flush, invalidate. for _ in range(2): y = block(torch.randn(1, T, D)) (y.pow(2).mean() / 2).backward() assert block.operator.filter_fn._k_graph is not None block.operator.flush_pending_filter_grads() block.operator.invalidate_filter_cache() assert block.operator.filter_fn._k_graph is None assert block.operator.filter_fn._k_leaf is None # Zero filter params' grads (simulating optimizer.zero_grad()) for p in block.parameters(): p.grad = None # Step 2: must work the same (not use stale state). for _ in range(2): y = block(torch.randn(1, T, D)) (y.pow(2).mean() / 2).backward() assert block.operator.filter_fn._k_graph is not None, ( "second step failed to rebuild k_graph" ) block.operator.flush_pending_filter_grads() # All filter MLP params got grad again. for name, p in block.operator.filter_fn.implicit_filter.named_parameters(): if p.requires_grad: assert p.grad is not None, f"step 2: {name} has no grad" def test_train_cache_disabled_by_default(monkeypatch): """Unset env var → train-cache OFF → filter runs per micro-batch as before.""" monkeypatch.delenv("HYDRA_HYENA_TRAIN_CACHE", raising=False) torch.manual_seed(5) D, T = 32, 16 block = HyenaBlock(d_model=D, seq_len=T) assert block.operator._use_train_cache is False def test_train_cache_forward_output_matches_uncached(monkeypatch): """Cached vs uncached forward outputs must match numerically.""" torch.manual_seed(6) D, T = 32, 16 # Uncached. block_a = HyenaBlock(d_model=D, seq_len=T) block_a.eval() # Cached copy. monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") block_b = HyenaBlock(d_model=D, seq_len=T) block_b.load_state_dict(block_a.state_dict()) block_b.train() # train-cache only activates under grad_enabled x = torch.randn(1, T, D) y_a = block_a(x) # uncached path (no grad → eval mode anyway) y_b = block_b(x) # cached path max_diff = (y_a - y_b).abs().max().item() assert max_diff < 1e-5, f"forward drift under train-cache: {max_diff:.3e}" def test_flush_is_no_op_on_second_call(monkeypatch): """Idempotent flush: second call in the same step must not crash.""" monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") torch.manual_seed(7) D, T = 32, 16 block = HyenaBlock(d_model=D, seq_len=T) block.train() y = block(torch.randn(1, T, D)) y.pow(2).mean().backward() # First flush: real work. block.operator.flush_pending_filter_grads() # Second flush: must silently no-op. block.operator.flush_pending_filter_grads() def test_flush_is_no_op_when_no_forward(monkeypatch): """If no Hyena forward ran this step, flush is a safe no-op.""" monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") D, T = 32, 16 block = HyenaBlock(d_model=D, seq_len=T) block.train() # No forward called. Flush should just return. block.operator.flush_pending_filter_grads() if __name__ == "__main__": sys.exit(pytest.main([__file__, "-v"]))