forensics-grpo / code /head_sanity.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
8.93 kB
"""Sanity-check the stage1 ForgeryHead on a sample of train videos.
For each sampled video we:
- load the cached video_inputs / video_kwargs
- run model.visual(...) -> visual features
- run model.forgery_head(...) -> per-second logits, sigmoid -> scores
- compare against GT segments (per-second binary labels)
Aggregate stats reported:
- global AUC across all per-second labels
- mean head score inside vs outside GT
- distribution of (in - out) gap per video
- per-generator breakdown
"""
import json
import os
import random
import sys
import time
import numpy as np
import torch
from transformers import Qwen2_5_VLForConditionalGeneration
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.open_r1.data_loader import (
GENERATOR_TO_DIR, TRAIN_GENERATORS, build_examples,
)
from src.open_r1.forgery_head import (
ForgeryHead, frame_labels_from_segments, head_auc as _head_auc,
)
CKPT = "/mnt/local-fast/zhangt/forensics_grpo/outputs_forensics/stage1_forgery"
ANNOT = "/mnt/local-fast/zhangt/annot/annot"
VROOT = "/mnt/local-fast/zhangt/video"
CACHE = "/mnt/local-fast/zhangt/forensics_grpo_cache_uniform3584_fps2.0"
N_SAMPLES = 250
SEED = 42
FPS_TO_GROUPS = 1.0
def main():
random.seed(SEED)
print(f"Loading model from {CKPT} ...", flush=True)
t0 = time.time()
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
CKPT, torch_dtype=torch.bfloat16, attn_implementation="sdpa",
)
model.eval()
print(f" loaded in {time.time()-t0:.1f}s. param dtype={next(model.parameters()).dtype}", flush=True)
# Attach head with the same hidden_dim used at train time, then load weights
# DIRECTLY from safetensors (Qwen2_5_VLForConditionalGeneration silently drops
# the forgery_head.* keys during from_pretrained).
head = ForgeryHead(hidden_dim=model.config.hidden_size, mlp_dim=1024)
head.to(dtype=torch.bfloat16)
import glob
import safetensors.torch as st
head_sd = {}
for p in sorted(glob.glob(os.path.join(CKPT, "model-*.safetensors"))):
with st.safe_open(p, framework="pt") as f:
for k in f.keys():
if k.startswith("forgery_head."):
head_sd[k.replace("forgery_head.", "")] = f.get_tensor(k)
print(f" head_sd keys collected: {list(head_sd.keys())}", flush=True)
res = head.load_state_dict(head_sd, strict=True)
print(f" head loaded: {res}", flush=True)
model.forgery_head = head
model = model.to("cuda:0")
head = head.to("cuda:0")
print("Building examples ...", flush=True)
examples = build_examples(
annot_dir=ANNOT, video_root=VROOT, generators=TRAIN_GENERATORS,
split_prefix="train", preprocessed_data_path=CACHE, require_video_exists=True,
)
print(f" {len(examples)} train examples", flush=True)
random.shuffle(examples)
examples = examples[:N_SAMPLES]
print(f" sampling {len(examples)}", flush=True)
all_logits = []
all_labels = []
per_video_in_minus_out = []
per_gen = {} # gen -> list of (mean_in, mean_out)
failures = 0
t0 = time.time()
for i, ex in enumerate(examples, 1):
sample_id = os.path.splitext(os.path.basename(ex["video_path"]))[0]
gen = ex["generator"]
cache_dir = os.path.join(CACHE, "train", gen, sample_id)
vi_path = os.path.join(cache_dir, "video_inputs.pt")
if not os.path.exists(vi_path):
failures += 1
continue
video_inputs = torch.load(vi_path, weights_only=False)
# video_inputs is a list of 1 tensor (T*4, C, H, W) or similar; the processor
# would normally batch + return pixel_values_videos + video_grid_thw. We
# reproduce that minimal batching here.
# Easier route: call processor directly. But to avoid re-encoding we
# use the cached path: video_inputs[0] is a single video tensor.
with open(os.path.join(cache_dir, "video_kwargs.json"), "r") as f:
video_kwargs = json.load(f)
# Build pixel_values_videos + grid manually. Qwen2.5-VL processor returns
# `pixel_values_videos` and `video_grid_thw` from the video_inputs list. We
# invoke the processor's underlying transform: easier — just use
# AutoProcessor with the same inputs and pull what we need.
from transformers import AutoProcessor
if not hasattr(main, "_proc"):
main._proc = AutoProcessor.from_pretrained(CKPT)
proc = main._proc
# The processor needs the raw video tensor; video_inputs is already the
# raw tensor list. Pass via videos=...
try:
packed = proc(text=["dummy"], videos=video_inputs, padding=True,
return_tensors="pt", **video_kwargs)
except Exception as e:
failures += 1
if failures <= 3:
print(f" [skip] {sample_id}: {type(e).__name__}: {e}")
continue
pv = packed["pixel_values_videos"].to("cuda:0", dtype=torch.bfloat16)
grid = packed["video_grid_thw"].to("cuda:0")
with torch.no_grad():
visual = model.visual(pv, grid_thw=grid) # (N_tot, hidden)
logits_list = head(visual, grid) # list of (T,)
logits = logits_list[0].float().cpu()
T = int(logits.shape[0])
labels = frame_labels_from_segments(ex["solution"], T, fps_to_groups=FPS_TO_GROUPS)
scores = torch.sigmoid(logits).numpy()
lbl = labels.numpy()
all_logits.append(logits.numpy())
all_labels.append(lbl)
if lbl.any() and not lbl.all():
m_in = float(scores[lbl > 0.5].mean())
m_out = float(scores[lbl < 0.5].mean())
per_video_in_minus_out.append(m_in - m_out)
per_gen.setdefault(gen, []).append((m_in, m_out))
if i % 25 == 0:
elapsed = time.time() - t0
print(f" [{i}/{len(examples)}] elapsed={elapsed:.0f}s "
f"running gap={np.mean(per_video_in_minus_out):.3f} "
f"failures={failures}", flush=True)
# === Aggregate ===
print("\n========== HEAD SANITY REPORT ==========")
print(f"sampled : {len(examples)} (failures: {failures})")
print(f"video count w/ both pos+neg seconds: {len(per_video_in_minus_out)}")
if all_logits:
L = np.concatenate(all_logits)
Y = np.concatenate(all_labels)
S = 1.0 / (1.0 + np.exp(-L)) # sigmoid
print(f"total per-second labels: {len(L)} ({int(Y.sum())} positive, {int((1-Y).sum())} negative)")
print(f"global mean score : POS={S[Y>0.5].mean():.3f} NEG={S[Y<0.5].mean():.3f} gap={S[Y>0.5].mean()-S[Y<0.5].mean():+.3f}")
# Global AUC via Mann-Whitney U (subsample if too large)
pos_s = S[Y > 0.5]
neg_s = S[Y < 0.5]
if len(pos_s) > 4000 or len(neg_s) > 4000:
rng = np.random.default_rng(SEED)
pos_s = rng.choice(pos_s, size=min(len(pos_s), 4000), replace=False)
neg_s = rng.choice(neg_s, size=min(len(neg_s), 4000), replace=False)
cmp = (pos_s[:, None] > neg_s[None, :]).astype(float)
eq = (pos_s[:, None] == neg_s[None, :]).astype(float) * 0.5
auc = (cmp + eq).mean()
print(f"global AUC (sampled cmp): {auc:.3f}")
if per_video_in_minus_out:
arr = np.array(per_video_in_minus_out)
print(f"\nper-video (in_mean - out_mean) over {len(arr)} videos:")
for q in [0, 10, 25, 50, 75, 90, 100]:
print(f" p{q:3d} = {np.percentile(arr, q):+.3f}")
print(f" mean = {arr.mean():+.3f} std = {arr.std():.3f}")
frac_useful = float((arr > 0.05).mean())
print(f" fraction of videos with gap > 0.05 : {frac_useful:.2%}")
frac_strong = float((arr > 0.15).mean())
print(f" fraction of videos with gap > 0.15 : {frac_strong:.2%}")
if per_gen:
print("\nper-generator mean scores:")
print(f" {'gen':<12} {'n':>4} {'pos':>6} {'neg':>6} {'gap':>6}")
for g in sorted(per_gen.keys()):
pairs = per_gen[g]
mp = np.mean([p[0] for p in pairs])
mn = np.mean([p[1] for p in pairs])
print(f" {g:<12} {len(pairs):>4} {mp:>6.3f} {mn:>6.3f} {mp-mn:>+6.3f}")
print("\nrecommendation:")
if not per_video_in_minus_out:
print(" ! degenerate (no videos with both pos+neg seconds) - cannot judge")
return
g = float(np.array(per_video_in_minus_out).mean())
if g > 0.15:
print(f" ✓ strong signal (mean gap {g:+.3f}) — option C reward will have teeth")
elif g > 0.05:
print(f" ~ moderate signal (mean gap {g:+.3f}) — option C may work but expect noisy gradients")
else:
print(f" ✗ weak signal (mean gap {g:+.3f}) — head not discriminative enough; train head more before C")
if __name__ == "__main__":
main()