| """M2 step 2: train temporal head over cached CLIP features. |
| |
| Architecture: |
| CLIP-L/14 (frozen, features already cached) → |
| Linear 768→384 → +PosEnc → 4-layer TransformerEncoder → Linear 384→1 |
| BCE loss against per-second forgery labels. |
| |
| Train/val split is VIDEO-LEVEL on the AF TRAIN cache only. 90% videos for |
| training, 10% held-out for model selection. No frame leaks across split, no |
| test-set involvement. |
| |
| Output: best checkpoint at <CACHE_PARENT>/verifier_temporal_best.pt |
| """ |
| import argparse |
| import json |
| import math |
| import os |
| import random |
| import sys |
| import time |
| from collections import defaultdict |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
| CACHE = "/mnt/local-fast/zhangt/forensics_verifier_clip_l14" |
| OUT_DIR = "/mnt/local-fast/zhangt/forensics_verifier_clip_l14" |
| SEED = 42 |
|
|
|
|
| |
| |
| |
| class VerifierDataset(Dataset): |
| """Loads (T, 768) features + (T,) labels per video, with metadata.""" |
|
|
| def __init__(self, video_dirs): |
| self.video_dirs = video_dirs |
|
|
| def __len__(self): |
| return len(self.video_dirs) |
|
|
| def __getitem__(self, idx): |
| d = self.video_dirs[idx] |
| feats = torch.load(os.path.join(d, "clip_feats.pt"), weights_only=True) |
| labels = torch.load(os.path.join(d, "clip_labels.pt"), weights_only=True) |
| gen = os.path.basename(os.path.dirname(d)) |
| return feats.float(), labels.float(), gen, d |
|
|
|
|
| def pad_collate(batch): |
| feats = [b[0] for b in batch] |
| lbls = [b[1] for b in batch] |
| gens = [b[2] for b in batch] |
| dirs = [b[3] for b in batch] |
| T_max = max(f.shape[0] for f in feats) |
| D = feats[0].shape[1] |
| B = len(batch) |
| pad_feats = torch.zeros(B, T_max, D, dtype=torch.float32) |
| pad_lbls = torch.zeros(B, T_max, dtype=torch.float32) |
| mask = torch.zeros(B, T_max, dtype=torch.bool) |
| for i, (f, l) in enumerate(zip(feats, lbls)): |
| T = f.shape[0] |
| pad_feats[i, :T] = f |
| pad_lbls[i, :T] = l |
| mask[i, :T] = True |
| return pad_feats, pad_lbls, mask, gens, dirs |
|
|
|
|
| |
| |
| |
| class TemporalVerifier(nn.Module): |
| def __init__(self, in_dim=768, hidden=384, num_layers=4, num_heads=8, dropout=0.1, max_len=512): |
| super().__init__() |
| self.in_proj = nn.Linear(in_dim, hidden) |
| |
| self.pos_emb = nn.Parameter(torch.zeros(1, max_len, hidden)) |
| nn.init.trunc_normal_(self.pos_emb, std=0.02) |
| layer = nn.TransformerEncoderLayer( |
| d_model=hidden, nhead=num_heads, dim_feedforward=hidden * 4, |
| dropout=dropout, batch_first=True, activation="gelu", norm_first=True, |
| ) |
| self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers) |
| self.norm = nn.LayerNorm(hidden) |
| self.head = nn.Linear(hidden, 1) |
|
|
| def forward(self, x, mask=None): |
| """x: (B, T, in_dim); mask: (B, T) True for valid; out: (B, T).""" |
| B, T, _ = x.shape |
| h = self.in_proj(x) + self.pos_emb[:, :T] |
| kpm = ~mask if mask is not None else None |
| h = self.encoder(h, src_key_padding_mask=kpm) |
| h = self.norm(h) |
| return self.head(h).squeeze(-1) |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def evaluate(model, loader, device): |
| model.eval() |
| per_video_gap = [] |
| per_gen = defaultdict(list) |
| all_logits, all_labels = [], [] |
| for feats, lbls, mask, gens, _ in loader: |
| feats = feats.to(device, non_blocking=True) |
| mask = mask.to(device, non_blocking=True) |
| logits = model(feats, mask=mask) |
| for i in range(feats.size(0)): |
| valid = mask[i].cpu().numpy().astype(bool) |
| l = logits[i].cpu().float().numpy()[valid] |
| y = lbls[i].cpu().numpy()[valid] |
| s = 1.0 / (1.0 + np.exp(-l)) |
| all_logits.append(s) |
| all_labels.append(y) |
| if y.any() and not y.all(): |
| m_in = float(s[y > 0.5].mean()) |
| m_out = float(s[y < 0.5].mean()) |
| per_video_gap.append(m_in - m_out) |
| per_gen[gens[i]].append((m_in, m_out)) |
| arr = np.array(per_video_gap) if per_video_gap else np.array([0.0]) |
| S = np.concatenate(all_logits) if all_logits else np.array([0.0]) |
| Y = np.concatenate(all_labels) if all_labels else np.array([0.0]) |
| |
| pos_s, neg_s = S[Y > 0.5], S[Y < 0.5] |
| auc = 0.5 |
| if len(pos_s) and len(neg_s): |
| rng = np.random.default_rng(SEED) |
| if len(pos_s) > 4000: pos_s = rng.choice(pos_s, 4000, replace=False) |
| if len(neg_s) > 4000: neg_s = rng.choice(neg_s, 4000, replace=False) |
| cmp = (pos_s[:, None] > neg_s[None, :]).astype(float) |
| eq = (pos_s[:, None] == neg_s[None, :]).astype(float) * 0.5 |
| auc = float((cmp + eq).mean()) |
| out = { |
| "gap_mean": float(arr.mean()), |
| "gap_median": float(np.median(arr)), |
| "gap_p25": float(np.percentile(arr, 25)), |
| "gap_p75": float(np.percentile(arr, 75)), |
| "frac_gt_005": float((arr > 0.05).mean()), |
| "frac_gt_010": float((arr > 0.10).mean()), |
| "frac_gt_015": float((arr > 0.15).mean()), |
| "global_auc": auc, |
| "n_videos_evaluated": len(per_video_gap), |
| "per_gen": {g: {"n": len(p), "pos": float(np.mean([x[0] for x in p])), |
| "neg": float(np.mean([x[1] for x in p])), |
| "gap": float(np.mean([x[0] - x[1] for x in p]))} |
| for g, p in per_gen.items()}, |
| } |
| return out |
|
|
|
|
| |
| |
| |
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--epochs", type=int, default=40) |
| ap.add_argument("--batch_size", type=int, default=16) |
| ap.add_argument("--lr", type=float, default=5e-4) |
| ap.add_argument("--val_frac", type=float, default=0.10) |
| ap.add_argument("--num_layers", type=int, default=4) |
| ap.add_argument("--hidden", type=int, default=384) |
| ap.add_argument("--num_heads", type=int, default=8) |
| ap.add_argument("--dropout", type=float, default=0.1) |
| ap.add_argument("--out", default=os.path.join(OUT_DIR, "verifier_temporal_best.pt")) |
| args = ap.parse_args() |
|
|
| random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED) |
|
|
| |
| train_root = os.path.join(CACHE, "train") |
| video_dirs = [] |
| for gen in sorted(os.listdir(train_root)): |
| gen_dir = os.path.join(train_root, gen) |
| if not os.path.isdir(gen_dir): continue |
| for sid in sorted(os.listdir(gen_dir)): |
| d = os.path.join(gen_dir, sid) |
| if os.path.exists(os.path.join(d, "clip_feats.pt")): |
| video_dirs.append(d) |
| print(f"found {len(video_dirs)} train videos with cached features") |
|
|
| |
| by_gen = defaultdict(list) |
| for d in video_dirs: |
| by_gen[os.path.basename(os.path.dirname(d))].append(d) |
| rng = random.Random(SEED) |
| train_dirs, val_dirs = [], [] |
| for g, dirs in by_gen.items(): |
| rng.shuffle(dirs) |
| k = max(1, int(len(dirs) * args.val_frac)) |
| val_dirs.extend(dirs[:k]) |
| train_dirs.extend(dirs[k:]) |
| rng.shuffle(train_dirs) |
| print(f"split: train={len(train_dirs)} val={len(val_dirs)}") |
|
|
| train_ds = VerifierDataset(train_dirs) |
| val_ds = VerifierDataset(val_dirs) |
| train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, |
| collate_fn=pad_collate, num_workers=4, pin_memory=True) |
| val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, |
| collate_fn=pad_collate, num_workers=2, pin_memory=True) |
|
|
| device = "cuda:0" |
| |
| print("scanning max sequence length ...") |
| max_T = 0 |
| for f, _, _, _ in train_ds: |
| max_T = max(max_T, f.shape[0]) |
| print(f" max_T = {max_T}") |
|
|
| model = TemporalVerifier( |
| in_dim=768, hidden=args.hidden, num_layers=args.num_layers, |
| num_heads=args.num_heads, dropout=args.dropout, max_len=max_T + 1, |
| ).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f"verifier params: {n_params/1e6:.2f}M") |
|
|
| |
| pos_count = neg_count = 0 |
| for _, y, _, _ in train_ds: |
| pos_count += int((y > 0.5).sum().item()) |
| neg_count += int((y < 0.5).sum().item()) |
| pw = torch.tensor([neg_count / max(1, pos_count)], device=device) |
| print(f"pos={pos_count} neg={neg_count} pos_weight={pw.item():.3f}") |
|
|
| opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=args.epochs) |
|
|
| best_val_gap = -1.0 |
| log_history = [] |
| for epoch in range(1, args.epochs + 1): |
| model.train() |
| ep_loss, ep_n = 0.0, 0 |
| t0 = time.time() |
| for feats, lbls, mask, _, _ in train_loader: |
| feats = feats.to(device, non_blocking=True) |
| lbls = lbls.to(device, non_blocking=True) |
| mask = mask.to(device, non_blocking=True) |
| logits = model(feats, mask=mask) |
| loss_per = F.binary_cross_entropy_with_logits(logits, lbls, pos_weight=pw, reduction="none") |
| loss = (loss_per * mask.float()).sum() / mask.float().sum().clamp_min(1) |
| opt.zero_grad() |
| loss.backward() |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
| ep_loss += float(loss.item()) * feats.size(0) |
| ep_n += feats.size(0) |
| sched.step() |
| train_loss = ep_loss / max(1, ep_n) |
|
|
| val_metrics = evaluate(model, val_loader, device) |
| log_history.append({"epoch": epoch, "train_loss": train_loss, **val_metrics}) |
|
|
| print( |
| f"epoch {epoch:3d}/{args.epochs} train_loss={train_loss:.4f} " |
| f"val_gap_mean={val_metrics['gap_mean']:+.3f} (median={val_metrics['gap_median']:+.3f}) " |
| f"val_AUC={val_metrics['global_auc']:.3f} " |
| f">0.05 {val_metrics['frac_gt_005']:.1%} " |
| f">0.10 {val_metrics['frac_gt_010']:.1%} " |
| f">0.15 {val_metrics['frac_gt_015']:.1%} " |
| f"t={time.time()-t0:.1f}s", |
| flush=True, |
| ) |
|
|
| if val_metrics["gap_mean"] > best_val_gap: |
| best_val_gap = val_metrics["gap_mean"] |
| torch.save({ |
| "model_state": model.state_dict(), |
| "args": vars(args), |
| "val_metrics": val_metrics, |
| "epoch": epoch, |
| "max_T": max_T, |
| }, args.out) |
| print(f" ✓ saved new best to {args.out} (gap={best_val_gap:+.3f})", flush=True) |
|
|
| print(f"\n=== best val gap_mean = {best_val_gap:+.3f} ===") |
| print("\nFinal val per-generator:") |
| final = log_history[-1] |
| if "per_gen" in final: |
| for g, m in sorted(final["per_gen"].items()): |
| print(f" {g:<12} n={m['n']:<4} pos={m['pos']:.3f} neg={m['neg']:.3f} gap={m['gap']:+.3f}") |
|
|
| with open(args.out.replace(".pt", "_log.json"), "w") as f: |
| json.dump(log_history, f, indent=2) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|