forensics-grpo / code /src /open_r1 /fbr_aug.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
3.44 kB
"""Forward-Backward Reversal (FBR) data augmentation.
Hypothesis: AI-generated forgery segments do not model temporal causality
(gravity, momentum, physical flow). Reversing the video makes real content
"look weird" but leaves the AI-segment's intrinsic artifacts intact.
Implementation: with probability p, time-flip the cached video tensor and
remap each GT interval [s, e] -> [T - e, T - s] where T = duration. Train
with the standard iou reward. The model implicitly learns that forgery
detection should not depend on temporal direction — a useful inductive bias
when the data has no other signal pointing this way.
This is the AUGMENTATION-form of FBC (the reward-form was validated to
have weak correlation r=0.196 on stage1_decomp_boundary; augmentation costs
no extra compute and creates additional training data with derived GT).
Activation:
FORENSICS_FBR_AUG=true enables augmentation
FORENSICS_FBR_PROB=0.25 per-sample probability (default 0.25)
Off by default; existing stage1 / v1 runs unaffected. Composable with SPI:
when both are enabled, sample is first FBR-flipped then SPI-shuffled (or
vice versa, doesn't matter for the math since flip and shuffle commute up
to GT remapping).
"""
from __future__ import annotations
import os
import random
from typing import Any, Dict, List, Tuple
import torch
def _env_bool(name: str, default: str = "false") -> bool:
return os.getenv(name, default).lower() in ("true", "1", "yes")
def _normalise_intervals(solution: Any) -> List[Tuple[float, float]] | None:
if solution is None:
return None
if isinstance(solution, list) and solution:
first = solution[0]
if isinstance(first, (list, tuple)) and len(first) == 2 \
and isinstance(first[0], (int, float)):
return [(float(s), float(e)) for s, e in solution]
if isinstance(first, (int, float)) and len(solution) == 2:
return [(float(solution[0]), float(solution[1]))]
return None
def maybe_apply_fbr(data: Dict[str, Any]) -> Dict[str, Any]:
if not _env_bool("FORENSICS_FBR_AUG"):
return data
prob = float(os.getenv("FORENSICS_FBR_PROB", "0.25"))
if random.random() > prob:
return data
use_pp = data.get("use_preprocessed", [False])
if not (use_pp and use_pp[0]):
return data
try:
video_list = data["video_inputs"][0]
if not isinstance(video_list, list) or not video_list:
return data
video = video_list[0]
if not torch.is_tensor(video) or video.dim() < 3:
return data
T = video.shape[0]
if T < 8:
return data
kwargs = data["video_kwargs"][0]
fps = float(kwargs["fps"][0])
if fps <= 0:
return data
intervals = _normalise_intervals(data.get("solution"))
if not intervals:
return data
except (KeyError, IndexError, TypeError, ValueError):
return data
duration = T / fps
new_video = video.flip(0).contiguous()
new_intervals = []
for s, e in intervals:
# [s, e] in forward -> [T - e, T - s] in reversed coordinates.
new_s = max(0.0, duration - e)
new_e = max(new_s + 1.0 / fps, duration - s)
new_intervals.append((new_s, new_e))
new_intervals.sort()
data["video_inputs"] = [[new_video]]
data["solution"] = new_intervals
data["_fbr"] = [True]
return data