File size: 5,483 Bytes
06c11b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """
Replay episodes from HDF5 datasets and save rollout videos.
Loads recorded joint actions from record_dataset_<Task>.h5, steps the environment,
and writes side-by-side front/wrist camera videos to disk.
"""
import os
import cv2
import h5py
import imageio
import numpy as np
from robomme.robomme_env import *
from robomme.robomme_env.utils import *
from robomme.env_record_wrapper import BenchmarkEnvBuilder
from robomme.robomme_env.utils import EE_POSE_ACTION_SPACE
# --- Config ---
GUI_RENDER = False
REPLAY_VIDEO_DIR = "replay_videos"
VIDEO_FPS = 30
MAX_STEPS = 1000
def _frame_from_obs(obs: dict, is_video_frame: bool = False) -> np.ndarray:
"""Build a single side-by-side frame from front and wrist camera obs."""
front = obs["front_camera"][0].cpu().numpy()
wrist = obs["wrist_camera"][0].cpu().numpy()
frame = np.concatenate([front, wrist], axis=1).astype(np.uint8)
if is_video_frame:
frame = cv2.rectangle(
frame, (0, 0), (frame.shape[1], frame.shape[0]), (255, 0, 0), 10
)
return frame
def _first_execution_step(episode_data) -> int:
"""Return the first step index that is not a video-demo step."""
step_idx = 0
while episode_data[f"timestep_{step_idx}"]["info"]["is_video_demo"][()]:
step_idx += 1
return step_idx
def process_episode(env_data: h5py.File, episode_idx: int, env_id: str) -> None:
"""Replay one episode from HDF5 data, record frames, and save a video."""
episode_data = env_data[f"episode_{episode_idx}"]
task_goal = episode_data["setup"]["task_goal"][()].decode()
total_steps = sum(1 for k in episode_data.keys() if k.startswith("timestep_"))
step_idx = _first_execution_step(episode_data)
print(f"Execution start step index: {step_idx}")
env_builder = BenchmarkEnvBuilder(
env_id=env_id,
dataset="test",
action_space=EE_POSE_ACTION_SPACE,
gui_render=GUI_RENDER,
)
env = env_builder.make_env_for_episode(
episode_idx,
max_steps=MAX_STEPS,
include_maniskill_obs=True,
include_front_depth=True,
include_wrist_depth=True,
include_front_camera_extrinsic=True,
include_wrist_camera_extrinsic=True,
include_available_multi_choices=True,
include_front_camera_intrinsic=True,
include_wrist_camera_intrinsic=True,
)
print(f"task_name: {env_id}, episode_idx: {episode_idx}, task_goal: {task_goal}")
obs, info = env.reset()
# Obs lists: length 1 = no video, length > 1 = video; last element is current.
frames = []
n_obs = len(obs["front_camera"])
for i in range(n_obs):
single_obs = {k: [v[i]] for k, v in obs.items()}
frames.append(_frame_from_obs(single_obs, is_video_frame=(i < n_obs - 1)))
print(f"Initial frames (video + current): {len(frames)}")
outcome = "unknown"
try:
while step_idx < total_steps:
action = np.asarray(
episode_data[f"timestep_{step_idx}"]["action"]["eef_action"][()],
dtype=np.float32,
)
obs, _, terminated, _, info = env.step(action)
frames.append(_frame_from_obs(obs))
if GUI_RENDER:
env.render()
# TODO: hongze makes this correct
# there are two many nested lists here, need to flatten them
if terminated:
if info.get("success", False)[-1][-1]:
outcome = "success"
if info.get("fail", False)[-1][-1]:
outcome = "fail"
break
step_idx += 1
finally:
env.close()
safe_goal = task_goal.replace(" ", "_").replace("/", "_")
os.makedirs(REPLAY_VIDEO_DIR, exist_ok=True)
video_name = f"{outcome}_{env_id}_ep{episode_idx}_{safe_goal}_step-{len(frames)}.mp4"
video_path = os.path.join(REPLAY_VIDEO_DIR, video_name)
imageio.mimsave(video_path, frames, fps=VIDEO_FPS)
print(f"Saved video to {video_path}")
def replay(h5_data_dir: str = "/data/hongzefu/dataset_generate") -> None:
"""Replay all episodes from all task HDF5 files in the given directory."""
env_id_list = BenchmarkEnvBuilder.get_task_list()
env_id_list =[
"PickXtimes",
# "StopCube",
# "SwingXtimes",
# "BinFill",
# "VideoUnmaskSwap",
# "VideoUnmask",
# "ButtonUnmaskSwap",
# "ButtonUnmask",
# "VideoRepick",
# "VideoPlaceButton",
# "VideoPlaceOrder",
# "PickHighlight",
# "InsertPeg",
# 'MoveCube',
# "PatternLock",
# "RouteStick"
]
for env_id in env_id_list:
file_name = f"record_dataset_{env_id}.h5"
file_path = os.path.join(h5_data_dir, file_name)
if not os.path.exists(file_path):
print(f"Skipping {env_id}: file not found: {file_path}")
continue
with h5py.File(file_path, "r") as data:
episode_indices = sorted(
int(k.split("_")[1])
for k in data.keys()
if k.startswith("episode_")
)
print(f"Task: {env_id}, has {len(episode_indices)} episodes")
for episode_idx in episode_indices[:1]:
process_episode(data, episode_idx, env_id)
if __name__ == "__main__":
import tyro
tyro.cli(replay)
|