| import torch |
| from torch.autograd import Function |
|
|
| from pointops._C import subtraction_forward_cuda, subtraction_backward_cuda |
|
|
|
|
| class Subtraction(Function): |
| @staticmethod |
| def forward(ctx, input1, input2, idx): |
| """ |
| input: input1: (n, c), input2: (n, c), idx: (n, nsample) |
| output: (n, nsample, c) |
| """ |
| assert input1.is_contiguous() and input2.is_contiguous() |
| n, c = input1.shape |
| nsample = idx.shape[-1] |
| output = torch.zeros((n, nsample, c), dtype=torch.float, device=input1.device) |
| subtraction_forward_cuda(n, nsample, c, input1, input2, idx, output) |
| ctx.save_for_backward(idx) |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| """ |
| input: grad_out: (n, nsample, c) |
| output: grad_input1: (n, c), grad_input2: (n, c) |
| """ |
| (idx,) = ctx.saved_tensors |
| n, nsample, c = grad_output.shape |
| grad_input1 = torch.zeros((n, c), dtype=torch.float, device=idx.device) |
| grad_input2 = torch.zeros((n, c), dtype=torch.float, device=idx.device) |
|
|
| subtraction_backward_cuda( |
| n, nsample, c, idx, grad_output, grad_input1, grad_input2 |
| ) |
| return grad_input1, grad_input2, None |
|
|
|
|
| subtraction = Subtraction.apply |
|
|