forensics-grpo / code /verifier_m2_train_temporal.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
12.3 kB
"""M2 step 2: train temporal head over cached CLIP features.
Architecture:
CLIP-L/14 (frozen, features already cached) →
Linear 768→384 → +PosEnc → 4-layer TransformerEncoder → Linear 384→1
BCE loss against per-second forgery labels.
Train/val split is VIDEO-LEVEL on the AF TRAIN cache only. 90% videos for
training, 10% held-out for model selection. No frame leaks across split, no
test-set involvement.
Output: best checkpoint at <CACHE_PARENT>/verifier_temporal_best.pt
"""
import argparse
import json
import math
import os
import random
import sys
import time
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
CACHE = "/mnt/local-fast/zhangt/forensics_verifier_clip_l14"
OUT_DIR = "/mnt/local-fast/zhangt/forensics_verifier_clip_l14"
SEED = 42
# --------------------------------------------------------------------------- #
# Dataset #
# --------------------------------------------------------------------------- #
class VerifierDataset(Dataset):
"""Loads (T, 768) features + (T,) labels per video, with metadata."""
def __init__(self, video_dirs):
self.video_dirs = video_dirs # list of paths to <split>/<gen>/<sample_id>/
def __len__(self):
return len(self.video_dirs)
def __getitem__(self, idx):
d = self.video_dirs[idx]
feats = torch.load(os.path.join(d, "clip_feats.pt"), weights_only=True)
labels = torch.load(os.path.join(d, "clip_labels.pt"), weights_only=True)
gen = os.path.basename(os.path.dirname(d))
return feats.float(), labels.float(), gen, d
def pad_collate(batch):
feats = [b[0] for b in batch]
lbls = [b[1] for b in batch]
gens = [b[2] for b in batch]
dirs = [b[3] for b in batch]
T_max = max(f.shape[0] for f in feats)
D = feats[0].shape[1]
B = len(batch)
pad_feats = torch.zeros(B, T_max, D, dtype=torch.float32)
pad_lbls = torch.zeros(B, T_max, dtype=torch.float32)
mask = torch.zeros(B, T_max, dtype=torch.bool)
for i, (f, l) in enumerate(zip(feats, lbls)):
T = f.shape[0]
pad_feats[i, :T] = f
pad_lbls[i, :T] = l
mask[i, :T] = True
return pad_feats, pad_lbls, mask, gens, dirs
# --------------------------------------------------------------------------- #
# Model #
# --------------------------------------------------------------------------- #
class TemporalVerifier(nn.Module):
def __init__(self, in_dim=768, hidden=384, num_layers=4, num_heads=8, dropout=0.1, max_len=512):
super().__init__()
self.in_proj = nn.Linear(in_dim, hidden)
# Learnable positional embedding (simpler than sinusoidal for short sequences)
self.pos_emb = nn.Parameter(torch.zeros(1, max_len, hidden))
nn.init.trunc_normal_(self.pos_emb, std=0.02)
layer = nn.TransformerEncoderLayer(
d_model=hidden, nhead=num_heads, dim_feedforward=hidden * 4,
dropout=dropout, batch_first=True, activation="gelu", norm_first=True,
)
self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
self.norm = nn.LayerNorm(hidden)
self.head = nn.Linear(hidden, 1)
def forward(self, x, mask=None):
"""x: (B, T, in_dim); mask: (B, T) True for valid; out: (B, T)."""
B, T, _ = x.shape
h = self.in_proj(x) + self.pos_emb[:, :T]
kpm = ~mask if mask is not None else None # transformer expects True for PAD
h = self.encoder(h, src_key_padding_mask=kpm)
h = self.norm(h)
return self.head(h).squeeze(-1)
# --------------------------------------------------------------------------- #
# Eval #
# --------------------------------------------------------------------------- #
@torch.no_grad()
def evaluate(model, loader, device):
model.eval()
per_video_gap = []
per_gen = defaultdict(list)
all_logits, all_labels = [], []
for feats, lbls, mask, gens, _ in loader:
feats = feats.to(device, non_blocking=True)
mask = mask.to(device, non_blocking=True)
logits = model(feats, mask=mask)
for i in range(feats.size(0)):
valid = mask[i].cpu().numpy().astype(bool)
l = logits[i].cpu().float().numpy()[valid]
y = lbls[i].cpu().numpy()[valid]
s = 1.0 / (1.0 + np.exp(-l))
all_logits.append(s)
all_labels.append(y)
if y.any() and not y.all():
m_in = float(s[y > 0.5].mean())
m_out = float(s[y < 0.5].mean())
per_video_gap.append(m_in - m_out)
per_gen[gens[i]].append((m_in, m_out))
arr = np.array(per_video_gap) if per_video_gap else np.array([0.0])
S = np.concatenate(all_logits) if all_logits else np.array([0.0])
Y = np.concatenate(all_labels) if all_labels else np.array([0.0])
# AUC
pos_s, neg_s = S[Y > 0.5], S[Y < 0.5]
auc = 0.5
if len(pos_s) and len(neg_s):
rng = np.random.default_rng(SEED)
if len(pos_s) > 4000: pos_s = rng.choice(pos_s, 4000, replace=False)
if len(neg_s) > 4000: neg_s = rng.choice(neg_s, 4000, replace=False)
cmp = (pos_s[:, None] > neg_s[None, :]).astype(float)
eq = (pos_s[:, None] == neg_s[None, :]).astype(float) * 0.5
auc = float((cmp + eq).mean())
out = {
"gap_mean": float(arr.mean()),
"gap_median": float(np.median(arr)),
"gap_p25": float(np.percentile(arr, 25)),
"gap_p75": float(np.percentile(arr, 75)),
"frac_gt_005": float((arr > 0.05).mean()),
"frac_gt_010": float((arr > 0.10).mean()),
"frac_gt_015": float((arr > 0.15).mean()),
"global_auc": auc,
"n_videos_evaluated": len(per_video_gap),
"per_gen": {g: {"n": len(p), "pos": float(np.mean([x[0] for x in p])),
"neg": float(np.mean([x[1] for x in p])),
"gap": float(np.mean([x[0] - x[1] for x in p]))}
for g, p in per_gen.items()},
}
return out
# --------------------------------------------------------------------------- #
# Main #
# --------------------------------------------------------------------------- #
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--epochs", type=int, default=40)
ap.add_argument("--batch_size", type=int, default=16)
ap.add_argument("--lr", type=float, default=5e-4)
ap.add_argument("--val_frac", type=float, default=0.10)
ap.add_argument("--num_layers", type=int, default=4)
ap.add_argument("--hidden", type=int, default=384)
ap.add_argument("--num_heads", type=int, default=8)
ap.add_argument("--dropout", type=float, default=0.1)
ap.add_argument("--out", default=os.path.join(OUT_DIR, "verifier_temporal_best.pt"))
args = ap.parse_args()
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
# Enumerate train videos
train_root = os.path.join(CACHE, "train")
video_dirs = []
for gen in sorted(os.listdir(train_root)):
gen_dir = os.path.join(train_root, gen)
if not os.path.isdir(gen_dir): continue
for sid in sorted(os.listdir(gen_dir)):
d = os.path.join(gen_dir, sid)
if os.path.exists(os.path.join(d, "clip_feats.pt")):
video_dirs.append(d)
print(f"found {len(video_dirs)} train videos with cached features")
# Video-level 90/10 split, stratified across generators (rough)
by_gen = defaultdict(list)
for d in video_dirs:
by_gen[os.path.basename(os.path.dirname(d))].append(d)
rng = random.Random(SEED)
train_dirs, val_dirs = [], []
for g, dirs in by_gen.items():
rng.shuffle(dirs)
k = max(1, int(len(dirs) * args.val_frac))
val_dirs.extend(dirs[:k])
train_dirs.extend(dirs[k:])
rng.shuffle(train_dirs)
print(f"split: train={len(train_dirs)} val={len(val_dirs)}")
train_ds = VerifierDataset(train_dirs)
val_ds = VerifierDataset(val_dirs)
train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True,
collate_fn=pad_collate, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
collate_fn=pad_collate, num_workers=2, pin_memory=True)
device = "cuda:0"
# Need max_len = max video duration we encountered
print("scanning max sequence length ...")
max_T = 0
for f, _, _, _ in train_ds:
max_T = max(max_T, f.shape[0])
print(f" max_T = {max_T}")
model = TemporalVerifier(
in_dim=768, hidden=args.hidden, num_layers=args.num_layers,
num_heads=args.num_heads, dropout=args.dropout, max_len=max_T + 1,
).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f"verifier params: {n_params/1e6:.2f}M")
# Class-balance from training data
pos_count = neg_count = 0
for _, y, _, _ in train_ds:
pos_count += int((y > 0.5).sum().item())
neg_count += int((y < 0.5).sum().item())
pw = torch.tensor([neg_count / max(1, pos_count)], device=device)
print(f"pos={pos_count} neg={neg_count} pos_weight={pw.item():.3f}")
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs)
best_val_gap = -1.0
log_history = []
for epoch in range(1, args.epochs + 1):
model.train()
ep_loss, ep_n = 0.0, 0
t0 = time.time()
for feats, lbls, mask, _, _ in train_loader:
feats = feats.to(device, non_blocking=True)
lbls = lbls.to(device, non_blocking=True)
mask = mask.to(device, non_blocking=True)
logits = model(feats, mask=mask)
loss_per = F.binary_cross_entropy_with_logits(logits, lbls, pos_weight=pw, reduction="none")
loss = (loss_per * mask.float()).sum() / mask.float().sum().clamp_min(1)
opt.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
ep_loss += float(loss.item()) * feats.size(0)
ep_n += feats.size(0)
sched.step()
train_loss = ep_loss / max(1, ep_n)
val_metrics = evaluate(model, val_loader, device)
log_history.append({"epoch": epoch, "train_loss": train_loss, **val_metrics})
print(
f"epoch {epoch:3d}/{args.epochs} train_loss={train_loss:.4f} "
f"val_gap_mean={val_metrics['gap_mean']:+.3f} (median={val_metrics['gap_median']:+.3f}) "
f"val_AUC={val_metrics['global_auc']:.3f} "
f">0.05 {val_metrics['frac_gt_005']:.1%} "
f">0.10 {val_metrics['frac_gt_010']:.1%} "
f">0.15 {val_metrics['frac_gt_015']:.1%} "
f"t={time.time()-t0:.1f}s",
flush=True,
)
if val_metrics["gap_mean"] > best_val_gap:
best_val_gap = val_metrics["gap_mean"]
torch.save({
"model_state": model.state_dict(),
"args": vars(args),
"val_metrics": val_metrics,
"epoch": epoch,
"max_T": max_T,
}, args.out)
print(f" ✓ saved new best to {args.out} (gap={best_val_gap:+.3f})", flush=True)
print(f"\n=== best val gap_mean = {best_val_gap:+.3f} ===")
print("\nFinal val per-generator:")
final = log_history[-1]
if "per_gen" in final:
for g, m in sorted(final["per_gen"].items()):
print(f" {g:<12} n={m['n']:<4} pos={m['pos']:.3f} neg={m['neg']:.3f} gap={m['gap']:+.3f}")
with open(args.out.replace(".pt", "_log.json"), "w") as f:
json.dump(log_history, f, indent=2)
if __name__ == "__main__":
main()