"""Flash-FFT-conv integration: opt-in fast path, graceful fallback. **What this validates:** * When `flashfftconv` is NOT importable, `fftconv_ref` falls back silently to the pure-PyTorch path regardless of env-var value. * `HYDRA_HYENA_FLASH_FFT=0` (default) always uses the pure path. * The env-var gate + import-probe gate are independent; both must pass for the fast path to activate. * The vendored source tree is present and structurally sane (csrc/, flashfftconv/, LICENSE) so offline builds remain possible. Numeric equivalence between the CUDA kernel and the pure path is validated separately when flashfftconv is actually built — that requires a specific GPU arch match and is run manually (see `test_flash_fft_vs_pytorch_fftconv`). Run: cd /home/mikeb/work/feather .venv/bin/pytest tests/test_flash_fft_integration.py -v """ from __future__ import annotations import os import sys from pathlib import Path import pytest import torch sys.path.insert(0, str(Path(__file__).resolve().parents[1])) from subsystems import hyena_pure # noqa: E402 from subsystems.hyena_pure import ( # noqa: E402 _FLASH_FFT_SUPPORTED_SIZES, _flash_fft_conv_supported, _try_load_flash_fft_conv, fftconv_ref, ) def test_flash_fft_conv_supported_matrix(): """Supported seqlens are the specific power-of-2 grid the kernel handles.""" assert _flash_fft_conv_supported(4096, torch.bfloat16) is True assert _flash_fft_conv_supported(4096, torch.float16) is True # fp32 not supported (kernel requires 16-bit input). assert _flash_fft_conv_supported(4096, torch.float32) is False # Non-power-of-2 / off-grid. assert _flash_fft_conv_supported(4000, torch.bfloat16) is False # Very large — not in set. assert _flash_fft_conv_supported(2**24, torch.bfloat16) is False def test_flash_fft_supported_set_matches_expected(): """The supported set must include every fft_size HYDRA may reach. HYDRA's Hyena uses fft_size = 2 * sequence_len. Sequence lengths in practice: 512, 1024, 2048, 4096. → fft sizes 1024, 2048, 4096, 8192. All must be in the supported set. """ for s in (1024, 2048, 4096, 8192): assert s in _FLASH_FFT_SUPPORTED_SIZES, ( f"fft_size {s} must be supported for HYDRA sequence length " f"{s // 2}" ) def test_pure_path_used_when_env_off(monkeypatch): """HYDRA_HYENA_FLASH_FFT=0 (or unset) → pure PyTorch path.""" monkeypatch.delenv("HYDRA_HYENA_FLASH_FFT", raising=False) torch.manual_seed(0) B, D, L = 1, 8, 16 u = torch.randn(B, D, L) k = torch.randn(D, L) D_bias = torch.randn(D) # Count filter rfft invocations — the pure path calls it once when k_f is None. hyena_pure._fftconv_filter_rfft_count = 0 y = fftconv_ref(u, k, D_bias, gelu=False) assert y.shape == (B, D, L) # Pure path: exactly one filter rfft (k_f was None). assert hyena_pure._fftconv_filter_rfft_count == 1 def test_try_load_flash_fft_conv_memoized(): """_try_load_flash_fft_conv probes once and memoizes the result.""" # Reset memo so this test can observe the probe. hyena_pure._flash_fft_conv_cls = None hyena_pure._flash_fft_conv_probed = False r1 = _try_load_flash_fft_conv() assert hyena_pure._flash_fft_conv_probed is True r2 = _try_load_flash_fft_conv() assert r1 is r2, "second probe must return the memoized value" def test_fallback_when_flash_fft_unavailable(monkeypatch): """HYDRA_HYENA_FLASH_FFT=1 + flashfftconv unimportable → pure path. Fallback must be silent (stderr warning but no crash, no behavior change). """ monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1") # Force the probe to record "unavailable" regardless of what's installed. monkeypatch.setattr(hyena_pure, "_flash_fft_conv_cls", None) monkeypatch.setattr(hyena_pure, "_flash_fft_conv_probed", True) torch.manual_seed(1) B, D, L = 1, 8, 16 u = torch.randn(B, D, L) k = torch.randn(D, L) D_bias = torch.randn(D) y = fftconv_ref(u, k, D_bias, gelu=False) assert y.shape == (B, D, L) assert torch.isfinite(y).all() def test_fallback_when_dtype_unsupported(monkeypatch): """fp32 input + env on → falls back even if flashfftconv present.""" monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1") torch.manual_seed(2) B, D, L = 1, 8, 16 u = torch.randn(B, D, L, dtype=torch.float32) k = torch.randn(D, L, dtype=torch.float32) # fp32 is NOT supported D_bias = torch.randn(D) y = fftconv_ref(u, k, D_bias, gelu=False) # Pure path handles fp32 fine. assert y.dtype == torch.float32 assert torch.isfinite(y).all() def test_fallback_when_k_is_higher_rank(monkeypatch): """k.dim()>2 (reverse-filter path) → fall back. HYDRA doesn't use this.""" monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1") torch.manual_seed(3) B, D, L = 1, 8, 16 u = torch.randn(B, D, L) # k shape [C, D, L] — upstream reverse-filter shape; kernel doesn't handle it. k = torch.randn(2, D, L) D_bias = torch.randn(D) # The upstream pure-path handles 3-D k by unsqueeze; we must not fast-path. # Pass k_f=None to force the fall-through. # Reshape to [D, L] so the pure path accepts it for this test. y = fftconv_ref(u, k[0], D_bias, gelu=False) assert y.shape == (B, D, L) def test_vendored_source_tree_intact(): """The vendored flash-fft-conv source files must exist at known paths.""" root = Path(__file__).resolve().parents[1] / "kernels" / "cuda" / "flashfftconv" assert root.exists() assert (root / "LICENSE").exists() assert (root / "UPSTREAM_COMMIT").exists() assert (root / "csrc").exists() assert (root / "csrc" / "setup.py").exists() assert (root / "flashfftconv").exists() assert (root / "flashfftconv" / "conv.py").exists() # LICENSE must be Apache 2.0 (pin — if this drifts, update the vendor). license_text = (root / "LICENSE").read_text() assert "Apache License" in license_text @pytest.mark.skipif( _try_load_flash_fft_conv() is None or not torch.cuda.is_available(), reason="flashfftconv not installed or CUDA unavailable", ) def test_flash_fft_vs_pytorch_fftconv_numeric_equivalence(): """When the kernel IS available, its output must match pure PyTorch within bf16 tolerance. This test only runs on machines with a successful flashfftconv build. See kernels/cuda/flashfftconv/README.md for setup instructions. """ torch.manual_seed(42) B, D, L = 2, 16, 2048 fft_size = 2 * L assert fft_size in _FLASH_FFT_SUPPORTED_SIZES u = torch.randn(B, D, L, device="cuda", dtype=torch.bfloat16) k = torch.randn(D, L, device="cuda", dtype=torch.bfloat16) D_bias = torch.randn(D, device="cuda", dtype=torch.bfloat16) os.environ["HYDRA_HYENA_FLASH_FFT"] = "0" y_pure = fftconv_ref(u, k, D_bias, gelu=False) os.environ["HYDRA_HYENA_FLASH_FFT"] = "1" y_flash = fftconv_ref(u, k, D_bias, gelu=False) max_abs_diff = (y_pure - y_flash).abs().max().item() # bf16 tolerance target from the task spec. assert max_abs_diff < 1e-3, ( f"flash-fft-conv vs pure-PyTorch disagree: |Δ| max = {max_abs_diff:.3e}" ) if __name__ == "__main__": sys.exit(pytest.main([__file__, "-v"]))