| """Flash-FFT-conv integration: opt-in fast path, graceful fallback. |
| |
| **What this validates:** |
| * When `flashfftconv` is NOT importable, `fftconv_ref` falls back silently |
| to the pure-PyTorch path regardless of env-var value. |
| * `HYDRA_HYENA_FLASH_FFT=0` (default) always uses the pure path. |
| * The env-var gate + import-probe gate are independent; both must pass for |
| the fast path to activate. |
| * The vendored source tree is present and structurally sane (csrc/, |
| flashfftconv/, LICENSE) so offline builds remain possible. |
| |
| Numeric equivalence between the CUDA kernel and the pure path is validated |
| separately when flashfftconv is actually built β that requires a specific |
| GPU arch match and is run manually (see `test_flash_fft_vs_pytorch_fftconv`). |
| |
| Run: |
| cd /home/mikeb/work/feather |
| .venv/bin/pytest tests/test_flash_fft_integration.py -v |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| import sys |
| from pathlib import Path |
|
|
| import pytest |
| import torch |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) |
|
|
| from subsystems import hyena_pure |
| from subsystems.hyena_pure import ( |
| _FLASH_FFT_SUPPORTED_SIZES, |
| _flash_fft_conv_supported, |
| _try_load_flash_fft_conv, |
| fftconv_ref, |
| ) |
|
|
|
|
| def test_flash_fft_conv_supported_matrix(): |
| """Supported seqlens are the specific power-of-2 grid the kernel handles.""" |
| assert _flash_fft_conv_supported(4096, torch.bfloat16) is True |
| assert _flash_fft_conv_supported(4096, torch.float16) is True |
| |
| assert _flash_fft_conv_supported(4096, torch.float32) is False |
| |
| assert _flash_fft_conv_supported(4000, torch.bfloat16) is False |
| |
| assert _flash_fft_conv_supported(2**24, torch.bfloat16) is False |
|
|
|
|
| def test_flash_fft_supported_set_matches_expected(): |
| """The supported set must include every fft_size HYDRA may reach. |
| |
| HYDRA's Hyena uses fft_size = 2 * sequence_len. Sequence lengths in |
| practice: 512, 1024, 2048, 4096. β fft sizes 1024, 2048, 4096, 8192. |
| All must be in the supported set. |
| """ |
| for s in (1024, 2048, 4096, 8192): |
| assert s in _FLASH_FFT_SUPPORTED_SIZES, ( |
| f"fft_size {s} must be supported for HYDRA sequence length " |
| f"{s // 2}" |
| ) |
|
|
|
|
| def test_pure_path_used_when_env_off(monkeypatch): |
| """HYDRA_HYENA_FLASH_FFT=0 (or unset) β pure PyTorch path.""" |
| monkeypatch.delenv("HYDRA_HYENA_FLASH_FFT", raising=False) |
|
|
| torch.manual_seed(0) |
| B, D, L = 1, 8, 16 |
| u = torch.randn(B, D, L) |
| k = torch.randn(D, L) |
| D_bias = torch.randn(D) |
|
|
| |
| hyena_pure._fftconv_filter_rfft_count = 0 |
| y = fftconv_ref(u, k, D_bias, gelu=False) |
| assert y.shape == (B, D, L) |
| |
| assert hyena_pure._fftconv_filter_rfft_count == 1 |
|
|
|
|
| def test_try_load_flash_fft_conv_memoized(): |
| """_try_load_flash_fft_conv probes once and memoizes the result.""" |
| |
| hyena_pure._flash_fft_conv_cls = None |
| hyena_pure._flash_fft_conv_probed = False |
|
|
| r1 = _try_load_flash_fft_conv() |
| assert hyena_pure._flash_fft_conv_probed is True |
| r2 = _try_load_flash_fft_conv() |
| assert r1 is r2, "second probe must return the memoized value" |
|
|
|
|
| def test_fallback_when_flash_fft_unavailable(monkeypatch): |
| """HYDRA_HYENA_FLASH_FFT=1 + flashfftconv unimportable β pure path. |
| |
| Fallback must be silent (stderr warning but no crash, no behavior change). |
| """ |
| monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1") |
| |
| monkeypatch.setattr(hyena_pure, "_flash_fft_conv_cls", None) |
| monkeypatch.setattr(hyena_pure, "_flash_fft_conv_probed", True) |
|
|
| torch.manual_seed(1) |
| B, D, L = 1, 8, 16 |
| u = torch.randn(B, D, L) |
| k = torch.randn(D, L) |
| D_bias = torch.randn(D) |
|
|
| y = fftconv_ref(u, k, D_bias, gelu=False) |
| assert y.shape == (B, D, L) |
| assert torch.isfinite(y).all() |
|
|
|
|
| def test_fallback_when_dtype_unsupported(monkeypatch): |
| """fp32 input + env on β falls back even if flashfftconv present.""" |
| monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1") |
|
|
| torch.manual_seed(2) |
| B, D, L = 1, 8, 16 |
| u = torch.randn(B, D, L, dtype=torch.float32) |
| k = torch.randn(D, L, dtype=torch.float32) |
| D_bias = torch.randn(D) |
|
|
| y = fftconv_ref(u, k, D_bias, gelu=False) |
| |
| assert y.dtype == torch.float32 |
| assert torch.isfinite(y).all() |
|
|
|
|
| def test_fallback_when_k_is_higher_rank(monkeypatch): |
| """k.dim()>2 (reverse-filter path) β fall back. HYDRA doesn't use this.""" |
| monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1") |
|
|
| torch.manual_seed(3) |
| B, D, L = 1, 8, 16 |
| u = torch.randn(B, D, L) |
| |
| k = torch.randn(2, D, L) |
| D_bias = torch.randn(D) |
|
|
| |
| |
| |
| y = fftconv_ref(u, k[0], D_bias, gelu=False) |
| assert y.shape == (B, D, L) |
|
|
|
|
| def test_vendored_source_tree_intact(): |
| """The vendored flash-fft-conv source files must exist at known paths.""" |
| root = Path(__file__).resolve().parents[1] / "kernels" / "cuda" / "flashfftconv" |
| assert root.exists() |
| assert (root / "LICENSE").exists() |
| assert (root / "UPSTREAM_COMMIT").exists() |
| assert (root / "csrc").exists() |
| assert (root / "csrc" / "setup.py").exists() |
| assert (root / "flashfftconv").exists() |
| assert (root / "flashfftconv" / "conv.py").exists() |
| |
| license_text = (root / "LICENSE").read_text() |
| assert "Apache License" in license_text |
|
|
|
|
| @pytest.mark.skipif( |
| _try_load_flash_fft_conv() is None or not torch.cuda.is_available(), |
| reason="flashfftconv not installed or CUDA unavailable", |
| ) |
| def test_flash_fft_vs_pytorch_fftconv_numeric_equivalence(): |
| """When the kernel IS available, its output must match pure PyTorch |
| within bf16 tolerance. |
| |
| This test only runs on machines with a successful flashfftconv build. |
| See kernels/cuda/flashfftconv/README.md for setup instructions. |
| """ |
| torch.manual_seed(42) |
| B, D, L = 2, 16, 2048 |
| fft_size = 2 * L |
| assert fft_size in _FLASH_FFT_SUPPORTED_SIZES |
|
|
| u = torch.randn(B, D, L, device="cuda", dtype=torch.bfloat16) |
| k = torch.randn(D, L, device="cuda", dtype=torch.bfloat16) |
| D_bias = torch.randn(D, device="cuda", dtype=torch.bfloat16) |
|
|
| os.environ["HYDRA_HYENA_FLASH_FFT"] = "0" |
| y_pure = fftconv_ref(u, k, D_bias, gelu=False) |
|
|
| os.environ["HYDRA_HYENA_FLASH_FFT"] = "1" |
| y_flash = fftconv_ref(u, k, D_bias, gelu=False) |
|
|
| max_abs_diff = (y_pure - y_flash).abs().max().item() |
| |
| assert max_abs_diff < 1e-3, ( |
| f"flash-fft-conv vs pure-PyTorch disagree: |Ξ| max = {max_abs_diff:.3e}" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(pytest.main([__file__, "-v"])) |
|
|