"""M2 step 2: train temporal head over cached CLIP features. Architecture: CLIP-L/14 (frozen, features already cached) → Linear 768→384 → +PosEnc → 4-layer TransformerEncoder → Linear 384→1 BCE loss against per-second forgery labels. Train/val split is VIDEO-LEVEL on the AF TRAIN cache only. 90% videos for training, 10% held-out for model selection. No frame leaks across split, no test-set involvement. Output: best checkpoint at /verifier_temporal_best.pt """ import argparse import json import math import os import random import sys import time from collections import defaultdict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) CACHE = "/mnt/local-fast/zhangt/forensics_verifier_clip_l14" OUT_DIR = "/mnt/local-fast/zhangt/forensics_verifier_clip_l14" SEED = 42 # --------------------------------------------------------------------------- # # Dataset # # --------------------------------------------------------------------------- # class VerifierDataset(Dataset): """Loads (T, 768) features + (T,) labels per video, with metadata.""" def __init__(self, video_dirs): self.video_dirs = video_dirs # list of paths to /// def __len__(self): return len(self.video_dirs) def __getitem__(self, idx): d = self.video_dirs[idx] feats = torch.load(os.path.join(d, "clip_feats.pt"), weights_only=True) labels = torch.load(os.path.join(d, "clip_labels.pt"), weights_only=True) gen = os.path.basename(os.path.dirname(d)) return feats.float(), labels.float(), gen, d def pad_collate(batch): feats = [b[0] for b in batch] lbls = [b[1] for b in batch] gens = [b[2] for b in batch] dirs = [b[3] for b in batch] T_max = max(f.shape[0] for f in feats) D = feats[0].shape[1] B = len(batch) pad_feats = torch.zeros(B, T_max, D, dtype=torch.float32) pad_lbls = torch.zeros(B, T_max, dtype=torch.float32) mask = torch.zeros(B, T_max, dtype=torch.bool) for i, (f, l) in enumerate(zip(feats, lbls)): T = f.shape[0] pad_feats[i, :T] = f pad_lbls[i, :T] = l mask[i, :T] = True return pad_feats, pad_lbls, mask, gens, dirs # --------------------------------------------------------------------------- # # Model # # --------------------------------------------------------------------------- # class TemporalVerifier(nn.Module): def __init__(self, in_dim=768, hidden=384, num_layers=4, num_heads=8, dropout=0.1, max_len=512): super().__init__() self.in_proj = nn.Linear(in_dim, hidden) # Learnable positional embedding (simpler than sinusoidal for short sequences) self.pos_emb = nn.Parameter(torch.zeros(1, max_len, hidden)) nn.init.trunc_normal_(self.pos_emb, std=0.02) layer = nn.TransformerEncoderLayer( d_model=hidden, nhead=num_heads, dim_feedforward=hidden * 4, dropout=dropout, batch_first=True, activation="gelu", norm_first=True, ) self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers) self.norm = nn.LayerNorm(hidden) self.head = nn.Linear(hidden, 1) def forward(self, x, mask=None): """x: (B, T, in_dim); mask: (B, T) True for valid; out: (B, T).""" B, T, _ = x.shape h = self.in_proj(x) + self.pos_emb[:, :T] kpm = ~mask if mask is not None else None # transformer expects True for PAD h = self.encoder(h, src_key_padding_mask=kpm) h = self.norm(h) return self.head(h).squeeze(-1) # --------------------------------------------------------------------------- # # Eval # # --------------------------------------------------------------------------- # @torch.no_grad() def evaluate(model, loader, device): model.eval() per_video_gap = [] per_gen = defaultdict(list) all_logits, all_labels = [], [] for feats, lbls, mask, gens, _ in loader: feats = feats.to(device, non_blocking=True) mask = mask.to(device, non_blocking=True) logits = model(feats, mask=mask) for i in range(feats.size(0)): valid = mask[i].cpu().numpy().astype(bool) l = logits[i].cpu().float().numpy()[valid] y = lbls[i].cpu().numpy()[valid] s = 1.0 / (1.0 + np.exp(-l)) all_logits.append(s) all_labels.append(y) if y.any() and not y.all(): m_in = float(s[y > 0.5].mean()) m_out = float(s[y < 0.5].mean()) per_video_gap.append(m_in - m_out) per_gen[gens[i]].append((m_in, m_out)) arr = np.array(per_video_gap) if per_video_gap else np.array([0.0]) S = np.concatenate(all_logits) if all_logits else np.array([0.0]) Y = np.concatenate(all_labels) if all_labels else np.array([0.0]) # AUC pos_s, neg_s = S[Y > 0.5], S[Y < 0.5] auc = 0.5 if len(pos_s) and len(neg_s): rng = np.random.default_rng(SEED) if len(pos_s) > 4000: pos_s = rng.choice(pos_s, 4000, replace=False) if len(neg_s) > 4000: neg_s = rng.choice(neg_s, 4000, replace=False) cmp = (pos_s[:, None] > neg_s[None, :]).astype(float) eq = (pos_s[:, None] == neg_s[None, :]).astype(float) * 0.5 auc = float((cmp + eq).mean()) out = { "gap_mean": float(arr.mean()), "gap_median": float(np.median(arr)), "gap_p25": float(np.percentile(arr, 25)), "gap_p75": float(np.percentile(arr, 75)), "frac_gt_005": float((arr > 0.05).mean()), "frac_gt_010": float((arr > 0.10).mean()), "frac_gt_015": float((arr > 0.15).mean()), "global_auc": auc, "n_videos_evaluated": len(per_video_gap), "per_gen": {g: {"n": len(p), "pos": float(np.mean([x[0] for x in p])), "neg": float(np.mean([x[1] for x in p])), "gap": float(np.mean([x[0] - x[1] for x in p]))} for g, p in per_gen.items()}, } return out # --------------------------------------------------------------------------- # # Main # # --------------------------------------------------------------------------- # def main(): ap = argparse.ArgumentParser() ap.add_argument("--epochs", type=int, default=40) ap.add_argument("--batch_size", type=int, default=16) ap.add_argument("--lr", type=float, default=5e-4) ap.add_argument("--val_frac", type=float, default=0.10) ap.add_argument("--num_layers", type=int, default=4) ap.add_argument("--hidden", type=int, default=384) ap.add_argument("--num_heads", type=int, default=8) ap.add_argument("--dropout", type=float, default=0.1) ap.add_argument("--out", default=os.path.join(OUT_DIR, "verifier_temporal_best.pt")) args = ap.parse_args() random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED) # Enumerate train videos train_root = os.path.join(CACHE, "train") video_dirs = [] for gen in sorted(os.listdir(train_root)): gen_dir = os.path.join(train_root, gen) if not os.path.isdir(gen_dir): continue for sid in sorted(os.listdir(gen_dir)): d = os.path.join(gen_dir, sid) if os.path.exists(os.path.join(d, "clip_feats.pt")): video_dirs.append(d) print(f"found {len(video_dirs)} train videos with cached features") # Video-level 90/10 split, stratified across generators (rough) by_gen = defaultdict(list) for d in video_dirs: by_gen[os.path.basename(os.path.dirname(d))].append(d) rng = random.Random(SEED) train_dirs, val_dirs = [], [] for g, dirs in by_gen.items(): rng.shuffle(dirs) k = max(1, int(len(dirs) * args.val_frac)) val_dirs.extend(dirs[:k]) train_dirs.extend(dirs[k:]) rng.shuffle(train_dirs) print(f"split: train={len(train_dirs)} val={len(val_dirs)}") train_ds = VerifierDataset(train_dirs) val_ds = VerifierDataset(val_dirs) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=pad_collate, num_workers=4, pin_memory=True) val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=pad_collate, num_workers=2, pin_memory=True) device = "cuda:0" # Need max_len = max video duration we encountered print("scanning max sequence length ...") max_T = 0 for f, _, _, _ in train_ds: max_T = max(max_T, f.shape[0]) print(f" max_T = {max_T}") model = TemporalVerifier( in_dim=768, hidden=args.hidden, num_layers=args.num_layers, num_heads=args.num_heads, dropout=args.dropout, max_len=max_T + 1, ).to(device) n_params = sum(p.numel() for p in model.parameters()) print(f"verifier params: {n_params/1e6:.2f}M") # Class-balance from training data pos_count = neg_count = 0 for _, y, _, _ in train_ds: pos_count += int((y > 0.5).sum().item()) neg_count += int((y < 0.5).sum().item()) pw = torch.tensor([neg_count / max(1, pos_count)], device=device) print(f"pos={pos_count} neg={neg_count} pos_weight={pw.item():.3f}") opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs) best_val_gap = -1.0 log_history = [] for epoch in range(1, args.epochs + 1): model.train() ep_loss, ep_n = 0.0, 0 t0 = time.time() for feats, lbls, mask, _, _ in train_loader: feats = feats.to(device, non_blocking=True) lbls = lbls.to(device, non_blocking=True) mask = mask.to(device, non_blocking=True) logits = model(feats, mask=mask) loss_per = F.binary_cross_entropy_with_logits(logits, lbls, pos_weight=pw, reduction="none") loss = (loss_per * mask.float()).sum() / mask.float().sum().clamp_min(1) opt.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() ep_loss += float(loss.item()) * feats.size(0) ep_n += feats.size(0) sched.step() train_loss = ep_loss / max(1, ep_n) val_metrics = evaluate(model, val_loader, device) log_history.append({"epoch": epoch, "train_loss": train_loss, **val_metrics}) print( f"epoch {epoch:3d}/{args.epochs} train_loss={train_loss:.4f} " f"val_gap_mean={val_metrics['gap_mean']:+.3f} (median={val_metrics['gap_median']:+.3f}) " f"val_AUC={val_metrics['global_auc']:.3f} " f">0.05 {val_metrics['frac_gt_005']:.1%} " f">0.10 {val_metrics['frac_gt_010']:.1%} " f">0.15 {val_metrics['frac_gt_015']:.1%} " f"t={time.time()-t0:.1f}s", flush=True, ) if val_metrics["gap_mean"] > best_val_gap: best_val_gap = val_metrics["gap_mean"] torch.save({ "model_state": model.state_dict(), "args": vars(args), "val_metrics": val_metrics, "epoch": epoch, "max_T": max_T, }, args.out) print(f" ✓ saved new best to {args.out} (gap={best_val_gap:+.3f})", flush=True) print(f"\n=== best val gap_mean = {best_val_gap:+.3f} ===") print("\nFinal val per-generator:") final = log_history[-1] if "per_gen" in final: for g, m in sorted(final["per_gen"].items()): print(f" {g:<12} n={m['n']:<4} pos={m['pos']:.3f} neg={m['neg']:.3f} gap={m['gap']:+.3f}") with open(args.out.replace(".pt", "_log.json"), "w") as f: json.dump(log_history, f, indent=2) if __name__ == "__main__": main()