File size: 11,799 Bytes
c475135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
"""Training-safe filter cache for HyenaOperator.

**What this validates:**
When `HYDRA_HYENA_TRAIN_CACHE=1`, the filter MLP must:
  1. Run EXACTLY ONCE per optimizer step, not once per micro-batch.
  2. Produce gradients on its params that match the uncached path to within
     bf16 tolerance (we use fp32 CPU tensors here, so atol should be tight).
  3. Not trip `RuntimeError: Trying to backward through the graph a second time`
     under the grad-accum pattern.

**Design under test:**
`HyenaFilter.get_or_build_train_cache(L, fft_size)` returns a LEAF tensor
`k_leaf` whose grad accumulates across micro-batches. After all micro-batch
backwards, `flush_pending_filter_grads()` does one
`torch.autograd.backward(_k_graph, _k_leaf.grad)` to populate the filter
MLP params' `.grad`. Then `invalidate_cache()` resets state for the next
step.

Run:
    cd /home/mikeb/work/feather
    .venv/bin/pytest tests/test_hyena_train_cache.py -v
"""

from __future__ import annotations

import sys
from pathlib import Path

import pytest
import torch

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from hydra.hyena_block import HyenaBlock  # noqa: E402
from subsystems import hyena_pure  # noqa: E402


def _reset_rfft_counter():
    hyena_pure._fftconv_filter_rfft_count = 0


def _rfft_count() -> int:
    return hyena_pure._fftconv_filter_rfft_count


def test_train_cache_runs_filter_mlp_once_per_step(monkeypatch):
    """With HYDRA_HYENA_TRAIN_CACHE=1, the IMPLICIT FILTER MLP runs exactly
    once across N accum micro-batches, not once per micro-batch.

    We can't distinguish MLP forwards via the rfft counter alone (rfft also
    fires for `k_f` per micro-batch for graph-safety reasons, see
    `HyenaFilter.get_or_build_train_cache` docstring). We instead patch the
    `implicit_filter` Sequential's forward with a counting proxy and verify
    it ran once.
    """
    monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")

    torch.manual_seed(0)
    D, T = 32, 16
    block = HyenaBlock(d_model=D, seq_len=T)
    block.train()
    assert block.operator._use_train_cache is True

    # Count MLP forwards.
    orig_forward = block.operator.filter_fn.implicit_filter.forward
    n_calls = {"count": 0}

    def counting_forward(*args, **kwargs):
        n_calls["count"] += 1
        return orig_forward(*args, **kwargs)

    block.operator.filter_fn.implicit_filter.forward = counting_forward

    accum = 3
    for _ in range(accum):
        x = torch.randn(1, T, D)
        y = block(x)
        loss = y.pow(2).mean() / accum
        loss.backward()

    # EXACTLY 1 MLP forward total, not 3.
    assert n_calls["count"] == 1, (
        f"expected exactly 1 filter MLP forward under train-cache across "
        f"{accum} micro-batches, got {n_calls['count']}"
    )


def test_train_cache_no_backward_twice_error(monkeypatch):
    """Three micro-batches with train-cache on must NOT raise
    'Trying to backward through the graph a second time'.

    This is the core correctness guarantee. Without the fix, this test
    reliably reproduces the runtime error.
    """
    monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")

    torch.manual_seed(1)
    D, T = 32, 16
    block = HyenaBlock(d_model=D, seq_len=T)
    block.train()

    accum = 4
    # This must not raise.
    for _ in range(accum):
        x = torch.randn(1, T, D)
        y = block(x)
        loss = y.pow(2).mean() / accum
        loss.backward()

    # After all micro-batches, k_leaf.grad must be non-None (grad accumulated).
    k_leaf = block.operator.filter_fn._k_leaf
    assert k_leaf is not None, "train-cache failed to populate _k_leaf"
    assert k_leaf.grad is not None, "no accumulated gradient on _k_leaf"
    assert torch.isfinite(k_leaf.grad).all(), "k_leaf.grad has NaN/Inf"


