File size: 6,545 Bytes
33569f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""External forensics verifier — CLIP frozen features + trained temporal head.

This module is the bridge between the offline-trained verifier (see
verifier_m2_train_temporal.py) and the RL trainer.  It exposes a single
callable that, given a training-set video identifier, returns per-second
forgery scores (numpy array, shape (T,), values in [0, 1]).

Key design choices:
- CLIP features are pre-extracted ONCE (verifier_m2_extract_clip.py) and cached.
  At RL training time we only run the small temporal head, ~ms latency.
- Train-cache layout matches the existing forensics preprocess cache:
    <CACHE_ROOT>/<split>/<gen>/<sample_id>/clip_feats.pt
  This lets the trainer look up scores by the same (split, gen, sample_id)
  used elsewhere in the codebase.
- The verifier is a STATIC reward shaper — it is loaded once, kept frozen on a
  dedicated GPU slot if available, and never updates during RL.
"""
import os
from typing import Optional

import numpy as np
import torch
import torch.nn as nn


# ---------------------------------------------------------------------------
# Model definition (must match verifier_m2_train_temporal.py)
# ---------------------------------------------------------------------------
class TemporalVerifier(nn.Module):
    """1D Transformer over per-frame CLIP features → per-second forgery logit."""

    def __init__(self, in_dim=768, hidden=384, num_layers=4, num_heads=8,
                 dropout=0.1, max_len=512):
        super().__init__()
        self.in_proj = nn.Linear(in_dim, hidden)
        self.pos_emb = nn.Parameter(torch.zeros(1, max_len, hidden))
        nn.init.trunc_normal_(self.pos_emb, std=0.02)
        layer = nn.TransformerEncoderLayer(
            d_model=hidden, nhead=num_heads, dim_feedforward=hidden * 4,
            dropout=dropout, batch_first=True, activation="gelu", norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers)
        self.norm = nn.LayerNorm(hidden)
        self.head = nn.Linear(hidden, 1)

    def forward(self, x, mask=None):
        B, T, _ = x.shape
        h = self.in_proj(x) + self.pos_emb[:, :T]
        kpm = ~mask if mask is not None else None
        h = self.encoder(h, src_key_padding_mask=kpm)
        h = self.norm(h)
        return self.head(h).squeeze(-1)


# ---------------------------------------------------------------------------
# Verifier wrapper                                                          #
# ---------------------------------------------------------------------------
class ForensicsVerifier:
    """Loads a trained temporal verifier and provides per-video score lookup.

    Usage:
        verifier = ForensicsVerifier(
            ckpt="/mnt/.../verifier_temporal_best.pt",
            cache_root="/mnt/.../forensics_verifier_clip_l14",
            device="cuda:0",
        )
        scores = verifier.scores_for("train", "scifi", "v_abc...")
        # scores: numpy (T,), per-second forgery prob in [0, 1]
    """

    def __init__(self, ckpt: str, cache_root: str, device: str = "cuda:0"):
        self.cache_root = cache_root
        self.device = device

        if not os.path.exists(ckpt):
            raise FileNotFoundError(f"verifier checkpoint missing: {ckpt}")
        state = torch.load(ckpt, map_location="cpu", weights_only=False)
        args = state.get("args", {}) or {}
        max_T = int(state.get("max_T", 512))

        self.model = TemporalVerifier(
            in_dim=768,
            hidden=args.get("hidden", 384),
            num_layers=args.get("num_layers", 4),
            num_heads=args.get("num_heads", 8),
            dropout=0.0,                 # disable dropout at inference
            max_len=max_T + 1,
        ).to(device).eval()
        self.model.load_state_dict(state["model_state"])

        # Cache features in memory keyed by (split, gen, sample_id) -> tensor
        self._feat_cache: dict = {}

    @torch.no_grad()
    def scores_for(self, split: str, generator: str, sample_id: str) -> Optional[np.ndarray]:
        """Return per-second forgery probability array, or None if not cached."""
        key = (split, generator, sample_id)
        feats = self._feat_cache.get(key)
        if feats is None:
            path = os.path.join(self.cache_root, split, generator, sample_id, "clip_feats.pt")
            if not os.path.exists(path):
                return None
            feats = torch.load(path, weights_only=True)
            self._feat_cache[key] = feats   # keep a copy, RL re-uses same videos
        feats = feats.to(self.device, dtype=torch.float32).unsqueeze(0)  # (1, T, 768)
        logits = self.model(feats)                                       # (1, T)
        return torch.sigmoid(logits).squeeze(0).cpu().numpy()

    def warmup(self, video_keys):
        """Pre-load features into the in-memory cache before RL starts."""
        n = 0
        for split, gen, sid in video_keys:
            path = os.path.join(self.cache_root, split, gen, sid, "clip_feats.pt")
            if not os.path.exists(path):
                continue
            self._feat_cache[(split, gen, sid)] = torch.load(path, weights_only=True)
            n += 1
        return n


def sample_id_from_video_path(video_path: str) -> str:
    """Match the trainer's convention: stem of the basename."""
    return os.path.splitext(os.path.basename(video_path))[0]


def format_verifier_scores(scores: Optional[np.ndarray], per_line: int = 8) -> str:
    """Format per-second forgery probabilities as compact text for prompt context.

    Used in the verifier-as-context experiment: instead of consuming the verifier
    output as a reward shaper (which we found to be ~86% redundant with IoU and
    empirically harmful), we give the VLM the raw per-second scores so it can
    reason over them. RL reward stays pure IoU.

    Format example for a 12-second video:
        External forensics verifier per-second forgery scores (0-100):
        s0=12 s1=15 s2=11 s3=91 s4=95 s5=93 s6=34 s7=21
        s8=18 s9=15 s10=12 s11=10
    """
    if scores is None or len(scores) == 0:
        return ""
    header = "External forensics verifier per-second forgery scores (range 0-100, higher = more suspicious):"
    lines = [header]
    chunk = []
    for i, s in enumerate(scores):
        chunk.append(f"s{i}={int(round(float(s) * 100)):02d}")
        if (i + 1) % per_line == 0:
            lines.append(" ".join(chunk))
            chunk = []
    if chunk:
        lines.append(" ".join(chunk))
    return "\n".join(lines)