| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from . import pointnet2_utils |
| from . import pytorch_utils as pt_utils |
| from typing import List |
|
|
|
|
| class _PointnetSAModuleBase(nn.Module): |
|
|
| def __init__(self): |
| super().__init__() |
| self.npoint = None |
| self.groupers = None |
| self.mlps = None |
| self.pool_method = 'max_pool' |
|
|
| def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor): |
| """ |
| :param xyz: (B, N, 3) tensor of the xyz coordinates of the features |
| :param features: (B, N, C) tensor of the descriptors of the the features |
| :param new_xyz: |
| :return: |
| new_xyz: (B, npoint, 3) tensor of the new features' xyz |
| new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors |
| """ |
| new_features_list = [] |
|
|
| xyz_flipped = xyz.transpose(1, 2).contiguous() |
| if new_xyz is None: |
| new_xyz = pointnet2_utils.gather_operation( |
| xyz_flipped, |
| pointnet2_utils.furthest_point_sample(xyz, self.npoint) |
| ).transpose(1, 2).contiguous() if self.npoint is not None else None |
|
|
| for i in range(len(self.groupers)): |
| new_features = self.groupers[i](xyz, new_xyz, features) |
|
|
| new_features = self.mlps[i](new_features) |
| if self.pool_method == 'max_pool': |
| new_features = F.max_pool2d( |
| new_features, kernel_size=[1, new_features.size(3)] |
| ) |
| elif self.pool_method == 'avg_pool': |
| new_features = F.avg_pool2d( |
| new_features, kernel_size=[1, new_features.size(3)] |
| ) |
| else: |
| raise NotImplementedError |
|
|
| new_features = new_features.squeeze(-1) |
| new_features_list.append(new_features) |
|
|
| return new_xyz, torch.cat(new_features_list, dim=1) |
|
|
|
|
| class PointnetSAModuleMSG(_PointnetSAModuleBase): |
| """Pointnet set abstraction layer with multiscale grouping""" |
|
|
| def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True, |
| use_xyz: bool = True, pool_method='max_pool', instance_norm=False): |
| """ |
| :param npoint: int |
| :param radii: list of float, list of radii to group with |
| :param nsamples: list of int, number of samples in each ball query |
| :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale |
| :param bn: whether to use batchnorm |
| :param use_xyz: |
| :param pool_method: max_pool / avg_pool |
| :param instance_norm: whether to use instance_norm |
| """ |
| super().__init__() |
|
|
| assert len(radii) == len(nsamples) == len(mlps) |
|
|
| self.npoint = npoint |
| self.groupers = nn.ModuleList() |
| self.mlps = nn.ModuleList() |
| for i in range(len(radii)): |
| radius = radii[i] |
| nsample = nsamples[i] |
| self.groupers.append( |
| pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) |
| if npoint is not None else pointnet2_utils.GroupAll(use_xyz) |
| ) |
| mlp_spec = mlps[i] |
| if use_xyz: |
| mlp_spec[0] += 3 |
|
|
| self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm)) |
| self.pool_method = pool_method |
|
|
|
|
| class PointnetSAModule(PointnetSAModuleMSG): |
| """Pointnet set abstraction layer""" |
|
|
| def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None, |
| bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False): |
| """ |
| :param mlp: list of int, spec of the pointnet before the global max_pool |
| :param npoint: int, number of features |
| :param radius: float, radius of ball |
| :param nsample: int, number of samples in the ball query |
| :param bn: whether to use batchnorm |
| :param use_xyz: |
| :param pool_method: max_pool / avg_pool |
| :param instance_norm: whether to use instance_norm |
| """ |
| super().__init__( |
| mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz, |
| pool_method=pool_method, instance_norm=instance_norm |
| ) |
|
|
|
|
| class PointnetFPModule(nn.Module): |
| r"""Propigates the features of one set to another""" |
|
|
| def __init__(self, *, mlp: List[int], bn: bool = True): |
| """ |
| :param mlp: list of int |
| :param bn: whether to use batchnorm |
| """ |
| super().__init__() |
| self.mlp = pt_utils.SharedMLP(mlp, bn=bn) |
|
|
| def forward( |
| self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features |
| :param known: (B, m, 3) tensor of the xyz positions of the known features |
| :param unknow_feats: (B, C1, n) tensor of the features to be propigated to |
| :param known_feats: (B, C2, m) tensor of features to be propigated |
| :return: |
| new_features: (B, mlp[-1], n) tensor of the features of the unknown features |
| """ |
| if known is not None: |
| dist, idx = pointnet2_utils.three_nn(unknown, known) |
| dist_recip = 1.0 / (dist + 1e-8) |
| norm = torch.sum(dist_recip, dim=2, keepdim=True) |
| weight = dist_recip / norm |
|
|
| interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight) |
| else: |
| interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1)) |
|
|
| if unknow_feats is not None: |
| new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) |
| else: |
| new_features = interpolated_feats |
|
|
| new_features = new_features.unsqueeze(-1) |
| new_features = self.mlp(new_features) |
|
|
| return new_features.squeeze(-1) |
|
|
|
|
| if __name__ == "__main__": |
| pass |
|
|