feather-a10g-large-runtime / overlay /tests /test_hyena_train_cache.py
icarus112's picture
Update Feather a10g-large training runtime image
c475135 verified
"""Training-safe filter cache for HyenaOperator.
**What this validates:**
When `HYDRA_HYENA_TRAIN_CACHE=1`, the filter MLP must:
1. Run EXACTLY ONCE per optimizer step, not once per micro-batch.
2. Produce gradients on its params that match the uncached path to within
bf16 tolerance (we use fp32 CPU tensors here, so atol should be tight).
3. Not trip `RuntimeError: Trying to backward through the graph a second time`
under the grad-accum pattern.
**Design under test:**
`HyenaFilter.get_or_build_train_cache(L, fft_size)` returns a LEAF tensor
`k_leaf` whose grad accumulates across micro-batches. After all micro-batch
backwards, `flush_pending_filter_grads()` does one
`torch.autograd.backward(_k_graph, _k_leaf.grad)` to populate the filter
MLP params' `.grad`. Then `invalidate_cache()` resets state for the next
step.
Run:
cd /home/mikeb/work/feather
.venv/bin/pytest tests/test_hyena_train_cache.py -v
"""
from __future__ import annotations
import sys
from pathlib import Path
import pytest
import torch
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from hydra.hyena_block import HyenaBlock # noqa: E402
from subsystems import hyena_pure # noqa: E402
def _reset_rfft_counter():
hyena_pure._fftconv_filter_rfft_count = 0
def _rfft_count() -> int:
return hyena_pure._fftconv_filter_rfft_count
def test_train_cache_runs_filter_mlp_once_per_step(monkeypatch):
"""With HYDRA_HYENA_TRAIN_CACHE=1, the IMPLICIT FILTER MLP runs exactly
once across N accum micro-batches, not once per micro-batch.
We can't distinguish MLP forwards via the rfft counter alone (rfft also
fires for `k_f` per micro-batch for graph-safety reasons, see
`HyenaFilter.get_or_build_train_cache` docstring). We instead patch the
`implicit_filter` Sequential's forward with a counting proxy and verify
it ran once.
"""
monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
torch.manual_seed(0)
D, T = 32, 16
block = HyenaBlock(d_model=D, seq_len=T)
block.train()
assert block.operator._use_train_cache is True
# Count MLP forwards.
orig_forward = block.operator.filter_fn.implicit_filter.forward
n_calls = {"count": 0}
def counting_forward(*args, **kwargs):
n_calls["count"] += 1
return orig_forward(*args, **kwargs)
block.operator.filter_fn.implicit_filter.forward = counting_forward
accum = 3
for _ in range(accum):
x = torch.randn(1, T, D)
y = block(x)
loss = y.pow(2).mean() / accum
loss.backward()
# EXACTLY 1 MLP forward total, not 3.
assert n_calls["count"] == 1, (
f"expected exactly 1 filter MLP forward under train-cache across "
f"{accum} micro-batches, got {n_calls['count']}"
)
def test_train_cache_no_backward_twice_error(monkeypatch):
"""Three micro-batches with train-cache on must NOT raise
'Trying to backward through the graph a second time'.
This is the core correctness guarantee. Without the fix, this test
reliably reproduces the runtime error.
"""
monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
torch.manual_seed(1)
D, T = 32, 16
block = HyenaBlock(d_model=D, seq_len=T)
block.train()
accum = 4
# This must not raise.
for _ in range(accum):
x = torch.randn(1, T, D)
y = block(x)
loss = y.pow(2).mean() / accum
loss.backward()
# After all micro-batches, k_leaf.grad must be non-None (grad accumulated).
k_leaf = block.operator.filter_fn._k_leaf
assert k_leaf is not None, "train-cache failed to populate _k_leaf"
assert k_leaf.grad is not None, "no accumulated gradient on _k_leaf"
assert torch.isfinite(k_leaf.grad).all(), "k_leaf.grad has NaN/Inf"
def test_train_cache_flush_populates_filter_params(monkeypatch):
"""After flush, the filter MLP params must have non-zero, finite grads."""
monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
torch.manual_seed(2)
D, T = 32, 16
block = HyenaBlock(d_model=D, seq_len=T)
block.train()
# Zero-init params' grads.
for p in block.parameters():
p.grad = None
# Run 3 accum micro-batches.
for _ in range(3):
x = torch.randn(1, T, D)
y = block(x)
loss = y.pow(2).mean() / 3
loss.backward()
# Before flush, filter MLP params should have NO grad (the backward chain
# was cut at k_leaf). Only params downstream of k_leaf (short_filter,
# in_proj, out_proj) should have grads.
# NOTE: the filter's `bias` is actually used AFTER the leaf stash (see
# HyenaOperator.forward: bias comes from filter_fn.bias directly, not from
# the cached k_leaf) so `bias.grad` WILL be populated by the direct path.
for name, p in block.operator.filter_fn.implicit_filter.named_parameters():
if p.requires_grad:
assert p.grad is None or p.grad.abs().max() == 0, (
f"implicit_filter.{name} has grad before flush — the leaf "
f"cache didn't actually cut the graph"
)
# Flush: this invokes torch.autograd.backward(_k_graph, _k_leaf.grad).
block.operator.flush_pending_filter_grads()
# Now implicit_filter params must have real grads.
for name, p in block.operator.filter_fn.implicit_filter.named_parameters():
if p.requires_grad:
assert p.grad is not None, f"implicit_filter.{name} has no grad after flush"
assert torch.isfinite(p.grad).all(), f"implicit_filter.{name} grad NaN/Inf"
# With 3 random micro-batches and dL/dy = 2*y/(B*T*D*3), the
# propagated grad MUST be non-zero for every param that's
# reachable from the filter output.
assert p.grad.abs().max() > 0, (
f"implicit_filter.{name}.grad is all zero — flush didn't "
f"push the k_leaf.grad back"
)
def test_train_cache_gradient_matches_uncached(monkeypatch):
"""Parameter gradients under train-cache must numerically match
the uncached path within tolerance.
We construct two identical blocks, run the same 3 micro-batches on each,
flush train-cache, then compare .grad on every param.
"""
torch.manual_seed(3)
D, T = 32, 16
# Block A: no cache (baseline).
block_a = HyenaBlock(d_model=D, seq_len=T)
block_a.train()
# Block B: train-cache on, same weights.
# Note: monkeypatch.setenv only affects env reads AT CONSTRUCTION; the
# block reads the flag in __init__. So we set before constructing B.
monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
block_b = HyenaBlock(d_model=D, seq_len=T)
block_b.load_state_dict(block_a.state_dict())
block_b.train()
# Verify the flag actually took effect.
assert block_b.operator._use_train_cache is True
assert block_a.operator._use_train_cache is False
# Same 3 micro-batches.
xs = [torch.randn(1, T, D) for _ in range(3)]
for block, label in ((block_a, "a"), (block_b, "b")):
for p in block.parameters():
p.grad = None
for x in xs:
y = block(x)
loss = y.pow(2).mean() / len(xs)
loss.backward()
# Flush train-cache (block_b only).
block_b.operator.flush_pending_filter_grads()
# Compare grads.
state_a = dict(block_a.named_parameters())
state_b = dict(block_b.named_parameters())
max_abs_diff = 0.0
max_diff_name = ""
for name, p_a in state_a.items():
p_b = state_b[name]
if p_a.grad is None:
assert p_b.grad is None or p_b.grad.abs().max() == 0, (
f"{name}: A has no grad, B has nonzero grad"
)
continue
assert p_b.grad is not None, f"{name}: A has grad, B has none"
diff = (p_a.grad - p_b.grad).abs().max().item()
if diff > max_abs_diff:
max_abs_diff = diff
max_diff_name = name
# Tight tolerance: the two paths do the SAME math in fp32 CPU, just the
# cached path defers the backward. Expected diff ≈ 0 modulo FP noise.
assert max_abs_diff < 1e-4, (
f"grad mismatch between cached and uncached paths: "
f"max |Δgrad| = {max_abs_diff:.3e} on {max_diff_name!r}"
)
def test_train_cache_invalidate_resets_state(monkeypatch):
"""After invalidate_cache(), the next step rebuilds k_graph fresh.
Simulates the post-optimizer.step() lifecycle.
"""
monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
torch.manual_seed(4)
D, T = 32, 16
block = HyenaBlock(d_model=D, seq_len=T)
block.train()
# Step 1: 2 micro-batches, flush, invalidate.
for _ in range(2):
y = block(torch.randn(1, T, D))
(y.pow(2).mean() / 2).backward()
assert block.operator.filter_fn._k_graph is not None
block.operator.flush_pending_filter_grads()
block.operator.invalidate_filter_cache()
assert block.operator.filter_fn._k_graph is None
assert block.operator.filter_fn._k_leaf is None
# Zero filter params' grads (simulating optimizer.zero_grad())
for p in block.parameters():
p.grad = None
# Step 2: must work the same (not use stale state).
for _ in range(2):
y = block(torch.randn(1, T, D))
(y.pow(2).mean() / 2).backward()
assert block.operator.filter_fn._k_graph is not None, (
"second step failed to rebuild k_graph"
)
block.operator.flush_pending_filter_grads()
# All filter MLP params got grad again.
for name, p in block.operator.filter_fn.implicit_filter.named_parameters():
if p.requires_grad:
assert p.grad is not None, f"step 2: {name} has no grad"
def test_train_cache_disabled_by_default(monkeypatch):
"""Unset env var → train-cache OFF → filter runs per micro-batch as before."""
monkeypatch.delenv("HYDRA_HYENA_TRAIN_CACHE", raising=False)
torch.manual_seed(5)
D, T = 32, 16
block = HyenaBlock(d_model=D, seq_len=T)
assert block.operator._use_train_cache is False
def test_train_cache_forward_output_matches_uncached(monkeypatch):
"""Cached vs uncached forward outputs must match numerically."""
torch.manual_seed(6)
D, T = 32, 16
# Uncached.
block_a = HyenaBlock(d_model=D, seq_len=T)
block_a.eval()
# Cached copy.
monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
block_b = HyenaBlock(d_model=D, seq_len=T)
block_b.load_state_dict(block_a.state_dict())
block_b.train() # train-cache only activates under grad_enabled
x = torch.randn(1, T, D)
y_a = block_a(x) # uncached path (no grad → eval mode anyway)
y_b = block_b(x) # cached path
max_diff = (y_a - y_b).abs().max().item()
assert max_diff < 1e-5, f"forward drift under train-cache: {max_diff:.3e}"
def test_flush_is_no_op_on_second_call(monkeypatch):
"""Idempotent flush: second call in the same step must not crash."""
monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
torch.manual_seed(7)
D, T = 32, 16
block = HyenaBlock(d_model=D, seq_len=T)
block.train()
y = block(torch.randn(1, T, D))
y.pow(2).mean().backward()
# First flush: real work.
block.operator.flush_pending_filter_grads()
# Second flush: must silently no-op.
block.operator.flush_pending_filter_grads()
def test_flush_is_no_op_when_no_forward(monkeypatch):
"""If no Hyena forward ran this step, flush is a safe no-op."""
monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
D, T = 32, 16
block = HyenaBlock(d_model=D, seq_len=T)
block.train()
# No forward called. Flush should just return.
block.operator.flush_pending_filter_grads()
if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-v"]))