| """Sanity-check the stage1 ForgeryHead on a sample of train videos. |
| |
| For each sampled video we: |
| - load the cached video_inputs / video_kwargs |
| - run model.visual(...) -> visual features |
| - run model.forgery_head(...) -> per-second logits, sigmoid -> scores |
| - compare against GT segments (per-second binary labels) |
| |
| Aggregate stats reported: |
| - global AUC across all per-second labels |
| - mean head score inside vs outside GT |
| - distribution of (in - out) gap per video |
| - per-generator breakdown |
| """ |
| import json |
| import os |
| import random |
| import sys |
| import time |
|
|
| import numpy as np |
| import torch |
| from transformers import Qwen2_5_VLForConditionalGeneration |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from src.open_r1.data_loader import ( |
| GENERATOR_TO_DIR, TRAIN_GENERATORS, build_examples, |
| ) |
| from src.open_r1.forgery_head import ( |
| ForgeryHead, frame_labels_from_segments, head_auc as _head_auc, |
| ) |
|
|
| CKPT = "/mnt/local-fast/zhangt/forensics_grpo/outputs_forensics/stage1_forgery" |
| ANNOT = "/mnt/local-fast/zhangt/annot/annot" |
| VROOT = "/mnt/local-fast/zhangt/video" |
| CACHE = "/mnt/local-fast/zhangt/forensics_grpo_cache_uniform3584_fps2.0" |
| N_SAMPLES = 250 |
| SEED = 42 |
| FPS_TO_GROUPS = 1.0 |
|
|
|
|
| def main(): |
| random.seed(SEED) |
|
|
| print(f"Loading model from {CKPT} ...", flush=True) |
| t0 = time.time() |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| CKPT, torch_dtype=torch.bfloat16, attn_implementation="sdpa", |
| ) |
| model.eval() |
| print(f" loaded in {time.time()-t0:.1f}s. param dtype={next(model.parameters()).dtype}", flush=True) |
|
|
| |
| |
| |
| head = ForgeryHead(hidden_dim=model.config.hidden_size, mlp_dim=1024) |
| head.to(dtype=torch.bfloat16) |
|
|
| import glob |
| import safetensors.torch as st |
| head_sd = {} |
| for p in sorted(glob.glob(os.path.join(CKPT, "model-*.safetensors"))): |
| with st.safe_open(p, framework="pt") as f: |
| for k in f.keys(): |
| if k.startswith("forgery_head."): |
| head_sd[k.replace("forgery_head.", "")] = f.get_tensor(k) |
| print(f" head_sd keys collected: {list(head_sd.keys())}", flush=True) |
| res = head.load_state_dict(head_sd, strict=True) |
| print(f" head loaded: {res}", flush=True) |
| model.forgery_head = head |
|
|
| model = model.to("cuda:0") |
| head = head.to("cuda:0") |
|
|
| print("Building examples ...", flush=True) |
| examples = build_examples( |
| annot_dir=ANNOT, video_root=VROOT, generators=TRAIN_GENERATORS, |
| split_prefix="train", preprocessed_data_path=CACHE, require_video_exists=True, |
| ) |
| print(f" {len(examples)} train examples", flush=True) |
|
|
| random.shuffle(examples) |
| examples = examples[:N_SAMPLES] |
| print(f" sampling {len(examples)}", flush=True) |
|
|
| all_logits = [] |
| all_labels = [] |
| per_video_in_minus_out = [] |
| per_gen = {} |
| failures = 0 |
|
|
| t0 = time.time() |
| for i, ex in enumerate(examples, 1): |
| sample_id = os.path.splitext(os.path.basename(ex["video_path"]))[0] |
| gen = ex["generator"] |
| cache_dir = os.path.join(CACHE, "train", gen, sample_id) |
| vi_path = os.path.join(cache_dir, "video_inputs.pt") |
| if not os.path.exists(vi_path): |
| failures += 1 |
| continue |
|
|
| video_inputs = torch.load(vi_path, weights_only=False) |
| |
| |
| |
| |
| |
| with open(os.path.join(cache_dir, "video_kwargs.json"), "r") as f: |
| video_kwargs = json.load(f) |
|
|
| |
| |
| |
| |
| from transformers import AutoProcessor |
| if not hasattr(main, "_proc"): |
| main._proc = AutoProcessor.from_pretrained(CKPT) |
| proc = main._proc |
| |
| |
| try: |
| packed = proc(text=["dummy"], videos=video_inputs, padding=True, |
| return_tensors="pt", **video_kwargs) |
| except Exception as e: |
| failures += 1 |
| if failures <= 3: |
| print(f" [skip] {sample_id}: {type(e).__name__}: {e}") |
| continue |
|
|
| pv = packed["pixel_values_videos"].to("cuda:0", dtype=torch.bfloat16) |
| grid = packed["video_grid_thw"].to("cuda:0") |
|
|
| with torch.no_grad(): |
| visual = model.visual(pv, grid_thw=grid) |
| logits_list = head(visual, grid) |
|
|
| logits = logits_list[0].float().cpu() |
| T = int(logits.shape[0]) |
| labels = frame_labels_from_segments(ex["solution"], T, fps_to_groups=FPS_TO_GROUPS) |
| scores = torch.sigmoid(logits).numpy() |
| lbl = labels.numpy() |
|
|
| all_logits.append(logits.numpy()) |
| all_labels.append(lbl) |
|
|
| if lbl.any() and not lbl.all(): |
| m_in = float(scores[lbl > 0.5].mean()) |
| m_out = float(scores[lbl < 0.5].mean()) |
| per_video_in_minus_out.append(m_in - m_out) |
| per_gen.setdefault(gen, []).append((m_in, m_out)) |
|
|
| if i % 25 == 0: |
| elapsed = time.time() - t0 |
| print(f" [{i}/{len(examples)}] elapsed={elapsed:.0f}s " |
| f"running gap={np.mean(per_video_in_minus_out):.3f} " |
| f"failures={failures}", flush=True) |
|
|
| |
| print("\n========== HEAD SANITY REPORT ==========") |
| print(f"sampled : {len(examples)} (failures: {failures})") |
| print(f"video count w/ both pos+neg seconds: {len(per_video_in_minus_out)}") |
|
|
| if all_logits: |
| L = np.concatenate(all_logits) |
| Y = np.concatenate(all_labels) |
| S = 1.0 / (1.0 + np.exp(-L)) |
| print(f"total per-second labels: {len(L)} ({int(Y.sum())} positive, {int((1-Y).sum())} negative)") |
| print(f"global mean score : POS={S[Y>0.5].mean():.3f} NEG={S[Y<0.5].mean():.3f} gap={S[Y>0.5].mean()-S[Y<0.5].mean():+.3f}") |
|
|
| |
| pos_s = S[Y > 0.5] |
| neg_s = S[Y < 0.5] |
| if len(pos_s) > 4000 or len(neg_s) > 4000: |
| rng = np.random.default_rng(SEED) |
| pos_s = rng.choice(pos_s, size=min(len(pos_s), 4000), replace=False) |
| neg_s = rng.choice(neg_s, size=min(len(neg_s), 4000), replace=False) |
| cmp = (pos_s[:, None] > neg_s[None, :]).astype(float) |
| eq = (pos_s[:, None] == neg_s[None, :]).astype(float) * 0.5 |
| auc = (cmp + eq).mean() |
| print(f"global AUC (sampled cmp): {auc:.3f}") |
|
|
| if per_video_in_minus_out: |
| arr = np.array(per_video_in_minus_out) |
| print(f"\nper-video (in_mean - out_mean) over {len(arr)} videos:") |
| for q in [0, 10, 25, 50, 75, 90, 100]: |
| print(f" p{q:3d} = {np.percentile(arr, q):+.3f}") |
| print(f" mean = {arr.mean():+.3f} std = {arr.std():.3f}") |
| frac_useful = float((arr > 0.05).mean()) |
| print(f" fraction of videos with gap > 0.05 : {frac_useful:.2%}") |
| frac_strong = float((arr > 0.15).mean()) |
| print(f" fraction of videos with gap > 0.15 : {frac_strong:.2%}") |
|
|
| if per_gen: |
| print("\nper-generator mean scores:") |
| print(f" {'gen':<12} {'n':>4} {'pos':>6} {'neg':>6} {'gap':>6}") |
| for g in sorted(per_gen.keys()): |
| pairs = per_gen[g] |
| mp = np.mean([p[0] for p in pairs]) |
| mn = np.mean([p[1] for p in pairs]) |
| print(f" {g:<12} {len(pairs):>4} {mp:>6.3f} {mn:>6.3f} {mp-mn:>+6.3f}") |
|
|
| print("\nrecommendation:") |
| if not per_video_in_minus_out: |
| print(" ! degenerate (no videos with both pos+neg seconds) - cannot judge") |
| return |
| g = float(np.array(per_video_in_minus_out).mean()) |
| if g > 0.15: |
| print(f" ✓ strong signal (mean gap {g:+.3f}) — option C reward will have teeth") |
| elif g > 0.05: |
| print(f" ~ moderate signal (mean gap {g:+.3f}) — option C may work but expect noisy gradients") |
| else: |
| print(f" ✗ weak signal (mean gap {g:+.3f}) — head not discriminative enough; train head more before C") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|