File size: 7,361 Bytes
c475135 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 | """Flash-FFT-conv integration: opt-in fast path, graceful fallback.
**What this validates:**
* When `flashfftconv` is NOT importable, `fftconv_ref` falls back silently
to the pure-PyTorch path regardless of env-var value.
* `HYDRA_HYENA_FLASH_FFT=0` (default) always uses the pure path.
* The env-var gate + import-probe gate are independent; both must pass for
the fast path to activate.
* The vendored source tree is present and structurally sane (csrc/,
flashfftconv/, LICENSE) so offline builds remain possible.
Numeric equivalence between the CUDA kernel and the pure path is validated
separately when flashfftconv is actually built β that requires a specific
GPU arch match and is run manually (see `test_flash_fft_vs_pytorch_fftconv`).
Run:
cd /home/mikeb/work/feather
.venv/bin/pytest tests/test_flash_fft_integration.py -v
"""
from __future__ import annotations
import os
import sys
from pathlib import Path
import pytest
import torch
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from subsystems import hyena_pure # noqa: E402
from subsystems.hyena_pure import ( # noqa: E402
_FLASH_FFT_SUPPORTED_SIZES,
_flash_fft_conv_supported,
_try_load_flash_fft_conv,
fftconv_ref,
)
def test_flash_fft_conv_supported_matrix():
"""Supported seqlens are the specific power-of-2 grid the kernel handles."""
assert _flash_fft_conv_supported(4096, torch.bfloat16) is True
assert _flash_fft_conv_supported(4096, torch.float16) is True
# fp32 not supported (kernel requires 16-bit input).
assert _flash_fft_conv_supported(4096, torch.float32) is False
# Non-power-of-2 / off-grid.
assert _flash_fft_conv_supported(4000, torch.bfloat16) is False
# Very large β not in set.
assert _flash_fft_conv_supported(2**24, torch.bfloat16) is False
def test_flash_fft_supported_set_matches_expected():
"""The supported set must include every fft_size HYDRA may reach.
HYDRA's Hyena uses fft_size = 2 * sequence_len. Sequence lengths in
practice: 512, 1024, 2048, 4096. β fft sizes 1024, 2048, 4096, 8192.
All must be in the supported set.
"""
for s in (1024, 2048, 4096, 8192):
assert s in _FLASH_FFT_SUPPORTED_SIZES, (
f"fft_size {s} must be supported for HYDRA sequence length "
f"{s // 2}"
)
def test_pure_path_used_when_env_off(monkeypatch):
"""HYDRA_HYENA_FLASH_FFT=0 (or unset) β pure PyTorch path."""
monkeypatch.delenv("HYDRA_HYENA_FLASH_FFT", raising=False)
torch.manual_seed(0)
B, D, L = 1, 8, 16
u = torch.randn(B, D, L)
k = torch.randn(D, L)
D_bias = torch.randn(D)
# Count filter rfft invocations β the pure path calls it once when k_f is None.
hyena_pure._fftconv_filter_rfft_count = 0
y = fftconv_ref(u, k, D_bias, gelu=False)
assert y.shape == (B, D, L)
# Pure path: exactly one filter rfft (k_f was None).
assert hyena_pure._fftconv_filter_rfft_count == 1
def test_try_load_flash_fft_conv_memoized():
"""_try_load_flash_fft_conv probes once and memoizes the result."""
# Reset memo so this test can observe the probe.
hyena_pure._flash_fft_conv_cls = None
hyena_pure._flash_fft_conv_probed = False
r1 = _try_load_flash_fft_conv()
assert hyena_pure._flash_fft_conv_probed is True
r2 = _try_load_flash_fft_conv()
assert r1 is r2, "second probe must return the memoized value"
def test_fallback_when_flash_fft_unavailable(monkeypatch):
"""HYDRA_HYENA_FLASH_FFT=1 + flashfftconv unimportable β pure path.
Fallback must be silent (stderr warning but no crash, no behavior change).
"""
monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1")
# Force the probe to record "unavailable" regardless of what's installed.
monkeypatch.setattr(hyena_pure, "_flash_fft_conv_cls", None)
monkeypatch.setattr(hyena_pure, "_flash_fft_conv_probed", True)
torch.manual_seed(1)
B, D, L = 1, 8, 16
u = torch.randn(B, D, L)
k = torch.randn(D, L)
D_bias = torch.randn(D)
y = fftconv_ref(u, k, D_bias, gelu=False)
assert y.shape == (B, D, L)
assert torch.isfinite(y).all()
def test_fallback_when_dtype_unsupported(monkeypatch):
"""fp32 input + env on β falls back even if flashfftconv present."""
monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1")
torch.manual_seed(2)
B, D, L = 1, 8, 16
u = torch.randn(B, D, L, dtype=torch.float32)
k = torch.randn(D, L, dtype=torch.float32) # fp32 is NOT supported
D_bias = torch.randn(D)
y = fftconv_ref(u, k, D_bias, gelu=False)
# Pure path handles fp32 fine.
assert y.dtype == torch.float32
assert torch.isfinite(y).all()
def test_fallback_when_k_is_higher_rank(monkeypatch):
"""k.dim()>2 (reverse-filter path) β fall back. HYDRA doesn't use this."""
monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1")
torch.manual_seed(3)
B, D, L = 1, 8, 16
u = torch.randn(B, D, L)
# k shape [C, D, L] β upstream reverse-filter shape; kernel doesn't handle it.
k = torch.randn(2, D, L)
D_bias = torch.randn(D)
# The upstream pure-path handles 3-D k by unsqueeze; we must not fast-path.
# Pass k_f=None to force the fall-through.
# Reshape to [D, L] so the pure path accepts it for this test.
y = fftconv_ref(u, k[0], D_bias, gelu=False)
assert y.shape == (B, D, L)
def test_vendored_source_tree_intact():
"""The vendored flash-fft-conv source files must exist at known paths."""
root = Path(__file__).resolve().parents[1] / "kernels" / "cuda" / "flashfftconv"
assert root.exists()
assert (root / "LICENSE").exists()
assert (root / "UPSTREAM_COMMIT").exists()
assert (root / "csrc").exists()
assert (root / "csrc" / "setup.py").exists()
assert (root / "flashfftconv").exists()
assert (root / "flashfftconv" / "conv.py").exists()
# LICENSE must be Apache 2.0 (pin β if this drifts, update the vendor).
license_text = (root / "LICENSE").read_text()
assert "Apache License" in license_text
@pytest.mark.skipif(
_try_load_flash_fft_conv() is None or not torch.cuda.is_available(),
reason="flashfftconv not installed or CUDA unavailable",
)
def test_flash_fft_vs_pytorch_fftconv_numeric_equivalence():
"""When the kernel IS available, its output must match pure PyTorch
within bf16 tolerance.
This test only runs on machines with a successful flashfftconv build.
See kernels/cuda/flashfftconv/README.md for setup instructions.
"""
torch.manual_seed(42)
B, D, L = 2, 16, 2048
fft_size = 2 * L
assert fft_size in _FLASH_FFT_SUPPORTED_SIZES
u = torch.randn(B, D, L, device="cuda", dtype=torch.bfloat16)
k = torch.randn(D, L, device="cuda", dtype=torch.bfloat16)
D_bias = torch.randn(D, device="cuda", dtype=torch.bfloat16)
os.environ["HYDRA_HYENA_FLASH_FFT"] = "0"
y_pure = fftconv_ref(u, k, D_bias, gelu=False)
os.environ["HYDRA_HYENA_FLASH_FFT"] = "1"
y_flash = fftconv_ref(u, k, D_bias, gelu=False)
max_abs_diff = (y_pure - y_flash).abs().max().item()
# bf16 tolerance target from the task spec.
assert max_abs_diff < 1e-3, (
f"flash-fft-conv vs pure-PyTorch disagree: |Ξ| max = {max_abs_diff:.3e}"
)
if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-v"]))
|