opd_zt / scripts /layer_resolved_attention.py
sdzt's picture
Add files using upload-large-folder tool
bf46e5d verified
Raw
History Blame Contribute Delete
4.98 kB
#!/usr/bin/env python3
"""Settle first-frame-sink vs middle-hump: per-LAYER frame-attention profile.
Same extraction as frame_attention_probe but keeps the layer axis. For each
question: attentions[layer][0, heads, -1, vision_tokens] -> mean heads ->
reshape [grid_t, spatial] sum spatial -> per-layer per-bin attention, normalized
within video. Average over questions -> [n_layers, grid_t] heatmap.
"""
from __future__ import annotations
import json, os, sys, time
from pathlib import Path
os.environ.setdefault("HF_HOME", "/mnt/local-fast/opd_zt/hf_cache")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
import numpy as np, torch
ROOT = Path("/mnt/local-fast/opd_zt")
MODEL = str(ROOT / "hf_cache/hub/models--Qwen--Qwen2.5-VL-7B-Instruct/snapshots/"
"cc594898137f460bfe9f0759e9844b3ce807cfb5")
VH = ROOT / "data/benchmarks/VideoHallucer/temporal"
NF = 32; MAXP = 128*28*28; MINP = 16*28*28
SUFFIX = "\nAnswer the question using 'yes' or 'no'."
N_SAMPLE = int(sys.argv[1]) if len(sys.argv) > 1 else 60
def decode(path):
from decord import VideoReader, cpu
vr = VideoReader(path, ctx=cpu(0), num_threads=1); total = len(vr)
if total < 1: return None
idx = np.linspace(0, total-1, NF).round().astype(int).clip(0, total-1)
return vr.get_batch(idx.tolist()).asnumpy()
@torch.no_grad()
def main():
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
proc = AutoProcessor.from_pretrained(MODEL, trust_remote_code=True, max_pixels=MAXP, min_pixels=MINP)
vtok = proc.tokenizer.convert_tokens_to_ids("<|video_pad|>")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL, torch_dtype=torch.bfloat16, attn_implementation="eager", trust_remote_code=True
).to("cuda:0").eval()
data = json.loads((VH/"temporal.json").read_text())
items = []
for idx, pair in enumerate(data):
for side in ("basic", "hallucination"):
q = pair[side]
items.append((str(VH/"videos"/q["video"]), q["question"]))
items = items[:N_SAMPLE]
acc = None; n = 0; gt_ref = 16
t0 = time.time()
for vp, q in items:
fr = decode(vp)
if fr is None: continue
pil = [Image.fromarray(f) for f in fr]
text = proc.apply_chat_template(
[{"role":"user","content":[{"type":"video"},{"type":"text","text":q+SUFFIX}]}],
tokenize=False, add_generation_prompt=True)
inp = proc(text=[text], videos=[pil], return_tensors="pt")
inp = {k:(v.to(model.device) if hasattr(v,"to") else v) for k,v in inp.items()}
out = model(**inp, output_attentions=True, use_cache=False)
ids = inp["input_ids"][0]; g = inp["video_grid_thw"][0].tolist(); grid_t = g[0]
if grid_t != gt_ref:
del out; continue
vis = (ids == vtok)
nv = int(vis.sum()); spatial = nv // grid_t
L = len(out.attentions)
mat = np.zeros((L, grid_t), dtype=np.float64)
for li in range(L):
a = out.attentions[li][0, :, -1, :].mean(0).float() # [seq]
va = a[vis][:grid_t*spatial].view(grid_t, spatial).sum(1)
s = float(va.sum()); va = (va/s) if s > 0 else va
mat[li] = va.cpu().numpy()
acc = mat if acc is None else acc + mat
n += 1
del out
if n % 20 == 0: print(f"[run] {n} ({time.time()-t0:.0f}s)", flush=True)
prof = acc / n # [L, grid_t]
np.save(ROOT/"outputs/frame_attention/layer_profile.npy", prof)
L, gt = prof.shape
print(f"\nn={n} questions, layers={L}, grid_t={gt}, uniform={1/gt:.3f}")
print("per-layer argmax bin (0=first .. 15=last), and shape tag:")
first = mid = last = 0
for li in range(L):
row = prof[li]; am = int(row.argmax())
tag = "FIRST" if am <= 1 else ("LAST" if am >= gt-2 else "MID")
if tag == "FIRST": first += 1
elif tag == "MID": mid += 1
else: last += 1
bar = "".join("#" if row[b] > 1.5/gt else ("+" if row[b] > 1/gt else ".") for b in range(gt))
print(f" L{li:2d} argmax={am:2d} {tag:5s} [{bar}]")
print(f"\nlayers peaking FIRST(<=1)={first} MID={mid} LAST(>=14)={last} of {L}")
try:
import matplotlib; matplotlib.use("Agg"); import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(prof, aspect="auto", cmap="magma", origin="lower")
ax.set_xlabel("temporal bin (0=start .. 15=end)"); ax.set_ylabel("layer")
ax.set_title(f"Per-layer frame-attention (n={n} VideoHallucer temporal Qs)\n"
"first-frame sink vs middle hump?")
fig.colorbar(im, label="mean attention within video"); fig.tight_layout()
p = ROOT/"outputs/frame_attention/fig_layer_profile.png"
fig.savefig(p, dpi=130); plt.close(fig); print(f"[fig] wrote {p}")
except Exception as e:
print("[fig] skip:", e)
if __name__ == "__main__":
main()