File size: 7,402 Bytes
33569f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | """ActivityForensics dataset loader for GRPO training.
Reads annotation files at <annot_dir>/{train,test}@<gen>.txt with format:
<video_filename> <duration> <s1=e1+s2=e2+...>
Maps generator names to their video directories and produces a HuggingFace
Dataset with the schema expected by grpo_forensics.py / Qwen2VL trainers.
"""
import json
import os
import random
from typing import Optional
import torch
from datasets import Dataset, DatasetDict
from tqdm import tqdm
GENERATOR_TO_DIR = {
"vidu": "01_vidu",
"wan": "02_wan",
"fcvg": "03_fcvg",
"scifi": "04_scifi",
"ltx": "05_ltx",
"vace-1.3B": "06_vace-1.3B",
}
TRAIN_GENERATORS = ["wan", "scifi", "fcvg", "vace-1.3B", "ltx"]
TEST_GENERATORS = ["wan", "scifi", "fcvg", "vace-1.3B", "ltx", "vidu"]
FORENSICS_PROBLEM = "forensics" # placeholder; QUESTION_TEMPLATE is fixed.
def parse_annotation_line(line: str):
"""Parse one line of annot/*.txt -> (filename, duration, [(s, e), ...])."""
parts = line.strip().split(" ")
assert len(parts) == 3, f"Bad line: {line!r}"
filename, duration_str, time_str = parts
duration = float(duration_str)
segments = []
for seg in time_str.split("+"):
s_str, e_str = seg.split("=")
s, e = float(s_str), float(e_str)
e = min(e, duration)
segments.append((s, e))
return filename, duration, segments
def build_examples(
annot_dir: str,
video_root: str,
generators: list,
split_prefix: str,
preprocessed_data_path: Optional[str] = None,
require_video_exists: bool = True,
):
"""Build a list of example dicts for the given split and generator list.
Env var FORENSICS_SPLIT_SINGLE_SPAN (default false): when true AND
split_prefix == "train", expand each multi-segment row into one example
per segment (solution becomes a length-1 list). Used for the strict
TempSamp-R1 / Charades-style baseline. Eval split is never expanded so
grounding metrics remain comparable across runs.
"""
split_single_span = (
split_prefix == "train"
and os.getenv("FORENSICS_SPLIT_SINGLE_SPAN", "false").lower() in ("true", "1", "yes")
)
examples = []
skipped = 0
expanded = 0
for gen in generators:
annot_file = os.path.join(annot_dir, f"{split_prefix}@{gen}.txt")
video_dir = os.path.join(video_root, GENERATOR_TO_DIR[gen])
if not os.path.exists(annot_file):
print(f"[warn] missing annot file: {annot_file}")
continue
with open(annot_file, "r") as f:
for line in f:
line = line.strip()
if not line:
continue
filename, duration, segments = parse_annotation_line(line)
video_path = os.path.join(video_dir, filename)
if require_video_exists and not os.path.exists(video_path):
skipped += 1
continue
emit_solutions = (
[[seg] for seg in segments] if split_single_span else [segments]
)
if split_single_span and len(segments) > 1:
expanded += len(segments) - 1
for sol in emit_solutions:
ex = {
"problem": FORENSICS_PROBLEM,
"solution": sol, # list[(s, e)]
"video_path": video_path,
"durations": duration,
"generator": gen,
"preprocessed_path": "",
}
if preprocessed_data_path:
sample_id = os.path.splitext(filename)[0]
ex["preprocessed_path"] = os.path.join(
preprocessed_data_path, split_prefix, gen, sample_id
)
examples.append(ex)
if skipped:
print(f"[forensics] {split_prefix}: skipped {skipped} examples with missing videos")
if split_single_span:
print(f"[forensics] {split_prefix}: SPLIT_SINGLE_SPAN on, +{expanded} rows from multi-seg expansion")
return examples
def make_dataset(examples: list, shuffle: bool = True, seed: int = 42):
if shuffle:
rng = random.Random(seed)
rng.shuffle(examples)
return Dataset.from_list(examples)
def load_forensics_dataset(
annot_dir: str,
video_root: str,
preprocessed_data_path: Optional[str] = None,
train_generators: Optional[list] = None,
test_generators: Optional[list] = None,
shuffle_train: bool = True,
seed: int = 42,
) -> DatasetDict:
train_generators = train_generators or TRAIN_GENERATORS
test_generators = test_generators or TEST_GENERATORS
train_examples = build_examples(
annot_dir, video_root, train_generators, "train", preprocessed_data_path
)
eval_examples = build_examples(
annot_dir, video_root, test_generators, "test", preprocessed_data_path
)
train_dataset = make_dataset(train_examples, shuffle=shuffle_train, seed=seed)
eval_dataset = make_dataset(eval_examples, shuffle=False)
print(f"[forensics] train: {len(train_dataset)} eval: {len(eval_dataset)}")
if len(train_dataset) > 0:
print(f"[forensics] sample train example: {train_dataset[0]}")
# SPI / FBR augmentations (off unless their env vars are set). Eval set
# never augmented; only train. Order: FBR (full reversal) then SPI
# (chunk permutation) — composable, GT remapping flows through both.
from .spi_aug import maybe_apply_spi
from .fbr_aug import maybe_apply_fbr
def _build_getitem(apply_augs: bool):
def __getitem__(self, idx):
example = Dataset.__getitem__(self, idx)
data = {k: v for k, v in example.items()}
pp = example.get("preprocessed_path", "")
if isinstance(pp, list):
pp = pp[0] if pp else ""
if pp:
try:
data["video_inputs"] = [torch.load(os.path.join(pp, "video_inputs.pt"))]
with open(os.path.join(pp, "video_kwargs.json"), "r") as f:
data["video_kwargs"] = [json.load(f)]
data["use_preprocessed"] = [True]
except Exception as e:
print(f"[warn] failed to load cache {pp}: {e}")
data["use_preprocessed"] = [False]
else:
data["use_preprocessed"] = [False]
if apply_augs:
data = maybe_apply_fbr(data)
data = maybe_apply_spi(data)
return data
return __getitem__
train_dataset.__getitem__ = _build_getitem(apply_augs=True).__get__(train_dataset, Dataset)
eval_dataset.__getitem__ = _build_getitem(apply_augs=False).__get__(eval_dataset, Dataset)
return DatasetDict({"train": train_dataset, "eval": eval_dataset})
if __name__ == "__main__":
# Smoke test
import argparse
p = argparse.ArgumentParser()
p.add_argument("--annot_dir", default="/ces/zt/activityforensics/annot")
p.add_argument("--video_root", default="/ces/zt")
args = p.parse_args()
ds = load_forensics_dataset(args.annot_dir, args.video_root)
print(ds)
for i in range(3):
print(ds["train"][i])
|