| import torch |
| from torch.autograd import Function |
|
|
| from pointops._C import interpolation_forward_cuda, interpolation_backward_cuda |
| from .query import knn_query |
|
|
|
|
| def interpolation(xyz, new_xyz, feat, offset, new_offset, k=3): |
| """ |
| input: coords: (m, 3), new_xyz: (n, 3), color: (m, c), offset: (b), new_offset: (b) |
| output: (n, c) |
| """ |
| assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() |
| idx, dist = knn_query(k, xyz, offset, new_xyz, new_offset) |
| dist_recip = 1.0 / (dist + 1e-8) |
| norm = torch.sum(dist_recip, dim=1, keepdim=True) |
| weight = dist_recip / norm |
|
|
| new_feat = torch.zeros( |
| (new_xyz.shape[0], feat.shape[1]), dtype=torch.float, device=xyz.device |
| ) |
| for i in range(k): |
| new_feat += feat[idx[:, i].long(), :] * weight[:, i].unsqueeze(-1) |
| return new_feat |
|
|
|
|
| class Interpolation(Function): |
| @staticmethod |
| def forward(ctx, xyz, new_xyz, input, offset, new_offset, k=3): |
| """ |
| input: coords: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) |
| output: (n, c) |
| """ |
| assert xyz.is_contiguous() and new_xyz.is_contiguous() and input.is_contiguous() |
| idx, dist = knn_query(k, xyz, offset, new_xyz, new_offset) |
| dist_recip = 1.0 / (dist + 1e-8) |
| norm = torch.sum(dist_recip, dim=1, keepdim=True) |
| weight = dist_recip / norm |
|
|
| n, c, m = new_xyz.shape[0], input.shape[1], input.shape[0] |
| output = torch.zeros((n, c), dtype=torch.float, device=xyz.device) |
| interpolation_forward_cuda(n, c, k, input, idx, weight, output) |
| ctx.m, ctx.k = m, k |
| ctx.save_for_backward(idx, weight) |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| """ |
| input: coords: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) |
| output: (n, c) |
| """ |
| m, k = ctx.m, ctx.k |
| idx, weight = ctx.saved_tensors |
| n, c = grad_output.shape |
| grad_input = torch.zeros((m, c), dtype=torch.float, device=idx.device) |
| interpolation_backward_cuda(n, c, k, grad_output, idx, weight, grad_input) |
| return None, None, grad_input, None, None, None |
|
|
|
|
| interpolation2 = Interpolation.apply |
|
|