Spaces:
Runtime error
Runtime error
| """Flash-FFT-conv integration: opt-in fast path, graceful fallback. | |
| **What this validates:** | |
| * When `flashfftconv` is NOT importable, `fftconv_ref` falls back silently | |
| to the pure-PyTorch path regardless of env-var value. | |
| * `HYDRA_HYENA_FLASH_FFT=0` (default) always uses the pure path. | |
| * The env-var gate + import-probe gate are independent; both must pass for | |
| the fast path to activate. | |
| * The vendored source tree is present and structurally sane (csrc/, | |
| flashfftconv/, LICENSE) so offline builds remain possible. | |
| Numeric equivalence between the CUDA kernel and the pure path is validated | |
| separately when flashfftconv is actually built β that requires a specific | |
| GPU arch match and is run manually (see `test_flash_fft_vs_pytorch_fftconv`). | |
| Run: | |
| cd /home/mikeb/work/feather | |
| .venv/bin/pytest tests/test_flash_fft_integration.py -v | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sys | |
| from pathlib import Path | |
| import pytest | |
| import torch | |
| sys.path.insert(0, str(Path(__file__).resolve().parents[1])) | |
| from subsystems import hyena_pure # noqa: E402 | |
| from subsystems.hyena_pure import ( # noqa: E402 | |
| _FLASH_FFT_SUPPORTED_SIZES, | |
| _flash_fft_conv_supported, | |
| _try_load_flash_fft_conv, | |
| fftconv_ref, | |
| ) | |
| def test_flash_fft_conv_supported_matrix(): | |
| """Supported seqlens are the specific power-of-2 grid the kernel handles.""" | |
| assert _flash_fft_conv_supported(4096, torch.bfloat16) is True | |
| assert _flash_fft_conv_supported(4096, torch.float16) is True | |
| # fp32 not supported (kernel requires 16-bit input). | |
| assert _flash_fft_conv_supported(4096, torch.float32) is False | |
| # Non-power-of-2 / off-grid. | |
| assert _flash_fft_conv_supported(4000, torch.bfloat16) is False | |
| # Very large β not in set. | |
| assert _flash_fft_conv_supported(2**24, torch.bfloat16) is False | |
| def test_flash_fft_supported_set_matches_expected(): | |
| """The supported set must include every fft_size HYDRA may reach. | |
| HYDRA's Hyena uses fft_size = 2 * sequence_len. Sequence lengths in | |
| practice: 512, 1024, 2048, 4096. β fft sizes 1024, 2048, 4096, 8192. | |
| All must be in the supported set. | |
| """ | |
| for s in (1024, 2048, 4096, 8192): | |
| assert s in _FLASH_FFT_SUPPORTED_SIZES, ( | |
| f"fft_size {s} must be supported for HYDRA sequence length " | |
| f"{s // 2}" | |
| ) | |
| def test_pure_path_used_when_env_off(monkeypatch): | |
| """HYDRA_HYENA_FLASH_FFT=0 (or unset) β pure PyTorch path.""" | |
| monkeypatch.delenv("HYDRA_HYENA_FLASH_FFT", raising=False) | |
| torch.manual_seed(0) | |
| B, D, L = 1, 8, 16 | |
| u = torch.randn(B, D, L) | |
| k = torch.randn(D, L) | |
| D_bias = torch.randn(D) | |
| # Count filter rfft invocations β the pure path calls it once when k_f is None. | |
| hyena_pure._fftconv_filter_rfft_count = 0 | |
| y = fftconv_ref(u, k, D_bias, gelu=False) | |
| assert y.shape == (B, D, L) | |
| # Pure path: exactly one filter rfft (k_f was None). | |
| assert hyena_pure._fftconv_filter_rfft_count == 1 | |
| def test_try_load_flash_fft_conv_memoized(): | |
| """_try_load_flash_fft_conv probes once and memoizes the result.""" | |
| # Reset memo so this test can observe the probe. | |
| hyena_pure._flash_fft_conv_cls = None | |
| hyena_pure._flash_fft_conv_probed = False | |
| r1 = _try_load_flash_fft_conv() | |
| assert hyena_pure._flash_fft_conv_probed is True | |
| r2 = _try_load_flash_fft_conv() | |
| assert r1 is r2, "second probe must return the memoized value" | |
| def test_fallback_when_flash_fft_unavailable(monkeypatch): | |
| """HYDRA_HYENA_FLASH_FFT=1 + flashfftconv unimportable β pure path. | |
| Fallback must be silent (stderr warning but no crash, no behavior change). | |
| """ | |
| monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1") | |
| # Force the probe to record "unavailable" regardless of what's installed. | |
| monkeypatch.setattr(hyena_pure, "_flash_fft_conv_cls", None) | |
| monkeypatch.setattr(hyena_pure, "_flash_fft_conv_probed", True) | |
| torch.manual_seed(1) | |
| B, D, L = 1, 8, 16 | |
| u = torch.randn(B, D, L) | |
| k = torch.randn(D, L) | |
| D_bias = torch.randn(D) | |
| y = fftconv_ref(u, k, D_bias, gelu=False) | |
| assert y.shape == (B, D, L) | |
| assert torch.isfinite(y).all() | |
| def test_fallback_when_dtype_unsupported(monkeypatch): | |
| """fp32 input + env on β falls back even if flashfftconv present.""" | |
| monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1") | |
| torch.manual_seed(2) | |
| B, D, L = 1, 8, 16 | |
| u = torch.randn(B, D, L, dtype=torch.float32) | |
| k = torch.randn(D, L, dtype=torch.float32) # fp32 is NOT supported | |
| D_bias = torch.randn(D) | |
| y = fftconv_ref(u, k, D_bias, gelu=False) | |
| # Pure path handles fp32 fine. | |
| assert y.dtype == torch.float32 | |
| assert torch.isfinite(y).all() | |
| def test_fallback_when_k_is_higher_rank(monkeypatch): | |
| """k.dim()>2 (reverse-filter path) β fall back. HYDRA doesn't use this.""" | |
| monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1") | |
| torch.manual_seed(3) | |
| B, D, L = 1, 8, 16 | |
| u = torch.randn(B, D, L) | |
| # k shape [C, D, L] β upstream reverse-filter shape; kernel doesn't handle it. | |
| k = torch.randn(2, D, L) | |
| D_bias = torch.randn(D) | |
| # The upstream pure-path handles 3-D k by unsqueeze; we must not fast-path. | |
| # Pass k_f=None to force the fall-through. | |
| # Reshape to [D, L] so the pure path accepts it for this test. | |
| y = fftconv_ref(u, k[0], D_bias, gelu=False) | |
| assert y.shape == (B, D, L) | |
| def test_vendored_source_tree_intact(): | |
| """The vendored flash-fft-conv source files must exist at known paths.""" | |
| root = Path(__file__).resolve().parents[1] / "kernels" / "cuda" / "flashfftconv" | |
| assert root.exists() | |
| assert (root / "LICENSE").exists() | |
| assert (root / "UPSTREAM_COMMIT").exists() | |
| assert (root / "csrc").exists() | |
| assert (root / "csrc" / "setup.py").exists() | |
| assert (root / "flashfftconv").exists() | |
| assert (root / "flashfftconv" / "conv.py").exists() | |
| # LICENSE must be Apache 2.0 (pin β if this drifts, update the vendor). | |
| license_text = (root / "LICENSE").read_text() | |
| assert "Apache License" in license_text | |
| def test_flash_fft_vs_pytorch_fftconv_numeric_equivalence(): | |
| """When the kernel IS available, its output must match pure PyTorch | |
| within bf16 tolerance. | |
| This test only runs on machines with a successful flashfftconv build. | |
| See kernels/cuda/flashfftconv/README.md for setup instructions. | |
| """ | |
| torch.manual_seed(42) | |
| B, D, L = 2, 16, 2048 | |
| fft_size = 2 * L | |
| assert fft_size in _FLASH_FFT_SUPPORTED_SIZES | |
| u = torch.randn(B, D, L, device="cuda", dtype=torch.bfloat16) | |
| k = torch.randn(D, L, device="cuda", dtype=torch.bfloat16) | |
| D_bias = torch.randn(D, device="cuda", dtype=torch.bfloat16) | |
| os.environ["HYDRA_HYENA_FLASH_FFT"] = "0" | |
| y_pure = fftconv_ref(u, k, D_bias, gelu=False) | |
| os.environ["HYDRA_HYENA_FLASH_FFT"] = "1" | |
| y_flash = fftconv_ref(u, k, D_bias, gelu=False) | |
| max_abs_diff = (y_pure - y_flash).abs().max().item() | |
| # bf16 tolerance target from the task spec. | |
| assert max_abs_diff < 1e-3, ( | |
| f"flash-fft-conv vs pure-PyTorch disagree: |Ξ| max = {max_abs_diff:.3e}" | |
| ) | |
| if __name__ == "__main__": | |
| sys.exit(pytest.main([__file__, "-v"])) | |