"""Zero-shot validation for R1 / R3 binary probing assumptions. Goal: BEFORE committing to a multi-day GRPO training run with binary probing reward, verify that Qwen2.5-VL actually distinguishes forgery boundaries from generic "smooth" video positions. What it tests ------------- For each test video with multi-segment forgery GT, we probe at three kinds of boundary points: - forgery_start : t = GT segment start - forgery_end : t = GT segment end - control : a random t far from any GT boundary (Δ_safe seconds) At each boundary, we run BOTH R1 (3 window probes: pre/post/cross coherence) and R3 (4 point probes: forgery-classification at t±1). Output ------ A JSON with per-class distribution statistics (mean / std / quantiles) and a GO/MARGINAL/NO-GO recommendation per reward variant. Use this to decide whether to add `binary_probing` to the v10 reward stack. Run --- python scripts/probe_zero_shot.py \ --annot_dir /mnt/local-fast/zhangt/annot/annot \ --video_root /mnt/local-fast/zhangt/video \ --preprocessed_data_path /mnt/local-fast/zhangt/forensics_grpo_cache_uniform3584_fps2.0 \ --model_path /mnt/local-fast/zhangt/Qwen2.5-VL-7B-Instruct \ --n_per_class 100 \ --out_json probe_zero_shot_results.json """ import argparse import json import os import random import sys from collections import defaultdict import numpy as np import torch from tqdm import tqdm # Allow execution from anywhere inside the repo. HERE = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, os.path.dirname(HERE)) from src.open_r1.data_loader import build_examples, TEST_GENERATORS # noqa: E402 from src.open_r1.binary_prober import BinaryProber, slice_video_by_time # noqa: E402 from src.open_r1.reward import ( # noqa: E402 R1_COHERENCE_QUESTION, R3_FORGERY_QUESTION, ) def parse_args(): p = argparse.ArgumentParser() p.add_argument("--annot_dir", default="/mnt/local-fast/zhangt/annot/annot") p.add_argument("--video_root", default="/mnt/local-fast/zhangt/video") p.add_argument("--preprocessed_data_path", required=True, help="Forensics cache root (output of preprocess_forensics.py)") p.add_argument("--model_path", required=True, help="Path to Qwen2.5-VL checkpoint used as frozen prober") p.add_argument("--n_per_class", type=int, default=100, help="Cap samples per boundary class (forgery_start/end, control)") p.add_argument("--delta_s", type=float, default=2.0) p.add_argument("--point_window_s", type=float, default=1.0) p.add_argument("--safe_band_s", type=float, default=3.0, help="Control points must be at least this many seconds " "from any GT boundary") p.add_argument("--seed", type=int, default=42) p.add_argument("--out_json", required=True) return p.parse_args() def _enumerate_boundaries(examples, safe_band, rng): """Build (example, t_anchor, label) entries for each boundary class.""" by_label = defaultdict(list) for ex in examples: if not ex.get("preprocessed_path"): continue sol = ex["solution"] duration = ex["durations"] if not sol or not duration or duration < 2 * safe_band + 2: continue for (s, e) in sol: if safe_band <= s <= duration - safe_band: by_label["forgery_start"].append((ex, float(s))) if safe_band <= e <= duration - safe_band: by_label["forgery_end"].append((ex, float(e))) # One control point per video (random, away from any GT boundary). for _ in range(20): t = float(rng.uniform(safe_band, duration - safe_band)) far_enough = all( min(abs(t - s), abs(t - e)) > safe_band for (s, e) in sol ) if far_enough: by_label["control"].append((ex, t)) break return by_label def _load_video(ex): """Return (video_tensor, fps, duration) from a forensics example.""" pdir = ex["preprocessed_path"] vi_path = os.path.join(pdir, "video_inputs.pt") vk_path = os.path.join(pdir, "video_kwargs.json") if not (os.path.exists(vi_path) and os.path.exists(vk_path)): return None, None, None vi = torch.load(vi_path, map_location="cpu", weights_only=False) with open(vk_path) as f: vk = json.load(f) if isinstance(vi, list): vi = vi[0] fps = vk.get("fps") if isinstance(fps, list): fps = fps[0] return vi, float(fps), float(ex["durations"]) def _r1_window_probes(t, delta, duration): """Return [(s_s, s_e, expected), ...] for R1 window probes around `t`.""" return [ (max(0.0, t - delta), t, "yes"), # pre (t, min(duration, t + delta), "yes"), # post (max(0.0, t - delta / 2), min(duration, t + delta / 2), "no"), # cross ] def _r3_point_probes(t, point_window, duration): half = point_window / 2 return [ (max(0.0, t - 1 - half), max(0.0, t - 1 + half), "no"), (max(0.0, t + 1 - half), min(duration, t + 1 + half), "yes"), ] def main(): args = parse_args() random.seed(args.seed) rng = np.random.default_rng(args.seed) examples = build_examples( annot_dir=args.annot_dir, video_root=args.video_root, generators=TEST_GENERATORS, split_prefix="test", preprocessed_data_path=args.preprocessed_data_path, require_video_exists=False, ) print(f"Loaded {len(examples)} test examples") by_label = _enumerate_boundaries(examples, args.safe_band_s, rng) print({k: len(v) for k, v in by_label.items()}) # Cap each class to n_per_class. for label in list(by_label.keys()): items = by_label[label] if args.n_per_class > 0 and len(items) > args.n_per_class: idx = rng.choice(len(items), args.n_per_class, replace=False) by_label[label] = [items[i] for i in idx] print(f" {label}: {len(by_label[label])} kept") prober = BinaryProber(model_path=args.model_path) # Result store: results[label][probe_kind][expected] -> list of P(expected) results: dict = defaultdict(lambda: defaultdict(list)) def _run_probes(label, ex, t): vi, fps, duration = _load_video(ex) if vi is None: return # R1 probes (3 per boundary). r1 = _r1_window_probes(t, args.delta_s, duration) clips, fpss, qs, expecteds, probe_keys = [], [], [], [], [] for (s, e, expected) in r1: clip = slice_video_by_time(vi, fps, s, e) if clip is None: continue clips.append(clip) fpss.append(fps) qs.append(R1_COHERENCE_QUESTION) expecteds.append(expected) probe_keys.append(("R1", expected)) # R3 probes (2 per anchor; original spec is 4 around (t1, t2), but # in zero-shot we treat each boundary point in isolation). r3 = _r3_point_probes(t, args.point_window_s, duration) for (s, e, expected) in r3: clip = slice_video_by_time(vi, fps, s, e) if clip is None: continue clips.append(clip) fpss.append(fps) qs.append(R3_FORGERY_QUESTION) expecteds.append(expected) probe_keys.append(("R3", expected)) if not clips: return # Batch in small chunks to avoid OOM on long videos. out = [] BS = 8 for i in range(0, len(clips), BS): out.extend(prober.probe_batch(clips[i:i + BS], fpss[i:i + BS], qs[i:i + BS])) for (kind, expected), (p_yes, p_no) in zip(probe_keys, out): results[label][f"{kind}_{expected}_Pexp"].append( p_yes if expected == "yes" else p_no ) results[label][f"{kind}_{expected}_Pyes"].append(p_yes) for label, items in by_label.items(): for (ex, t) in tqdm(items, desc=label): _run_probes(label, ex, t) # Summarise + decision. summary = {"args": vars(args), "stats": {}, "decision": {}} def stat_of(arr): a = np.asarray(arr) if a.size == 0: return {"n": 0} return { "n": int(a.size), "mean": float(a.mean()), "std": float(a.std()), "median": float(np.median(a)), "q25": float(np.percentile(a, 25)), "q75": float(np.percentile(a, 75)), } for label, kinds in results.items(): summary["stats"][label] = {k: stat_of(v) for k, v in kinds.items()} # R1 decision: cross-window P(no) at forgery boundary vs control. def _mean(label, key): vs = results.get(label, {}).get(key, []) return float(np.mean(vs)) if vs else None forg_cross_pno = np.mean( (1 - np.array(results.get("forgery_start", {}).get("R1_no_Pyes", []) or [1])).tolist() + (1 - np.array(results.get("forgery_end", {}).get("R1_no_Pyes", []) or [1])).tolist() ) ctrl_cross_pno = 1 - np.mean(results.get("control", {}).get("R1_no_Pyes", []) or [1.0]) delta_r1 = float(forg_cross_pno - ctrl_cross_pno) summary["decision"]["R1"] = { "forgery_cross_P_no_mean": float(forg_cross_pno), "control_cross_P_no_mean": float(ctrl_cross_pno), "delta": delta_r1, "verdict": ( "GO (delta>0.20)" if delta_r1 > 0.20 else "MARGINAL (0.10 0.10 else "NO-GO (delta<=0.10)" ), } # R3 decision: P(yes-is-forgery) at boundary+1 vs control+1. forg_yes_after = np.mean( (results.get("forgery_start", {}).get("R3_yes_Pexp", []) or []) + (results.get("forgery_end", {}).get("R3_yes_Pexp", []) or []) ) ctrl_yes_after = np.mean(results.get("control", {}).get("R3_yes_Pexp", []) or [0.0]) delta_r3 = float(forg_yes_after - ctrl_yes_after) summary["decision"]["R3"] = { "forgery_inside_P_forged_mean": float(forg_yes_after), "control_inside_P_forged_mean": float(ctrl_yes_after), "delta": delta_r3, "verdict": ( "GO (delta>0.20)" if delta_r3 > 0.20 else "MARGINAL (0.10 0.10 else "NO-GO (delta<=0.10)" ), } os.makedirs(os.path.dirname(os.path.abspath(args.out_json)) or ".", exist_ok=True) with open(args.out_json, "w") as f: json.dump(summary, f, indent=2) print("\n=== DECISION ===") print(json.dumps(summary["decision"], indent=2)) print(f"\nFull stats written to {args.out_json}") if __name__ == "__main__": main()