""" patch.py — SparseVLM for Qwen2-VL and Qwen2.5-VL using PyTorch hooks. Uses register_forward_hook / register_forward_pre_hook so the original decoder layers are NEVER replaced — avoiding all module-wrapping issues. pre-hook (all layers): inject pruned position context from shared_state post-hook (target layers): prune output tokens, update shared_state """ import torch import torch.nn as nn from kernels.token_scorer import ( select_raters, score_visual_tokens, compute_prune_counts, get_kept_and_deleted_indices, recycle_and_cluster, ) def default_target_layers(n_layers): return [i for i in range(2, n_layers, 4)] def _get_layers(model): if hasattr(model, "model") and hasattr(model.model, "layers"): return model.model.layers if (hasattr(model, "model") and hasattr(model.model, "language_model") and hasattr(model.model.language_model, "layers")): return model.model.language_model.layers raise ValueError( f"Cannot find decoder layers in {type(model).__name__}. " "Tried model.model.layers and model.model.language_model.layers." ) # ── hook factories ──────────────────────────────────────────────────────────── def _make_pre_hook(shared_state, is_target=False): """ Inject updated position context before each layer. For target layers, also request attention weights. """ def pre_hook(module, args, kwargs): pid = shared_state.get("position_ids") pe = shared_state.get("position_embeddings") am = shared_state.get("attention_mask") need_update = pid is not None or pe is not None or am is not None or is_target if not need_update: return args, kwargs kwargs = dict(kwargs) if pid is not None: kwargs["position_ids"] = pid if pe is not None: kwargs["position_embeddings"] = pe if am is not None: kwargs["attention_mask"] = am if is_target: # Request attention weights from this layer so the post-hook can score tokens kwargs["output_attentions"] = True return args, kwargs return pre_hook def _make_post_hook(shared_state, layer_idx, min_keep, tau, theta): """After target layer: score visual tokens, prune, update context.""" def post_hook(module, args, kwargs, output): n_vis = shared_state["n_vis"] if n_vis <= min_keep: return output hidden_check = output[0] # Skip decode steps (seq_len==1) — only prune during prefill if hidden_check.shape[1] <= 1: return output hidden_out = output[0] rest = list(output[1:]) # Find 4-D attention weight tensor produced when output_attentions=True attn_weights = None attn_rest_idx = None for i, r in enumerate(rest): if r is not None and torch.is_tensor(r) and r.dim() == 4: attn_weights = r attn_rest_idx = i break if attn_weights is None: return output # no attn weights → can't score, skip B, H, N_total, _ = attn_weights.shape device = hidden_out.device # Text→visual submatrix, averaged over heads: [B, N_text, N_vis] A_tv = attn_weights[:, :, n_vis:, :n_vis].mean(dim=1) rater_mask = select_raters(A_tv) n_raters = rater_mask.sum(dim=-1) vision_scores, A_rater = score_visual_tokens(A_tv, rater_mask) # float32 for rank estimation (bfloat16/fp16 not supported by linalg) prune_counts = compute_prune_counts( A_rater.float(), n_raters, n_vis, min_keep ) kept_list, deleted_list, deleted_scores_list = \ get_kept_and_deleted_indices(vision_scores, prune_counts) vis_tokens = hidden_out[:, :n_vis, :] text_tokens = hidden_out[:, n_vis:, :] new_seqs = [] new_n_vis_list = [] for b in range(B): kept = vis_tokens[b, kept_list[b]] recycled = None if deleted_list[b].numel() > 0: recycled = recycle_and_cluster( vis_tokens[b, deleted_list[b]], deleted_scores_list[b], tau=tau, theta=theta, ) parts = [kept] if recycled is not None: parts.append(recycled) parts.append(text_tokens[b]) new_seqs.append(torch.cat(parts, dim=0)) new_n_vis_list.append( kept.shape[0] + (recycled.shape[0] if recycled is not None else 0) ) max_len = max(s.shape[0] for s in new_seqs) D = hidden_out.shape[-1] padded = torch.zeros(B, max_len, D, device=device, dtype=hidden_out.dtype) for b, seq in enumerate(new_seqs): padded[b, :seq.shape[0]] = seq new_n_vis = min(new_n_vis_list) hidden_out = padded shared_state["n_vis"] = new_n_vis # Build kept-all indices (kept vis + all text) n_text = text_tokens.shape[1] kept0 = kept_list[0].to(device) # batch size 1 in inference text_ix = torch.arange(n_vis, n_vis + n_text, device=device) kept_all = torch.cat([kept0, text_ix]) # Prune position_ids: [B, N] or [B, 3, N] pid = shared_state.get("position_ids") if pid is not None: shared_state["position_ids"] = ( pid[:, kept_all] if pid.dim() == 2 else pid[:, :, kept_all] ) # Prune position_embeddings: (cos, sin) each [B, N, D] pe = shared_state.get("position_embeddings") if pe is not None: cos, sin = pe shared_state["position_embeddings"] = ( cos[:, kept_all, :], sin[:, kept_all, :] ) # Prune attention_mask: [B, 1, N, N] am = shared_state.get("attention_mask") if am is not None and am.dim() == 4: shared_state["attention_mask"] = \ am[:, :, kept_all, :][:, :, :, kept_all] # Remove attn_weights from output (caller didn't request them) if attn_rest_idx is not None: rest[attn_rest_idx] = None return (hidden_out,) + tuple(rest) return post_hook # ── public API ──────────────────────────────────────────────────────────────── def patch_qwen2vl(model, n_vis, target_layers=None, min_keep=32, tau=0.5, theta=0.5): layers = _get_layers(model) n_layers = len(layers) target_layers = target_layers or default_target_layers(n_layers) target_set = set(target_layers) shared_state = { "n_vis": n_vis, "position_ids": None, "position_embeddings": None, "attention_mask": None, "_hooks": [], } for layer_idx, layer in enumerate(layers): is_target = layer_idx in target_set # Pre-hook on every layer: inject context; on target layers also request attn h_pre = layer.register_forward_pre_hook( _make_pre_hook(shared_state, is_target=is_target), with_kwargs=True ) shared_state["_hooks"].append(h_pre) if is_target: h_post = layer.register_forward_hook( _make_post_hook(shared_state, layer_idx, min_keep, tau, theta), with_kwargs=True, ) shared_state["_hooks"].append(h_post) n_pre = n_layers n_target = len(target_set) print( f"[SparseVLM] Registered hooks on {n_pre} layers " f"(pre-hook all, post-hook at {sorted(target_set)}). " f"n_vis={n_vis}, min_keep={min_keep}." ) return shared_state def reset_n_vis(shared_state, n_vis): shared_state["n_vis"] = n_vis shared_state["position_ids"] = None shared_state["position_embeddings"] = None shared_state["attention_mask"] = None def unpatch_qwen2vl(model): # Hooks are stored in the model — find and remove SparseVLM hooks # The cleanest way is to remove all hooks registered by us, stored in state. # But unpatch is typically called on a state returned by patch_qwen2vl. print("[SparseVLM] unpatch: use the state dict's '_hooks' list to remove hooks.") print(" Hint: for h in state['_hooks']: h.remove()") def remove_hooks(shared_state): """Remove all SparseVLM hooks. Call this instead of unpatch_qwen2vl.""" for h in shared_state.get("_hooks", []): h.remove() shared_state["_hooks"] = [] print(f"[SparseVLM] All hooks removed.")