RoboMME / scripts /run_example.py
HongzeFu's picture
HF Space: code-only (no binary assets)
06c11b0
"""
Run a single benchmark episode and save the rollout as a video.
Use this script to sanity-check the environment and action space
"""
from pathlib import Path
from typing import Literal
import cv2
import imageio
import numpy as np
import torch
import tyro
from robomme.env_record_wrapper import BenchmarkEnvBuilder
from robomme.robomme_env.utils import generate_sample_actions
GUI_RENDER = False
VIDEO_FPS = 30
VIDEO_OUTPUT_DIR = "sample_run_videos"
MAX_STEPS = 300
EPISODE_LIMITS = {"train": 100, "test": 50, "val": 50}
VIDEO_BORDER_COLOR = (255, 0, 0)
VIDEO_BORDER_THICKNESS = 10
TaskID = Literal[
"BinFill", "PickXtimes", "SwingXtimes", "StopCube",
"VideoUnmask", "VideoUnmaskSwap", "ButtonUnmask", "ButtonUnmaskSwap",
"PickHighlight", "VideoRepick", "VideoPlaceButton", "VideoPlaceOrder",
"MoveCube", "InsertPeg", "PatternLock", "RouteStick",
"All",
]
ActionSpaceType = Literal["joint_angle", "ee_pose", "waypoint", "multi_choice"]
DatasetType = Literal["train", "test", "val"]
def _to_numpy(t) -> np.ndarray:
return t.cpu().numpy() if isinstance(t, torch.Tensor) else np.asarray(t)
def _frame_from_obs(
front: np.ndarray | torch.Tensor,
wrist: np.ndarray | torch.Tensor,
is_video_demo: bool = False,
) -> np.ndarray:
frame = np.hstack([_to_numpy(front), _to_numpy(wrist)]).astype(np.uint8)
if is_video_demo:
h, w = frame.shape[:2]
cv2.rectangle(frame, (0, 0), (w, h),
VIDEO_BORDER_COLOR, VIDEO_BORDER_THICKNESS)
return frame
def _validate_episode_index(episode_idx: int, dataset: DatasetType) -> None:
if episode_idx == -1:
return
limit = EPISODE_LIMITS[dataset]
if not 0 <= episode_idx < limit:
raise ValueError(
f"Invalid episode_idx {episode_idx} for '{dataset}'; allowed: [0, {limit})"
)
def _save_video(
frames: list[np.ndarray],
task_id: str,
episode_idx: int,
action_space_type: str,
task_goal: str,
) -> Path:
video_dir = Path(VIDEO_OUTPUT_DIR) / action_space_type
video_dir.mkdir(parents=True, exist_ok=True)
path = video_dir / f"{task_id}_ep{episode_idx}_{task_goal}.mp4"
imageio.mimsave(str(path), frames, fps=VIDEO_FPS)
return path
def _extract_frames(obs: dict, is_video_demo_fn=None) -> list[np.ndarray]:
n = len(obs["front_rgb_list"])
return [
_frame_from_obs(
obs["front_rgb_list"][i],
obs["wrist_rgb_list"][i],
is_video_demo=(is_video_demo_fn(i) if is_video_demo_fn else False),
)
for i in range(n)
]
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main(
dataset: DatasetType = "test",
task_id: TaskID = "PickXtimes",
action_space_type: ActionSpaceType = "joint_angle",
episode_idx: int = 0,
) -> None:
"""
Run a single benchmark episode and save the rollout as a video.
Args:
action_space_type: Type of action space to use.
dataset: Dataset split (train / test / val).
task_id: Task identifier, or "All" to run every task.
episode_idx: Episode index (-1 = All episodes).
"""
task_ids = (
BenchmarkEnvBuilder.get_task_list() if task_id == "All" else [task_id]
)
_validate_episode_index(episode_idx, dataset)
for tid in task_ids:
env_builder = BenchmarkEnvBuilder(
env_id=tid,
dataset=dataset,
action_space=action_space_type,
gui_render=GUI_RENDER,
max_steps=MAX_STEPS,
)
episodes = (
list(range(env_builder.get_episode_num()))
if episode_idx == -1
else [episode_idx]
)
for ep in episodes:
print(f"\nRunning task: {tid}, episode: {ep}, action_space: {action_space_type}, dataset: {dataset}")
env = env_builder.make_env_for_episode(
ep,
include_maniskill_obs=True,
include_front_depth=True,
include_wrist_depth=True,
include_front_camera_extrinsic=True,
include_wrist_camera_extrinsic=True,
include_available_multi_choices=True,
include_front_camera_intrinsic=True,
include_wrist_camera_intrinsic=True,
)
obs, info = env.reset()
task_goal = info["task_goal"]
if isinstance(task_goal, list):
task_goal = task_goal[0]
print(f"Task goal: {task_goal}")
frames = _extract_frames(
obs, is_video_demo_fn=lambda i, n=len(obs["front_rgb_list"]): i < n - 1
)
action_gen = generate_sample_actions(action_space_type, env=env)
for action in action_gen:
obs, _, terminated, truncated, info = env.step(action)
status = info.get("status", "unknown")
if status == "error":
print(f"Step error: {info.get('error_message', 'unknown error')}")
break
frames.extend(_extract_frames(obs))
if GUI_RENDER:
env.render()
if terminated or truncated:
print(f"Outcome: {status}")
break
env.close()
path = _save_video(frames, tid, ep, action_space_type, task_goal)
print(f"Saved video: {path}\n")
if __name__ == "__main__":
tyro.cli(main)