def test_train_cache_flush_populates_filter_params(monkeypatch):
    """After flush, the filter MLP params must have non-zero, finite grads."""
    monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")

    torch.manual_seed(2)
    D, T = 32, 16
    block = HyenaBlock(d_model=D, seq_len=T)
    block.train()

    # Zero-init params' grads.
    for p in block.parameters():
        p.grad = None

    # Run 3 accum micro-batches.
    for _ in range(3):
        x = torch.randn(1, T, D)
        y = block(x)
        loss = y.pow(2).mean() / 3
        loss.backward()

    # Before flush, filter MLP params should have NO grad (the backward chain
    # was cut at k_leaf). Only params downstream of k_leaf (short_filter,
    # in_proj, out_proj) should have grads.
    # NOTE: the filter's `bias` is actually used AFTER the leaf stash (see
    # HyenaOperator.forward: bias comes from filter_fn.bias directly, not from
    # the cached k_leaf) so `bias.grad` WILL be populated by the direct path.
    for name, p in block.operator.filter_fn.implicit_filter.named_parameters():
        if p.requires_grad:
            assert p.grad is None or p.grad.abs().max() == 0, (
                f"implicit_filter.{name} has grad before flush — the leaf "
                f"cache didn't actually cut the graph"
            )

    # Flush: this invokes torch.autograd.backward(_k_graph, _k_leaf.grad).
    block.operator.flush_pending_filter_grads()

    # Now implicit_filter params must have real grads.
    for name, p in block.operator.filter_fn.implicit_filter.named_parameters():
        if p.requires_grad:
            assert p.grad is not None, f"implicit_filter.{name} has no grad after flush"
            assert torch.isfinite(p.grad).all(), f"implicit_filter.{name} grad NaN/Inf"
            # With 3 random micro-batches and dL/dy = 2*y/(B*T*D*3), the
            # propagated grad MUST be non-zero for every param that's
            # reachable from the filter output.
            assert p.grad.abs().max() > 0, (
                f"implicit_filter.{name}.grad is all zero — flush didn't "
                f"push the k_leaf.grad back"
            )


def test_train_cache_gradient_matches_uncached(monkeypatch):
    """Parameter gradients under train-cache must numerically match
    the uncached path within tolerance.

    We construct two identical blocks, run the same 3 micro-batches on each,
    flush train-cache, then compare .grad on every param.
    """
    torch.manual_seed(3)
    D, T = 32, 16

    # Block A: no cache (baseline).
    block_a = HyenaBlock(d_model=D, seq_len=T)
    block_a.train()
    # Block B: train-cache on, same weights.
    # Note: monkeypatch.setenv only affects env reads AT CONSTRUCTION; the
    # block reads the flag in __init__. So we set before constructing B.
    monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
    block_b = HyenaBlock(d_model=D, seq_len=T)
    block_b.load_state_dict(block_a.state_dict())
    block_b.train()
    # Verify the flag actually took effect.
    assert block_b.operator._use_train_cache is True
    assert block_a.operator._use_train_cache is False

    # Same 3 micro-batches.
    xs = [torch.randn(1, T, D) for _ in range(3)]

    for block, label in ((block_a, "a"), (block_b, "b")):
        for p in block.parameters():
            p.grad = None
        for x in xs:
            y = block(x)
            loss = y.pow(2).mean() / len(xs)
            loss.backward()

    # Flush train-cache (block_b only).
    block_b.operator.flush_pending_filter_grads()

    # Compare grads.
    state_a = dict(block_a.named_parameters())
    state_b = dict(block_b.named_parameters())
    max_abs_diff = 0.0
    max_diff_name = ""
    for name, p_a in state_a.items():
        p_b = state_b[name]
        if p_a.grad is None:
            assert p_b.grad is None or p_b.grad.abs().max() == 0, (
                f"{name}: A has no grad, B has nonzero grad"
            )
            continue
        assert p_b.grad is not None, f"{name}: A has grad, B has none"
        diff = (p_a.grad - p_b.grad).abs().max().item()
        if diff > max_abs_diff:
            max_abs_diff = diff
            max_diff_name = name

    # Tight tolerance: the two paths do the SAME math in fp32 CPU, just the
    # cached path defers the backward. Expected diff ≈ 0 modulo FP noise.
    assert max_abs_diff < 1e-4, (
        f"grad mismatch between cached and uncached paths: "
        f"max |Δgrad| = {max_abs_diff:.3e} on {max_diff_name!r}"
    )


