| import pytest |
| import torch |
| from kernels.sparse_attn import sparse_vision_attn |
|
|
|
|
| def _make_inputs(B, N_vis, N_text, D, K, device): |
| torch.manual_seed(42) |
| patch = torch.randn(B, N_vis, D, device=device) |
| text = torch.randn(B, N_text, D, device=device) |
| kept = torch.stack([torch.randperm(N_vis, device=device)[:K] for _ in range(B)]) |
| return patch, text, kept |
|
|
|
|
| def test_matches_dense(device): |
| B, N_vis, N_text, D, K = 4, 196, 77, 768, 80 |
| patch, text, kept = _make_inputs(B, N_vis, N_text, D, K, device) |
|
|
| scale = D ** -0.5 |
| dense_out = torch.bmm(patch, text.transpose(1, 2)) * scale |
| sparse_out = sparse_vision_attn(patch, text, kept, use_triton=False) |
|
|
| idx = kept.unsqueeze(-1).expand(B, K, N_text) |
| dense_at_kept = torch.gather(dense_out, 1, idx) |
| assert (dense_at_kept - sparse_out).abs().max().item() < 1e-4 |
|
|
|
|
| def test_output_shape(device): |
| B, N_vis, N_text, D, K = 2, 196, 77, 768, 64 |
| patch, text, kept = _make_inputs(B, N_vis, N_text, D, K, device) |
| out = sparse_vision_attn(patch, text, kept, use_triton=False) |
| assert out.shape == (B, K, N_text) |
|
|
|
|
| def test_high_compression(device): |
| B, N_vis, N_text, D = 4, 576, 77, 1024 |
| K = int(N_vis * 0.22) |
| patch, text, kept = _make_inputs(B, N_vis, N_text, D, K, device) |
| out = sparse_vision_attn(patch, text, kept, use_triton=False) |
| assert out.shape == (B, K, N_text) |
| assert not torch.isnan(out).any() |
|
|
|
|
| def test_cpu_fallback(): |
| B, N_vis, N_text, D, K = 2, 64, 32, 128, 20 |
| patch = torch.randn(B, N_vis, D) |
| text = torch.randn(B, N_text, D) |
| kept = torch.stack([torch.randperm(N_vis)[:K] for _ in range(B)]) |
| out = sparse_vision_attn(patch, text, kept, use_triton=False) |
| assert out.shape == (B, K, N_text) |
|
|