File size: 8,908 Bytes
176b11a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
"""
patch.py β€” SparseVLM for Qwen2-VL and Qwen2.5-VL using PyTorch hooks.

Uses register_forward_hook / register_forward_pre_hook so the original
decoder layers are NEVER replaced β€” avoiding all module-wrapping issues.

  pre-hook  (all layers): inject pruned position context from shared_state
  post-hook (target layers): prune output tokens, update shared_state
"""

import torch
import torch.nn as nn
from kernels.token_scorer import (
    select_raters, score_visual_tokens,
    compute_prune_counts, get_kept_and_deleted_indices,
    recycle_and_cluster,
)


def default_target_layers(n_layers):
    return [i for i in range(2, n_layers, 4)]


def _get_layers(model):
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        return model.model.layers
    if (hasattr(model, "model") and hasattr(model.model, "language_model")
            and hasattr(model.model.language_model, "layers")):
        return model.model.language_model.layers
    raise ValueError(
        f"Cannot find decoder layers in {type(model).__name__}. "
        "Tried model.model.layers and model.model.language_model.layers."
    )


# ── hook factories ────────────────────────────────────────────────────────────

def _make_pre_hook(shared_state, is_target=False):
    """
    Inject updated position context before each layer.
    For target layers, also request attention weights.
    """
    def pre_hook(module, args, kwargs):
        pid  = shared_state.get("position_ids")
        pe   = shared_state.get("position_embeddings")
        am   = shared_state.get("attention_mask")
        need_update = pid is not None or pe is not None or am is not None or is_target
        if not need_update:
            return args, kwargs
        kwargs = dict(kwargs)
        if pid is not None:
            kwargs["position_ids"] = pid
        if pe is not None:
            kwargs["position_embeddings"] = pe
        if am is not None:
            kwargs["attention_mask"] = am
        if is_target:
            # Request attention weights from this layer so the post-hook can score tokens
            kwargs["output_attentions"] = True
        return args, kwargs
    return pre_hook


def _make_post_hook(shared_state, layer_idx, min_keep, tau, theta):
    """After target layer: score visual tokens, prune, update context."""
    def post_hook(module, args, kwargs, output):
        n_vis = shared_state["n_vis"]
        if n_vis <= min_keep:
            return output

        hidden_check = output[0]
        # Skip decode steps (seq_len==1) β€” only prune during prefill
        if hidden_check.shape[1] <= 1:
            return output

        hidden_out = output[0]
        rest       = list(output[1:])

        # Find 4-D attention weight tensor produced when output_attentions=True
        attn_weights = None
        attn_rest_idx = None
        for i, r in enumerate(rest):
            if r is not None and torch.is_tensor(r) and r.dim() == 4:
                attn_weights = r
                attn_rest_idx = i
                break

        if attn_weights is None:
            return output   # no attn weights β†’ can't score, skip

        B, H, N_total, _ = attn_weights.shape
        device = hidden_out.device

        # Text→visual submatrix, averaged over heads: [B, N_text, N_vis]
        A_tv = attn_weights[:, :, n_vis:, :n_vis].mean(dim=1)

        rater_mask   = select_raters(A_tv)
        n_raters     = rater_mask.sum(dim=-1)
        vision_scores, A_rater = score_visual_tokens(A_tv, rater_mask)
        # float32 for rank estimation (bfloat16/fp16 not supported by linalg)
        prune_counts = compute_prune_counts(
            A_rater.float(), n_raters, n_vis, min_keep
        )
        kept_list, deleted_list, deleted_scores_list = \
            get_kept_and_deleted_indices(vision_scores, prune_counts)

        vis_tokens  = hidden_out[:, :n_vis, :]
        text_tokens = hidden_out[:, n_vis:, :]
        new_seqs    = []
        new_n_vis_list = []

        for b in range(B):
            kept     = vis_tokens[b, kept_list[b]]
            recycled = None
            if deleted_list[b].numel() > 0:
                recycled = recycle_and_cluster(
                    vis_tokens[b, deleted_list[b]],
                    deleted_scores_list[b],
                    tau=tau, theta=theta,
                )
            parts = [kept]
            if recycled is not None:
                parts.append(recycled)
            parts.append(text_tokens[b])
            new_seqs.append(torch.cat(parts, dim=0))
            new_n_vis_list.append(
                kept.shape[0] + (recycled.shape[0] if recycled is not None else 0)
            )

        max_len = max(s.shape[0] for s in new_seqs)
        D = hidden_out.shape[-1]
        padded = torch.zeros(B, max_len, D, device=device, dtype=hidden_out.dtype)
        for b, seq in enumerate(new_seqs):
            padded[b, :seq.shape[0]] = seq

        new_n_vis  = min(new_n_vis_list)
        hidden_out = padded
        shared_state["n_vis"] = new_n_vis

        # Build kept-all indices (kept vis + all text)
        n_text  = text_tokens.shape[1]
        kept0   = kept_list[0].to(device)           # batch size 1 in inference
        text_ix = torch.arange(n_vis, n_vis + n_text, device=device)
        kept_all = torch.cat([kept0, text_ix])

        # Prune position_ids: [B, N] or [B, 3, N]
        pid = shared_state.get("position_ids")
        if pid is not None:
            shared_state["position_ids"] = (
                pid[:, kept_all] if pid.dim() == 2 else pid[:, :, kept_all]
            )

        # Prune position_embeddings: (cos, sin) each [B, N, D]
        pe = shared_state.get("position_embeddings")
        if pe is not None:
            cos, sin = pe
            shared_state["position_embeddings"] = (
                cos[:, kept_all, :], sin[:, kept_all, :]
            )

        # Prune attention_mask: [B, 1, N, N]
        am = shared_state.get("attention_mask")
        if am is not None and am.dim() == 4:
            shared_state["attention_mask"] = \
                am[:, :, kept_all, :][:, :, :, kept_all]

        # Remove attn_weights from output (caller didn't request them)
        if attn_rest_idx is not None:
            rest[attn_rest_idx] = None

        return (hidden_out,) + tuple(rest)

    return post_hook


