forensics-grpo / code /src /open_r1 /spi_aug.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
12.4 kB
"""Splice-Position Invariance (SPI) data augmentation.
Hypothesis: Forgery segments are "foreign matter" — their detectability
depends on intrinsic artifacts, not on temporal position or surrounding
context. Action localization is the opposite (context-dependent).
We exploit this by chunk-permuting each training video while keeping the
forgery interval as an atomic chunk. The model must locate the forgery
regardless of where its surrounding real chunks happen to be — forcing it
to use intrinsic forgery features, not contextual semantics.
Activation:
FORENSICS_SPI_AUG=true enables augmentation
FORENSICS_SPI_PROB=0.5 per-sample probability of applying SPI (default 0.5)
FORENSICS_SPI_CHUNK_S=2.5 real-chunk size in seconds (default 2.5)
FORENSICS_SPI_SAFETY_S=0.5 safety band around each GT interval (default 0.5)
SPLIT-FAKE extension (manufactures multi-segment GTs from any sample):
FORENSICS_SPI_SPLIT_FAKE=false also chunk the forgery and scatter it
FORENSICS_SPI_SPLIT_PROB=0.5 P(split | SPI applied) — keep <1 so atomic
long-forgery samples still appear
FORENSICS_SPI_FAKE_CHUNK_S=2.5 forgery-chunk size in seconds
FORENSICS_SPI_FAKE_MIN_S=2.0 lower bound on a fake chunk (avoid sub-2s
fragments whose IoU reward is pure noise)
Off by default; existing stage1 / v1 / v5 / v7* / v10 / v12 runs are unaffected.
Algorithm per sample:
1. Parse GT intervals -> frame ranges + safety band -> "forgery atoms"
(each atom keeps its frame order intact).
2. Slice the gap regions into uniform chunks of ~CHUNK_S seconds each.
3. (SPLIT-FAKE) Optionally slice each forgery atom into ~FAKE_CHUNK_S sub-
chunks so a single long forgery becomes several fake atoms.
4. Randomly permute all atoms (forgery atoms + real chunks together,
so the forgery's TEMPORAL POSITION varies too — this is what makes
the augmentation force position-invariance). Under SPLIT-FAKE the
permutation is constrained so no two fake chunks are adjacent (>=1 real
chunk between them), preventing them from re-fusing into one segment;
the scattered fake chunks become a 3-4 segment GT.
5. Concatenate frames in the permuted order; recompute each resulting
forgery piece's new (start, end) timestamps based on where it landed.
Fallback (returns data unchanged, or to atomic SPI) when:
- No cached preprocessed features for this sample
- Video has fewer than 16 frames (too short for meaningful permutation)
- Solution format is unparseable
- After atom construction, fewer than 2 real chunks exist
- (SPLIT-FAKE) too few real chunks to separate the fake chunks -> falls
back to atomic (un-split) SPI for that sample
- Random permutation happens to equal identity
"""
from __future__ import annotations
import os
import random
from typing import Any, Dict, List, Tuple
import torch
def _env_bool(name: str, default: str = "false") -> bool:
return os.getenv(name, default).lower() in ("true", "1", "yes")
def _normalise_intervals(solution: Any) -> List[Tuple[float, float]] | None:
"""Coerce HF Dataset's variable solution format -> list of (s, e)."""
if solution is None:
return None
if isinstance(solution, list) and solution:
first = solution[0]
if isinstance(first, (list, tuple)) and len(first) == 2 \
and isinstance(first[0], (int, float)):
return [(float(s), float(e)) for s, e in solution]
if isinstance(first, (int, float)) and len(solution) == 2:
return [(float(solution[0]), float(solution[1]))]
return None
def _split_forgery_atoms(
forgery_atoms: List[Dict[str, Any]], fake_chunk_f: int, fake_min_f: int, fps: float,
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Slice each forgery atom's frame span into ~fake_chunk_f sub-pieces.
Returns (fake_atoms, extra_real_atoms). A piece that turns out to contain
no forgery frames (e.g. pure safety-band margin) is returned as a real
atom so it can be shuffled like any other real chunk. Each fake atom
carries explicit forgery frame bounds (gfs, gfe) for GT remapping.
Atoms whose span is < 2*fake_chunk_f are kept whole (single fake atom).
"""
fake_atoms: List[Dict[str, Any]] = []
extra_real: List[Dict[str, Any]] = []
for atom in forgery_atoms:
fs, fe = atom["fs"], atom["fe"]
# True forgery span inside this (safety-banded) atom. Split only this
# span; hand the safety margins back as real so every fake chunk is
# pure forgery and its GT == the chunk (respects fake_min_f exactly).
ufs = max(fs, min(int(round(s * fps)) for s, _ in atom["orig"]))
ufe = min(fe, max(int(round(e * fps)) for _, e in atom["orig"]))
if ufs > fs:
extra_real.append({"type": "r", "fs": fs, "fe": ufs - 1})
if ufe < fe:
extra_real.append({"type": "r", "fs": ufe + 1, "fe": fe})
if (ufe - ufs + 1) < 2 * fake_chunk_f: # too short to split — keep whole
fake_atoms.append({"type": "f", "fs": ufs, "fe": ufe, "gfs": ufs, "gfe": ufe})
continue
pieces = [[st, min(st + fake_chunk_f - 1, ufe)] for st in range(ufs, ufe + 1, fake_chunk_f)]
if len(pieces) >= 2 and (pieces[-1][1] - pieces[-1][0] + 1) < fake_min_f:
pieces[-2][1] = pieces[-1][1]
pieces.pop()
for a, b in pieces:
fake_atoms.append({"type": "f", "fs": a, "fe": b, "gfs": a, "gfe": b})
return fake_atoms, extra_real
def _interleave_no_adjacent_fakes(
fakes: List[Dict[str, Any]], reals: List[Dict[str, Any]],
) -> List[Dict[str, Any]] | None:
"""Random arrangement with no two fake atoms adjacent (>=1 real between).
Places shuffled reals in a row (nR+1 slots around them) and drops one
shuffled fake into nF distinct slots. Requires nF <= nR + 1, else None.
"""
nF, nR = len(fakes), len(reals)
if nF > nR + 1:
return None
F, R = fakes[:], reals[:]
random.shuffle(F)
random.shuffle(R)
chosen = set(random.sample(range(nR + 1), nF))
seq: List[Dict[str, Any]] = []
fi = 0
for i in range(nR + 1):
if i in chosen:
seq.append(F[fi]); fi += 1
if i < nR:
seq.append(R[i])
return seq
def maybe_apply_spi(data: Dict[str, Any]) -> Dict[str, Any]:
"""Apply SPI augmentation in-place if env vars enable it and sample is eligible.
Modifies (when augmenting):
data["video_inputs"] -> [[shuffled_tensor]]
data["solution"] -> remapped list of (s, e)
data["_spi"] -> True (marker for downstream logging)
"""
if not _env_bool("FORENSICS_SPI_AUG"):
return data
prob = float(os.getenv("FORENSICS_SPI_PROB", "0.5"))
if random.random() > prob:
return data
use_pp = data.get("use_preprocessed", [False])
if not (use_pp and use_pp[0]):
return data
chunk_s = float(os.getenv("FORENSICS_SPI_CHUNK_S", "2.5"))
safety_s = float(os.getenv("FORENSICS_SPI_SAFETY_S", "0.5"))
split_fake = _env_bool("FORENSICS_SPI_SPLIT_FAKE") \
and random.random() < float(os.getenv("FORENSICS_SPI_SPLIT_PROB", "0.5"))
fake_chunk_s = float(os.getenv("FORENSICS_SPI_FAKE_CHUNK_S", "2.5"))
fake_min_s = float(os.getenv("FORENSICS_SPI_FAKE_MIN_S", "2.0"))
try:
video_list = data["video_inputs"][0] # cached as list-of-one-tensor
if not isinstance(video_list, list) or not video_list:
return data
video = video_list[0]
if not torch.is_tensor(video) or video.dim() < 3:
return data
T = video.shape[0]
if T < 16:
return data
kwargs = data["video_kwargs"][0]
fps = float(kwargs["fps"][0])
if fps <= 0:
return data
intervals = _normalise_intervals(data.get("solution"))
if not intervals:
return data
except (KeyError, IndexError, TypeError, ValueError):
return data
# Build forgery atoms: each GT interval -> frame range + safety band, then
# merge overlapping atoms.
safety_f = max(1, int(round(safety_s * fps)))
raw = []
for s, e in intervals:
fs = max(0, int(round(s * fps)) - safety_f)
fe = min(T - 1, int(round(e * fps)) + safety_f - 1)
if fe < fs:
fe = fs
raw.append((fs, fe, s, e))
raw.sort()
forgery_atoms: List[Dict[str, Any]] = []
for fs, fe, s_orig, e_orig in raw:
if forgery_atoms and fs <= forgery_atoms[-1]["fe"] + 1:
forgery_atoms[-1]["fe"] = max(forgery_atoms[-1]["fe"], fe)
forgery_atoms[-1]["orig"].append((s_orig, e_orig))
else:
forgery_atoms.append({"fs": fs, "fe": fe, "orig": [(s_orig, e_orig)]})
# Real ranges = gaps between forgery atoms (and head/tail).
real_ranges: List[Tuple[int, int]] = []
cur = 0
for atom in forgery_atoms:
if atom["fs"] > cur:
real_ranges.append((cur, atom["fs"] - 1))
cur = atom["fe"] + 1
if cur <= T - 1:
real_ranges.append((cur, T - 1))
# Slice real ranges into chunks.
chunk_f = max(2, int(round(chunk_s * fps)))
real_chunks: List[Tuple[int, int]] = []
for rs, re in real_ranges:
i = rs
while i <= re:
j = min(i + chunk_f - 1, re)
real_chunks.append((i, j))
i = j + 1
if len(real_chunks) < 2:
return data # not enough free material to permute
# Build fake atoms (one per forgery atom, or several when splitting) and
# the real atom pool.
real_atoms: List[Dict[str, Any]] = [{"type": "r", "fs": fs, "fe": fe} for fs, fe in real_chunks]
fake_atoms: List[Dict[str, Any]]
if split_fake:
fake_chunk_f = max(2, int(round(fake_chunk_s * fps)))
fake_min_f = max(2, int(round(fake_min_s * fps)))
fake_atoms, extra_real = _split_forgery_atoms(forgery_atoms, fake_chunk_f, fake_min_f, fps)
real_atoms.extend(extra_real)
# Need >=1 real chunk between each adjacent fake pair; else no point
# splitting — fall back to atomic SPI for this sample.
if len(fake_atoms) < 2 or len(fake_atoms) > len(real_atoms) + 1:
split_fake = False
if not split_fake:
fake_atoms = [{"type": "f", "fs": a["fs"], "fe": a["fe"], "orig": a["orig"]}
for a in forgery_atoms]
# Permute.
if split_fake:
order = _interleave_no_adjacent_fakes(fake_atoms, real_atoms)
if order is None:
return data
else:
atoms = fake_atoms + real_atoms
idx = list(range(len(atoms)))
random.shuffle(idx)
if idx == list(range(len(atoms))): # reject identity
if len(idx) >= 2:
idx[0], idx[1] = idx[1], idx[0]
else:
return data
order = [atoms[i] for i in idx]
# Reconstruct frame tensor + GT timestamps based on the new order.
pieces = []
new_intervals: List[Tuple[float, float]] = []
cursor = 0 # current frame in new video
for atom in order:
atom_start, atom_end = atom["fs"], atom["fe"]
atom_len = atom_end - atom_start + 1
pieces.append(video[atom_start:atom_end + 1])
if atom["type"] == "f":
if "gfs" in atom: # split fake chunk: one forgery span per piece
spans = [(atom["gfs"], atom["gfe"])]
else: # atomic atom: remap each original GT interval
spans = [(max(0, int(round(s * fps))), min(T - 1, int(round(e * fps))))
for s, e in atom["orig"]]
for fs_o, fe_o in spans:
new_s = (cursor + fs_o - atom_start) / fps
new_e = (cursor + fe_o - atom_start) / fps
if new_e <= new_s:
new_e = new_s + 1.0 / fps
new_intervals.append((new_s, new_e))
cursor += atom_len
if not new_intervals:
return data
new_intervals.sort()
new_video = torch.cat(pieces, dim=0).contiguous()
# Sanity check: shape preserved.
if new_video.shape[0] != T:
return data
data["video_inputs"] = [[new_video]]
data["solution"] = new_intervals
data["_spi"] = [True]
data["_spi_split"] = [bool(split_fake)]
return data