RoboMME / scripts /dev /eval_dataset_replay.py
HongzeFu's picture
HF Space: code-only (no binary assets)
06c11b0
# -*- coding: utf-8 -*-
# Script function: Unified dataset replay entry point, supporting 4 action spaces: joint_angle / ee_pose / waypoint / multi_choice.
# Consistent with subgoal_evaluate_func.py main loop; difference is actions come from EpisodeDatasetResolver.
import os
from typing import Any, Optional
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import cv2
import numpy as np
import torch
from robomme.robomme_env import *
from robomme.robomme_env.utils import *
from robomme.env_record_wrapper import (
BenchmarkEnvBuilder,
EpisodeDatasetResolver,
)
from robomme.env_record_wrapper.OraclePlannerDemonstrationWrapper import (
OraclePlannerDemonstrationWrapper,
)
from robomme.robomme_env.utils.choice_action_mapping import (
_unique_candidates,
extract_actor_position_xyz,
project_world_to_pixel,
select_target_with_pixel,
)
from robomme.robomme_env.utils.save_reset_video import save_robomme_video
# Only enable one ACTION_SPACE; others are commented out for manual switching
ACTION_SPACE = "joint_angle"
GUI_RENDER = False
DATASET_ROOT = "/data/hongzefu/data_0226"
DEFAULT_ENV_IDS = [
#"PickXtimes",
#"StopCube",
#"SwingXtimes",
"BinFill",
# "VideoUnmaskSwap",
# "VideoUnmask",
# "ButtonUnmaskSwap",
# "ButtonUnmask",
# "VideoRepick",
# "VideoPlaceButton",
# "VideoPlaceOrder",
#"PickHighlight",
#"InsertPeg",
#"MoveCube",
#"PatternLock",
#"RouteStick",
]
OUT_VIDEO_DIR = "/data/hongzefu/dataset_replay"
MAX_STEPS = 1000
def _parse_oracle_command(choice_action: Optional[Any]) -> Optional[dict[str, Any]]:
if not isinstance(choice_action, dict):
return None
choice = choice_action.get("choice")
if not isinstance(choice, str) or not choice.strip():
return None
point = choice_action.get("point")
if not isinstance(point, (list, tuple, np.ndarray)) or len(point) != 2:
return None
return choice_action
def _to_numpy_copy(value: Any) -> np.ndarray:
if isinstance(value, torch.Tensor):
value = value.detach().cpu().numpy()
else:
value = np.asarray(value)
return np.array(value, copy=True)
def _to_frame_list(frames_like: Any) -> list[np.ndarray]:
if frames_like is None:
return []
if isinstance(frames_like, torch.Tensor):
arr = frames_like.detach().cpu().numpy()
if arr.ndim == 3:
return [np.array(arr, copy=True)]
if arr.ndim == 4:
return [np.array(x, copy=True) for x in arr]
return []
if isinstance(frames_like, np.ndarray):
if frames_like.ndim == 3:
return [np.array(frames_like, copy=True)]
if frames_like.ndim == 4:
return [np.array(x, copy=True) for x in frames_like]
return []
if isinstance(frames_like, (list, tuple)):
out = []
for frame in frames_like:
if frame is None:
continue
out.append(_to_numpy_copy(frame))
return out
try:
arr = np.asarray(frames_like)
except Exception:
return []
if arr.ndim == 3:
return [np.array(arr, copy=True)]
if arr.ndim == 4:
return [np.array(x, copy=True) for x in arr]
return []
def _normalize_pixel_xy(pixel_like: Any) -> Optional[list[int]]:
if not isinstance(pixel_like, (list, tuple, np.ndarray)):
return None
if len(pixel_like) < 2:
return None
try:
x = float(pixel_like[0])
y = float(pixel_like[1])
except (TypeError, ValueError):
return None
if not np.isfinite(x) or not np.isfinite(y):
return None
return [int(np.rint(x)), int(np.rint(y))]
def _normalize_point_yx_to_pixel_xy(point_like: Any) -> Optional[list[int]]:
if not isinstance(point_like, (list, tuple, np.ndarray)):
return None
if len(point_like) < 2:
return None
try:
y = float(point_like[0])
x = float(point_like[1])
except (TypeError, ValueError):
return None
if not np.isfinite(x) or not np.isfinite(y):
return None
return [int(np.rint(x)), int(np.rint(y))]
def _find_oracle_wrapper(env_like: Any) -> Optional[OraclePlannerDemonstrationWrapper]:
current = env_like
visited: set[int] = set()
for _ in range(16):
if current is None:
return None
if isinstance(current, OraclePlannerDemonstrationWrapper):
return current
obj_id = id(current)
if obj_id in visited:
return None
visited.add(obj_id)
current = getattr(current, "env", None)
return None
def _collect_multi_choice_visualization(
env_like: Any,
command: dict[str, Any],
) -> tuple[list[list[int]], Optional[list[int]], Optional[list[int]]]:
clicked_pixel = _normalize_point_yx_to_pixel_xy(command.get("point"))
oracle_wrapper = _find_oracle_wrapper(env_like)
if oracle_wrapper is None:
return [], clicked_pixel, None
try:
_selected_target, solve_options = oracle_wrapper._build_step_options()
found_idx, _ = oracle_wrapper._resolve_command(command, solve_options)
except Exception:
return [], clicked_pixel, None
if found_idx is None or found_idx < 0 or found_idx >= len(solve_options):
return [], clicked_pixel, None
option = solve_options[found_idx]
available = option.get("available")
intrinsic_cv = getattr(oracle_wrapper, "_front_camera_intrinsic_cv", None)
extrinsic_cv = getattr(oracle_wrapper, "_front_camera_extrinsic_cv", None)
image_shape = getattr(oracle_wrapper, "_front_rgb_shape", None)
candidate_pixels: list[list[int]] = []
if available is not None:
for actor in _unique_candidates(available):
actor_pos = extract_actor_position_xyz(actor)
if actor_pos is None:
continue
projected = project_world_to_pixel(
actor_pos,
intrinsic_cv=intrinsic_cv,
extrinsic_cv=extrinsic_cv,
image_shape=image_shape,
)
if projected is None:
continue
candidate_pixels.append([int(projected[0]), int(projected[1])])
matched_pixel: Optional[list[int]] = None
if available is not None and clicked_pixel is not None:
matched = select_target_with_pixel(
available=available,
pixel_like=clicked_pixel,
intrinsic_cv=intrinsic_cv,
extrinsic_cv=extrinsic_cv,
image_shape=image_shape,
)
if isinstance(matched, dict):
matched_pixel = _normalize_pixel_xy(matched.get("projected_pixel"))
return candidate_pixels, clicked_pixel, matched_pixel
def _make_blackboard(frame_like: Any) -> np.ndarray:
frame = _to_numpy_copy(frame_like)
if frame.ndim < 2:
return np.zeros((1, 1, 3), dtype=np.uint8)
h, w = int(frame.shape[0]), int(frame.shape[1])
if h <= 0 or w <= 0:
return np.zeros((1, 1, 3), dtype=np.uint8)
return np.zeros((h, w, 3), dtype=np.uint8)
def _draw_candidate_blackboard(
frame_like: Any,
candidate_pixels: list[list[int]],
) -> np.ndarray:
board = _make_blackboard(frame_like)
for pixel in candidate_pixels:
if len(pixel) < 2:
continue
cv2.circle(board, (int(pixel[0]), int(pixel[1])), 4, (0, 255, 255), 1)
return board
def _draw_selection_blackboard(
frame_like: Any,
clicked_pixel: Optional[list[int]],
matched_pixel: Optional[list[int]],
) -> np.ndarray:
board = _make_blackboard(frame_like)
if clicked_pixel is not None:
cv2.drawMarker(
board,
(int(clicked_pixel[0]), int(clicked_pixel[1])),
(255, 255, 0),
markerType=cv2.MARKER_TILTED_CROSS,
markerSize=10,
thickness=1,
)
if matched_pixel is not None:
cv2.circle(board, (int(matched_pixel[0]), int(matched_pixel[1])), 5, (255, 0, 0), 2)
return board
def main():
from robomme.logging_utils import setup_logging
setup_logging(level="DEBUG")
env_id_list = BenchmarkEnvBuilder.get_task_list()
print(f"Running envs: {env_id_list}")
print(f"Using action_space: {ACTION_SPACE}")
#for env_id in env_id_list:
for env_id in DEFAULT_ENV_IDS:
env_builder = BenchmarkEnvBuilder(
env_id=env_id,
dataset="train",
action_space=ACTION_SPACE,
gui_render=GUI_RENDER,
)
episode_count = env_builder.get_episode_num()
print(f"[{env_id}] episode_count from metadata: {episode_count}")
env = None
for episode in range(episode_count):
if episode !=15:
continue
env = env_builder.make_env_for_episode(
episode,
max_steps=MAX_STEPS,
include_maniskill_obs=True,
include_front_depth=True,
include_wrist_depth=True,
include_front_camera_extrinsic=True,
include_wrist_camera_extrinsic=True,
include_available_multi_choices=True,
include_front_camera_intrinsic=True,
include_wrist_camera_intrinsic=True,
)
try:
dataset_resolver = EpisodeDatasetResolver(
env_id=env_id,
episode=episode,
dataset_directory=DATASET_ROOT,
)
except KeyError as e:
print(f"[{env_id}] Episode {episode} missing in H5, skipping. ({e})")
if env is not None:
env.close()
continue
# ======== Reset ========
# obs: dict-of-lists (columnar batch, list length = number of demo frames)
# info: flat dict (last frame values only)
obs, info = env.reset()
# --- Explicitly read all obs fields (each is a list) ---
maniskill_obs = obs["maniskill_obs"]
front_rgb_list = _to_frame_list(obs["front_rgb_list"])
wrist_rgb_list = _to_frame_list(obs["wrist_rgb_list"])
front_depth_list = obs["front_depth_list"]
wrist_depth_list = obs["wrist_depth_list"]
end_effector_pose_raw = obs["end_effector_pose_raw"]
eef_state_list = obs["eef_state_list"]
joint_state_list = obs["joint_state_list"]
# velocity = obs["velocity"]
gripper_state_list = obs["gripper_state_list"]
front_camera_extrinsic_list = obs["front_camera_extrinsic_list"]
wrist_camera_extrinsic_list = obs["wrist_camera_extrinsic_list"]
# --- Explicitly read all info fields (flat dict, last frame values) ---
task_goal = info["task_goal"]
simple_subgoal_online = info["simple_subgoal_online"]
grounded_subgoal_online = info["grounded_subgoal_online"]
available_multi_choices = info.get("available_multi_choices")
front_camera_intrinsic = info["front_camera_intrinsic"]
wrist_camera_intrinsic = info["wrist_camera_intrinsic"]
status = info.get("status")
# --- Video saving variable preparation (reset phase) ---
reset_base_frames = [_to_numpy_copy(f) for f in front_rgb_list]
reset_wrist_frames = [_to_numpy_copy(f) for f in wrist_rgb_list]
reset_right_frames = (
[_make_blackboard(f) for f in reset_base_frames]
if ACTION_SPACE == "multi_choice"
else None
)
reset_far_right_frames = (
[_make_blackboard(f) for f in reset_base_frames]
if ACTION_SPACE == "multi_choice"
else None
)
reset_subgoal_grounded = [grounded_subgoal_online] * len(front_rgb_list)
step = 0
episode_success = False
rollout_base_frames: list[np.ndarray] = []
rollout_wrist_frames: list[np.ndarray] = []
rollout_right_frames: list[np.ndarray] = []
rollout_far_right_frames: list[np.ndarray] = []
rollout_subgoal_grounded: list[Any] = []
# ======== Step loop ========
while True:
replay_key = ACTION_SPACE
action = dataset_resolver.get_step(replay_key, step)
if ACTION_SPACE == "multi_choice":
action = _parse_oracle_command(action)
if action is None:
break
candidate_pixels: list[list[int]] = []
clicked_pixel: Optional[list[int]] = None
matched_pixel: Optional[list[int]] = None
if ACTION_SPACE == "multi_choice":
candidate_pixels, clicked_pixel, matched_pixel = _collect_multi_choice_visualization(
env, action
)
# step returns: obs (dict-of-lists), reward (scalar tensor),
# terminated (scalar tensor), truncated (scalar tensor), info (flat dict)
obs, reward, terminated, truncated, info = env.step(action)
# --- Explicitly read all obs fields (dict-of-lists, typically 1 element per list) -- maniskill_obs = obs["maniskill_obs"]
front_rgb_list = _to_frame_list(obs["front_rgb_list"])
wrist_rgb_list = _to_frame_list(obs["wrist_rgb_list"])
front_depth_list = obs["front_depth_list"]
wrist_depth_list = obs["wrist_depth_list"]
end_effector_pose_raw = obs["end_effector_pose_raw"]
eef_state_list = obs["eef_state_list"]
joint_state_list = obs["joint_state_list"]
gripper_state_list = obs["gripper_state_list"]
front_camera_extrinsic_list = obs["front_camera_extrinsic_list"]
wrist_camera_extrinsic_list = obs["wrist_camera_extrinsic_list"]
# --- Explicitly read all info fields (flat dict) ---
task_goal = info["task_goal"]
simple_subgoal_online = info["simple_subgoal_online"]
grounded_subgoal_online = info["grounded_subgoal_online"]
available_multi_choices = info.get("available_multi_choices")
front_camera_intrinsic = info["front_camera_intrinsic"]
wrist_camera_intrinsic = info["wrist_camera_intrinsic"]
status = info.get("status")
# --- Video saving variable preparation (replay phase) ---
rollout_base_frames.extend(
_to_numpy_copy(f) for f in front_rgb_list
)
rollout_wrist_frames.extend(
_to_numpy_copy(f) for f in wrist_rgb_list
)
if ACTION_SPACE == "multi_choice":
for base_frame in front_rgb_list:
rollout_right_frames.append(
_draw_candidate_blackboard(
base_frame,
candidate_pixels=candidate_pixels,
)
)
rollout_far_right_frames.append(
_draw_selection_blackboard(
base_frame,
clicked_pixel=clicked_pixel,
matched_pixel=matched_pixel,
)
)
rollout_subgoal_grounded.extend([grounded_subgoal_online] * len(front_rgb_list))
terminated_flag = bool(terminated.item())
truncated_flag = bool(truncated.item())
step += 1
if GUI_RENDER:
env.render()
if truncated_flag:
print(f"[{env_id}] episode {episode} steps exceeded, step {step}.")
break
if terminated_flag:
if status == "success":
print(f"[{env_id}] episode {episode} success.")
episode_success = True
elif status == "fail":
print(f"[{env_id}] episode {episode} failed.")
break
# ======== Video saving ========
save_robomme_video(
reset_base_frames=reset_base_frames,
reset_wrist_frames=reset_wrist_frames,
rollout_base_frames=rollout_base_frames,
rollout_wrist_frames=rollout_wrist_frames,
reset_subgoal_grounded=reset_subgoal_grounded,
rollout_subgoal_grounded=rollout_subgoal_grounded,
out_video_dir=OUT_VIDEO_DIR,
action_space=ACTION_SPACE,
env_id=env_id,
episode=episode,
episode_success=episode_success,
reset_right_frames=reset_right_frames if ACTION_SPACE == "multi_choice" else None,
rollout_right_frames=rollout_right_frames if ACTION_SPACE == "multi_choice" else None,
reset_far_right_frames=(
reset_far_right_frames if ACTION_SPACE == "multi_choice" else None
),
rollout_far_right_frames=(
rollout_far_right_frames if ACTION_SPACE == "multi_choice" else None
),
)
if env is not None:
env.close()
if __name__ == "__main__":
main()