File size: 4,809 Bytes
d93804e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | from __future__ import annotations
import argparse
import os
import shutil
import subprocess
import sys
import time
from pathlib import Path
def checkpoint_ready(checkpoint_dir: Path) -> bool:
return (checkpoint_dir / "_CHECKPOINT_METADATA").is_file() and (checkpoint_dir / "params" / "_METADATA").is_file()
def wait_for_checkpoint(checkpoint_dir: Path, poll_s: float) -> None:
while not checkpoint_ready(checkpoint_dir):
print(f"waiting for checkpoint: {checkpoint_dir}", flush=True)
time.sleep(poll_s)
def pid_running(pid: int) -> bool:
return Path(f"/proc/{pid}").exists()
def wait_for_pid_exit(pid: int, poll_s: float) -> None:
while pid_running(pid):
print(f"waiting for pid to exit: {pid}", flush=True)
time.sleep(poll_s)
def stage_checkpoint(checkpoint_dir: Path, stage_root: Path) -> Path:
stage_root.mkdir(parents=True, exist_ok=True)
staged_dir = stage_root / checkpoint_dir.name
if staged_dir.exists():
shutil.rmtree(staged_dir)
subprocess.run(["cp", "-al", str(checkpoint_dir), str(stage_root)], check=True)
return staged_dir
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint-dir", default=None, help="Checkpoint step directory, e.g. .../4000")
parser.add_argument(
"--checkpoint-root",
default="/workspace/openpi/checkpoints/pi05_kinova_teddybear/teddybear_kinova_ft_v4_b24_w12_fork",
help="Checkpoint experiment root used with --step.",
)
parser.add_argument("--step", type=int, default=None, help="Checkpoint step under --checkpoint-root.")
parser.add_argument("--wait", action="store_true", help="Wait until the checkpoint has finalized before exporting.")
parser.add_argument("--poll-s", type=float, default=15.0)
parser.add_argument("--wait-for-pid-exit", type=int, default=None, help="Wait for this PID to exit before export.")
parser.add_argument(
"--stage-root",
default=None,
help="Optional directory where the finalized checkpoint is hardlink-staged before export.",
)
parser.add_argument("--session-root", default="/workspace/data/teddybear_raw/session_20260327_165944_bear")
parser.add_argument("--sync-row-index", type=int, default=0)
parser.add_argument("--prompt", default="pick up the teddy bear and place it in the red box")
parser.add_argument(
"--python",
default="/workspace/openpi/.venv/bin/python",
help="Python interpreter used to run the OpenPI export step.",
)
parser.add_argument(
"--jax-platforms",
default="cpu",
help="Value for JAX_PLATFORMS during action export. Use 'cuda' once training has stopped.",
)
parser.add_argument(
"--output-json",
default="/workspace/kinova_scene_sim/outputs/pi_actions_preview.json",
)
parser.add_argument(
"--output-gif",
default="/workspace/kinova_scene_sim/outputs/pi_policy_preview.gif",
)
parser.add_argument("--fps", type=int, default=4)
args = parser.parse_args()
if args.checkpoint_dir is None:
if args.step is None:
raise ValueError("provide either --checkpoint-dir or --step")
checkpoint_dir = Path(args.checkpoint_root) / str(args.step)
else:
checkpoint_dir = Path(args.checkpoint_dir)
if args.wait:
wait_for_checkpoint(checkpoint_dir, args.poll_s)
elif not checkpoint_ready(checkpoint_dir):
raise FileNotFoundError(f"checkpoint not ready: {checkpoint_dir}")
if args.stage_root is not None:
checkpoint_dir = stage_checkpoint(checkpoint_dir, Path(args.stage_root))
print(f"staged checkpoint to {checkpoint_dir}", flush=True)
if args.wait_for_pid_exit is not None:
wait_for_pid_exit(args.wait_for_pid_exit, args.poll_s)
export_cmd = [
args.python,
"/workspace/kinova_scene_sim/export_pi_actions.py",
"--checkpoint-dir",
str(checkpoint_dir),
"--session-root",
args.session_root,
"--sync-row-index",
str(args.sync_row_index),
"--prompt",
args.prompt,
"--output-json",
args.output_json,
]
export_env = os.environ.copy()
export_env["JAX_PLATFORMS"] = args.jax_platforms
subprocess.run(export_cmd, check=True, env=export_env)
render_cmd = [
sys.executable,
"/workspace/kinova_scene_sim/render_pose_sequence.py",
"--session-root",
args.session_root,
"--poses-json",
args.output_json,
"--output",
args.output_gif,
"--fps",
str(args.fps),
]
subprocess.run(render_cmd, check=True)
print(f"saved preview gif to {args.output_gif}")
if __name__ == "__main__":
main()
|