import pytest import torch from kernels.token_scorer import ( select_raters, score_visual_tokens, compute_prune_counts, recycle_and_cluster, sparsevlm_score, ) from tests.conftest import make_attn_weights, make_softmax_matrix def test_rater_selection_shape(device): A_tv = make_softmax_matrix(4, 32, 196, device) mask = select_raters(A_tv) assert mask.shape == (4, 32) assert mask.dtype == torch.bool assert mask.any(dim=-1).all() def test_raters_above_mean(device): A_tv = make_softmax_matrix(2, 20, 100, device) mask = select_raters(A_tv) mean_per_text = A_tv.mean(dim=-1) global_mean = mean_per_text.mean(dim=-1, keepdim=True) for b in range(2): assert (mean_per_text[b, mask[b]] > global_mean[b]).all() def test_score_shape(device): A_tv = make_softmax_matrix(4, 32, 196, device) mask = select_raters(A_tv) scores, A_rater = score_visual_tokens(A_tv, mask) assert scores.shape == (4, 196) assert A_rater.shape[0] == 4 assert A_rater.shape[2] == 196 def test_prune_counts_bounds(device): A_tv = make_softmax_matrix(8, 32, 196, device) mask = select_raters(A_tv) n_raters = mask.sum(dim=-1) _, A_rater = score_visual_tokens(A_tv, mask) counts = compute_prune_counts(A_rater, n_raters, 196, min_keep=32) assert (counts >= 0).all() assert (counts <= 164).all() # 196 - 32 def test_recycle_output(device): torch.manual_seed(0) D = 256 deleted_tokens = torch.randn(50, D, device=device) deleted_scores = torch.rand(50, device=device) out = recycle_and_cluster(deleted_tokens, deleted_scores) assert out is not None assert out.shape[1] == D assert not torch.isnan(out).any() def test_recycle_empty(device): D = 256 out = recycle_and_cluster( torch.zeros(0, D, device=device), torch.zeros(0, device=device), ) assert out is None def test_sparsevlm_score_shape(device): B, H, N_vis, N_text, D = 2, 8, 64, 16, 256 N_total = N_vis + N_text attn = make_attn_weights(B, H, N_total, device) hidden = torch.randn(B, N_total, D, device=device) new_hidden, new_n_vis = sparsevlm_score(attn, hidden, n_vis=N_vis, min_keep=8) assert new_hidden.dim() == 3 assert new_hidden.shape[0] == B assert new_hidden.shape[2] == D assert 8 <= new_n_vis < N_vis def test_sparsevlm_score_no_nan(device): B, H, N_vis, N_text, D = 4, 16, 128, 32, 512 N_total = N_vis + N_text attn = make_attn_weights(B, H, N_total, device) hidden = torch.randn(B, N_total, D, device=device) out, _ = sparsevlm_score(attn, hidden, n_vis=N_vis) assert not torch.isnan(out).any() assert not torch.isinf(out).any()