""" Replay episodes from HDF5 datasets and save rollout videos. Loads recorded joint actions from record_dataset_.h5, steps the environment, and writes side-by-side front/wrist camera videos to disk. """ import os import cv2 import h5py import imageio import numpy as np from robomme.robomme_env import * from robomme.robomme_env.utils import * from robomme.env_record_wrapper import BenchmarkEnvBuilder from robomme.robomme_env.utils import EE_POSE_ACTION_SPACE # --- Config --- GUI_RENDER = False REPLAY_VIDEO_DIR = "replay_videos" VIDEO_FPS = 30 MAX_STEPS = 1000 def _frame_from_obs(obs: dict, is_video_frame: bool = False) -> np.ndarray: """Build a single side-by-side frame from front and wrist camera obs.""" front = obs["front_camera"][0].cpu().numpy() wrist = obs["wrist_camera"][0].cpu().numpy() frame = np.concatenate([front, wrist], axis=1).astype(np.uint8) if is_video_frame: frame = cv2.rectangle( frame, (0, 0), (frame.shape[1], frame.shape[0]), (255, 0, 0), 10 ) return frame def _first_execution_step(episode_data) -> int: """Return the first step index that is not a video-demo step.""" step_idx = 0 while episode_data[f"timestep_{step_idx}"]["info"]["is_video_demo"][()]: step_idx += 1 return step_idx def process_episode(env_data: h5py.File, episode_idx: int, env_id: str) -> None: """Replay one episode from HDF5 data, record frames, and save a video.""" episode_data = env_data[f"episode_{episode_idx}"] task_goal = episode_data["setup"]["task_goal"][()].decode() total_steps = sum(1 for k in episode_data.keys() if k.startswith("timestep_")) step_idx = _first_execution_step(episode_data) print(f"Execution start step index: {step_idx}") env_builder = BenchmarkEnvBuilder( env_id=env_id, dataset="test", action_space=EE_POSE_ACTION_SPACE, gui_render=GUI_RENDER, ) env = env_builder.make_env_for_episode( episode_idx, max_steps=MAX_STEPS, include_maniskill_obs=True, include_front_depth=True, include_wrist_depth=True, include_front_camera_extrinsic=True, include_wrist_camera_extrinsic=True, include_available_multi_choices=True, include_front_camera_intrinsic=True, include_wrist_camera_intrinsic=True, ) print(f"task_name: {env_id}, episode_idx: {episode_idx}, task_goal: {task_goal}") obs, info = env.reset() # Obs lists: length 1 = no video, length > 1 = video; last element is current. frames = [] n_obs = len(obs["front_camera"]) for i in range(n_obs): single_obs = {k: [v[i]] for k, v in obs.items()} frames.append(_frame_from_obs(single_obs, is_video_frame=(i < n_obs - 1))) print(f"Initial frames (video + current): {len(frames)}") outcome = "unknown" try: while step_idx < total_steps: action = np.asarray( episode_data[f"timestep_{step_idx}"]["action"]["eef_action"][()], dtype=np.float32, ) obs, _, terminated, _, info = env.step(action) frames.append(_frame_from_obs(obs)) if GUI_RENDER: env.render() # TODO: hongze makes this correct # there are two many nested lists here, need to flatten them if terminated: if info.get("success", False)[-1][-1]: outcome = "success" if info.get("fail", False)[-1][-1]: outcome = "fail" break step_idx += 1 finally: env.close() safe_goal = task_goal.replace(" ", "_").replace("/", "_") os.makedirs(REPLAY_VIDEO_DIR, exist_ok=True) video_name = f"{outcome}_{env_id}_ep{episode_idx}_{safe_goal}_step-{len(frames)}.mp4" video_path = os.path.join(REPLAY_VIDEO_DIR, video_name) imageio.mimsave(video_path, frames, fps=VIDEO_FPS) print(f"Saved video to {video_path}") def replay(h5_data_dir: str = "/data/hongzefu/dataset_generate") -> None: """Replay all episodes from all task HDF5 files in the given directory.""" env_id_list = BenchmarkEnvBuilder.get_task_list() env_id_list =[ "PickXtimes", # "StopCube", # "SwingXtimes", # "BinFill", # "VideoUnmaskSwap", # "VideoUnmask", # "ButtonUnmaskSwap", # "ButtonUnmask", # "VideoRepick", # "VideoPlaceButton", # "VideoPlaceOrder", # "PickHighlight", # "InsertPeg", # 'MoveCube', # "PatternLock", # "RouteStick" ] for env_id in env_id_list: file_name = f"record_dataset_{env_id}.h5" file_path = os.path.join(h5_data_dir, file_name) if not os.path.exists(file_path): print(f"Skipping {env_id}: file not found: {file_path}") continue with h5py.File(file_path, "r") as data: episode_indices = sorted( int(k.split("_")[1]) for k in data.keys() if k.startswith("episode_") ) print(f"Task: {env_id}, has {len(episode_indices)} episodes") for episode_idx in episode_indices[:1]: process_episode(data, episode_idx, env_id) if __name__ == "__main__": import tyro tyro.cli(replay)