| """Forgery head + per-second label generation. |
| |
| Self-contained anomaly head bolted onto Qwen2.5-VL's vision encoder. |
| |
| Layout assumption (verified by inspection on Qwen2.5-VL-7B): |
| - Visual encoder output is a flat token sequence of shape (N_tot, hidden_dim). |
| - video_grid_thw[i] = (T_i, H_i, W_i) is the *pre-merger* grid; merger does 2x2 |
| spatial merge so the post-merger sequence length per video i is |
| T_i * (H_i // 2) * (W_i // 2). hidden_dim is the LLM hidden size (3584 for 7B). |
| - Per the input fps=2 + temporal_stride=2 convention, T_i corresponds to ~1 |
| output token group per second of the source video. We treat 1 temporal group |
| = 1 second when generating frame-level labels. |
| |
| The head spatial-mean-pools each temporal group, normalizes, and emits a |
| per-second logit. It is supervised by a BCE loss with labels derived from |
| GT segments (see frame_labels_from_segments). |
| """ |
| from typing import List, Sequence, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class ForgeryHead(nn.Module): |
| """Per-temporal-group binary forgery classifier on Qwen2.5-VL visual tokens.""" |
|
|
| def __init__(self, hidden_dim: int = 3584, mlp_dim: int = 1024): |
| super().__init__() |
| self.norm = nn.LayerNorm(hidden_dim) |
| self.mlp = nn.Sequential( |
| nn.Linear(hidden_dim, mlp_dim), |
| nn.GELU(), |
| nn.Linear(mlp_dim, 1), |
| ) |
|
|
| def forward( |
| self, |
| visual_tokens: torch.Tensor, |
| grid_thw: torch.Tensor, |
| ) -> List[torch.Tensor]: |
| """Return per-video logits. |
| |
| visual_tokens: (N_tot, hidden_dim) — concatenation across the batch |
| grid_thw: (B, 3) — pre-merger (T, H, W) per video |
| |
| Returns: list of length B, each a (T_i,) logit tensor. |
| """ |
| outputs = [] |
| offset = 0 |
| for i in range(grid_thw.shape[0]): |
| T, H, W = (int(x) for x in grid_thw[i].tolist()) |
| spatial = (H // 2) * (W // 2) |
| n = T * spatial |
| chunk = visual_tokens[offset : offset + n] |
| offset += n |
| x = chunk.view(T, spatial, -1).mean(dim=1) |
| x = self.norm(x) |
| logits = self.mlp(x).squeeze(-1) |
| outputs.append(logits) |
| return outputs |
|
|
|
|
| def frame_labels_from_segments( |
| segments: Sequence[Tuple[float, float]], |
| T: int, |
| fps_to_groups: float = 1.0, |
| ) -> torch.Tensor: |
| """Convert GT segments (in seconds) to per-temporal-group binary labels. |
| |
| fps_to_groups = how many *temporal groups* correspond to one second of |
| source video. At Qwen2.5-VL's input fps=2 + temporal_stride=2 this is 1.0 |
| (one group per second). Values fed to the head's sigmoid; missing labels |
| are zero (= authentic). |
| |
| Args: |
| segments: list of (start_sec, end_sec); may be empty. |
| T: number of temporal groups for this video. |
| fps_to_groups: groups per source second. |
| |
| Returns: |
| (T,) float tensor in {0, 1}. |
| """ |
| label = torch.zeros(T, dtype=torch.float32) |
| for s, e in segments: |
| if e <= s: |
| continue |
| s_idx = max(0, int(s * fps_to_groups)) |
| e_idx = min(T, int(e * fps_to_groups + 0.999)) |
| if s_idx < e_idx: |
| label[s_idx:e_idx] = 1.0 |
| return label |
|
|
|
|
| def head_bce_loss( |
| head_logits: List[torch.Tensor], |
| segments_per_video: List[Sequence[Tuple[float, float]]], |
| fps_to_groups: float = 1.0, |
| ) -> torch.Tensor: |
| """Binary cross-entropy across all temporal groups in the batch. |
| |
| Returns a scalar. Caller is responsible for any extra weighting. |
| """ |
| losses = [] |
| for logits, segs in zip(head_logits, segments_per_video): |
| T = logits.shape[0] |
| labels = frame_labels_from_segments(segs, T, fps_to_groups).to(logits.device) |
| losses.append(F.binary_cross_entropy_with_logits(logits, labels, reduction="mean")) |
| if not losses: |
| return torch.zeros((), device=head_logits[0].device if head_logits else "cpu") |
| return torch.stack(losses).mean() |
|
|
|
|
| def head_auc( |
| head_logits: List[torch.Tensor], |
| segments_per_video: List[Sequence[Tuple[float, float]]], |
| fps_to_groups: float = 1.0, |
| ) -> float: |
| """Cheap macro-AUC across all groups in a batch (used for phase-2 gate). |
| |
| Concatenates all (logit, label) pairs across the batch and returns |
| rank-AUC. Returns 0.5 when degenerate (all-positive or all-negative). |
| """ |
| all_logits, all_labels = [], [] |
| for logits, segs in zip(head_logits, segments_per_video): |
| T = logits.shape[0] |
| labels = frame_labels_from_segments(segs, T, fps_to_groups).to(logits.device) |
| all_logits.append(logits.detach().float()) |
| all_labels.append(labels) |
| if not all_logits: |
| return 0.5 |
| L = torch.cat(all_logits) |
| Y = torch.cat(all_labels) |
| pos = Y > 0.5 |
| neg = ~pos |
| n_pos, n_neg = int(pos.sum()), int(neg.sum()) |
| if n_pos == 0 or n_neg == 0: |
| return 0.5 |
| pos_scores = L[pos] |
| neg_scores = L[neg] |
| |
| cmp = (pos_scores.unsqueeze(1) > neg_scores.unsqueeze(0)).float() |
| eq = (pos_scores.unsqueeze(1) == neg_scores.unsqueeze(0)).float() * 0.5 |
| return float((cmp + eq).mean().item()) |
|
|