| """Training-safe filter cache for HyenaOperator. |
| |
| **What this validates:** |
| When `HYDRA_HYENA_TRAIN_CACHE=1`, the filter MLP must: |
| 1. Run EXACTLY ONCE per optimizer step, not once per micro-batch. |
| 2. Produce gradients on its params that match the uncached path to within |
| bf16 tolerance (we use fp32 CPU tensors here, so atol should be tight). |
| 3. Not trip `RuntimeError: Trying to backward through the graph a second time` |
| under the grad-accum pattern. |
| |
| **Design under test:** |
| `HyenaFilter.get_or_build_train_cache(L, fft_size)` returns a LEAF tensor |
| `k_leaf` whose grad accumulates across micro-batches. After all micro-batch |
| backwards, `flush_pending_filter_grads()` does one |
| `torch.autograd.backward(_k_graph, _k_leaf.grad)` to populate the filter |
| MLP params' `.grad`. Then `invalidate_cache()` resets state for the next |
| step. |
| |
| Run: |
| cd /home/mikeb/work/feather |
| .venv/bin/pytest tests/test_hyena_train_cache.py -v |
| """ |
|
|
| from __future__ import annotations |
|
|
| import sys |
| from pathlib import Path |
|
|
| import pytest |
| import torch |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) |
|
|
| from hydra.hyena_block import HyenaBlock |
| from subsystems import hyena_pure |
|
|
|
|
| def _reset_rfft_counter(): |
| hyena_pure._fftconv_filter_rfft_count = 0 |
|
|
|
|
| def _rfft_count() -> int: |
| return hyena_pure._fftconv_filter_rfft_count |
|
|
|
|
| def test_train_cache_runs_filter_mlp_once_per_step(monkeypatch): |
| """With HYDRA_HYENA_TRAIN_CACHE=1, the IMPLICIT FILTER MLP runs exactly |
| once across N accum micro-batches, not once per micro-batch. |
| |
| We can't distinguish MLP forwards via the rfft counter alone (rfft also |
| fires for `k_f` per micro-batch for graph-safety reasons, see |
| `HyenaFilter.get_or_build_train_cache` docstring). We instead patch the |
| `implicit_filter` Sequential's forward with a counting proxy and verify |
| it ran once. |
| """ |
| monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") |
|
|
| torch.manual_seed(0) |
| D, T = 32, 16 |
| block = HyenaBlock(d_model=D, seq_len=T) |
| block.train() |
| assert block.operator._use_train_cache is True |
|
|
| |
| orig_forward = block.operator.filter_fn.implicit_filter.forward |
| n_calls = {"count": 0} |
|
|
| def counting_forward(*args, **kwargs): |
| n_calls["count"] += 1 |
| return orig_forward(*args, **kwargs) |
|
|
| block.operator.filter_fn.implicit_filter.forward = counting_forward |
|
|
| accum = 3 |
| for _ in range(accum): |
| x = torch.randn(1, T, D) |
| y = block(x) |
| loss = y.pow(2).mean() / accum |
| loss.backward() |
|
|
| |
| assert n_calls["count"] == 1, ( |
| f"expected exactly 1 filter MLP forward under train-cache across " |
| f"{accum} micro-batches, got {n_calls['count']}" |
| ) |
|
|
|
|
| def test_train_cache_no_backward_twice_error(monkeypatch): |
| """Three micro-batches with train-cache on must NOT raise |
| 'Trying to backward through the graph a second time'. |
| |
| This is the core correctness guarantee. Without the fix, this test |
| reliably reproduces the runtime error. |
| """ |
| monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") |
|
|
| torch.manual_seed(1) |
| D, T = 32, 16 |
| block = HyenaBlock(d_model=D, seq_len=T) |
| block.train() |
|
|
| accum = 4 |
| |
| for _ in range(accum): |
| x = torch.randn(1, T, D) |
| y = block(x) |
| loss = y.pow(2).mean() / accum |
| loss.backward() |
|
|
| |
| k_leaf = block.operator.filter_fn._k_leaf |
| assert k_leaf is not None, "train-cache failed to populate _k_leaf" |
| assert k_leaf.grad is not None, "no accumulated gradient on _k_leaf" |
| assert torch.isfinite(k_leaf.grad).all(), "k_leaf.grad has NaN/Inf" |
|
|
|
|
| def test_train_cache_flush_populates_filter_params(monkeypatch): |
| """After flush, the filter MLP params must have non-zero, finite grads.""" |
| monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") |
|
|
| torch.manual_seed(2) |
| D, T = 32, 16 |
| block = HyenaBlock(d_model=D, seq_len=T) |
| block.train() |
|
|
| |
| for p in block.parameters(): |
| p.grad = None |
|
|
| |
| for _ in range(3): |
| x = torch.randn(1, T, D) |
| y = block(x) |
| loss = y.pow(2).mean() / 3 |
| loss.backward() |
|
|
| |
| |
| |
| |
| |
| |
| for name, p in block.operator.filter_fn.implicit_filter.named_parameters(): |
| if p.requires_grad: |
| assert p.grad is None or p.grad.abs().max() == 0, ( |
| f"implicit_filter.{name} has grad before flush — the leaf " |
| f"cache didn't actually cut the graph" |
| ) |
|
|
| |
| block.operator.flush_pending_filter_grads() |
|
|
| |
| for name, p in block.operator.filter_fn.implicit_filter.named_parameters(): |
| if p.requires_grad: |
| assert p.grad is not None, f"implicit_filter.{name} has no grad after flush" |
| assert torch.isfinite(p.grad).all(), f"implicit_filter.{name} grad NaN/Inf" |
| |
| |
| |
| assert p.grad.abs().max() > 0, ( |
| f"implicit_filter.{name}.grad is all zero — flush didn't " |
| f"push the k_leaf.grad back" |
| ) |
|
|
|
|
| def test_train_cache_gradient_matches_uncached(monkeypatch): |
| """Parameter gradients under train-cache must numerically match |
| the uncached path within tolerance. |
| |
| We construct two identical blocks, run the same 3 micro-batches on each, |
| flush train-cache, then compare .grad on every param. |
| """ |
| torch.manual_seed(3) |
| D, T = 32, 16 |
|
|
| |
| block_a = HyenaBlock(d_model=D, seq_len=T) |
| block_a.train() |
| |
| |
| |
| monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") |
| block_b = HyenaBlock(d_model=D, seq_len=T) |
| block_b.load_state_dict(block_a.state_dict()) |
| block_b.train() |
| |
| assert block_b.operator._use_train_cache is True |
| assert block_a.operator._use_train_cache is False |
|
|
| |
| xs = [torch.randn(1, T, D) for _ in range(3)] |
|
|
| for block, label in ((block_a, "a"), (block_b, "b")): |
| for p in block.parameters(): |
| p.grad = None |
| for x in xs: |
| y = block(x) |
| loss = y.pow(2).mean() / len(xs) |
| loss.backward() |
|
|
| |
| block_b.operator.flush_pending_filter_grads() |
|
|
| |
| state_a = dict(block_a.named_parameters()) |
| state_b = dict(block_b.named_parameters()) |
| max_abs_diff = 0.0 |
| max_diff_name = "" |
| for name, p_a in state_a.items(): |
| p_b = state_b[name] |
| if p_a.grad is None: |
| assert p_b.grad is None or p_b.grad.abs().max() == 0, ( |
| f"{name}: A has no grad, B has nonzero grad" |
| ) |
| continue |
| assert p_b.grad is not None, f"{name}: A has grad, B has none" |
| diff = (p_a.grad - p_b.grad).abs().max().item() |
| if diff > max_abs_diff: |
| max_abs_diff = diff |
| max_diff_name = name |
|
|
| |
| |
| assert max_abs_diff < 1e-4, ( |
| f"grad mismatch between cached and uncached paths: " |
| f"max |Δgrad| = {max_abs_diff:.3e} on {max_diff_name!r}" |
| ) |
|
|
|
|
| def test_train_cache_invalidate_resets_state(monkeypatch): |
| """After invalidate_cache(), the next step rebuilds k_graph fresh. |
| |
| Simulates the post-optimizer.step() lifecycle. |
| """ |
| monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") |
|
|
| torch.manual_seed(4) |
| D, T = 32, 16 |
| block = HyenaBlock(d_model=D, seq_len=T) |
| block.train() |
|
|
| |
| for _ in range(2): |
| y = block(torch.randn(1, T, D)) |
| (y.pow(2).mean() / 2).backward() |
| assert block.operator.filter_fn._k_graph is not None |
| block.operator.flush_pending_filter_grads() |
| block.operator.invalidate_filter_cache() |
| assert block.operator.filter_fn._k_graph is None |
| assert block.operator.filter_fn._k_leaf is None |
|
|
| |
| for p in block.parameters(): |
| p.grad = None |
|
|
| |
| for _ in range(2): |
| y = block(torch.randn(1, T, D)) |
| (y.pow(2).mean() / 2).backward() |
| assert block.operator.filter_fn._k_graph is not None, ( |
| "second step failed to rebuild k_graph" |
| ) |
| block.operator.flush_pending_filter_grads() |
| |
| for name, p in block.operator.filter_fn.implicit_filter.named_parameters(): |
| if p.requires_grad: |
| assert p.grad is not None, f"step 2: {name} has no grad" |
|
|
|
|
| def test_train_cache_disabled_by_default(monkeypatch): |
| """Unset env var → train-cache OFF → filter runs per micro-batch as before.""" |
| monkeypatch.delenv("HYDRA_HYENA_TRAIN_CACHE", raising=False) |
|
|
| torch.manual_seed(5) |
| D, T = 32, 16 |
| block = HyenaBlock(d_model=D, seq_len=T) |
| assert block.operator._use_train_cache is False |
|
|
|
|
| def test_train_cache_forward_output_matches_uncached(monkeypatch): |
| """Cached vs uncached forward outputs must match numerically.""" |
| torch.manual_seed(6) |
| D, T = 32, 16 |
|
|
| |
| block_a = HyenaBlock(d_model=D, seq_len=T) |
| block_a.eval() |
|
|
| |
| monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") |
| block_b = HyenaBlock(d_model=D, seq_len=T) |
| block_b.load_state_dict(block_a.state_dict()) |
| block_b.train() |
|
|
| x = torch.randn(1, T, D) |
| y_a = block_a(x) |
| y_b = block_b(x) |
|
|
| max_diff = (y_a - y_b).abs().max().item() |
| assert max_diff < 1e-5, f"forward drift under train-cache: {max_diff:.3e}" |
|
|
|
|
| def test_flush_is_no_op_on_second_call(monkeypatch): |
| """Idempotent flush: second call in the same step must not crash.""" |
| monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") |
|
|
| torch.manual_seed(7) |
| D, T = 32, 16 |
| block = HyenaBlock(d_model=D, seq_len=T) |
| block.train() |
|
|
| y = block(torch.randn(1, T, D)) |
| y.pow(2).mean().backward() |
|
|
| |
| block.operator.flush_pending_filter_grads() |
| |
| block.operator.flush_pending_filter_grads() |
|
|
|
|
| def test_flush_is_no_op_when_no_forward(monkeypatch): |
| """If no Hyena forward ran this step, flush is a safe no-op.""" |
| monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") |
|
|
| D, T = 32, 16 |
| block = HyenaBlock(d_model=D, seq_len=T) |
| block.train() |
|
|
| |
| block.operator.flush_pending_filter_grads() |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(pytest.main([__file__, "-v"])) |
|
|