File size: 4,632 Bytes
33569f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Pre-encode ActivityForensics videos with the Qwen2.5-VL processor.

For each example produced by data_loader.build_examples(), runs
qwen_vl_utils.process_vision_info to extract `video_inputs` and `video_kwargs`,
then saves them under <output_dir>/<split>/<gen>/<sample_id>/ as
`video_inputs.pt` + `video_kwargs.json`.

The training-time __getitem__ (in src/open_r1/data_loader.py) loads these
cached tensors and skips re-encoding.
"""
import argparse
import json
import multiprocessing as mp
import os
import sys

import torch
from qwen_vl_utils import process_vision_info
from tqdm import tqdm

# Allow `python preprocess_forensics.py` from project root
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.open_r1.data_loader import (
    GENERATOR_TO_DIR, TRAIN_GENERATORS, TEST_GENERATORS, build_examples,
)


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--annot_dir", default="/ces/zt/activityforensics/annot")
    p.add_argument("--video_root", default="/ces/zt")
    p.add_argument("--output_dir", required=True,
                   help="Where to dump cached video_inputs.pt / video_kwargs.json")
    p.add_argument("--max_pix_size", type=int, default=3584,
                   help="total_pixels = max_pix_size * 28 * 28 (uniform for all videos)")
    p.add_argument("--min_pix_size", type=int, default=16,
                   help="min_pixels = min_pix_size * 28 * 28")
    p.add_argument("--fps", type=float, default=2.0,
                   help="Fixed sampling fps for every video (no max_frames cap).")
    p.add_argument("--num_workers", type=int, default=8)
    p.add_argument("--splits", nargs="+", default=["train", "test"],
                   choices=["train", "test"])
    p.add_argument("--skip_existing", action="store_true",
                   help="Skip samples whose cache dir already has video_inputs.pt")
    return p.parse_args()


def _encode_one(task):
    video_path, max_pixels, min_pixels, fps, out_dir, skip_existing = task
    if skip_existing and os.path.exists(os.path.join(out_dir, "video_inputs.pt")):
        return ("skipped", video_path, None)
    try:
        messages = [
            {"role": "user", "content": [
                {"type": "video", "video": video_path,
                 "total_pixels": max_pixels, "min_pixels": min_pixels,
                 "fps": fps},
            ]},
        ]
        _, video_inputs, video_kwargs = process_vision_info(
            [messages], return_video_kwargs=True
        )
        os.makedirs(out_dir, exist_ok=True)
        torch.save(video_inputs, os.path.join(out_dir, "video_inputs.pt"))
        with open(os.path.join(out_dir, "video_kwargs.json"), "w") as f:
            json.dump(video_kwargs, f)
        return ("ok", video_path, None)
    except Exception as e:
        return ("fail", video_path, repr(e))


def process_split(split, args, base_max_pixels, min_pixels):
    generators = TRAIN_GENERATORS if split == "train" else TEST_GENERATORS
    examples = build_examples(
        annot_dir=args.annot_dir, video_root=args.video_root,
        generators=generators, split_prefix=split,
        preprocessed_data_path=None, require_video_exists=True,
    )

    tasks = []
    for ex in examples:
        sample_id = os.path.splitext(os.path.basename(ex["video_path"]))[0]
        out_dir = os.path.join(args.output_dir, split, ex["generator"], sample_id)
        tasks.append((ex["video_path"], base_max_pixels, min_pixels,
                      args.fps, out_dir, args.skip_existing))
    print(f"[{split}] uniform total_pixels={base_max_pixels} fixed_fps={args.fps} "
          f"for {len(tasks)} videos")

    print(f"[{split}] encoding {len(tasks)} videos with {args.num_workers} workers")
    n_ok = n_skip = n_fail = 0
    with mp.Pool(processes=args.num_workers) as pool:
        for status, vp, err in tqdm(
            pool.imap_unordered(_encode_one, tasks), total=len(tasks), desc=f"preprocess[{split}]"
        ):
            if status == "ok": n_ok += 1
            elif status == "skipped": n_skip += 1
            else:
                n_fail += 1
                print(f"[fail] {vp}: {err}")
    print(f"[{split}] ok={n_ok} skipped={n_skip} fail={n_fail}")


def main():
    args = parse_args()
    base_max_pixels = args.max_pix_size * 28 * 28
    min_pixels = args.min_pix_size * 28 * 28
    print(f"output_dir={args.output_dir}  total_pixels={base_max_pixels}  min_pixels={min_pixels}")
    os.makedirs(args.output_dir, exist_ok=True)
    for split in args.splits:
        process_split(split, args, base_max_pixels, min_pixels)


if __name__ == "__main__":
    main()