| import os |
| import glob |
| import random |
| from typing import List, Dict, Any |
| from pathlib import Path |
| import json |
|
|
| import open3d |
| import traceback |
| from tqdm import tqdm |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import einops |
|
|
| from rlbench.observation_config import ObservationConfig, CameraConfig |
| from rlbench.environment import Environment |
| from rlbench.task_environment import TaskEnvironment |
| from rlbench.action_modes.action_mode import MoveArmThenGripper |
| from rlbench.action_modes.gripper_action_modes import Discrete |
| from rlbench.action_modes.arm_action_modes import EndEffectorPoseViaPlanning |
| from rlbench.backend.exceptions import InvalidActionError |
| from rlbench.demo import Demo |
| from pyrep.errors import IKError, ConfigurationPathError |
| from pyrep.const import RenderMode |
|
|
|
|
| ALL_RLBENCH_TASKS = [ |
| 'basketball_in_hoop', 'beat_the_buzz', 'change_channel', 'change_clock', 'close_box', |
| 'close_door', 'close_drawer', 'close_fridge', 'close_grill', 'close_jar', 'close_laptop_lid', |
| 'close_microwave', 'hang_frame_on_hanger', 'insert_onto_square_peg', 'insert_usb_in_computer', |
| 'lamp_off', 'lamp_on', 'lift_numbered_block', 'light_bulb_in', 'meat_off_grill', 'meat_on_grill', |
| 'move_hanger', 'open_box', 'open_door', 'open_drawer', 'open_fridge', 'open_grill', |
| 'open_microwave', 'open_oven', 'open_window', 'open_wine_bottle', 'phone_on_base', |
| 'pick_and_lift', 'pick_and_lift_small', 'pick_up_cup', 'place_cups', 'place_hanger_on_rack', |
| 'place_shape_in_shape_sorter', 'place_wine_at_rack_location', 'play_jenga', |
| 'plug_charger_in_power_supply', 'press_switch', 'push_button', 'push_buttons', 'put_books_on_bookshelf', |
| 'put_groceries_in_cupboard', 'put_item_in_drawer', 'put_knife_on_chopping_board', 'put_money_in_safe', |
| 'put_rubbish_in_bin', 'put_umbrella_in_umbrella_stand', 'reach_and_drag', 'reach_target', |
| 'scoop_with_spatula', 'screw_nail', 'setup_checkers', 'slide_block_to_color_target', |
| 'slide_block_to_target', 'slide_cabinet_open_and_place_cups', 'stack_blocks', 'stack_cups', |
| 'stack_wine', 'straighten_rope', 'sweep_to_dustpan', 'sweep_to_dustpan_of_size', 'take_frame_off_hanger', |
| 'take_lid_off_saucepan', 'take_money_out_safe', 'take_plate_off_colored_dish_rack', 'take_shoes_out_of_box', |
| 'take_toilet_roll_off_stand', 'take_umbrella_out_of_umbrella_stand', 'take_usb_out_of_computer', |
| 'toilet_seat_down', 'toilet_seat_up', 'tower3', 'turn_oven_on', 'turn_tap', 'tv_on', 'unplug_charger', |
| 'water_plants', 'wipe_desk', 'bimanual_pick_laptop','bimanual_pick_plate','bimanual_straighten_rope', |
| 'coordinated_lift_ball','coordinated_lift_tray','coordinated_push_box','coordinated_put_bottle_in_fridge','dual_push_buttons', |
| 'handover_item','bimanual_sweep_to_dustpan','coordinated_take_tray_out_of_oven','handover_item_easy' |
| ] |
| TASK_TO_ID = {task: i for i, task in enumerate(ALL_RLBENCH_TASKS)} |
|
|
|
|
| def task_file_to_task_class(task_file): |
| import importlib |
|
|
| name = task_file.replace(".py", "") |
| class_name = "".join([w[0].upper() + w[1:] for w in name.split("_")]) |
| mod = importlib.import_module("rlbench.tasks.%s" % name) |
| mod = importlib.reload(mod) |
| task_class = getattr(mod, class_name) |
| return task_class |
|
|
|
|
| def load_episodes() -> Dict[str, Any]: |
| with open(Path(__file__).parent.parent / "data_preprocessing/episodes.json") as fid: |
| return json.load(fid) |
|
|
|
|
| class Mover: |
|
|
| def __init__(self, task, disabled=False, max_tries=1): |
| self._task = task |
| self._last_action = None |
| self._step_id = 0 |
| self._max_tries = max_tries |
| self._disabled = disabled |
|
|
| def __call__(self, action, collision_checking=False): |
| if self._disabled: |
| return self._task.step(action) |
|
|
| target = action.copy() |
| if self._last_action is not None: |
| action[7] = self._last_action[7].copy() |
|
|
| images = [] |
| try_id = 0 |
| obs = None |
| terminate = None |
| reward = 0 |
|
|
| for try_id in range(self._max_tries): |
| action_collision = np.ones(action.shape[0]+1) |
| action_collision[:-1] = action |
| if collision_checking: |
| action_collision[-1] = 0 |
| obs, reward, terminate = self._task.step(action_collision) |
|
|
| pos = obs.gripper_pose[:3] |
| rot = obs.gripper_pose[3:7] |
| dist_pos = np.sqrt(np.square(target[:3] - pos).sum()) |
| dist_rot = np.sqrt(np.square(target[3:7] - rot).sum()) |
| criteria = (dist_pos < 5e-3,) |
|
|
| if all(criteria) or reward == 1: |
| break |
|
|
| print( |
| f"Too far away (pos: {dist_pos:.3f}, rot: {dist_rot:.3f}, step: {self._step_id})... Retrying..." |
| ) |
|
|
| |
| action = target |
| if ( |
| not reward == 1.0 |
| and self._last_action is not None |
| and action[7] != self._last_action[7] |
| ): |
| action_collision = np.ones(action.shape[0]+1) |
| action_collision[:-1] = action |
| if collision_checking: |
| action_collision[-1] = 0 |
| obs, reward, terminate = self._task.step(action_collision) |
|
|
| if try_id == self._max_tries: |
| print(f"Failure after {self._max_tries} tries") |
|
|
| self._step_id += 1 |
| self._last_action = action.copy() |
|
|
| return obs, reward, terminate, images |
|
|
|
|
| class Actioner: |
|
|
| def __init__( |
| self, |
| policy=None, |
| instructions=None, |
| apply_cameras=("over_shoulder_left", "over_shoulder_right", "overhead", "wrist_right", "wrist_left", "front"), |
| action_dim=7, |
| predict_trajectory=True |
| ): |
| self._policy = policy |
| self._instructions = instructions |
| self._apply_cameras = apply_cameras |
| self._action_dim = action_dim |
| self._predict_trajectory = predict_trajectory |
|
|
| self._actions = {} |
| self._instr = None |
| self._task_str = None |
|
|
| self._policy.eval() |
|
|
| def load_episode(self, task_str, variation): |
| self._task_str = task_str |
| instructions = list(self._instructions[task_str][variation]) |
| self._instr = random.choice(instructions).unsqueeze(0) |
| self._task_id = torch.tensor(TASK_TO_ID[task_str]).unsqueeze(0) |
| self._actions = {} |
|
|
| def get_action_from_demo(self, demo): |
| """ |
| Fetch the desired state and action based on the provided demo. |
| :param demo: fetch each demo and save key-point observations |
| :return: a list of obs and action |
| """ |
| key_frame = keypoint_discovery(demo) |
|
|
| action_ls = [] |
| trajectory_ls = [] |
| for i in range(len(key_frame)): |
| obs = demo[key_frame[i]] |
| action_np = np.concatenate([obs.gripper_pose, [obs.gripper_open]]) |
| action = torch.from_numpy(action_np) |
| action_ls.append(action.unsqueeze(0)) |
|
|
| trajectory_np = [] |
| for j in range(key_frame[i - 1] if i > 0 else 0, key_frame[i]): |
| obs = demo[j] |
| trajectory_np.append(np.concatenate([ |
| obs.gripper_pose, [obs.gripper_open] |
| ])) |
| trajectory_ls.append(np.stack(trajectory_np)) |
|
|
| trajectory_mask_ls = [ |
| torch.zeros(1, key_frame[i] - (key_frame[i - 1] if i > 0 else 0)).bool() |
| for i in range(len(key_frame)) |
| ] |
|
|
| return action_ls, trajectory_ls, trajectory_mask_ls |
|
|
| def predict(self, rgbs, pcds, gripper, |
| interpolation_length=None): |
| """ |
| Args: |
| rgbs: (bs, num_hist, num_cameras, 3, H, W) |
| pcds: (bs, num_hist, num_cameras, 3, H, W) |
| gripper: (B, nhist, output_dim) |
| interpolation_length: an integer |
| |
| Returns: |
| {"action": torch.Tensor, "trajectory": torch.Tensor} |
| """ |
| output = {"action": None, "trajectory": None} |
|
|
| rgbs = rgbs / 2 + 0.5 |
|
|
| if self._instr is None: |
| raise ValueError() |
|
|
| self._instr = self._instr.to(rgbs.device) |
| self._task_id = self._task_id.to(rgbs.device) |
|
|
| |
| if self._predict_trajectory: |
| print('Predict Trajectory') |
| fake_traj = torch.full( |
| [1, interpolation_length - 1, gripper.shape[-1]], 0 |
| ).to(rgbs.device) |
| traj_mask = torch.full( |
| [1, interpolation_length - 1], False |
| ).to(rgbs.device) |
| output["trajectory"] = self._policy( |
| fake_traj, |
| traj_mask, |
| rgbs[:, -1], |
| pcds[:, -1], |
| self._instr, |
| gripper[..., :7], |
| run_inference=True |
| ) |
| else: |
| print('Predict Keypose') |
| pred = self._policy( |
| rgbs[:, -1], |
| pcds[:, -1], |
| self._instr, |
| gripper[:, -1, :self._action_dim], |
| ) |
| |
| output["action"] = self._policy.prepare_action(pred) |
|
|
| return output |
|
|
| @property |
| def device(self): |
| return next(self._policy.parameters()).device |
|
|
|
|
| def obs_to_attn(obs, camera): |
| extrinsics_44 = torch.from_numpy( |
| obs.misc[f"{camera}_camera_extrinsics"] |
| ).float() |
| extrinsics_44 = torch.linalg.inv(extrinsics_44) |
| intrinsics_33 = torch.from_numpy( |
| obs.misc[f"{camera}_camera_intrinsics"] |
| ).float() |
| intrinsics_34 = F.pad(intrinsics_33, (0, 1, 0, 0)) |
| gripper_pos_3 = torch.from_numpy(obs.gripper_pose[:3]).float() |
| gripper_pos_41 = F.pad(gripper_pos_3, (0, 1), value=1).unsqueeze(1) |
| points_cam_41 = extrinsics_44 @ gripper_pos_41 |
|
|
| proj_31 = intrinsics_34 @ points_cam_41 |
| proj_3 = proj_31.float().squeeze(1) |
| u = int((proj_3[0] / proj_3[2]).round()) |
| v = int((proj_3[1] / proj_3[2]).round()) |
|
|
| return u, v |
|
|
| def obs_to_attn_right(obs, camera): |
| extrinsics_44 = torch.from_numpy( |
| obs.misc[f"{camera}_camera_extrinsics"] |
| ).float() |
| extrinsics_44 = torch.linalg.inv(extrinsics_44) |
| intrinsics_33 = torch.from_numpy( |
| obs.misc[f"{camera}_camera_intrinsics"] |
| ).float() |
| intrinsics_34 = F.pad(intrinsics_33, (0, 1, 0, 0)) |
| gripper_pos_3 = torch.from_numpy(obs.right.gripper_pose[:3]).float() |
| gripper_pos_41 = F.pad(gripper_pos_3, (0, 1), value=1).unsqueeze(1) |
| points_cam_41 = extrinsics_44 @ gripper_pos_41 |
|
|
| proj_31 = intrinsics_34 @ points_cam_41 |
| proj_3 = proj_31.float().squeeze(1) |
| u = int((proj_3[0] / proj_3[2]).round()) |
| v = int((proj_3[1] / proj_3[2]).round()) |
|
|
| return u, v |
|
|
| def obs_to_attn_left(obs, camera): |
| extrinsics_44 = torch.from_numpy( |
| obs.misc[f"{camera}_camera_extrinsics"] |
| ).float() |
| extrinsics_44 = torch.linalg.inv(extrinsics_44) |
| intrinsics_33 = torch.from_numpy( |
| obs.misc[f"{camera}_camera_intrinsics"] |
| ).float() |
| intrinsics_34 = F.pad(intrinsics_33, (0, 1, 0, 0)) |
| gripper_pos_3 = torch.from_numpy(obs.left.gripper_pose[:3]).float() |
| gripper_pos_41 = F.pad(gripper_pos_3, (0, 1), value=1).unsqueeze(1) |
| points_cam_41 = extrinsics_44 @ gripper_pos_41 |
|
|
| proj_31 = intrinsics_34 @ points_cam_41 |
| proj_3 = proj_31.float().squeeze(1) |
| u = int((proj_3[0] / proj_3[2]).round()) |
| v = int((proj_3[1] / proj_3[2]).round()) |
|
|
| return u, v |
|
|
| class RLBenchEnv: |
|
|
| def __init__( |
| self, |
| data_path, |
| image_size=(256,256), |
| apply_rgb=False, |
| apply_depth=False, |
| apply_pc=False, |
| headless=False, |
| apply_cameras=("over_shoulder_left", "over_shoulder_right", "overhead", "wrist_right", "wrist_left", "front"), |
| fine_sampling_ball_diameter=None, |
| collision_checking=False |
| ): |
|
|
| |
| self.data_path = data_path |
| self.apply_rgb = apply_rgb |
| self.apply_depth = apply_depth |
| self.apply_pc = apply_pc |
| self.apply_cameras = apply_cameras |
| self.fine_sampling_ball_diameter = fine_sampling_ball_diameter |
|
|
| |
| self.obs_config = self.create_obs_config( |
| image_size, apply_rgb, apply_depth, apply_pc, apply_cameras |
| ) |
|
|
| self.action_mode = MoveArmThenGripper( |
| arm_action_mode=EndEffectorPoseViaPlanning(collision_checking=collision_checking), |
| gripper_action_mode=Discrete() |
| ) |
| self.env = Environment( |
| self.action_mode, str(data_path), self.obs_config, |
| headless=headless |
| ) |
| self.image_size = image_size |
|
|
| def get_obs_action(self, obs): |
| """ |
| Fetch the desired state and action based on the provided demo. |
| :param obs: incoming obs |
| :return: required observation and action list |
| """ |
|
|
| |
| state_dict = {"rgb": [], "depth": [], "pc": []} |
| |
| for cam in self.apply_cameras: |
| if self.apply_rgb: |
| |
| rgb = obs.perception_data["{}_rgb".format(cam)] |
| state_dict["rgb"] += [rgb] |
|
|
| if self.apply_depth: |
| |
| depth = obs.perception_data["{}_depth".format(cam)] |
| state_dict["depth"] += [depth] |
|
|
| if self.apply_pc: |
| |
| pc = obs.perception_data["{}_point_cloud".format(cam)] |
| state_dict["pc"] += [pc] |
|
|
| |
| |
| right_action = np.concatenate([obs.right.gripper_pose, [obs.right.gripper_open]]) |
| left_action = np.concatenate([obs.left.gripper_pose, [obs.left.gripper_open]]) |
| return state_dict, torch.from_numpy(right_action).float(), torch.from_numpy(left_action).float() |
|
|
| def get_obs_action_right(self, obs): |
| """ |
| Fetch the desired state and action based on the provided demo. |
| :param obs: incoming obs |
| :return: required observation and action list |
| """ |
|
|
| |
| state_dict = {"rgb": [], "depth": [], "pc": []} |
| |
| right_apply_cameras=( "over_shoulder_right", "overhead", "wrist_right", "front") |
| for cam in right_apply_cameras: |
| if self.apply_rgb: |
| |
| rgb = obs.perception_data["{}_rgb".format(cam)] |
| state_dict["rgb"] += [rgb] |
|
|
| if self.apply_depth: |
| |
| depth = obs.perception_data["{}_depth".format(cam)] |
| state_dict["depth"] += [depth] |
|
|
| if self.apply_pc: |
| |
| pc = obs.perception_data["{}_point_cloud".format(cam)] |
| state_dict["pc"] += [pc] |
|
|
| |
| |
| right_action = np.concatenate([obs.right.gripper_pose, [obs.right.gripper_open]]) |
| return state_dict, torch.from_numpy(right_action).float() |
|
|
| def get_obs_action_left(self, obs): |
| """ |
| Fetch the desired state and action based on the provided demo. |
| :param obs: incoming obs |
| :return: required observation and action list |
| """ |
|
|
| |
| state_dict = {"rgb": [], "depth": [], "pc": []} |
| |
| left_apply_cameras=("over_shoulder_left", "overhead", "wrist_left", "front") |
| for cam in left_apply_cameras: |
| if self.apply_rgb: |
| |
| rgb = obs.perception_data["{}_rgb".format(cam)] |
| state_dict["rgb"] += [rgb] |
|
|
| if self.apply_depth: |
| |
| depth = obs.perception_data["{}_depth".format(cam)] |
| state_dict["depth"] += [depth] |
|
|
| if self.apply_pc: |
| |
| pc = obs.perception_data["{}_point_cloud".format(cam)] |
| state_dict["pc"] += [pc] |
|
|
| |
| |
| left_action = np.concatenate([obs.left.gripper_pose, [obs.left.gripper_open]]) |
| return state_dict, torch.from_numpy(left_action).float() |
|
|
| def get_rgb_pcd_gripper_from_obs(self, obs): |
| """ |
| Return rgb, pcd, and gripper from a given observation |
| :param obs: an Observation from the env |
| :return: rgb, pcd, gripper |
| """ |
| state_dict, gripper = self.get_obs_action(obs) |
| state = transform(state_dict, augmentation=False) |
| state = einops.rearrange( |
| state, |
| "(m n ch) h w -> n m ch h w", |
| ch=3, |
| n=len(self.apply_cameras), |
| m=2 |
| ) |
| rgb = state[:, 0].unsqueeze(0) |
| pcd = state[:, 1].unsqueeze(0) |
| gripper = gripper.unsqueeze(0) |
|
|
| attns = torch.Tensor([]) |
| for cam in self.apply_cameras: |
| u, v = obs_to_attn(obs, cam) |
| attn = torch.zeros(1, 1, 1, self.image_size[0], self.image_size[1]) |
| if not (u < 0 or u > self.image_size[1] - 1 or v < 0 or v > self.image_size[0] - 1): |
| attn[0, 0, 0, v, u] = 1 |
| attns = torch.cat([attns, attn], 1) |
| rgb = torch.cat([rgb, attns], 2) |
|
|
| return rgb, pcd, gripper |
|
|
| def get_obs_action_from_demo(self, demo): |
| """ |
| Fetch the desired state and action based on the provided demo. |
| :param demo: fetch each demo and save key-point observations |
| :param normalise_rgb: normalise rgb to (-1, 1) |
| :return: a list of obs and action |
| """ |
| key_frame = keypoint_discovery(demo) |
| key_frame.insert(0, 0) |
| state_ls = [] |
| action_ls = [] |
| for f in key_frame: |
| state, action = self.get_obs_action(demo._observations[f]) |
| state = transform(state, augmentation=False) |
| state_ls.append(state.unsqueeze(0)) |
| action_ls.append(action.unsqueeze(0)) |
| return state_ls, action_ls |
|
|
| def get_gripper_matrix_from_action(self, action): |
| action = action.cpu().numpy() |
| position = action[:3] |
| quaternion = action[3:7] |
| rotation = open3d.geometry.get_rotation_matrix_from_quaternion( |
| np.array((quaternion[3], quaternion[0], quaternion[1], quaternion[2])) |
| ) |
| gripper_matrix = np.eye(4) |
| gripper_matrix[:3, :3] = rotation |
| gripper_matrix[:3, 3] = position |
| return gripper_matrix |
|
|
| def get_demo(self, task_name, variation, episode_index): |
| """ |
| Fetch a demo from the saved environment. |
| :param task_name: fetch task name |
| :param variation: fetch variation id |
| :param episode_index: fetch episode index: 0 ~ 99 |
| :return: desired demo |
| """ |
| demos = self.env.get_demos( |
| task_name=task_name, |
| variation_number=variation, |
| amount=1, |
| from_episode_number=episode_index, |
| random_selection=False |
| ) |
| return demos |
|
|
| def evaluate_task_on_multiple_variations( |
| self, |
| task_str: str, |
| max_steps: int, |
| num_variations: int, |
| num_demos: int, |
| actioner: Actioner, |
| max_tries: int = 1, |
| verbose: bool = False, |
| dense_interpolation=False, |
| interpolation_length=100, |
| num_history=1, |
| ): |
| self.env.launch() |
| task_type = task_file_to_task_class(task_str) |
| task = self.env.get_task(task_type) |
| task_variations = task.variation_count() |
|
|
| if num_variations > 0: |
| task_variations = np.minimum(num_variations, task_variations) |
| task_variations = range(task_variations) |
| else: |
| task_variations = glob.glob(os.path.join(self.data_path, task_str, "variation*")) |
| task_variations = [int(n.split('/')[-1].replace('variation', '')) for n in task_variations] |
|
|
| var_success_rates = {} |
| var_num_valid_demos = {} |
|
|
| for variation in task_variations: |
| task.set_variation(variation) |
| success_rate, valid, num_valid_demos = ( |
| self._evaluate_task_on_one_variation( |
| task_str=task_str, |
| task=task, |
| max_steps=max_steps, |
| variation=variation, |
| num_demos=num_demos // len(task_variations) + 1, |
| actioner=actioner, |
| max_tries=max_tries, |
| verbose=verbose, |
| dense_interpolation=dense_interpolation, |
| interpolation_length=interpolation_length, |
| num_history=num_history |
| ) |
| ) |
| if valid: |
| var_success_rates[variation] = success_rate |
| var_num_valid_demos[variation] = num_valid_demos |
|
|
| self.env.shutdown() |
|
|
| var_success_rates["mean"] = ( |
| sum(var_success_rates.values()) / |
| sum(var_num_valid_demos.values()) |
| ) |
|
|
| return var_success_rates |
|
|
| @torch.no_grad() |
| def _evaluate_task_on_one_variation( |
| self, |
| task_str: str, |
| task: TaskEnvironment, |
| max_steps: int, |
| variation: int, |
| num_demos: int, |
| actioner: Actioner, |
| max_tries: int = 1, |
| verbose: bool = False, |
| dense_interpolation=False, |
| interpolation_length=50, |
| num_history=0, |
| ): |
| device = actioner.device |
|
|
| success_rate = 0 |
| num_valid_demos = 0 |
| total_reward = 0 |
|
|
| for demo_id in range(num_demos): |
| if verbose: |
| print() |
| print(f"Starting demo {demo_id}") |
|
|
| try: |
| demo = self.get_demo(task_str, variation, episode_index=demo_id)[0] |
| num_valid_demos += 1 |
| except: |
| continue |
|
|
| rgbs = torch.Tensor([]).to(device) |
| pcds = torch.Tensor([]).to(device) |
| grippers = torch.Tensor([]).to(device) |
|
|
| |
| descriptions, obs = task.reset_to_demo(demo) |
|
|
| actioner.load_episode(task_str, variation) |
|
|
| move = Mover(task, max_tries=max_tries) |
| reward = 0.0 |
| max_reward = 0.0 |
|
|
| for step_id in range(max_steps): |
|
|
| |
| rgb, pcd, gripper = self.get_rgb_pcd_gripper_from_obs(obs) |
| rgb = rgb.to(device) |
| pcd = pcd.to(device) |
| gripper = gripper.to(device) |
|
|
| rgbs = torch.cat([rgbs, rgb.unsqueeze(1)], dim=1) |
| pcds = torch.cat([pcds, pcd.unsqueeze(1)], dim=1) |
| grippers = torch.cat([grippers, gripper.unsqueeze(1)], dim=1) |
|
|
| |
| rgbs_input = rgbs[:, -1:][:, :, :, :3] |
| pcds_input = pcds[:, -1:] |
| if num_history < 1: |
| gripper_input = grippers[:, -1] |
| else: |
| gripper_input = grippers[:, -num_history:] |
| npad = num_history - gripper_input.shape[1] |
| gripper_input = F.pad( |
| gripper_input, (0, 0, npad, 0), mode='replicate' |
| ) |
|
|
| output = actioner.predict( |
| rgbs_input, |
| pcds_input, |
| gripper_input, |
| interpolation_length=interpolation_length |
| ) |
|
|
| if verbose: |
| print(f"Step {step_id}") |
|
|
| terminate = True |
|
|
| |
| try: |
| |
| if output.get("trajectory", None) is not None: |
| trajectory = output["trajectory"][-1].cpu().numpy() |
| trajectory[:, -1] = trajectory[:, -1].round() |
|
|
| |
| for action in tqdm(trajectory): |
| |
| |
| |
| |
| |
| |
| collision_checking = self._collision_checking(task_str, step_id) |
| obs, reward, terminate, _ = move(action, collision_checking=collision_checking) |
|
|
| |
| else: |
| print("Plan with RRT") |
| action = output["action"] |
| action[..., -1] = torch.round(action[..., -1]) |
| action = action[-1].detach().cpu().numpy() |
|
|
| collision_checking = self._collision_checking(task_str, step_id) |
| obs, reward, terminate, _ = move(action, collision_checking=collision_checking) |
|
|
| max_reward = max(max_reward, reward) |
|
|
| if reward == 1: |
| success_rate += 1 |
| break |
|
|
| if terminate: |
| print("The episode has terminated!") |
|
|
| except (IKError, ConfigurationPathError, InvalidActionError) as e: |
| print(task_str, demo, step_id, success_rate, e) |
| reward = 0 |
| |
|
|
| total_reward += max_reward |
| if reward == 0: |
| step_id += 1 |
|
|
| print( |
| task_str, |
| "Variation", |
| variation, |
| "Demo", |
| demo_id, |
| "Reward", |
| f"{reward:.2f}", |
| "max_reward", |
| f"{max_reward:.2f}", |
| f"SR: {success_rate}/{demo_id+1}", |
| f"SR: {total_reward:.2f}/{demo_id+1}", |
| "# valid demos", num_valid_demos, |
| ) |
|
|
| |
| if num_valid_demos == 0: |
| assert success_rate == 0 |
| valid = False |
| else: |
| valid = True |
|
|
| return success_rate, valid, num_valid_demos |
|
|
| def _collision_checking(self, task_str, step_id): |
| """Collision checking for planner.""" |
| |
| collision_checking = False |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| return collision_checking |
|
|
| def verify_demos( |
| self, |
| task_str: str, |
| variation: int, |
| num_demos: int, |
| max_tries: int = 1, |
| verbose: bool = False, |
| ): |
| if verbose: |
| print() |
| print(f"{task_str}, variation {variation}, {num_demos} demos") |
|
|
| self.env.launch() |
| task_type = task_file_to_task_class(task_str) |
| task = self.env.get_task(task_type) |
| task.set_variation(variation) |
|
|
| success_rate = 0.0 |
| invalid_demos = 0 |
|
|
| for demo_id in range(num_demos): |
| if verbose: |
| print(f"Starting demo {demo_id}") |
|
|
| try: |
| demo = self.get_demo(task_str, variation, episode_index=demo_id)[0] |
| except: |
| print(f"Invalid demo {demo_id} for {task_str} variation {variation}") |
| print() |
| traceback.print_exc() |
| invalid_demos += 1 |
|
|
| task.reset_to_demo(demo) |
|
|
| gt_keyframe_actions = [] |
| for f in keypoint_discovery(demo): |
| obs = demo[f] |
| action = np.concatenate([obs.gripper_pose, [obs.gripper_open]]) |
| gt_keyframe_actions.append(action) |
|
|
| move = Mover(task, max_tries=max_tries) |
|
|
| for step_id, action in enumerate(gt_keyframe_actions): |
| if verbose: |
| print(f"Step {step_id}") |
|
|
| try: |
| obs, reward, terminate, step_images = move(action) |
| if reward == 1: |
| success_rate += 1 / num_demos |
| break |
| if terminate and verbose: |
| print("The episode has terminated!") |
|
|
| except (IKError, ConfigurationPathError, InvalidActionError) as e: |
| print(task_type, demo, success_rate, e) |
| reward = 0 |
| break |
|
|
| if verbose: |
| print(f"Finished demo {demo_id}, SR: {success_rate}") |
|
|
| |
| if (num_demos - invalid_demos) == 0: |
| success_rate = 0.0 |
| valid = False |
| else: |
| success_rate = success_rate * num_demos / (num_demos - invalid_demos) |
| valid = True |
|
|
| self.env.shutdown() |
| return success_rate, valid, invalid_demos |
|
|
| def create_obs_config( |
| self, image_size, apply_rgb, apply_depth, apply_pc, apply_cameras, **kwargs |
| ): |
| """ |
| Set up observation config for RLBench environment. |
| :param image_size: Image size. |
| :param apply_rgb: Applying RGB as inputs. |
| :param apply_depth: Applying Depth as inputs. |
| :param apply_pc: Applying Point Cloud as inputs. |
| :param apply_cameras: Desired cameras. |
| :return: observation config |
| """ |
| unused_cams = CameraConfig() |
| unused_cams.set_all(False) |
| used_cams = CameraConfig( |
| rgb=apply_rgb, |
| point_cloud=apply_pc, |
| depth=apply_depth, |
| mask=False, |
| image_size=image_size, |
| render_mode=RenderMode.OPENGL, |
| **kwargs, |
| ) |
|
|
| camera_names = apply_cameras |
| kwargs = {} |
| for n in camera_names: |
| kwargs[n] = used_cams |
| camera_configs = { |
| "front": CameraConfig(kwargs.get("front", unused_cams)), |
| "over_shoulder_left":CameraConfig(kwargs.get("over_shoulder_left", unused_cams)), |
| "over_shoulder_right":CameraConfig(kwargs.get("over_shoulder_right", unused_cams)), |
| "wrist_left":CameraConfig(kwargs.get("wrist_left", unused_cams)), |
| "wrist_right":CameraConfig(kwargs.get("wrist_right", unused_cams)), |
| "overhead":CameraConfig(kwargs.get("overhead", unused_cams)), |
| } |
| obs_config = ObservationConfig( |
| camera_configs=camera_configs, |
| joint_forces=False, |
| joint_positions=False, |
| joint_velocities=True, |
| task_low_dim_state=False, |
| gripper_touch_forces=False, |
| gripper_pose=True, |
| gripper_open=True, |
| gripper_matrix=True, |
| gripper_joint_positions=True, |
| ) |
|
|
| return obs_config |
|
|
|
|
| |
| def _is_stopped(demo, i, obs, delta=0.1): |
| next_is_not_final = i == (len(demo) - 2) |
| gripper_state_no_change = i < (len(demo) - 2) and ( |
| obs.gripper_open == demo[i + 1].gripper_open |
| and obs.gripper_open == demo[i - 1].gripper_open |
| and demo[i - 2].gripper_open == demo[i - 1].gripper_open |
| ) |
| small_delta = np.allclose(obs.joint_velocities, 0, atol=delta) |
| return small_delta and (not next_is_not_final) and gripper_state_no_change |
|
|
| def _is_stopped_right(demo, i, obs, delta=0.1): |
| next_is_not_final = i == (len(demo) - 2) |
| gripper_state_no_change = i < (len(demo) - 2) and ( |
| obs.gripper_open == demo[i + 1].right.gripper_open |
| and obs.gripper_open == demo[i - 1].right.gripper_open |
| and demo[i - 2].right.gripper_open == demo[i - 1].right.gripper_open |
| ) |
| small_delta = np.allclose(obs.joint_velocities, 0, atol=delta) |
| return small_delta and (not next_is_not_final) and gripper_state_no_change |
|
|
|
|
| def _is_stopped_left(demo, i, obs, delta=0.1): |
| next_is_not_final = i == (len(demo) - 2) |
| gripper_state_no_change = i < (len(demo) - 2) and ( |
| obs.gripper_open == demo[i + 1].left.gripper_open |
| and obs.gripper_open == demo[i - 1].left.gripper_open |
| and demo[i - 2].left.gripper_open == demo[i - 1].left.gripper_open |
| ) |
| small_delta = np.allclose(obs.joint_velocities, 0, atol=delta) |
| return small_delta and (not next_is_not_final) and gripper_state_no_change |
|
|
| def keypoint_discovery(demo: Demo, stopping_delta=0.1) -> List[int]: |
| episode_keypoints = [] |
| |
| right_prev_gripper_open = demo[0].right.gripper_open |
| left_prev_gripper_open = demo[0].left.gripper_open |
| stopped_buffer = 0 |
|
|
| for i, obs in enumerate(demo._observations): |
| right_stopped = _is_stopped_right(demo, i, obs.right, stopping_delta) |
| left_stopped = _is_stopped_left(demo, i, obs.left, stopping_delta) |
| stopped = (stopped_buffer <= 0) and right_stopped and left_stopped |
| stopped_buffer = 4 if stopped else stopped_buffer - 1 |
| |
| last = i == (len(demo) - 1) |
| right_state_changed = obs.right.gripper_open != right_prev_gripper_open |
| left_state_changed = obs.left.gripper_open != left_prev_gripper_open |
| state_changed = right_state_changed or left_state_changed |
| if i != 0 and (state_changed or last or stopped): |
| episode_keypoints.append(i) |
| right_prev_gripper_open = obs.right.gripper_open |
| left_prev_gripper_open = obs.left.gripper_open |
| if ( |
| len(episode_keypoints) > 1 |
| and (episode_keypoints[-1] - 1) == episode_keypoints[-2] |
| ): |
| episode_keypoints.pop(-2) |
| print("Found %d keypoints." % len(episode_keypoints), episode_keypoints) |
| return episode_keypoints |
|
|
|
|
| def transform(obs_dict, scale_size=(0.75, 1.25), augmentation=False): |
| apply_depth = len(obs_dict.get("depth", [])) > 0 |
| apply_pc = len(obs_dict["pc"]) > 0 |
| num_cams = len(obs_dict["rgb"]) |
|
|
| obs_rgb = [] |
| obs_depth = [] |
| obs_pc = [] |
| for i in range(num_cams): |
| rgb = torch.tensor(obs_dict["rgb"][i]).float().permute(2, 0, 1) |
| depth = ( |
| torch.tensor(obs_dict["depth"][i]).float().permute(2, 0, 1) |
| if apply_depth |
| else None |
| ) |
| pc = ( |
| torch.tensor(obs_dict["pc"][i]).float().permute(2, 0, 1) if apply_pc else None |
| ) |
|
|
| if augmentation: |
| raise NotImplementedError() |
|
|
| |
| rgb = rgb / 255.0 |
| rgb = 2 * (rgb - 0.5) |
|
|
| obs_rgb += [rgb.float()] |
| if depth is not None: |
| obs_depth += [depth.float()] |
| if pc is not None: |
| obs_pc += [pc.float()] |
| obs = obs_rgb + obs_depth + obs_pc |
| return torch.cat(obs, dim=0) |
|
|