| from pathlib import Path |
| import argparse |
| import json |
| import pickle |
| import sys |
|
|
|
|
| PROJECT_ROOT = Path(__file__).resolve().parents[1] |
| if str(PROJECT_ROOT) not in sys.path: |
| sys.path.insert(0, str(PROJECT_ROOT)) |
|
|
| from rr_label_study.oven_study import ( |
| _compute_frame_rows_independent, |
| _compute_frame_rows_sequential, |
| _load_demo, |
| ) |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--episode-dir", required=True) |
| parser.add_argument("--templates-pkl", required=True) |
| parser.add_argument("--frame-indices", nargs="+", type=int, required=True) |
| parser.add_argument("--checkpoint-stride", type=int, default=16) |
| parser.add_argument("--output-dir", required=True) |
| parser.add_argument("--independent-replay", action="store_true") |
| args = parser.parse_args() |
|
|
| episode_dir = Path(args.episode_dir) |
| with Path(args.templates_pkl).open("rb") as handle: |
| templates = pickle.load(handle) |
| demo = _load_demo(episode_dir) |
| compute_rows = ( |
| _compute_frame_rows_independent |
| if args.independent_replay |
| else _compute_frame_rows_sequential |
| ) |
| rows = compute_rows( |
| episode_dir=episode_dir, |
| demo=demo, |
| templates=templates, |
| checkpoint_stride=args.checkpoint_stride, |
| frame_indices=args.frame_indices, |
| ) |
|
|
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| for row in rows: |
| frame_index = int(row["frame_index"]) |
| with output_dir.joinpath(f"frame_{frame_index:04d}.json").open( |
| "w", encoding="utf-8" |
| ) as handle: |
| json.dump(row, handle) |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|