File size: 7,402 Bytes
33569f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""ActivityForensics dataset loader for GRPO training.

Reads annotation files at <annot_dir>/{train,test}@<gen>.txt with format:
    <video_filename> <duration> <s1=e1+s2=e2+...>

Maps generator names to their video directories and produces a HuggingFace
Dataset with the schema expected by grpo_forensics.py / Qwen2VL trainers.
"""
import json
import os
import random
from typing import Optional

import torch
from datasets import Dataset, DatasetDict
from tqdm import tqdm


GENERATOR_TO_DIR = {
    "vidu":       "01_vidu",
    "wan":        "02_wan",
    "fcvg":       "03_fcvg",
    "scifi":      "04_scifi",
    "ltx":        "05_ltx",
    "vace-1.3B":  "06_vace-1.3B",
}

TRAIN_GENERATORS = ["wan", "scifi", "fcvg", "vace-1.3B", "ltx"]
TEST_GENERATORS  = ["wan", "scifi", "fcvg", "vace-1.3B", "ltx", "vidu"]

FORENSICS_PROBLEM = "forensics"  # placeholder; QUESTION_TEMPLATE is fixed.


def parse_annotation_line(line: str):
    """Parse one line of annot/*.txt -> (filename, duration, [(s, e), ...])."""
    parts = line.strip().split(" ")
    assert len(parts) == 3, f"Bad line: {line!r}"
    filename, duration_str, time_str = parts
    duration = float(duration_str)
    segments = []
    for seg in time_str.split("+"):
        s_str, e_str = seg.split("=")
        s, e = float(s_str), float(e_str)
        e = min(e, duration)
        segments.append((s, e))
    return filename, duration, segments


def build_examples(
    annot_dir: str,
    video_root: str,
    generators: list,
    split_prefix: str,
    preprocessed_data_path: Optional[str] = None,
    require_video_exists: bool = True,
):
    """Build a list of example dicts for the given split and generator list.

    Env var FORENSICS_SPLIT_SINGLE_SPAN (default false): when true AND
    split_prefix == "train", expand each multi-segment row into one example
    per segment (solution becomes a length-1 list). Used for the strict
    TempSamp-R1 / Charades-style baseline. Eval split is never expanded so
    grounding metrics remain comparable across runs.
    """
    split_single_span = (
        split_prefix == "train"
        and os.getenv("FORENSICS_SPLIT_SINGLE_SPAN", "false").lower() in ("true", "1", "yes")
    )
    examples = []
    skipped = 0
    expanded = 0
    for gen in generators:
        annot_file = os.path.join(annot_dir, f"{split_prefix}@{gen}.txt")
        video_dir = os.path.join(video_root, GENERATOR_TO_DIR[gen])
        if not os.path.exists(annot_file):
            print(f"[warn] missing annot file: {annot_file}")
            continue
        with open(annot_file, "r") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                filename, duration, segments = parse_annotation_line(line)
                video_path = os.path.join(video_dir, filename)
                if require_video_exists and not os.path.exists(video_path):
                    skipped += 1
                    continue
                emit_solutions = (
                    [[seg] for seg in segments] if split_single_span else [segments]
                )
                if split_single_span and len(segments) > 1:
                    expanded += len(segments) - 1
                for sol in emit_solutions:
                    ex = {
                        "problem": FORENSICS_PROBLEM,
                        "solution": sol,                 # list[(s, e)]
                        "video_path": video_path,
                        "durations": duration,
                        "generator": gen,
                        "preprocessed_path": "",
                    }
                    if preprocessed_data_path:
                        sample_id = os.path.splitext(filename)[0]
                        ex["preprocessed_path"] = os.path.join(
                            preprocessed_data_path, split_prefix, gen, sample_id
                        )
                    examples.append(ex)
    if skipped:
        print(f"[forensics] {split_prefix}: skipped {skipped} examples with missing videos")
    if split_single_span:
        print(f"[forensics] {split_prefix}: SPLIT_SINGLE_SPAN on, +{expanded} rows from multi-seg expansion")
    return examples


def make_dataset(examples: list, shuffle: bool = True, seed: int = 42):
    if shuffle:
        rng = random.Random(seed)
        rng.shuffle(examples)
    return Dataset.from_list(examples)


def load_forensics_dataset(
    annot_dir: str,
    video_root: str,
    preprocessed_data_path: Optional[str] = None,
    train_generators: Optional[list] = None,
    test_generators: Optional[list] = None,
    shuffle_train: bool = True,
    seed: int = 42,
) -> DatasetDict:
    train_generators = train_generators or TRAIN_GENERATORS
    test_generators = test_generators or TEST_GENERATORS

    train_examples = build_examples(
        annot_dir, video_root, train_generators, "train", preprocessed_data_path
    )
    eval_examples = build_examples(
        annot_dir, video_root, test_generators, "test", preprocessed_data_path
    )

    train_dataset = make_dataset(train_examples, shuffle=shuffle_train, seed=seed)
    eval_dataset = make_dataset(eval_examples, shuffle=False)

    print(f"[forensics] train: {len(train_dataset)}  eval: {len(eval_dataset)}")
    if len(train_dataset) > 0:
        print(f"[forensics] sample train example: {train_dataset[0]}")

    # SPI / FBR augmentations (off unless their env vars are set). Eval set
    # never augmented; only train. Order: FBR (full reversal) then SPI
    # (chunk permutation) — composable, GT remapping flows through both.
    from .spi_aug import maybe_apply_spi
    from .fbr_aug import maybe_apply_fbr

    def _build_getitem(apply_augs: bool):
        def __getitem__(self, idx):
            example = Dataset.__getitem__(self, idx)
            data = {k: v for k, v in example.items()}
            pp = example.get("preprocessed_path", "")
            if isinstance(pp, list):
                pp = pp[0] if pp else ""
            if pp:
                try:
                    data["video_inputs"] = [torch.load(os.path.join(pp, "video_inputs.pt"))]
                    with open(os.path.join(pp, "video_kwargs.json"), "r") as f:
                        data["video_kwargs"] = [json.load(f)]
                    data["use_preprocessed"] = [True]
                except Exception as e:
                    print(f"[warn] failed to load cache {pp}: {e}")
                    data["use_preprocessed"] = [False]
            else:
                data["use_preprocessed"] = [False]
            if apply_augs:
                data = maybe_apply_fbr(data)
                data = maybe_apply_spi(data)
            return data
        return __getitem__

    train_dataset.__getitem__ = _build_getitem(apply_augs=True).__get__(train_dataset, Dataset)
    eval_dataset.__getitem__ = _build_getitem(apply_augs=False).__get__(eval_dataset, Dataset)

    return DatasetDict({"train": train_dataset, "eval": eval_dataset})


if __name__ == "__main__":
    # Smoke test
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument("--annot_dir", default="/ces/zt/activityforensics/annot")
    p.add_argument("--video_root", default="/ces/zt")
    args = p.parse_args()
    ds = load_forensics_dataset(args.annot_dir, args.video_root)
    print(ds)
    for i in range(3):
        print(ds["train"][i])