Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.autograd import Function | |
| from pointops._C import aggregation_forward_cuda, aggregation_backward_cuda | |
| class Aggregation(Function): | |
| def forward(ctx, input, position, weight, idx): | |
| """ | |
| input: input: (n, c), position: (n, nsample, c), weight : (n, nsample, c'), idx: (n, nsample) | |
| output: (n, c) | |
| """ | |
| assert ( | |
| input.is_contiguous() | |
| and position.is_contiguous() | |
| and weight.is_contiguous() | |
| ) | |
| n, nsample, c = position.shape | |
| w_c = weight.shape[-1] | |
| output = torch.cuda.FloatTensor(n, c).zero_() | |
| aggregation_forward_cuda( | |
| n, nsample, c, w_c, input, position, weight, idx, output | |
| ) | |
| ctx.save_for_backward(input, position, weight, idx) | |
| return output | |
| def backward(ctx, grad_output): | |
| """ | |
| input: grad_out: (n, c) | |
| output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight : (n, nsample, c') | |
| """ | |
| input, position, weight, idx = ctx.saved_tensors | |
| n, nsample, c = position.shape | |
| w_c = weight.shape[-1] | |
| grad_input = torch.cuda.FloatTensor(n, c).zero_() | |
| grad_position = torch.cuda.FloatTensor(n, nsample, c).zero_() | |
| grad_weight = torch.cuda.FloatTensor(n, nsample, w_c).zero_() | |
| aggregation_backward_cuda( | |
| n, | |
| nsample, | |
| c, | |
| w_c, | |
| input, | |
| position, | |
| weight, | |
| idx, | |
| grad_output, | |
| grad_input, | |
| grad_position, | |
| grad_weight, | |
| ) | |
| return grad_input, grad_position, grad_weight, None | |
| aggregation = Aggregation.apply | |