SparseVLM / tests /test_sparse_attn.py
Aryan3108's picture
Upload folder using huggingface_hub
176b11a verified
Raw
History Blame Contribute Delete
1.77 kB
import pytest
import torch
from kernels.sparse_attn import sparse_vision_attn
def _make_inputs(B, N_vis, N_text, D, K, device):
torch.manual_seed(42)
patch = torch.randn(B, N_vis, D, device=device)
text = torch.randn(B, N_text, D, device=device)
kept = torch.stack([torch.randperm(N_vis, device=device)[:K] for _ in range(B)])
return patch, text, kept
def test_matches_dense(device):
B, N_vis, N_text, D, K = 4, 196, 77, 768, 80
patch, text, kept = _make_inputs(B, N_vis, N_text, D, K, device)
scale = D ** -0.5
dense_out = torch.bmm(patch, text.transpose(1, 2)) * scale
sparse_out = sparse_vision_attn(patch, text, kept, use_triton=False)
idx = kept.unsqueeze(-1).expand(B, K, N_text)
dense_at_kept = torch.gather(dense_out, 1, idx)
assert (dense_at_kept - sparse_out).abs().max().item() < 1e-4
def test_output_shape(device):
B, N_vis, N_text, D, K = 2, 196, 77, 768, 64
patch, text, kept = _make_inputs(B, N_vis, N_text, D, K, device)
out = sparse_vision_attn(patch, text, kept, use_triton=False)
assert out.shape == (B, K, N_text)
def test_high_compression(device):
B, N_vis, N_text, D = 4, 576, 77, 1024
K = int(N_vis * 0.22)
patch, text, kept = _make_inputs(B, N_vis, N_text, D, K, device)
out = sparse_vision_attn(patch, text, kept, use_triton=False)
assert out.shape == (B, K, N_text)
assert not torch.isnan(out).any()
def test_cpu_fallback():
B, N_vis, N_text, D, K = 2, 64, 32, 128, 20
patch = torch.randn(B, N_vis, D)
text = torch.randn(B, N_text, D)
kept = torch.stack([torch.randperm(N_vis)[:K] for _ in range(B)])
out = sparse_vision_attn(patch, text, kept, use_triton=False)
assert out.shape == (B, K, N_text)