File size: 8,642 Bytes
06c11b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 | """
MultiStepDemonstrationWrapper: Wraps DemonstrationWrapper to provide waypoint step interface.
Each step(action) receives action = waypoint_p(3) + rpy(3) + gripper_action(1), total 7 dimensions.
Internally converts RPY to quat then calls move_to_pose_with_screw and close_gripper/open_gripper via planner_denseStep,
where PatternLock/RouteStick will force skip close_gripper/open_gripper.
Returns obs as dictionary-of-lists, and reward/terminated/truncated as the last step value.
Caller must ensure scripts/ is in sys.path to import planner_fail_safe.
"""
import numpy as np
import sapien
import torch
import gymnasium as gym
from ..robomme_env.utils import planner_denseStep
from ..robomme_env.utils.rpy_util import rpy_xyz_to_quat_wxyz_torch
from ..robomme_env.utils.planner_fail_safe import ScrewPlanFailure
DATASET_SCREW_MAX_ATTEMPTS = 3
DATASET_RRT_MAX_ATTEMPTS = 3
class RRTPlanFailure(RuntimeError):
"""Raised when move_to_pose_with_RRTStar returns -1 (planning failed)."""
class MultiStepDemonstrationWrapper(gym.Wrapper):
"""
Wraps DemonstrationWrapper. step(action) interprets action as
(waypoint_p, rpy, gripper_action) total 7 dims, internally converts RPY to quat,
executes planning via planner_denseStep, and returns last-step signals.
"""
def __init__(self, env, gui_render=True, vis=True, **kwargs):
super().__init__(env)
self._planner = None
self._gui_render = gui_render
self._vis = vis
self.action_space = gym.spaces.Box(
low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64
)
@staticmethod
def _batch_to_steps(batch):
obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = batch
n = int(reward_batch.numel())
steps = []
obs_keys = list(obs_batch.keys())
info_keys = list(info_batch.keys())
for idx in range(n):
obs = {k: obs_batch[k][idx] for k in obs_keys}
info = {k: info_batch[k][idx] for k in info_keys}
reward = reward_batch[idx]
terminated = terminated_batch[idx]
truncated = truncated_batch[idx]
steps.append((obs, reward, terminated, truncated, info))
return steps
@staticmethod
def _flatten_info_batch(info_batch):
return {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()}
@staticmethod
def _take_last_step_value(value):
if isinstance(value, torch.Tensor):
if value.numel() == 0 or value.ndim == 0:
return value
return value.reshape(-1)[-1]
if isinstance(value, np.ndarray):
if value.size == 0 or value.ndim == 0:
return value
return value.reshape(-1)[-1]
if isinstance(value, (list, tuple)):
return value[-1] if value else value
return value
def _get_planner(self):
if self._planner is not None:
return self._planner
from ..robomme_env.utils.planner_fail_safe import (
FailAwarePandaArmMotionPlanningSolver,
FailAwarePandaStickMotionPlanningSolver,
)
env_id = self.env.unwrapped.spec.id
base_pose = self.env.unwrapped.agent.robot.pose
if env_id in ("PatternLock", "RouteStick"):
self._planner = FailAwarePandaStickMotionPlanningSolver(
self.env,
debug=False,
vis=self._vis,
base_pose=base_pose,
visualize_target_grasp_pose=False,
print_env_info=False,
joint_vel_limits=0.3,
)
else:
self._planner = FailAwarePandaArmMotionPlanningSolver(
self.env,
debug=False,
vis=self._vis,
base_pose=base_pose,
visualize_target_grasp_pose=True,
print_env_info=False,
)
return self._planner
def _current_tcp_p(self):
current_pose = self.env.unwrapped.agent.tcp.pose
p = current_pose.p
if hasattr(p, "cpu"):
p = p.cpu().numpy()
p = np.asarray(p).flatten()
return p
def _no_op_step(self):
"""Execute one step using current qpos + gripper, without moving arm, only to get observation."""
robot = self.env.unwrapped.agent.robot
qpos = robot.get_qpos().cpu().numpy().flatten()
arm = qpos[:7]
gripper = float(qpos[7]) if len(qpos) > 7 else 0.0
action = np.hstack([arm, gripper])
return self.env.step(action)
def step(self, action):
"""Execute waypoint step and return last-step signals for reward/terminated/truncated."""
action = np.asarray(action, dtype=np.float64).flatten()
if action.size < 7:
raise ValueError(f"action must have at least 7 elements, got {action.size}")
waypoint_p = action[:3]
rpy = action[3:6]
gripper_action = float(action[6])
# RPY → quat (wxyz) for sapien.Pose
rpy_t = torch.as_tensor(rpy, dtype=torch.float64)
waypoint_q = rpy_xyz_to_quat_wxyz_torch(rpy_t).numpy()
pose = sapien.Pose(p=waypoint_p, q=waypoint_q)
planner = self._get_planner()
is_stick_env = self.env.unwrapped.spec.id in ("PatternLock", "RouteStick")
current_p = self._current_tcp_p()
dist = np.linalg.norm(current_p - waypoint_p)
collected_steps = []
# if dist < 0.001:
# collected_steps.append(self._no_op_step())
move_steps = -1
for attempt in range(1, DATASET_SCREW_MAX_ATTEMPTS + 1):
try:
result = planner_denseStep._collect_dense_steps(
planner, lambda: planner.move_to_pose_with_screw(pose)
)
except ScrewPlanFailure as exc:
print(f"[MultiStep] screw planning failed (attempt {attempt}/{DATASET_SCREW_MAX_ATTEMPTS}): {exc}")
continue
if isinstance(result, int) and result == -1:
print(f"[MultiStep] screw planning returned -1 (attempt {attempt}/{DATASET_SCREW_MAX_ATTEMPTS})")
continue
move_steps = result
break
if move_steps == -1:
print(f"[MultiStep] screw planning exhausted; fallback to RRT* (max {DATASET_RRT_MAX_ATTEMPTS} attempts)")
for attempt in range(1, DATASET_RRT_MAX_ATTEMPTS + 1):
try:
result = planner_denseStep._collect_dense_steps(
planner, lambda: planner.move_to_pose_with_RRTStar(pose)
)
except Exception as exc:
print(f"[MultiStep] RRT* planning failed (attempt {attempt}/{DATASET_RRT_MAX_ATTEMPTS}): {exc}")
continue
if isinstance(result, int) and result == -1:
print(f"[MultiStep] RRT* planning returned -1 (attempt {attempt}/{DATASET_RRT_MAX_ATTEMPTS})")
continue
move_steps = result
break
if move_steps == -1:
raise RRTPlanFailure("Both screw and RRTStar planning exhausted.")
collected_steps.extend(move_steps)
# PatternLock/RouteStick force skip gripper action (even if planner object has method with same name).
if not is_stick_env:
if gripper_action == -1:
if hasattr(planner, "close_gripper"):
result = planner_denseStep.close_gripper(planner)
if result != -1:
collected_steps.extend(self._batch_to_steps(result))
elif gripper_action == 1:
if hasattr(planner, "open_gripper"):
result = planner_denseStep.open_gripper(planner)
if result != -1:
collected_steps.extend(self._batch_to_steps(result))
obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = planner_denseStep.to_step_batch(
collected_steps
)
info_flat = self._flatten_info_batch(info_batch)
return (
obs_batch,
self._take_last_step_value(reward_batch),
self._take_last_step_value(terminated_batch),
self._take_last_step_value(truncated_batch),
info_flat,
)
def reset(self, **kwargs):
self._planner = None
return self.env.reset(**kwargs)
def close(self):
self._planner = None
return self.env.close()
|