| """Zero-shot validation for R1 / R3 binary probing assumptions. |
| |
| Goal: BEFORE committing to a multi-day GRPO training run with binary probing |
| reward, verify that Qwen2.5-VL actually distinguishes forgery boundaries from |
| generic "smooth" video positions. |
| |
| What it tests |
| ------------- |
| For each test video with multi-segment forgery GT, we probe at three kinds |
| of boundary points: |
| |
| - forgery_start : t = GT segment start |
| - forgery_end : t = GT segment end |
| - control : a random t far from any GT boundary (Δ_safe seconds) |
| |
| At each boundary, we run BOTH R1 (3 window probes: pre/post/cross coherence) |
| and R3 (4 point probes: forgery-classification at t±1). |
| |
| Output |
| ------ |
| A JSON with per-class distribution statistics (mean / std / quantiles) and a |
| GO/MARGINAL/NO-GO recommendation per reward variant. Use this to decide |
| whether to add `binary_probing` to the v10 reward stack. |
| |
| Run |
| --- |
| python scripts/probe_zero_shot.py \ |
| --annot_dir /mnt/local-fast/zhangt/annot/annot \ |
| --video_root /mnt/local-fast/zhangt/video \ |
| --preprocessed_data_path /mnt/local-fast/zhangt/forensics_grpo_cache_uniform3584_fps2.0 \ |
| --model_path /mnt/local-fast/zhangt/Qwen2.5-VL-7B-Instruct \ |
| --n_per_class 100 \ |
| --out_json probe_zero_shot_results.json |
| """ |
| import argparse |
| import json |
| import os |
| import random |
| import sys |
| from collections import defaultdict |
|
|
| import numpy as np |
| import torch |
| from tqdm import tqdm |
|
|
| |
| HERE = os.path.dirname(os.path.abspath(__file__)) |
| sys.path.insert(0, os.path.dirname(HERE)) |
|
|
| from src.open_r1.data_loader import build_examples, TEST_GENERATORS |
| from src.open_r1.binary_prober import BinaryProber, slice_video_by_time |
| from src.open_r1.reward import ( |
| R1_COHERENCE_QUESTION, |
| R3_FORGERY_QUESTION, |
| ) |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--annot_dir", default="/mnt/local-fast/zhangt/annot/annot") |
| p.add_argument("--video_root", default="/mnt/local-fast/zhangt/video") |
| p.add_argument("--preprocessed_data_path", required=True, |
| help="Forensics cache root (output of preprocess_forensics.py)") |
| p.add_argument("--model_path", required=True, |
| help="Path to Qwen2.5-VL checkpoint used as frozen prober") |
| p.add_argument("--n_per_class", type=int, default=100, |
| help="Cap samples per boundary class (forgery_start/end, control)") |
| p.add_argument("--delta_s", type=float, default=2.0) |
| p.add_argument("--point_window_s", type=float, default=1.0) |
| p.add_argument("--safe_band_s", type=float, default=3.0, |
| help="Control points must be at least this many seconds " |
| "from any GT boundary") |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--out_json", required=True) |
| return p.parse_args() |
|
|
|
|
| def _enumerate_boundaries(examples, safe_band, rng): |
| """Build (example, t_anchor, label) entries for each boundary class.""" |
| by_label = defaultdict(list) |
| for ex in examples: |
| if not ex.get("preprocessed_path"): |
| continue |
| sol = ex["solution"] |
| duration = ex["durations"] |
| if not sol or not duration or duration < 2 * safe_band + 2: |
| continue |
| for (s, e) in sol: |
| if safe_band <= s <= duration - safe_band: |
| by_label["forgery_start"].append((ex, float(s))) |
| if safe_band <= e <= duration - safe_band: |
| by_label["forgery_end"].append((ex, float(e))) |
| |
| for _ in range(20): |
| t = float(rng.uniform(safe_band, duration - safe_band)) |
| far_enough = all( |
| min(abs(t - s), abs(t - e)) > safe_band for (s, e) in sol |
| ) |
| if far_enough: |
| by_label["control"].append((ex, t)) |
| break |
| return by_label |
|
|
|
|
| def _load_video(ex): |
| """Return (video_tensor, fps, duration) from a forensics example.""" |
| pdir = ex["preprocessed_path"] |
| vi_path = os.path.join(pdir, "video_inputs.pt") |
| vk_path = os.path.join(pdir, "video_kwargs.json") |
| if not (os.path.exists(vi_path) and os.path.exists(vk_path)): |
| return None, None, None |
| vi = torch.load(vi_path, map_location="cpu", weights_only=False) |
| with open(vk_path) as f: |
| vk = json.load(f) |
| if isinstance(vi, list): |
| vi = vi[0] |
| fps = vk.get("fps") |
| if isinstance(fps, list): |
| fps = fps[0] |
| return vi, float(fps), float(ex["durations"]) |
|
|
|
|
| def _r1_window_probes(t, delta, duration): |
| """Return [(s_s, s_e, expected), ...] for R1 window probes around `t`.""" |
| return [ |
| (max(0.0, t - delta), t, "yes"), |
| (t, min(duration, t + delta), "yes"), |
| (max(0.0, t - delta / 2), min(duration, t + delta / 2), "no"), |
| ] |
|
|
|
|
| def _r3_point_probes(t, point_window, duration): |
| half = point_window / 2 |
| return [ |
| (max(0.0, t - 1 - half), max(0.0, t - 1 + half), "no"), |
| (max(0.0, t + 1 - half), min(duration, t + 1 + half), "yes"), |
| ] |
|
|
|
|
| def main(): |
| args = parse_args() |
| random.seed(args.seed) |
| rng = np.random.default_rng(args.seed) |
|
|
| examples = build_examples( |
| annot_dir=args.annot_dir, |
| video_root=args.video_root, |
| generators=TEST_GENERATORS, |
| split_prefix="test", |
| preprocessed_data_path=args.preprocessed_data_path, |
| require_video_exists=False, |
| ) |
| print(f"Loaded {len(examples)} test examples") |
|
|
| by_label = _enumerate_boundaries(examples, args.safe_band_s, rng) |
| print({k: len(v) for k, v in by_label.items()}) |
|
|
| |
| for label in list(by_label.keys()): |
| items = by_label[label] |
| if args.n_per_class > 0 and len(items) > args.n_per_class: |
| idx = rng.choice(len(items), args.n_per_class, replace=False) |
| by_label[label] = [items[i] for i in idx] |
| print(f" {label}: {len(by_label[label])} kept") |
|
|
| prober = BinaryProber(model_path=args.model_path) |
|
|
| |
| results: dict = defaultdict(lambda: defaultdict(list)) |
|
|
| def _run_probes(label, ex, t): |
| vi, fps, duration = _load_video(ex) |
| if vi is None: |
| return |
| |
| r1 = _r1_window_probes(t, args.delta_s, duration) |
| clips, fpss, qs, expecteds, probe_keys = [], [], [], [], [] |
| for (s, e, expected) in r1: |
| clip = slice_video_by_time(vi, fps, s, e) |
| if clip is None: |
| continue |
| clips.append(clip) |
| fpss.append(fps) |
| qs.append(R1_COHERENCE_QUESTION) |
| expecteds.append(expected) |
| probe_keys.append(("R1", expected)) |
| |
| |
| r3 = _r3_point_probes(t, args.point_window_s, duration) |
| for (s, e, expected) in r3: |
| clip = slice_video_by_time(vi, fps, s, e) |
| if clip is None: |
| continue |
| clips.append(clip) |
| fpss.append(fps) |
| qs.append(R3_FORGERY_QUESTION) |
| expecteds.append(expected) |
| probe_keys.append(("R3", expected)) |
|
|
| if not clips: |
| return |
| |
| out = [] |
| BS = 8 |
| for i in range(0, len(clips), BS): |
| out.extend(prober.probe_batch(clips[i:i + BS], |
| fpss[i:i + BS], |
| qs[i:i + BS])) |
| for (kind, expected), (p_yes, p_no) in zip(probe_keys, out): |
| results[label][f"{kind}_{expected}_Pexp"].append( |
| p_yes if expected == "yes" else p_no |
| ) |
| results[label][f"{kind}_{expected}_Pyes"].append(p_yes) |
|
|
| for label, items in by_label.items(): |
| for (ex, t) in tqdm(items, desc=label): |
| _run_probes(label, ex, t) |
|
|
| |
| summary = {"args": vars(args), "stats": {}, "decision": {}} |
| def stat_of(arr): |
| a = np.asarray(arr) |
| if a.size == 0: |
| return {"n": 0} |
| return { |
| "n": int(a.size), |
| "mean": float(a.mean()), |
| "std": float(a.std()), |
| "median": float(np.median(a)), |
| "q25": float(np.percentile(a, 25)), |
| "q75": float(np.percentile(a, 75)), |
| } |
|
|
| for label, kinds in results.items(): |
| summary["stats"][label] = {k: stat_of(v) for k, v in kinds.items()} |
|
|
| |
| def _mean(label, key): |
| vs = results.get(label, {}).get(key, []) |
| return float(np.mean(vs)) if vs else None |
|
|
| forg_cross_pno = np.mean( |
| (1 - np.array(results.get("forgery_start", {}).get("R1_no_Pyes", []) or [1])).tolist() |
| + (1 - np.array(results.get("forgery_end", {}).get("R1_no_Pyes", []) or [1])).tolist() |
| ) |
| ctrl_cross_pno = 1 - np.mean(results.get("control", {}).get("R1_no_Pyes", []) or [1.0]) |
| delta_r1 = float(forg_cross_pno - ctrl_cross_pno) |
| summary["decision"]["R1"] = { |
| "forgery_cross_P_no_mean": float(forg_cross_pno), |
| "control_cross_P_no_mean": float(ctrl_cross_pno), |
| "delta": delta_r1, |
| "verdict": ( |
| "GO (delta>0.20)" if delta_r1 > 0.20 else |
| "MARGINAL (0.10<delta<=0.20)" if delta_r1 > 0.10 else |
| "NO-GO (delta<=0.10)" |
| ), |
| } |
|
|
| |
| forg_yes_after = np.mean( |
| (results.get("forgery_start", {}).get("R3_yes_Pexp", []) or []) + |
| (results.get("forgery_end", {}).get("R3_yes_Pexp", []) or []) |
| ) |
| ctrl_yes_after = np.mean(results.get("control", {}).get("R3_yes_Pexp", []) or [0.0]) |
| delta_r3 = float(forg_yes_after - ctrl_yes_after) |
| summary["decision"]["R3"] = { |
| "forgery_inside_P_forged_mean": float(forg_yes_after), |
| "control_inside_P_forged_mean": float(ctrl_yes_after), |
| "delta": delta_r3, |
| "verdict": ( |
| "GO (delta>0.20)" if delta_r3 > 0.20 else |
| "MARGINAL (0.10<delta<=0.20)" if delta_r3 > 0.10 else |
| "NO-GO (delta<=0.10)" |
| ), |
| } |
|
|
| os.makedirs(os.path.dirname(os.path.abspath(args.out_json)) or ".", exist_ok=True) |
| with open(args.out_json, "w") as f: |
| json.dump(summary, f, indent=2) |
| print("\n=== DECISION ===") |
| print(json.dumps(summary["decision"], indent=2)) |
| print(f"\nFull stats written to {args.out_json}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|