File size: 6,545 Bytes
33569f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | """External forensics verifier — CLIP frozen features + trained temporal head.
This module is the bridge between the offline-trained verifier (see
verifier_m2_train_temporal.py) and the RL trainer. It exposes a single
callable that, given a training-set video identifier, returns per-second
forgery scores (numpy array, shape (T,), values in [0, 1]).
Key design choices:
- CLIP features are pre-extracted ONCE (verifier_m2_extract_clip.py) and cached.
At RL training time we only run the small temporal head, ~ms latency.
- Train-cache layout matches the existing forensics preprocess cache:
<CACHE_ROOT>/<split>/<gen>/<sample_id>/clip_feats.pt
This lets the trainer look up scores by the same (split, gen, sample_id)
used elsewhere in the codebase.
- The verifier is a STATIC reward shaper — it is loaded once, kept frozen on a
dedicated GPU slot if available, and never updates during RL.
"""
import os
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
# ---------------------------------------------------------------------------
# Model definition (must match verifier_m2_train_temporal.py)
# ---------------------------------------------------------------------------
class TemporalVerifier(nn.Module):
"""1D Transformer over per-frame CLIP features → per-second forgery logit."""
def __init__(self, in_dim=768, hidden=384, num_layers=4, num_heads=8,
dropout=0.1, max_len=512):
super().__init__()
self.in_proj = nn.Linear(in_dim, hidden)
self.pos_emb = nn.Parameter(torch.zeros(1, max_len, hidden))
nn.init.trunc_normal_(self.pos_emb, std=0.02)
layer = nn.TransformerEncoderLayer(
d_model=hidden, nhead=num_heads, dim_feedforward=hidden * 4,
dropout=dropout, batch_first=True, activation="gelu", norm_first=True,
)
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
self.norm = nn.LayerNorm(hidden)
self.head = nn.Linear(hidden, 1)
def forward(self, x, mask=None):
B, T, _ = x.shape
h = self.in_proj(x) + self.pos_emb[:, :T]
kpm = ~mask if mask is not None else None
h = self.encoder(h, src_key_padding_mask=kpm)
h = self.norm(h)
return self.head(h).squeeze(-1)
# ---------------------------------------------------------------------------
# Verifier wrapper #
# ---------------------------------------------------------------------------
class ForensicsVerifier:
"""Loads a trained temporal verifier and provides per-video score lookup.
Usage:
verifier = ForensicsVerifier(
ckpt="/mnt/.../verifier_temporal_best.pt",
cache_root="/mnt/.../forensics_verifier_clip_l14",
device="cuda:0",
)
scores = verifier.scores_for("train", "scifi", "v_abc...")
# scores: numpy (T,), per-second forgery prob in [0, 1]
"""
def __init__(self, ckpt: str, cache_root: str, device: str = "cuda:0"):
self.cache_root = cache_root
self.device = device
if not os.path.exists(ckpt):
raise FileNotFoundError(f"verifier checkpoint missing: {ckpt}")
state = torch.load(ckpt, map_location="cpu", weights_only=False)
args = state.get("args", {}) or {}
max_T = int(state.get("max_T", 512))
self.model = TemporalVerifier(
in_dim=768,
hidden=args.get("hidden", 384),
num_layers=args.get("num_layers", 4),
num_heads=args.get("num_heads", 8),
dropout=0.0, # disable dropout at inference
max_len=max_T + 1,
).to(device).eval()
self.model.load_state_dict(state["model_state"])
# Cache features in memory keyed by (split, gen, sample_id) -> tensor
self._feat_cache: dict = {}
@torch.no_grad()
def scores_for(self, split: str, generator: str, sample_id: str) -> Optional[np.ndarray]:
"""Return per-second forgery probability array, or None if not cached."""
key = (split, generator, sample_id)
feats = self._feat_cache.get(key)
if feats is None:
path = os.path.join(self.cache_root, split, generator, sample_id, "clip_feats.pt")
if not os.path.exists(path):
return None
feats = torch.load(path, weights_only=True)
self._feat_cache[key] = feats # keep a copy, RL re-uses same videos
feats = feats.to(self.device, dtype=torch.float32).unsqueeze(0) # (1, T, 768)
logits = self.model(feats) # (1, T)
return torch.sigmoid(logits).squeeze(0).cpu().numpy()
def warmup(self, video_keys):
"""Pre-load features into the in-memory cache before RL starts."""
n = 0
for split, gen, sid in video_keys:
path = os.path.join(self.cache_root, split, gen, sid, "clip_feats.pt")
if not os.path.exists(path):
continue
self._feat_cache[(split, gen, sid)] = torch.load(path, weights_only=True)
n += 1
return n
def sample_id_from_video_path(video_path: str) -> str:
"""Match the trainer's convention: stem of the basename."""
return os.path.splitext(os.path.basename(video_path))[0]
def format_verifier_scores(scores: Optional[np.ndarray], per_line: int = 8) -> str:
"""Format per-second forgery probabilities as compact text for prompt context.
Used in the verifier-as-context experiment: instead of consuming the verifier
output as a reward shaper (which we found to be ~86% redundant with IoU and
empirically harmful), we give the VLM the raw per-second scores so it can
reason over them. RL reward stays pure IoU.
Format example for a 12-second video:
External forensics verifier per-second forgery scores (0-100):
s0=12 s1=15 s2=11 s3=91 s4=95 s5=93 s6=34 s7=21
s8=18 s9=15 s10=12 s11=10
"""
if scores is None or len(scores) == 0:
return ""
header = "External forensics verifier per-second forgery scores (range 0-100, higher = more suspicious):"
lines = [header]
chunk = []
for i, s in enumerate(scores):
chunk.append(f"s{i}={int(round(float(s) * 100)):02d}")
if (i + 1) % per_line == 0:
lines.append(" ".join(chunk))
chunk = []
if chunk:
lines.append(" ".join(chunk))
return "\n".join(lines)
|