forensics-grpo / code /evaluate_forensics.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
11.6 kB
"""Evaluate a trained forensics-GRPO model on the AF test split.
Adapted from evaluate.py (single-span Charades-style grounding) for:
- multi-segment localisation (list of (s, e) tuples per video)
- the forensics CoT prompt template (FORENSICS_COT toggle preserved)
- cached video_inputs.pt to avoid re-decoding
- multi-GPU sharding (one process per device)
- multiple matching metrics: soft_F1, mean_F1@{0.5,0.75,0.85,0.95}, hungarian_IoU
Output:
<out_dir>/rank_<r>.jsonl one record per evaluated test video on this rank
<out_dir>/summary.json aggregate metrics (overall + per-generator)
"""
import argparse
import json
import os
import random
import sys
import time
import torch
from transformers import (
AutoProcessor,
GenerationConfig,
Qwen2_5_VLForConditionalGeneration,
)
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.open_r1.data_loader import TEST_GENERATORS, build_examples
from src.open_r1.reward import (
hungarian_iou_reward,
mean_f1_at_tiou,
parse_segments,
soft_f1,
)
from src.open_r1.trainer.grpo_trainer_video_GT_soft import (
SYSTEM_PROMPT,
get_question_template,
)
from src.open_r1.verifier import (
ForensicsVerifier,
format_verifier_scores,
sample_id_from_video_path,
)
ANNOT = "/mnt/local-fast/zhangt/annot/annot"
VROOT = "/mnt/local-fast/zhangt/video"
CACHE = "/mnt/local-fast/zhangt/forensics_grpo_cache_uniform3584_fps2.0"
def get_args():
p = argparse.ArgumentParser()
p.add_argument("--model_path", required=True)
p.add_argument("--rank", type=int, default=0)
p.add_argument("--world_size", type=int, default=1)
p.add_argument("--device", type=int, default=0,
help="cuda device index (set CUDA_VISIBLE_DEVICES to pin physical GPU)")
p.add_argument("--out_dir", default="eval_outputs/stage2_verifier_grounded")
p.add_argument("--cot", choices=["true", "false"], default="true",
help="Use CoT prompt template ('true') or no-CoT ('false').")
p.add_argument("--cot_variant", choices=["descriptive", "counterfactual", "counterfactual_parsimonious"], default="descriptive",
help="CoT prompt variant; must match the variant used at training time.")
p.add_argument("--verifier_context", choices=["true", "false"], default="false",
help="If true, inject external verifier per-second scores into the prompt.")
p.add_argument("--verifier_ckpt", default="/mnt/local-fast/zhangt/forensics_verifier_clip_l14/verifier_temporal_best.pt")
p.add_argument("--verifier_cache", default="/mnt/local-fast/zhangt/forensics_verifier_clip_l14")
p.add_argument("--max_new_tokens", type=int, default=640)
p.add_argument("--temperature", type=float, default=0.0,
help="Greedy if 0 else sample with this temp.")
p.add_argument("--limit", type=int, default=0, help="Cap number of videos per rank (0=all)")
return p.parse_args()
def load_cached(sample_dir):
feats = torch.load(os.path.join(sample_dir, "video_inputs.pt"), weights_only=False)
with open(os.path.join(sample_dir, "video_kwargs.json"), "r") as f:
kwargs = json.load(f)
return feats, kwargs
def main():
args = get_args()
device = f"cuda:{args.device}"
os.makedirs(args.out_dir, exist_ok=True)
# Toggle CoT/no-CoT prompt template via env var (the template fn reads it).
os.environ["FORENSICS_COT"] = args.cot
os.environ["FORENSICS_COT_VARIANT"] = args.cot_variant
print(f"[rank {args.rank}/{args.world_size}] device={device} model={args.model_path}", flush=True)
print(f" cot={args.cot} cot_variant={args.cot_variant} max_new_tokens={args.max_new_tokens} temp={args.temperature}", flush=True)
t0 = time.time()
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
args.model_path,
torch_dtype=torch.bfloat16,
use_sliding_window=True,
attn_implementation="flash_attention_2",
device_map=device,
)
model.eval()
processor = AutoProcessor.from_pretrained(args.model_path)
print(f" loaded model+processor in {time.time()-t0:.1f}s", flush=True)
# Build all test examples, then shard.
examples = build_examples(
annot_dir=ANNOT, video_root=VROOT, generators=TEST_GENERATORS,
split_prefix="test", preprocessed_data_path=CACHE, require_video_exists=True,
)
examples = [ex for i, ex in enumerate(examples) if i % args.world_size == args.rank]
if args.limit > 0:
examples = examples[: args.limit]
print(f" rank {args.rank} processes {len(examples)} test videos", flush=True)
question = get_question_template() # CoT or no-CoT depending on env var
# v14 generator-conditional eval mode (env var FORENSICS_GENCOND_MODE):
# none — plain prompt (default; deployment-equivalent, the headline ship number)
# matched — random 1[FORENSICS_GENCOND_PROB] prepend correct gen name, else generic
# (mirrors the training distribution; this is the "what the model
# was optimized for" reading)
# correct — always prepend the true generator name (oracle upper bound)
# wrong — always prepend a deterministic *other* generator name (token control)
gencond_mode = os.getenv("FORENSICS_GENCOND_MODE", "none").lower()
if gencond_mode not in ("none", "matched", "correct", "wrong"):
raise ValueError(f"FORENSICS_GENCOND_MODE must be none|matched|correct|wrong, got {gencond_mode!r}")
# matched-mode prob MUST equal the FORENSICS_GENCOND_PROB used at training
# time (the 3-way driver script enforces this by exporting both from the
# same value).
gencond_prob = float(os.getenv("FORENSICS_GENCOND_PROB", "0.5"))
# Deterministic per-rank RNGs so reruns of the same mode produce identical
# prompts (wrong: which other gen; matched: which samples are conditioned).
wrong_rng = random.Random(0xC0FFEE + args.rank)
matched_rng = random.Random(0xBEEF00 + args.rank)
print(f" FORENSICS_GENCOND_MODE = {gencond_mode}"
+ (f" PROB = {gencond_prob}" if gencond_mode == "matched" else ""), flush=True)
gen_cfg = GenerationConfig(
max_new_tokens=args.max_new_tokens,
do_sample=args.temperature > 0,
temperature=max(args.temperature, 1e-6),
pad_token_id=processor.tokenizer.pad_token_id,
use_cache=True,
)
# Force use_cache=True in 3 places (HF merges model defaults into GenerationConfig and
# silently overrides use_cache=False when gradient_checkpointing was on at train time).
model.config.use_cache = True
if hasattr(model, "generation_config"):
model.generation_config.use_cache = True
out_path = os.path.join(args.out_dir, f"rank_{args.rank}.jsonl")
fout = open(out_path, "w")
t_start = time.time()
done = failed = 0
for ex in examples:
sample_id = os.path.splitext(os.path.basename(ex["video_path"]))[0]
sample_dir = os.path.join(CACHE, "test", ex["generator"], sample_id)
if not os.path.exists(os.path.join(sample_dir, "video_inputs.pt")):
failed += 1
continue
try:
video_inputs, video_kwargs = load_cached(sample_dir)
# Per-sample prompt: optionally prepend a generator-name sentence to
# exercise the v14 generator-conditional pathway. The prepend string
# is byte-identical to the trainer's injection (single source of truth
# is grpo_trainer_video_GT_soft.py:make_conversation_video).
if gencond_mode == "correct":
q_text = f"The forged segments in this video were generated by {ex['generator']}. " + question
elif gencond_mode == "wrong":
others = [g for g in TEST_GENERATORS if g != ex["generator"]]
wrong_gen = wrong_rng.choice(others)
q_text = f"The forged segments in this video were generated by {wrong_gen}. " + question
elif gencond_mode == "matched":
if matched_rng.random() < gencond_prob:
q_text = f"The forged segments in this video were generated by {ex['generator']}. " + question
else:
q_text = question
else:
q_text = question
# Build chat-template prompt BYTE-IDENTICAL to the trainer's
# make_conversation_video (grpo_trainer_video_GT_soft.py:664-677):
# - system content is a plain string (NOT a list of blocks)
# - video block carries the same max_pixels/min_pixels/fps/max_frames
# The video tensor is supplied separately via processor(videos=[...]);
# the "video" key value is unused by apply_chat_template, but kwargs
# are preserved verbatim to remove any divergence risk.
chat = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": [
{"type": "video",
"video": ex["video_path"],
"max_pixels": 3584 * 28 * 28,
"min_pixels": 200704,
"fps": 2.0,
"max_frames": 64,
},
{"type": "text", "text": q_text},
],
},
]
text = processor.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[text],
videos=[video_inputs[0]],
fps=[video_kwargs["fps"][0]],
padding=True,
return_tensors="pt",
padding_side="left",
add_special_tokens=False,
)
inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
with torch.no_grad():
out_ids = model.generate(**inputs, generation_config=gen_cfg, use_cache=True)
gen_ids = out_ids[0][inputs["input_ids"].shape[1]:]
output_text = processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
except Exception as e:
failed += 1
print(f" [skip] {sample_id}: {type(e).__name__}: {e}", flush=True)
continue
# Score.
pred = parse_segments(output_text)
gt = [tuple(s) for s in ex["solution"]]
sf = soft_f1(pred, gt)
mf = mean_f1_at_tiou(pred, gt)
hg = hungarian_iou_reward(pred, gt)
rec = {
"sample_id": sample_id,
"generator": ex["generator"],
"gt": gt,
"pred": pred,
"output_text": output_text,
"soft_F1": sf,
"mean_F1_tIoU": mf,
"hungarian_iou": hg,
"n_pred": len(pred),
"n_gt": len(gt),
"parse_failed": len(pred) == 0,
}
fout.write(json.dumps(rec) + "\n")
fout.flush()
done += 1
if done % 20 == 0:
elapsed = time.time() - t_start
rate = done / max(1e-6, elapsed)
remaining = (len(examples) - done - failed) / max(1e-6, rate)
print(
f" rank={args.rank} done={done} fail={failed} "
f"rate={rate:.2f}/s eta={remaining/60:.1f}min",
flush=True,
)
fout.close()
print(f"[rank {args.rank}] DONE done={done} failed={failed} elapsed={time.time()-t_start:.0f}s", flush=True)
if __name__ == "__main__":
main()