from __future__ import annotations import argparse import os import shutil import subprocess import sys import time from pathlib import Path def checkpoint_ready(checkpoint_dir: Path) -> bool: return (checkpoint_dir / "_CHECKPOINT_METADATA").is_file() and (checkpoint_dir / "params" / "_METADATA").is_file() def wait_for_checkpoint(checkpoint_dir: Path, poll_s: float) -> None: while not checkpoint_ready(checkpoint_dir): print(f"waiting for checkpoint: {checkpoint_dir}", flush=True) time.sleep(poll_s) def pid_running(pid: int) -> bool: return Path(f"/proc/{pid}").exists() def wait_for_pid_exit(pid: int, poll_s: float) -> None: while pid_running(pid): print(f"waiting for pid to exit: {pid}", flush=True) time.sleep(poll_s) def stage_checkpoint(checkpoint_dir: Path, stage_root: Path) -> Path: stage_root.mkdir(parents=True, exist_ok=True) staged_dir = stage_root / checkpoint_dir.name if staged_dir.exists(): shutil.rmtree(staged_dir) subprocess.run(["cp", "-al", str(checkpoint_dir), str(stage_root)], check=True) return staged_dir def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--checkpoint-dir", default=None, help="Checkpoint step directory, e.g. .../4000") parser.add_argument( "--checkpoint-root", default="/workspace/openpi/checkpoints/pi05_kinova_teddybear/teddybear_kinova_ft_v4_b24_w12_fork", help="Checkpoint experiment root used with --step.", ) parser.add_argument("--step", type=int, default=None, help="Checkpoint step under --checkpoint-root.") parser.add_argument("--wait", action="store_true", help="Wait until the checkpoint has finalized before exporting.") parser.add_argument("--poll-s", type=float, default=15.0) parser.add_argument("--wait-for-pid-exit", type=int, default=None, help="Wait for this PID to exit before export.") parser.add_argument( "--stage-root", default=None, help="Optional directory where the finalized checkpoint is hardlink-staged before export.", ) parser.add_argument("--session-root", default="/workspace/data/teddybear_raw/session_20260327_165944_bear") parser.add_argument("--sync-row-index", type=int, default=0) parser.add_argument("--prompt", default="pick up the teddy bear and place it in the red box") parser.add_argument( "--python", default="/workspace/openpi/.venv/bin/python", help="Python interpreter used to run the OpenPI export step.", ) parser.add_argument( "--jax-platforms", default="cpu", help="Value for JAX_PLATFORMS during action export. Use 'cuda' once training has stopped.", ) parser.add_argument( "--output-json", default="/workspace/kinova_scene_sim/outputs/pi_actions_preview.json", ) parser.add_argument( "--output-gif", default="/workspace/kinova_scene_sim/outputs/pi_policy_preview.gif", ) parser.add_argument("--fps", type=int, default=4) args = parser.parse_args() if args.checkpoint_dir is None: if args.step is None: raise ValueError("provide either --checkpoint-dir or --step") checkpoint_dir = Path(args.checkpoint_root) / str(args.step) else: checkpoint_dir = Path(args.checkpoint_dir) if args.wait: wait_for_checkpoint(checkpoint_dir, args.poll_s) elif not checkpoint_ready(checkpoint_dir): raise FileNotFoundError(f"checkpoint not ready: {checkpoint_dir}") if args.stage_root is not None: checkpoint_dir = stage_checkpoint(checkpoint_dir, Path(args.stage_root)) print(f"staged checkpoint to {checkpoint_dir}", flush=True) if args.wait_for_pid_exit is not None: wait_for_pid_exit(args.wait_for_pid_exit, args.poll_s) export_cmd = [ args.python, "/workspace/kinova_scene_sim/export_pi_actions.py", "--checkpoint-dir", str(checkpoint_dir), "--session-root", args.session_root, "--sync-row-index", str(args.sync_row_index), "--prompt", args.prompt, "--output-json", args.output_json, ] export_env = os.environ.copy() export_env["JAX_PLATFORMS"] = args.jax_platforms subprocess.run(export_cmd, check=True, env=export_env) render_cmd = [ sys.executable, "/workspace/kinova_scene_sim/render_pose_sequence.py", "--session-root", args.session_root, "--poses-json", args.output_json, "--output", args.output_gif, "--fps", str(args.fps), ] subprocess.run(render_cmd, check=True) print(f"saved preview gif to {args.output_gif}") if __name__ == "__main__": main()