VLAarchtests2 / third_party /AnyBimanual /utils /utils_with_rlbench.py
lsnu's picture
Add files using upload-large-folder tool
2f28ec8 verified
import os
import glob
import random
from typing import List, Dict, Any
from pathlib import Path
import json
import open3d
import traceback
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
import einops
from rlbench.observation_config import ObservationConfig, CameraConfig
from rlbench.environment import Environment
from rlbench.task_environment import TaskEnvironment
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.action_modes.arm_action_modes import EndEffectorPoseViaPlanning
from rlbench.backend.exceptions import InvalidActionError
from rlbench.demo import Demo
from pyrep.errors import IKError, ConfigurationPathError
from pyrep.const import RenderMode
ALL_RLBENCH_TASKS = [
'basketball_in_hoop', 'beat_the_buzz', 'change_channel', 'change_clock', 'close_box',
'close_door', 'close_drawer', 'close_fridge', 'close_grill', 'close_jar', 'close_laptop_lid',
'close_microwave', 'hang_frame_on_hanger', 'insert_onto_square_peg', 'insert_usb_in_computer',
'lamp_off', 'lamp_on', 'lift_numbered_block', 'light_bulb_in', 'meat_off_grill', 'meat_on_grill',
'move_hanger', 'open_box', 'open_door', 'open_drawer', 'open_fridge', 'open_grill',
'open_microwave', 'open_oven', 'open_window', 'open_wine_bottle', 'phone_on_base',
'pick_and_lift', 'pick_and_lift_small', 'pick_up_cup', 'place_cups', 'place_hanger_on_rack',
'place_shape_in_shape_sorter', 'place_wine_at_rack_location', 'play_jenga',
'plug_charger_in_power_supply', 'press_switch', 'push_button', 'push_buttons', 'put_books_on_bookshelf',
'put_groceries_in_cupboard', 'put_item_in_drawer', 'put_knife_on_chopping_board', 'put_money_in_safe',
'put_rubbish_in_bin', 'put_umbrella_in_umbrella_stand', 'reach_and_drag', 'reach_target',
'scoop_with_spatula', 'screw_nail', 'setup_checkers', 'slide_block_to_color_target',
'slide_block_to_target', 'slide_cabinet_open_and_place_cups', 'stack_blocks', 'stack_cups',
'stack_wine', 'straighten_rope', 'sweep_to_dustpan', 'sweep_to_dustpan_of_size', 'take_frame_off_hanger',
'take_lid_off_saucepan', 'take_money_out_safe', 'take_plate_off_colored_dish_rack', 'take_shoes_out_of_box',
'take_toilet_roll_off_stand', 'take_umbrella_out_of_umbrella_stand', 'take_usb_out_of_computer',
'toilet_seat_down', 'toilet_seat_up', 'tower3', 'turn_oven_on', 'turn_tap', 'tv_on', 'unplug_charger',
'water_plants', 'wipe_desk', 'bimanual_pick_laptop','bimanual_pick_plate','bimanual_straighten_rope',
'coordinated_lift_ball','coordinated_lift_tray','coordinated_push_box','coordinated_put_bottle_in_fridge','dual_push_buttons',
'handover_item','bimanual_sweep_to_dustpan','coordinated_take_tray_out_of_oven','handover_item_easy'
]
TASK_TO_ID = {task: i for i, task in enumerate(ALL_RLBENCH_TASKS)}
def task_file_to_task_class(task_file):
import importlib
name = task_file.replace(".py", "")
class_name = "".join([w[0].upper() + w[1:] for w in name.split("_")])
mod = importlib.import_module("rlbench.tasks.%s" % name)
mod = importlib.reload(mod)
task_class = getattr(mod, class_name)
return task_class
def load_episodes() -> Dict[str, Any]:
with open(Path(__file__).parent.parent / "data_preprocessing/episodes.json") as fid:
return json.load(fid)
class Mover:
def __init__(self, task, disabled=False, max_tries=1):
self._task = task
self._last_action = None
self._step_id = 0
self._max_tries = max_tries
self._disabled = disabled
def __call__(self, action, collision_checking=False):
if self._disabled:
return self._task.step(action)
target = action.copy()
if self._last_action is not None:
action[7] = self._last_action[7].copy()
images = []
try_id = 0
obs = None
terminate = None
reward = 0
for try_id in range(self._max_tries):
action_collision = np.ones(action.shape[0]+1)
action_collision[:-1] = action
if collision_checking:
action_collision[-1] = 0
obs, reward, terminate = self._task.step(action_collision)
pos = obs.gripper_pose[:3]
rot = obs.gripper_pose[3:7]
dist_pos = np.sqrt(np.square(target[:3] - pos).sum())
dist_rot = np.sqrt(np.square(target[3:7] - rot).sum())
criteria = (dist_pos < 5e-3,)
if all(criteria) or reward == 1:
break
print(
f"Too far away (pos: {dist_pos:.3f}, rot: {dist_rot:.3f}, step: {self._step_id})... Retrying..."
)
# we execute the gripper action after re-tries
action = target
if (
not reward == 1.0
and self._last_action is not None
and action[7] != self._last_action[7]
):
action_collision = np.ones(action.shape[0]+1)
action_collision[:-1] = action
if collision_checking:
action_collision[-1] = 0
obs, reward, terminate = self._task.step(action_collision)
if try_id == self._max_tries:
print(f"Failure after {self._max_tries} tries")
self._step_id += 1
self._last_action = action.copy()
return obs, reward, terminate, images
class Actioner:
def __init__(
self,
policy=None,
instructions=None,
apply_cameras=("over_shoulder_left", "over_shoulder_right", "overhead", "wrist_right", "wrist_left", "front"),
action_dim=7,
predict_trajectory=True
):
self._policy = policy
self._instructions = instructions
self._apply_cameras = apply_cameras
self._action_dim = action_dim
self._predict_trajectory = predict_trajectory
self._actions = {}
self._instr = None
self._task_str = None
self._policy.eval()
def load_episode(self, task_str, variation):
self._task_str = task_str
instructions = list(self._instructions[task_str][variation])
self._instr = random.choice(instructions).unsqueeze(0)
self._task_id = torch.tensor(TASK_TO_ID[task_str]).unsqueeze(0)
self._actions = {}
def get_action_from_demo(self, demo):
"""
Fetch the desired state and action based on the provided demo.
:param demo: fetch each demo and save key-point observations
:return: a list of obs and action
"""
key_frame = keypoint_discovery(demo)
action_ls = []
trajectory_ls = []
for i in range(len(key_frame)):
obs = demo[key_frame[i]]
action_np = np.concatenate([obs.gripper_pose, [obs.gripper_open]])
action = torch.from_numpy(action_np)
action_ls.append(action.unsqueeze(0))
trajectory_np = []
for j in range(key_frame[i - 1] if i > 0 else 0, key_frame[i]):
obs = demo[j]
trajectory_np.append(np.concatenate([
obs.gripper_pose, [obs.gripper_open]
]))
trajectory_ls.append(np.stack(trajectory_np))
trajectory_mask_ls = [
torch.zeros(1, key_frame[i] - (key_frame[i - 1] if i > 0 else 0)).bool()
for i in range(len(key_frame))
]
return action_ls, trajectory_ls, trajectory_mask_ls
def predict(self, rgbs, pcds, gripper,
interpolation_length=None):
"""
Args:
rgbs: (bs, num_hist, num_cameras, 3, H, W)
pcds: (bs, num_hist, num_cameras, 3, H, W)
gripper: (B, nhist, output_dim)
interpolation_length: an integer
Returns:
{"action": torch.Tensor, "trajectory": torch.Tensor}
"""
output = {"action": None, "trajectory": None}
rgbs = rgbs / 2 + 0.5 # in [0, 1]
if self._instr is None:
raise ValueError()
self._instr = self._instr.to(rgbs.device)
self._task_id = self._task_id.to(rgbs.device)
# Predict trajectory
if self._predict_trajectory:
print('Predict Trajectory')
fake_traj = torch.full(
[1, interpolation_length - 1, gripper.shape[-1]], 0
).to(rgbs.device)
traj_mask = torch.full(
[1, interpolation_length - 1], False
).to(rgbs.device)
output["trajectory"] = self._policy(
fake_traj,
traj_mask,
rgbs[:, -1],
pcds[:, -1],
self._instr,
gripper[..., :7],
run_inference=True
)
else:
print('Predict Keypose')
pred = self._policy(
rgbs[:, -1],
pcds[:, -1],
self._instr,
gripper[:, -1, :self._action_dim],
)
# Hackish, assume self._policy is an instance of Act3D
output["action"] = self._policy.prepare_action(pred)
return output
@property
def device(self):
return next(self._policy.parameters()).device
def obs_to_attn(obs, camera):
extrinsics_44 = torch.from_numpy(
obs.misc[f"{camera}_camera_extrinsics"]
).float()
extrinsics_44 = torch.linalg.inv(extrinsics_44)
intrinsics_33 = torch.from_numpy(
obs.misc[f"{camera}_camera_intrinsics"]
).float()
intrinsics_34 = F.pad(intrinsics_33, (0, 1, 0, 0))
gripper_pos_3 = torch.from_numpy(obs.gripper_pose[:3]).float()
gripper_pos_41 = F.pad(gripper_pos_3, (0, 1), value=1).unsqueeze(1)
points_cam_41 = extrinsics_44 @ gripper_pos_41
proj_31 = intrinsics_34 @ points_cam_41
proj_3 = proj_31.float().squeeze(1)
u = int((proj_3[0] / proj_3[2]).round())
v = int((proj_3[1] / proj_3[2]).round())
return u, v
def obs_to_attn_right(obs, camera):
extrinsics_44 = torch.from_numpy(
obs.misc[f"{camera}_camera_extrinsics"]
).float()
extrinsics_44 = torch.linalg.inv(extrinsics_44)
intrinsics_33 = torch.from_numpy(
obs.misc[f"{camera}_camera_intrinsics"]
).float()
intrinsics_34 = F.pad(intrinsics_33, (0, 1, 0, 0))
gripper_pos_3 = torch.from_numpy(obs.right.gripper_pose[:3]).float()
gripper_pos_41 = F.pad(gripper_pos_3, (0, 1), value=1).unsqueeze(1)
points_cam_41 = extrinsics_44 @ gripper_pos_41
proj_31 = intrinsics_34 @ points_cam_41
proj_3 = proj_31.float().squeeze(1)
u = int((proj_3[0] / proj_3[2]).round())
v = int((proj_3[1] / proj_3[2]).round())
return u, v
def obs_to_attn_left(obs, camera):
extrinsics_44 = torch.from_numpy(
obs.misc[f"{camera}_camera_extrinsics"]
).float()
extrinsics_44 = torch.linalg.inv(extrinsics_44)
intrinsics_33 = torch.from_numpy(
obs.misc[f"{camera}_camera_intrinsics"]
).float()
intrinsics_34 = F.pad(intrinsics_33, (0, 1, 0, 0))
gripper_pos_3 = torch.from_numpy(obs.left.gripper_pose[:3]).float()
gripper_pos_41 = F.pad(gripper_pos_3, (0, 1), value=1).unsqueeze(1)
points_cam_41 = extrinsics_44 @ gripper_pos_41
proj_31 = intrinsics_34 @ points_cam_41
proj_3 = proj_31.float().squeeze(1)
u = int((proj_3[0] / proj_3[2]).round())
v = int((proj_3[1] / proj_3[2]).round())
return u, v
class RLBenchEnv:
def __init__(
self,
data_path,
image_size=(256,256),
apply_rgb=False,
apply_depth=False,
apply_pc=False,
headless=False,
apply_cameras=("over_shoulder_left", "over_shoulder_right", "overhead", "wrist_right", "wrist_left", "front"),
fine_sampling_ball_diameter=None,
collision_checking=False
):
# setup required inputs
self.data_path = data_path
self.apply_rgb = apply_rgb
self.apply_depth = apply_depth
self.apply_pc = apply_pc
self.apply_cameras = apply_cameras
self.fine_sampling_ball_diameter = fine_sampling_ball_diameter
# setup RLBench environments
self.obs_config = self.create_obs_config(
image_size, apply_rgb, apply_depth, apply_pc, apply_cameras
)
self.action_mode = MoveArmThenGripper(
arm_action_mode=EndEffectorPoseViaPlanning(collision_checking=collision_checking),
gripper_action_mode=Discrete()
)
self.env = Environment(
self.action_mode, str(data_path), self.obs_config,
headless=headless
)
self.image_size = image_size
def get_obs_action(self, obs):
"""
Fetch the desired state and action based on the provided demo.
:param obs: incoming obs
:return: required observation and action list
"""
# fetch state
state_dict = {"rgb": [], "depth": [], "pc": []}
# print(obs)
for cam in self.apply_cameras:
if self.apply_rgb:
# rgb = getattr(obs.perception_data, "{}_rgb".format(cam))
rgb = obs.perception_data["{}_rgb".format(cam)]
state_dict["rgb"] += [rgb]
if self.apply_depth:
# depth = getattr(obs, "{}_depth".format(cam))
depth = obs.perception_data["{}_depth".format(cam)]
state_dict["depth"] += [depth]
if self.apply_pc:
# pc = getattr(obs, "{}_point_cloud".format(cam))
pc = obs.perception_data["{}_point_cloud".format(cam)]
state_dict["pc"] += [pc]
# fetch action
# action = np.concatenate([obs.gripper_pose, [obs.gripper_open]])
right_action = np.concatenate([obs.right.gripper_pose, [obs.right.gripper_open]])
left_action = np.concatenate([obs.left.gripper_pose, [obs.left.gripper_open]])
return state_dict, torch.from_numpy(right_action).float(), torch.from_numpy(left_action).float()
def get_obs_action_right(self, obs):
"""
Fetch the desired state and action based on the provided demo.
:param obs: incoming obs
:return: required observation and action list
"""
# fetch state
state_dict = {"rgb": [], "depth": [], "pc": []}
# print(obs)
right_apply_cameras=( "over_shoulder_right", "overhead", "wrist_right", "front")
for cam in right_apply_cameras:
if self.apply_rgb:
# rgb = getattr(obs.perception_data, "{}_rgb".format(cam))
rgb = obs.perception_data["{}_rgb".format(cam)]
state_dict["rgb"] += [rgb]
if self.apply_depth:
# depth = getattr(obs, "{}_depth".format(cam))
depth = obs.perception_data["{}_depth".format(cam)]
state_dict["depth"] += [depth]
if self.apply_pc:
# pc = getattr(obs, "{}_point_cloud".format(cam))
pc = obs.perception_data["{}_point_cloud".format(cam)]
state_dict["pc"] += [pc]
# fetch action
# action = np.concatenate([obs.gripper_pose, [obs.gripper_open]])
right_action = np.concatenate([obs.right.gripper_pose, [obs.right.gripper_open]])
return state_dict, torch.from_numpy(right_action).float()
def get_obs_action_left(self, obs):
"""
Fetch the desired state and action based on the provided demo.
:param obs: incoming obs
:return: required observation and action list
"""
# fetch state
state_dict = {"rgb": [], "depth": [], "pc": []}
# print(obs)
left_apply_cameras=("over_shoulder_left", "overhead", "wrist_left", "front")
for cam in left_apply_cameras:
if self.apply_rgb:
# rgb = getattr(obs.perception_data, "{}_rgb".format(cam))
rgb = obs.perception_data["{}_rgb".format(cam)]
state_dict["rgb"] += [rgb]
if self.apply_depth:
# depth = getattr(obs, "{}_depth".format(cam))
depth = obs.perception_data["{}_depth".format(cam)]
state_dict["depth"] += [depth]
if self.apply_pc:
# pc = getattr(obs, "{}_point_cloud".format(cam))
pc = obs.perception_data["{}_point_cloud".format(cam)]
state_dict["pc"] += [pc]
# fetch action
# action = np.concatenate([obs.gripper_pose, [obs.gripper_open]])
left_action = np.concatenate([obs.left.gripper_pose, [obs.left.gripper_open]])
return state_dict, torch.from_numpy(left_action).float()
def get_rgb_pcd_gripper_from_obs(self, obs):
"""
Return rgb, pcd, and gripper from a given observation
:param obs: an Observation from the env
:return: rgb, pcd, gripper
"""
state_dict, gripper = self.get_obs_action(obs)
state = transform(state_dict, augmentation=False)
state = einops.rearrange(
state,
"(m n ch) h w -> n m ch h w",
ch=3,
n=len(self.apply_cameras),
m=2
)
rgb = state[:, 0].unsqueeze(0) # 1, N, C, H, W
pcd = state[:, 1].unsqueeze(0) # 1, N, C, H, W
gripper = gripper.unsqueeze(0) # 1, D
attns = torch.Tensor([])
for cam in self.apply_cameras:
u, v = obs_to_attn(obs, cam)
attn = torch.zeros(1, 1, 1, self.image_size[0], self.image_size[1])
if not (u < 0 or u > self.image_size[1] - 1 or v < 0 or v > self.image_size[0] - 1):
attn[0, 0, 0, v, u] = 1
attns = torch.cat([attns, attn], 1)
rgb = torch.cat([rgb, attns], 2)
return rgb, pcd, gripper
def get_obs_action_from_demo(self, demo):
"""
Fetch the desired state and action based on the provided demo.
:param demo: fetch each demo and save key-point observations
:param normalise_rgb: normalise rgb to (-1, 1)
:return: a list of obs and action
"""
key_frame = keypoint_discovery(demo)
key_frame.insert(0, 0)
state_ls = []
action_ls = []
for f in key_frame:
state, action = self.get_obs_action(demo._observations[f])
state = transform(state, augmentation=False)
state_ls.append(state.unsqueeze(0))
action_ls.append(action.unsqueeze(0))
return state_ls, action_ls
def get_gripper_matrix_from_action(self, action):
action = action.cpu().numpy()
position = action[:3]
quaternion = action[3:7]
rotation = open3d.geometry.get_rotation_matrix_from_quaternion(
np.array((quaternion[3], quaternion[0], quaternion[1], quaternion[2]))
)
gripper_matrix = np.eye(4)
gripper_matrix[:3, :3] = rotation
gripper_matrix[:3, 3] = position
return gripper_matrix
def get_demo(self, task_name, variation, episode_index):
"""
Fetch a demo from the saved environment.
:param task_name: fetch task name
:param variation: fetch variation id
:param episode_index: fetch episode index: 0 ~ 99
:return: desired demo
"""
demos = self.env.get_demos(
task_name=task_name,
variation_number=variation,
amount=1,
from_episode_number=episode_index,
random_selection=False
)
return demos
def evaluate_task_on_multiple_variations(
self,
task_str: str,
max_steps: int,
num_variations: int, # -1 means all variations
num_demos: int,
actioner: Actioner,
max_tries: int = 1,
verbose: bool = False,
dense_interpolation=False,
interpolation_length=100,
num_history=1,
):
self.env.launch()
task_type = task_file_to_task_class(task_str)
task = self.env.get_task(task_type)
task_variations = task.variation_count()
if num_variations > 0:
task_variations = np.minimum(num_variations, task_variations)
task_variations = range(task_variations)
else:
task_variations = glob.glob(os.path.join(self.data_path, task_str, "variation*"))
task_variations = [int(n.split('/')[-1].replace('variation', '')) for n in task_variations]
var_success_rates = {}
var_num_valid_demos = {}
for variation in task_variations:
task.set_variation(variation)
success_rate, valid, num_valid_demos = (
self._evaluate_task_on_one_variation(
task_str=task_str,
task=task,
max_steps=max_steps,
variation=variation,
num_demos=num_demos // len(task_variations) + 1,
actioner=actioner,
max_tries=max_tries,
verbose=verbose,
dense_interpolation=dense_interpolation,
interpolation_length=interpolation_length,
num_history=num_history
)
)
if valid:
var_success_rates[variation] = success_rate
var_num_valid_demos[variation] = num_valid_demos
self.env.shutdown()
var_success_rates["mean"] = (
sum(var_success_rates.values()) /
sum(var_num_valid_demos.values())
)
return var_success_rates
@torch.no_grad()
def _evaluate_task_on_one_variation(
self,
task_str: str,
task: TaskEnvironment,
max_steps: int,
variation: int,
num_demos: int,
actioner: Actioner,
max_tries: int = 1,
verbose: bool = False,
dense_interpolation=False,
interpolation_length=50,
num_history=0,
):
device = actioner.device
success_rate = 0
num_valid_demos = 0
total_reward = 0
for demo_id in range(num_demos):
if verbose:
print()
print(f"Starting demo {demo_id}")
try:
demo = self.get_demo(task_str, variation, episode_index=demo_id)[0]
num_valid_demos += 1
except:
continue
rgbs = torch.Tensor([]).to(device)
pcds = torch.Tensor([]).to(device)
grippers = torch.Tensor([]).to(device)
# descriptions, obs = task.reset()
descriptions, obs = task.reset_to_demo(demo)
actioner.load_episode(task_str, variation)
move = Mover(task, max_tries=max_tries)
reward = 0.0
max_reward = 0.0
for step_id in range(max_steps):
# Fetch the current observation, and predict one action
rgb, pcd, gripper = self.get_rgb_pcd_gripper_from_obs(obs)
rgb = rgb.to(device)
pcd = pcd.to(device)
gripper = gripper.to(device)
rgbs = torch.cat([rgbs, rgb.unsqueeze(1)], dim=1)
pcds = torch.cat([pcds, pcd.unsqueeze(1)], dim=1)
grippers = torch.cat([grippers, gripper.unsqueeze(1)], dim=1)
# Prepare proprioception history
rgbs_input = rgbs[:, -1:][:, :, :, :3]
pcds_input = pcds[:, -1:]
if num_history < 1:
gripper_input = grippers[:, -1]
else:
gripper_input = grippers[:, -num_history:]
npad = num_history - gripper_input.shape[1]
gripper_input = F.pad(
gripper_input, (0, 0, npad, 0), mode='replicate'
)
output = actioner.predict(
rgbs_input,
pcds_input,
gripper_input,
interpolation_length=interpolation_length
)
if verbose:
print(f"Step {step_id}")
terminate = True
# Update the observation based on the predicted action
try:
# Execute entire predicted trajectory step by step
if output.get("trajectory", None) is not None:
trajectory = output["trajectory"][-1].cpu().numpy()
trajectory[:, -1] = trajectory[:, -1].round()
# execute
for action in tqdm(trajectory):
#try:
# collision_checking = self._collision_checking(task_str, step_id)
# obs, reward, terminate, _ = move(action_np, collision_checking=collision_checking)
#except:
# terminate = True
# pass
collision_checking = self._collision_checking(task_str, step_id)
obs, reward, terminate, _ = move(action, collision_checking=collision_checking)
# Or plan to reach next predicted keypoint
else:
print("Plan with RRT")
action = output["action"]
action[..., -1] = torch.round(action[..., -1])
action = action[-1].detach().cpu().numpy()
collision_checking = self._collision_checking(task_str, step_id)
obs, reward, terminate, _ = move(action, collision_checking=collision_checking)
max_reward = max(max_reward, reward)
if reward == 1:
success_rate += 1
break
if terminate:
print("The episode has terminated!")
except (IKError, ConfigurationPathError, InvalidActionError) as e:
print(task_str, demo, step_id, success_rate, e)
reward = 0
#break
total_reward += max_reward
if reward == 0:
step_id += 1
print(
task_str,
"Variation",
variation,
"Demo",
demo_id,
"Reward",
f"{reward:.2f}",
"max_reward",
f"{max_reward:.2f}",
f"SR: {success_rate}/{demo_id+1}",
f"SR: {total_reward:.2f}/{demo_id+1}",
"# valid demos", num_valid_demos,
)
# Compensate for failed demos
if num_valid_demos == 0:
assert success_rate == 0
valid = False
else:
valid = True
return success_rate, valid, num_valid_demos
def _collision_checking(self, task_str, step_id):
"""Collision checking for planner."""
# collision_checking = True
collision_checking = False
# if task_str == 'close_door':
# collision_checking = True
# if task_str == 'open_fridge' and step_id == 0:
# collision_checking = True
# if task_str == 'open_oven' and step_id == 3:
# collision_checking = True
# if task_str == 'hang_frame_on_hanger' and step_id == 0:
# collision_checking = True
# if task_str == 'take_frame_off_hanger' and step_id == 0:
# for i in range(300):
# self.env._scene.step()
# collision_checking = True
# if task_str == 'put_books_on_bookshelf' and step_id == 0:
# collision_checking = True
# if task_str == 'slide_cabinet_open_and_place_cups' and step_id == 0:
# collision_checking = True
return collision_checking
def verify_demos(
self,
task_str: str,
variation: int,
num_demos: int,
max_tries: int = 1,
verbose: bool = False,
):
if verbose:
print()
print(f"{task_str}, variation {variation}, {num_demos} demos")
self.env.launch()
task_type = task_file_to_task_class(task_str)
task = self.env.get_task(task_type)
task.set_variation(variation) # type: ignore
success_rate = 0.0
invalid_demos = 0
for demo_id in range(num_demos):
if verbose:
print(f"Starting demo {demo_id}")
try:
demo = self.get_demo(task_str, variation, episode_index=demo_id)[0]
except:
print(f"Invalid demo {demo_id} for {task_str} variation {variation}")
print()
traceback.print_exc()
invalid_demos += 1
task.reset_to_demo(demo)
gt_keyframe_actions = []
for f in keypoint_discovery(demo):
obs = demo[f]
action = np.concatenate([obs.gripper_pose, [obs.gripper_open]])
gt_keyframe_actions.append(action)
move = Mover(task, max_tries=max_tries)
for step_id, action in enumerate(gt_keyframe_actions):
if verbose:
print(f"Step {step_id}")
try:
obs, reward, terminate, step_images = move(action)
if reward == 1:
success_rate += 1 / num_demos
break
if terminate and verbose:
print("The episode has terminated!")
except (IKError, ConfigurationPathError, InvalidActionError) as e:
print(task_type, demo, success_rate, e)
reward = 0
break
if verbose:
print(f"Finished demo {demo_id}, SR: {success_rate}")
# Compensate for failed demos
if (num_demos - invalid_demos) == 0:
success_rate = 0.0
valid = False
else:
success_rate = success_rate * num_demos / (num_demos - invalid_demos)
valid = True
self.env.shutdown()
return success_rate, valid, invalid_demos
def create_obs_config(
self, image_size, apply_rgb, apply_depth, apply_pc, apply_cameras, **kwargs
):
"""
Set up observation config for RLBench environment.
:param image_size: Image size.
:param apply_rgb: Applying RGB as inputs.
:param apply_depth: Applying Depth as inputs.
:param apply_pc: Applying Point Cloud as inputs.
:param apply_cameras: Desired cameras.
:return: observation config
"""
unused_cams = CameraConfig()
unused_cams.set_all(False)
used_cams = CameraConfig(
rgb=apply_rgb,
point_cloud=apply_pc,
depth=apply_depth,
mask=False,
image_size=image_size,
render_mode=RenderMode.OPENGL,
**kwargs,
)
camera_names = apply_cameras
kwargs = {}
for n in camera_names:
kwargs[n] = used_cams
camera_configs = {
"front": CameraConfig(kwargs.get("front", unused_cams)),
"over_shoulder_left":CameraConfig(kwargs.get("over_shoulder_left", unused_cams)),
"over_shoulder_right":CameraConfig(kwargs.get("over_shoulder_right", unused_cams)),
"wrist_left":CameraConfig(kwargs.get("wrist_left", unused_cams)),
"wrist_right":CameraConfig(kwargs.get("wrist_right", unused_cams)),
"overhead":CameraConfig(kwargs.get("overhead", unused_cams)),
}
obs_config = ObservationConfig(
camera_configs=camera_configs,
joint_forces=False,
joint_positions=False,
joint_velocities=True,
task_low_dim_state=False,
gripper_touch_forces=False,
gripper_pose=True,
gripper_open=True,
gripper_matrix=True,
gripper_joint_positions=True,
)
return obs_config
# Identify way-point in each RLBench Demo
def _is_stopped(demo, i, obs, delta=0.1):
next_is_not_final = i == (len(demo) - 2)
gripper_state_no_change = i < (len(demo) - 2) and (
obs.gripper_open == demo[i + 1].gripper_open
and obs.gripper_open == demo[i - 1].gripper_open
and demo[i - 2].gripper_open == demo[i - 1].gripper_open
)
small_delta = np.allclose(obs.joint_velocities, 0, atol=delta)
return small_delta and (not next_is_not_final) and gripper_state_no_change
def _is_stopped_right(demo, i, obs, delta=0.1):
next_is_not_final = i == (len(demo) - 2)
gripper_state_no_change = i < (len(demo) - 2) and (
obs.gripper_open == demo[i + 1].right.gripper_open
and obs.gripper_open == demo[i - 1].right.gripper_open
and demo[i - 2].right.gripper_open == demo[i - 1].right.gripper_open
)
small_delta = np.allclose(obs.joint_velocities, 0, atol=delta)
return small_delta and (not next_is_not_final) and gripper_state_no_change
def _is_stopped_left(demo, i, obs, delta=0.1):
next_is_not_final = i == (len(demo) - 2)
gripper_state_no_change = i < (len(demo) - 2) and (
obs.gripper_open == demo[i + 1].left.gripper_open
and obs.gripper_open == demo[i - 1].left.gripper_open
and demo[i - 2].left.gripper_open == demo[i - 1].left.gripper_open
)
small_delta = np.allclose(obs.joint_velocities, 0, atol=delta)
return small_delta and (not next_is_not_final) and gripper_state_no_change
def keypoint_discovery(demo: Demo, stopping_delta=0.1) -> List[int]:
episode_keypoints = []
# print(demo[0].right.gripper_open)
right_prev_gripper_open = demo[0].right.gripper_open
left_prev_gripper_open = demo[0].left.gripper_open
stopped_buffer = 0
for i, obs in enumerate(demo._observations):
right_stopped = _is_stopped_right(demo, i, obs.right, stopping_delta)
left_stopped = _is_stopped_left(demo, i, obs.left, stopping_delta)
stopped = (stopped_buffer <= 0) and right_stopped and left_stopped
stopped_buffer = 4 if stopped else stopped_buffer - 1
# If change in gripper, or end of episode.
last = i == (len(demo) - 1)
right_state_changed = obs.right.gripper_open != right_prev_gripper_open
left_state_changed = obs.left.gripper_open != left_prev_gripper_open
state_changed = right_state_changed or left_state_changed
if i != 0 and (state_changed or last or stopped):
episode_keypoints.append(i)
right_prev_gripper_open = obs.right.gripper_open
left_prev_gripper_open = obs.left.gripper_open
if (
len(episode_keypoints) > 1
and (episode_keypoints[-1] - 1) == episode_keypoints[-2]
):
episode_keypoints.pop(-2)
print("Found %d keypoints." % len(episode_keypoints), episode_keypoints)
return episode_keypoints
def transform(obs_dict, scale_size=(0.75, 1.25), augmentation=False):
apply_depth = len(obs_dict.get("depth", [])) > 0
apply_pc = len(obs_dict["pc"]) > 0
num_cams = len(obs_dict["rgb"])
obs_rgb = []
obs_depth = []
obs_pc = []
for i in range(num_cams):
rgb = torch.tensor(obs_dict["rgb"][i]).float().permute(2, 0, 1)
depth = (
torch.tensor(obs_dict["depth"][i]).float().permute(2, 0, 1)
if apply_depth
else None
)
pc = (
torch.tensor(obs_dict["pc"][i]).float().permute(2, 0, 1) if apply_pc else None
)
if augmentation:
raise NotImplementedError() # Deprecated
# normalise to [-1, 1]
rgb = rgb / 255.0
rgb = 2 * (rgb - 0.5)
obs_rgb += [rgb.float()]
if depth is not None:
obs_depth += [depth.float()]
if pc is not None:
obs_pc += [pc.float()]
obs = obs_rgb + obs_depth + obs_pc
return torch.cat(obs, dim=0)