TeddyBearKinova / bundle /sim /export_pi_actions.py
lsnu's picture
Upload folder using huggingface_hub
d93804e verified
from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image
def main() -> None:
t0 = time.time()
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint-dir", required=True, help="Local checkpoint directory, e.g. .../4000")
parser.add_argument("--session-root", default="/workspace/data/teddybear_raw/session_20260327_165944_bear")
parser.add_argument("--sync-row-index", type=int, default=0)
parser.add_argument("--config-name", default="pi05_kinova_teddybear")
parser.add_argument("--output-json", default="/workspace/kinova_scene_sim/outputs/pi_actions.json")
parser.add_argument("--prompt", default="pick up the teddy bear and place it in the red box")
args = parser.parse_args()
session_root = Path(args.session_root)
print(f"[export] loading session from {session_root}", flush=True)
sync = pd.read_csv(session_root / "sync_index.csv")
row = sync.iloc[args.sync_row_index]
azure_rgb = np.array(Image.open(session_root / row["azure_rgb_file"]).convert("RGB"))
wrist_rgb = np.array(Image.open(session_root / row["rgb_file"]).convert("RGB"))
state = np.array(
[
row["tool_x_m"],
row["tool_y_m"],
row["tool_z_m"],
row["tool_theta_x_deg"],
row["tool_theta_y_deg"],
row["tool_theta_z_deg"],
row["gripper_pos"],
],
dtype=np.float32,
)
print(f"[export] loaded observation in {time.time() - t0:.2f}s", flush=True)
print("[export] importing openpi modules", flush=True)
from openpi.policies import policy_config as _policy_config
from openpi.training import config as _config
print(f"[export] imported openpi in {time.time() - t0:.2f}s", flush=True)
print(f"[export] building train config {args.config_name}", flush=True)
train_config = _config.get_config(args.config_name)
print(f"[export] creating policy from {args.checkpoint_dir}", flush=True)
policy = _policy_config.create_trained_policy(train_config, Path(args.checkpoint_dir))
print(f"[export] policy ready in {time.time() - t0:.2f}s", flush=True)
print("[export] running inference", flush=True)
results = policy.infer(
{
"observation/state": state,
"observation/image": azure_rgb,
"observation/wrist_image": wrist_rgb,
"prompt": args.prompt,
}
)
print(f"[export] inference done in {time.time() - t0:.2f}s", flush=True)
actions = np.asarray(results["actions"], dtype=np.float32)
output = {
"checkpoint_dir": str(Path(args.checkpoint_dir)),
"session_root": str(session_root),
"sync_row_index": int(args.sync_row_index),
"azure_rgb_file": str(row["azure_rgb_file"]),
"wrist_rgb_file": str(row["rgb_file"]),
"state": state.tolist(),
"actions": actions.tolist(),
}
output_path = Path(args.output_json)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json.dumps(output, indent=2))
print(f"saved action chunk to {output_path}")
if __name__ == "__main__":
main()