|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.autograd import Function, Variable |
|
|
|
|
|
import sys |
|
|
import os |
|
|
os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:' + os.environ.get('LD_LIBRARY_PATH', '') |
|
|
sys.path.append('pc_util-1.0-py3.10-linux-x86_64.egg') |
|
|
import pc_util |
|
|
|
|
|
|
|
|
class BallQuery(Function): |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, xyz_batch_cnt: torch.Tensor, |
|
|
new_xyz: torch.Tensor, new_xyz_batch_cnt): |
|
|
""" |
|
|
Args: |
|
|
ctx: |
|
|
radius: float, radius of the balls |
|
|
nsample: int, maximum number of features in the balls |
|
|
xyz: (N1 + N2 ..., 3) xyz coordinates of the features |
|
|
xyz_batch_cnt: (batch_size), [N1, N2, ...] |
|
|
new_xyz: (M1 + M2 ..., 3) centers of the ball query |
|
|
new_xyz_batch_cnt: (batch_size), [M1, M2, ...] |
|
|
|
|
|
Returns: |
|
|
idx: (M1 + M2, nsample) tensor with the indicies of the features that form the query balls |
|
|
""" |
|
|
assert new_xyz.is_contiguous() |
|
|
assert new_xyz_batch_cnt.is_contiguous() |
|
|
assert xyz.is_contiguous() |
|
|
assert xyz_batch_cnt.is_contiguous() |
|
|
|
|
|
B = xyz_batch_cnt.shape[0] |
|
|
M = new_xyz.shape[0] |
|
|
idx = torch.cuda.IntTensor(M, nsample).zero_() |
|
|
|
|
|
pc_util.ball_query_wrapper_stack(B, M, radius, nsample, new_xyz, new_xyz_batch_cnt, xyz, xyz_batch_cnt, idx) |
|
|
empty_ball_mask = (idx[:, 0] == -1) |
|
|
idx[empty_ball_mask] = 0 |
|
|
return idx, empty_ball_mask |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, a=None): |
|
|
return None, None, None, None |
|
|
|
|
|
|
|
|
ball_query = BallQuery.apply |
|
|
|
|
|
|
|
|
class GroupingOperation(Function): |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, features: torch.Tensor, features_batch_cnt: torch.Tensor, |
|
|
idx: torch.Tensor, idx_batch_cnt: torch.Tensor): |
|
|
""" |
|
|
Args: |
|
|
ctx: |
|
|
features: (N1 + N2 ..., C) tensor of features to group |
|
|
features_batch_cnt: (batch_size) [N1 + N2 ...] tensor containing the indicies of features to group with |
|
|
idx: (M1 + M2 ..., nsample) tensor containing the indicies of features to group with |
|
|
idx_batch_cnt: (batch_size) [M1 + M2 ...] tensor containing the indicies of features to group with |
|
|
|
|
|
Returns: |
|
|
output: (M1 + M2, C, nsample) tensor |
|
|
""" |
|
|
assert features.is_contiguous() |
|
|
assert features_batch_cnt.is_contiguous() |
|
|
assert idx.is_contiguous() |
|
|
assert idx_batch_cnt.is_contiguous() |
|
|
|
|
|
assert features.shape[0] == features_batch_cnt.sum(), \ |
|
|
'features: %s, features_batch_cnt: %s' % (str(features.shape), str(features_batch_cnt)) |
|
|
assert idx.shape[0] == idx_batch_cnt.sum(), \ |
|
|
'idx: %s, idx_batch_cnt: %s' % (str(idx.shape), str(idx_batch_cnt)) |
|
|
|
|
|
M, nsample = idx.size() |
|
|
N, C = features.size() |
|
|
B = idx_batch_cnt.shape[0] |
|
|
output = torch.cuda.FloatTensor(M, C, nsample) |
|
|
|
|
|
pc_util.group_points_wrapper_stack(B, M, C, nsample, features, features_batch_cnt, idx, idx_batch_cnt, output) |
|
|
|
|
|
ctx.for_backwards = (B, N, idx, features_batch_cnt, idx_batch_cnt) |
|
|
return output |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_out: torch.Tensor): |
|
|
""" |
|
|
Args: |
|
|
ctx: |
|
|
grad_out: (M1 + M2 ..., C, nsample) tensor of the gradients of the output from forward |
|
|
|
|
|
Returns: |
|
|
grad_features: (N1 + N2 ..., C) gradient of the features |
|
|
""" |
|
|
B, N, idx, features_batch_cnt, idx_batch_cnt = ctx.for_backwards |
|
|
|
|
|
M, C, nsample = grad_out.size() |
|
|
grad_features = Variable(torch.cuda.FloatTensor(N, C).zero_()) |
|
|
|
|
|
grad_out_data = grad_out.data.contiguous() |
|
|
pc_util.group_points_grad_wrapper_stack(B, M, C, N, nsample, grad_out_data, idx, |
|
|
idx_batch_cnt, features_batch_cnt, grad_features.data) |
|
|
return grad_features, None, None, None |
|
|
|
|
|
|
|
|
grouping_operation = GroupingOperation.apply |
|
|
|
|
|
|
|
|
class QueryAndGroup(nn.Module): |
|
|
def __init__(self, radius: float, nsample: int, use_xyz: bool = True): |
|
|
""" |
|
|
Args: |
|
|
radius: float, radius of ball |
|
|
nsample: int, maximum number of features to gather in the ball |
|
|
use_xyz: |
|
|
""" |
|
|
super().__init__() |
|
|
self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz |
|
|
|
|
|
def forward(self, xyz: torch.Tensor, xyz_batch_cnt: torch.Tensor, |
|
|
new_xyz: torch.Tensor, new_xyz_batch_cnt: torch.Tensor, |
|
|
features: torch.Tensor = None): |
|
|
""" |
|
|
Args: |
|
|
xyz: (N1 + N2 ..., 3) xyz coordinates of the features |
|
|
xyz_batch_cnt: (batch_size), [N1, N2, ...] |
|
|
new_xyz: (M1 + M2 ..., 3) centers of the ball query |
|
|
new_xyz_batch_cnt: (batch_size), [M1, M2, ...] |
|
|
features: (N1 + N2 ..., C) tensor of features to group |
|
|
|
|
|
Returns: |
|
|
new_features: (M1 + M2, C, nsample) tensor |
|
|
""" |
|
|
assert xyz.shape[0] == xyz_batch_cnt.sum(), 'xyz: %s, xyz_batch_cnt: %s' % (str(xyz.shape), str(new_xyz_batch_cnt)) |
|
|
assert new_xyz.shape[0] == new_xyz_batch_cnt.sum(), \ |
|
|
'new_xyz: %s, new_xyz_batch_cnt: %s' % (str(new_xyz.shape), str(new_xyz_batch_cnt)) |
|
|
|
|
|
|
|
|
idx, empty_ball_mask = ball_query(self.radius, self.nsample, xyz, xyz_batch_cnt, new_xyz, new_xyz_batch_cnt) |
|
|
grouped_xyz = grouping_operation(xyz, xyz_batch_cnt, idx, new_xyz_batch_cnt) |
|
|
grouped_xyz -= new_xyz.unsqueeze(-1) |
|
|
|
|
|
grouped_xyz[empty_ball_mask] = 0 |
|
|
|
|
|
if features is not None: |
|
|
grouped_features = grouping_operation(features, xyz_batch_cnt, idx, new_xyz_batch_cnt) |
|
|
grouped_features[empty_ball_mask] = 0 |
|
|
if self.use_xyz: |
|
|
new_features = torch.cat([grouped_xyz, grouped_features], dim=1) |
|
|
else: |
|
|
new_features = grouped_features |
|
|
else: |
|
|
assert self.use_xyz, "Cannot have not features and not use xyz as a feature!" |
|
|
new_features = grouped_xyz |
|
|
|
|
|
return new_features, idx |
|
|
|
|
|
|
|
|
class FurthestPointSampling(Function): |
|
|
@staticmethod |
|
|
def forward(ctx, xyz: torch.Tensor, npoint: int): |
|
|
""" |
|
|
Args: |
|
|
ctx: |
|
|
xyz: (B, N, 3) where N > npoint |
|
|
npoint: int, number of features in the sampled set |
|
|
|
|
|
Returns: |
|
|
output: (B, npoint) tensor containing the set |
|
|
""" |
|
|
assert xyz.is_contiguous() |
|
|
|
|
|
B, N, _ = xyz.size() |
|
|
output = torch.cuda.IntTensor(B, npoint) |
|
|
temp = torch.cuda.FloatTensor(B, N).fill_(1e10) |
|
|
|
|
|
pc_util.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) |
|
|
return output |
|
|
|
|
|
@staticmethod |
|
|
def backward(xyz, a=None): |
|
|
return None, None |
|
|
|
|
|
|
|
|
furthest_point_sample = FurthestPointSampling.apply |
|
|
|
|
|
|
|
|
class ThreeNN(Function): |
|
|
@staticmethod |
|
|
def forward(ctx, unknown, unknown_batch_cnt, known, known_batch_cnt): |
|
|
""" |
|
|
Args: |
|
|
ctx: |
|
|
unknown: (N1 + N2..., 3) |
|
|
unknown_batch_cnt: (batch_size), [N1, N2, ...] |
|
|
known: (M1 + M2..., 3) |
|
|
known_batch_cnt: (batch_size), [M1, M2, ...] |
|
|
|
|
|
Returns: |
|
|
dist: (N1 + N2 ..., 3) l2 distance to the three nearest neighbors |
|
|
idx: (N1 + N2 ..., 3) index of the three nearest neighbors, range [0, M1+M2+...] |
|
|
""" |
|
|
assert unknown.shape.__len__() == 2 and unknown.shape[1] == 3 |
|
|
assert known.shape.__len__() == 2 and known.shape[1] == 3 |
|
|
assert unknown_batch_cnt.__len__() == known_batch_cnt.__len__() |
|
|
|
|
|
dist2 = unknown.new_zeros(unknown.shape) |
|
|
idx = unknown_batch_cnt.new_zeros(unknown.shape).int() |
|
|
|
|
|
pc_util.three_nn_wrapper_stack( |
|
|
unknown.contiguous(), unknown_batch_cnt.contiguous(), |
|
|
known.contiguous(), known_batch_cnt.contiguous(), dist2, idx |
|
|
) |
|
|
return torch.sqrt(dist2), idx |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, a=None, b=None): |
|
|
return None, None |
|
|
|
|
|
|
|
|
three_nn = ThreeNN.apply |
|
|
|
|
|
|
|
|
class ThreeInterpolate(Function): |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor): |
|
|
""" |
|
|
Args: |
|
|
ctx: |
|
|
features: (M1 + M2 ..., C) |
|
|
idx: [N1 + N2 ..., 3] |
|
|
weight: [N1 + N2 ..., 3] |
|
|
|
|
|
Returns: |
|
|
out_tensor: (N1 + N2 ..., C) |
|
|
""" |
|
|
assert idx.shape[0] == weight.shape[0] and idx.shape[1] == weight.shape[1] == 3 |
|
|
|
|
|
ctx.three_interpolate_for_backward = (idx, weight, features.shape[0]) |
|
|
output = features.new_zeros((idx.shape[0], features.shape[1])) |
|
|
pc_util.three_interpolate_wrapper_stack(features.contiguous(), idx.contiguous(), weight.contiguous(), output) |
|
|
return output |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_out: torch.Tensor): |
|
|
""" |
|
|
Args: |
|
|
ctx: |
|
|
grad_out: (N1 + N2 ..., C) |
|
|
|
|
|
Returns: |
|
|
grad_features: (M1 + M2 ..., C) |
|
|
""" |
|
|
idx, weight, M = ctx.three_interpolate_for_backward |
|
|
grad_features = grad_out.new_zeros((M, grad_out.shape[1])) |
|
|
pc_util.three_interpolate_grad_wrapper_stack( |
|
|
grad_out.contiguous(), idx.contiguous(), weight.contiguous(), grad_features |
|
|
) |
|
|
return grad_features, None, None |
|
|
|
|
|
|
|
|
three_interpolate = ThreeInterpolate.apply |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
pass |
|
|
|