HitPF_demo / tests /test_cpu_attention.py
github-actions[bot]
Sync from GitHub f6dbbfb
cc0720f
"""Tests for the CPU-compatible attention patch in src/inference.py."""
import torch
import torch.nn.functional as F
def _cpu_sdpa_under_test(q, k, v, attn_mask=None):
"""Standalone copy of _cpu_sdpa (the patched attention) for testing.
Mirrors the implementation in src.inference._patch_gatr_attention_for_cpu.
"""
B, H, N, D = q.shape
scale = float(D) ** -0.5
q2 = q.reshape(B * H, N, D)
k2 = k.reshape(B * H, N, D)
v2 = v.reshape(B * H, N, D)
attn = torch.bmm(q2 * scale, k2.transpose(1, 2))
if attn_mask is not None:
attn = attn.masked_fill(~attn_mask.unsqueeze(0), float("-inf"))
attn = torch.softmax(attn, dim=-1)
attn = attn.nan_to_num(0.0)
out = torch.bmm(attn, v2)
return out.reshape(B, H, N, D)
def test_cpu_sdpa_matches_reference():
"""The CPU SDPA must agree with PyTorch's reference implementation."""
torch.manual_seed(42)
B, H, N, D = 2, 4, 16, 32
q = torch.randn(B, H, N, D)
k = torch.randn(B, H, N, D)
v = torch.randn(B, H, N, D)
out_ours = _cpu_sdpa_under_test(q, k, v)
# PyTorch reference (no mask)
out_ref = F.scaled_dot_product_attention(q, k, v)
assert out_ours.shape == (B, H, N, D)
assert torch.allclose(out_ours, out_ref, atol=1e-5), (
f"Max diff: {(out_ours - out_ref).abs().max().item()}"
)
def test_cpu_sdpa_output_shape():
"""Output shape must be [B, H, N, D], matching the input convention."""
B, H, N, D = 1, 8, 64, 16
q = torch.randn(B, H, N, D)
k = torch.randn(B, H, N, D)
v = torch.randn(B, H, N, D)
out = _cpu_sdpa_under_test(q, k, v)
assert out.shape == (B, H, N, D)
def test_cpu_sdpa_single_head():
"""Single-head attention must work correctly."""
torch.manual_seed(0)
B, H, N, D = 1, 1, 10, 8
q = torch.randn(B, H, N, D)
k = torch.randn(B, H, N, D)
v = torch.randn(B, H, N, D)
out_ours = _cpu_sdpa_under_test(q, k, v)
out_ref = F.scaled_dot_product_attention(q, k, v)
assert torch.allclose(out_ours, out_ref, atol=1e-5)
def test_cpu_sdpa_asymmetric_heads_items():
"""Ensure heads and items dimensions are not confused.
When H != N, swapping them would change the tensor layout and
produce different (wrong) results.
"""
torch.manual_seed(123)
B, H, N, D = 1, 3, 7, 16 # H != N
q = torch.randn(B, H, N, D)
k = torch.randn(B, H, N, D)
v = torch.randn(B, H, N, D)
out_ours = _cpu_sdpa_under_test(q, k, v)
out_ref = F.scaled_dot_product_attention(q, k, v)
assert out_ours.shape == (B, H, N, D)
assert torch.allclose(out_ours, out_ref, atol=1e-5), (
f"Max diff: {(out_ours - out_ref).abs().max().item()}"
)
if __name__ == "__main__":
test_cpu_sdpa_matches_reference()
test_cpu_sdpa_output_shape()
test_cpu_sdpa_single_head()
test_cpu_sdpa_asymmetric_heads_items()
print("All tests passed.")