forensics-grpo / code /preprocess_forensics.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
4.63 kB
"""Pre-encode ActivityForensics videos with the Qwen2.5-VL processor.
For each example produced by data_loader.build_examples(), runs
qwen_vl_utils.process_vision_info to extract `video_inputs` and `video_kwargs`,
then saves them under <output_dir>/<split>/<gen>/<sample_id>/ as
`video_inputs.pt` + `video_kwargs.json`.
The training-time __getitem__ (in src/open_r1/data_loader.py) loads these
cached tensors and skips re-encoding.
"""
import argparse
import json
import multiprocessing as mp
import os
import sys
import torch
from qwen_vl_utils import process_vision_info
from tqdm import tqdm
# Allow `python preprocess_forensics.py` from project root
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.open_r1.data_loader import (
GENERATOR_TO_DIR, TRAIN_GENERATORS, TEST_GENERATORS, build_examples,
)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--annot_dir", default="/ces/zt/activityforensics/annot")
p.add_argument("--video_root", default="/ces/zt")
p.add_argument("--output_dir", required=True,
help="Where to dump cached video_inputs.pt / video_kwargs.json")
p.add_argument("--max_pix_size", type=int, default=3584,
help="total_pixels = max_pix_size * 28 * 28 (uniform for all videos)")
p.add_argument("--min_pix_size", type=int, default=16,
help="min_pixels = min_pix_size * 28 * 28")
p.add_argument("--fps", type=float, default=2.0,
help="Fixed sampling fps for every video (no max_frames cap).")
p.add_argument("--num_workers", type=int, default=8)
p.add_argument("--splits", nargs="+", default=["train", "test"],
choices=["train", "test"])
p.add_argument("--skip_existing", action="store_true",
help="Skip samples whose cache dir already has video_inputs.pt")
return p.parse_args()
def _encode_one(task):
video_path, max_pixels, min_pixels, fps, out_dir, skip_existing = task
if skip_existing and os.path.exists(os.path.join(out_dir, "video_inputs.pt")):
return ("skipped", video_path, None)
try:
messages = [
{"role": "user", "content": [
{"type": "video", "video": video_path,
"total_pixels": max_pixels, "min_pixels": min_pixels,
"fps": fps},
]},
]
_, video_inputs, video_kwargs = process_vision_info(
[messages], return_video_kwargs=True
)
os.makedirs(out_dir, exist_ok=True)
torch.save(video_inputs, os.path.join(out_dir, "video_inputs.pt"))
with open(os.path.join(out_dir, "video_kwargs.json"), "w") as f:
json.dump(video_kwargs, f)
return ("ok", video_path, None)
except Exception as e:
return ("fail", video_path, repr(e))
def process_split(split, args, base_max_pixels, min_pixels):
generators = TRAIN_GENERATORS if split == "train" else TEST_GENERATORS
examples = build_examples(
annot_dir=args.annot_dir, video_root=args.video_root,
generators=generators, split_prefix=split,
preprocessed_data_path=None, require_video_exists=True,
)
tasks = []
for ex in examples:
sample_id = os.path.splitext(os.path.basename(ex["video_path"]))[0]
out_dir = os.path.join(args.output_dir, split, ex["generator"], sample_id)
tasks.append((ex["video_path"], base_max_pixels, min_pixels,
args.fps, out_dir, args.skip_existing))
print(f"[{split}] uniform total_pixels={base_max_pixels} fixed_fps={args.fps} "
f"for {len(tasks)} videos")
print(f"[{split}] encoding {len(tasks)} videos with {args.num_workers} workers")
n_ok = n_skip = n_fail = 0
with mp.Pool(processes=args.num_workers) as pool:
for status, vp, err in tqdm(
pool.imap_unordered(_encode_one, tasks), total=len(tasks), desc=f"preprocess[{split}]"
):
if status == "ok": n_ok += 1
elif status == "skipped": n_skip += 1
else:
n_fail += 1
print(f"[fail] {vp}: {err}")
print(f"[{split}] ok={n_ok} skipped={n_skip} fail={n_fail}")
def main():
args = parse_args()
base_max_pixels = args.max_pix_size * 28 * 28
min_pixels = args.min_pix_size * 28 * 28
print(f"output_dir={args.output_dir} total_pixels={base_max_pixels} min_pixels={min_pixels}")
os.makedirs(args.output_dir, exist_ok=True)
for split in args.splits:
process_split(split, args, base_max_pixels, min_pixels)
if __name__ == "__main__":
main()