forensics-grpo / code /fbc_validate.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
11.2 kB
"""Forward-Backward Consistency (FBC) signal validation.
Hypothesis: Real-world video has temporal asymmetry under reversal (gravity,
momentum, causal flow); AI-generated segments often lack this asymmetry.
So a model trained for forgery localization should produce SIMILAR predictions
on the forward and reversed versions of the same video — because the AI
artifact carries through reversal, while real content gets "weird" enough to
suppress false-positive detection.
Quantitative test: run stage1_decomp_boundary ckpt on each test sample twice
(forward video / temporally-flipped video). Map reversed-prediction back to
original coordinates and measure:
IoU(pred_F, GT) — forward accuracy (baseline)
IoU(pred_R_remapped, GT) — reverse accuracy
IoU(pred_F, pred_R_remapped) — KEY: model self-consistency under reversal
For FBC to be a useful GRPO reward:
1. mean IoU(F, R) should be substantially > 0 (i.e. model IS consistent
— if it's near 0, reverse video is just confusing the model and we
can't extract a forensic signal from it).
2. corr(IoU(F, R), IoU(F, GT)) > 0 — consistent predictions correlate
with correct predictions. This is what makes "push toward consistency"
a valid training pressure.
3. Per-generator analysis: AI-heavy generators (wan, ltx, vace, fcvg)
should have higher IoU(F, R) than less-AI generators if the hypothesis
about AI lacking temporal causality holds.
If (1) and (2) fail, FBC is not a usable signal and we need a different idea.
"""
from __future__ import annotations
import argparse
import glob
import json
import os
import sys
import time
from pathlib import Path
import numpy as np
import torch
from transformers import (
AutoProcessor,
GenerationConfig,
Qwen2_5_VLForConditionalGeneration,
)
REPO = Path("/mnt/local-fast/zhangt/forensics_grpo")
sys.path.insert(0, str(REPO))
sys.path.insert(0, str(REPO / "src"))
from src.open_r1.data_loader import TEST_GENERATORS, build_examples
from src.open_r1.reward import parse_segments
from src.open_r1.trainer.grpo_trainer_video_GT_soft import (
SYSTEM_PROMPT,
get_question_template,
)
VROOT = "/mnt/local-fast/zhangt/video"
ANNOT = "/mnt/local-fast/zhangt/annot/annot"
CACHE = "/mnt/local-fast/zhangt/forensics_grpo_cache_uniform3584_fps2.0"
def iou_1d(a, b):
s1, e1 = a; s2, e2 = b
inter = max(0.0, min(e1, e2) - max(s1, s2))
union = max(e1, e2) - min(s1, s2)
return inter / union if union > 0 else 0.0
def soft_f1_iou(preds, gts):
"""Set-level soft IoU = soft_F1 of pairwise IoU matrix (matches reward.py)."""
if not preds and not gts:
return 1.0
if not preds or not gts:
return 0.0
pres = [max(iou_1d(p, g) for g in gts) for p in preds]
recs = [max(iou_1d(g, p) for p in preds) for g in gts]
p, r = sum(pres) / len(pres), sum(recs) / len(recs)
return 2 * p * r / (p + r) if (p + r) > 0 else 0.0
def remap_reversed(segs, duration):
"""Map intervals from reversed-time coords back to original coords."""
return [(max(0.0, duration - e), max(0.0, duration - s)) for s, e in segs]
def run_inference(model, processor, video_tensor, fps, question, gen_cfg, device):
chat = [
{"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
{"role": "user", "content": [
{"type": "video", "video": "placeholder"},
{"type": "text", "text": question},
]},
]
text = processor.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[text],
videos=[video_tensor],
fps=[fps],
padding=True,
return_tensors="pt",
padding_side="left",
add_special_tokens=False,
)
inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
with torch.no_grad():
out_ids = model.generate(**inputs, generation_config=gen_cfg, use_cache=True)
gen_ids = out_ids[0][inputs["input_ids"].shape[1]:]
return processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model_path", default=str(REPO / "outputs_forensics/stage1_decomp_boundary"))
ap.add_argument("--n", type=int, default=200, help="number of test samples to evaluate")
ap.add_argument("--device", default="cuda:0")
ap.add_argument("--max_new_tokens", type=int, default=64)
ap.add_argument("--out", default=str(REPO / "fbc_signal_validation.jsonl"))
args = ap.parse_args()
# No-CoT prompt since stage1 was trained without CoT.
os.environ["FORENSICS_COT"] = "false"
print(f"[fbc-validate] device={args.device} model={args.model_path} n={args.n}",
flush=True)
t0 = time.time()
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
args.model_path, torch_dtype=torch.bfloat16,
use_sliding_window=True, attn_implementation="flash_attention_2",
device_map=args.device,
)
model.eval()
processor = AutoProcessor.from_pretrained(args.model_path)
model.config.use_cache = True
if hasattr(model, "generation_config"):
model.generation_config.use_cache = True
print(f" loaded in {time.time()-t0:.1f}s", flush=True)
examples = build_examples(
annot_dir=ANNOT, video_root=VROOT, generators=TEST_GENERATORS,
split_prefix="test", preprocessed_data_path=CACHE, require_video_exists=True,
)
# Deterministic sample: first N with cached features.
sampled = []
for ex in examples:
sample_id = os.path.splitext(os.path.basename(ex["video_path"]))[0]
sample_dir = os.path.join(CACHE, "test", ex["generator"], sample_id)
if os.path.exists(os.path.join(sample_dir, "video_inputs.pt")):
sampled.append((ex, sample_id, sample_dir))
if len(sampled) >= args.n:
break
print(f" using {len(sampled)} samples", flush=True)
question = get_question_template()
gen_cfg = GenerationConfig(
max_new_tokens=args.max_new_tokens, do_sample=False,
temperature=1e-6,
pad_token_id=processor.tokenizer.pad_token_id, use_cache=True,
)
fout = open(args.out, "w")
records = []
t_start = time.time()
for i, (ex, sample_id, sample_dir) in enumerate(sampled):
try:
feats = torch.load(os.path.join(sample_dir, "video_inputs.pt"),
weights_only=False)
with open(os.path.join(sample_dir, "video_kwargs.json")) as f:
kw = json.load(f)
video_f = feats[0] # (T, C, H, W)
video_r = video_f.flip(0).contiguous()
fps = kw["fps"][0]
duration = video_f.shape[0] / fps
out_f = run_inference(model, processor, video_f, fps, question, gen_cfg, args.device)
pred_f = parse_segments(out_f)
out_r = run_inference(model, processor, video_r, fps, question, gen_cfg, args.device)
pred_r = parse_segments(out_r)
pred_r_remapped = remap_reversed(pred_r, duration)
gt = [tuple(s) for s in ex["solution"]]
iou_f_gt = soft_f1_iou(pred_f, gt)
iou_r_gt = soft_f1_iou(pred_r_remapped, gt)
iou_f_r = soft_f1_iou(pred_f, pred_r_remapped)
except Exception as e:
print(f" [skip] {sample_id}: {type(e).__name__}: {e}", flush=True)
continue
rec = {
"sample_id": sample_id,
"generator": ex["generator"],
"duration": duration,
"gt": gt,
"pred_f": pred_f,
"pred_r_remapped": pred_r_remapped,
"iou_f_gt": iou_f_gt,
"iou_r_gt": iou_r_gt,
"iou_f_r": iou_f_r,
"n_pred_f": len(pred_f),
"n_pred_r": len(pred_r),
"n_gt": len(gt),
}
records.append(rec)
fout.write(json.dumps(rec) + "\n"); fout.flush()
if (i + 1) % 20 == 0:
elapsed = time.time() - t_start
rate = (i + 1) / elapsed
eta = (len(sampled) - i - 1) / rate
cur = np.array([(r["iou_f_gt"], r["iou_r_gt"], r["iou_f_r"]) for r in records])
print(f" i={i+1}/{len(sampled)} rate={rate:.2f}/s eta={eta/60:.1f}min "
f"f_gt={cur[:,0].mean():.3f} r_gt={cur[:,1].mean():.3f} f_r={cur[:,2].mean():.3f}",
flush=True)
fout.close()
print(f"\n=== FBC SIGNAL VALIDATION SUMMARY (n={len(records)}) ===")
A = np.array([(r["iou_f_gt"], r["iou_r_gt"], r["iou_f_r"]) for r in records])
iou_f_gt, iou_r_gt, iou_f_r = A[:, 0], A[:, 1], A[:, 2]
print(f"\nOverall:")
print(f" iou_f_gt (forward acc) : mean={iou_f_gt.mean():.3f} median={np.median(iou_f_gt):.3f}")
print(f" iou_r_gt (reverse acc) : mean={iou_r_gt.mean():.3f} median={np.median(iou_r_gt):.3f}")
print(f" iou_f_r (consistency) : mean={iou_f_r.mean():.3f} median={np.median(iou_f_r):.3f} "
f">0.5 frac={(iou_f_r > 0.5).mean()*100:.1f}%")
# Validation criterion 1: is iou_f_r substantially > 0?
crit1 = iou_f_r.mean() > 0.3
print(f"\n[Criterion 1] mean iou_f_r > 0.3? {'PASS' if crit1 else 'FAIL'} "
f"({iou_f_r.mean():.3f})")
print(f" Interpretation: " +
("model IS consistent under reversal — signal exists" if crit1 else
"model produces unrelated predictions on reversed input — no useful signal"))
# Validation criterion 2: does iou_f_r correlate with iou_f_gt?
if len(A) > 3 and iou_f_r.std() > 0 and iou_f_gt.std() > 0:
corr = np.corrcoef(iou_f_r, iou_f_gt)[0, 1]
else:
corr = 0.0
crit2 = corr > 0.2
print(f"\n[Criterion 2] corr(iou_f_r, iou_f_gt) > 0.2? {'PASS' if crit2 else 'FAIL'} "
f"({corr:.3f})")
print(f" Interpretation: " +
("consistency under reversal predicts correctness — FBC reward will steer toward right answers" if crit2 else
"consistency is uncorrelated with correctness — FBC reward will push toward random consistency"))
# Per-generator breakdown
print(f"\nPer-generator (sorted by iou_f_r):")
by_gen = {}
for r in records:
by_gen.setdefault(r["generator"], []).append(r)
rows = []
for g, rs in by_gen.items():
arr = np.array([(x["iou_f_gt"], x["iou_f_r"]) for x in rs])
rows.append((g, len(rs), arr[:, 0].mean(), arr[:, 1].mean()))
for g, n, fg, fr in sorted(rows, key=lambda x: -x[3]):
print(f" {g:<12s} n={n:3d} iou_f_gt={fg:.3f} iou_f_r={fr:.3f}")
# Verdict
print(f"\n{'='*60}")
if crit1 and crit2:
print("VERDICT: FBC signal exists. Proceed to implement as GRPO reward.")
elif crit1 and not crit2:
print("VERDICT: model is consistent but not in a useful way. FBC alone "
"won't steer training; combine with iou or rethink.")
else:
print("VERDICT: FBC signal absent. Reversed video doesn't elicit meaningful "
"model behavior. Rethink the spatial / temporal causality framing.")
if __name__ == "__main__":
main()