File size: 7,361 Bytes
c475135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""Flash-FFT-conv integration: opt-in fast path, graceful fallback.

**What this validates:**
  * When `flashfftconv` is NOT importable, `fftconv_ref` falls back silently
    to the pure-PyTorch path regardless of env-var value.
  * `HYDRA_HYENA_FLASH_FFT=0` (default) always uses the pure path.
  * The env-var gate + import-probe gate are independent; both must pass for
    the fast path to activate.
  * The vendored source tree is present and structurally sane (csrc/,
    flashfftconv/, LICENSE) so offline builds remain possible.

Numeric equivalence between the CUDA kernel and the pure path is validated
separately when flashfftconv is actually built β€” that requires a specific
GPU arch match and is run manually (see `test_flash_fft_vs_pytorch_fftconv`).

Run:
    cd /home/mikeb/work/feather
    .venv/bin/pytest tests/test_flash_fft_integration.py -v
"""

from __future__ import annotations

import os
import sys
from pathlib import Path

import pytest
import torch

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from subsystems import hyena_pure  # noqa: E402
from subsystems.hyena_pure import (  # noqa: E402
    _FLASH_FFT_SUPPORTED_SIZES,
    _flash_fft_conv_supported,
    _try_load_flash_fft_conv,
    fftconv_ref,
)


def test_flash_fft_conv_supported_matrix():
    """Supported seqlens are the specific power-of-2 grid the kernel handles."""
    assert _flash_fft_conv_supported(4096, torch.bfloat16) is True
    assert _flash_fft_conv_supported(4096, torch.float16) is True
    # fp32 not supported (kernel requires 16-bit input).
    assert _flash_fft_conv_supported(4096, torch.float32) is False
    # Non-power-of-2 / off-grid.
    assert _flash_fft_conv_supported(4000, torch.bfloat16) is False
    # Very large β€” not in set.
    assert _flash_fft_conv_supported(2**24, torch.bfloat16) is False


def test_flash_fft_supported_set_matches_expected():
    """The supported set must include every fft_size HYDRA may reach.

    HYDRA's Hyena uses fft_size = 2 * sequence_len. Sequence lengths in
    practice: 512, 1024, 2048, 4096. β†’ fft sizes 1024, 2048, 4096, 8192.
    All must be in the supported set.
    """
    for s in (1024, 2048, 4096, 8192):
        assert s in _FLASH_FFT_SUPPORTED_SIZES, (
            f"fft_size {s} must be supported for HYDRA sequence length "
            f"{s // 2}"
        )


def test_pure_path_used_when_env_off(monkeypatch):
    """HYDRA_HYENA_FLASH_FFT=0 (or unset) β†’ pure PyTorch path."""
    monkeypatch.delenv("HYDRA_HYENA_FLASH_FFT", raising=False)

    torch.manual_seed(0)
    B, D, L = 1, 8, 16
    u = torch.randn(B, D, L)
    k = torch.randn(D, L)
    D_bias = torch.randn(D)

    # Count filter rfft invocations β€” the pure path calls it once when k_f is None.
    hyena_pure._fftconv_filter_rfft_count = 0
    y = fftconv_ref(u, k, D_bias, gelu=False)
    assert y.shape == (B, D, L)
    # Pure path: exactly one filter rfft (k_f was None).
    assert hyena_pure._fftconv_filter_rfft_count == 1


def test_try_load_flash_fft_conv_memoized():
    """_try_load_flash_fft_conv probes once and memoizes the result."""
    # Reset memo so this test can observe the probe.
    hyena_pure._flash_fft_conv_cls = None
    hyena_pure._flash_fft_conv_probed = False

    r1 = _try_load_flash_fft_conv()
    assert hyena_pure._flash_fft_conv_probed is True
    r2 = _try_load_flash_fft_conv()
    assert r1 is r2, "second probe must return the memoized value"


def test_fallback_when_flash_fft_unavailable(monkeypatch):
    """HYDRA_HYENA_FLASH_FFT=1 + flashfftconv unimportable β†’ pure path.

    Fallback must be silent (stderr warning but no crash, no behavior change).
    """
    monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1")
    # Force the probe to record "unavailable" regardless of what's installed.
    monkeypatch.setattr(hyena_pure, "_flash_fft_conv_cls", None)
    monkeypatch.setattr(hyena_pure, "_flash_fft_conv_probed", True)

    torch.manual_seed(1)
    B, D, L = 1, 8, 16
    u = torch.randn(B, D, L)
    k = torch.randn(D, L)
    D_bias = torch.randn(D)

    y = fftconv_ref(u, k, D_bias, gelu=False)
    assert y.shape == (B, D, L)
    assert torch.isfinite(y).all()


def test_fallback_when_dtype_unsupported(monkeypatch):
    """fp32 input + env on β†’ falls back even if flashfftconv present."""
    monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1")

    torch.manual_seed(2)
    B, D, L = 1, 8, 16
    u = torch.randn(B, D, L, dtype=torch.float32)
    k = torch.randn(D, L, dtype=torch.float32)  # fp32 is NOT supported
    D_bias = torch.randn(D)

    y = fftconv_ref(u, k, D_bias, gelu=False)
    # Pure path handles fp32 fine.
    assert y.dtype == torch.float32
    assert torch.isfinite(y).all()


def test_fallback_when_k_is_higher_rank(monkeypatch):
    """k.dim()>2 (reverse-filter path) β†’ fall back. HYDRA doesn't use this."""
    monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1")

    torch.manual_seed(3)
    B, D, L = 1, 8, 16
    u = torch.randn(B, D, L)
    # k shape [C, D, L] β€” upstream reverse-filter shape; kernel doesn't handle it.
    k = torch.randn(2, D, L)
    D_bias = torch.randn(D)

    # The upstream pure-path handles 3-D k by unsqueeze; we must not fast-path.
    # Pass k_f=None to force the fall-through.
    # Reshape to [D, L] so the pure path accepts it for this test.
    y = fftconv_ref(u, k[0], D_bias, gelu=False)
    assert y.shape == (B, D, L)


