File size: 8,908 Bytes
176b11a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 | """
patch.py β SparseVLM for Qwen2-VL and Qwen2.5-VL using PyTorch hooks.
Uses register_forward_hook / register_forward_pre_hook so the original
decoder layers are NEVER replaced β avoiding all module-wrapping issues.
pre-hook (all layers): inject pruned position context from shared_state
post-hook (target layers): prune output tokens, update shared_state
"""
import torch
import torch.nn as nn
from kernels.token_scorer import (
select_raters, score_visual_tokens,
compute_prune_counts, get_kept_and_deleted_indices,
recycle_and_cluster,
)
def default_target_layers(n_layers):
return [i for i in range(2, n_layers, 4)]
def _get_layers(model):
if hasattr(model, "model") and hasattr(model.model, "layers"):
return model.model.layers
if (hasattr(model, "model") and hasattr(model.model, "language_model")
and hasattr(model.model.language_model, "layers")):
return model.model.language_model.layers
raise ValueError(
f"Cannot find decoder layers in {type(model).__name__}. "
"Tried model.model.layers and model.model.language_model.layers."
)
# ββ hook factories ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _make_pre_hook(shared_state, is_target=False):
"""
Inject updated position context before each layer.
For target layers, also request attention weights.
"""
def pre_hook(module, args, kwargs):
pid = shared_state.get("position_ids")
pe = shared_state.get("position_embeddings")
am = shared_state.get("attention_mask")
need_update = pid is not None or pe is not None or am is not None or is_target
if not need_update:
return args, kwargs
kwargs = dict(kwargs)
if pid is not None:
kwargs["position_ids"] = pid
if pe is not None:
kwargs["position_embeddings"] = pe
if am is not None:
kwargs["attention_mask"] = am
if is_target:
# Request attention weights from this layer so the post-hook can score tokens
kwargs["output_attentions"] = True
return args, kwargs
return pre_hook
def _make_post_hook(shared_state, layer_idx, min_keep, tau, theta):
"""After target layer: score visual tokens, prune, update context."""
def post_hook(module, args, kwargs, output):
n_vis = shared_state["n_vis"]
if n_vis <= min_keep:
return output
hidden_check = output[0]
# Skip decode steps (seq_len==1) β only prune during prefill
if hidden_check.shape[1] <= 1:
return output
hidden_out = output[0]
rest = list(output[1:])
# Find 4-D attention weight tensor produced when output_attentions=True
attn_weights = None
attn_rest_idx = None
for i, r in enumerate(rest):
if r is not None and torch.is_tensor(r) and r.dim() == 4:
attn_weights = r
attn_rest_idx = i
break
if attn_weights is None:
return output # no attn weights β can't score, skip
B, H, N_total, _ = attn_weights.shape
device = hidden_out.device
# Textβvisual submatrix, averaged over heads: [B, N_text, N_vis]
A_tv = attn_weights[:, :, n_vis:, :n_vis].mean(dim=1)
rater_mask = select_raters(A_tv)
n_raters = rater_mask.sum(dim=-1)
vision_scores, A_rater = score_visual_tokens(A_tv, rater_mask)
# float32 for rank estimation (bfloat16/fp16 not supported by linalg)
prune_counts = compute_prune_counts(
A_rater.float(), n_raters, n_vis, min_keep
)
kept_list, deleted_list, deleted_scores_list = \
get_kept_and_deleted_indices(vision_scores, prune_counts)
vis_tokens = hidden_out[:, :n_vis, :]
text_tokens = hidden_out[:, n_vis:, :]
new_seqs = []
new_n_vis_list = []
for b in range(B):
kept = vis_tokens[b, kept_list[b]]
recycled = None
if deleted_list[b].numel() > 0:
recycled = recycle_and_cluster(
vis_tokens[b, deleted_list[b]],
deleted_scores_list[b],
tau=tau, theta=theta,
)
parts = [kept]
if recycled is not None:
parts.append(recycled)
parts.append(text_tokens[b])
new_seqs.append(torch.cat(parts, dim=0))
new_n_vis_list.append(
kept.shape[0] + (recycled.shape[0] if recycled is not None else 0)
)
max_len = max(s.shape[0] for s in new_seqs)
D = hidden_out.shape[-1]
padded = torch.zeros(B, max_len, D, device=device, dtype=hidden_out.dtype)
for b, seq in enumerate(new_seqs):
padded[b, :seq.shape[0]] = seq
new_n_vis = min(new_n_vis_list)
hidden_out = padded
shared_state["n_vis"] = new_n_vis
# Build kept-all indices (kept vis + all text)
n_text = text_tokens.shape[1]
kept0 = kept_list[0].to(device) # batch size 1 in inference
text_ix = torch.arange(n_vis, n_vis + n_text, device=device)
kept_all = torch.cat([kept0, text_ix])
# Prune position_ids: [B, N] or [B, 3, N]
pid = shared_state.get("position_ids")
if pid is not None:
shared_state["position_ids"] = (
pid[:, kept_all] if pid.dim() == 2 else pid[:, :, kept_all]
)
# Prune position_embeddings: (cos, sin) each [B, N, D]
pe = shared_state.get("position_embeddings")
if pe is not None:
cos, sin = pe
shared_state["position_embeddings"] = (
cos[:, kept_all, :], sin[:, kept_all, :]
)
# Prune attention_mask: [B, 1, N, N]
am = shared_state.get("attention_mask")
if am is not None and am.dim() == 4:
shared_state["attention_mask"] = \
am[:, :, kept_all, :][:, :, :, kept_all]
# Remove attn_weights from output (caller didn't request them)
if attn_rest_idx is not None:
rest[attn_rest_idx] = None
return (hidden_out,) + tuple(rest)
return post_hook
# ββ public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def patch_qwen2vl(model, n_vis, target_layers=None,
min_keep=32, tau=0.5, theta=0.5):
layers = _get_layers(model)
n_layers = len(layers)
target_layers = target_layers or default_target_layers(n_layers)
target_set = set(target_layers)
shared_state = {
"n_vis": n_vis,
"position_ids": None,
"position_embeddings": None,
"attention_mask": None,
"_hooks": [],
}
for layer_idx, layer in enumerate(layers):
is_target = layer_idx in target_set
# Pre-hook on every layer: inject context; on target layers also request attn
h_pre = layer.register_forward_pre_hook(
_make_pre_hook(shared_state, is_target=is_target), with_kwargs=True
)
shared_state["_hooks"].append(h_pre)
if is_target:
h_post = layer.register_forward_hook(
_make_post_hook(shared_state, layer_idx, min_keep, tau, theta),
with_kwargs=True,
)
shared_state["_hooks"].append(h_post)
n_pre = n_layers
n_target = len(target_set)
print(
f"[SparseVLM] Registered hooks on {n_pre} layers "
f"(pre-hook all, post-hook at {sorted(target_set)}). "
f"n_vis={n_vis}, min_keep={min_keep}."
)
return shared_state
def reset_n_vis(shared_state, n_vis):
shared_state["n_vis"] = n_vis
shared_state["position_ids"] = None
shared_state["position_embeddings"] = None
shared_state["attention_mask"] = None
def unpatch_qwen2vl(model):
# Hooks are stored in the model β find and remove SparseVLM hooks
# The cleanest way is to remove all hooks registered by us, stored in state.
# But unpatch is typically called on a state returned by patch_qwen2vl.
print("[SparseVLM] unpatch: use the state dict's '_hooks' list to remove hooks.")
print(" Hint: for h in state['_hooks']: h.remove()")
def remove_hooks(shared_state):
"""Remove all SparseVLM hooks. Call this instead of unpatch_qwen2vl."""
for h in shared_state.get("_hooks", []):
h.remove()
shared_state["_hooks"] = []
print(f"[SparseVLM] All hooks removed.")
|