File size: 4,168 Bytes
f6cafdf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""
generate.py β€” KV cache pruning for SparseVLM.

Usage:
    from sparsevlm import sparsevlm_generate

    output = sparsevlm_generate(
        model, processor, inputs,
        n_vis=256,        # total visual tokens in the sequence
        keep_n_vis=64,    # how many to keep (25%)
        max_new_tokens=256,
    )
"""

import torch


def _prune_kv_cache(cache, kept_indices):
    """
    Remove KV entries for pruned visual tokens.
    Works with transformers 5.x DynamicCache (cache.layers[i].keys / .values).
    .contiguous() ensures no stride gaps after indexing.
    """
    for layer in cache.layers:
        k = kept_indices.to(layer.keys.device)
        layer.keys   = layer.keys[:, :, k, :].contiguous()
        layer.values = layer.values[:, :, k, :].contiguous()
    return cache


def sparsevlm_generate(
    model,
    processor,
    inputs,
    n_vis: int,
    keep_n_vis: int,
    max_new_tokens: int = 256,
    target_layer: int = 2,
    device: str = "cuda",
):
    """
    SparseVLM generation via KV cache pruning.

    Runs prefill once with output_attentions=True, scores all n_vis visual
    tokens by their total text attention, keeps the top keep_n_vis, and
    decodes with the pruned KV cache.

    Args:
        model:          Qwen2_5_VLForConditionalGeneration (loaded with
                        attn_implementation="eager")
        processor:      AutoProcessor
        inputs:         dict from processor(..., return_tensors="pt")
        n_vis:          number of visual tokens in the sequence
                        (inputs["input_ids"].shape[1] - n_text)
        keep_n_vis:     how many visual tokens to keep
        max_new_tokens: generation length
        target_layer:   which layer's attention to use for scoring (default 2)
        device:         primary device (default "cuda")

    Returns:
        generated token ids [B, max_new_tokens]
    """
    N_TOTAL = inputs["input_ids"].shape[1]

    # ── 1. Prefill β€” get KV cache + attention weights ─────────────────────────
    with torch.no_grad():
        prefill = model(**inputs, use_cache=True, output_attentions=True)

    # ── 2. Score all n_vis visual tokens ──────────────────────────────────────
    # text→visual attention submatrix: [B, H, N_text, N_vis] averaged over heads
    attn   = prefill.attentions[target_layer]
    A_tv   = attn[:, :, n_vis:, :n_vis].mean(dim=1)  # [B, N_text, N_vis]
    scores = A_tv.sum(dim=1)[0]                        # [N_vis]

    # ── 3. Keep top-keep_n_vis visual tokens by attention score ───────────────
    kept_vis = scores.topk(keep_n_vis).indices
    text_idx = torch.arange(n_vis, N_TOTAL, device=kept_vis.device)
    kept_all = torch.cat([kept_vis, text_idx])

    cache  = _prune_kv_cache(prefill.past_key_values, kept_all)
    n_kept = cache.get_seq_length()

    # ── 4. Fix rope_deltas so decode positions are correct ────────────────────
    # generate() computes: next_pos = cache.get_seq_length() + rope_deltas
    # After pruning get_seq_length() = n_kept < N_TOTAL, so we compensate:
    n_pruned     = N_TOTAL - n_kept
    orig_deltas  = model.model.rope_deltas.clone()
    model.model.rope_deltas = orig_deltas + n_pruned

    # ── 5. Decode with pruned cache ────────────────────────────────────────────
    attn_mask = torch.ones(1, n_kept + 1, device=device, dtype=torch.long)
    with torch.no_grad():
        output = model.generate(
            input_ids=inputs["input_ids"][:, -1:],
            attention_mask=attn_mask,
            past_key_values=cache,
            max_new_tokens=max_new_tokens,
            use_cache=True,
        )

    # ── 6. Restore rope_deltas ─────────────────────────────────────────────────
    model.model.rope_deltas = orig_deltas

    return output