def test_vendored_source_tree_intact():
    """The vendored flash-fft-conv source files must exist at known paths."""
    root = Path(__file__).resolve().parents[1] / "kernels" / "cuda" / "flashfftconv"
    assert root.exists()
    assert (root / "LICENSE").exists()
    assert (root / "UPSTREAM_COMMIT").exists()
    assert (root / "csrc").exists()
    assert (root / "csrc" / "setup.py").exists()
    assert (root / "flashfftconv").exists()
    assert (root / "flashfftconv" / "conv.py").exists()
    # LICENSE must be Apache 2.0 (pin β€” if this drifts, update the vendor).
    license_text = (root / "LICENSE").read_text()
    assert "Apache License" in license_text


@pytest.mark.skipif(
    _try_load_flash_fft_conv() is None or not torch.cuda.is_available(),
    reason="flashfftconv not installed or CUDA unavailable",
)
def test_flash_fft_vs_pytorch_fftconv_numeric_equivalence():
    """When the kernel IS available, its output must match pure PyTorch
    within bf16 tolerance.

    This test only runs on machines with a successful flashfftconv build.
    See kernels/cuda/flashfftconv/README.md for setup instructions.
    """
    torch.manual_seed(42)
    B, D, L = 2, 16, 2048
    fft_size = 2 * L
    assert fft_size in _FLASH_FFT_SUPPORTED_SIZES

    u = torch.randn(B, D, L, device="cuda", dtype=torch.bfloat16)
    k = torch.randn(D, L, device="cuda", dtype=torch.bfloat16)
    D_bias = torch.randn(D, device="cuda", dtype=torch.bfloat16)

    os.environ["HYDRA_HYENA_FLASH_FFT"] = "0"
    y_pure = fftconv_ref(u, k, D_bias, gelu=False)

    os.environ["HYDRA_HYENA_FLASH_FFT"] = "1"
    y_flash = fftconv_ref(u, k, D_bias, gelu=False)

    max_abs_diff = (y_pure - y_flash).abs().max().item()
    # bf16 tolerance target from the task spec.
    assert max_abs_diff < 1e-3, (
        f"flash-fft-conv vs pure-PyTorch disagree: |Ξ”| max = {max_abs_diff:.3e}"
    )


if __name__ == "__main__":
    sys.exit(pytest.main([__file__, "-v"]))