"""ActivityForensics dataset loader for GRPO training. Reads annotation files at /{train,test}@.txt with format: Maps generator names to their video directories and produces a HuggingFace Dataset with the schema expected by grpo_forensics.py / Qwen2VL trainers. """ import json import os import random from typing import Optional import torch from datasets import Dataset, DatasetDict from tqdm import tqdm GENERATOR_TO_DIR = { "vidu": "01_vidu", "wan": "02_wan", "fcvg": "03_fcvg", "scifi": "04_scifi", "ltx": "05_ltx", "vace-1.3B": "06_vace-1.3B", } TRAIN_GENERATORS = ["wan", "scifi", "fcvg", "vace-1.3B", "ltx"] TEST_GENERATORS = ["wan", "scifi", "fcvg", "vace-1.3B", "ltx", "vidu"] FORENSICS_PROBLEM = "forensics" # placeholder; QUESTION_TEMPLATE is fixed. def parse_annotation_line(line: str): """Parse one line of annot/*.txt -> (filename, duration, [(s, e), ...]).""" parts = line.strip().split(" ") assert len(parts) == 3, f"Bad line: {line!r}" filename, duration_str, time_str = parts duration = float(duration_str) segments = [] for seg in time_str.split("+"): s_str, e_str = seg.split("=") s, e = float(s_str), float(e_str) e = min(e, duration) segments.append((s, e)) return filename, duration, segments def build_examples( annot_dir: str, video_root: str, generators: list, split_prefix: str, preprocessed_data_path: Optional[str] = None, require_video_exists: bool = True, ): """Build a list of example dicts for the given split and generator list. Env var FORENSICS_SPLIT_SINGLE_SPAN (default false): when true AND split_prefix == "train", expand each multi-segment row into one example per segment (solution becomes a length-1 list). Used for the strict TempSamp-R1 / Charades-style baseline. Eval split is never expanded so grounding metrics remain comparable across runs. """ split_single_span = ( split_prefix == "train" and os.getenv("FORENSICS_SPLIT_SINGLE_SPAN", "false").lower() in ("true", "1", "yes") ) examples = [] skipped = 0 expanded = 0 for gen in generators: annot_file = os.path.join(annot_dir, f"{split_prefix}@{gen}.txt") video_dir = os.path.join(video_root, GENERATOR_TO_DIR[gen]) if not os.path.exists(annot_file): print(f"[warn] missing annot file: {annot_file}") continue with open(annot_file, "r") as f: for line in f: line = line.strip() if not line: continue filename, duration, segments = parse_annotation_line(line) video_path = os.path.join(video_dir, filename) if require_video_exists and not os.path.exists(video_path): skipped += 1 continue emit_solutions = ( [[seg] for seg in segments] if split_single_span else [segments] ) if split_single_span and len(segments) > 1: expanded += len(segments) - 1 for sol in emit_solutions: ex = { "problem": FORENSICS_PROBLEM, "solution": sol, # list[(s, e)] "video_path": video_path, "durations": duration, "generator": gen, "preprocessed_path": "", } if preprocessed_data_path: sample_id = os.path.splitext(filename)[0] ex["preprocessed_path"] = os.path.join( preprocessed_data_path, split_prefix, gen, sample_id ) examples.append(ex) if skipped: print(f"[forensics] {split_prefix}: skipped {skipped} examples with missing videos") if split_single_span: print(f"[forensics] {split_prefix}: SPLIT_SINGLE_SPAN on, +{expanded} rows from multi-seg expansion") return examples def make_dataset(examples: list, shuffle: bool = True, seed: int = 42): if shuffle: rng = random.Random(seed) rng.shuffle(examples) return Dataset.from_list(examples) def load_forensics_dataset( annot_dir: str, video_root: str, preprocessed_data_path: Optional[str] = None, train_generators: Optional[list] = None, test_generators: Optional[list] = None, shuffle_train: bool = True, seed: int = 42, ) -> DatasetDict: train_generators = train_generators or TRAIN_GENERATORS test_generators = test_generators or TEST_GENERATORS train_examples = build_examples( annot_dir, video_root, train_generators, "train", preprocessed_data_path ) eval_examples = build_examples( annot_dir, video_root, test_generators, "test", preprocessed_data_path ) train_dataset = make_dataset(train_examples, shuffle=shuffle_train, seed=seed) eval_dataset = make_dataset(eval_examples, shuffle=False) print(f"[forensics] train: {len(train_dataset)} eval: {len(eval_dataset)}") if len(train_dataset) > 0: print(f"[forensics] sample train example: {train_dataset[0]}") # SPI / FBR augmentations (off unless their env vars are set). Eval set # never augmented; only train. Order: FBR (full reversal) then SPI # (chunk permutation) — composable, GT remapping flows through both. from .spi_aug import maybe_apply_spi from .fbr_aug import maybe_apply_fbr def _build_getitem(apply_augs: bool): def __getitem__(self, idx): example = Dataset.__getitem__(self, idx) data = {k: v for k, v in example.items()} pp = example.get("preprocessed_path", "") if isinstance(pp, list): pp = pp[0] if pp else "" if pp: try: data["video_inputs"] = [torch.load(os.path.join(pp, "video_inputs.pt"))] with open(os.path.join(pp, "video_kwargs.json"), "r") as f: data["video_kwargs"] = [json.load(f)] data["use_preprocessed"] = [True] except Exception as e: print(f"[warn] failed to load cache {pp}: {e}") data["use_preprocessed"] = [False] else: data["use_preprocessed"] = [False] if apply_augs: data = maybe_apply_fbr(data) data = maybe_apply_spi(data) return data return __getitem__ train_dataset.__getitem__ = _build_getitem(apply_augs=True).__get__(train_dataset, Dataset) eval_dataset.__getitem__ = _build_getitem(apply_augs=False).__get__(eval_dataset, Dataset) return DatasetDict({"train": train_dataset, "eval": eval_dataset}) if __name__ == "__main__": # Smoke test import argparse p = argparse.ArgumentParser() p.add_argument("--annot_dir", default="/ces/zt/activityforensics/annot") p.add_argument("--video_root", default="/ces/zt") args = p.parse_args() ds = load_forensics_dataset(args.annot_dir, args.video_root) print(ds) for i in range(3): print(ds["train"][i])