# ── public API ────────────────────────────────────────────────────────────────

def patch_qwen2vl(model, n_vis, target_layers=None,
                  min_keep=32, tau=0.5, theta=0.5):
    layers        = _get_layers(model)
    n_layers      = len(layers)
    target_layers = target_layers or default_target_layers(n_layers)
    target_set    = set(target_layers)

    shared_state = {
        "n_vis": n_vis,
        "position_ids": None,
        "position_embeddings": None,
        "attention_mask": None,
        "_hooks": [],
    }

    for layer_idx, layer in enumerate(layers):
        is_target = layer_idx in target_set
        # Pre-hook on every layer: inject context; on target layers also request attn
        h_pre = layer.register_forward_pre_hook(
            _make_pre_hook(shared_state, is_target=is_target), with_kwargs=True
        )
        shared_state["_hooks"].append(h_pre)

        if is_target:
            h_post = layer.register_forward_hook(
                _make_post_hook(shared_state, layer_idx, min_keep, tau, theta),
                with_kwargs=True,
            )
            shared_state["_hooks"].append(h_post)

    n_pre    = n_layers
    n_target = len(target_set)
    print(
        f"[SparseVLM] Registered hooks on {n_pre} layers "
        f"(pre-hook all, post-hook at {sorted(target_set)}). "
        f"n_vis={n_vis}, min_keep={min_keep}."
    )
    return shared_state


def reset_n_vis(shared_state, n_vis):
    shared_state["n_vis"]                = n_vis
    shared_state["position_ids"]         = None
    shared_state["position_embeddings"]  = None
    shared_state["attention_mask"]       = None


def unpatch_qwen2vl(model):
    # Hooks are stored in the model β€” find and remove SparseVLM hooks
    # The cleanest way is to remove all hooks registered by us, stored in state.
    # But unpatch is typically called on a state returned by patch_qwen2vl.
    print("[SparseVLM] unpatch: use the state dict's '_hooks' list to remove hooks.")
    print("  Hint: for h in state['_hooks']: h.remove()")


def remove_hooks(shared_state):
    """Remove all SparseVLM hooks. Call this instead of unpatch_qwen2vl."""
    for h in shared_state.get("_hooks", []):
        h.remove()
    shared_state["_hooks"] = []
    print(f"[SparseVLM] All hooks removed.")