File size: 1,733 Bytes
712dc89
 
 
 
 
 
 
 
 
 
 
7f173cd
 
 
 
 
712dc89
 
 
 
 
 
 
 
 
7f173cd
712dc89
 
 
 
 
 
7f173cd
 
 
 
 
 
712dc89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from pathlib import Path
import argparse
import json
import pickle
import sys


PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from rr_label_study.oven_study import (
    _compute_frame_rows_independent,
    _compute_frame_rows_sequential,
    _load_demo,
)


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--episode-dir", required=True)
    parser.add_argument("--templates-pkl", required=True)
    parser.add_argument("--frame-indices", nargs="+", type=int, required=True)
    parser.add_argument("--checkpoint-stride", type=int, default=16)
    parser.add_argument("--output-dir", required=True)
    parser.add_argument("--independent-replay", action="store_true")
    args = parser.parse_args()

    episode_dir = Path(args.episode_dir)
    with Path(args.templates_pkl).open("rb") as handle:
        templates = pickle.load(handle)
    demo = _load_demo(episode_dir)
    compute_rows = (
        _compute_frame_rows_independent
        if args.independent_replay
        else _compute_frame_rows_sequential
    )
    rows = compute_rows(
        episode_dir=episode_dir,
        demo=demo,
        templates=templates,
        checkpoint_stride=args.checkpoint_stride,
        frame_indices=args.frame_indices,
    )

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    for row in rows:
        frame_index = int(row["frame_index"])
        with output_dir.joinpath(f"frame_{frame_index:04d}.json").open(
            "w", encoding="utf-8"
        ) as handle:
            json.dump(row, handle)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())