| import torch |
| from torch.autograd import Function |
|
|
| from pointops._C import aggregation_forward_cuda, aggregation_backward_cuda |
|
|
|
|
| class Aggregation(Function): |
| @staticmethod |
| def forward(ctx, input, position, weight, idx): |
| """ |
| input: input: (n, c), position: (n, nsample, c), weight : (n, nsample, c'), idx: (n, nsample) |
| output: (n, c) |
| """ |
| assert ( |
| input.is_contiguous() |
| and position.is_contiguous() |
| and weight.is_contiguous() |
| ) |
| n, nsample, c = position.shape |
| w_c = weight.shape[-1] |
| output = torch.zeros((n, c), dtype=torch.float, device=input.device) |
| aggregation_forward_cuda( |
| n, nsample, c, w_c, input, position, weight, idx, output |
| ) |
| ctx.save_for_backward(input, position, weight, idx) |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| """ |
| input: grad_out: (n, c) |
| output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight : (n, nsample, c') |
| """ |
| input, position, weight, idx = ctx.saved_tensors |
| n, nsample, c = position.shape |
| w_c = weight.shape[-1] |
| grad_input = torch.zeros((n, c), dtype=torch.float, device=input.device) |
| grad_position = torch.zeros( |
| (n, nsample, c), dtype=torch.float, device=input.device |
| ) |
| grad_weight = torch.zeros( |
| (n, nsample, w_c), dtype=torch.float, device=input.device |
| ) |
| aggregation_backward_cuda( |
| n, |
| nsample, |
| c, |
| w_c, |
| input, |
| position, |
| weight, |
| idx, |
| grad_output, |
| grad_input, |
| grad_position, |
| grad_weight, |
| ) |
| return grad_input, grad_position, grad_weight, None |
|
|
|
|
| aggregation = Aggregation.apply |
|
|