forensics-grpo / code /scripts /probe_zero_shot.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
10.9 kB
"""Zero-shot validation for R1 / R3 binary probing assumptions.
Goal: BEFORE committing to a multi-day GRPO training run with binary probing
reward, verify that Qwen2.5-VL actually distinguishes forgery boundaries from
generic "smooth" video positions.
What it tests
-------------
For each test video with multi-segment forgery GT, we probe at three kinds
of boundary points:
- forgery_start : t = GT segment start
- forgery_end : t = GT segment end
- control : a random t far from any GT boundary (Δ_safe seconds)
At each boundary, we run BOTH R1 (3 window probes: pre/post/cross coherence)
and R3 (4 point probes: forgery-classification at t±1).
Output
------
A JSON with per-class distribution statistics (mean / std / quantiles) and a
GO/MARGINAL/NO-GO recommendation per reward variant. Use this to decide
whether to add `binary_probing` to the v10 reward stack.
Run
---
python scripts/probe_zero_shot.py \
--annot_dir /mnt/local-fast/zhangt/annot/annot \
--video_root /mnt/local-fast/zhangt/video \
--preprocessed_data_path /mnt/local-fast/zhangt/forensics_grpo_cache_uniform3584_fps2.0 \
--model_path /mnt/local-fast/zhangt/Qwen2.5-VL-7B-Instruct \
--n_per_class 100 \
--out_json probe_zero_shot_results.json
"""
import argparse
import json
import os
import random
import sys
from collections import defaultdict
import numpy as np
import torch
from tqdm import tqdm
# Allow execution from anywhere inside the repo.
HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(HERE))
from src.open_r1.data_loader import build_examples, TEST_GENERATORS # noqa: E402
from src.open_r1.binary_prober import BinaryProber, slice_video_by_time # noqa: E402
from src.open_r1.reward import ( # noqa: E402
R1_COHERENCE_QUESTION,
R3_FORGERY_QUESTION,
)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--annot_dir", default="/mnt/local-fast/zhangt/annot/annot")
p.add_argument("--video_root", default="/mnt/local-fast/zhangt/video")
p.add_argument("--preprocessed_data_path", required=True,
help="Forensics cache root (output of preprocess_forensics.py)")
p.add_argument("--model_path", required=True,
help="Path to Qwen2.5-VL checkpoint used as frozen prober")
p.add_argument("--n_per_class", type=int, default=100,
help="Cap samples per boundary class (forgery_start/end, control)")
p.add_argument("--delta_s", type=float, default=2.0)
p.add_argument("--point_window_s", type=float, default=1.0)
p.add_argument("--safe_band_s", type=float, default=3.0,
help="Control points must be at least this many seconds "
"from any GT boundary")
p.add_argument("--seed", type=int, default=42)
p.add_argument("--out_json", required=True)
return p.parse_args()
def _enumerate_boundaries(examples, safe_band, rng):
"""Build (example, t_anchor, label) entries for each boundary class."""
by_label = defaultdict(list)
for ex in examples:
if not ex.get("preprocessed_path"):
continue
sol = ex["solution"]
duration = ex["durations"]
if not sol or not duration or duration < 2 * safe_band + 2:
continue
for (s, e) in sol:
if safe_band <= s <= duration - safe_band:
by_label["forgery_start"].append((ex, float(s)))
if safe_band <= e <= duration - safe_band:
by_label["forgery_end"].append((ex, float(e)))
# One control point per video (random, away from any GT boundary).
for _ in range(20):
t = float(rng.uniform(safe_band, duration - safe_band))
far_enough = all(
min(abs(t - s), abs(t - e)) > safe_band for (s, e) in sol
)
if far_enough:
by_label["control"].append((ex, t))
break
return by_label
def _load_video(ex):
"""Return (video_tensor, fps, duration) from a forensics example."""
pdir = ex["preprocessed_path"]
vi_path = os.path.join(pdir, "video_inputs.pt")
vk_path = os.path.join(pdir, "video_kwargs.json")
if not (os.path.exists(vi_path) and os.path.exists(vk_path)):
return None, None, None
vi = torch.load(vi_path, map_location="cpu", weights_only=False)
with open(vk_path) as f:
vk = json.load(f)
if isinstance(vi, list):
vi = vi[0]
fps = vk.get("fps")
if isinstance(fps, list):
fps = fps[0]
return vi, float(fps), float(ex["durations"])
def _r1_window_probes(t, delta, duration):
"""Return [(s_s, s_e, expected), ...] for R1 window probes around `t`."""
return [
(max(0.0, t - delta), t, "yes"), # pre
(t, min(duration, t + delta), "yes"), # post
(max(0.0, t - delta / 2), min(duration, t + delta / 2), "no"), # cross
]
def _r3_point_probes(t, point_window, duration):
half = point_window / 2
return [
(max(0.0, t - 1 - half), max(0.0, t - 1 + half), "no"),
(max(0.0, t + 1 - half), min(duration, t + 1 + half), "yes"),
]
def main():
args = parse_args()
random.seed(args.seed)
rng = np.random.default_rng(args.seed)
examples = build_examples(
annot_dir=args.annot_dir,
video_root=args.video_root,
generators=TEST_GENERATORS,
split_prefix="test",
preprocessed_data_path=args.preprocessed_data_path,
require_video_exists=False,
)
print(f"Loaded {len(examples)} test examples")
by_label = _enumerate_boundaries(examples, args.safe_band_s, rng)
print({k: len(v) for k, v in by_label.items()})
# Cap each class to n_per_class.
for label in list(by_label.keys()):
items = by_label[label]
if args.n_per_class > 0 and len(items) > args.n_per_class:
idx = rng.choice(len(items), args.n_per_class, replace=False)
by_label[label] = [items[i] for i in idx]
print(f" {label}: {len(by_label[label])} kept")
prober = BinaryProber(model_path=args.model_path)
# Result store: results[label][probe_kind][expected] -> list of P(expected)
results: dict = defaultdict(lambda: defaultdict(list))
def _run_probes(label, ex, t):
vi, fps, duration = _load_video(ex)
if vi is None:
return
# R1 probes (3 per boundary).
r1 = _r1_window_probes(t, args.delta_s, duration)
clips, fpss, qs, expecteds, probe_keys = [], [], [], [], []
for (s, e, expected) in r1:
clip = slice_video_by_time(vi, fps, s, e)
if clip is None:
continue
clips.append(clip)
fpss.append(fps)
qs.append(R1_COHERENCE_QUESTION)
expecteds.append(expected)
probe_keys.append(("R1", expected))
# R3 probes (2 per anchor; original spec is 4 around (t1, t2), but
# in zero-shot we treat each boundary point in isolation).
r3 = _r3_point_probes(t, args.point_window_s, duration)
for (s, e, expected) in r3:
clip = slice_video_by_time(vi, fps, s, e)
if clip is None:
continue
clips.append(clip)
fpss.append(fps)
qs.append(R3_FORGERY_QUESTION)
expecteds.append(expected)
probe_keys.append(("R3", expected))
if not clips:
return
# Batch in small chunks to avoid OOM on long videos.
out = []
BS = 8
for i in range(0, len(clips), BS):
out.extend(prober.probe_batch(clips[i:i + BS],
fpss[i:i + BS],
qs[i:i + BS]))
for (kind, expected), (p_yes, p_no) in zip(probe_keys, out):
results[label][f"{kind}_{expected}_Pexp"].append(
p_yes if expected == "yes" else p_no
)
results[label][f"{kind}_{expected}_Pyes"].append(p_yes)
for label, items in by_label.items():
for (ex, t) in tqdm(items, desc=label):
_run_probes(label, ex, t)
# Summarise + decision.
summary = {"args": vars(args), "stats": {}, "decision": {}}
def stat_of(arr):
a = np.asarray(arr)
if a.size == 0:
return {"n": 0}
return {
"n": int(a.size),
"mean": float(a.mean()),
"std": float(a.std()),
"median": float(np.median(a)),
"q25": float(np.percentile(a, 25)),
"q75": float(np.percentile(a, 75)),
}
for label, kinds in results.items():
summary["stats"][label] = {k: stat_of(v) for k, v in kinds.items()}
# R1 decision: cross-window P(no) at forgery boundary vs control.
def _mean(label, key):
vs = results.get(label, {}).get(key, [])
return float(np.mean(vs)) if vs else None
forg_cross_pno = np.mean(
(1 - np.array(results.get("forgery_start", {}).get("R1_no_Pyes", []) or [1])).tolist()
+ (1 - np.array(results.get("forgery_end", {}).get("R1_no_Pyes", []) or [1])).tolist()
)
ctrl_cross_pno = 1 - np.mean(results.get("control", {}).get("R1_no_Pyes", []) or [1.0])
delta_r1 = float(forg_cross_pno - ctrl_cross_pno)
summary["decision"]["R1"] = {
"forgery_cross_P_no_mean": float(forg_cross_pno),
"control_cross_P_no_mean": float(ctrl_cross_pno),
"delta": delta_r1,
"verdict": (
"GO (delta>0.20)" if delta_r1 > 0.20 else
"MARGINAL (0.10<delta<=0.20)" if delta_r1 > 0.10 else
"NO-GO (delta<=0.10)"
),
}
# R3 decision: P(yes-is-forgery) at boundary+1 vs control+1.
forg_yes_after = np.mean(
(results.get("forgery_start", {}).get("R3_yes_Pexp", []) or []) +
(results.get("forgery_end", {}).get("R3_yes_Pexp", []) or [])
)
ctrl_yes_after = np.mean(results.get("control", {}).get("R3_yes_Pexp", []) or [0.0])
delta_r3 = float(forg_yes_after - ctrl_yes_after)
summary["decision"]["R3"] = {
"forgery_inside_P_forged_mean": float(forg_yes_after),
"control_inside_P_forged_mean": float(ctrl_yes_after),
"delta": delta_r3,
"verdict": (
"GO (delta>0.20)" if delta_r3 > 0.20 else
"MARGINAL (0.10<delta<=0.20)" if delta_r3 > 0.10 else
"NO-GO (delta<=0.10)"
),
}
os.makedirs(os.path.dirname(os.path.abspath(args.out_json)) or ".", exist_ok=True)
with open(args.out_json, "w") as f:
json.dump(summary, f, indent=2)
print("\n=== DECISION ===")
print(json.dumps(summary["decision"], indent=2))
print(f"\nFull stats written to {args.out_json}")
if __name__ == "__main__":
main()