| | """ |
| | MultiStepDemonstrationWrapper: Wraps DemonstrationWrapper to provide waypoint step interface. |
| | |
| | Each step(action) receives action = waypoint_p(3) + rpy(3) + gripper_action(1), total 7 dimensions. |
| | Internally converts RPY to quat then calls move_to_pose_with_screw and close_gripper/open_gripper via planner_denseStep, |
| | where PatternLock/RouteStick will force skip close_gripper/open_gripper. |
| | Returns obs as dictionary-of-lists, and reward/terminated/truncated as the last step value. |
| | Caller must ensure scripts/ is in sys.path to import planner_fail_safe. |
| | """ |
| | import numpy as np |
| | import sapien |
| | import torch |
| | import gymnasium as gym |
| |
|
| | from ..robomme_env.utils import planner_denseStep |
| | from ..robomme_env.utils.rpy_util import rpy_xyz_to_quat_wxyz_torch |
| | from ..robomme_env.utils.planner_fail_safe import ScrewPlanFailure |
| |
|
| | DATASET_SCREW_MAX_ATTEMPTS = 3 |
| | DATASET_RRT_MAX_ATTEMPTS = 3 |
| |
|
| |
|
| | class RRTPlanFailure(RuntimeError): |
| | """Raised when move_to_pose_with_RRTStar returns -1 (planning failed).""" |
| |
|
| |
|
| | class MultiStepDemonstrationWrapper(gym.Wrapper): |
| | """ |
| | Wraps DemonstrationWrapper. step(action) interprets action as |
| | (waypoint_p, rpy, gripper_action) total 7 dims, internally converts RPY to quat, |
| | executes planning via planner_denseStep, and returns last-step signals. |
| | """ |
| |
|
| | def __init__(self, env, gui_render=True, vis=True, **kwargs): |
| | super().__init__(env) |
| | self._planner = None |
| | self._gui_render = gui_render |
| | self._vis = vis |
| | self.action_space = gym.spaces.Box( |
| | low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64 |
| | ) |
| |
|
| | @staticmethod |
| | def _batch_to_steps(batch): |
| | obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = batch |
| | n = int(reward_batch.numel()) |
| | steps = [] |
| | obs_keys = list(obs_batch.keys()) |
| | info_keys = list(info_batch.keys()) |
| | for idx in range(n): |
| | obs = {k: obs_batch[k][idx] for k in obs_keys} |
| | info = {k: info_batch[k][idx] for k in info_keys} |
| | reward = reward_batch[idx] |
| | terminated = terminated_batch[idx] |
| | truncated = truncated_batch[idx] |
| | steps.append((obs, reward, terminated, truncated, info)) |
| | return steps |
| |
|
| | @staticmethod |
| | def _flatten_info_batch(info_batch): |
| | return {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()} |
| |
|
| | @staticmethod |
| | def _take_last_step_value(value): |
| | if isinstance(value, torch.Tensor): |
| | if value.numel() == 0 or value.ndim == 0: |
| | return value |
| | return value.reshape(-1)[-1] |
| | if isinstance(value, np.ndarray): |
| | if value.size == 0 or value.ndim == 0: |
| | return value |
| | return value.reshape(-1)[-1] |
| | if isinstance(value, (list, tuple)): |
| | return value[-1] if value else value |
| | return value |
| |
|
| | def _get_planner(self): |
| | if self._planner is not None: |
| | return self._planner |
| | from ..robomme_env.utils.planner_fail_safe import ( |
| | FailAwarePandaArmMotionPlanningSolver, |
| | FailAwarePandaStickMotionPlanningSolver, |
| | ) |
| |
|
| | env_id = self.env.unwrapped.spec.id |
| | base_pose = self.env.unwrapped.agent.robot.pose |
| | if env_id in ("PatternLock", "RouteStick"): |
| | self._planner = FailAwarePandaStickMotionPlanningSolver( |
| | self.env, |
| | debug=False, |
| | vis=self._vis, |
| | base_pose=base_pose, |
| | visualize_target_grasp_pose=False, |
| | print_env_info=False, |
| | joint_vel_limits=0.3, |
| | ) |
| | else: |
| | self._planner = FailAwarePandaArmMotionPlanningSolver( |
| | self.env, |
| | debug=False, |
| | vis=self._vis, |
| | base_pose=base_pose, |
| | visualize_target_grasp_pose=True, |
| | print_env_info=False, |
| | ) |
| | return self._planner |
| |
|
| | def _current_tcp_p(self): |
| | current_pose = self.env.unwrapped.agent.tcp.pose |
| | p = current_pose.p |
| | if hasattr(p, "cpu"): |
| | p = p.cpu().numpy() |
| | p = np.asarray(p).flatten() |
| | return p |
| |
|
| | def _no_op_step(self): |
| | """Execute one step using current qpos + gripper, without moving arm, only to get observation.""" |
| | robot = self.env.unwrapped.agent.robot |
| | qpos = robot.get_qpos().cpu().numpy().flatten() |
| | arm = qpos[:7] |
| | gripper = float(qpos[7]) if len(qpos) > 7 else 0.0 |
| | action = np.hstack([arm, gripper]) |
| | return self.env.step(action) |
| |
|
| | def step(self, action): |
| | """Execute waypoint step and return last-step signals for reward/terminated/truncated.""" |
| | action = np.asarray(action, dtype=np.float64).flatten() |
| | if action.size < 7: |
| | raise ValueError(f"action must have at least 7 elements, got {action.size}") |
| | waypoint_p = action[:3] |
| | rpy = action[3:6] |
| | gripper_action = float(action[6]) |
| |
|
| | |
| | rpy_t = torch.as_tensor(rpy, dtype=torch.float64) |
| | waypoint_q = rpy_xyz_to_quat_wxyz_torch(rpy_t).numpy() |
| |
|
| | pose = sapien.Pose(p=waypoint_p, q=waypoint_q) |
| | planner = self._get_planner() |
| | is_stick_env = self.env.unwrapped.spec.id in ("PatternLock", "RouteStick") |
| |
|
| | current_p = self._current_tcp_p() |
| | dist = np.linalg.norm(current_p - waypoint_p) |
| |
|
| | collected_steps = [] |
| | |
| | |
| | move_steps = -1 |
| | for attempt in range(1, DATASET_SCREW_MAX_ATTEMPTS + 1): |
| | try: |
| | result = planner_denseStep._collect_dense_steps( |
| | planner, lambda: planner.move_to_pose_with_screw(pose) |
| | ) |
| | except ScrewPlanFailure as exc: |
| | print(f"[MultiStep] screw planning failed (attempt {attempt}/{DATASET_SCREW_MAX_ATTEMPTS}): {exc}") |
| | continue |
| | |
| | if isinstance(result, int) and result == -1: |
| | print(f"[MultiStep] screw planning returned -1 (attempt {attempt}/{DATASET_SCREW_MAX_ATTEMPTS})") |
| | continue |
| |
|
| | move_steps = result |
| | break |
| |
|
| | if move_steps == -1: |
| | print(f"[MultiStep] screw planning exhausted; fallback to RRT* (max {DATASET_RRT_MAX_ATTEMPTS} attempts)") |
| | for attempt in range(1, DATASET_RRT_MAX_ATTEMPTS + 1): |
| | try: |
| | result = planner_denseStep._collect_dense_steps( |
| | planner, lambda: planner.move_to_pose_with_RRTStar(pose) |
| | ) |
| | except Exception as exc: |
| | print(f"[MultiStep] RRT* planning failed (attempt {attempt}/{DATASET_RRT_MAX_ATTEMPTS}): {exc}") |
| | continue |
| |
|
| | if isinstance(result, int) and result == -1: |
| | print(f"[MultiStep] RRT* planning returned -1 (attempt {attempt}/{DATASET_RRT_MAX_ATTEMPTS})") |
| | continue |
| |
|
| | move_steps = result |
| | break |
| |
|
| | if move_steps == -1: |
| | raise RRTPlanFailure("Both screw and RRTStar planning exhausted.") |
| | collected_steps.extend(move_steps) |
| |
|
| | |
| | if not is_stick_env: |
| | if gripper_action == -1: |
| | if hasattr(planner, "close_gripper"): |
| | result = planner_denseStep.close_gripper(planner) |
| | if result != -1: |
| | collected_steps.extend(self._batch_to_steps(result)) |
| | elif gripper_action == 1: |
| | if hasattr(planner, "open_gripper"): |
| | result = planner_denseStep.open_gripper(planner) |
| | if result != -1: |
| | collected_steps.extend(self._batch_to_steps(result)) |
| |
|
| | obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = planner_denseStep.to_step_batch( |
| | collected_steps |
| | ) |
| | info_flat = self._flatten_info_batch(info_batch) |
| | return ( |
| | obs_batch, |
| | self._take_last_step_value(reward_batch), |
| | self._take_last_step_value(terminated_batch), |
| | self._take_last_step_value(truncated_batch), |
| | info_flat, |
| | ) |
| |
|
| | def reset(self, **kwargs): |
| | self._planner = None |
| | return self.env.reset(**kwargs) |
| |
|
| | def close(self): |
| | self._planner = None |
| | return self.env.close() |
| |
|