"""Forgery head + per-second label generation. Self-contained anomaly head bolted onto Qwen2.5-VL's vision encoder. Layout assumption (verified by inspection on Qwen2.5-VL-7B): - Visual encoder output is a flat token sequence of shape (N_tot, hidden_dim). - video_grid_thw[i] = (T_i, H_i, W_i) is the *pre-merger* grid; merger does 2x2 spatial merge so the post-merger sequence length per video i is T_i * (H_i // 2) * (W_i // 2). hidden_dim is the LLM hidden size (3584 for 7B). - Per the input fps=2 + temporal_stride=2 convention, T_i corresponds to ~1 output token group per second of the source video. We treat 1 temporal group = 1 second when generating frame-level labels. The head spatial-mean-pools each temporal group, normalizes, and emits a per-second logit. It is supervised by a BCE loss with labels derived from GT segments (see frame_labels_from_segments). """ from typing import List, Sequence, Tuple import torch import torch.nn as nn import torch.nn.functional as F class ForgeryHead(nn.Module): """Per-temporal-group binary forgery classifier on Qwen2.5-VL visual tokens.""" def __init__(self, hidden_dim: int = 3584, mlp_dim: int = 1024): super().__init__() self.norm = nn.LayerNorm(hidden_dim) self.mlp = nn.Sequential( nn.Linear(hidden_dim, mlp_dim), nn.GELU(), nn.Linear(mlp_dim, 1), ) def forward( self, visual_tokens: torch.Tensor, grid_thw: torch.Tensor, ) -> List[torch.Tensor]: """Return per-video logits. visual_tokens: (N_tot, hidden_dim) — concatenation across the batch grid_thw: (B, 3) — pre-merger (T, H, W) per video Returns: list of length B, each a (T_i,) logit tensor. """ outputs = [] offset = 0 for i in range(grid_thw.shape[0]): T, H, W = (int(x) for x in grid_thw[i].tolist()) spatial = (H // 2) * (W // 2) n = T * spatial chunk = visual_tokens[offset : offset + n] offset += n x = chunk.view(T, spatial, -1).mean(dim=1) # (T, hidden_dim) spatial pool x = self.norm(x) logits = self.mlp(x).squeeze(-1) # (T,) outputs.append(logits) return outputs def frame_labels_from_segments( segments: Sequence[Tuple[float, float]], T: int, fps_to_groups: float = 1.0, ) -> torch.Tensor: """Convert GT segments (in seconds) to per-temporal-group binary labels. fps_to_groups = how many *temporal groups* correspond to one second of source video. At Qwen2.5-VL's input fps=2 + temporal_stride=2 this is 1.0 (one group per second). Values fed to the head's sigmoid; missing labels are zero (= authentic). Args: segments: list of (start_sec, end_sec); may be empty. T: number of temporal groups for this video. fps_to_groups: groups per source second. Returns: (T,) float tensor in {0, 1}. """ label = torch.zeros(T, dtype=torch.float32) for s, e in segments: if e <= s: continue s_idx = max(0, int(s * fps_to_groups)) e_idx = min(T, int(e * fps_to_groups + 0.999)) # half-open, ceil end if s_idx < e_idx: label[s_idx:e_idx] = 1.0 return label def head_bce_loss( head_logits: List[torch.Tensor], segments_per_video: List[Sequence[Tuple[float, float]]], fps_to_groups: float = 1.0, ) -> torch.Tensor: """Binary cross-entropy across all temporal groups in the batch. Returns a scalar. Caller is responsible for any extra weighting. """ losses = [] for logits, segs in zip(head_logits, segments_per_video): T = logits.shape[0] labels = frame_labels_from_segments(segs, T, fps_to_groups).to(logits.device) losses.append(F.binary_cross_entropy_with_logits(logits, labels, reduction="mean")) if not losses: return torch.zeros((), device=head_logits[0].device if head_logits else "cpu") return torch.stack(losses).mean() def head_auc( head_logits: List[torch.Tensor], segments_per_video: List[Sequence[Tuple[float, float]]], fps_to_groups: float = 1.0, ) -> float: """Cheap macro-AUC across all groups in a batch (used for phase-2 gate). Concatenates all (logit, label) pairs across the batch and returns rank-AUC. Returns 0.5 when degenerate (all-positive or all-negative). """ all_logits, all_labels = [], [] for logits, segs in zip(head_logits, segments_per_video): T = logits.shape[0] labels = frame_labels_from_segments(segs, T, fps_to_groups).to(logits.device) all_logits.append(logits.detach().float()) all_labels.append(labels) if not all_logits: return 0.5 L = torch.cat(all_logits) Y = torch.cat(all_labels) pos = Y > 0.5 neg = ~pos n_pos, n_neg = int(pos.sum()), int(neg.sum()) if n_pos == 0 or n_neg == 0: return 0.5 pos_scores = L[pos] neg_scores = L[neg] # Mann-Whitney U via rank. cmp = (pos_scores.unsqueeze(1) > neg_scores.unsqueeze(0)).float() eq = (pos_scores.unsqueeze(1) == neg_scores.unsqueeze(0)).float() * 0.5 return float((cmp + eq).mean().item())