| | """ |
| | Replay episodes from HDF5 datasets and save rollout videos. |
| | Loads recorded actions from record_dataset_<Task>.h5, steps the environment |
| | """ |
| |
|
| | import os |
| | os.environ["CUDA_VISIBLE_DEVICES"] = "1" |
| |
|
| | import json |
| | from pathlib import Path |
| | from typing import Any, Dict, Literal, Union |
| |
|
| | import cv2 |
| | import h5py |
| | import imageio |
| | import numpy as np |
| | import torch |
| |
|
| | from robomme.env_record_wrapper import BenchmarkEnvBuilder |
| |
|
| | GUI_RENDER = False |
| | REPLAY_VIDEO_DIR = "replay_videos" |
| | VIDEO_FPS = 30 |
| | VIDEO_BORDER_COLOR = (255, 0, 0) |
| | VIDEO_BORDER_THICKNESS = 10 |
| |
|
| | TaskID = Literal[ |
| | |
| | |
| | |
| | |
| | |
| | "VideoUnmaskSwap", |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | ] |
| |
|
| |
|
| | ActionSpaceType = Literal["joint_angle", "ee_pose", "waypoint", "multi_choice"] |
| |
|
| | def _to_numpy(t) -> np.ndarray: |
| | return t.cpu().numpy() if isinstance(t, torch.Tensor) else np.asarray(t) |
| |
|
| |
|
| | def _frame_from_obs( |
| | front: np.ndarray | torch.Tensor, |
| | wrist: np.ndarray | torch.Tensor, |
| | is_video_demo: bool = False, |
| | ) -> np.ndarray: |
| | frame = np.hstack([_to_numpy(front), _to_numpy(wrist)]).astype(np.uint8) |
| | if is_video_demo: |
| | h, w = frame.shape[:2] |
| | cv2.rectangle(frame, (0, 0), (w, h), |
| | VIDEO_BORDER_COLOR, VIDEO_BORDER_THICKNESS) |
| | return frame |
| |
|
| |
|
| | def _extract_frames(obs: dict, is_video_demo_fn=None) -> list[np.ndarray]: |
| | n = len(obs["front_rgb_list"]) |
| | return [ |
| | _frame_from_obs( |
| | obs["front_rgb_list"][i], |
| | obs["wrist_rgb_list"][i], |
| | is_video_demo=(is_video_demo_fn(i) if is_video_demo_fn else False), |
| | ) |
| | for i in range(n) |
| | ] |
| |
|
| |
|
| | def _is_video_demo(ts: h5py.Group) -> bool: |
| | info = ts.get("info") |
| | if info is None or "is_video_demo" not in info: |
| | return False |
| | return bool(np.reshape(np.asarray(info["is_video_demo"][()]), -1)[0]) |
| |
|
| |
|
| | def _is_subgoal_boundary(ts: h5py.Group) -> bool: |
| | info = ts.get("info") |
| | if info is None or "is_subgoal_boundary" not in info: |
| | return False |
| | return bool(np.reshape(np.asarray(info["is_subgoal_boundary"][()]), -1)[0]) |
| |
|
| |
|
| | def _decode_h5_str(raw) -> str: |
| | """Uniformly decode bytes / numpy bytes / str from HDF5 to str.""" |
| | if isinstance(raw, np.ndarray): |
| | raw = raw.flatten()[0] |
| | if isinstance(raw, (bytes, np.bytes_)): |
| | raw = raw.decode("utf-8") |
| | return raw |
| |
|
| |
|
| | def _build_action_sequence( |
| | episode_data: h5py.Group, action_space_type: str |
| | ) -> list[Union[np.ndarray, Dict[str, Any]]]: |
| | """ |
| | Scan the entire episode and return the deduplicated action sequence: |
| | - joint_angle / ee_pose: actions of all non-video-demo steps (sequential, not deduplicated) |
| | - waypoint: remove adjacent duplicate waypoint_action (like EpisodeDatasetResolver) |
| | - multi_choice: choice_action (JSON dict) only for steps where is_subgoal_boundary=True |
| | """ |
| | timestep_keys = sorted( |
| | (k for k in episode_data.keys() if k.startswith("timestep_")), |
| | key=lambda k: int(k.split("_")[1]), |
| | ) |
| |
|
| | actions: list[Union[np.ndarray, Dict[str, Any]]] = [] |
| | prev_waypoint: np.ndarray | None = None |
| |
|
| | for key in timestep_keys: |
| | ts = episode_data[key] |
| | if _is_video_demo(ts): |
| | continue |
| |
|
| | action_grp = ts.get("action") |
| | if action_grp is None: |
| | continue |
| |
|
| | if action_space_type == "joint_angle": |
| | if "joint_action" not in action_grp: |
| | continue |
| | actions.append(np.asarray(action_grp["joint_action"][()], dtype=np.float32)) |
| |
|
| | elif action_space_type == "ee_pose": |
| | if "eef_action" not in action_grp: |
| | continue |
| | actions.append(np.asarray(action_grp["eef_action"][()], dtype=np.float32)) |
| |
|
| | elif action_space_type == "waypoint": |
| | if "waypoint_action" not in action_grp: |
| | continue |
| | wa = np.asarray(action_grp["waypoint_action"][()], dtype=np.float32).flatten() |
| | if wa.shape != (7,) or not np.all(np.isfinite(wa)): |
| | continue |
| | |
| | if prev_waypoint is None or not np.array_equal(wa, prev_waypoint): |
| | actions.append(wa) |
| | prev_waypoint = wa.copy() |
| |
|
| | elif action_space_type == "multi_choice": |
| | if not _is_subgoal_boundary(ts): |
| | continue |
| | if "choice_action" not in action_grp: |
| | continue |
| | raw = _decode_h5_str(action_grp["choice_action"][()]) |
| | try: |
| | payload = json.loads(raw) |
| | except (TypeError, ValueError, json.JSONDecodeError): |
| | continue |
| | if not isinstance(payload, dict): |
| | continue |
| | choice = payload.get("choice") |
| | if not isinstance(choice, str) or not choice.strip(): |
| | continue |
| | if "point" not in payload: |
| | continue |
| | actions.append({"choice": choice, "point": payload.get("point")}) |
| |
|
| | else: |
| | raise ValueError(f"Unknown action space type: {action_space_type}") |
| |
|
| | return actions |
| |
|
| |
|
| | def _save_video( |
| | frames: list[np.ndarray], |
| | task_id: str, |
| | episode_idx: int, |
| | task_goal: str, |
| | outcome: str, |
| | action_space_type: str, |
| | ) -> Path: |
| | video_dir = Path(REPLAY_VIDEO_DIR) / action_space_type |
| | video_dir.mkdir(parents=True, exist_ok=True) |
| | name = f"{outcome}_{task_id}_ep{episode_idx}_{task_goal}.mp4" |
| | path = video_dir / name |
| | imageio.mimsave(str(path), frames, fps=VIDEO_FPS) |
| | return path |
| |
|
| |
|
| | def _get_episode_indices(data: h5py.File) -> list[int]: |
| | return sorted( |
| | int(key.split("_")[1]) |
| | for key in data.keys() |
| | if key.startswith("episode_") |
| | ) |
| |
|
| |
|
| | def process_episode( |
| | env_data: h5py.File, |
| | episode_idx: int, |
| | task_id: str, |
| | action_space_type: ActionSpaceType, |
| | ) -> None: |
| | """Replay one episode from HDF5 data, record frames, and save a video.""" |
| | episode_data = env_data[f"episode_{episode_idx}"] |
| | task_goal = episode_data["setup"]["task_goal"][()][0].decode() |
| | action_sequence = _build_action_sequence(episode_data, action_space_type) |
| |
|
| | env = BenchmarkEnvBuilder( |
| | env_id=task_id, |
| | dataset="train", |
| | action_space=action_space_type, |
| | gui_render=GUI_RENDER, |
| | ).make_env_for_episode(episode_idx) |
| |
|
| | print(f"\nTask: {task_id}, Episode: {episode_idx}, ", |
| | f"Seed: {env.unwrapped.seed}, Difficulty: {env.unwrapped.difficulty}") |
| | print(f"Task goal: {task_goal}") |
| | print(f"Total actions after dedup: {len(action_sequence)}") |
| |
|
| | obs, _ = env.reset() |
| | frames = _extract_frames( |
| | obs, is_video_demo_fn=lambda i, n=len(obs["front_rgb_list"]): i < n - 1 |
| | ) |
| |
|
| | outcome = "unknown" |
| | for seq_idx, action in enumerate(action_sequence): |
| | try: |
| | obs, _, terminated, truncated, info = env.step(action) |
| | frames.extend(_extract_frames(obs)) |
| | except Exception as e: |
| | print(f"Error at seq_idx {seq_idx}: {e}") |
| | break |
| |
|
| | if GUI_RENDER: |
| | env.render() |
| | if terminated or truncated: |
| | outcome = info.get("status", "unknown") |
| | print(f"Outcome: {outcome}") |
| | break |
| |
|
| | env.close() |
| | path = _save_video(frames, task_id, episode_idx, task_goal, outcome, action_space_type) |
| | print(f"Saved video to {path}\n") |
| |
|
| |
|
| | def replay( |
| | h5_data_dir: str = "/data/hongzefu/data_0226", |
| | action_space_type: ActionSpaceType = "ee_pose", |
| | replay_number: int = 10, |
| | ) -> None: |
| | """Replay episodes from HDF5 dataset files and save rollout videos.""" |
| | |
| | for task_id in ["VideoUnmaskSwap"]: |
| | file_path = Path(h5_data_dir) / f"record_dataset_{task_id}.h5" |
| |
|
| | if not file_path.exists(): |
| | print(f"Skipping {task_id}: file not found: {file_path}") |
| | continue |
| |
|
| | with h5py.File(file_path, "r") as data: |
| | episode_indices = _get_episode_indices(data) |
| | for episode_idx in episode_indices[:min(replay_number, len(episode_indices))]: |
| | process_episode(data, episode_idx, task_id, action_space_type) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import tyro |
| | tyro.cli(replay) |
| |
|