SparseVLM / tests /test_token_scorer.py
Aryan3108's picture
Upload folder using huggingface_hub
176b11a verified
Raw
History Blame Contribute Delete
2.74 kB
import pytest
import torch
from kernels.token_scorer import (
select_raters, score_visual_tokens,
compute_prune_counts, recycle_and_cluster, sparsevlm_score,
)
from tests.conftest import make_attn_weights, make_softmax_matrix
def test_rater_selection_shape(device):
A_tv = make_softmax_matrix(4, 32, 196, device)
mask = select_raters(A_tv)
assert mask.shape == (4, 32)
assert mask.dtype == torch.bool
assert mask.any(dim=-1).all()
def test_raters_above_mean(device):
A_tv = make_softmax_matrix(2, 20, 100, device)
mask = select_raters(A_tv)
mean_per_text = A_tv.mean(dim=-1)
global_mean = mean_per_text.mean(dim=-1, keepdim=True)
for b in range(2):
assert (mean_per_text[b, mask[b]] > global_mean[b]).all()
def test_score_shape(device):
A_tv = make_softmax_matrix(4, 32, 196, device)
mask = select_raters(A_tv)
scores, A_rater = score_visual_tokens(A_tv, mask)
assert scores.shape == (4, 196)
assert A_rater.shape[0] == 4
assert A_rater.shape[2] == 196
def test_prune_counts_bounds(device):
A_tv = make_softmax_matrix(8, 32, 196, device)
mask = select_raters(A_tv)
n_raters = mask.sum(dim=-1)
_, A_rater = score_visual_tokens(A_tv, mask)
counts = compute_prune_counts(A_rater, n_raters, 196, min_keep=32)
assert (counts >= 0).all()
assert (counts <= 164).all() # 196 - 32
def test_recycle_output(device):
torch.manual_seed(0)
D = 256
deleted_tokens = torch.randn(50, D, device=device)
deleted_scores = torch.rand(50, device=device)
out = recycle_and_cluster(deleted_tokens, deleted_scores)
assert out is not None
assert out.shape[1] == D
assert not torch.isnan(out).any()
def test_recycle_empty(device):
D = 256
out = recycle_and_cluster(
torch.zeros(0, D, device=device),
torch.zeros(0, device=device),
)
assert out is None
def test_sparsevlm_score_shape(device):
B, H, N_vis, N_text, D = 2, 8, 64, 16, 256
N_total = N_vis + N_text
attn = make_attn_weights(B, H, N_total, device)
hidden = torch.randn(B, N_total, D, device=device)
new_hidden, new_n_vis = sparsevlm_score(attn, hidden, n_vis=N_vis, min_keep=8)
assert new_hidden.dim() == 3
assert new_hidden.shape[0] == B
assert new_hidden.shape[2] == D
assert 8 <= new_n_vis < N_vis
def test_sparsevlm_score_no_nan(device):
B, H, N_vis, N_text, D = 4, 16, 128, 32, 512
N_total = N_vis + N_text
attn = make_attn_weights(B, H, N_total, device)
hidden = torch.randn(B, N_total, D, device=device)
out, _ = sparsevlm_score(attn, hidden, n_vis=N_vis)
assert not torch.isnan(out).any()
assert not torch.isinf(out).any()