| import torch
|
| import torch.nn as nn
|
| import warnings
|
| from torch.autograd import Function
|
| from typing import *
|
|
|
| try:
|
| import pointnet2_ops._ext as _ext
|
| except ImportError:
|
| from torch.utils.cpp_extension import load
|
| import glob
|
| import os.path as osp
|
| import os
|
|
|
| warnings.warn("Unable to load pointnet2_ops cpp extension. JIT Compiling.")
|
|
|
| _ext_src_root = osp.join(osp.dirname(__file__), "_ext-src")
|
| _ext_sources = glob.glob(osp.join(_ext_src_root, "src", "*.cpp")) + glob.glob(
|
| osp.join(_ext_src_root, "src", "*.cu")
|
| )
|
| _ext_headers = glob.glob(osp.join(_ext_src_root, "include", "*"))
|
|
|
| os.environ["TORCH_CUDA_ARCH_LIST"] = "3.7+PTX;5.0;6.0;6.1;6.2;7.0;7.5"
|
| _ext = load(
|
| "_ext",
|
| sources=_ext_sources,
|
| extra_include_paths=[osp.join(_ext_src_root, "include")],
|
| extra_cflags=["-O3"],
|
| extra_cuda_cflags=["-O3", "-Xfatbin", "-compress-all"],
|
| with_cuda=True,
|
| )
|
|
|
|
|
| class FurthestPointSampling(Function):
|
| @staticmethod
|
| def forward(ctx, xyz, npoint):
|
|
|
| r"""
|
| Uses iterative furthest point sampling to select a set of npoint features that have the largest
|
| minimum distance
|
|
|
| Parameters
|
| ----------
|
| xyz : torch.Tensor
|
| (B, N, 3) tensor where N > npoint
|
| npoint : int32
|
| number of features in the sampled set
|
|
|
| Returns
|
| -------
|
| torch.Tensor
|
| (B, npoint) tensor containing the set
|
| """
|
| out = _ext.furthest_point_sampling(xyz, npoint)
|
|
|
| ctx.mark_non_differentiable(out)
|
|
|
| return out
|
|
|
| @staticmethod
|
| def backward(ctx, grad_out):
|
| return ()
|
|
|
|
|
| furthest_point_sample = FurthestPointSampling.apply
|
|
|
|
|
| class GatherOperation(Function):
|
| @staticmethod
|
| def forward(ctx, features, idx):
|
|
|
| r"""
|
|
|
| Parameters
|
| ----------
|
| features : torch.Tensor
|
| (B, C, N) tensor
|
|
|
| idx : torch.Tensor
|
| (B, npoint) tensor of the features to gather
|
|
|
| Returns
|
| -------
|
| torch.Tensor
|
| (B, C, npoint) tensor
|
| """
|
|
|
| ctx.save_for_backward(idx, features)
|
|
|
| return _ext.gather_points(features, idx)
|
|
|
| @staticmethod
|
| def backward(ctx, grad_out):
|
| idx, features = ctx.saved_tensors
|
| N = features.size(2)
|
|
|
| grad_features = _ext.gather_points_grad(grad_out.contiguous(), idx, N)
|
| return grad_features, None
|
|
|
|
|
| gather_operation = GatherOperation.apply
|
|
|
|
|
| class ThreeNN(Function):
|
| @staticmethod
|
| def forward(ctx, unknown, known):
|
|
|
| r"""
|
| Find the three nearest neighbors of unknown in known
|
| Parameters
|
| ----------
|
| unknown : torch.Tensor
|
| (B, n, 3) tensor of known features
|
| known : torch.Tensor
|
| (B, m, 3) tensor of unknown features
|
|
|
| Returns
|
| -------
|
| dist : torch.Tensor
|
| (B, n, 3) l2 distance to the three nearest neighbors
|
| idx : torch.Tensor
|
| (B, n, 3) index of 3 nearest neighbors
|
| """
|
| dist2, idx = _ext.three_nn(unknown, known)
|
| dist = torch.sqrt(dist2)
|
|
|
| ctx.mark_non_differentiable(dist, idx)
|
|
|
| return dist, idx
|
|
|
| @staticmethod
|
| def backward(ctx, grad_dist, grad_idx):
|
| return ()
|
|
|
|
|
| three_nn = ThreeNN.apply
|
|
|
|
|
| class ThreeInterpolate(Function):
|
| @staticmethod
|
| def forward(ctx, features, idx, weight):
|
|
|
| r"""
|
| Performs weight linear interpolation on 3 features
|
| Parameters
|
| ----------
|
| features : torch.Tensor
|
| (B, c, m) Features descriptors to be interpolated from
|
| idx : torch.Tensor
|
| (B, n, 3) three nearest neighbors of the target features in features
|
| weight : torch.Tensor
|
| (B, n, 3) weights
|
|
|
| Returns
|
| -------
|
| torch.Tensor
|
| (B, c, n) tensor of the interpolated features
|
| """
|
| ctx.save_for_backward(idx, weight, features)
|
|
|
| return _ext.three_interpolate(features, idx, weight)
|
|
|
| @staticmethod
|
| def backward(ctx, grad_out):
|
|
|
| r"""
|
| Parameters
|
| ----------
|
| grad_out : torch.Tensor
|
| (B, c, n) tensor with gradients of ouputs
|
|
|
| Returns
|
| -------
|
| grad_features : torch.Tensor
|
| (B, c, m) tensor with gradients of features
|
|
|
| None
|
|
|
| None
|
| """
|
| idx, weight, features = ctx.saved_tensors
|
| m = features.size(2)
|
|
|
| grad_features = _ext.three_interpolate_grad(
|
| grad_out.contiguous(), idx, weight, m
|
| )
|
|
|
| return grad_features, torch.zeros_like(idx), torch.zeros_like(weight)
|
|
|
|
|
| three_interpolate = ThreeInterpolate.apply
|
|
|
|
|
| class GroupingOperation(Function):
|
| @staticmethod
|
| def forward(ctx, features, idx):
|
|
|
| r"""
|
|
|
| Parameters
|
| ----------
|
| features : torch.Tensor
|
| (B, C, N) tensor of features to group
|
| idx : torch.Tensor
|
| (B, npoint, nsample) tensor containing the indicies of features to group with
|
|
|
| Returns
|
| -------
|
| torch.Tensor
|
| (B, C, npoint, nsample) tensor
|
| """
|
| ctx.save_for_backward(idx, features)
|
|
|
| return _ext.group_points(features, idx)
|
|
|
| @staticmethod
|
| def backward(ctx, grad_out):
|
|
|
| r"""
|
|
|
| Parameters
|
| ----------
|
| grad_out : torch.Tensor
|
| (B, C, npoint, nsample) tensor of the gradients of the output from forward
|
|
|
| Returns
|
| -------
|
| torch.Tensor
|
| (B, C, N) gradient of the features
|
| None
|
| """
|
| idx, features = ctx.saved_tensors
|
| N = features.size(2)
|
|
|
| grad_features = _ext.group_points_grad(grad_out.contiguous(), idx, N)
|
|
|
| return grad_features, torch.zeros_like(idx)
|
|
|
|
|
| grouping_operation = GroupingOperation.apply
|
|
|
|
|
| class BallQuery(Function):
|
| @staticmethod
|
| def forward(ctx, radius, nsample, xyz, new_xyz):
|
|
|
| r"""
|
|
|
| Parameters
|
| ----------
|
| radius : float
|
| radius of the balls
|
| nsample : int
|
| maximum number of features in the balls
|
| xyz : torch.Tensor
|
| (B, N, 3) xyz coordinates of the features
|
| new_xyz : torch.Tensor
|
| (B, npoint, 3) centers of the ball query
|
|
|
| Returns
|
| -------
|
| torch.Tensor
|
| (B, npoint, nsample) tensor with the indicies of the features that form the query balls
|
| """
|
| output = _ext.ball_query(new_xyz, xyz, radius, nsample)
|
|
|
| ctx.mark_non_differentiable(output)
|
|
|
| return output
|
|
|
| @staticmethod
|
| def backward(ctx, grad_out):
|
| return ()
|
|
|
|
|
| ball_query = BallQuery.apply
|
|
|
|
|
| class QueryAndGroup(nn.Module):
|
| r"""
|
| Groups with a ball query of radius
|
|
|
| Parameters
|
| ---------
|
| radius : float32
|
| Radius of ball
|
| nsample : int32
|
| Maximum number of features to gather in the ball
|
| """
|
|
|
| def __init__(self, radius, nsample, use_xyz=True):
|
|
|
| super(QueryAndGroup, self).__init__()
|
| self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
|
|
|
| def forward(self, xyz, new_xyz, features=None):
|
|
|
| r"""
|
| Parameters
|
| ----------
|
| xyz : torch.Tensor
|
| xyz coordinates of the features (B, N, 3)
|
| new_xyz : torch.Tensor
|
| centriods (B, npoint, 3)
|
| features : torch.Tensor
|
| Descriptors of the features (B, C, N)
|
|
|
| Returns
|
| -------
|
| new_features : torch.Tensor
|
| (B, 3 + C, npoint, nsample) tensor
|
| """
|
|
|
| idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
|
| xyz_trans = xyz.transpose(1, 2).contiguous()
|
| grouped_xyz = grouping_operation(xyz_trans, idx)
|
| grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
|
|
|
| if features is not None:
|
| grouped_features = grouping_operation(features, idx)
|
| if self.use_xyz:
|
| new_features = torch.cat(
|
| [grouped_xyz, grouped_features], dim=1
|
| )
|
| else:
|
| new_features = grouped_features
|
| else:
|
| assert (
|
| self.use_xyz
|
| ), "Cannot have not features and not use xyz as a feature!"
|
| new_features = grouped_xyz
|
|
|
| return new_features
|
|
|
|
|
| class GroupAll(nn.Module):
|
| r"""
|
| Groups all features
|
|
|
| Parameters
|
| ---------
|
| """
|
|
|
| def __init__(self, use_xyz=True):
|
|
|
| super(GroupAll, self).__init__()
|
| self.use_xyz = use_xyz
|
|
|
| def forward(self, xyz, new_xyz, features=None):
|
|
|
| r"""
|
| Parameters
|
| ----------
|
| xyz : torch.Tensor
|
| xyz coordinates of the features (B, N, 3)
|
| new_xyz : torch.Tensor
|
| Ignored
|
| features : torch.Tensor
|
| Descriptors of the features (B, C, N)
|
|
|
| Returns
|
| -------
|
| new_features : torch.Tensor
|
| (B, C + 3, 1, N) tensor
|
| """
|
|
|
| grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
|
| if features is not None:
|
| grouped_features = features.unsqueeze(2)
|
| if self.use_xyz:
|
| new_features = torch.cat(
|
| [grouped_xyz, grouped_features], dim=1
|
| )
|
| else:
|
| new_features = grouped_features
|
| else:
|
| new_features = grouped_xyz
|
|
|
| return new_features
|
|
|