| """ |
| generate.py β KV cache pruning for SparseVLM. |
| |
| Usage: |
| from sparsevlm import sparsevlm_generate |
| |
| output = sparsevlm_generate( |
| model, processor, inputs, |
| n_vis=256, # total visual tokens in the sequence |
| keep_n_vis=64, # how many to keep (25%) |
| max_new_tokens=256, |
| ) |
| """ |
|
|
| import torch |
|
|
|
|
| def _prune_kv_cache(cache, kept_indices): |
| """ |
| Remove KV entries for pruned visual tokens. |
| Works with transformers 5.x DynamicCache (cache.layers[i].keys / .values). |
| .contiguous() ensures no stride gaps after indexing. |
| """ |
| for layer in cache.layers: |
| k = kept_indices.to(layer.keys.device) |
| layer.keys = layer.keys[:, :, k, :].contiguous() |
| layer.values = layer.values[:, :, k, :].contiguous() |
| return cache |
|
|
|
|
| def sparsevlm_generate( |
| model, |
| processor, |
| inputs, |
| n_vis: int, |
| keep_n_vis: int, |
| max_new_tokens: int = 256, |
| target_layer: int = 2, |
| device: str = "cuda", |
| ): |
| """ |
| SparseVLM generation via KV cache pruning. |
| |
| Runs prefill once with output_attentions=True, scores all n_vis visual |
| tokens by their total text attention, keeps the top keep_n_vis, and |
| decodes with the pruned KV cache. |
| |
| Args: |
| model: Qwen2_5_VLForConditionalGeneration (loaded with |
| attn_implementation="eager") |
| processor: AutoProcessor |
| inputs: dict from processor(..., return_tensors="pt") |
| n_vis: number of visual tokens in the sequence |
| (inputs["input_ids"].shape[1] - n_text) |
| keep_n_vis: how many visual tokens to keep |
| max_new_tokens: generation length |
| target_layer: which layer's attention to use for scoring (default 2) |
| device: primary device (default "cuda") |
| |
| Returns: |
| generated token ids [B, max_new_tokens] |
| """ |
| N_TOTAL = inputs["input_ids"].shape[1] |
|
|
| |
| with torch.no_grad(): |
| prefill = model(**inputs, use_cache=True, output_attentions=True) |
|
|
| |
| |
| attn = prefill.attentions[target_layer] |
| A_tv = attn[:, :, n_vis:, :n_vis].mean(dim=1) |
| scores = A_tv.sum(dim=1)[0] |
|
|
| |
| kept_vis = scores.topk(keep_n_vis).indices |
| text_idx = torch.arange(n_vis, N_TOTAL, device=kept_vis.device) |
| kept_all = torch.cat([kept_vis, text_idx]) |
|
|
| cache = _prune_kv_cache(prefill.past_key_values, kept_all) |
| n_kept = cache.get_seq_length() |
|
|
| |
| |
| |
| n_pruned = N_TOTAL - n_kept |
| orig_deltas = model.model.rope_deltas.clone() |
| model.model.rope_deltas = orig_deltas + n_pruned |
|
|
| |
| attn_mask = torch.ones(1, n_kept + 1, device=device, dtype=torch.long) |
| with torch.no_grad(): |
| output = model.generate( |
| input_ids=inputs["input_ids"][:, -1:], |
| attention_mask=attn_mask, |
| past_key_values=cache, |
| max_new_tokens=max_new_tokens, |
| use_cache=True, |
| ) |
|
|
| |
| model.model.rope_deltas = orig_deltas |
|
|
| return output |
|
|