forensics-grpo / code /src /open_r1 /data_loader.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
7.4 kB
"""ActivityForensics dataset loader for GRPO training.
Reads annotation files at <annot_dir>/{train,test}@<gen>.txt with format:
<video_filename> <duration> <s1=e1+s2=e2+...>
Maps generator names to their video directories and produces a HuggingFace
Dataset with the schema expected by grpo_forensics.py / Qwen2VL trainers.
"""
import json
import os
import random
from typing import Optional
import torch
from datasets import Dataset, DatasetDict
from tqdm import tqdm
GENERATOR_TO_DIR = {
"vidu": "01_vidu",
"wan": "02_wan",
"fcvg": "03_fcvg",
"scifi": "04_scifi",
"ltx": "05_ltx",
"vace-1.3B": "06_vace-1.3B",
}
TRAIN_GENERATORS = ["wan", "scifi", "fcvg", "vace-1.3B", "ltx"]
TEST_GENERATORS = ["wan", "scifi", "fcvg", "vace-1.3B", "ltx", "vidu"]
FORENSICS_PROBLEM = "forensics" # placeholder; QUESTION_TEMPLATE is fixed.
def parse_annotation_line(line: str):
"""Parse one line of annot/*.txt -> (filename, duration, [(s, e), ...])."""
parts = line.strip().split(" ")
assert len(parts) == 3, f"Bad line: {line!r}"
filename, duration_str, time_str = parts
duration = float(duration_str)
segments = []
for seg in time_str.split("+"):
s_str, e_str = seg.split("=")
s, e = float(s_str), float(e_str)
e = min(e, duration)
segments.append((s, e))
return filename, duration, segments
def build_examples(
annot_dir: str,
video_root: str,
generators: list,
split_prefix: str,
preprocessed_data_path: Optional[str] = None,
require_video_exists: bool = True,
):
"""Build a list of example dicts for the given split and generator list.
Env var FORENSICS_SPLIT_SINGLE_SPAN (default false): when true AND
split_prefix == "train", expand each multi-segment row into one example
per segment (solution becomes a length-1 list). Used for the strict
TempSamp-R1 / Charades-style baseline. Eval split is never expanded so
grounding metrics remain comparable across runs.
"""
split_single_span = (
split_prefix == "train"
and os.getenv("FORENSICS_SPLIT_SINGLE_SPAN", "false").lower() in ("true", "1", "yes")
)
examples = []
skipped = 0
expanded = 0
for gen in generators:
annot_file = os.path.join(annot_dir, f"{split_prefix}@{gen}.txt")
video_dir = os.path.join(video_root, GENERATOR_TO_DIR[gen])
if not os.path.exists(annot_file):
print(f"[warn] missing annot file: {annot_file}")
continue
with open(annot_file, "r") as f:
for line in f:
line = line.strip()
if not line:
continue
filename, duration, segments = parse_annotation_line(line)
video_path = os.path.join(video_dir, filename)
if require_video_exists and not os.path.exists(video_path):
skipped += 1
continue
emit_solutions = (
[[seg] for seg in segments] if split_single_span else [segments]
)
if split_single_span and len(segments) > 1:
expanded += len(segments) - 1
for sol in emit_solutions:
ex = {
"problem": FORENSICS_PROBLEM,
"solution": sol, # list[(s, e)]
"video_path": video_path,
"durations": duration,
"generator": gen,
"preprocessed_path": "",
}
if preprocessed_data_path:
sample_id = os.path.splitext(filename)[0]
ex["preprocessed_path"] = os.path.join(
preprocessed_data_path, split_prefix, gen, sample_id
)
examples.append(ex)
if skipped:
print(f"[forensics] {split_prefix}: skipped {skipped} examples with missing videos")
if split_single_span:
print(f"[forensics] {split_prefix}: SPLIT_SINGLE_SPAN on, +{expanded} rows from multi-seg expansion")
return examples
def make_dataset(examples: list, shuffle: bool = True, seed: int = 42):
if shuffle:
rng = random.Random(seed)
rng.shuffle(examples)
return Dataset.from_list(examples)
def load_forensics_dataset(
annot_dir: str,
video_root: str,
preprocessed_data_path: Optional[str] = None,
train_generators: Optional[list] = None,
test_generators: Optional[list] = None,
shuffle_train: bool = True,
seed: int = 42,
) -> DatasetDict:
train_generators = train_generators or TRAIN_GENERATORS
test_generators = test_generators or TEST_GENERATORS
train_examples = build_examples(
annot_dir, video_root, train_generators, "train", preprocessed_data_path
)
eval_examples = build_examples(
annot_dir, video_root, test_generators, "test", preprocessed_data_path
)
train_dataset = make_dataset(train_examples, shuffle=shuffle_train, seed=seed)
eval_dataset = make_dataset(eval_examples, shuffle=False)
print(f"[forensics] train: {len(train_dataset)} eval: {len(eval_dataset)}")
if len(train_dataset) > 0:
print(f"[forensics] sample train example: {train_dataset[0]}")
# SPI / FBR augmentations (off unless their env vars are set). Eval set
# never augmented; only train. Order: FBR (full reversal) then SPI
# (chunk permutation) — composable, GT remapping flows through both.
from .spi_aug import maybe_apply_spi
from .fbr_aug import maybe_apply_fbr
def _build_getitem(apply_augs: bool):
def __getitem__(self, idx):
example = Dataset.__getitem__(self, idx)
data = {k: v for k, v in example.items()}
pp = example.get("preprocessed_path", "")
if isinstance(pp, list):
pp = pp[0] if pp else ""
if pp:
try:
data["video_inputs"] = [torch.load(os.path.join(pp, "video_inputs.pt"))]
with open(os.path.join(pp, "video_kwargs.json"), "r") as f:
data["video_kwargs"] = [json.load(f)]
data["use_preprocessed"] = [True]
except Exception as e:
print(f"[warn] failed to load cache {pp}: {e}")
data["use_preprocessed"] = [False]
else:
data["use_preprocessed"] = [False]
if apply_augs:
data = maybe_apply_fbr(data)
data = maybe_apply_spi(data)
return data
return __getitem__
train_dataset.__getitem__ = _build_getitem(apply_augs=True).__get__(train_dataset, Dataset)
eval_dataset.__getitem__ = _build_getitem(apply_augs=False).__get__(eval_dataset, Dataset)
return DatasetDict({"train": train_dataset, "eval": eval_dataset})
if __name__ == "__main__":
# Smoke test
import argparse
p = argparse.ArgumentParser()
p.add_argument("--annot_dir", default="/ces/zt/activityforensics/annot")
p.add_argument("--video_root", default="/ces/zt")
args = p.parse_args()
ds = load_forensics_dataset(args.annot_dir, args.video_root)
print(ds)
for i in range(3):
print(ds["train"][i])