def test_train_cache_invalidate_resets_state(monkeypatch):
    """After invalidate_cache(), the next step rebuilds k_graph fresh.

    Simulates the post-optimizer.step() lifecycle.
    """
    monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")

    torch.manual_seed(4)
    D, T = 32, 16
    block = HyenaBlock(d_model=D, seq_len=T)
    block.train()

    # Step 1: 2 micro-batches, flush, invalidate.
    for _ in range(2):
        y = block(torch.randn(1, T, D))
        (y.pow(2).mean() / 2).backward()
    assert block.operator.filter_fn._k_graph is not None
    block.operator.flush_pending_filter_grads()
    block.operator.invalidate_filter_cache()
    assert block.operator.filter_fn._k_graph is None
    assert block.operator.filter_fn._k_leaf is None

    # Zero filter params' grads (simulating optimizer.zero_grad())
    for p in block.parameters():
        p.grad = None

    # Step 2: must work the same (not use stale state).
    for _ in range(2):
        y = block(torch.randn(1, T, D))
        (y.pow(2).mean() / 2).backward()
    assert block.operator.filter_fn._k_graph is not None, (
        "second step failed to rebuild k_graph"
    )
    block.operator.flush_pending_filter_grads()
    # All filter MLP params got grad again.
    for name, p in block.operator.filter_fn.implicit_filter.named_parameters():
        if p.requires_grad:
            assert p.grad is not None, f"step 2: {name} has no grad"


def test_train_cache_disabled_by_default(monkeypatch):
    """Unset env var → train-cache OFF → filter runs per micro-batch as before."""
    monkeypatch.delenv("HYDRA_HYENA_TRAIN_CACHE", raising=False)

    torch.manual_seed(5)
    D, T = 32, 16
    block = HyenaBlock(d_model=D, seq_len=T)
    assert block.operator._use_train_cache is False


def test_train_cache_forward_output_matches_uncached(monkeypatch):
    """Cached vs uncached forward outputs must match numerically."""
    torch.manual_seed(6)
    D, T = 32, 16

    # Uncached.
    block_a = HyenaBlock(d_model=D, seq_len=T)
    block_a.eval()

    # Cached copy.
    monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")
    block_b = HyenaBlock(d_model=D, seq_len=T)
    block_b.load_state_dict(block_a.state_dict())
    block_b.train()  # train-cache only activates under grad_enabled

    x = torch.randn(1, T, D)
    y_a = block_a(x)  # uncached path (no grad → eval mode anyway)
    y_b = block_b(x)  # cached path

    max_diff = (y_a - y_b).abs().max().item()
    assert max_diff < 1e-5, f"forward drift under train-cache: {max_diff:.3e}"


def test_flush_is_no_op_on_second_call(monkeypatch):
    """Idempotent flush: second call in the same step must not crash."""
    monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")

    torch.manual_seed(7)
    D, T = 32, 16
    block = HyenaBlock(d_model=D, seq_len=T)
    block.train()

    y = block(torch.randn(1, T, D))
    y.pow(2).mean().backward()

    # First flush: real work.
    block.operator.flush_pending_filter_grads()
    # Second flush: must silently no-op.
    block.operator.flush_pending_filter_grads()


def test_flush_is_no_op_when_no_forward(monkeypatch):
    """If no Hyena forward ran this step, flush is a safe no-op."""
    monkeypatch.setenv("HYDRA_HYENA_TRAIN_CACHE", "1")

    D, T = 32, 16
    block = HyenaBlock(d_model=D, seq_len=T)
    block.train()

    # No forward called. Flush should just return.
    block.operator.flush_pending_filter_grads()


if __name__ == "__main__":
    sys.exit(pytest.main([__file__, "-v"]))