Equidiff / equidiff /equi_diffpo /model /vision /pointnet_extractor.py
Lillianwei's picture
mimicgen
c1f1d32
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import copy
from typing import Optional, Dict, Tuple, Union, List, Type
from termcolor import cprint
def create_mlp(
input_dim: int,
output_dim: int,
net_arch: List[int],
activation_fn: Type[nn.Module] = nn.ReLU,
squash_output: bool = False,
) -> List[nn.Module]:
"""
Create a multi layer perceptron (MLP), which is
a collection of fully-connected layers each followed by an activation function.
:param input_dim: Dimension of the input vector
:param output_dim:
:param net_arch: Architecture of the neural net
It represents the number of units per layer.
The length of this list is the number of layers.
:param activation_fn: The activation function
to use after each layer.
:param squash_output: Whether to squash the output using a Tanh
activation function
:return:
"""
if len(net_arch) > 0:
modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()]
else:
modules = []
for idx in range(len(net_arch) - 1):
modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1]))
modules.append(activation_fn())
if output_dim > 0:
last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim
modules.append(nn.Linear(last_layer_dim, output_dim))
if squash_output:
modules.append(nn.Tanh())
return modules
class PointNetEncoderXYZRGB(nn.Module):
"""Encoder for Pointcloud
"""
def __init__(self,
in_channels: int,
out_channels: int=1024,
use_layernorm: bool=False,
final_norm: str='none',
use_projection: bool=True,
**kwargs
):
"""_summary_
Args:
in_channels (int): feature size of input (3 or 6)
input_transform (bool, optional): whether to use transformation for coordinates. Defaults to True.
feature_transform (bool, optional): whether to use transformation for features. Defaults to True.
is_seg (bool, optional): for segmentation or classification. Defaults to False.
"""
super().__init__()
block_channel = [64, 128, 256, 512]
cprint("pointnet use_layernorm: {}".format(use_layernorm), 'cyan')
cprint("pointnet use_final_norm: {}".format(final_norm), 'cyan')
self.mlp = nn.Sequential(
nn.Linear(in_channels, block_channel[0]),
nn.LayerNorm(block_channel[0]) if use_layernorm else nn.Identity(),
nn.ReLU(),
nn.Linear(block_channel[0], block_channel[1]),
nn.LayerNorm(block_channel[1]) if use_layernorm else nn.Identity(),
nn.ReLU(),
nn.Linear(block_channel[1], block_channel[2]),
nn.LayerNorm(block_channel[2]) if use_layernorm else nn.Identity(),
nn.ReLU(),
nn.Linear(block_channel[2], block_channel[3]),
)
if final_norm == 'layernorm':
self.final_projection = nn.Sequential(
nn.Linear(block_channel[-1], out_channels),
nn.LayerNorm(out_channels)
)
elif final_norm == 'none':
self.final_projection = nn.Linear(block_channel[-1], out_channels)
else:
raise NotImplementedError(f"final_norm: {final_norm}")
def forward(self, x):
x = self.mlp(x)
x = torch.max(x, 1)[0]
x = self.final_projection(x)
return x
class PointNetEncoderXYZ(nn.Module):
"""Encoder for Pointcloud
"""
def __init__(self,
in_channels: int=3,
out_channels: int=1024,
use_layernorm: bool=False,
final_norm: str='none',
use_projection: bool=True,
**kwargs
):
"""_summary_
Args:
in_channels (int): feature size of input (3 or 6)
input_transform (bool, optional): whether to use transformation for coordinates. Defaults to True.
feature_transform (bool, optional): whether to use transformation for features. Defaults to True.
is_seg (bool, optional): for segmentation or classification. Defaults to False.
"""
super().__init__()
block_channel = [64, 128, 256]
cprint("[PointNetEncoderXYZ] use_layernorm: {}".format(use_layernorm), 'cyan')
cprint("[PointNetEncoderXYZ] use_final_norm: {}".format(final_norm), 'cyan')
assert in_channels == 3, cprint(f"PointNetEncoderXYZ only supports 3 channels, but got {in_channels}", "red")
self.mlp = nn.Sequential(
nn.Linear(in_channels, block_channel[0]),
nn.LayerNorm(block_channel[0]) if use_layernorm else nn.Identity(),
nn.ReLU(),
nn.Linear(block_channel[0], block_channel[1]),
nn.LayerNorm(block_channel[1]) if use_layernorm else nn.Identity(),
nn.ReLU(),
nn.Linear(block_channel[1], block_channel[2]),
nn.LayerNorm(block_channel[2]) if use_layernorm else nn.Identity(),
nn.ReLU(),
)
if final_norm == 'layernorm':
self.final_projection = nn.Sequential(
nn.Linear(block_channel[-1], out_channels),
nn.LayerNorm(out_channels)
)
elif final_norm == 'none':
self.final_projection = nn.Linear(block_channel[-1], out_channels)
else:
raise NotImplementedError(f"final_norm: {final_norm}")
self.use_projection = use_projection
if not use_projection:
self.final_projection = nn.Identity()
cprint("[PointNetEncoderXYZ] not use projection", "yellow")
VIS_WITH_GRAD_CAM = False
if VIS_WITH_GRAD_CAM:
self.gradient = None
self.feature = None
self.input_pointcloud = None
self.mlp[0].register_forward_hook(self.save_input)
self.mlp[6].register_forward_hook(self.save_feature)
self.mlp[6].register_backward_hook(self.save_gradient)
def forward(self, x):
x = self.mlp(x)
x = torch.max(x, 1)[0]
x = self.final_projection(x)
return x
def save_gradient(self, module, grad_input, grad_output):
"""
for grad-cam
"""
self.gradient = grad_output[0]
def save_feature(self, module, input, output):
"""
for grad-cam
"""
if isinstance(output, tuple):
self.feature = output[0].detach()
else:
self.feature = output.detach()
def save_input(self, module, input, output):
"""
for grad-cam
"""
self.input_pointcloud = input[0].detach()
class DP3Encoder(nn.Module):
def __init__(self,
observation_space: Dict,
img_crop_shape=None,
out_channel=256,
state_mlp_size=(64, 64), state_mlp_activation_fn=nn.ReLU,
pointcloud_encoder_cfg=None,
use_pc_color=False,
pointnet_type='pointnet',
state_keys=['robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos']
):
super().__init__()
self.imagination_key = 'imagin_robot'
self.state_keys = state_keys
self.point_cloud_key = 'point_cloud'
self.rgb_image_key = 'image'
self.n_output_channels = out_channel
self.use_imagined_robot = self.imagination_key in observation_space.keys()
self.point_cloud_shape = observation_space[self.point_cloud_key]
self.state_size = sum([observation_space[key][0] for key in self.state_keys])
if self.use_imagined_robot:
self.imagination_shape = observation_space[self.imagination_key]
else:
self.imagination_shape = None
cprint(f"[DP3Encoder] point cloud shape: {self.point_cloud_shape}", "yellow")
cprint(f"[DP3Encoder] state shape: {self.state_size}", "yellow")
cprint(f"[DP3Encoder] imagination point shape: {self.imagination_shape}", "yellow")
self.use_pc_color = use_pc_color
self.pointnet_type = pointnet_type
if pointnet_type == "pointnet":
if use_pc_color:
pointcloud_encoder_cfg.in_channels = 6
self.extractor = PointNetEncoderXYZRGB(**pointcloud_encoder_cfg)
else:
pointcloud_encoder_cfg.in_channels = 3
self.extractor = PointNetEncoderXYZ(**pointcloud_encoder_cfg)
else:
raise NotImplementedError(f"pointnet_type: {pointnet_type}")
if len(state_mlp_size) == 0:
raise RuntimeError(f"State mlp size is empty")
elif len(state_mlp_size) == 1:
net_arch = []
else:
net_arch = state_mlp_size[:-1]
output_dim = state_mlp_size[-1]
self.n_output_channels += output_dim
self.state_mlp = nn.Sequential(*create_mlp(self.state_size, output_dim, net_arch, state_mlp_activation_fn))
cprint(f"[DP3Encoder] output dim: {self.n_output_channels}", "red")
def forward(self, observations: Dict) -> torch.Tensor:
points = observations[self.point_cloud_key]
assert len(points.shape) == 3, cprint(f"point cloud shape: {points.shape}, length should be 3", "red")
if self.use_imagined_robot:
img_points = observations[self.imagination_key][..., :points.shape[-1]] # align the last dim
points = torch.concat([points, img_points], dim=1)
# points = torch.transpose(points, 1, 2) # B * 3 * N
# points: B * 3 * (N + sum(Ni))
pn_feat = self.extractor(points) # B * out_channel
state = torch.cat([observations[key] for key in self.state_keys], dim=-1)
state_feat = self.state_mlp(state) # B * 64
final_feat = torch.cat([pn_feat, state_feat], dim=-1)
return final_feat
def output_shape(self):
return self.n_output_channels