forensics-grpo / code /src /open_r1 /forgery_head.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
5.29 kB
"""Forgery head + per-second label generation.
Self-contained anomaly head bolted onto Qwen2.5-VL's vision encoder.
Layout assumption (verified by inspection on Qwen2.5-VL-7B):
- Visual encoder output is a flat token sequence of shape (N_tot, hidden_dim).
- video_grid_thw[i] = (T_i, H_i, W_i) is the *pre-merger* grid; merger does 2x2
spatial merge so the post-merger sequence length per video i is
T_i * (H_i // 2) * (W_i // 2). hidden_dim is the LLM hidden size (3584 for 7B).
- Per the input fps=2 + temporal_stride=2 convention, T_i corresponds to ~1
output token group per second of the source video. We treat 1 temporal group
= 1 second when generating frame-level labels.
The head spatial-mean-pools each temporal group, normalizes, and emits a
per-second logit. It is supervised by a BCE loss with labels derived from
GT segments (see frame_labels_from_segments).
"""
from typing import List, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class ForgeryHead(nn.Module):
"""Per-temporal-group binary forgery classifier on Qwen2.5-VL visual tokens."""
def __init__(self, hidden_dim: int = 3584, mlp_dim: int = 1024):
super().__init__()
self.norm = nn.LayerNorm(hidden_dim)
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, 1),
)
def forward(
self,
visual_tokens: torch.Tensor,
grid_thw: torch.Tensor,
) -> List[torch.Tensor]:
"""Return per-video logits.
visual_tokens: (N_tot, hidden_dim) — concatenation across the batch
grid_thw: (B, 3) — pre-merger (T, H, W) per video
Returns: list of length B, each a (T_i,) logit tensor.
"""
outputs = []
offset = 0
for i in range(grid_thw.shape[0]):
T, H, W = (int(x) for x in grid_thw[i].tolist())
spatial = (H // 2) * (W // 2)
n = T * spatial
chunk = visual_tokens[offset : offset + n]
offset += n
x = chunk.view(T, spatial, -1).mean(dim=1) # (T, hidden_dim) spatial pool
x = self.norm(x)
logits = self.mlp(x).squeeze(-1) # (T,)
outputs.append(logits)
return outputs
def frame_labels_from_segments(
segments: Sequence[Tuple[float, float]],
T: int,
fps_to_groups: float = 1.0,
) -> torch.Tensor:
"""Convert GT segments (in seconds) to per-temporal-group binary labels.
fps_to_groups = how many *temporal groups* correspond to one second of
source video. At Qwen2.5-VL's input fps=2 + temporal_stride=2 this is 1.0
(one group per second). Values fed to the head's sigmoid; missing labels
are zero (= authentic).
Args:
segments: list of (start_sec, end_sec); may be empty.
T: number of temporal groups for this video.
fps_to_groups: groups per source second.
Returns:
(T,) float tensor in {0, 1}.
"""
label = torch.zeros(T, dtype=torch.float32)
for s, e in segments:
if e <= s:
continue
s_idx = max(0, int(s * fps_to_groups))
e_idx = min(T, int(e * fps_to_groups + 0.999)) # half-open, ceil end
if s_idx < e_idx:
label[s_idx:e_idx] = 1.0
return label
def head_bce_loss(
head_logits: List[torch.Tensor],
segments_per_video: List[Sequence[Tuple[float, float]]],
fps_to_groups: float = 1.0,
) -> torch.Tensor:
"""Binary cross-entropy across all temporal groups in the batch.
Returns a scalar. Caller is responsible for any extra weighting.
"""
losses = []
for logits, segs in zip(head_logits, segments_per_video):
T = logits.shape[0]
labels = frame_labels_from_segments(segs, T, fps_to_groups).to(logits.device)
losses.append(F.binary_cross_entropy_with_logits(logits, labels, reduction="mean"))
if not losses:
return torch.zeros((), device=head_logits[0].device if head_logits else "cpu")
return torch.stack(losses).mean()
def head_auc(
head_logits: List[torch.Tensor],
segments_per_video: List[Sequence[Tuple[float, float]]],
fps_to_groups: float = 1.0,
) -> float:
"""Cheap macro-AUC across all groups in a batch (used for phase-2 gate).
Concatenates all (logit, label) pairs across the batch and returns
rank-AUC. Returns 0.5 when degenerate (all-positive or all-negative).
"""
all_logits, all_labels = [], []
for logits, segs in zip(head_logits, segments_per_video):
T = logits.shape[0]
labels = frame_labels_from_segments(segs, T, fps_to_groups).to(logits.device)
all_logits.append(logits.detach().float())
all_labels.append(labels)
if not all_logits:
return 0.5
L = torch.cat(all_logits)
Y = torch.cat(all_labels)
pos = Y > 0.5
neg = ~pos
n_pos, n_neg = int(pos.sum()), int(neg.sum())
if n_pos == 0 or n_neg == 0:
return 0.5
pos_scores = L[pos]
neg_scores = L[neg]
# Mann-Whitney U via rank.
cmp = (pos_scores.unsqueeze(1) > neg_scores.unsqueeze(0)).float()
eq = (pos_scores.unsqueeze(1) == neg_scores.unsqueeze(0)).float() * 0.5
return float((cmp + eq).mean().item())