SparseVLM / kernels /token_scorer.py
Aryan3108's picture
Upload folder using huggingface_hub
176b11a verified
Raw
History Blame Contribute Delete
7.75 kB
"""
token_scorer.py
---------------
Faithful implementation of SparseVLM paper Sections 3.2 and 3.3.
Section 3.2 β€” Sparsification Guidance from Text to Vision:
1. Extract text→visual submatrix from LLM's own self-attention
2. Select rater tokens: text tokens with above-average visual attention
3. Score visual tokens by summed rater attention
4. Rank of A_rater β†’ adaptive prune count
5. Return kept_indices
Section 3.3 β€” Visual Token Recycling:
Cluster pruned tokens β†’ compact aggregate representations
"""
import torch
import torch.nn.functional as F
from .rank_estimator import sketch_rank
# ── Rater selection ───────────────────────────────────────────────────────────
def select_raters(A_tv: torch.Tensor) -> torch.Tensor:
"""
A text token is a rater if its mean attention to visual tokens
exceeds the global mean across all text tokens.
Args:
A_tv: [B, N_text, N_vis]
Returns:
rater_mask: [B, N_text] bool
"""
mean_per_text = A_tv.mean(dim=-1) # [B, N_text]
global_mean = mean_per_text.mean(dim=-1, keepdim=True) # [B, 1]
return mean_per_text > global_mean
def score_visual_tokens(
A_tv: torch.Tensor,
rater_mask: torch.Tensor,
) -> tuple:
"""
Score each visual token by summed attention from rater tokens only.
Args:
A_tv: [B, N_text, N_vis]
rater_mask: [B, N_text] bool
Returns:
vision_scores: [B, N_vis]
A_rater: [B, max_raters, N_vis] padded rater attention matrix
"""
B, N_text, N_vis = A_tv.shape
max_raters = rater_mask.sum(dim=-1).max().item()
A_rater = torch.zeros(B, max_raters, N_vis, device=A_tv.device, dtype=A_tv.dtype)
for b in range(B):
rows = A_tv[b, rater_mask[b]]
A_rater[b, :rows.shape[0]] = rows
vision_scores = A_rater.sum(dim=1) # [B, N_vis]
return vision_scores, A_rater
def compute_prune_counts(
A_rater: torch.Tensor,
n_raters: torch.Tensor,
N_vis: int,
min_keep: int = 32,
) -> torch.Tensor:
"""
Rank-adaptive prune count: prune_count = 0.5 * (N_vis - rank(A_rater))
Uses sketch_rank instead of SVD β€” 15-50x faster, same result.
Returns: [B] int prune counts
"""
ranks = sketch_rank(A_rater)
prune_counts = (0.5 * (N_vis - ranks.float())).int()
return prune_counts.clamp(min=0, max=N_vis - min_keep)
def get_kept_and_deleted_indices(
vision_scores: torch.Tensor,
prune_counts: torch.Tensor,
) -> tuple:
"""Split visual tokens into kept and deleted sets."""
B, N_vis = vision_scores.shape
kept_list = []
deleted_list = []
deleted_scores_list = []
for b in range(B):
P = prune_counts[b].item()
K = N_vis - P
topk_result = torch.topk(vision_scores[b], k=K)
kept_idx = topk_result.indices
all_idx = torch.arange(N_vis, device=vision_scores.device)
deleted_mask = torch.ones(N_vis, dtype=torch.bool, device=vision_scores.device)
deleted_mask[kept_idx] = False
deleted_idx = all_idx[deleted_mask]
kept_list.append(kept_idx)
deleted_list.append(deleted_idx)
deleted_scores_list.append(vision_scores[b, deleted_idx])
return kept_list, deleted_list, deleted_scores_list
# ── Token recycling ───────────────────────────────────────────────────────────
def recycle_and_cluster(
deleted_tokens: torch.Tensor,
deleted_scores: torch.Tensor,
tau: float = 0.5,
theta: float = 0.5,
) -> torch.Tensor | None:
"""
Paper Section 3.3: cluster pruned tokens into compact representations.
Args:
deleted_tokens: [P, D]
deleted_scores: [P]
tau: fraction of pruned to recycle
theta: cluster ratio
Returns:
aggregated: [n_clusters, D] or None
"""
P = deleted_tokens.shape[0]
if P < 1:
return None
n_recycle = max(1, int(tau * P))
recycle_idx = torch.topk(deleted_scores, n_recycle).indices
recycled_tokens = deleted_tokens[recycle_idx]
recycled_scores = deleted_scores[recycle_idx]
n_clusters = max(1, int(theta * n_recycle))
recycled_norm = F.normalize(recycled_tokens, dim=-1)
# Greedy k-means++ center selection
centers = [recycled_norm[recycled_scores.argmax()]]
for _ in range(1, n_clusters):
sims = torch.stack([torch.matmul(recycled_norm, c.unsqueeze(-1)).squeeze(-1)
for c in centers], dim=1)
dists = 1 - sims.max(dim=1).values
centers.append(recycled_norm[dists.argmax()])
sims = torch.stack([torch.matmul(recycled_norm, c.unsqueeze(-1)).squeeze(-1)
for c in centers], dim=1)
assignments = sims.argmax(dim=1)
aggregated = []
for k in range(n_clusters):
members = recycled_tokens[assignments == k]
if members.shape[0] > 0:
aggregated.append(members.sum(dim=0))
return torch.stack(aggregated) if aggregated else None
# ── Main entry point ──────────────────────────────────────────────────────────
def sparsevlm_score(
attn_weights: torch.Tensor, # [B, H, N_total, N_total]
hidden_states: torch.Tensor, # [B, N_total, D]
n_vis: int,
min_keep: int = 32,
tau: float = 0.5,
theta: float = 0.5,
) -> tuple:
"""
Full SparseVLM scoring for one transformer layer.
Called from the attention hook after attn_weights are computed.
Returns:
new_hidden_states: [B, N_new, D]
new_n_vis: int
"""
B, H, N_total, _ = attn_weights.shape
# Text→visual submatrix, averaged over heads
A_tv = attn_weights[:, :, n_vis:, :n_vis].mean(dim=1) # [B, N_text, N_vis]
rater_mask = select_raters(A_tv)
n_raters = rater_mask.sum(dim=-1)
vision_scores, A_rater = score_visual_tokens(A_tv, rater_mask)
prune_counts = compute_prune_counts(A_rater, n_raters, n_vis, min_keep)
kept_list, deleted_list, deleted_scores_list = get_kept_and_deleted_indices(
vision_scores, prune_counts
)
vis_tokens = hidden_states[:, :n_vis, :]
text_tokens = hidden_states[:, n_vis:, :]
new_sequences = []
new_n_vis_per_item = []
for b in range(B):
kept_tokens = vis_tokens[b, kept_list[b]]
recycled = None
if deleted_list[b].numel() > 0:
recycled = recycle_and_cluster(
vis_tokens[b, deleted_list[b]],
deleted_scores_list[b],
tau=tau, theta=theta,
)
parts = [kept_tokens]
if recycled is not None:
parts.append(recycled)
parts.append(text_tokens[b])
combined = torch.cat(parts, dim=0)
new_sequences.append(combined)
n_vis_b = kept_tokens.shape[0] + (recycled.shape[0] if recycled is not None else 0)
new_n_vis_per_item.append(n_vis_b)
# Pad to same length for batched output
max_len = max(s.shape[0] for s in new_sequences)
D = hidden_states.shape[-1]
padded = torch.zeros(B, max_len, D, device=hidden_states.device, dtype=hidden_states.dtype)
for b, seq in enumerate(new_sequences):
padded[b, :seq.shape[0]] = seq
new_n_vis = min(new_n_vis_per_item)
return padded, new_n_vis