| """ActivityForensics dataset loader for GRPO training. |
| |
| Reads annotation files at <annot_dir>/{train,test}@<gen>.txt with format: |
| <video_filename> <duration> <s1=e1+s2=e2+...> |
| |
| Maps generator names to their video directories and produces a HuggingFace |
| Dataset with the schema expected by grpo_forensics.py / Qwen2VL trainers. |
| """ |
| import json |
| import os |
| import random |
| from typing import Optional |
|
|
| import torch |
| from datasets import Dataset, DatasetDict |
| from tqdm import tqdm |
|
|
|
|
| GENERATOR_TO_DIR = { |
| "vidu": "01_vidu", |
| "wan": "02_wan", |
| "fcvg": "03_fcvg", |
| "scifi": "04_scifi", |
| "ltx": "05_ltx", |
| "vace-1.3B": "06_vace-1.3B", |
| } |
|
|
| TRAIN_GENERATORS = ["wan", "scifi", "fcvg", "vace-1.3B", "ltx"] |
| TEST_GENERATORS = ["wan", "scifi", "fcvg", "vace-1.3B", "ltx", "vidu"] |
|
|
| FORENSICS_PROBLEM = "forensics" |
|
|
|
|
| def parse_annotation_line(line: str): |
| """Parse one line of annot/*.txt -> (filename, duration, [(s, e), ...]).""" |
| parts = line.strip().split(" ") |
| assert len(parts) == 3, f"Bad line: {line!r}" |
| filename, duration_str, time_str = parts |
| duration = float(duration_str) |
| segments = [] |
| for seg in time_str.split("+"): |
| s_str, e_str = seg.split("=") |
| s, e = float(s_str), float(e_str) |
| e = min(e, duration) |
| segments.append((s, e)) |
| return filename, duration, segments |
|
|
|
|
| def build_examples( |
| annot_dir: str, |
| video_root: str, |
| generators: list, |
| split_prefix: str, |
| preprocessed_data_path: Optional[str] = None, |
| require_video_exists: bool = True, |
| ): |
| """Build a list of example dicts for the given split and generator list. |
| |
| Env var FORENSICS_SPLIT_SINGLE_SPAN (default false): when true AND |
| split_prefix == "train", expand each multi-segment row into one example |
| per segment (solution becomes a length-1 list). Used for the strict |
| TempSamp-R1 / Charades-style baseline. Eval split is never expanded so |
| grounding metrics remain comparable across runs. |
| """ |
| split_single_span = ( |
| split_prefix == "train" |
| and os.getenv("FORENSICS_SPLIT_SINGLE_SPAN", "false").lower() in ("true", "1", "yes") |
| ) |
| examples = [] |
| skipped = 0 |
| expanded = 0 |
| for gen in generators: |
| annot_file = os.path.join(annot_dir, f"{split_prefix}@{gen}.txt") |
| video_dir = os.path.join(video_root, GENERATOR_TO_DIR[gen]) |
| if not os.path.exists(annot_file): |
| print(f"[warn] missing annot file: {annot_file}") |
| continue |
| with open(annot_file, "r") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| filename, duration, segments = parse_annotation_line(line) |
| video_path = os.path.join(video_dir, filename) |
| if require_video_exists and not os.path.exists(video_path): |
| skipped += 1 |
| continue |
| emit_solutions = ( |
| [[seg] for seg in segments] if split_single_span else [segments] |
| ) |
| if split_single_span and len(segments) > 1: |
| expanded += len(segments) - 1 |
| for sol in emit_solutions: |
| ex = { |
| "problem": FORENSICS_PROBLEM, |
| "solution": sol, |
| "video_path": video_path, |
| "durations": duration, |
| "generator": gen, |
| "preprocessed_path": "", |
| } |
| if preprocessed_data_path: |
| sample_id = os.path.splitext(filename)[0] |
| ex["preprocessed_path"] = os.path.join( |
| preprocessed_data_path, split_prefix, gen, sample_id |
| ) |
| examples.append(ex) |
| if skipped: |
| print(f"[forensics] {split_prefix}: skipped {skipped} examples with missing videos") |
| if split_single_span: |
| print(f"[forensics] {split_prefix}: SPLIT_SINGLE_SPAN on, +{expanded} rows from multi-seg expansion") |
| return examples |
|
|
|
|
| def make_dataset(examples: list, shuffle: bool = True, seed: int = 42): |
| if shuffle: |
| rng = random.Random(seed) |
| rng.shuffle(examples) |
| return Dataset.from_list(examples) |
|
|
|
|
| def load_forensics_dataset( |
| annot_dir: str, |
| video_root: str, |
| preprocessed_data_path: Optional[str] = None, |
| train_generators: Optional[list] = None, |
| test_generators: Optional[list] = None, |
| shuffle_train: bool = True, |
| seed: int = 42, |
| ) -> DatasetDict: |
| train_generators = train_generators or TRAIN_GENERATORS |
| test_generators = test_generators or TEST_GENERATORS |
|
|
| train_examples = build_examples( |
| annot_dir, video_root, train_generators, "train", preprocessed_data_path |
| ) |
| eval_examples = build_examples( |
| annot_dir, video_root, test_generators, "test", preprocessed_data_path |
| ) |
|
|
| train_dataset = make_dataset(train_examples, shuffle=shuffle_train, seed=seed) |
| eval_dataset = make_dataset(eval_examples, shuffle=False) |
|
|
| print(f"[forensics] train: {len(train_dataset)} eval: {len(eval_dataset)}") |
| if len(train_dataset) > 0: |
| print(f"[forensics] sample train example: {train_dataset[0]}") |
|
|
| |
| |
| |
| from .spi_aug import maybe_apply_spi |
| from .fbr_aug import maybe_apply_fbr |
|
|
| def _build_getitem(apply_augs: bool): |
| def __getitem__(self, idx): |
| example = Dataset.__getitem__(self, idx) |
| data = {k: v for k, v in example.items()} |
| pp = example.get("preprocessed_path", "") |
| if isinstance(pp, list): |
| pp = pp[0] if pp else "" |
| if pp: |
| try: |
| data["video_inputs"] = [torch.load(os.path.join(pp, "video_inputs.pt"))] |
| with open(os.path.join(pp, "video_kwargs.json"), "r") as f: |
| data["video_kwargs"] = [json.load(f)] |
| data["use_preprocessed"] = [True] |
| except Exception as e: |
| print(f"[warn] failed to load cache {pp}: {e}") |
| data["use_preprocessed"] = [False] |
| else: |
| data["use_preprocessed"] = [False] |
| if apply_augs: |
| data = maybe_apply_fbr(data) |
| data = maybe_apply_spi(data) |
| return data |
| return __getitem__ |
|
|
| train_dataset.__getitem__ = _build_getitem(apply_augs=True).__get__(train_dataset, Dataset) |
| eval_dataset.__getitem__ = _build_getitem(apply_augs=False).__get__(eval_dataset, Dataset) |
|
|
| return DatasetDict({"train": train_dataset, "eval": eval_dataset}) |
|
|
|
|
| if __name__ == "__main__": |
| |
| import argparse |
| p = argparse.ArgumentParser() |
| p.add_argument("--annot_dir", default="/ces/zt/activityforensics/annot") |
| p.add_argument("--video_root", default="/ces/zt") |
| args = p.parse_args() |
| ds = load_forensics_dataset(args.annot_dir, args.video_root) |
| print(ds) |
| for i in range(3): |
| print(ds["train"][i]) |
|
|