RoboMME / src /robomme /env_record_wrapper /MultiStepDemonstrationWrapper.py
HongzeFu's picture
HF Space: code-only (no binary assets)
06c11b0
"""
MultiStepDemonstrationWrapper: Wraps DemonstrationWrapper to provide waypoint step interface.
Each step(action) receives action = waypoint_p(3) + rpy(3) + gripper_action(1), total 7 dimensions.
Internally converts RPY to quat then calls move_to_pose_with_screw and close_gripper/open_gripper via planner_denseStep,
where PatternLock/RouteStick will force skip close_gripper/open_gripper.
Returns obs as dictionary-of-lists, and reward/terminated/truncated as the last step value.
Caller must ensure scripts/ is in sys.path to import planner_fail_safe.
"""
import numpy as np
import sapien
import torch
import gymnasium as gym
from ..robomme_env.utils import planner_denseStep
from ..robomme_env.utils.rpy_util import rpy_xyz_to_quat_wxyz_torch
from ..robomme_env.utils.planner_fail_safe import ScrewPlanFailure
DATASET_SCREW_MAX_ATTEMPTS = 3
DATASET_RRT_MAX_ATTEMPTS = 3
class RRTPlanFailure(RuntimeError):
"""Raised when move_to_pose_with_RRTStar returns -1 (planning failed)."""
class MultiStepDemonstrationWrapper(gym.Wrapper):
"""
Wraps DemonstrationWrapper. step(action) interprets action as
(waypoint_p, rpy, gripper_action) total 7 dims, internally converts RPY to quat,
executes planning via planner_denseStep, and returns last-step signals.
"""
def __init__(self, env, gui_render=True, vis=True, **kwargs):
super().__init__(env)
self._planner = None
self._gui_render = gui_render
self._vis = vis
self.action_space = gym.spaces.Box(
low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64
)
@staticmethod
def _batch_to_steps(batch):
obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = batch
n = int(reward_batch.numel())
steps = []
obs_keys = list(obs_batch.keys())
info_keys = list(info_batch.keys())
for idx in range(n):
obs = {k: obs_batch[k][idx] for k in obs_keys}
info = {k: info_batch[k][idx] for k in info_keys}
reward = reward_batch[idx]
terminated = terminated_batch[idx]
truncated = truncated_batch[idx]
steps.append((obs, reward, terminated, truncated, info))
return steps
@staticmethod
def _flatten_info_batch(info_batch):
return {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()}
@staticmethod
def _take_last_step_value(value):
if isinstance(value, torch.Tensor):
if value.numel() == 0 or value.ndim == 0:
return value
return value.reshape(-1)[-1]
if isinstance(value, np.ndarray):
if value.size == 0 or value.ndim == 0:
return value
return value.reshape(-1)[-1]
if isinstance(value, (list, tuple)):
return value[-1] if value else value
return value
def _get_planner(self):
if self._planner is not None:
return self._planner
from ..robomme_env.utils.planner_fail_safe import (
FailAwarePandaArmMotionPlanningSolver,
FailAwarePandaStickMotionPlanningSolver,
)
env_id = self.env.unwrapped.spec.id
base_pose = self.env.unwrapped.agent.robot.pose
if env_id in ("PatternLock", "RouteStick"):
self._planner = FailAwarePandaStickMotionPlanningSolver(
self.env,
debug=False,
vis=self._vis,
base_pose=base_pose,
visualize_target_grasp_pose=False,
print_env_info=False,
joint_vel_limits=0.3,
)
else:
self._planner = FailAwarePandaArmMotionPlanningSolver(
self.env,
debug=False,
vis=self._vis,
base_pose=base_pose,
visualize_target_grasp_pose=True,
print_env_info=False,
)
return self._planner
def _current_tcp_p(self):
current_pose = self.env.unwrapped.agent.tcp.pose
p = current_pose.p
if hasattr(p, "cpu"):
p = p.cpu().numpy()
p = np.asarray(p).flatten()
return p
def _no_op_step(self):
"""Execute one step using current qpos + gripper, without moving arm, only to get observation."""
robot = self.env.unwrapped.agent.robot
qpos = robot.get_qpos().cpu().numpy().flatten()
arm = qpos[:7]
gripper = float(qpos[7]) if len(qpos) > 7 else 0.0
action = np.hstack([arm, gripper])
return self.env.step(action)
def step(self, action):
"""Execute waypoint step and return last-step signals for reward/terminated/truncated."""
action = np.asarray(action, dtype=np.float64).flatten()
if action.size < 7:
raise ValueError(f"action must have at least 7 elements, got {action.size}")
waypoint_p = action[:3]
rpy = action[3:6]
gripper_action = float(action[6])
# RPY → quat (wxyz) for sapien.Pose
rpy_t = torch.as_tensor(rpy, dtype=torch.float64)
waypoint_q = rpy_xyz_to_quat_wxyz_torch(rpy_t).numpy()
pose = sapien.Pose(p=waypoint_p, q=waypoint_q)
planner = self._get_planner()
is_stick_env = self.env.unwrapped.spec.id in ("PatternLock", "RouteStick")
current_p = self._current_tcp_p()
dist = np.linalg.norm(current_p - waypoint_p)
collected_steps = []
# if dist < 0.001:
# collected_steps.append(self._no_op_step())
move_steps = -1
for attempt in range(1, DATASET_SCREW_MAX_ATTEMPTS + 1):
try:
result = planner_denseStep._collect_dense_steps(
planner, lambda: planner.move_to_pose_with_screw(pose)
)
except ScrewPlanFailure as exc:
print(f"[MultiStep] screw planning failed (attempt {attempt}/{DATASET_SCREW_MAX_ATTEMPTS}): {exc}")
continue
if isinstance(result, int) and result == -1:
print(f"[MultiStep] screw planning returned -1 (attempt {attempt}/{DATASET_SCREW_MAX_ATTEMPTS})")
continue
move_steps = result
break
if move_steps == -1:
print(f"[MultiStep] screw planning exhausted; fallback to RRT* (max {DATASET_RRT_MAX_ATTEMPTS} attempts)")
for attempt in range(1, DATASET_RRT_MAX_ATTEMPTS + 1):
try:
result = planner_denseStep._collect_dense_steps(
planner, lambda: planner.move_to_pose_with_RRTStar(pose)
)
except Exception as exc:
print(f"[MultiStep] RRT* planning failed (attempt {attempt}/{DATASET_RRT_MAX_ATTEMPTS}): {exc}")
continue
if isinstance(result, int) and result == -1:
print(f"[MultiStep] RRT* planning returned -1 (attempt {attempt}/{DATASET_RRT_MAX_ATTEMPTS})")
continue
move_steps = result
break
if move_steps == -1:
raise RRTPlanFailure("Both screw and RRTStar planning exhausted.")
collected_steps.extend(move_steps)
# PatternLock/RouteStick force skip gripper action (even if planner object has method with same name).
if not is_stick_env:
if gripper_action == -1:
if hasattr(planner, "close_gripper"):
result = planner_denseStep.close_gripper(planner)
if result != -1:
collected_steps.extend(self._batch_to_steps(result))
elif gripper_action == 1:
if hasattr(planner, "open_gripper"):
result = planner_denseStep.open_gripper(planner)
if result != -1:
collected_steps.extend(self._batch_to_steps(result))
obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = planner_denseStep.to_step_batch(
collected_steps
)
info_flat = self._flatten_info_batch(info_batch)
return (
obs_batch,
self._take_last_step_value(reward_batch),
self._take_last_step_value(terminated_batch),
self._take_last_step_value(truncated_batch),
info_flat,
)
def reset(self, **kwargs):
self._planner = None
return self.env.reset(**kwargs)
def close(self):
self._planner = None
return self.env.close()