| """ Adapted from https://github.com/dyson-ai/hdp/blob/main/rk_diffuser/models/pointnet.py """ |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.nn.parallel |
| import torch.utils.data |
| from torch.autograd import Variable |
| from diffusion_policy.common.pytorch_util import replace_submodules |
|
|
|
|
| class STN3d(nn.Module): |
| def __init__(self): |
| super(STN3d, self).__init__() |
| self.conv1 = torch.nn.Conv1d(3, 64, 1) |
| self.conv2 = torch.nn.Conv1d(64, 128, 1) |
| self.conv3 = torch.nn.Conv1d(128, 1024, 1) |
| self.fc1 = nn.Linear(1024, 512) |
| self.fc2 = nn.Linear(512, 256) |
| self.fc3 = nn.Linear(256, 9) |
| self.relu = nn.ReLU() |
|
|
| self.bn1 = nn.BatchNorm1d(64) |
| self.bn2 = nn.BatchNorm1d(128) |
| self.bn3 = nn.BatchNorm1d(1024) |
| |
| |
|
|
| self.bn4 = nn.LayerNorm(512) |
| self.bn5 = nn.LayerNorm(256) |
|
|
| def forward(self, x): |
| batchsize = x.size()[0] |
| x = F.relu(self.bn1(self.conv1(x))) |
| x = F.relu(self.bn2(self.conv2(x))) |
| x = F.relu(self.bn3(self.conv3(x))) |
| x = torch.max(x, 2, keepdim=True)[0] |
| x = x.view(-1, 1024) |
|
|
| x = F.relu(self.bn4(self.fc1(x))) |
| x = F.relu(self.bn5(self.fc2(x))) |
| x = self.fc3(x) |
|
|
| iden = ( |
| Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))) |
| .view(1, 9) |
| .repeat(batchsize, 1) |
| ) |
| if x.is_cuda: |
| iden = iden.cuda() |
| x = x + iden |
| x = x.view(-1, 3, 3) |
| return x |
|
|
|
|
| class STNkd(nn.Module): |
| def __init__(self, k=64): |
| super(STNkd, self).__init__() |
| self.conv1 = torch.nn.Conv1d(k, 64, 1) |
| self.conv2 = torch.nn.Conv1d(64, 128, 1) |
| self.conv3 = torch.nn.Conv1d(128, 1024, 1) |
| self.fc1 = nn.Linear(1024, 512) |
| self.fc2 = nn.Linear(512, 256) |
| self.fc3 = nn.Linear(256, k * k) |
| self.relu = nn.ReLU() |
|
|
| self.bn1 = nn.BatchNorm1d(64) |
| self.bn2 = nn.BatchNorm1d(128) |
| self.bn3 = nn.BatchNorm1d(1024) |
| |
| |
|
|
| self.bn4 = nn.LayerNorm(512) |
| self.bn5 = nn.LayerNorm(256) |
|
|
| self.k = k |
|
|
| def forward(self, x): |
| batchsize = x.size()[0] |
| x = F.relu(self.bn1(self.conv1(x))) |
| x = F.relu(self.bn2(self.conv2(x))) |
| x = F.relu(self.bn3(self.conv3(x))) |
| x = torch.max(x, 2, keepdim=True)[0] |
| x = x.view(-1, 1024) |
|
|
| x = F.relu(self.bn4(self.fc1(x))) |
| x = F.relu(self.bn5(self.fc2(x))) |
| x = self.fc3(x) |
|
|
| iden = ( |
| Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))) |
| .view(1, self.k * self.k) |
| .repeat(batchsize, 1) |
| ) |
| if x.is_cuda: |
| iden = iden.cuda() |
| x = x + iden |
| x = x.view(-1, self.k, self.k) |
| return x |
|
|
|
|
| class PointNetfeat(nn.Module): |
| def __init__(self, input_channels: int, input_transform: bool, feature_transform=False): |
| super(PointNetfeat, self).__init__() |
| self.input_transform = input_transform |
| if self.input_transform: |
| self.stn = STNkd(k=input_channels) |
| self.conv1 = torch.nn.Conv1d(input_channels, 64, 1) |
| self.conv2 = torch.nn.Conv1d(64, 128, 1) |
| self.conv3 = torch.nn.Conv1d(128, 1024, 1) |
| self.bn1 = nn.BatchNorm1d(64) |
| self.bn2 = nn.BatchNorm1d(128) |
| self.bn3 = nn.BatchNorm1d(1024) |
| self.feature_transform = feature_transform |
| if self.feature_transform: |
| self.fstn = STNkd(k=64) |
|
|
| def forward(self, x): |
| b = x.size(0) |
| if len(x.shape) == 4: |
| x = x.view(b, -1, 3).permute(0, 2, 1).contiguous() |
|
|
| if self.input_transform: |
| trans = self.stn(x) |
| x = x.transpose(2, 1) |
| x = torch.bmm(x, trans) |
| x = x.transpose(2, 1) |
| else: |
| trans = None |
|
|
| x = F.relu(self.bn1(self.conv1(x))) |
|
|
| if self.feature_transform: |
| trans_feat = self.fstn(x) |
| x = x.transpose(2, 1) |
| x = torch.bmm(x, trans_feat) |
| x = x.transpose(2, 1) |
| else: |
| trans_feat = None |
|
|
| x = F.relu(self.bn2(self.conv2(x))) |
| x = self.bn3(self.conv3(x)) |
| x = torch.max(x, 2, keepdim=True)[0] |
| x = x.view(-1, 1024) |
| return x |
|
|
|
|
| class PointNetCls(nn.Module): |
| def __init__(self, k=2, feature_transform=False): |
| super(PointNetCls, self).__init__() |
| self.feature_transform = feature_transform |
| self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform) |
| self.fc1 = nn.Linear(1024, 512) |
| self.fc2 = nn.Linear(512, 256) |
| self.fc3 = nn.Linear(256, k) |
| self.dropout = nn.Dropout(p=0.3) |
| self.bn1 = nn.BatchNorm1d(512) |
| self.bn2 = nn.BatchNorm1d(256) |
| self.relu = nn.ReLU() |
|
|
| def forward(self, x): |
| x, trans, trans_feat = self.feat(x) |
| x = F.relu(self.bn1(self.fc1(x))) |
| x = F.relu(self.bn2(self.dropout(self.fc2(x)))) |
| x = self.fc3(x) |
| return F.log_softmax(x, dim=1), trans, trans_feat |
|
|
|
|
| class PointNetDenseCls(nn.Module): |
| def __init__(self, k=2, feature_transform=False): |
| super(PointNetDenseCls, self).__init__() |
| self.k = k |
| self.feature_transform = feature_transform |
| self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform) |
| self.conv1 = torch.nn.Conv1d(1088, 512, 1) |
| self.conv2 = torch.nn.Conv1d(512, 256, 1) |
| self.conv3 = torch.nn.Conv1d(256, 128, 1) |
| self.conv4 = torch.nn.Conv1d(128, self.k, 1) |
| self.bn1 = nn.BatchNorm1d(512) |
| self.bn2 = nn.BatchNorm1d(256) |
| self.bn3 = nn.BatchNorm1d(128) |
|
|
| def forward(self, x): |
| batchsize = x.size()[0] |
| n_pts = x.size()[2] |
| x, trans, trans_feat = self.feat(x) |
| x = F.relu(self.bn1(self.conv1(x))) |
| x = F.relu(self.bn2(self.conv2(x))) |
| x = F.relu(self.bn3(self.conv3(x))) |
| x = self.conv4(x) |
| x = x.transpose(2, 1).contiguous() |
| x = F.log_softmax(x.view(-1, self.k), dim=-1) |
| x = x.view(batchsize, n_pts, self.k) |
| return x, trans, trans_feat |
|
|
|
|
| class PointNetBackbone(nn.Module): |
| def __init__( |
| self, |
| embed_dim: int, |
| input_channels: int, |
| input_transform: bool, |
| use_group_norm: bool = False, |
| ): |
| super().__init__() |
| assert input_channels in [3, 6], "Input channels must be 3 or 6" |
| self.backbone = nn.Sequential( |
| PointNetfeat(input_channels, input_transform), |
| nn.Mish(), |
| nn.Linear(1024, 512), |
| nn.Mish(), |
| nn.Linear(512, embed_dim), |
| ) |
| if use_group_norm: |
| self.backbone = replace_submodules( |
| root_module=self.backbone, |
| predicate=lambda x: isinstance(x, nn.BatchNorm1d), |
| func=lambda x: nn.GroupNorm( |
| num_groups=x.num_features // 16, num_channels=x.num_features |
| ), |
| ) |
| return |
|
|
| def forward(self, pcd: torch.Tensor, robot_state_obs: torch.Tensor = None) -> torch.Tensor: |
| B = pcd.shape[0] |
| |
| pcd = pcd.float().reshape(-1, *pcd.shape[2:]) |
| robot_state_obs = robot_state_obs.float().reshape(-1, *robot_state_obs.shape[2:]) |
| |
| pcd = pcd.permute(0, 2, 1) |
| |
| encoded_pcd = self.backbone(pcd) |
| nx = torch.cat([encoded_pcd, robot_state_obs], dim=1) |
| |
| nx = nx.reshape(B, -1) |
| return nx |
|
|