| import torch |
| from torch.autograd import Variable |
| from torch.autograd import Function |
| import torch.nn as nn |
| from typing import Tuple |
|
|
| import pointnet2_cuda as pointnet2 |
|
|
|
|
| class FurthestPointSampling(Function): |
| @staticmethod |
| def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: |
| """ |
| Uses iterative furthest point sampling to select a set of npoint features that have the largest |
| minimum distance |
| :param ctx: |
| :param xyz: (B, N, 3) where N > npoint |
| :param npoint: int, number of features in the sampled set |
| :return: |
| output: (B, npoint) tensor containing the set |
| """ |
| assert xyz.is_contiguous() |
|
|
| B, N, _ = xyz.size() |
| output = torch.cuda.IntTensor(B, npoint) |
| temp = torch.cuda.FloatTensor(B, N).fill_(1e10) |
|
|
| pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) |
| return output |
|
|
| @staticmethod |
| def backward(xyz, a=None): |
| return None, None |
|
|
|
|
| furthest_point_sample = FurthestPointSampling.apply |
|
|
|
|
| class GatherOperation(Function): |
|
|
| @staticmethod |
| def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: |
| """ |
| :param ctx: |
| :param features: (B, C, N) |
| :param idx: (B, npoint) index tensor of the features to gather |
| :return: |
| output: (B, C, npoint) |
| """ |
| assert features.is_contiguous() |
| assert idx.is_contiguous() |
|
|
| B, npoint = idx.size() |
| _, C, N = features.size() |
| output = torch.cuda.FloatTensor(B, C, npoint) |
|
|
| pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output) |
|
|
| ctx.for_backwards = (idx, C, N) |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_out): |
| idx, C, N = ctx.for_backwards |
| B, npoint = idx.size() |
|
|
| grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) |
| grad_out_data = grad_out.data.contiguous() |
| pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data) |
| return grad_features, None |
|
|
|
|
| gather_operation = GatherOperation.apply |
|
|
| class KNN(Function): |
|
|
| @staticmethod |
| def forward(ctx, k: int, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Find the three nearest neighbors of unknown in known |
| :param ctx: |
| :param unknown: (B, N, 3) |
| :param known: (B, M, 3) |
| :return: |
| dist: (B, N, k) l2 distance to the three nearest neighbors |
| idx: (B, N, k) index of 3 nearest neighbors |
| """ |
| assert unknown.is_contiguous() |
| assert known.is_contiguous() |
|
|
| B, N, _ = unknown.size() |
| m = known.size(1) |
| dist2 = torch.cuda.FloatTensor(B, N, k) |
| idx = torch.cuda.IntTensor(B, N, k) |
|
|
| pointnet2.knn_wrapper(B, N, m, k, unknown, known, dist2, idx) |
| return torch.sqrt(dist2), idx |
|
|
| @staticmethod |
| def backward(ctx, a=None, b=None): |
| return None, None, None |
| knn = KNN.apply |
|
|
| class ThreeNN(Function): |
|
|
| @staticmethod |
| def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Find the three nearest neighbors of unknown in known |
| :param ctx: |
| :param unknown: (B, N, 3) |
| :param known: (B, M, 3) |
| :return: |
| dist: (B, N, 3) l2 distance to the three nearest neighbors |
| idx: (B, N, 3) index of 3 nearest neighbors |
| """ |
| assert unknown.is_contiguous() |
| assert known.is_contiguous() |
|
|
| B, N, _ = unknown.size() |
| m = known.size(1) |
| dist2 = torch.cuda.FloatTensor(B, N, 3) |
| idx = torch.cuda.IntTensor(B, N, 3) |
|
|
| pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx) |
| return torch.sqrt(dist2), idx |
|
|
| @staticmethod |
| def backward(ctx, a=None, b=None): |
| return None, None |
|
|
|
|
| three_nn = ThreeNN.apply |
|
|
|
|
| class ThreeInterpolate(Function): |
|
|
| @staticmethod |
| def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: |
| """ |
| Performs weight linear interpolation on 3 features |
| :param ctx: |
| :param features: (B, C, M) Features descriptors to be interpolated from |
| :param idx: (B, n, 3) three nearest neighbors of the target features in features |
| :param weight: (B, n, 3) weights |
| :return: |
| output: (B, C, N) tensor of the interpolated features |
| """ |
| assert features.is_contiguous() |
| assert idx.is_contiguous() |
| assert weight.is_contiguous() |
|
|
| B, c, m = features.size() |
| n = idx.size(1) |
| ctx.three_interpolate_for_backward = (idx, weight, m) |
| output = torch.cuda.FloatTensor(B, c, n) |
|
|
| pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output) |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| :param ctx: |
| :param grad_out: (B, C, N) tensor with gradients of outputs |
| :return: |
| grad_features: (B, C, M) tensor with gradients of features |
| None: |
| None: |
| """ |
| idx, weight, m = ctx.three_interpolate_for_backward |
| B, c, n = grad_out.size() |
|
|
| grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_()) |
| grad_out_data = grad_out.data.contiguous() |
|
|
| pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data) |
| return grad_features, None, None |
|
|
|
|
| three_interpolate = ThreeInterpolate.apply |
|
|
|
|
| class GroupingOperation(Function): |
|
|
| @staticmethod |
| def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: |
| """ |
| :param ctx: |
| :param features: (B, C, N) tensor of features to group |
| :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with |
| :return: |
| output: (B, C, npoint, nsample) tensor |
| """ |
| assert features.is_contiguous() |
| assert idx.is_contiguous() |
| idx = idx.int() |
| B, nfeatures, nsample = idx.size() |
| _, C, N = features.size() |
| output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) |
|
|
| pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output) |
|
|
| ctx.for_backwards = (idx, N) |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| :param ctx: |
| :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward |
| :return: |
| grad_features: (B, C, N) gradient of the features |
| """ |
| idx, N = ctx.for_backwards |
|
|
| B, C, npoint, nsample = grad_out.size() |
| grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) |
|
|
| grad_out_data = grad_out.data.contiguous() |
| pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data) |
| return grad_features, None |
|
|
|
|
| grouping_operation = GroupingOperation.apply |
|
|
|
|
| class BallQuery(Function): |
|
|
| @staticmethod |
| def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor: |
| """ |
| :param ctx: |
| :param radius: float, radius of the balls |
| :param nsample: int, maximum number of features in the balls |
| :param xyz: (B, N, 3) xyz coordinates of the features |
| :param new_xyz: (B, npoint, 3) centers of the ball query |
| :return: |
| idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls |
| """ |
| assert new_xyz.is_contiguous() |
| assert xyz.is_contiguous() |
|
|
| B, N, _ = xyz.size() |
| npoint = new_xyz.size(1) |
| idx = torch.cuda.IntTensor(B, npoint, nsample).zero_() |
|
|
| pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx) |
| return idx |
|
|
| @staticmethod |
| def backward(ctx, a=None): |
| return None, None, None, None |
|
|
|
|
| ball_query = BallQuery.apply |
|
|
|
|
| class QueryAndGroup(nn.Module): |
| def __init__(self, radius: float, nsample: int, use_xyz: bool = True): |
| """ |
| :param radius: float, radius of ball |
| :param nsample: int, maximum number of features to gather in the ball |
| :param use_xyz: |
| """ |
| super().__init__() |
| self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz |
|
|
| def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]: |
| """ |
| :param xyz: (B, N, 3) xyz coordinates of the features |
| :param new_xyz: (B, npoint, 3) centroids |
| :param features: (B, C, N) descriptors of the features |
| :return: |
| new_features: (B, 3 + C, npoint, nsample) |
| """ |
| idx = ball_query(self.radius, self.nsample, xyz, new_xyz) |
| xyz_trans = xyz.transpose(1, 2).contiguous() |
| grouped_xyz = grouping_operation(xyz_trans, idx) |
| grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) |
|
|
| if features is not None: |
| grouped_features = grouping_operation(features, idx) |
| if self.use_xyz: |
| new_features = torch.cat([grouped_xyz, grouped_features], dim=1) |
| else: |
| new_features = grouped_features |
| else: |
| assert self.use_xyz, "Cannot have not features and not use xyz as a feature!" |
| new_features = grouped_xyz |
|
|
| return new_features |
|
|
|
|
| class GroupAll(nn.Module): |
| def __init__(self, use_xyz: bool = True): |
| super().__init__() |
| self.use_xyz = use_xyz |
|
|
| def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None): |
| """ |
| :param xyz: (B, N, 3) xyz coordinates of the features |
| :param new_xyz: ignored |
| :param features: (B, C, N) descriptors of the features |
| :return: |
| new_features: (B, C + 3, 1, N) |
| """ |
| grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) |
| if features is not None: |
| grouped_features = features.unsqueeze(2) |
| if self.use_xyz: |
| new_features = torch.cat([grouped_xyz, grouped_features], dim=1) |
| else: |
| new_features = grouped_features |
| else: |
| new_features = grouped_xyz |
|
|
| return new_features |
|
|