| | """ |
| | DemonstrationWrapper: Wrap another layer outside Robomme environment to automatically generate demonstration trajectories and record frames/actions/states/subgoals, etc. |
| | |
| | - Call get_demonstration_trajectory() after reset, use Motion Planner to execute tasks marked with demonstration and record trajectory. |
| | - step receives joint space action, performs segmentation and subgoal placeholder filling, trajectory recording, truncate and success judgment. ee_pose->joint is handled by outer EndeffectorDemonstrationWrapper. |
| | - Does not include video saving function; reset/step returns unified dense batch; step injects current step frames/subgoal etc via _augment_obs_and_info. |
| | """ |
| | import copy |
| | import re |
| | import time |
| | import json |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import Callable, Dict, List, Optional, Tuple, Union |
| |
|
| | import gymnasium as gym |
| | import h5py |
| | import numpy as np |
| | import sapien.physx as physx |
| | import torch |
| | import cv2 |
| | import colorsys |
| | import imageio |
| |
|
| | from mani_skill import get_commit_info |
| | from mani_skill.envs.sapien_env import BaseEnv |
| | from mani_skill.utils import common, gym_utils, sapien_utils |
| | from mani_skill.utils.io_utils import dump_json |
| | from mani_skill.utils.logging_utils import logger |
| | from mani_skill.utils.structs.types import Array |
| | from mani_skill.utils.wrappers import CPUGymWrapper |
| |
|
| | from mani_skill.examples.motionplanning.panda.motionplanner import \ |
| | PandaArmMotionPlanningSolver |
| | from mani_skill.examples.motionplanning.panda.motionplanner_stick import PandaStickMotionPlanningSolver |
| | from mani_skill.examples.motionplanning.base_motionplanner.utils import ( |
| | compute_grasp_info_by_obb, |
| | get_actor_obb, |
| | ) |
| | from ..robomme_env.utils import task_goal |
| | from ..robomme_env.utils.vqa_options import get_vqa_options |
| |
|
| | from ..robomme_env.utils import reset_panda |
| |
|
| | from ..robomme_env.utils import planner_denseStep |
| | |
| | from ..robomme_env.utils.rpy_util import build_endeffector_pose_dict |
| |
|
| | from ..logging_utils import logger |
| |
|
| | from typing import Any |
| |
|
| | try: |
| | import torch |
| | _HAS_TORCH = True |
| | except ImportError: |
| | _HAS_TORCH = False |
| |
|
| | def _tensor_to_numpy(value: Any, dtype: np.dtype) -> np.ndarray: |
| | """Convert a single Tensor to an ndarray of specified dtype; if already ndarray, only convert dtype.""" |
| | if _HAS_TORCH and isinstance(value, torch.Tensor): |
| | arr = value.detach().cpu().numpy() |
| | else: |
| | arr = np.asarray(value) |
| | if arr.dtype != dtype: |
| | arr = arr.astype(dtype, copy=False) |
| | return arr |
| | class DemonstrationWrapper(gym.Wrapper): |
| | """ |
| | Demonstration wrapper (does not include video saving function). |
| | |
| | Main functions: |
| | 1. Automatically generate demonstration Trajectory after environment reset, using Motion Planner. |
| | 2. Record data such as frames, actions, states, subgoals during demonstration for downstream tasks. |
| | """ |
| | def __init__(self, env, max_steps_without_demonstration, gui_render, |
| | include_maniskill_obs=False, |
| | include_front_depth=False, |
| | include_wrist_depth=False, |
| | include_front_camera_extrinsic=False, |
| | include_wrist_camera_extrinsic=False, |
| | include_available_multi_choices=False, |
| | include_front_camera_intrinsic=False, |
| | include_wrist_camera_intrinsic=False, |
| | **kwargs): |
| | |
| | |
| | self.max_steps_without_demonstration = max_steps_without_demonstration |
| | self.gui_render = gui_render |
| | self.include_maniskill_obs = include_maniskill_obs |
| | self.include_front_depth = include_front_depth |
| | self.include_wrist_depth = include_wrist_depth |
| | self.include_front_camera_extrinsic = include_front_camera_extrinsic |
| | self.include_wrist_camera_extrinsic = include_wrist_camera_extrinsic |
| | self.include_available_multi_choices = include_available_multi_choices |
| | self.include_front_camera_intrinsic = include_front_camera_intrinsic |
| | self.include_wrist_camera_intrinsic = include_wrist_camera_intrinsic |
| |
|
| | super().__init__(env) |
| | self.unwrapped.use_demonstrationwrapper = True |
| |
|
| | self.demonstration_record_traj = False |
| |
|
| | |
| | self.steps_without_demonstration = 0 |
| | |
| | self._doing_extra_step = False |
| | |
| | self.demonstration_data = None |
| | |
| | self.current_subgoal_segment_filled = None |
| | |
| | self.episode_success = False |
| |
|
| | self._failed_match_save_count = 0 |
| | |
| | self._demo_screw_max_attempts = 1 |
| | |
| | self._demo_rrt_max_attempts = 3 |
| | |
| | self._current_demo_task_screw_failed = False |
| | |
| | |
| | |
| | |
| | self._prev_ee_quat_wxyz = None |
| | self._prev_ee_rpy_xyz = None |
| |
|
| | |
| | def generate_color_map(n=100, s_min=0.70, s_max=0.95, v_min=0.78, v_max=0.95): |
| | phi = 0.6180339887498948 |
| | color_map = {} |
| | for i in range(1, n + 1): |
| | h = (i * phi) % 1.0 |
| | s = s_min + (s_max - s_min) * ((i % 7) / 6) |
| | v = v_min + (v_max - v_min) * (((i * 3) % 5) / 4) |
| | r, g, b = colorsys.hsv_to_rgb(h, s, v) |
| | color_map[i] = [int(round(r * 255)), int(round(g * 255)), int(round(b * 255))] |
| | return color_map |
| |
|
| | self.color_map = generate_color_map(10000) |
| |
|
| |
|
| |
|
| |
|
| | def reset(self, **kwargs): |
| | """Reset environment and generate demonstration trajectory, then execute one initial action step and return unified batch.""" |
| | |
| | self.last_subgoal_segment = None |
| | self.latched_replacements = None |
| | self._failed_match_save_count = 0 |
| | |
| | self.steps_without_demonstration = 0 |
| | |
| | |
| | self._prev_ee_quat_wxyz = None |
| | self._prev_ee_rpy_xyz = None |
| |
|
| | super().reset(**kwargs) |
| | self.episode_success = False |
| | |
| | demo_batch = self.get_demonstration_trajectory() |
| |
|
| | |
| | if self.unwrapped.spec.id == "PatternLock" or self.unwrapped.spec.id == "RouteStick": |
| | gripper = "stick" |
| | else: |
| | gripper = None |
| | if self.unwrapped.spec.id == "PatternLock" or self.unwrapped.spec.id == "RouteStick": |
| | action = self.unwrapped.swing_qpos |
| | else: |
| | action = reset_panda.get_reset_panda_param("action", gripper=gripper) |
| |
|
| | |
| | init_batch = self._step_batch(action) |
| | merged_batch = planner_denseStep.concat_step_batches([demo_batch, init_batch]) |
| | merged_batch = self._filter_no_record_from_step_batch(merged_batch) |
| | self.demonstration_data = merged_batch |
| | |
| | |
| | obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = merged_batch |
| | info_flat = self._flatten_info_batch(info_batch) |
| | return obs_batch, info_flat |
| |
|
| | def _filter_no_record_from_step_batch(self, batch): |
| | """ |
| | Only used before reset return: Filter out frames where info_batch['subgoal'] is "NO RECORD". |
| | |
| | Return contract consistent with input batch: |
| | (obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch) |
| | """ |
| | if not (isinstance(batch, tuple) and len(batch) == 5): |
| | return batch |
| | obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = batch |
| |
|
| | if ( |
| | not isinstance(reward_batch, torch.Tensor) |
| | or not isinstance(terminated_batch, torch.Tensor) |
| | or not isinstance(truncated_batch, torch.Tensor) |
| | ): |
| | return batch |
| | if not isinstance(info_batch, dict): |
| | return batch |
| |
|
| | n = int(reward_batch.numel()) |
| | if n == 0: |
| | return batch |
| | if int(terminated_batch.numel()) != n or int(truncated_batch.numel()) != n: |
| | return batch |
| |
|
| | subgoal_list = info_batch.get("simple_subgoal_online") |
| | if not isinstance(subgoal_list, list) or len(subgoal_list) != n: |
| | return batch |
| |
|
| | keep_indices = [ |
| | idx for idx, subgoal in enumerate(subgoal_list) |
| | if str(subgoal).strip() != "NO RECORD" |
| | ] |
| | if len(keep_indices) == n: |
| | return batch |
| | |
| | if len(keep_indices) == 0: |
| | return batch |
| |
|
| | index_reward = torch.as_tensor(keep_indices, dtype=torch.long, device=reward_batch.device) |
| | index_terminated = torch.as_tensor(keep_indices, dtype=torch.long, device=terminated_batch.device) |
| | index_truncated = torch.as_tensor(keep_indices, dtype=torch.long, device=truncated_batch.device) |
| |
|
| | def _filter_columnar_dict(batch_dict): |
| | if not isinstance(batch_dict, dict): |
| | return batch_dict |
| | filtered = {} |
| | for key, value in batch_dict.items(): |
| | if isinstance(value, list) and len(value) == n: |
| | filtered[key] = [value[i] for i in keep_indices] |
| | else: |
| | filtered[key] = value |
| | return filtered |
| |
|
| | filtered_obs_batch = _filter_columnar_dict(obs_batch) |
| | filtered_info_batch = _filter_columnar_dict(info_batch) |
| | filtered_reward_batch = reward_batch.index_select(0, index_reward) |
| | filtered_terminated_batch = terminated_batch.index_select(0, index_terminated) |
| | filtered_truncated_batch = truncated_batch.index_select(0, index_truncated) |
| | return ( |
| | filtered_obs_batch, |
| | filtered_reward_batch, |
| | filtered_terminated_batch, |
| | filtered_truncated_batch, |
| | filtered_info_batch, |
| | ) |
| |
|
| |
|
| | def _augment_obs_and_info(self, obs, info, action): |
| | """Extract current step data directly from obs and merge into obs and info to return, bypassing list buffer intermediate.""" |
| | language_goal = task_goal.get_language_goal(self.env, self.unwrapped.spec.id) |
| |
|
| | base_obs = obs if isinstance(obs, dict) else {} |
| | env_id = self.unwrapped.spec.id |
| | subgoal_text = getattr(self, 'current_task_name', 'Unknown') |
| | grounded_subgoal = self.current_subgoal_segment_filled |
| |
|
| | |
| | image = obs['sensor_data']['base_camera']['rgb'][0] |
| | wrist_image = obs['sensor_data']['hand_camera']['rgb'][0] |
| | state = self.agent.robot.qpos |
| | |
| |
|
| | |
| | |
| | _tcp_p = self.agent.tcp.pose.p |
| | _tcp_q = self.agent.tcp.pose.q |
| | if _tcp_p.ndim > 1: |
| | _tcp_p = _tcp_p.squeeze(0) |
| | if _tcp_q.ndim > 1: |
| | _tcp_q = _tcp_q.squeeze(0) |
| | robot_endeffector_pose, self._prev_ee_quat_wxyz, self._prev_ee_rpy_xyz = \ |
| | build_endeffector_pose_dict( |
| | _tcp_p, |
| | _tcp_q, |
| | self._prev_ee_quat_wxyz, |
| | self._prev_ee_rpy_xyz, |
| | ) |
| |
|
| | |
| | image_np = _tensor_to_numpy(image, np.uint8) |
| | wrist_image_np = _tensor_to_numpy(wrist_image, np.uint8) |
| |
|
| | robot_endeffector_pose_np = { |
| | "pose": _tensor_to_numpy(robot_endeffector_pose['pose'], np.float32), |
| | "quat": _tensor_to_numpy(robot_endeffector_pose['quat'], np.float32), |
| | "rpy": _tensor_to_numpy(robot_endeffector_pose['rpy'], np.float32), |
| | } |
| |
|
| | eef_state_list_f64 = np.concatenate([ |
| | robot_endeffector_pose_np['pose'].flatten()[:3], |
| | robot_endeffector_pose_np['rpy'].flatten()[:3] |
| | ]).astype(np.float64, copy=False) |
| |
|
| | |
| | state_flat = state.detach().cpu().numpy().flatten() if hasattr(state, 'cpu') else np.asarray(state).flatten() |
| |
|
| | is_stick_env = self.unwrapped.spec.id in ("PatternLock", "RouteStick") |
| | if is_stick_env: |
| | gripper_state = np.zeros(2, dtype=np.float64) |
| | else: |
| | gripper_state = state_flat[7:9] if len(state_flat) >= 9 else np.zeros(2, dtype=np.float64) |
| |
|
| | |
| | joint_state = state_flat[:7] |
| |
|
| | |
| | new_obs = { |
| | 'front_rgb_list': image_np, |
| | 'wrist_rgb_list': wrist_image_np, |
| | 'joint_state_list': joint_state, |
| | |
| | 'eef_state_list': eef_state_list_f64, |
| | 'gripper_state_list': gripper_state, |
| | } |
| | if self.include_maniskill_obs: |
| | new_obs['maniskill_obs'] = base_obs |
| | if self.include_front_depth: |
| | new_obs['front_depth_list'] = _tensor_to_numpy(obs["sensor_data"]["base_camera"]["depth"][0], np.int16) |
| | if self.include_wrist_depth: |
| | new_obs['wrist_depth_list'] = _tensor_to_numpy(obs["sensor_data"]["hand_camera"]["depth"][0], np.int16) |
| | if self.include_front_camera_extrinsic: |
| | _ext = _tensor_to_numpy(obs["sensor_param"]["base_camera"]["extrinsic_cv"], np.float32) |
| | if _ext.ndim == 3: |
| | _ext = _ext.squeeze(0) |
| | new_obs['front_camera_extrinsic_list'] = _ext |
| | if self.include_wrist_camera_extrinsic: |
| | _ext = _tensor_to_numpy(obs["sensor_param"]["hand_camera"]["extrinsic_cv"], np.float32) |
| | if _ext.ndim == 3: |
| | _ext = _ext.squeeze(0) |
| | new_obs['wrist_camera_extrinsic_list'] = _ext |
| |
|
| | |
| | new_info = { |
| | **info, |
| | 'simple_subgoal_online': subgoal_text, |
| | 'grounded_subgoal_online': grounded_subgoal, |
| | 'task_goal': language_goal, |
| | } |
| | if self.include_available_multi_choices: |
| | dummy_target = {"obj": None, "name": None, "seg_id": None} |
| | raw_options = get_vqa_options(self, None, dummy_target, env_id) |
| | available_options = [ |
| | {"label": opt.get("label"), "action": opt.get("action", "Unknown"), "need_parameter": bool(opt.get("available"))} |
| | for opt in raw_options |
| | ] |
| | new_info['available_multi_choices'] = available_options |
| | if self.include_front_camera_intrinsic: |
| | _intr = _tensor_to_numpy(obs["sensor_param"]["base_camera"]["intrinsic_cv"], np.float32) |
| | if _intr.ndim == 3: |
| | _intr = _intr.squeeze(0) |
| | new_info['front_camera_intrinsic'] = _intr |
| | if self.include_wrist_camera_intrinsic: |
| | _intr = _tensor_to_numpy(obs["sensor_param"]["hand_camera"]["intrinsic_cv"], np.float32) |
| | if _intr.ndim == 3: |
| | _intr = _intr.squeeze(0) |
| | new_info['wrist_camera_intrinsic'] = _intr |
| |
|
| | return new_obs, new_info |
| |
|
| | def _add_red_border(self, frame, border_width=5): |
| | """Draw red border on four sides of image, used to highlight demonstration frames (currently not used for video saving).""" |
| | frame_with_border = frame.copy() |
| | frame_with_border[:border_width, :] = [255, 0, 0] |
| | frame_with_border[-border_width:, :] = [255, 0, 0] |
| | frame_with_border[:, :border_width] = [255, 0, 0] |
| | frame_with_border[:, -border_width:] = [255, 0, 0] |
| | return frame_with_border |
| |
|
| | TEXT_AREA_HEIGHT = 80 |
| |
|
| | def _add_text_to_frame(self, frame, text, position='top_right'): |
| | """Append black text area above frame and stitch, supporting multi-line and auto-wrap. Black border height fixed to TEXT_AREA_HEIGHT.""" |
| | if text is None: |
| | text = "" |
| | text_area_height = self.TEXT_AREA_HEIGHT |
| | if not text and not (isinstance(text, (list, tuple)) and any(text)): |
| | text_area = np.zeros((text_area_height, frame.shape[1], 3), dtype=np.uint8) |
| | return np.vstack((text_area, frame)) |
| |
|
| | if isinstance(text, str): |
| | text_list = [text] |
| | else: |
| | text_list = list(text) if text else [] |
| |
|
| | font = cv2.FONT_HERSHEY_SIMPLEX |
| | font_scale = 0.3 |
| | thickness = 1 |
| | max_width = max(1, frame.shape[1] - 20) |
| |
|
| | lines = [] |
| | for text_item in text_list: |
| | if text_item is None: |
| | continue |
| | text_item = str(text_item).strip() |
| | if not text_item: |
| | continue |
| | words = text_item.replace(',', ' ').split() |
| | if not words: |
| | continue |
| | current_line = words[0] |
| | for word in words[1:]: |
| | test_line = f"{current_line} {word}" |
| | (text_width, _), _ = cv2.getTextSize(test_line, font, font_scale, thickness) |
| | if text_width <= max_width: |
| | current_line = test_line |
| | else: |
| | lines.append(current_line) |
| | current_line = word |
| | lines.append(current_line) |
| |
|
| | if not lines: |
| | text_area = np.zeros((text_area_height, frame.shape[1], 3), dtype=np.uint8) |
| | return np.vstack((text_area, frame)) |
| |
|
| | line_height = 20 |
| | text_area = np.zeros((text_area_height, frame.shape[1], 3), dtype=np.uint8) |
| | text_area[:] = (0, 0, 0) |
| | max_visible_lines = (text_area_height - 15) // line_height |
| | for i, line in enumerate(lines[:max_visible_lines]): |
| | y_position = 15 + i * line_height |
| | cv2.putText(text_area, line, (10, y_position), font, font_scale, (255, 255, 255), thickness) |
| |
|
| | return np.vstack((text_area, frame)) |
| |
|
| | def save_frame_as_image(self, output_path: Union[str, Path], frame: np.ndarray, text=None): |
| | """ |
| | Overlay single frame with text and save as image. |
| | """ |
| | output_path = Path(output_path) |
| | output_path.parent.mkdir(parents=True, exist_ok=True) |
| | combined = self._add_text_to_frame(np.asarray(frame).copy(), text) |
| | if combined.ndim == 2: |
| | combined = cv2.cvtColor(combined, cv2.COLOR_GRAY2RGB) |
| | scale = 2 |
| | out_h, out_w = combined.shape[0] * scale, combined.shape[1] * scale |
| | combined = cv2.resize(combined, (out_w, out_h), interpolation=cv2.INTER_LINEAR) |
| | imageio.imwrite(str(output_path), combined) |
| |
|
| | def _compute_segmentation_and_fill_subgoal( |
| | self, |
| | obs: Dict, |
| | ) -> Tuple[Optional[str], bool]: |
| | """ |
| | Parse base camera segmentation from observation, build object ID mapping cared by current task, calculate target object pixel center on image, and replace placeholders (like <target>) in current subgoal text with specific coordinates <y, x>. |
| | Support latching: Result is reused after successful fill for same subgoal; latch cleared when subgoal changes. |
| | |
| | Args: |
| | obs: Current step observation, must contain sensor_data.base_camera.segmentation (and optional rgb etc). |
| | |
| | Returns: |
| | filled_text: Subgoal text after placeholder replacement; consistent with current_subgoal_segment if no subgoal or no replacement. |
| | failed_match: True if text has placeholder but no valid fill in this frame and no latch (used for saving failed frames etc). |
| | """ |
| | current_subgoal_segment = getattr(self.unwrapped, 'current_subgoal_segment', None) |
| | current_task_name = getattr(self, 'current_task_name', 'Unknown') |
| |
|
| | |
| | segmentation = None |
| | try: |
| | segmentation = obs['sensor_data']['base_camera']['segmentation'] |
| | except Exception: |
| | segmentation = None |
| |
|
| | segmentation_2d = None |
| | active_segments = [] |
| | segment_ids_by_index = {} |
| | vis_obj_id_list = [] |
| |
|
| | if segmentation is not None: |
| | if hasattr(segmentation, "cpu"): |
| | segmentation = segmentation.cpu().numpy() |
| | segmentation = np.asarray(segmentation) |
| | if segmentation.ndim > 2: |
| | segmentation = segmentation[0] |
| | segmentation_2d = segmentation.squeeze() |
| |
|
| | |
| | current_segment = getattr(self, "current_segment", None) |
| | if isinstance(current_segment, (list, tuple)): |
| | active_segments = list(current_segment) |
| | elif current_segment is None: |
| | active_segments = [] |
| | else: |
| | active_segments = [current_segment] |
| |
|
| | |
| | segment_ids_by_index = {idx: [] for idx in range(len(active_segments))} |
| | segmentation_id_map = getattr(self, "segmentation_id_map", None) |
| | if isinstance(segmentation_id_map, dict): |
| | for obj_id, obj in sorted(segmentation_id_map.items()): |
| | if active_segments: |
| | for idx, target in enumerate(active_segments): |
| | if obj is target: |
| | vis_obj_id_list.append(obj_id) |
| | segment_ids_by_index[idx].append(obj_id) |
| | break |
| | |
| | if getattr(obj, "name", None) == 'table-workspace': |
| | self.color_map[obj_id] = [0, 0, 0] |
| |
|
| | |
| | if segmentation_2d is None: |
| | return (current_subgoal_segment, False) |
| |
|
| | def center_from_ids(segmentation_mask: np.ndarray, ids: List): |
| | """ |
| | Calculate pixel center (centroid) of the object on image based on segmentation mask and object ID list. |
| | Return (center [y, x] or None, no_object_flag_this). |
| | no_object_flag_this is True when ids is not empty but no corresponding pixels in mask. |
| | """ |
| | if not ids: |
| | return None, False |
| | mask = np.isin(segmentation_mask, ids) |
| | if not np.any(mask): |
| | return None, True |
| | coords = np.argwhere(mask) |
| | if coords.size == 0: |
| | return None, True |
| | center_y = int(coords[:, 0].mean()) |
| | center_x = int(coords[:, 1].mean()) |
| | return [center_y, center_x], False |
| |
|
| | |
| | if current_subgoal_segment != self.last_subgoal_segment: |
| | self.last_subgoal_segment = current_subgoal_segment |
| | self.latched_replacements = None |
| |
|
| | |
| | segment_centers = [] |
| | no_object_flag = False |
| | if active_segments: |
| | for idx in range(len(active_segments)): |
| | center, no_obj = center_from_ids(segmentation_2d, segment_ids_by_index.get(idx, [])) |
| | segment_centers.append(center) |
| | no_object_flag = no_object_flag or no_obj |
| | else: |
| | center, no_obj = center_from_ids(segmentation_2d, vis_obj_id_list) |
| | segment_centers.append(center) |
| | no_object_flag = no_obj |
| |
|
| | |
| | if not current_subgoal_segment: |
| | return (current_subgoal_segment, False) |
| |
|
| | |
| | placeholder_pattern = re.compile(r'<[^>]*>') |
| | placeholders = list(placeholder_pattern.finditer(current_subgoal_segment)) |
| | placeholder_count = len(placeholders) |
| |
|
| | final_replacements = None |
| | missing_placeholder = False |
| |
|
| | |
| | if self.latched_replacements is not None: |
| | final_replacements = self.latched_replacements |
| | else: |
| | |
| | normalized_centers = [] |
| | for center in segment_centers: |
| | if center is None: |
| | normalized_centers.append(None) |
| | continue |
| | center_y, center_x = center |
| | normalized_centers.append(f'<{center_y}, {center_x}>') |
| |
|
| | if placeholder_count > 0 and normalized_centers: |
| | replacements = normalized_centers.copy() |
| | |
| | if len(replacements) == 1 and placeholder_count > 1: |
| | replacements = replacements * placeholder_count |
| | elif len(replacements) < placeholder_count: |
| | replacements.extend([None] * (placeholder_count - len(replacements))) |
| | |
| | temp_missing_placeholder = any(r is None for r in replacements) |
| | if not temp_missing_placeholder: |
| | self.latched_replacements = replacements |
| | final_replacements = replacements |
| |
|
| | |
| | if final_replacements and placeholder_count > 0: |
| | new_text_parts = [] |
| | last_idx = 0 |
| | for idx, match in enumerate(placeholders): |
| | new_text_parts.append(current_subgoal_segment[last_idx:match.start()]) |
| | replacement_text = final_replacements[idx] if idx < len(final_replacements) else None |
| | if replacement_text is None: |
| | missing_placeholder = True |
| | else: |
| | new_text_parts.append(replacement_text) |
| | last_idx = match.end() |
| | new_text_parts.append(current_subgoal_segment[last_idx:]) |
| | filled_text = current_task_name if missing_placeholder else ''.join(new_text_parts) |
| | |
| | failed_match = self.latched_replacements is None and (final_replacements is None or missing_placeholder) |
| | return (filled_text, failed_match) |
| | else: |
| | |
| | failed_match = placeholder_count > 0 and self.latched_replacements is None |
| | return (current_subgoal_segment, failed_match) |
| |
|
| | _STICK_ENV_IDS = ("PatternLock", "RouteStick") |
| |
|
| | def _normalize_action_for_env_step(self, action) -> np.ndarray: |
| | """ |
| | Normalize external action to the dimensionality required by the wrapped env.step. |
| | - PatternLock/RouteStick: accept len>=7 and pass first 7 dims. |
| | - Other envs: accept len>=8 and pass first 8 dims. |
| | """ |
| | env_spec = getattr(self.unwrapped, "spec", None) |
| | env_id = getattr(env_spec, "id", "<unknown_env>") |
| | action_arr = np.asarray(action, dtype=np.float64).flatten() |
| | if env_id in self._STICK_ENV_IDS: |
| | if action_arr.size < 7: |
| | raise ValueError(f"[{env_id}] action must have at least 7 elements, got {action_arr.size}") |
| | return action_arr[:7] |
| | if action_arr.size < 8: |
| | raise ValueError(f"[{env_id}] action must have at least 8 elements, got {action_arr.size}") |
| | return action_arr[:8] |
| |
|
| | @staticmethod |
| | def _flatten_info_batch(info_batch: dict) -> dict: |
| | """Convert columnar info dict-of-lists to flat dict by taking the last value of each key.""" |
| | return {k: v[-1] if isinstance(v, list) and v else v for k, v in info_batch.items()} |
| |
|
| | def _step_batch(self, action): |
| | """Internal step returning full batch format (dict-of-lists for both obs and info). |
| | |
| | Used by reset() and other internal callers that need batch-compatible output |
| | for concat_step_batches. |
| | """ |
| | normalized_action = self._normalize_action_for_env_step(action) |
| | obs, reward, terminated, truncated, info = super().step(normalized_action) |
| |
|
| | |
| | filled_text, failed_match = self._compute_segmentation_and_fill_subgoal(obs) |
| | current_subgoal_segment = getattr(self.unwrapped, 'current_subgoal_segment', None) |
| | self.current_subgoal_segment_filled = filled_text if filled_text is not None else current_subgoal_segment |
| |
|
| | |
| | if self.current_task_demonstration == False: |
| | self.steps_without_demonstration += 1 |
| | if self.steps_without_demonstration >= self.max_steps_without_demonstration: |
| | truncated = torch.tensor([True]) |
| |
|
| | |
| | if terminated.any(): |
| | if info.get("success") == torch.tensor([True]) or (isinstance(info.get("success"), torch.Tensor) and info.get("success").item()): |
| | self.episode_success = True |
| | |
| | else: |
| | self.episode_success = False |
| | |
| |
|
| | |
| | if terminated.any() and not self._doing_extra_step: |
| | |
| | |
| | |
| | cached_prev_quat = None if self._prev_ee_quat_wxyz is None else self._prev_ee_quat_wxyz.detach().clone() |
| | cached_prev_rpy = None if self._prev_ee_rpy_xyz is None else self._prev_ee_rpy_xyz.detach().clone() |
| | self._doing_extra_step = True |
| | try: |
| | self._step_batch(normalized_action) |
| | finally: |
| | self._doing_extra_step = False |
| | |
| | self._prev_ee_quat_wxyz = cached_prev_quat |
| | self._prev_ee_rpy_xyz = cached_prev_rpy |
| |
|
| | obs, info = self._augment_obs_and_info(obs, info, normalized_action) |
| |
|
| | |
| | raw_success = info.get("success") |
| | is_success = (isinstance(raw_success, torch.Tensor) and raw_success.item()) or raw_success is True |
| | if is_success: |
| | info["status"] = "success" |
| | elif terminated.any(): |
| | info["status"] = "fail" |
| | elif truncated.any(): |
| | info["status"] = "timeout" |
| | else: |
| | info["status"] = "ongoing" |
| |
|
| | return planner_denseStep.to_step_batch([(obs, reward, terminated, truncated, info)]) |
| |
|
| | def step(self, action): |
| | """Execute one step and return (obs_batch, reward, terminated, truncated, info). |
| | |
| | obs_batch is dict[str, list]; info is a flat dict (last values only). |
| | |
| | If an exception occurs during _step_batch(), the exception is caught and |
| | returned as a structured error via info["status"] = "error" and |
| | info["error_message"] = "<ExceptionType>: <message>", instead of propagating. |
| | Callers should check ``info.get("status") == "error"`` to detect step failures. |
| | """ |
| | batch = self._step_batch(action) |
| | obs_batch, reward_batch, terminated_batch, truncated_batch, info_batch = batch |
| | info_flat = self._flatten_info_batch(info_batch) |
| | return (obs_batch, reward_batch[-1], terminated_batch[-1], truncated_batch[-1], info_flat) |
| |
|
| | def close(self): |
| | """Close environment, release resources (this wrapper no longer saves video).""" |
| | super().close() |
| | return None |
| |
|
| | def get_demonstration_trajectory(self): |
| | """ |
| | Generate Demonstration Trajectory. |
| | |
| | Flow: |
| | 1. Select appropriate Motion Planner (PandaArm or PandaStick) based on environment ID. |
| | 2. Iterate task list (task_list), find tasks marked as demonstration. |
| | 3. For each demonstration task, wrap entire solve call with _collect_dense_steps, |
| | monkey-patch planner.env.step to collect all env.step calls |
| | (including move_to_pose_with_screw, follow_path, direct env.step and all other paths). |
| | 4. Return unified batch (obs/info dict values as list, reward/terminated/truncated as 1D tensor). |
| | """ |
| | |
| | try: |
| | from ..robomme_env.utils.planner_fail_safe import ( |
| | FailAwarePandaArmMotionPlanningSolver, |
| | FailAwarePandaStickMotionPlanningSolver, |
| | ScrewPlanFailure, |
| | ) |
| | except Exception as exc: |
| | logger.debug(f"[DemonstrationWrapper] Warning: failed to import planner_fail_safe, fallback to base planners: {exc}") |
| | FailAwarePandaArmMotionPlanningSolver = PandaArmMotionPlanningSolver |
| | FailAwarePandaStickMotionPlanningSolver = PandaStickMotionPlanningSolver |
| | ScrewPlanFailure = RuntimeError |
| |
|
| | |
| | if self.unwrapped.spec.id == "PatternLock" or self.unwrapped.spec.id == "RouteStick": |
| | planner = FailAwarePandaStickMotionPlanningSolver( |
| | self, |
| | debug=False, |
| | vis=False, |
| | base_pose=self.unwrapped.agent.robot.pose, |
| | visualize_target_grasp_pose=False, |
| | print_env_info=False, |
| | joint_vel_limits=0.3, |
| | ) |
| | else: |
| | planner = FailAwarePandaArmMotionPlanningSolver( |
| | self, |
| | debug=False, |
| | vis=False, |
| | base_pose=self.unwrapped.agent.robot.pose, |
| | visualize_target_grasp_pose=False, |
| | print_env_info=False, |
| | ) |
| |
|
| | |
| | original_move_to_pose_with_screw = planner.move_to_pose_with_screw |
| | original_move_to_pose_with_rrt = planner.move_to_pose_with_RRTStar |
| |
|
| | def _move_to_pose_with_screw_then_rrt_retry(*args, **kwargs): |
| | for attempt in range(1, self._demo_screw_max_attempts + 1): |
| | try: |
| | result = original_move_to_pose_with_screw(*args, **kwargs) |
| | except ScrewPlanFailure as exc: |
| | logger.debug( |
| | f"[DemonstrationWrapper] screw planning failed " |
| | f"(attempt {attempt}/{self._demo_screw_max_attempts}): {exc}" |
| | ) |
| | continue |
| |
|
| | |
| | if isinstance(result, int) and result == -1: |
| | logger.debug( |
| | f"[DemonstrationWrapper] screw planning returned -1 " |
| | f"(attempt {attempt}/{self._demo_screw_max_attempts})" |
| | ) |
| | continue |
| |
|
| | return result |
| |
|
| | logger.debug( |
| | "[DemonstrationWrapper] screw planning exhausted; " |
| | f"fallback to RRT* (max {self._demo_rrt_max_attempts} attempts)" |
| | ) |
| |
|
| | for attempt in range(1, self._demo_rrt_max_attempts + 1): |
| | try: |
| | result = original_move_to_pose_with_rrt(*args, **kwargs) |
| | except Exception as exc: |
| | logger.debug( |
| | f"[DemonstrationWrapper] RRT* planning failed " |
| | f"(attempt {attempt}/{self._demo_rrt_max_attempts}): {exc}" |
| | ) |
| | continue |
| |
|
| | if isinstance(result, int) and result == -1: |
| | logger.debug( |
| | f"[DemonstrationWrapper] RRT* planning returned -1 " |
| | f"(attempt {attempt}/{self._demo_rrt_max_attempts})" |
| | ) |
| | continue |
| |
|
| | return result |
| |
|
| | self._current_demo_task_screw_failed = True |
| | logger.debug("[DemonstrationWrapper] screw->RRT* planning exhausted; return -1") |
| | return -1 |
| |
|
| | planner.move_to_pose_with_screw = _move_to_pose_with_screw_then_rrt_retry |
| | tasks = getattr(self, 'task_list', []) |
| | self.task_list_length = len(tasks) |
| | logger.debug(f"Task list length: {self.task_list_length}") |
| |
|
| | demonstration_tasks = [task for task in tasks if task.get("demonstration", False)] |
| | self.non_demonstration_task_length = len(tasks) - len(demonstration_tasks) |
| | logger.debug(f"Non-demonstration task length: {self.non_demonstration_task_length}") |
| |
|
| | all_collected_steps = [] |
| |
|
| | |
| | |
| | |
| | for idx, task_entry in enumerate(demonstration_tasks): |
| | self.unwrapped.demonstration_record_traj = True |
| | self._current_demo_task_screw_failed = False |
| | task_name = task_entry.get("name", f"Task {idx}") |
| | logger.debug(f"Executing task {idx+1}/{len(demonstration_tasks)}: {task_name}") |
| |
|
| | solve_callable = task_entry.get("solve") |
| | if not callable(solve_callable): |
| | raise ValueError(f"Task '{task_name}' must supply a callable 'solve'.") |
| |
|
| | self.evaluate(solve_complete_eval=True) |
| |
|
| | def _solve_task_without_hard_fail(): |
| | |
| | try: |
| | solve_result = solve_callable(self, planner) |
| | except ScrewPlanFailure as exc: |
| | self._current_demo_task_screw_failed = True |
| | logger.debug(f"[DemonstrationWrapper] task '{task_name}' screw failure: {exc}") |
| | return None |
| | if isinstance(solve_result, int) and solve_result == -1: |
| | self._current_demo_task_screw_failed = True |
| | logger.debug(f"[DemonstrationWrapper] task '{task_name}' returned -1 after screw->RRT* retries") |
| | return None |
| | return solve_result |
| |
|
| | task_steps = planner_denseStep._collect_dense_steps( |
| | planner, |
| | _solve_task_without_hard_fail, |
| | ) |
| | if task_steps == -1: |
| | |
| | logger.debug(f"[DemonstrationWrapper] task '{task_name}' returned -1 from collector; continuing") |
| | else: |
| | all_collected_steps.extend(task_steps) |
| |
|
| | if self._current_demo_task_screw_failed: |
| | logger.debug(f"[DemonstrationWrapper] task '{task_name}' marked failed after screw->RRT* retries; continuing") |
| | self.evaluate(solve_complete_eval=True) |
| |
|
| | self.unwrapped.demonstration_record_traj = False |
| | return planner_denseStep.to_step_batch(all_collected_steps) |
| |
|