Spaces:
Runtime error
Runtime error
| """Training-safe filter cache for HyenaOperator. | |
| **What this validates:** | |
| When `HYDRA_HYENA_TRAIN_CACHE=1`, the filter MLP must: | |
| 1. Run EXACTLY ONCE per optimizer step, not once per micro-batch. | |
| 2. Produce gradients on its params that match the uncached path to within | |
| bf16 tolerance (we use fp32 CPU tensors here, so atol should be tight). | |
| 3. Not trip `RuntimeError: Trying to backward through the graph a second time` | |
| under the grad-accum pattern. | |
| **Design under test:** | |
| `HyenaFilter.get_or_build_train_cache(L, fft_size)` returns a LEAF tensor | |
| `k_leaf` whose grad accumulates across micro-batches. After all micro-batch | |
| backwards, `flush_pending_filter_grads()` does one | |
| `torch.autograd.backward(_k_graph, _k_leaf.grad)` to populate the filter | |
| MLP params' `.grad`. Then `invalidate_cache()` resets state for the next | |
| step. | |
| Run: | |
| cd /home/mikeb/work/feather | |
| .venv/bin/pytest tests/test_hyena_train_cache.py -v | |
| """ | |
| from __future__ import annotations | |
| import sys | |
| from pathlib import Path | |
| import pytest | |
| import torch | |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) | |
| from hydra.hyena_block import HyenaBlock # noqa: E402 | |
| from subsystems import hyena_pure # noqa: E402 | |
| def _reset_rfft_counter(): | |
| hyena_pure._fftconv_filter_rfft_count = 0 | |
| def _rfft_count() -> int: | |
| return hyena_pure._fftconv_filter_rfft_count | |
| def test_train_cache_runs_filter_mlp_once_per_step(monkeypatch): | |
| """With HYDRA_HYENA_TRAIN_CACHE=1, the IMPLICIT FILTER MLP runs exactly | |
| once across N accum micro-batches, not once per micro-batch. | |
| We can't distinguish MLP forwards via the rfft counter alone (rfft also | |
| fires for `k_f` per micro-batch for graph-safety reasons, see | |
| `HyenaFilter.get_or_build_train_cache` docstring). We instead patch the | |
| `implicit_filter` Sequential's forward with a counting proxy and verify | |
| it ran once. | |
| """ | |
| monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") | |
| torch.manual_seed(0) | |
| D, T = 32, 16 | |
| block = HyenaBlock(d_model=D, seq_len=T) | |
| block.train() | |
| assert block.operator._use_train_cache is True | |
| # Count MLP forwards. | |
| orig_forward = block.operator.filter_fn.implicit_filter.forward | |
| n_calls = {"count": 0} | |
| def counting_forward(*args, **kwargs): | |
| n_calls["count"] += 1 | |
| return orig_forward(*args, **kwargs) | |
| block.operator.filter_fn.implicit_filter.forward = counting_forward | |
| accum = 3 | |
| for _ in range(accum): | |
| x = torch.randn(1, T, D) | |
| y = block(x) | |
| loss = y.pow(2).mean() / accum | |
| loss.backward() | |
| # EXACTLY 1 MLP forward total, not 3. | |
| assert n_calls["count"] == 1, ( | |
| f"expected exactly 1 filter MLP forward under train-cache across " | |
| f"{accum} micro-batches, got {n_calls['count']}" | |
| ) | |
| def test_train_cache_no_backward_twice_error(monkeypatch): | |
| """Three micro-batches with train-cache on must NOT raise | |
| 'Trying to backward through the graph a second time'. | |
| This is the core correctness guarantee. Without the fix, this test | |
| reliably reproduces the runtime error. | |
| """ | |
| monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") | |
| torch.manual_seed(1) | |
| D, T = 32, 16 | |
| block = HyenaBlock(d_model=D, seq_len=T) | |
| block.train() | |
| accum = 4 | |
| # This must not raise. | |
| for _ in range(accum): | |
| x = torch.randn(1, T, D) | |
| y = block(x) | |
| loss = y.pow(2).mean() / accum | |
| loss.backward() | |
| # After all micro-batches, k_leaf.grad must be non-None (grad accumulated). | |
| k_leaf = block.operator.filter_fn._k_leaf | |
| assert k_leaf is not None, "train-cache failed to populate _k_leaf" | |
| assert k_leaf.grad is not None, "no accumulated gradient on _k_leaf" | |
| assert torch.isfinite(k_leaf.grad).all(), "k_leaf.grad has NaN/Inf" | |
| def test_train_cache_flush_populates_filter_params(monkeypatch): | |
| """After flush, the filter MLP params must have non-zero, finite grads.""" | |
| monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") | |
| torch.manual_seed(2) | |
| D, T = 32, 16 | |
| block = HyenaBlock(d_model=D, seq_len=T) | |
| block.train() | |
| # Zero-init params' grads. | |
| for p in block.parameters(): | |
| p.grad = None | |
| # Run 3 accum micro-batches. | |
| for _ in range(3): | |
| x = torch.randn(1, T, D) | |
| y = block(x) | |
| loss = y.pow(2).mean() / 3 | |
| loss.backward() | |
| # Before flush, filter MLP params should have NO grad (the backward chain | |
| # was cut at k_leaf). Only params downstream of k_leaf (short_filter, | |
| # in_proj, out_proj) should have grads. | |
| # NOTE: the filter's `bias` is actually used AFTER the leaf stash (see | |
| # HyenaOperator.forward: bias comes from filter_fn.bias directly, not from | |
| # the cached k_leaf) so `bias.grad` WILL be populated by the direct path. | |
| for name, p in block.operator.filter_fn.implicit_filter.named_parameters(): | |
| if p.requires_grad: | |
| assert p.grad is None or p.grad.abs().max() == 0, ( | |
| f"implicit_filter.{name} has grad before flush — the leaf " | |
| f"cache didn't actually cut the graph" | |
| ) | |
| # Flush: this invokes torch.autograd.backward(_k_graph, _k_leaf.grad). | |
| block.operator.flush_pending_filter_grads() | |
| # Now implicit_filter params must have real grads. | |
| for name, p in block.operator.filter_fn.implicit_filter.named_parameters(): | |
| if p.requires_grad: | |
| assert p.grad is not None, f"implicit_filter.{name} has no grad after flush" | |
| assert torch.isfinite(p.grad).all(), f"implicit_filter.{name} grad NaN/Inf" | |
| # With 3 random micro-batches and dL/dy = 2*y/(B*T*D*3), the | |
| # propagated grad MUST be non-zero for every param that's | |
| # reachable from the filter output. | |
| assert p.grad.abs().max() > 0, ( | |
| f"implicit_filter.{name}.grad is all zero — flush didn't " | |
| f"push the k_leaf.grad back" | |
| ) | |
| def test_train_cache_gradient_matches_uncached(monkeypatch): | |
| """Parameter gradients under train-cache must numerically match | |
| the uncached path within tolerance. | |
| We construct two identical blocks, run the same 3 micro-batches on each, | |
| flush train-cache, then compare .grad on every param. | |
| """ | |
| torch.manual_seed(3) | |
| D, T = 32, 16 | |
| # Block A: no cache (baseline). | |
| block_a = HyenaBlock(d_model=D, seq_len=T) | |
| block_a.train() | |
| # Block B: train-cache on, same weights. | |
| # Note: monkeypatch.setenv only affects env reads AT CONSTRUCTION; the | |
| # block reads the flag in __init__. So we set before constructing B. | |
| monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") | |
| block_b = HyenaBlock(d_model=D, seq_len=T) | |
| block_b.load_state_dict(block_a.state_dict()) | |
| block_b.train() | |
| # Verify the flag actually took effect. | |
| assert block_b.operator._use_train_cache is True | |
| assert block_a.operator._use_train_cache is False | |
| # Same 3 micro-batches. | |
| xs = [torch.randn(1, T, D) for _ in range(3)] | |
| for block, label in ((block_a, "a"), (block_b, "b")): | |
| for p in block.parameters(): | |
| p.grad = None | |
| for x in xs: | |
| y = block(x) | |
| loss = y.pow(2).mean() / len(xs) | |
| loss.backward() | |
| # Flush train-cache (block_b only). | |
| block_b.operator.flush_pending_filter_grads() | |
| # Compare grads. | |
| state_a = dict(block_a.named_parameters()) | |
| state_b = dict(block_b.named_parameters()) | |
| max_abs_diff = 0.0 | |
| max_diff_name = "" | |
| for name, p_a in state_a.items(): | |
| p_b = state_b[name] | |
| if p_a.grad is None: | |
| assert p_b.grad is None or p_b.grad.abs().max() == 0, ( | |
| f"{name}: A has no grad, B has nonzero grad" | |
| ) | |
| continue | |
| assert p_b.grad is not None, f"{name}: A has grad, B has none" | |
| diff = (p_a.grad - p_b.grad).abs().max().item() | |
| if diff > max_abs_diff: | |
| max_abs_diff = diff | |
| max_diff_name = name | |
| # Tight tolerance: the two paths do the SAME math in fp32 CPU, just the | |
| # cached path defers the backward. Expected diff ≈ 0 modulo FP noise. | |
| assert max_abs_diff < 1e-4, ( | |
| f"grad mismatch between cached and uncached paths: " | |
| f"max |Δgrad| = {max_abs_diff:.3e} on {max_diff_name!r}" | |
| ) | |
| def test_train_cache_invalidate_resets_state(monkeypatch): | |
| """After invalidate_cache(), the next step rebuilds k_graph fresh. | |
| Simulates the post-optimizer.step() lifecycle. | |
| """ | |
| monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") | |
| torch.manual_seed(4) | |
| D, T = 32, 16 | |
| block = HyenaBlock(d_model=D, seq_len=T) | |
| block.train() | |
| # Step 1: 2 micro-batches, flush, invalidate. | |
| for _ in range(2): | |
| y = block(torch.randn(1, T, D)) | |
| (y.pow(2).mean() / 2).backward() | |
| assert block.operator.filter_fn._k_graph is not None | |
| block.operator.flush_pending_filter_grads() | |
| block.operator.invalidate_filter_cache() | |
| assert block.operator.filter_fn._k_graph is None | |
| assert block.operator.filter_fn._k_leaf is None | |
| # Zero filter params' grads (simulating optimizer.zero_grad()) | |
| for p in block.parameters(): | |
| p.grad = None | |
| # Step 2: must work the same (not use stale state). | |
| for _ in range(2): | |
| y = block(torch.randn(1, T, D)) | |
| (y.pow(2).mean() / 2).backward() | |
| assert block.operator.filter_fn._k_graph is not None, ( | |
| "second step failed to rebuild k_graph" | |
| ) | |
| block.operator.flush_pending_filter_grads() | |
| # All filter MLP params got grad again. | |
| for name, p in block.operator.filter_fn.implicit_filter.named_parameters(): | |
| if p.requires_grad: | |
| assert p.grad is not None, f"step 2: {name} has no grad" | |
| def test_train_cache_disabled_by_default(monkeypatch): | |
| """Unset env var → train-cache OFF → filter runs per micro-batch as before.""" | |
| monkeypatch.delenv("HYDRA_HYENA_TRAIN_CACHE", raising=False) | |
| torch.manual_seed(5) | |
| D, T = 32, 16 | |
| block = HyenaBlock(d_model=D, seq_len=T) | |
| assert block.operator._use_train_cache is False | |
| def test_train_cache_forward_output_matches_uncached(monkeypatch): | |
| """Cached vs uncached forward outputs must match numerically.""" | |
| torch.manual_seed(6) | |
| D, T = 32, 16 | |
| # Uncached. | |
| block_a = HyenaBlock(d_model=D, seq_len=T) | |
| block_a.eval() | |
| # Cached copy. | |
| monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") | |
| block_b = HyenaBlock(d_model=D, seq_len=T) | |
| block_b.load_state_dict(block_a.state_dict()) | |
| block_b.train() # train-cache only activates under grad_enabled | |
| x = torch.randn(1, T, D) | |
| y_a = block_a(x) # uncached path (no grad → eval mode anyway) | |
| y_b = block_b(x) # cached path | |
| max_diff = (y_a - y_b).abs().max().item() | |
| assert max_diff < 1e-5, f"forward drift under train-cache: {max_diff:.3e}" | |
| def test_flush_is_no_op_on_second_call(monkeypatch): | |
| """Idempotent flush: second call in the same step must not crash.""" | |
| monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") | |
| torch.manual_seed(7) | |
| D, T = 32, 16 | |
| block = HyenaBlock(d_model=D, seq_len=T) | |
| block.train() | |
| y = block(torch.randn(1, T, D)) | |
| y.pow(2).mean().backward() | |
| # First flush: real work. | |
| block.operator.flush_pending_filter_grads() | |
| # Second flush: must silently no-op. | |
| block.operator.flush_pending_filter_grads() | |
| def test_flush_is_no_op_when_no_forward(monkeypatch): | |
| """If no Hyena forward ran this step, flush is a safe no-op.""" | |
| monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1") | |
| D, T = 32, 16 | |
| block = HyenaBlock(d_model=D, seq_len=T) | |
| block.train() | |
| # No forward called. Flush should just return. | |
| block.operator.flush_pending_filter_grads() | |
| if __name__ == "__main__": | |
| sys.exit(pytest.main([__file__, "-v"])) | |