| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import time |
| from pathlib import Path |
|
|
| import numpy as np |
| import pandas as pd |
| from PIL import Image |
|
|
|
|
| def main() -> None: |
| t0 = time.time() |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint-dir", required=True, help="Local checkpoint directory, e.g. .../4000") |
| parser.add_argument("--session-root", default="/workspace/data/teddybear_raw/session_20260327_165944_bear") |
| parser.add_argument("--sync-row-index", type=int, default=0) |
| parser.add_argument("--config-name", default="pi05_kinova_teddybear") |
| parser.add_argument("--output-json", default="/workspace/kinova_scene_sim/outputs/pi_actions.json") |
| parser.add_argument("--prompt", default="pick up the teddy bear and place it in the red box") |
| args = parser.parse_args() |
|
|
| session_root = Path(args.session_root) |
| print(f"[export] loading session from {session_root}", flush=True) |
| sync = pd.read_csv(session_root / "sync_index.csv") |
| row = sync.iloc[args.sync_row_index] |
|
|
| azure_rgb = np.array(Image.open(session_root / row["azure_rgb_file"]).convert("RGB")) |
| wrist_rgb = np.array(Image.open(session_root / row["rgb_file"]).convert("RGB")) |
| state = np.array( |
| [ |
| row["tool_x_m"], |
| row["tool_y_m"], |
| row["tool_z_m"], |
| row["tool_theta_x_deg"], |
| row["tool_theta_y_deg"], |
| row["tool_theta_z_deg"], |
| row["gripper_pos"], |
| ], |
| dtype=np.float32, |
| ) |
| print(f"[export] loaded observation in {time.time() - t0:.2f}s", flush=True) |
|
|
| print("[export] importing openpi modules", flush=True) |
| from openpi.policies import policy_config as _policy_config |
| from openpi.training import config as _config |
|
|
| print(f"[export] imported openpi in {time.time() - t0:.2f}s", flush=True) |
| print(f"[export] building train config {args.config_name}", flush=True) |
| train_config = _config.get_config(args.config_name) |
| print(f"[export] creating policy from {args.checkpoint_dir}", flush=True) |
| policy = _policy_config.create_trained_policy(train_config, Path(args.checkpoint_dir)) |
| print(f"[export] policy ready in {time.time() - t0:.2f}s", flush=True) |
| print("[export] running inference", flush=True) |
| results = policy.infer( |
| { |
| "observation/state": state, |
| "observation/image": azure_rgb, |
| "observation/wrist_image": wrist_rgb, |
| "prompt": args.prompt, |
| } |
| ) |
| print(f"[export] inference done in {time.time() - t0:.2f}s", flush=True) |
|
|
| actions = np.asarray(results["actions"], dtype=np.float32) |
| output = { |
| "checkpoint_dir": str(Path(args.checkpoint_dir)), |
| "session_root": str(session_root), |
| "sync_row_index": int(args.sync_row_index), |
| "azure_rgb_file": str(row["azure_rgb_file"]), |
| "wrist_rgb_file": str(row["rgb_file"]), |
| "state": state.tolist(), |
| "actions": actions.tolist(), |
| } |
|
|
| output_path = Path(args.output_json) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| output_path.write_text(json.dumps(output, indent=2)) |
| print(f"saved action chunk to {output_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|