"""Splice-Position Invariance (SPI) data augmentation. Hypothesis: Forgery segments are "foreign matter" — their detectability depends on intrinsic artifacts, not on temporal position or surrounding context. Action localization is the opposite (context-dependent). We exploit this by chunk-permuting each training video while keeping the forgery interval as an atomic chunk. The model must locate the forgery regardless of where its surrounding real chunks happen to be — forcing it to use intrinsic forgery features, not contextual semantics. Activation: FORENSICS_SPI_AUG=true enables augmentation FORENSICS_SPI_PROB=0.5 per-sample probability of applying SPI (default 0.5) FORENSICS_SPI_CHUNK_S=2.5 real-chunk size in seconds (default 2.5) FORENSICS_SPI_SAFETY_S=0.5 safety band around each GT interval (default 0.5) SPLIT-FAKE extension (manufactures multi-segment GTs from any sample): FORENSICS_SPI_SPLIT_FAKE=false also chunk the forgery and scatter it FORENSICS_SPI_SPLIT_PROB=0.5 P(split | SPI applied) — keep <1 so atomic long-forgery samples still appear FORENSICS_SPI_FAKE_CHUNK_S=2.5 forgery-chunk size in seconds FORENSICS_SPI_FAKE_MIN_S=2.0 lower bound on a fake chunk (avoid sub-2s fragments whose IoU reward is pure noise) Off by default; existing stage1 / v1 / v5 / v7* / v10 / v12 runs are unaffected. Algorithm per sample: 1. Parse GT intervals -> frame ranges + safety band -> "forgery atoms" (each atom keeps its frame order intact). 2. Slice the gap regions into uniform chunks of ~CHUNK_S seconds each. 3. (SPLIT-FAKE) Optionally slice each forgery atom into ~FAKE_CHUNK_S sub- chunks so a single long forgery becomes several fake atoms. 4. Randomly permute all atoms (forgery atoms + real chunks together, so the forgery's TEMPORAL POSITION varies too — this is what makes the augmentation force position-invariance). Under SPLIT-FAKE the permutation is constrained so no two fake chunks are adjacent (>=1 real chunk between them), preventing them from re-fusing into one segment; the scattered fake chunks become a 3-4 segment GT. 5. Concatenate frames in the permuted order; recompute each resulting forgery piece's new (start, end) timestamps based on where it landed. Fallback (returns data unchanged, or to atomic SPI) when: - No cached preprocessed features for this sample - Video has fewer than 16 frames (too short for meaningful permutation) - Solution format is unparseable - After atom construction, fewer than 2 real chunks exist - (SPLIT-FAKE) too few real chunks to separate the fake chunks -> falls back to atomic (un-split) SPI for that sample - Random permutation happens to equal identity """ from __future__ import annotations import os import random from typing import Any, Dict, List, Tuple import torch def _env_bool(name: str, default: str = "false") -> bool: return os.getenv(name, default).lower() in ("true", "1", "yes") def _normalise_intervals(solution: Any) -> List[Tuple[float, float]] | None: """Coerce HF Dataset's variable solution format -> list of (s, e).""" if solution is None: return None if isinstance(solution, list) and solution: first = solution[0] if isinstance(first, (list, tuple)) and len(first) == 2 \ and isinstance(first[0], (int, float)): return [(float(s), float(e)) for s, e in solution] if isinstance(first, (int, float)) and len(solution) == 2: return [(float(solution[0]), float(solution[1]))] return None def _split_forgery_atoms( forgery_atoms: List[Dict[str, Any]], fake_chunk_f: int, fake_min_f: int, fps: float, ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: """Slice each forgery atom's frame span into ~fake_chunk_f sub-pieces. Returns (fake_atoms, extra_real_atoms). A piece that turns out to contain no forgery frames (e.g. pure safety-band margin) is returned as a real atom so it can be shuffled like any other real chunk. Each fake atom carries explicit forgery frame bounds (gfs, gfe) for GT remapping. Atoms whose span is < 2*fake_chunk_f are kept whole (single fake atom). """ fake_atoms: List[Dict[str, Any]] = [] extra_real: List[Dict[str, Any]] = [] for atom in forgery_atoms: fs, fe = atom["fs"], atom["fe"] # True forgery span inside this (safety-banded) atom. Split only this # span; hand the safety margins back as real so every fake chunk is # pure forgery and its GT == the chunk (respects fake_min_f exactly). ufs = max(fs, min(int(round(s * fps)) for s, _ in atom["orig"])) ufe = min(fe, max(int(round(e * fps)) for _, e in atom["orig"])) if ufs > fs: extra_real.append({"type": "r", "fs": fs, "fe": ufs - 1}) if ufe < fe: extra_real.append({"type": "r", "fs": ufe + 1, "fe": fe}) if (ufe - ufs + 1) < 2 * fake_chunk_f: # too short to split — keep whole fake_atoms.append({"type": "f", "fs": ufs, "fe": ufe, "gfs": ufs, "gfe": ufe}) continue pieces = [[st, min(st + fake_chunk_f - 1, ufe)] for st in range(ufs, ufe + 1, fake_chunk_f)] if len(pieces) >= 2 and (pieces[-1][1] - pieces[-1][0] + 1) < fake_min_f: pieces[-2][1] = pieces[-1][1] pieces.pop() for a, b in pieces: fake_atoms.append({"type": "f", "fs": a, "fe": b, "gfs": a, "gfe": b}) return fake_atoms, extra_real def _interleave_no_adjacent_fakes( fakes: List[Dict[str, Any]], reals: List[Dict[str, Any]], ) -> List[Dict[str, Any]] | None: """Random arrangement with no two fake atoms adjacent (>=1 real between). Places shuffled reals in a row (nR+1 slots around them) and drops one shuffled fake into nF distinct slots. Requires nF <= nR + 1, else None. """ nF, nR = len(fakes), len(reals) if nF > nR + 1: return None F, R = fakes[:], reals[:] random.shuffle(F) random.shuffle(R) chosen = set(random.sample(range(nR + 1), nF)) seq: List[Dict[str, Any]] = [] fi = 0 for i in range(nR + 1): if i in chosen: seq.append(F[fi]); fi += 1 if i < nR: seq.append(R[i]) return seq def maybe_apply_spi(data: Dict[str, Any]) -> Dict[str, Any]: """Apply SPI augmentation in-place if env vars enable it and sample is eligible. Modifies (when augmenting): data["video_inputs"] -> [[shuffled_tensor]] data["solution"] -> remapped list of (s, e) data["_spi"] -> True (marker for downstream logging) """ if not _env_bool("FORENSICS_SPI_AUG"): return data prob = float(os.getenv("FORENSICS_SPI_PROB", "0.5")) if random.random() > prob: return data use_pp = data.get("use_preprocessed", [False]) if not (use_pp and use_pp[0]): return data chunk_s = float(os.getenv("FORENSICS_SPI_CHUNK_S", "2.5")) safety_s = float(os.getenv("FORENSICS_SPI_SAFETY_S", "0.5")) split_fake = _env_bool("FORENSICS_SPI_SPLIT_FAKE") \ and random.random() < float(os.getenv("FORENSICS_SPI_SPLIT_PROB", "0.5")) fake_chunk_s = float(os.getenv("FORENSICS_SPI_FAKE_CHUNK_S", "2.5")) fake_min_s = float(os.getenv("FORENSICS_SPI_FAKE_MIN_S", "2.0")) try: video_list = data["video_inputs"][0] # cached as list-of-one-tensor if not isinstance(video_list, list) or not video_list: return data video = video_list[0] if not torch.is_tensor(video) or video.dim() < 3: return data T = video.shape[0] if T < 16: return data kwargs = data["video_kwargs"][0] fps = float(kwargs["fps"][0]) if fps <= 0: return data intervals = _normalise_intervals(data.get("solution")) if not intervals: return data except (KeyError, IndexError, TypeError, ValueError): return data # Build forgery atoms: each GT interval -> frame range + safety band, then # merge overlapping atoms. safety_f = max(1, int(round(safety_s * fps))) raw = [] for s, e in intervals: fs = max(0, int(round(s * fps)) - safety_f) fe = min(T - 1, int(round(e * fps)) + safety_f - 1) if fe < fs: fe = fs raw.append((fs, fe, s, e)) raw.sort() forgery_atoms: List[Dict[str, Any]] = [] for fs, fe, s_orig, e_orig in raw: if forgery_atoms and fs <= forgery_atoms[-1]["fe"] + 1: forgery_atoms[-1]["fe"] = max(forgery_atoms[-1]["fe"], fe) forgery_atoms[-1]["orig"].append((s_orig, e_orig)) else: forgery_atoms.append({"fs": fs, "fe": fe, "orig": [(s_orig, e_orig)]}) # Real ranges = gaps between forgery atoms (and head/tail). real_ranges: List[Tuple[int, int]] = [] cur = 0 for atom in forgery_atoms: if atom["fs"] > cur: real_ranges.append((cur, atom["fs"] - 1)) cur = atom["fe"] + 1 if cur <= T - 1: real_ranges.append((cur, T - 1)) # Slice real ranges into chunks. chunk_f = max(2, int(round(chunk_s * fps))) real_chunks: List[Tuple[int, int]] = [] for rs, re in real_ranges: i = rs while i <= re: j = min(i + chunk_f - 1, re) real_chunks.append((i, j)) i = j + 1 if len(real_chunks) < 2: return data # not enough free material to permute # Build fake atoms (one per forgery atom, or several when splitting) and # the real atom pool. real_atoms: List[Dict[str, Any]] = [{"type": "r", "fs": fs, "fe": fe} for fs, fe in real_chunks] fake_atoms: List[Dict[str, Any]] if split_fake: fake_chunk_f = max(2, int(round(fake_chunk_s * fps))) fake_min_f = max(2, int(round(fake_min_s * fps))) fake_atoms, extra_real = _split_forgery_atoms(forgery_atoms, fake_chunk_f, fake_min_f, fps) real_atoms.extend(extra_real) # Need >=1 real chunk between each adjacent fake pair; else no point # splitting — fall back to atomic SPI for this sample. if len(fake_atoms) < 2 or len(fake_atoms) > len(real_atoms) + 1: split_fake = False if not split_fake: fake_atoms = [{"type": "f", "fs": a["fs"], "fe": a["fe"], "orig": a["orig"]} for a in forgery_atoms] # Permute. if split_fake: order = _interleave_no_adjacent_fakes(fake_atoms, real_atoms) if order is None: return data else: atoms = fake_atoms + real_atoms idx = list(range(len(atoms))) random.shuffle(idx) if idx == list(range(len(atoms))): # reject identity if len(idx) >= 2: idx[0], idx[1] = idx[1], idx[0] else: return data order = [atoms[i] for i in idx] # Reconstruct frame tensor + GT timestamps based on the new order. pieces = [] new_intervals: List[Tuple[float, float]] = [] cursor = 0 # current frame in new video for atom in order: atom_start, atom_end = atom["fs"], atom["fe"] atom_len = atom_end - atom_start + 1 pieces.append(video[atom_start:atom_end + 1]) if atom["type"] == "f": if "gfs" in atom: # split fake chunk: one forgery span per piece spans = [(atom["gfs"], atom["gfe"])] else: # atomic atom: remap each original GT interval spans = [(max(0, int(round(s * fps))), min(T - 1, int(round(e * fps)))) for s, e in atom["orig"]] for fs_o, fe_o in spans: new_s = (cursor + fs_o - atom_start) / fps new_e = (cursor + fe_o - atom_start) / fps if new_e <= new_s: new_e = new_s + 1.0 / fps new_intervals.append((new_s, new_e)) cursor += atom_len if not new_intervals: return data new_intervals.sort() new_video = torch.cat(pieces, dim=0).contiguous() # Sanity check: shape preserved. if new_video.shape[0] != T: return data data["video_inputs"] = [[new_video]] data["solution"] = new_intervals data["_spi"] = [True] data["_spi_split"] = [bool(split_fake)] return data