| |
| """Settle first-frame-sink vs middle-hump: per-LAYER frame-attention profile. |
| |
| Same extraction as frame_attention_probe but keeps the layer axis. For each |
| question: attentions[layer][0, heads, -1, vision_tokens] -> mean heads -> |
| reshape [grid_t, spatial] sum spatial -> per-layer per-bin attention, normalized |
| within video. Average over questions -> [n_layers, grid_t] heatmap. |
| """ |
| from __future__ import annotations |
| import json, os, sys, time |
| from pathlib import Path |
| os.environ.setdefault("HF_HOME", "/mnt/local-fast/opd_zt/hf_cache") |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") |
| import numpy as np, torch |
|
|
| ROOT = Path("/mnt/local-fast/opd_zt") |
| MODEL = str(ROOT / "hf_cache/hub/models--Qwen--Qwen2.5-VL-7B-Instruct/snapshots/" |
| "cc594898137f460bfe9f0759e9844b3ce807cfb5") |
| VH = ROOT / "data/benchmarks/VideoHallucer/temporal" |
| NF = 32; MAXP = 128*28*28; MINP = 16*28*28 |
| SUFFIX = "\nAnswer the question using 'yes' or 'no'." |
| N_SAMPLE = int(sys.argv[1]) if len(sys.argv) > 1 else 60 |
|
|
|
|
| def decode(path): |
| from decord import VideoReader, cpu |
| vr = VideoReader(path, ctx=cpu(0), num_threads=1); total = len(vr) |
| if total < 1: return None |
| idx = np.linspace(0, total-1, NF).round().astype(int).clip(0, total-1) |
| return vr.get_batch(idx.tolist()).asnumpy() |
|
|
|
|
| @torch.no_grad() |
| def main(): |
| from PIL import Image |
| from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration |
| proc = AutoProcessor.from_pretrained(MODEL, trust_remote_code=True, max_pixels=MAXP, min_pixels=MINP) |
| vtok = proc.tokenizer.convert_tokens_to_ids("<|video_pad|>") |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| MODEL, torch_dtype=torch.bfloat16, attn_implementation="eager", trust_remote_code=True |
| ).to("cuda:0").eval() |
|
|
| data = json.loads((VH/"temporal.json").read_text()) |
| items = [] |
| for idx, pair in enumerate(data): |
| for side in ("basic", "hallucination"): |
| q = pair[side] |
| items.append((str(VH/"videos"/q["video"]), q["question"])) |
| items = items[:N_SAMPLE] |
|
|
| acc = None; n = 0; gt_ref = 16 |
| t0 = time.time() |
| for vp, q in items: |
| fr = decode(vp) |
| if fr is None: continue |
| pil = [Image.fromarray(f) for f in fr] |
| text = proc.apply_chat_template( |
| [{"role":"user","content":[{"type":"video"},{"type":"text","text":q+SUFFIX}]}], |
| tokenize=False, add_generation_prompt=True) |
| inp = proc(text=[text], videos=[pil], return_tensors="pt") |
| inp = {k:(v.to(model.device) if hasattr(v,"to") else v) for k,v in inp.items()} |
| out = model(**inp, output_attentions=True, use_cache=False) |
| ids = inp["input_ids"][0]; g = inp["video_grid_thw"][0].tolist(); grid_t = g[0] |
| if grid_t != gt_ref: |
| del out; continue |
| vis = (ids == vtok) |
| nv = int(vis.sum()); spatial = nv // grid_t |
| L = len(out.attentions) |
| mat = np.zeros((L, grid_t), dtype=np.float64) |
| for li in range(L): |
| a = out.attentions[li][0, :, -1, :].mean(0).float() |
| va = a[vis][:grid_t*spatial].view(grid_t, spatial).sum(1) |
| s = float(va.sum()); va = (va/s) if s > 0 else va |
| mat[li] = va.cpu().numpy() |
| acc = mat if acc is None else acc + mat |
| n += 1 |
| del out |
| if n % 20 == 0: print(f"[run] {n} ({time.time()-t0:.0f}s)", flush=True) |
| prof = acc / n |
| np.save(ROOT/"outputs/frame_attention/layer_profile.npy", prof) |
|
|
| L, gt = prof.shape |
| print(f"\nn={n} questions, layers={L}, grid_t={gt}, uniform={1/gt:.3f}") |
| print("per-layer argmax bin (0=first .. 15=last), and shape tag:") |
| first = mid = last = 0 |
| for li in range(L): |
| row = prof[li]; am = int(row.argmax()) |
| tag = "FIRST" if am <= 1 else ("LAST" if am >= gt-2 else "MID") |
| if tag == "FIRST": first += 1 |
| elif tag == "MID": mid += 1 |
| else: last += 1 |
| bar = "".join("#" if row[b] > 1.5/gt else ("+" if row[b] > 1/gt else ".") for b in range(gt)) |
| print(f" L{li:2d} argmax={am:2d} {tag:5s} [{bar}]") |
| print(f"\nlayers peaking FIRST(<=1)={first} MID={mid} LAST(>=14)={last} of {L}") |
|
|
| try: |
| import matplotlib; matplotlib.use("Agg"); import matplotlib.pyplot as plt |
| fig, ax = plt.subplots(figsize=(8, 6)) |
| im = ax.imshow(prof, aspect="auto", cmap="magma", origin="lower") |
| ax.set_xlabel("temporal bin (0=start .. 15=end)"); ax.set_ylabel("layer") |
| ax.set_title(f"Per-layer frame-attention (n={n} VideoHallucer temporal Qs)\n" |
| "first-frame sink vs middle hump?") |
| fig.colorbar(im, label="mean attention within video"); fig.tight_layout() |
| p = ROOT/"outputs/frame_attention/fig_layer_profile.png" |
| fig.savefig(p, dpi=130); plt.close(fig); print(f"[fig] wrote {p}") |
| except Exception as e: |
| print("[fig] skip:", e) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|