| """ |
| Test Flash Attention unified interface - verify FA3 and SDPA produce identical results. |
| |
| Run: python -m pytest tests/test_attention_fallback.py -v -s |
| |
| Note on test structure: |
| Tests are split into two classes due to dtype/device constraints: |
| |
| 1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs |
| and verify they produce identical results. These require a Hopper GPU (FA3 only |
| works on sm90+) and use bfloat16 (FA3 doesn't support float32). |
| |
| 2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run |
| on any device (CUDA, CPU, MPS) with the appropriate dtype for that device. |
| """ |
| import torch |
| import pytest |
| import nanochat.flash_attention as fa_module |
| from nanochat.flash_attention import flash_attn, HAS_FA3 |
| from nanochat.engine import KVCache |
|
|
|
|
| def set_impl(impl): |
| """Set the implementation override ('fa3', 'sdpa', or None for auto).""" |
| fa_module._override_impl = impl |
|
|
|
|
| def run_both_impls(fn): |
| """Run a function with both FA3 and SDPA, return both outputs.""" |
| set_impl('fa3') |
| out_fa3 = fn() |
| set_impl('sdpa') |
| out_sdpa = fn() |
| set_impl(None) |
| return out_fa3, out_sdpa |
|
|
|
|
| def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2): |
| """Assert two tensors are close, with helpful error message.""" |
| max_diff = (t1 - t2).abs().max().item() |
| mean_diff = (t1 - t2).abs().mean().item() |
| assert torch.allclose(t1, t2, atol=atol, rtol=rtol), \ |
| f"{name}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}" |
| return max_diff, mean_diff |
|
|
|
|
| |
| |
| |
| @pytest.mark.skipif(not HAS_FA3, reason="FA3 required to compare implementations") |
| class TestFA3VsSDPA: |
| """Compare FA3 and SDPA produce identical results. Requires Hopper GPU.""" |
|
|
| DEVICE = "cuda" |
| DTYPE = torch.bfloat16 |
|
|
| def test_basic_causal(self): |
| """Basic causal attention.""" |
| B, T, H, D = 2, 64, 4, 32 |
| q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
|
|
| def run(): |
| return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) |
|
|
| y_fa3, y_sdpa = run_both_impls(run) |
| max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "basic_causal") |
| print(f"basic_causal: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") |
|
|
| def test_full_context(self): |
| """Full context (window_size=-1).""" |
| B, T, H, D = 2, 128, 4, 32 |
| q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
|
|
| def run(): |
| return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1)) |
|
|
| y_fa3, y_sdpa = run_both_impls(run) |
| max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "full_context") |
| print(f"full_context: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") |
|
|
| def test_sliding_window(self): |
| """Sliding window attention.""" |
| B, T, H, D = 2, 128, 4, 32 |
| window = 32 |
| q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
|
|
| def run(): |
| return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0)) |
|
|
| y_fa3, y_sdpa = run_both_impls(run) |
| max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "sliding_window") |
| print(f"sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") |
|
|
| def test_gqa(self): |
| """Group Query Attention (fewer KV heads than Q heads).""" |
| B, T, D = 2, 64, 32 |
| n_heads = 8 |
| n_kv_heads = 2 |
|
|
| q = torch.randn(B, T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE) |
| k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE) |
| v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE) |
|
|
| def run(): |
| return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) |
|
|
| y_fa3, y_sdpa = run_both_impls(run) |
| max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "gqa") |
| print(f"gqa: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") |
|
|
| def test_larger_model(self): |
| """Larger dimensions closer to real model.""" |
| B, T, H, D = 4, 256, 12, 64 |
| q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
|
|
| def run(): |
| return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1)) |
|
|
| y_fa3, y_sdpa = run_both_impls(run) |
| max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "larger_model") |
| print(f"larger_model: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") |
|
|
| def test_kvcache_prefill(self): |
| """Test prefill (inserting multiple tokens into empty cache).""" |
| B, T_max, H, D = 2, 64, 4, 32 |
| T_prefill = 16 |
|
|
| q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) |
|
|
| def run(): |
| k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| cache_seqlens = torch.zeros(B, dtype=torch.int32, device=self.DEVICE) |
| return flash_attn.flash_attn_with_kvcache( |
| q, k_cache, v_cache, k=k, v=v, |
| cache_seqlens=cache_seqlens, |
| causal=True, window_size=(T_max, 0) |
| ) |
|
|
| y_fa3, y_sdpa = run_both_impls(run) |
| max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "prefill") |
| print(f"prefill: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") |
|
|
| def test_kvcache_single_token(self): |
| """Test single token generation (cache already has content).""" |
| B, T_max, H, D = 2, 64, 4, 32 |
| T_prefill = 16 |
|
|
| k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) |
|
|
| def run(): |
| k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| k_cache[:, :T_prefill, :, :] = k_init |
| v_cache[:, :T_prefill, :, :] = v_init |
| cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE) |
| return flash_attn.flash_attn_with_kvcache( |
| q_single, k_cache, v_cache, k=k_single, v=v_single, |
| cache_seqlens=cache_seqlens, |
| causal=True, window_size=(T_max, 0) |
| ) |
|
|
| y_fa3, y_sdpa = run_both_impls(run) |
| max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token") |
| print(f"single_token: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") |
|
|
| def test_kvcache_single_token_sliding_window(self): |
| """Test single token decode with sliding window smaller than cache size. |
| |
| This catches the bug where SDPA ignores window_size during Tq=1 decode. |
| When window < Tk, FA3 only attends to the last (window+1) tokens, |
| but SDPA was attending to all cached tokens. |
| """ |
| B, T_max, H, D = 2, 64, 4, 32 |
| T_prefill = 32 |
| window = 8 |
|
|
| k_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| v_init = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) |
|
|
| def run(): |
| k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| k_cache[:, :T_prefill, :, :] = k_init |
| v_cache[:, :T_prefill, :, :] = v_init |
| cache_seqlens = torch.full((B,), T_prefill, dtype=torch.int32, device=self.DEVICE) |
| return flash_attn.flash_attn_with_kvcache( |
| q_single, k_cache, v_cache, k=k_single, v=v_single, |
| cache_seqlens=cache_seqlens, |
| causal=True, window_size=(window, 0) |
| ) |
|
|
| y_fa3, y_sdpa = run_both_impls(run) |
| max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "single_token_sliding_window") |
| print(f"single_token_sliding_window: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") |
|
|
| def test_backward_gradients_match(self): |
| """Verify gradients are similar between FA3 and SDPA.""" |
| B, T, H, D = 2, 32, 4, 16 |
|
|
| q_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| k_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| v_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
|
|
| def run(): |
| q = q_data.clone().requires_grad_(True) |
| k = k_data.clone().requires_grad_(True) |
| v = v_data.clone().requires_grad_(True) |
| y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) |
| loss = y.sum() |
| loss.backward() |
| return y.detach(), q.grad.detach(), k.grad.detach(), v.grad.detach() |
|
|
| set_impl('fa3') |
| y_fa3, q_grad_fa3, k_grad_fa3, v_grad_fa3 = run() |
| set_impl('sdpa') |
| y_sdpa, q_grad_sdpa, k_grad_sdpa, v_grad_sdpa = run() |
| set_impl(None) |
|
|
| max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "backward_output") |
| print(f"backward_output: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") |
|
|
| max_diff, mean_diff = assert_close(q_grad_fa3, q_grad_sdpa, "q_grad", atol=0.05, rtol=0.05) |
| print(f"q_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") |
|
|
| max_diff, mean_diff = assert_close(k_grad_fa3, k_grad_sdpa, "k_grad", atol=0.05, rtol=0.05) |
| print(f"k_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") |
|
|
| max_diff, mean_diff = assert_close(v_grad_fa3, v_grad_sdpa, "v_grad", atol=0.05, rtol=0.05) |
| print(f"v_grad: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") |
|
|
|
|
| |
| |
| |
| class TestSDPAOnly: |
| """Test SDPA fallback works correctly. Runs on any device.""" |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| def test_basic_forward(self): |
| """Test SDPA forward pass produces valid output.""" |
| set_impl('sdpa') |
| B, T, H, D = 2, 64, 4, 32 |
| q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) |
|
|
| y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) |
|
|
| assert y.shape == (B, T, H, D) |
| assert not torch.isnan(y).any(), "Output contains NaN" |
| set_impl(None) |
|
|
| def test_backward(self): |
| """Test gradients flow through SDPA.""" |
| set_impl('sdpa') |
| B, T, H, D = 2, 32, 4, 16 |
| q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True) |
| k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True) |
| v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True) |
|
|
| y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) |
| loss = y.sum() |
| loss.backward() |
|
|
| assert q.grad is not None, "No gradient for q" |
| assert k.grad is not None, "No gradient for k" |
| assert v.grad is not None, "No gradient for v" |
| assert not torch.isnan(q.grad).any(), "NaN in q gradient" |
| set_impl(None) |
|
|
| def test_kvcache(self): |
| """Test SDPA with KV cache.""" |
| set_impl('sdpa') |
| B, T_max, H, D = 2, 64, 4, 32 |
| n_layers = 1 |
|
|
| cache = KVCache( |
| batch_size=B, num_heads=H, seq_len=T_max, head_dim=D, |
| num_layers=n_layers, device=self.DEVICE, dtype=self.DTYPE |
| ) |
| k_cache, v_cache = cache.get_layer_cache(0) |
|
|
| |
| T_prefill = 16 |
| q = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| k = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| v = torch.randn(B, T_prefill, H, D, device=self.DEVICE, dtype=self.DTYPE) |
|
|
| y = flash_attn.flash_attn_with_kvcache( |
| q, k_cache, v_cache, k=k, v=v, |
| cache_seqlens=cache.cache_seqlens, |
| causal=True, window_size=(T_max, 0) |
| ) |
| cache.advance(T_prefill) |
|
|
| assert y.shape == (B, T_prefill, H, D) |
| assert cache.get_pos() == T_prefill |
|
|
| |
| q_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| k_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) |
| v_single = torch.randn(B, 1, H, D, device=self.DEVICE, dtype=self.DTYPE) |
|
|
| y_single = flash_attn.flash_attn_with_kvcache( |
| q_single, k_cache, v_cache, k=k_single, v=v_single, |
| cache_seqlens=cache.cache_seqlens, |
| causal=True, window_size=(T_max, 0) |
| ) |
| cache.advance(1) |
|
|
| assert y_single.shape == (B, 1, H, D) |
| assert cache.get_pos() == T_prefill + 1 |
| set_impl(None) |
|
|
|
|
| |
| |
| |
| class TestOverrideMechanism: |
| """Test that the override mechanism works correctly.""" |
|
|
| @pytest.mark.skipif(not HAS_FA3, reason="FA3 required") |
| def test_override_fa3(self): |
| """Test that override='fa3' uses FA3.""" |
| set_impl('fa3') |
| assert fa_module._use_fa3() == True |
| set_impl(None) |
|
|
| def test_override_sdpa(self): |
| """Test that override='sdpa' uses SDPA.""" |
| set_impl('sdpa') |
| assert fa_module._use_fa3() == False |
| set_impl(None) |
|
|
| def test_override_auto(self): |
| """Test that override=None uses auto-detection.""" |
| set_impl(None) |
| assert fa_module._use_fa3() == HAS_FA3 |
|
|
|
|
| if __name__ == "__main__": |
| print(f"PyTorch version: {torch.__version__}") |
| print(f"CUDA available: {torch.cuda.is_available()}") |
| if torch.cuda.is_available(): |
| print(f"CUDA device: {torch.cuda.get_device_name()}") |
| major, minor = torch.cuda.get_device_capability() |
| print(f"Compute capability: {major}.{minor}") |
| print(f"HAS_FA3: {HAS_FA3}") |
| print() |
|
|
| pytest.main([__file__, "-v", "-s"]) |
|
|