import pytest import torch from kernels.sparse_attn import sparse_vision_attn def _make_inputs(B, N_vis, N_text, D, K, device): torch.manual_seed(42) patch = torch.randn(B, N_vis, D, device=device) text = torch.randn(B, N_text, D, device=device) kept = torch.stack([torch.randperm(N_vis, device=device)[:K] for _ in range(B)]) return patch, text, kept def test_matches_dense(device): B, N_vis, N_text, D, K = 4, 196, 77, 768, 80 patch, text, kept = _make_inputs(B, N_vis, N_text, D, K, device) scale = D ** -0.5 dense_out = torch.bmm(patch, text.transpose(1, 2)) * scale sparse_out = sparse_vision_attn(patch, text, kept, use_triton=False) idx = kept.unsqueeze(-1).expand(B, K, N_text) dense_at_kept = torch.gather(dense_out, 1, idx) assert (dense_at_kept - sparse_out).abs().max().item() < 1e-4 def test_output_shape(device): B, N_vis, N_text, D, K = 2, 196, 77, 768, 64 patch, text, kept = _make_inputs(B, N_vis, N_text, D, K, device) out = sparse_vision_attn(patch, text, kept, use_triton=False) assert out.shape == (B, K, N_text) def test_high_compression(device): B, N_vis, N_text, D = 4, 576, 77, 1024 K = int(N_vis * 0.22) patch, text, kept = _make_inputs(B, N_vis, N_text, D, K, device) out = sparse_vision_attn(patch, text, kept, use_triton=False) assert out.shape == (B, K, N_text) assert not torch.isnan(out).any() def test_cpu_fallback(): B, N_vis, N_text, D, K = 2, 64, 32, 128, 20 patch = torch.randn(B, N_vis, D) text = torch.randn(B, N_text, D) kept = torch.stack([torch.randperm(N_vis)[:K] for _ in range(B)]) out = sparse_vision_attn(patch, text, kept, use_triton=False) assert out.shape == (B, K, N_text)