| from typing import List, Optional, Tuple
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from pointnet2_ops import pointnet2_utils
|
|
|
|
|
| def build_shared_mlp(mlp_spec: List[int], bn: bool = True):
|
| layers = []
|
| for i in range(1, len(mlp_spec)):
|
| layers.append(
|
| nn.Conv2d(mlp_spec[i - 1], mlp_spec[i], kernel_size=1, bias=not bn)
|
| )
|
| if bn:
|
| layers.append(nn.BatchNorm2d(mlp_spec[i]))
|
| layers.append(nn.ReLU(True))
|
|
|
| return nn.Sequential(*layers)
|
|
|
|
|
| class _PointnetSAModuleBase(nn.Module):
|
| def __init__(self):
|
| super(_PointnetSAModuleBase, self).__init__()
|
| self.npoint = None
|
| self.groupers = None
|
| self.mlps = None
|
|
|
| def forward(
|
| self, xyz: torch.Tensor, features: Optional[torch.Tensor]
|
| ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| r"""
|
| Parameters
|
| ----------
|
| xyz : torch.Tensor
|
| (B, N, 3) tensor of the xyz coordinates of the features
|
| features : torch.Tensor
|
| (B, C, N) tensor of the descriptors of the the features
|
|
|
| Returns
|
| -------
|
| new_xyz : torch.Tensor
|
| (B, npoint, 3) tensor of the new features' xyz
|
| new_features : torch.Tensor
|
| (B, \sum_k(mlps[k][-1]), npoint) tensor of the new_features descriptors
|
| """
|
|
|
| new_features_list = []
|
|
|
| xyz_flipped = xyz.transpose(1, 2).contiguous()
|
| new_xyz = (
|
| pointnet2_utils.gather_operation(
|
| xyz_flipped, pointnet2_utils.furthest_point_sample(xyz, self.npoint)
|
| )
|
| .transpose(1, 2)
|
| .contiguous()
|
| if self.npoint is not None
|
| else None
|
| )
|
|
|
| for i in range(len(self.groupers)):
|
| new_features = self.groupers[i](
|
| xyz, new_xyz, features
|
| )
|
|
|
| new_features = self.mlps[i](new_features)
|
| new_features = F.max_pool2d(
|
| new_features, kernel_size=[1, new_features.size(3)]
|
| )
|
| new_features = new_features.squeeze(-1)
|
|
|
| new_features_list.append(new_features)
|
|
|
| return new_xyz, torch.cat(new_features_list, dim=1)
|
|
|
|
|
| class PointnetSAModuleMSG(_PointnetSAModuleBase):
|
| r"""Pointnet set abstrction layer with multiscale grouping
|
|
|
| Parameters
|
| ----------
|
| npoint : int
|
| Number of features
|
| radii : list of float32
|
| list of radii to group with
|
| nsamples : list of int32
|
| Number of samples in each ball query
|
| mlps : list of list of int32
|
| Spec of the pointnet before the global max_pool for each scale
|
| bn : bool
|
| Use batchnorm
|
| """
|
|
|
| def __init__(self, npoint, radii, nsamples, mlps, bn=True, use_xyz=True):
|
|
|
| super(PointnetSAModuleMSG, self).__init__()
|
|
|
| assert len(radii) == len(nsamples) == len(mlps)
|
|
|
| self.npoint = npoint
|
| self.groupers = nn.ModuleList()
|
| self.mlps = nn.ModuleList()
|
| for i in range(len(radii)):
|
| radius = radii[i]
|
| nsample = nsamples[i]
|
| self.groupers.append(
|
| pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
|
| if npoint is not None
|
| else pointnet2_utils.GroupAll(use_xyz)
|
| )
|
| mlp_spec = mlps[i]
|
| if use_xyz:
|
| mlp_spec[0] += 3
|
|
|
| self.mlps.append(build_shared_mlp(mlp_spec, bn))
|
|
|
|
|
| class PointnetSAModule(PointnetSAModuleMSG):
|
| r"""Pointnet set abstrction layer
|
|
|
| Parameters
|
| ----------
|
| npoint : int
|
| Number of features
|
| radius : float
|
| Radius of ball
|
| nsample : int
|
| Number of samples in the ball query
|
| mlp : list
|
| Spec of the pointnet before the global max_pool
|
| bn : bool
|
| Use batchnorm
|
| """
|
|
|
| def __init__(
|
| self, mlp, npoint=None, radius=None, nsample=None, bn=True, use_xyz=True
|
| ):
|
|
|
| super(PointnetSAModule, self).__init__(
|
| mlps=[mlp],
|
| npoint=npoint,
|
| radii=[radius],
|
| nsamples=[nsample],
|
| bn=bn,
|
| use_xyz=use_xyz,
|
| )
|
|
|
|
|
| class PointnetFPModule(nn.Module):
|
| r"""Propigates the features of one set to another
|
|
|
| Parameters
|
| ----------
|
| mlp : list
|
| Pointnet module parameters
|
| bn : bool
|
| Use batchnorm
|
| """
|
|
|
| def __init__(self, mlp, bn=True):
|
|
|
| super(PointnetFPModule, self).__init__()
|
| self.mlp = build_shared_mlp(mlp, bn=bn)
|
|
|
| def forward(self, unknown, known, unknow_feats, known_feats):
|
|
|
| r"""
|
| Parameters
|
| ----------
|
| unknown : torch.Tensor
|
| (B, n, 3) tensor of the xyz positions of the unknown features
|
| known : torch.Tensor
|
| (B, m, 3) tensor of the xyz positions of the known features
|
| unknow_feats : torch.Tensor
|
| (B, C1, n) tensor of the features to be propigated to
|
| known_feats : torch.Tensor
|
| (B, C2, m) tensor of features to be propigated
|
|
|
| Returns
|
| -------
|
| new_features : torch.Tensor
|
| (B, mlp[-1], n) tensor of the features of the unknown features
|
| """
|
|
|
| if known is not None:
|
| dist, idx = pointnet2_utils.three_nn(unknown, known)
|
| dist_recip = 1.0 / (dist + 1e-8)
|
| norm = torch.sum(dist_recip, dim=2, keepdim=True)
|
| weight = dist_recip / norm
|
|
|
| interpolated_feats = pointnet2_utils.three_interpolate(
|
| known_feats, idx, weight
|
| )
|
| else:
|
| interpolated_feats = known_feats.expand(
|
| *(known_feats.size()[0:2] + [unknown.size(1)])
|
| )
|
|
|
| if unknow_feats is not None:
|
| new_features = torch.cat(
|
| [interpolated_feats, unknow_feats], dim=1
|
| )
|
| else:
|
| new_features = interpolated_feats
|
|
|
| new_features = new_features.unsqueeze(-1)
|
| new_features = self.mlp(new_features)
|
|
|
| return new_features.squeeze(-1)
|
|
|