| import torch |
| from torch.autograd import Function |
|
|
| from pointops._C import grouping_forward_cuda, grouping_backward_cuda |
|
|
|
|
| class Grouping(Function): |
| @staticmethod |
| def forward(ctx, input, idx): |
| """ |
| input: input: (n, c), idx : (m, nsample) |
| output: (m, nsample, c) |
| """ |
| assert input.is_contiguous() and idx.is_contiguous() |
| m, nsample, n, c = idx.shape[0], idx.shape[1], input.shape[0], input.shape[1] |
| output = torch.zeros((m, nsample, c), dtype=torch.float, device=input.device) |
| grouping_forward_cuda(m, nsample, c, input, idx, output) |
| ctx.n = n |
| ctx.save_for_backward(idx) |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| """ |
| input: grad_out: (m, c, nsample) |
| output: (n, c), None |
| """ |
| n = ctx.n |
| (idx,) = ctx.saved_tensors |
| m, nsample, c = grad_output.shape |
| grad_input = torch.zeros((n, c), dtype=torch.float, device=idx.device) |
| grouping_backward_cuda(m, nsample, c, grad_output, idx, grad_input) |
| return grad_input, None |
|
|
|
|
| def grouping(idx, feat, xyz, new_xyz=None, with_xyz=False): |
| if new_xyz is None: |
| new_xyz = xyz |
| assert xyz.is_contiguous() and feat.is_contiguous() |
| m, nsample, c = idx.shape[0], idx.shape[1], feat.shape[1] |
| xyz = torch.cat([xyz, torch.zeros([1, 3]).to(xyz.device)], dim=0) |
| feat = torch.cat([feat, torch.zeros([1, c]).to(feat.device)], dim=0) |
| grouped_feat = feat[idx.view(-1).long(), :].view( |
| m, nsample, c |
| ) |
|
|
| if with_xyz: |
| assert new_xyz.is_contiguous() |
| mask = torch.sign(idx + 1) |
| grouped_xyz = xyz[idx.view(-1).long(), :].view( |
| m, nsample, 3 |
| ) - new_xyz.unsqueeze( |
| 1 |
| ) |
| grouped_xyz = torch.einsum( |
| "n s c, n s -> n s c", grouped_xyz, mask |
| ) |
| return torch.cat((grouped_xyz, grouped_feat), -1) |
| else: |
| return grouped_feat |
|
|
|
|
| grouping2 = Grouping.apply |
|
|