Spaces:
Sleeping
Sleeping
| """Tests for the CPU-compatible attention patch in src/inference.py.""" | |
| import torch | |
| import torch.nn.functional as F | |
| def _cpu_sdpa_under_test(q, k, v, attn_mask=None): | |
| """Standalone copy of _cpu_sdpa (the patched attention) for testing. | |
| Mirrors the implementation in src.inference._patch_gatr_attention_for_cpu. | |
| """ | |
| B, H, N, D = q.shape | |
| scale = float(D) ** -0.5 | |
| q2 = q.reshape(B * H, N, D) | |
| k2 = k.reshape(B * H, N, D) | |
| v2 = v.reshape(B * H, N, D) | |
| attn = torch.bmm(q2 * scale, k2.transpose(1, 2)) | |
| if attn_mask is not None: | |
| attn = attn.masked_fill(~attn_mask.unsqueeze(0), float("-inf")) | |
| attn = torch.softmax(attn, dim=-1) | |
| attn = attn.nan_to_num(0.0) | |
| out = torch.bmm(attn, v2) | |
| return out.reshape(B, H, N, D) | |
| def test_cpu_sdpa_matches_reference(): | |
| """The CPU SDPA must agree with PyTorch's reference implementation.""" | |
| torch.manual_seed(42) | |
| B, H, N, D = 2, 4, 16, 32 | |
| q = torch.randn(B, H, N, D) | |
| k = torch.randn(B, H, N, D) | |
| v = torch.randn(B, H, N, D) | |
| out_ours = _cpu_sdpa_under_test(q, k, v) | |
| # PyTorch reference (no mask) | |
| out_ref = F.scaled_dot_product_attention(q, k, v) | |
| assert out_ours.shape == (B, H, N, D) | |
| assert torch.allclose(out_ours, out_ref, atol=1e-5), ( | |
| f"Max diff: {(out_ours - out_ref).abs().max().item()}" | |
| ) | |
| def test_cpu_sdpa_output_shape(): | |
| """Output shape must be [B, H, N, D], matching the input convention.""" | |
| B, H, N, D = 1, 8, 64, 16 | |
| q = torch.randn(B, H, N, D) | |
| k = torch.randn(B, H, N, D) | |
| v = torch.randn(B, H, N, D) | |
| out = _cpu_sdpa_under_test(q, k, v) | |
| assert out.shape == (B, H, N, D) | |
| def test_cpu_sdpa_single_head(): | |
| """Single-head attention must work correctly.""" | |
| torch.manual_seed(0) | |
| B, H, N, D = 1, 1, 10, 8 | |
| q = torch.randn(B, H, N, D) | |
| k = torch.randn(B, H, N, D) | |
| v = torch.randn(B, H, N, D) | |
| out_ours = _cpu_sdpa_under_test(q, k, v) | |
| out_ref = F.scaled_dot_product_attention(q, k, v) | |
| assert torch.allclose(out_ours, out_ref, atol=1e-5) | |
| def test_cpu_sdpa_asymmetric_heads_items(): | |
| """Ensure heads and items dimensions are not confused. | |
| When H != N, swapping them would change the tensor layout and | |
| produce different (wrong) results. | |
| """ | |
| torch.manual_seed(123) | |
| B, H, N, D = 1, 3, 7, 16 # H != N | |
| q = torch.randn(B, H, N, D) | |
| k = torch.randn(B, H, N, D) | |
| v = torch.randn(B, H, N, D) | |
| out_ours = _cpu_sdpa_under_test(q, k, v) | |
| out_ref = F.scaled_dot_product_attention(q, k, v) | |
| assert out_ours.shape == (B, H, N, D) | |
| assert torch.allclose(out_ours, out_ref, atol=1e-5), ( | |
| f"Max diff: {(out_ours - out_ref).abs().max().item()}" | |
| ) | |
| if __name__ == "__main__": | |
| test_cpu_sdpa_matches_reference() | |
| test_cpu_sdpa_output_shape() | |
| test_cpu_sdpa_single_head() | |
| test_cpu_sdpa_asymmetric_heads_items() | |
| print("All tests passed.") | |