| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torchvision |
| | import copy |
| |
|
| | from typing import Optional, Dict, Tuple, Union, List, Type |
| | from termcolor import cprint |
| |
|
| |
|
| | def create_mlp( |
| | input_dim: int, |
| | output_dim: int, |
| | net_arch: List[int], |
| | activation_fn: Type[nn.Module] = nn.ReLU, |
| | squash_output: bool = False, |
| | ) -> List[nn.Module]: |
| | """ |
| | Create a multi layer perceptron (MLP), which is |
| | a collection of fully-connected layers each followed by an activation function. |
| | |
| | :param input_dim: Dimension of the input vector |
| | :param output_dim: |
| | :param net_arch: Architecture of the neural net |
| | It represents the number of units per layer. |
| | The length of this list is the number of layers. |
| | :param activation_fn: The activation function |
| | to use after each layer. |
| | :param squash_output: Whether to squash the output using a Tanh |
| | activation function |
| | :return: |
| | """ |
| |
|
| | if len(net_arch) > 0: |
| | modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()] |
| | else: |
| | modules = [] |
| |
|
| | for idx in range(len(net_arch) - 1): |
| | modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1])) |
| | modules.append(activation_fn()) |
| |
|
| | if output_dim > 0: |
| | last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim |
| | modules.append(nn.Linear(last_layer_dim, output_dim)) |
| | if squash_output: |
| | modules.append(nn.Tanh()) |
| | return modules |
| |
|
| |
|
| |
|
| |
|
| | class PointNetEncoderXYZRGB(nn.Module): |
| | """Encoder for Pointcloud |
| | """ |
| |
|
| | def __init__(self, |
| | in_channels: int, |
| | out_channels: int=1024, |
| | use_layernorm: bool=False, |
| | final_norm: str='none', |
| | use_projection: bool=True, |
| | **kwargs |
| | ): |
| | """_summary_ |
| | |
| | Args: |
| | in_channels (int): feature size of input (3 or 6) |
| | input_transform (bool, optional): whether to use transformation for coordinates. Defaults to True. |
| | feature_transform (bool, optional): whether to use transformation for features. Defaults to True. |
| | is_seg (bool, optional): for segmentation or classification. Defaults to False. |
| | """ |
| | super().__init__() |
| | block_channel = [64, 128, 256, 512] |
| | cprint("pointnet use_layernorm: {}".format(use_layernorm), 'cyan') |
| | cprint("pointnet use_final_norm: {}".format(final_norm), 'cyan') |
| | |
| | self.mlp = nn.Sequential( |
| | nn.Linear(in_channels, block_channel[0]), |
| | nn.LayerNorm(block_channel[0]) if use_layernorm else nn.Identity(), |
| | nn.ReLU(), |
| | nn.Linear(block_channel[0], block_channel[1]), |
| | nn.LayerNorm(block_channel[1]) if use_layernorm else nn.Identity(), |
| | nn.ReLU(), |
| | nn.Linear(block_channel[1], block_channel[2]), |
| | nn.LayerNorm(block_channel[2]) if use_layernorm else nn.Identity(), |
| | nn.ReLU(), |
| | nn.Linear(block_channel[2], block_channel[3]), |
| | ) |
| | |
| | |
| | if final_norm == 'layernorm': |
| | self.final_projection = nn.Sequential( |
| | nn.Linear(block_channel[-1], out_channels), |
| | nn.LayerNorm(out_channels) |
| | ) |
| | elif final_norm == 'none': |
| | self.final_projection = nn.Linear(block_channel[-1], out_channels) |
| | else: |
| | raise NotImplementedError(f"final_norm: {final_norm}") |
| | |
| | def forward(self, x): |
| | x = self.mlp(x) |
| | x = torch.max(x, 1)[0] |
| | x = self.final_projection(x) |
| | return x |
| | |
| |
|
| | class PointNetEncoderXYZ(nn.Module): |
| | """Encoder for Pointcloud |
| | """ |
| |
|
| | def __init__(self, |
| | in_channels: int=3, |
| | out_channels: int=1024, |
| | use_layernorm: bool=False, |
| | final_norm: str='none', |
| | use_projection: bool=True, |
| | **kwargs |
| | ): |
| | """_summary_ |
| | |
| | Args: |
| | in_channels (int): feature size of input (3 or 6) |
| | input_transform (bool, optional): whether to use transformation for coordinates. Defaults to True. |
| | feature_transform (bool, optional): whether to use transformation for features. Defaults to True. |
| | is_seg (bool, optional): for segmentation or classification. Defaults to False. |
| | """ |
| | super().__init__() |
| | block_channel = [64, 128, 256] |
| | cprint("[PointNetEncoderXYZ] use_layernorm: {}".format(use_layernorm), 'cyan') |
| | cprint("[PointNetEncoderXYZ] use_final_norm: {}".format(final_norm), 'cyan') |
| | |
| | assert in_channels == 3, cprint(f"PointNetEncoderXYZ only supports 3 channels, but got {in_channels}", "red") |
| | |
| | self.mlp = nn.Sequential( |
| | nn.Linear(in_channels, block_channel[0]), |
| | nn.LayerNorm(block_channel[0]) if use_layernorm else nn.Identity(), |
| | nn.ReLU(), |
| | nn.Linear(block_channel[0], block_channel[1]), |
| | nn.LayerNorm(block_channel[1]) if use_layernorm else nn.Identity(), |
| | nn.ReLU(), |
| | nn.Linear(block_channel[1], block_channel[2]), |
| | nn.LayerNorm(block_channel[2]) if use_layernorm else nn.Identity(), |
| | nn.ReLU(), |
| | ) |
| | |
| | |
| | if final_norm == 'layernorm': |
| | self.final_projection = nn.Sequential( |
| | nn.Linear(block_channel[-1], out_channels), |
| | nn.LayerNorm(out_channels) |
| | ) |
| | elif final_norm == 'none': |
| | self.final_projection = nn.Linear(block_channel[-1], out_channels) |
| | else: |
| | raise NotImplementedError(f"final_norm: {final_norm}") |
| |
|
| | self.use_projection = use_projection |
| | if not use_projection: |
| | self.final_projection = nn.Identity() |
| | cprint("[PointNetEncoderXYZ] not use projection", "yellow") |
| | |
| | VIS_WITH_GRAD_CAM = False |
| | if VIS_WITH_GRAD_CAM: |
| | self.gradient = None |
| | self.feature = None |
| | self.input_pointcloud = None |
| | self.mlp[0].register_forward_hook(self.save_input) |
| | self.mlp[6].register_forward_hook(self.save_feature) |
| | self.mlp[6].register_backward_hook(self.save_gradient) |
| | |
| | |
| | def forward(self, x): |
| | x = self.mlp(x) |
| | x = torch.max(x, 1)[0] |
| | x = self.final_projection(x) |
| | return x |
| | |
| | def save_gradient(self, module, grad_input, grad_output): |
| | """ |
| | for grad-cam |
| | """ |
| | self.gradient = grad_output[0] |
| |
|
| | def save_feature(self, module, input, output): |
| | """ |
| | for grad-cam |
| | """ |
| | if isinstance(output, tuple): |
| | self.feature = output[0].detach() |
| | else: |
| | self.feature = output.detach() |
| | |
| | def save_input(self, module, input, output): |
| | """ |
| | for grad-cam |
| | """ |
| | self.input_pointcloud = input[0].detach() |
| |
|
| | |
| |
|
| |
|
| | class DP3Encoder(nn.Module): |
| | def __init__(self, |
| | observation_space: Dict, |
| | img_crop_shape=None, |
| | out_channel=256, |
| | state_mlp_size=(64, 64), state_mlp_activation_fn=nn.ReLU, |
| | pointcloud_encoder_cfg=None, |
| | use_pc_color=False, |
| | pointnet_type='pointnet', |
| | state_keys=['robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos'] |
| | ): |
| | super().__init__() |
| | self.imagination_key = 'imagin_robot' |
| | self.state_keys = state_keys |
| | self.point_cloud_key = 'point_cloud' |
| | self.rgb_image_key = 'image' |
| | self.n_output_channels = out_channel |
| | |
| | self.use_imagined_robot = self.imagination_key in observation_space.keys() |
| | self.point_cloud_shape = observation_space[self.point_cloud_key] |
| | self.state_size = sum([observation_space[key][0] for key in self.state_keys]) |
| | if self.use_imagined_robot: |
| | self.imagination_shape = observation_space[self.imagination_key] |
| | else: |
| | self.imagination_shape = None |
| | |
| | |
| | |
| | cprint(f"[DP3Encoder] point cloud shape: {self.point_cloud_shape}", "yellow") |
| | cprint(f"[DP3Encoder] state shape: {self.state_size}", "yellow") |
| | cprint(f"[DP3Encoder] imagination point shape: {self.imagination_shape}", "yellow") |
| | |
| |
|
| | self.use_pc_color = use_pc_color |
| | self.pointnet_type = pointnet_type |
| | if pointnet_type == "pointnet": |
| | if use_pc_color: |
| | pointcloud_encoder_cfg.in_channels = 6 |
| | self.extractor = PointNetEncoderXYZRGB(**pointcloud_encoder_cfg) |
| | else: |
| | pointcloud_encoder_cfg.in_channels = 3 |
| | self.extractor = PointNetEncoderXYZ(**pointcloud_encoder_cfg) |
| | else: |
| | raise NotImplementedError(f"pointnet_type: {pointnet_type}") |
| |
|
| |
|
| | if len(state_mlp_size) == 0: |
| | raise RuntimeError(f"State mlp size is empty") |
| | elif len(state_mlp_size) == 1: |
| | net_arch = [] |
| | else: |
| | net_arch = state_mlp_size[:-1] |
| | output_dim = state_mlp_size[-1] |
| |
|
| | self.n_output_channels += output_dim |
| | self.state_mlp = nn.Sequential(*create_mlp(self.state_size, output_dim, net_arch, state_mlp_activation_fn)) |
| |
|
| | cprint(f"[DP3Encoder] output dim: {self.n_output_channels}", "red") |
| |
|
| |
|
| | def forward(self, observations: Dict) -> torch.Tensor: |
| | points = observations[self.point_cloud_key] |
| | assert len(points.shape) == 3, cprint(f"point cloud shape: {points.shape}, length should be 3", "red") |
| | if self.use_imagined_robot: |
| | img_points = observations[self.imagination_key][..., :points.shape[-1]] |
| | points = torch.concat([points, img_points], dim=1) |
| | |
| | |
| | |
| | pn_feat = self.extractor(points) |
| | |
| | state = torch.cat([observations[key] for key in self.state_keys], dim=-1) |
| | state_feat = self.state_mlp(state) |
| | final_feat = torch.cat([pn_feat, state_feat], dim=-1) |
| | return final_feat |
| |
|
| |
|
| | def output_shape(self): |
| | return self.n_output_channels |