| import pytest |
| import torch |
| from kernels.token_scorer import ( |
| select_raters, score_visual_tokens, |
| compute_prune_counts, recycle_and_cluster, sparsevlm_score, |
| ) |
| from tests.conftest import make_attn_weights, make_softmax_matrix |
|
|
|
|
| def test_rater_selection_shape(device): |
| A_tv = make_softmax_matrix(4, 32, 196, device) |
| mask = select_raters(A_tv) |
| assert mask.shape == (4, 32) |
| assert mask.dtype == torch.bool |
| assert mask.any(dim=-1).all() |
|
|
|
|
| def test_raters_above_mean(device): |
| A_tv = make_softmax_matrix(2, 20, 100, device) |
| mask = select_raters(A_tv) |
| mean_per_text = A_tv.mean(dim=-1) |
| global_mean = mean_per_text.mean(dim=-1, keepdim=True) |
| for b in range(2): |
| assert (mean_per_text[b, mask[b]] > global_mean[b]).all() |
|
|
|
|
| def test_score_shape(device): |
| A_tv = make_softmax_matrix(4, 32, 196, device) |
| mask = select_raters(A_tv) |
| scores, A_rater = score_visual_tokens(A_tv, mask) |
| assert scores.shape == (4, 196) |
| assert A_rater.shape[0] == 4 |
| assert A_rater.shape[2] == 196 |
|
|
|
|
| def test_prune_counts_bounds(device): |
| A_tv = make_softmax_matrix(8, 32, 196, device) |
| mask = select_raters(A_tv) |
| n_raters = mask.sum(dim=-1) |
| _, A_rater = score_visual_tokens(A_tv, mask) |
| counts = compute_prune_counts(A_rater, n_raters, 196, min_keep=32) |
| assert (counts >= 0).all() |
| assert (counts <= 164).all() |
|
|
|
|
| def test_recycle_output(device): |
| torch.manual_seed(0) |
| D = 256 |
| deleted_tokens = torch.randn(50, D, device=device) |
| deleted_scores = torch.rand(50, device=device) |
| out = recycle_and_cluster(deleted_tokens, deleted_scores) |
| assert out is not None |
| assert out.shape[1] == D |
| assert not torch.isnan(out).any() |
|
|
|
|
| def test_recycle_empty(device): |
| D = 256 |
| out = recycle_and_cluster( |
| torch.zeros(0, D, device=device), |
| torch.zeros(0, device=device), |
| ) |
| assert out is None |
|
|
|
|
| def test_sparsevlm_score_shape(device): |
| B, H, N_vis, N_text, D = 2, 8, 64, 16, 256 |
| N_total = N_vis + N_text |
| attn = make_attn_weights(B, H, N_total, device) |
| hidden = torch.randn(B, N_total, D, device=device) |
|
|
| new_hidden, new_n_vis = sparsevlm_score(attn, hidden, n_vis=N_vis, min_keep=8) |
|
|
| assert new_hidden.dim() == 3 |
| assert new_hidden.shape[0] == B |
| assert new_hidden.shape[2] == D |
| assert 8 <= new_n_vis < N_vis |
|
|
|
|
| def test_sparsevlm_score_no_nan(device): |
| B, H, N_vis, N_text, D = 4, 16, 128, 32, 512 |
| N_total = N_vis + N_text |
| attn = make_attn_weights(B, H, N_total, device) |
| hidden = torch.randn(B, N_total, D, device=device) |
| out, _ = sparsevlm_score(attn, hidden, n_vis=N_vis) |
| assert not torch.isnan(out).any() |
| assert not torch.isinf(out).any() |
|
|