lsnu's picture
Add files using upload-large-folder tool
0d89eb9 verified
import copy
import logging
from typing import List
import numpy as np
import torch
import torch.nn as nn
from rlbench.backend.observation import Observation
from rlbench.demo import Demo
from yarr.replay_buffer.prioritized_replay_buffer import (
PrioritizedReplayBuffer,
ObservationElement,
)
from yarr.replay_buffer.replay_buffer import ReplayElement, ReplayBuffer
from yarr.replay_buffer.uniform_replay_buffer import UniformReplayBuffer
from helpers import demo_loading_utils, utils
from helpers.custom_rlbench_env import CustomRLBenchEnv
from helpers.network_utils import (
SiameseNet,
DenseBlock,
Conv2DBlock,
Conv2DUpsampleBlock,
)
from helpers.preprocess_agent import PreprocessAgent
from agents.arm.next_best_pose_agent import NextBestPoseAgent
from agents.arm.qattention_agent import QAttentionAgent
REWARD_SCALE = 100.0
def create_replay(
batch_size: int,
timesteps: int,
prioritisation: bool,
save_dir: str,
cameras: list,
env: CustomRLBenchEnv,
):
observation_elements = env.observation_elements
for cname in cameras:
observation_elements.extend(
[
ObservationElement("%s_pixel_coord" % cname, (2,), np.int32),
]
)
replay_class = UniformReplayBuffer
if prioritisation:
replay_class = PrioritizedReplayBuffer
replay_buffer = replay_class(
save_dir=save_dir,
batch_size=batch_size,
timesteps=timesteps,
replay_capacity=int(1e5),
action_shape=(8,),
action_dtype=np.float32,
reward_shape=(),
reward_dtype=np.float32,
update_horizon=1,
observation_elements=observation_elements,
extra_replay_elements=[ReplayElement("demo", (), np.bool)],
)
return replay_buffer
def _point_to_pixel_index(
point: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray
):
point = np.array([point[0], point[1], point[2], 1])
world_to_cam = np.linalg.inv(extrinsics)
point_in_cam_frame = world_to_cam.dot(point)
px, py, pz = point_in_cam_frame[:3]
px = 2 * intrinsics[0, 2] - int(-intrinsics[0, 0] * (px / pz) + intrinsics[0, 2])
py = 2 * intrinsics[1, 2] - int(-intrinsics[1, 1] * (py / pz) + intrinsics[1, 2])
return px, py
def _get_action(obs_tp1: Observation):
quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
if quat[-1] < 0:
quat = -quat
return np.concatenate(
[obs_tp1.gripper_pose[:3], quat, [float(obs_tp1.gripper_open)]]
)
def _add_keypoints_to_replay(
replay: ReplayBuffer,
inital_obs: Observation,
demo: Demo,
env: CustomRLBenchEnv,
episode_keypoints: List[int],
cameras: List[str],
):
prev_action = None
obs = inital_obs
all_actions = []
for k, keypoint in enumerate(episode_keypoints):
obs_tp1 = demo[keypoint]
action = _get_action(obs_tp1)
all_actions.append(action)
terminal = k == len(episode_keypoints) - 1
reward = float(terminal) * REWARD_SCALE if terminal else 0
obs_dict = env.extract_obs(obs, t=k, prev_action=prev_action)
prev_action = np.copy(action)
others = {"demo": True}
final_obs = {}
for name in cameras:
px, py = _point_to_pixel_index(
obs_tp1.gripper_pose[:3],
obs_tp1.misc["%s_camera_extrinsics" % name],
obs_tp1.misc["%s_camera_intrinsics" % name],
)
final_obs["%s_pixel_coord" % name] = [py, px]
others.update(final_obs)
others.update(obs_dict)
timeout = False
replay.add(action, reward, terminal, timeout, **others)
obs = obs_tp1 # Set the next obs
# Final step
obs_dict_tp1 = env.extract_obs(obs_tp1, t=k + 1, prev_action=prev_action)
obs_dict_tp1.update(final_obs)
replay.add_final(**obs_dict_tp1)
return all_actions
def fill_replay(
replay: ReplayBuffer,
task: str,
env: CustomRLBenchEnv,
num_demos: int,
demo_augmentation: bool,
demo_augmentation_every_n: int,
cameras: List[str],
):
logging.info("Filling replay with demos...")
all_actions = []
for d_idx in range(num_demos):
demo = env.env.get_demos(
task,
1,
variation_number=0,
random_selection=False,
from_episode_number=d_idx,
)[0]
episode_keypoints = demo_loading_utils.keypoint_discovery(demo)
for i in range(len(demo) - 1):
if not demo_augmentation and i > 0:
break
if i % demo_augmentation_every_n != 0:
continue
obs = demo[i]
# If our starting point is past one of the keypoints, then remove it
while len(episode_keypoints) > 0 and i >= episode_keypoints[0]:
episode_keypoints = episode_keypoints[1:]
if len(episode_keypoints) == 0:
break
all_actions.extend(
_add_keypoints_to_replay(
replay, obs, demo, env, episode_keypoints, cameras
)
)
logging.info("Replay filled with demos.")
return all_actions
class SharedNet(nn.Module):
def __init__(self, activation: str, norm: str = None):
super(SharedNet, self).__init__()
self._activation = activation
self._norm = norm
def build(self):
self._rgb_pre = nn.Sequential(
Conv2DBlock(3, 32, 3, 1, activation=self._activation, norm=self._norm),
)
self._pcd_pre = nn.Sequential(
Conv2DBlock(3, 32, 3, 1, activation=self._activation, norm=self._norm),
)
def forward(self, observations):
x_rgb, x_pcd = self._rgb_pre(observations[0]), self._pcd_pre(observations[1])
x = torch.cat([x_rgb, x_pcd], dim=1)
return x
class ActorNet(nn.Module):
def __init__(self, activation: str, low_dim_size: int, norm: str = None):
super(ActorNet, self).__init__()
self._activation = activation
self._low_dim_size = low_dim_size
self._norm = norm
def build(self):
self._convs = nn.Sequential(
Conv2DBlock(
64 + self._low_dim_size,
64,
1,
1,
activation=self._activation,
norm=self._norm,
),
Conv2DBlock(64, 64, 3, 1, activation=self._activation, norm=self._norm),
)
self._fcs = nn.Sequential(
DenseBlock(64, 64, activation=self._activation),
DenseBlock(64, 64, activation=self._activation),
DenseBlock(64, 8 * 2),
)
self._maxp = nn.AdaptiveMaxPool2d(1)
def forward(self, observation_feats, low_dim_ins):
low_dim_feats = low_dim_ins
_, _, h, w = observation_feats.shape
low_dim_feats = low_dim_feats.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w)
x = torch.cat([observation_feats, low_dim_feats], dim=1)
x = self._convs(x)
x = self._maxp(x).squeeze(-1).squeeze(-1)
x = self._fcs(x)
return x
class CriticNet(nn.Module):
def __init__(
self, activation: str, low_dim_size: int, norm: str = None, q_conf: bool = True
):
super(CriticNet, self).__init__()
self._activation = activation
self._low_dim_size = low_dim_size
self._norm = norm
self._q_conf = q_conf
def build(self):
self._convs = nn.Sequential(
Conv2DBlock(
64 + self._low_dim_size, 128, 3, 1, self._norm, self._activation
),
Conv2DBlock(128, 128, 3, 1, self._norm, self._activation),
Conv2DBlock(128, 128, 3, 1, self._norm, self._activation),
Conv2DBlock(128, 128, 3, 1, self._norm, self._activation),
)
if self._q_conf:
self._final_conv = Conv2DBlock(128, 2, 3, 1)
else:
self._maxp = nn.AdaptiveMaxPool2d(1)
self._fcs = nn.Sequential(
DenseBlock(128, 64, activation=self._activation),
DenseBlock(64, 1),
)
def forward(self, observation_feats, low_dim_ins):
low_dim_feats = low_dim_ins
_, _, h, w = observation_feats.shape
low_dim_feats = low_dim_feats.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w)
x = torch.cat([observation_feats, low_dim_feats], dim=1)
x = self._convs(x)
if self._q_conf:
x = self._final_conv(x)
x[:, 1] = torch.sigmoid(x[:, 1])
else:
x = self._maxp(x).squeeze(-1).squeeze(-1)
x = self._fcs(x)
return x
class Qattention2DNet(nn.Module):
def __init__(
self,
siamese_net: SiameseNet,
filters: List[int],
kernel_sizes: List[int],
strides: List[int],
low_dim_state_len: int,
norm: str = None,
activation: str = "relu",
output_channels: int = 1,
skip_connections: bool = True,
):
super(Qattention2DNet, self).__init__()
self._siamese_net = copy.deepcopy(siamese_net)
self._input_channels = self._siamese_net.output_channels + low_dim_state_len
self._filters = filters
self._kernel_sizes = kernel_sizes
self._strides = strides
self._norm = norm
self._activation = activation
self._output_channels = output_channels
self._skip_connections = skip_connections
self._build_calls = 0
def build(self):
self._build_calls += 1
if self._build_calls != 1:
raise RuntimeError("Build needs to be called once.")
self._siamese_net.build()
self._down = []
ch = self._input_channels
for filt, ksize, stride in zip(
self._filters, self._kernel_sizes, self._strides
):
conv_block = Conv2DBlock(
ch,
filt,
ksize,
stride,
self._norm,
self._activation,
padding_mode="replicate",
)
ch = filt
self._down.append(conv_block)
self._down = nn.ModuleList(self._down)
reverse_conv_data = list(zip(self._filters, self._kernel_sizes, self._strides))
reverse_conv_data.reverse()
self._up = []
for i, (filt, ksize, stride) in enumerate(reverse_conv_data):
if i > 0 and self._skip_connections:
ch += reverse_conv_data[-i - 1][0]
convt_block = Conv2DUpsampleBlock(
ch, filt, ksize, stride, self._norm, self._activation
)
ch = filt
self._up.append(convt_block)
self._up = nn.ModuleList(self._up)
self._final_conv = Conv2DBlock(
ch, self._output_channels, 3, 1, padding_mode="replicate"
)
def forward(self, observations, low_dim_ins):
x = self._siamese_net(observations)
_, _, h, w = x.shape
if low_dim_ins is not None:
low_dim_latents = low_dim_ins.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, h, w)
x = torch.cat([x, low_dim_latents], dim=1)
self.ups = []
self.downs = []
layers_for_skip = []
for l in self._down:
x = l(x)
layers_for_skip.append(x)
self.downs.append(x)
self.latent = x
layers_for_skip.reverse()
for i, l in enumerate(self._up):
if i > 0 and self._skip_connections:
# Skip connections. Skip the first up layer.
x = torch.cat([layers_for_skip[i], x], 1)
x = l(x)
self.ups.append(x)
x = self._final_conv(x)
return x
def create_agent(
camera_name: str,
activation: str,
q_conf: bool,
action_min_max,
alpha,
alpha_lr,
alpha_auto_tune,
critic_lr,
actor_lr,
next_best_pose_critic_weight_decay,
next_best_pose_actor_weight_decay,
crop_shape,
next_best_pose_tau,
next_best_pose_critic_grad_clip,
next_best_pose_actor_grad_clip,
qattention_tau,
qattention_lr,
qattention_weight_decay,
qattention_lambda_qreg,
low_dim_state_len,
qattention_grad_clip,
):
siamese_net = SiameseNet(
input_channels=[3, 3],
filters=[8],
kernel_sizes=[5],
strides=[1],
activation=activation,
norm=None,
)
qattention_net = Qattention2DNet(
siamese_net=siamese_net,
filters=[16, 16],
kernel_sizes=[5, 5],
strides=[2, 2],
output_channels=1,
norm=None,
activation=activation,
skip_connections=True,
low_dim_state_len=0,
)
qattention_agent = QAttentionAgent(
pixel_unet=qattention_net,
tau=qattention_tau,
camera_name=camera_name,
lr=qattention_lr,
weight_decay=qattention_weight_decay,
lambda_qreg=qattention_lambda_qreg,
include_low_dim_state=False,
grad_clip=qattention_grad_clip,
)
shared_net = SharedNet(activation, norm="layer")
critic_net = CriticNet(
activation, low_dim_state_len + 8, norm="layer", q_conf=q_conf
)
actor_net = ActorNet(activation, low_dim_state_len)
next_best_pose_agent = NextBestPoseAgent(
qattention_agent=qattention_agent,
shared_network=shared_net,
critic_network=critic_net,
actor_network=actor_net,
action_min_max=action_min_max,
camera_name=camera_name,
alpha=alpha,
alpha_lr=alpha_lr,
alpha_auto_tune=alpha_auto_tune,
critic_lr=critic_lr,
actor_lr=actor_lr,
critic_weight_decay=next_best_pose_critic_weight_decay,
actor_weight_decay=next_best_pose_actor_weight_decay,
crop_shape=crop_shape,
critic_tau=next_best_pose_tau,
critic_grad_clip=next_best_pose_critic_grad_clip,
actor_grad_clip=next_best_pose_actor_grad_clip,
q_conf=q_conf,
)
return PreprocessAgent(pose_agent=next_best_pose_agent)