forensics-grpo / code /src /open_r1 /verifier.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
6.55 kB
"""External forensics verifier — CLIP frozen features + trained temporal head.
This module is the bridge between the offline-trained verifier (see
verifier_m2_train_temporal.py) and the RL trainer. It exposes a single
callable that, given a training-set video identifier, returns per-second
forgery scores (numpy array, shape (T,), values in [0, 1]).
Key design choices:
- CLIP features are pre-extracted ONCE (verifier_m2_extract_clip.py) and cached.
At RL training time we only run the small temporal head, ~ms latency.
- Train-cache layout matches the existing forensics preprocess cache:
<CACHE_ROOT>/<split>/<gen>/<sample_id>/clip_feats.pt
This lets the trainer look up scores by the same (split, gen, sample_id)
used elsewhere in the codebase.
- The verifier is a STATIC reward shaper — it is loaded once, kept frozen on a
dedicated GPU slot if available, and never updates during RL.
"""
import os
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
# ---------------------------------------------------------------------------
# Model definition (must match verifier_m2_train_temporal.py)
# ---------------------------------------------------------------------------
class TemporalVerifier(nn.Module):
"""1D Transformer over per-frame CLIP features → per-second forgery logit."""
def __init__(self, in_dim=768, hidden=384, num_layers=4, num_heads=8,
dropout=0.1, max_len=512):
super().__init__()
self.in_proj = nn.Linear(in_dim, hidden)
self.pos_emb = nn.Parameter(torch.zeros(1, max_len, hidden))
nn.init.trunc_normal_(self.pos_emb, std=0.02)
layer = nn.TransformerEncoderLayer(
d_model=hidden, nhead=num_heads, dim_feedforward=hidden * 4,
dropout=dropout, batch_first=True, activation="gelu", norm_first=True,
)
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
self.norm = nn.LayerNorm(hidden)
self.head = nn.Linear(hidden, 1)
def forward(self, x, mask=None):
B, T, _ = x.shape
h = self.in_proj(x) + self.pos_emb[:, :T]
kpm = ~mask if mask is not None else None
h = self.encoder(h, src_key_padding_mask=kpm)
h = self.norm(h)
return self.head(h).squeeze(-1)
# ---------------------------------------------------------------------------
# Verifier wrapper #
# ---------------------------------------------------------------------------
class ForensicsVerifier:
"""Loads a trained temporal verifier and provides per-video score lookup.
Usage:
verifier = ForensicsVerifier(
ckpt="/mnt/.../verifier_temporal_best.pt",
cache_root="/mnt/.../forensics_verifier_clip_l14",
device="cuda:0",
)
scores = verifier.scores_for("train", "scifi", "v_abc...")
# scores: numpy (T,), per-second forgery prob in [0, 1]
"""
def __init__(self, ckpt: str, cache_root: str, device: str = "cuda:0"):
self.cache_root = cache_root
self.device = device
if not os.path.exists(ckpt):
raise FileNotFoundError(f"verifier checkpoint missing: {ckpt}")
state = torch.load(ckpt, map_location="cpu", weights_only=False)
args = state.get("args", {}) or {}
max_T = int(state.get("max_T", 512))
self.model = TemporalVerifier(
in_dim=768,
hidden=args.get("hidden", 384),
num_layers=args.get("num_layers", 4),
num_heads=args.get("num_heads", 8),
dropout=0.0, # disable dropout at inference
max_len=max_T + 1,
).to(device).eval()
self.model.load_state_dict(state["model_state"])
# Cache features in memory keyed by (split, gen, sample_id) -> tensor
self._feat_cache: dict = {}
@torch.no_grad()
def scores_for(self, split: str, generator: str, sample_id: str) -> Optional[np.ndarray]:
"""Return per-second forgery probability array, or None if not cached."""
key = (split, generator, sample_id)
feats = self._feat_cache.get(key)
if feats is None:
path = os.path.join(self.cache_root, split, generator, sample_id, "clip_feats.pt")
if not os.path.exists(path):
return None
feats = torch.load(path, weights_only=True)
self._feat_cache[key] = feats # keep a copy, RL re-uses same videos
feats = feats.to(self.device, dtype=torch.float32).unsqueeze(0) # (1, T, 768)
logits = self.model(feats) # (1, T)
return torch.sigmoid(logits).squeeze(0).cpu().numpy()
def warmup(self, video_keys):
"""Pre-load features into the in-memory cache before RL starts."""
n = 0
for split, gen, sid in video_keys:
path = os.path.join(self.cache_root, split, gen, sid, "clip_feats.pt")
if not os.path.exists(path):
continue
self._feat_cache[(split, gen, sid)] = torch.load(path, weights_only=True)
n += 1
return n
def sample_id_from_video_path(video_path: str) -> str:
"""Match the trainer's convention: stem of the basename."""
return os.path.splitext(os.path.basename(video_path))[0]
def format_verifier_scores(scores: Optional[np.ndarray], per_line: int = 8) -> str:
"""Format per-second forgery probabilities as compact text for prompt context.
Used in the verifier-as-context experiment: instead of consuming the verifier
output as a reward shaper (which we found to be ~86% redundant with IoU and
empirically harmful), we give the VLM the raw per-second scores so it can
reason over them. RL reward stays pure IoU.
Format example for a 12-second video:
External forensics verifier per-second forgery scores (0-100):
s0=12 s1=15 s2=11 s3=91 s4=95 s5=93 s6=34 s7=21
s8=18 s9=15 s10=12 s11=10
"""
if scores is None or len(scores) == 0:
return ""
header = "External forensics verifier per-second forgery scores (range 0-100, higher = more suspicious):"
lines = [header]
chunk = []
for i, s in enumerate(scores):
chunk.append(f"s{i}={int(round(float(s) * 100)):02d}")
if (i + 1) % per_line == 0:
lines.append(" ".join(chunk))
chunk = []
if chunk:
lines.append(" ".join(chunk))
return "\n".join(lines)