feather-a10g-large-runtime / overlay /tests /test_flash_fft_integration.py
icarus112's picture
Update Feather a10g-large training runtime image
c475135 verified
"""Flash-FFT-conv integration: opt-in fast path, graceful fallback.
**What this validates:**
* When `flashfftconv` is NOT importable, `fftconv_ref` falls back silently
to the pure-PyTorch path regardless of env-var value.
* `HYDRA_HYENA_FLASH_FFT=0` (default) always uses the pure path.
* The env-var gate + import-probe gate are independent; both must pass for
the fast path to activate.
* The vendored source tree is present and structurally sane (csrc/,
flashfftconv/, LICENSE) so offline builds remain possible.
Numeric equivalence between the CUDA kernel and the pure path is validated
separately when flashfftconv is actually built β€” that requires a specific
GPU arch match and is run manually (see `test_flash_fft_vs_pytorch_fftconv`).
Run:
cd /home/mikeb/work/feather
.venv/bin/pytest tests/test_flash_fft_integration.py -v
"""
from __future__ import annotations
import os
import sys
from pathlib import Path
import pytest
import torch
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
from subsystems import hyena_pure # noqa: E402
from subsystems.hyena_pure import ( # noqa: E402
_FLASH_FFT_SUPPORTED_SIZES,
_flash_fft_conv_supported,
_try_load_flash_fft_conv,
fftconv_ref,
)
def test_flash_fft_conv_supported_matrix():
"""Supported seqlens are the specific power-of-2 grid the kernel handles."""
assert _flash_fft_conv_supported(4096, torch.bfloat16) is True
assert _flash_fft_conv_supported(4096, torch.float16) is True
# fp32 not supported (kernel requires 16-bit input).
assert _flash_fft_conv_supported(4096, torch.float32) is False
# Non-power-of-2 / off-grid.
assert _flash_fft_conv_supported(4000, torch.bfloat16) is False
# Very large β€” not in set.
assert _flash_fft_conv_supported(2**24, torch.bfloat16) is False
def test_flash_fft_supported_set_matches_expected():
"""The supported set must include every fft_size HYDRA may reach.
HYDRA's Hyena uses fft_size = 2 * sequence_len. Sequence lengths in
practice: 512, 1024, 2048, 4096. β†’ fft sizes 1024, 2048, 4096, 8192.
All must be in the supported set.
"""
for s in (1024, 2048, 4096, 8192):
assert s in _FLASH_FFT_SUPPORTED_SIZES, (
f"fft_size {s} must be supported for HYDRA sequence length "
f"{s // 2}"
)
def test_pure_path_used_when_env_off(monkeypatch):
"""HYDRA_HYENA_FLASH_FFT=0 (or unset) β†’ pure PyTorch path."""
monkeypatch.delenv("HYDRA_HYENA_FLASH_FFT", raising=False)
torch.manual_seed(0)
B, D, L = 1, 8, 16
u = torch.randn(B, D, L)
k = torch.randn(D, L)
D_bias = torch.randn(D)
# Count filter rfft invocations β€” the pure path calls it once when k_f is None.
hyena_pure._fftconv_filter_rfft_count = 0
y = fftconv_ref(u, k, D_bias, gelu=False)
assert y.shape == (B, D, L)
# Pure path: exactly one filter rfft (k_f was None).
assert hyena_pure._fftconv_filter_rfft_count == 1
def test_try_load_flash_fft_conv_memoized():
"""_try_load_flash_fft_conv probes once and memoizes the result."""
# Reset memo so this test can observe the probe.
hyena_pure._flash_fft_conv_cls = None
hyena_pure._flash_fft_conv_probed = False
r1 = _try_load_flash_fft_conv()
assert hyena_pure._flash_fft_conv_probed is True
r2 = _try_load_flash_fft_conv()
assert r1 is r2, "second probe must return the memoized value"
def test_fallback_when_flash_fft_unavailable(monkeypatch):
"""HYDRA_HYENA_FLASH_FFT=1 + flashfftconv unimportable β†’ pure path.
Fallback must be silent (stderr warning but no crash, no behavior change).
"""
monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1")
# Force the probe to record "unavailable" regardless of what's installed.
monkeypatch.setattr(hyena_pure, "_flash_fft_conv_cls", None)
monkeypatch.setattr(hyena_pure, "_flash_fft_conv_probed", True)
torch.manual_seed(1)
B, D, L = 1, 8, 16
u = torch.randn(B, D, L)
k = torch.randn(D, L)
D_bias = torch.randn(D)
y = fftconv_ref(u, k, D_bias, gelu=False)
assert y.shape == (B, D, L)
assert torch.isfinite(y).all()
def test_fallback_when_dtype_unsupported(monkeypatch):
"""fp32 input + env on β†’ falls back even if flashfftconv present."""
monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1")
torch.manual_seed(2)
B, D, L = 1, 8, 16
u = torch.randn(B, D, L, dtype=torch.float32)
k = torch.randn(D, L, dtype=torch.float32) # fp32 is NOT supported
D_bias = torch.randn(D)
y = fftconv_ref(u, k, D_bias, gelu=False)
# Pure path handles fp32 fine.
assert y.dtype == torch.float32
assert torch.isfinite(y).all()
def test_fallback_when_k_is_higher_rank(monkeypatch):
"""k.dim()>2 (reverse-filter path) β†’ fall back. HYDRA doesn't use this."""
monkeypatch.setenv("HYDRA_HYENA_FLASH_FFT", "1")
torch.manual_seed(3)
B, D, L = 1, 8, 16
u = torch.randn(B, D, L)
# k shape [C, D, L] β€” upstream reverse-filter shape; kernel doesn't handle it.
k = torch.randn(2, D, L)
D_bias = torch.randn(D)
# The upstream pure-path handles 3-D k by unsqueeze; we must not fast-path.
# Pass k_f=None to force the fall-through.
# Reshape to [D, L] so the pure path accepts it for this test.
y = fftconv_ref(u, k[0], D_bias, gelu=False)
assert y.shape == (B, D, L)
def test_vendored_source_tree_intact():
"""The vendored flash-fft-conv source files must exist at known paths."""
root = Path(__file__).resolve().parents[1] / "kernels" / "cuda" / "flashfftconv"
assert root.exists()
assert (root / "LICENSE").exists()
assert (root / "UPSTREAM_COMMIT").exists()
assert (root / "csrc").exists()
assert (root / "csrc" / "setup.py").exists()
assert (root / "flashfftconv").exists()
assert (root / "flashfftconv" / "conv.py").exists()
# LICENSE must be Apache 2.0 (pin β€” if this drifts, update the vendor).
license_text = (root / "LICENSE").read_text()
assert "Apache License" in license_text
@pytest.mark.skipif(
_try_load_flash_fft_conv() is None or not torch.cuda.is_available(),
reason="flashfftconv not installed or CUDA unavailable",
)
def test_flash_fft_vs_pytorch_fftconv_numeric_equivalence():
"""When the kernel IS available, its output must match pure PyTorch
within bf16 tolerance.
This test only runs on machines with a successful flashfftconv build.
See kernels/cuda/flashfftconv/README.md for setup instructions.
"""
torch.manual_seed(42)
B, D, L = 2, 16, 2048
fft_size = 2 * L
assert fft_size in _FLASH_FFT_SUPPORTED_SIZES
u = torch.randn(B, D, L, device="cuda", dtype=torch.bfloat16)
k = torch.randn(D, L, device="cuda", dtype=torch.bfloat16)
D_bias = torch.randn(D, device="cuda", dtype=torch.bfloat16)
os.environ["HYDRA_HYENA_FLASH_FFT"] = "0"
y_pure = fftconv_ref(u, k, D_bias, gelu=False)
os.environ["HYDRA_HYENA_FLASH_FFT"] = "1"
y_flash = fftconv_ref(u, k, D_bias, gelu=False)
max_abs_diff = (y_pure - y_flash).abs().max().item()
# bf16 tolerance target from the task spec.
assert max_abs_diff < 1e-3, (
f"flash-fft-conv vs pure-PyTorch disagree: |Ξ”| max = {max_abs_diff:.3e}"
)
if __name__ == "__main__":
sys.exit(pytest.main([__file__, "-v"]))