""" generate.py — KV cache pruning for SparseVLM. Usage: from sparsevlm import sparsevlm_generate output = sparsevlm_generate( model, processor, inputs, n_vis=256, # total visual tokens in the sequence keep_n_vis=64, # how many to keep (25%) max_new_tokens=256, ) """ import torch def _prune_kv_cache(cache, kept_indices): """ Remove KV entries for pruned visual tokens. Works with transformers 5.x DynamicCache (cache.layers[i].keys / .values). .contiguous() ensures no stride gaps after indexing. """ for layer in cache.layers: k = kept_indices.to(layer.keys.device) layer.keys = layer.keys[:, :, k, :].contiguous() layer.values = layer.values[:, :, k, :].contiguous() return cache def sparsevlm_generate( model, processor, inputs, n_vis: int, keep_n_vis: int, max_new_tokens: int = 256, target_layer: int = 2, device: str = "cuda", ): """ SparseVLM generation via KV cache pruning. Runs prefill once with output_attentions=True, scores all n_vis visual tokens by their total text attention, keeps the top keep_n_vis, and decodes with the pruned KV cache. Args: model: Qwen2_5_VLForConditionalGeneration (loaded with attn_implementation="eager") processor: AutoProcessor inputs: dict from processor(..., return_tensors="pt") n_vis: number of visual tokens in the sequence (inputs["input_ids"].shape[1] - n_text) keep_n_vis: how many visual tokens to keep max_new_tokens: generation length target_layer: which layer's attention to use for scoring (default 2) device: primary device (default "cuda") Returns: generated token ids [B, max_new_tokens] """ N_TOTAL = inputs["input_ids"].shape[1] # ── 1. Prefill — get KV cache + attention weights ───────────────────────── with torch.no_grad(): prefill = model(**inputs, use_cache=True, output_attentions=True) # ── 2. Score all n_vis visual tokens ────────────────────────────────────── # text→visual attention submatrix: [B, H, N_text, N_vis] averaged over heads attn = prefill.attentions[target_layer] A_tv = attn[:, :, n_vis:, :n_vis].mean(dim=1) # [B, N_text, N_vis] scores = A_tv.sum(dim=1)[0] # [N_vis] # ── 3. Keep top-keep_n_vis visual tokens by attention score ─────────────── kept_vis = scores.topk(keep_n_vis).indices text_idx = torch.arange(n_vis, N_TOTAL, device=kept_vis.device) kept_all = torch.cat([kept_vis, text_idx]) cache = _prune_kv_cache(prefill.past_key_values, kept_all) n_kept = cache.get_seq_length() # ── 4. Fix rope_deltas so decode positions are correct ──────────────────── # generate() computes: next_pos = cache.get_seq_length() + rope_deltas # After pruning get_seq_length() = n_kept < N_TOTAL, so we compensate: n_pruned = N_TOTAL - n_kept orig_deltas = model.model.rope_deltas.clone() model.model.rope_deltas = orig_deltas + n_pruned # ── 5. Decode with pruned cache ──────────────────────────────────────────── attn_mask = torch.ones(1, n_kept + 1, device=device, dtype=torch.long) with torch.no_grad(): output = model.generate( input_ids=inputs["input_ids"][:, -1:], attention_mask=attn_mask, past_key_values=cache, max_new_tokens=max_new_tokens, use_cache=True, ) # ── 6. Restore rope_deltas ───────────────────────────────────────────────── model.model.rope_deltas = orig_deltas return output