SparseVLM / sparsevlm /generate.py
Aryan3108's picture
Upload folder using huggingface_hub
f6cafdf verified
Raw
History Blame Contribute Delete
4.17 kB
"""
generate.py β€” KV cache pruning for SparseVLM.
Usage:
from sparsevlm import sparsevlm_generate
output = sparsevlm_generate(
model, processor, inputs,
n_vis=256, # total visual tokens in the sequence
keep_n_vis=64, # how many to keep (25%)
max_new_tokens=256,
)
"""
import torch
def _prune_kv_cache(cache, kept_indices):
"""
Remove KV entries for pruned visual tokens.
Works with transformers 5.x DynamicCache (cache.layers[i].keys / .values).
.contiguous() ensures no stride gaps after indexing.
"""
for layer in cache.layers:
k = kept_indices.to(layer.keys.device)
layer.keys = layer.keys[:, :, k, :].contiguous()
layer.values = layer.values[:, :, k, :].contiguous()
return cache
def sparsevlm_generate(
model,
processor,
inputs,
n_vis: int,
keep_n_vis: int,
max_new_tokens: int = 256,
target_layer: int = 2,
device: str = "cuda",
):
"""
SparseVLM generation via KV cache pruning.
Runs prefill once with output_attentions=True, scores all n_vis visual
tokens by their total text attention, keeps the top keep_n_vis, and
decodes with the pruned KV cache.
Args:
model: Qwen2_5_VLForConditionalGeneration (loaded with
attn_implementation="eager")
processor: AutoProcessor
inputs: dict from processor(..., return_tensors="pt")
n_vis: number of visual tokens in the sequence
(inputs["input_ids"].shape[1] - n_text)
keep_n_vis: how many visual tokens to keep
max_new_tokens: generation length
target_layer: which layer's attention to use for scoring (default 2)
device: primary device (default "cuda")
Returns:
generated token ids [B, max_new_tokens]
"""
N_TOTAL = inputs["input_ids"].shape[1]
# ── 1. Prefill β€” get KV cache + attention weights ─────────────────────────
with torch.no_grad():
prefill = model(**inputs, use_cache=True, output_attentions=True)
# ── 2. Score all n_vis visual tokens ──────────────────────────────────────
# text→visual attention submatrix: [B, H, N_text, N_vis] averaged over heads
attn = prefill.attentions[target_layer]
A_tv = attn[:, :, n_vis:, :n_vis].mean(dim=1) # [B, N_text, N_vis]
scores = A_tv.sum(dim=1)[0] # [N_vis]
# ── 3. Keep top-keep_n_vis visual tokens by attention score ───────────────
kept_vis = scores.topk(keep_n_vis).indices
text_idx = torch.arange(n_vis, N_TOTAL, device=kept_vis.device)
kept_all = torch.cat([kept_vis, text_idx])
cache = _prune_kv_cache(prefill.past_key_values, kept_all)
n_kept = cache.get_seq_length()
# ── 4. Fix rope_deltas so decode positions are correct ────────────────────
# generate() computes: next_pos = cache.get_seq_length() + rope_deltas
# After pruning get_seq_length() = n_kept < N_TOTAL, so we compensate:
n_pruned = N_TOTAL - n_kept
orig_deltas = model.model.rope_deltas.clone()
model.model.rope_deltas = orig_deltas + n_pruned
# ── 5. Decode with pruned cache ────────────────────────────────────────────
attn_mask = torch.ones(1, n_kept + 1, device=device, dtype=torch.long)
with torch.no_grad():
output = model.generate(
input_ids=inputs["input_ids"][:, -1:],
attention_mask=attn_mask,
past_key_values=cache,
max_new_tokens=max_new_tokens,
use_cache=True,
)
# ── 6. Restore rope_deltas ─────────────────────────────────────────────────
model.model.rope_deltas = orig_deltas
return output