| import copy |
| import logging |
| from typing import List |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from rlbench.backend.observation import Observation |
| from rlbench.demo import Demo |
| from yarr.replay_buffer.prioritized_replay_buffer import ( |
| PrioritizedReplayBuffer, |
| ObservationElement, |
| ) |
| from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer |
| from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer |
|
|
| from helpers import demo_loading_utils, utils |
| from helpers.custom_rlbench_env import CustomRLBenchEnv |
| from helpers.network_utils import ( |
| SiameseNet, |
| DenseBlock, |
| Conv2DBlock, |
| Conv2DUpsampleBlock, |
| ) |
| from helpers.preprocess_agent import PreprocessAgent |
| from agents.arm.next_best_pose_agent import NextBestPoseAgent |
| from agents.arm.qattention_agent import QAttentionAgent |
|
|
| REWARD_SCALE = 100.0 |
|
|
|
|
| def create_replay( |
| batch_size: int, |
| timesteps: int, |
| prioritisation: bool, |
| save_dir: str, |
| cameras: list, |
| env: CustomRLBenchEnv, |
| ): |
| observation_elements = env.observation_elements |
| for cname in cameras: |
| observation_elements.extend( |
| [ |
| ObservationElement("%s_pixel_coord" % cname, (2,), np.int32), |
| ] |
| ) |
|
|
| replay_class = UniformReplayBuffer |
| if prioritisation: |
| replay_class = PrioritizedReplayBuffer |
| replay_buffer = replay_class( |
| save_dir=save_dir, |
| batch_size=batch_size, |
| timesteps=timesteps, |
| replay_capacity=int(1e5), |
| action_shape=(8,), |
| action_dtype=np.float32, |
| reward_shape=(), |
| reward_dtype=np.float32, |
| update_horizon=1, |
| observation_elements=observation_elements, |
| extra_replay_elements=[ReplayElement("demo", (), np.bool)], |
| ) |
| return replay_buffer |
|
|
|
|
| def _point_to_pixel_index( |
| point: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray |
| ): |
| point = np.array([point[0], point[1], point[2], 1]) |
| world_to_cam = np.linalg.inv(extrinsics) |
| point_in_cam_frame = world_to_cam.dot(point) |
| px, py, pz = point_in_cam_frame[:3] |
| px = 2 * intrinsics[0, 2] - int(-intrinsics[0, 0] * (px / pz) + intrinsics[0, 2]) |
| py = 2 * intrinsics[1, 2] - int(-intrinsics[1, 1] * (py / pz) + intrinsics[1, 2]) |
| return px, py |
|
|
|
|
| def _get_action(obs_tp1: Observation): |
| quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:]) |
| if quat[-1] < 0: |
| quat = -quat |
| return np.concatenate( |
| [obs_tp1.gripper_pose[:3], quat, [float(obs_tp1.gripper_open)]] |
| ) |
|
|
|
|
| def _add_keypoints_to_replay( |
| replay: ReplayBuffer, |
| inital_obs: Observation, |
| demo: Demo, |
| env: CustomRLBenchEnv, |
| episode_keypoints: List[int], |
| cameras: List[str], |
| ): |
| prev_action = None |
| obs = inital_obs |
| all_actions = [] |
| for k, keypoint in enumerate(episode_keypoints): |
| obs_tp1 = demo[keypoint] |
| action = _get_action(obs_tp1) |
| all_actions.append(action) |
| terminal = k == len(episode_keypoints) - 1 |
| reward = float(terminal) * REWARD_SCALE if terminal else 0 |
| obs_dict = env.extract_obs(obs, t=k, prev_action=prev_action) |
| prev_action = np.copy(action) |
| others = {"demo": True} |
| final_obs = {} |
| for name in cameras: |
| px, py = _point_to_pixel_index( |
| obs_tp1.gripper_pose[:3], |
| obs_tp1.misc["%s_camera_extrinsics" % name], |
| obs_tp1.misc["%s_camera_intrinsics" % name], |
| ) |
| final_obs["%s_pixel_coord" % name] = [py, px] |
| others.update(final_obs) |
| others.update(obs_dict) |
| timeout = False |
| replay.add(action, reward, terminal, timeout, **others) |
| obs = obs_tp1 |
| |
| obs_dict_tp1 = env.extract_obs(obs_tp1, t=k + 1, prev_action=prev_action) |
| obs_dict_tp1.update(final_obs) |
| replay.add_final(**obs_dict_tp1) |
| return all_actions |
|
|
|
|
| def fill_replay( |
| replay: ReplayBuffer, |
| task: str, |
| env: CustomRLBenchEnv, |
| num_demos: int, |
| demo_augmentation: bool, |
| demo_augmentation_every_n: int, |
| cameras: List[str], |
| ): |
| logging.info("Filling replay with demos...") |
| all_actions = [] |
| for d_idx in range(num_demos): |
| demo = env.env.get_demos( |
| task, |
| 1, |
| variation_number=0, |
| random_selection=False, |
| from_episode_number=d_idx, |
| )[0] |
| episode_keypoints = demo_loading_utils.keypoint_discovery(demo) |
|
|
| for i in range(len(demo) - 1): |
| if not demo_augmentation and i > 0: |
| break |
| if i % demo_augmentation_every_n != 0: |
| continue |
| obs = demo[i] |
| |
| while len(episode_keypoints) > 0 and i >= episode_keypoints[0]: |
| episode_keypoints = episode_keypoints[1:] |
| if len(episode_keypoints) == 0: |
| break |
| all_actions.extend( |
| _add_keypoints_to_replay( |
| replay, obs, demo, env, episode_keypoints, cameras |
| ) |
| ) |
| logging.info("Replay filled with demos.") |
| return all_actions |
|
|
|
|
| class SharedNet(nn.Module): |
| def __init__(self, activation: str, norm: str = None): |
| super(SharedNet, self).__init__() |
| self._activation = activation |
| self._norm = norm |
|
|
| def build(self): |
| self._rgb_pre = nn.Sequential( |
| Conv2DBlock(3, 32, 3, 1, activation=self._activation, norm=self._norm), |
| ) |
| self._pcd_pre = nn.Sequential( |
| Conv2DBlock(3, 32, 3, 1, activation=self._activation, norm=self._norm), |
| ) |
|
|
| def forward(self, observations): |
| x_rgb, x_pcd = self._rgb_pre(observations[0]), self._pcd_pre(observations[1]) |
| x = torch.cat([x_rgb, x_pcd], dim=1) |
| return x |
|
|
|
|
| class ActorNet(nn.Module): |
| def __init__(self, activation: str, low_dim_size: int, norm: str = None): |
| super(ActorNet, self).__init__() |
| self._activation = activation |
| self._low_dim_size = low_dim_size |
| self._norm = norm |
|
|
| def build(self): |
| self._convs = nn.Sequential( |
| Conv2DBlock( |
| 64 + self._low_dim_size, |
| 64, |
| 1, |
| 1, |
| activation=self._activation, |
| norm=self._norm, |
| ), |
| Conv2DBlock(64, 64, 3, 1, activation=self._activation, norm=self._norm), |
| ) |
| self._fcs = nn.Sequential( |
| DenseBlock(64, 64, activation=self._activation), |
| DenseBlock(64, 64, activation=self._activation), |
| DenseBlock(64, 8 * 2), |
| ) |
| self._maxp = nn.AdaptiveMaxPool2d(1) |
|
|
| def forward(self, observation_feats, low_dim_ins): |
| low_dim_feats = low_dim_ins |
| _, _, h, w = observation_feats.shape |
| low_dim_feats = low_dim_feats.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w) |
| x = torch.cat([observation_feats, low_dim_feats], dim=1) |
| x = self._convs(x) |
| x = self._maxp(x).squeeze(-1).squeeze(-1) |
| x = self._fcs(x) |
| return x |
|
|
|
|
| class CriticNet(nn.Module): |
| def __init__( |
| self, activation: str, low_dim_size: int, norm: str = None, q_conf: bool = True |
| ): |
| super(CriticNet, self).__init__() |
| self._activation = activation |
| self._low_dim_size = low_dim_size |
| self._norm = norm |
| self._q_conf = q_conf |
|
|
| def build(self): |
| self._convs = nn.Sequential( |
| Conv2DBlock( |
| 64 + self._low_dim_size, 128, 3, 1, self._norm, self._activation |
| ), |
| Conv2DBlock(128, 128, 3, 1, self._norm, self._activation), |
| Conv2DBlock(128, 128, 3, 1, self._norm, self._activation), |
| Conv2DBlock(128, 128, 3, 1, self._norm, self._activation), |
| ) |
| if self._q_conf: |
| self._final_conv = Conv2DBlock(128, 2, 3, 1) |
| else: |
| self._maxp = nn.AdaptiveMaxPool2d(1) |
| self._fcs = nn.Sequential( |
| DenseBlock(128, 64, activation=self._activation), |
| DenseBlock(64, 1), |
| ) |
|
|
| def forward(self, observation_feats, low_dim_ins): |
| low_dim_feats = low_dim_ins |
| _, _, h, w = observation_feats.shape |
| low_dim_feats = low_dim_feats.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w) |
| x = torch.cat([observation_feats, low_dim_feats], dim=1) |
| x = self._convs(x) |
| if self._q_conf: |
| x = self._final_conv(x) |
| x[:, 1] = torch.sigmoid(x[:, 1]) |
| else: |
| x = self._maxp(x).squeeze(-1).squeeze(-1) |
| x = self._fcs(x) |
| return x |
|
|
|
|
| class Qattention2DNet(nn.Module): |
| def __init__( |
| self, |
| siamese_net: SiameseNet, |
| filters: List[int], |
| kernel_sizes: List[int], |
| strides: List[int], |
| low_dim_state_len: int, |
| norm: str = None, |
| activation: str = "relu", |
| output_channels: int = 1, |
| skip_connections: bool = True, |
| ): |
| super(Qattention2DNet, self).__init__() |
| self._siamese_net = copy.deepcopy(siamese_net) |
| self._input_channels = self._siamese_net.output_channels + low_dim_state_len |
| self._filters = filters |
| self._kernel_sizes = kernel_sizes |
| self._strides = strides |
| self._norm = norm |
| self._activation = activation |
| self._output_channels = output_channels |
| self._skip_connections = skip_connections |
| self._build_calls = 0 |
|
|
| def build(self): |
| self._build_calls += 1 |
| if self._build_calls != 1: |
| raise RuntimeError("Build needs to be called once.") |
| self._siamese_net.build() |
| self._down = [] |
| ch = self._input_channels |
| for filt, ksize, stride in zip( |
| self._filters, self._kernel_sizes, self._strides |
| ): |
| conv_block = Conv2DBlock( |
| ch, |
| filt, |
| ksize, |
| stride, |
| self._norm, |
| self._activation, |
| padding_mode="replicate", |
| ) |
| ch = filt |
| self._down.append(conv_block) |
| self._down = nn.ModuleList(self._down) |
|
|
| reverse_conv_data = list(zip(self._filters, self._kernel_sizes, self._strides)) |
| reverse_conv_data.reverse() |
|
|
| self._up = [] |
| for i, (filt, ksize, stride) in enumerate(reverse_conv_data): |
| if i > 0 and self._skip_connections: |
| ch += reverse_conv_data[-i - 1][0] |
| convt_block = Conv2DUpsampleBlock( |
| ch, filt, ksize, stride, self._norm, self._activation |
| ) |
| ch = filt |
| self._up.append(convt_block) |
| self._up = nn.ModuleList(self._up) |
|
|
| self._final_conv = Conv2DBlock( |
| ch, self._output_channels, 3, 1, padding_mode="replicate" |
| ) |
|
|
| def forward(self, observations, low_dim_ins): |
| x = self._siamese_net(observations) |
| _, _, h, w = x.shape |
| if low_dim_ins is not None: |
| low_dim_latents = low_dim_ins.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w) |
| x = torch.cat([x, low_dim_latents], dim=1) |
| self.ups = [] |
| self.downs = [] |
| layers_for_skip = [] |
| for l in self._down: |
| x = l(x) |
| layers_for_skip.append(x) |
| self.downs.append(x) |
| self.latent = x |
| layers_for_skip.reverse() |
| for i, l in enumerate(self._up): |
| if i > 0 and self._skip_connections: |
| |
| x = torch.cat([layers_for_skip[i], x], 1) |
| x = l(x) |
| self.ups.append(x) |
| x = self._final_conv(x) |
| return x |
|
|
|
|
| def create_agent( |
| camera_name: str, |
| activation: str, |
| q_conf: bool, |
| action_min_max, |
| alpha, |
| alpha_lr, |
| alpha_auto_tune, |
| critic_lr, |
| actor_lr, |
| next_best_pose_critic_weight_decay, |
| next_best_pose_actor_weight_decay, |
| crop_shape, |
| next_best_pose_tau, |
| next_best_pose_critic_grad_clip, |
| next_best_pose_actor_grad_clip, |
| qattention_tau, |
| qattention_lr, |
| qattention_weight_decay, |
| qattention_lambda_qreg, |
| low_dim_state_len, |
| qattention_grad_clip, |
| ): |
| siamese_net = SiameseNet( |
| input_channels=[3, 3], |
| filters=[8], |
| kernel_sizes=[5], |
| strides=[1], |
| activation=activation, |
| norm=None, |
| ) |
| qattention_net = Qattention2DNet( |
| siamese_net=siamese_net, |
| filters=[16, 16], |
| kernel_sizes=[5, 5], |
| strides=[2, 2], |
| output_channels=1, |
| norm=None, |
| activation=activation, |
| skip_connections=True, |
| low_dim_state_len=0, |
| ) |
|
|
| qattention_agent = QAttentionAgent( |
| pixel_unet=qattention_net, |
| tau=qattention_tau, |
| camera_name=camera_name, |
| lr=qattention_lr, |
| weight_decay=qattention_weight_decay, |
| lambda_qreg=qattention_lambda_qreg, |
| include_low_dim_state=False, |
| grad_clip=qattention_grad_clip, |
| ) |
|
|
| shared_net = SharedNet(activation, norm="layer") |
| critic_net = CriticNet( |
| activation, low_dim_state_len + 8, norm="layer", q_conf=q_conf |
| ) |
| actor_net = ActorNet(activation, low_dim_state_len) |
|
|
| next_best_pose_agent = NextBestPoseAgent( |
| qattention_agent=qattention_agent, |
| shared_network=shared_net, |
| critic_network=critic_net, |
| actor_network=actor_net, |
| action_min_max=action_min_max, |
| camera_name=camera_name, |
| alpha=alpha, |
| alpha_lr=alpha_lr, |
| alpha_auto_tune=alpha_auto_tune, |
| critic_lr=critic_lr, |
| actor_lr=actor_lr, |
| critic_weight_decay=next_best_pose_critic_weight_decay, |
| actor_weight_decay=next_best_pose_actor_weight_decay, |
| crop_shape=crop_shape, |
| critic_tau=next_best_pose_tau, |
| critic_grad_clip=next_best_pose_critic_grad_clip, |
| actor_grad_clip=next_best_pose_actor_grad_clip, |
| q_conf=q_conf, |
| ) |
|
|
| return PreprocessAgent(pose_agent=next_best_pose_agent) |
|
|