"""Sanity-check the stage1 ForgeryHead on a sample of train videos. For each sampled video we: - load the cached video_inputs / video_kwargs - run model.visual(...) -> visual features - run model.forgery_head(...) -> per-second logits, sigmoid -> scores - compare against GT segments (per-second binary labels) Aggregate stats reported: - global AUC across all per-second labels - mean head score inside vs outside GT - distribution of (in - out) gap per video - per-generator breakdown """ import json import os import random import sys import time import numpy as np import torch from transformers import Qwen2_5_VLForConditionalGeneration sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from src.open_r1.data_loader import ( GENERATOR_TO_DIR, TRAIN_GENERATORS, build_examples, ) from src.open_r1.forgery_head import ( ForgeryHead, frame_labels_from_segments, head_auc as _head_auc, ) CKPT = "/mnt/local-fast/zhangt/forensics_grpo/outputs_forensics/stage1_forgery" ANNOT = "/mnt/local-fast/zhangt/annot/annot" VROOT = "/mnt/local-fast/zhangt/video" CACHE = "/mnt/local-fast/zhangt/forensics_grpo_cache_uniform3584_fps2.0" N_SAMPLES = 250 SEED = 42 FPS_TO_GROUPS = 1.0 def main(): random.seed(SEED) print(f"Loading model from {CKPT} ...", flush=True) t0 = time.time() model = Qwen2_5_VLForConditionalGeneration.from_pretrained( CKPT, torch_dtype=torch.bfloat16, attn_implementation="sdpa", ) model.eval() print(f" loaded in {time.time()-t0:.1f}s. param dtype={next(model.parameters()).dtype}", flush=True) # Attach head with the same hidden_dim used at train time, then load weights # DIRECTLY from safetensors (Qwen2_5_VLForConditionalGeneration silently drops # the forgery_head.* keys during from_pretrained). head = ForgeryHead(hidden_dim=model.config.hidden_size, mlp_dim=1024) head.to(dtype=torch.bfloat16) import glob import safetensors.torch as st head_sd = {} for p in sorted(glob.glob(os.path.join(CKPT, "model-*.safetensors"))): with st.safe_open(p, framework="pt") as f: for k in f.keys(): if k.startswith("forgery_head."): head_sd[k.replace("forgery_head.", "")] = f.get_tensor(k) print(f" head_sd keys collected: {list(head_sd.keys())}", flush=True) res = head.load_state_dict(head_sd, strict=True) print(f" head loaded: {res}", flush=True) model.forgery_head = head model = model.to("cuda:0") head = head.to("cuda:0") print("Building examples ...", flush=True) examples = build_examples( annot_dir=ANNOT, video_root=VROOT, generators=TRAIN_GENERATORS, split_prefix="train", preprocessed_data_path=CACHE, require_video_exists=True, ) print(f" {len(examples)} train examples", flush=True) random.shuffle(examples) examples = examples[:N_SAMPLES] print(f" sampling {len(examples)}", flush=True) all_logits = [] all_labels = [] per_video_in_minus_out = [] per_gen = {} # gen -> list of (mean_in, mean_out) failures = 0 t0 = time.time() for i, ex in enumerate(examples, 1): sample_id = os.path.splitext(os.path.basename(ex["video_path"]))[0] gen = ex["generator"] cache_dir = os.path.join(CACHE, "train", gen, sample_id) vi_path = os.path.join(cache_dir, "video_inputs.pt") if not os.path.exists(vi_path): failures += 1 continue video_inputs = torch.load(vi_path, weights_only=False) # video_inputs is a list of 1 tensor (T*4, C, H, W) or similar; the processor # would normally batch + return pixel_values_videos + video_grid_thw. We # reproduce that minimal batching here. # Easier route: call processor directly. But to avoid re-encoding we # use the cached path: video_inputs[0] is a single video tensor. with open(os.path.join(cache_dir, "video_kwargs.json"), "r") as f: video_kwargs = json.load(f) # Build pixel_values_videos + grid manually. Qwen2.5-VL processor returns # `pixel_values_videos` and `video_grid_thw` from the video_inputs list. We # invoke the processor's underlying transform: easier — just use # AutoProcessor with the same inputs and pull what we need. from transformers import AutoProcessor if not hasattr(main, "_proc"): main._proc = AutoProcessor.from_pretrained(CKPT) proc = main._proc # The processor needs the raw video tensor; video_inputs is already the # raw tensor list. Pass via videos=... try: packed = proc(text=["dummy"], videos=video_inputs, padding=True, return_tensors="pt", **video_kwargs) except Exception as e: failures += 1 if failures <= 3: print(f" [skip] {sample_id}: {type(e).__name__}: {e}") continue pv = packed["pixel_values_videos"].to("cuda:0", dtype=torch.bfloat16) grid = packed["video_grid_thw"].to("cuda:0") with torch.no_grad(): visual = model.visual(pv, grid_thw=grid) # (N_tot, hidden) logits_list = head(visual, grid) # list of (T,) logits = logits_list[0].float().cpu() T = int(logits.shape[0]) labels = frame_labels_from_segments(ex["solution"], T, fps_to_groups=FPS_TO_GROUPS) scores = torch.sigmoid(logits).numpy() lbl = labels.numpy() all_logits.append(logits.numpy()) all_labels.append(lbl) if lbl.any() and not lbl.all(): m_in = float(scores[lbl > 0.5].mean()) m_out = float(scores[lbl < 0.5].mean()) per_video_in_minus_out.append(m_in - m_out) per_gen.setdefault(gen, []).append((m_in, m_out)) if i % 25 == 0: elapsed = time.time() - t0 print(f" [{i}/{len(examples)}] elapsed={elapsed:.0f}s " f"running gap={np.mean(per_video_in_minus_out):.3f} " f"failures={failures}", flush=True) # === Aggregate === print("\n========== HEAD SANITY REPORT ==========") print(f"sampled : {len(examples)} (failures: {failures})") print(f"video count w/ both pos+neg seconds: {len(per_video_in_minus_out)}") if all_logits: L = np.concatenate(all_logits) Y = np.concatenate(all_labels) S = 1.0 / (1.0 + np.exp(-L)) # sigmoid print(f"total per-second labels: {len(L)} ({int(Y.sum())} positive, {int((1-Y).sum())} negative)") print(f"global mean score : POS={S[Y>0.5].mean():.3f} NEG={S[Y<0.5].mean():.3f} gap={S[Y>0.5].mean()-S[Y<0.5].mean():+.3f}") # Global AUC via Mann-Whitney U (subsample if too large) pos_s = S[Y > 0.5] neg_s = S[Y < 0.5] if len(pos_s) > 4000 or len(neg_s) > 4000: rng = np.random.default_rng(SEED) pos_s = rng.choice(pos_s, size=min(len(pos_s), 4000), replace=False) neg_s = rng.choice(neg_s, size=min(len(neg_s), 4000), replace=False) cmp = (pos_s[:, None] > neg_s[None, :]).astype(float) eq = (pos_s[:, None] == neg_s[None, :]).astype(float) * 0.5 auc = (cmp + eq).mean() print(f"global AUC (sampled cmp): {auc:.3f}") if per_video_in_minus_out: arr = np.array(per_video_in_minus_out) print(f"\nper-video (in_mean - out_mean) over {len(arr)} videos:") for q in [0, 10, 25, 50, 75, 90, 100]: print(f" p{q:3d} = {np.percentile(arr, q):+.3f}") print(f" mean = {arr.mean():+.3f} std = {arr.std():.3f}") frac_useful = float((arr > 0.05).mean()) print(f" fraction of videos with gap > 0.05 : {frac_useful:.2%}") frac_strong = float((arr > 0.15).mean()) print(f" fraction of videos with gap > 0.15 : {frac_strong:.2%}") if per_gen: print("\nper-generator mean scores:") print(f" {'gen':<12} {'n':>4} {'pos':>6} {'neg':>6} {'gap':>6}") for g in sorted(per_gen.keys()): pairs = per_gen[g] mp = np.mean([p[0] for p in pairs]) mn = np.mean([p[1] for p in pairs]) print(f" {g:<12} {len(pairs):>4} {mp:>6.3f} {mn:>6.3f} {mp-mn:>+6.3f}") print("\nrecommendation:") if not per_video_in_minus_out: print(" ! degenerate (no videos with both pos+neg seconds) - cannot judge") return g = float(np.array(per_video_in_minus_out).mean()) if g > 0.15: print(f" ✓ strong signal (mean gap {g:+.3f}) — option C reward will have teeth") elif g > 0.05: print(f" ~ moderate signal (mean gap {g:+.3f}) — option C may work but expect noisy gradients") else: print(f" ✗ weak signal (mean gap {g:+.3f}) — head not discriminative enough; train head more before C") if __name__ == "__main__": main()