#!/usr/bin/env python3 """Settle first-frame-sink vs middle-hump: per-LAYER frame-attention profile. Same extraction as frame_attention_probe but keeps the layer axis. For each question: attentions[layer][0, heads, -1, vision_tokens] -> mean heads -> reshape [grid_t, spatial] sum spatial -> per-layer per-bin attention, normalized within video. Average over questions -> [n_layers, grid_t] heatmap. """ from __future__ import annotations import json, os, sys, time from pathlib import Path os.environ.setdefault("HF_HOME", "/mnt/local-fast/opd_zt/hf_cache") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") import numpy as np, torch ROOT = Path("/mnt/local-fast/opd_zt") MODEL = str(ROOT / "hf_cache/hub/models--Qwen--Qwen2.5-VL-7B-Instruct/snapshots/" "cc594898137f460bfe9f0759e9844b3ce807cfb5") VH = ROOT / "data/benchmarks/VideoHallucer/temporal" NF = 32; MAXP = 128*28*28; MINP = 16*28*28 SUFFIX = "\nAnswer the question using 'yes' or 'no'." N_SAMPLE = int(sys.argv[1]) if len(sys.argv) > 1 else 60 def decode(path): from decord import VideoReader, cpu vr = VideoReader(path, ctx=cpu(0), num_threads=1); total = len(vr) if total < 1: return None idx = np.linspace(0, total-1, NF).round().astype(int).clip(0, total-1) return vr.get_batch(idx.tolist()).asnumpy() @torch.no_grad() def main(): from PIL import Image from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration proc = AutoProcessor.from_pretrained(MODEL, trust_remote_code=True, max_pixels=MAXP, min_pixels=MINP) vtok = proc.tokenizer.convert_tokens_to_ids("<|video_pad|>") model = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL, torch_dtype=torch.bfloat16, attn_implementation="eager", trust_remote_code=True ).to("cuda:0").eval() data = json.loads((VH/"temporal.json").read_text()) items = [] for idx, pair in enumerate(data): for side in ("basic", "hallucination"): q = pair[side] items.append((str(VH/"videos"/q["video"]), q["question"])) items = items[:N_SAMPLE] acc = None; n = 0; gt_ref = 16 t0 = time.time() for vp, q in items: fr = decode(vp) if fr is None: continue pil = [Image.fromarray(f) for f in fr] text = proc.apply_chat_template( [{"role":"user","content":[{"type":"video"},{"type":"text","text":q+SUFFIX}]}], tokenize=False, add_generation_prompt=True) inp = proc(text=[text], videos=[pil], return_tensors="pt") inp = {k:(v.to(model.device) if hasattr(v,"to") else v) for k,v in inp.items()} out = model(**inp, output_attentions=True, use_cache=False) ids = inp["input_ids"][0]; g = inp["video_grid_thw"][0].tolist(); grid_t = g[0] if grid_t != gt_ref: del out; continue vis = (ids == vtok) nv = int(vis.sum()); spatial = nv // grid_t L = len(out.attentions) mat = np.zeros((L, grid_t), dtype=np.float64) for li in range(L): a = out.attentions[li][0, :, -1, :].mean(0).float() # [seq] va = a[vis][:grid_t*spatial].view(grid_t, spatial).sum(1) s = float(va.sum()); va = (va/s) if s > 0 else va mat[li] = va.cpu().numpy() acc = mat if acc is None else acc + mat n += 1 del out if n % 20 == 0: print(f"[run] {n} ({time.time()-t0:.0f}s)", flush=True) prof = acc / n # [L, grid_t] np.save(ROOT/"outputs/frame_attention/layer_profile.npy", prof) L, gt = prof.shape print(f"\nn={n} questions, layers={L}, grid_t={gt}, uniform={1/gt:.3f}") print("per-layer argmax bin (0=first .. 15=last), and shape tag:") first = mid = last = 0 for li in range(L): row = prof[li]; am = int(row.argmax()) tag = "FIRST" if am <= 1 else ("LAST" if am >= gt-2 else "MID") if tag == "FIRST": first += 1 elif tag == "MID": mid += 1 else: last += 1 bar = "".join("#" if row[b] > 1.5/gt else ("+" if row[b] > 1/gt else ".") for b in range(gt)) print(f" L{li:2d} argmax={am:2d} {tag:5s} [{bar}]") print(f"\nlayers peaking FIRST(<=1)={first} MID={mid} LAST(>=14)={last} of {L}") try: import matplotlib; matplotlib.use("Agg"); import matplotlib.pyplot as plt fig, ax = plt.subplots(figsize=(8, 6)) im = ax.imshow(prof, aspect="auto", cmap="magma", origin="lower") ax.set_xlabel("temporal bin (0=start .. 15=end)"); ax.set_ylabel("layer") ax.set_title(f"Per-layer frame-attention (n={n} VideoHallucer temporal Qs)\n" "first-frame sink vs middle hump?") fig.colorbar(im, label="mean attention within video"); fig.tight_layout() p = ROOT/"outputs/frame_attention/fig_layer_profile.png" fig.savefig(p, dpi=130); plt.close(fig); print(f"[fig] wrote {p}") except Exception as e: print("[fig] skip:", e) if __name__ == "__main__": main()