YYYYYYUUU's picture
Backup FULL local core code incl. libs/ CUDA ext + all configs
3499c27 verified
Raw
History Blame Contribute Delete
1.89 kB
import torch
from torch.autograd import Function
from pointops._C import aggregation_forward_cuda, aggregation_backward_cuda
class Aggregation(Function):
@staticmethod
def forward(ctx, input, position, weight, idx):
"""
input: input: (n, c), position: (n, nsample, c), weight : (n, nsample, c'), idx: (n, nsample)
output: (n, c)
"""
assert (
input.is_contiguous()
and position.is_contiguous()
and weight.is_contiguous()
)
n, nsample, c = position.shape
w_c = weight.shape[-1]
output = torch.zeros((n, c), dtype=torch.float, device=input.device)
aggregation_forward_cuda(
n, nsample, c, w_c, input, position, weight, idx, output
)
ctx.save_for_backward(input, position, weight, idx)
return output
@staticmethod
def backward(ctx, grad_output):
"""
input: grad_out: (n, c)
output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight : (n, nsample, c')
"""
input, position, weight, idx = ctx.saved_tensors
n, nsample, c = position.shape
w_c = weight.shape[-1]
grad_input = torch.zeros((n, c), dtype=torch.float, device=input.device)
grad_position = torch.zeros(
(n, nsample, c), dtype=torch.float, device=input.device
)
grad_weight = torch.zeros(
(n, nsample, w_c), dtype=torch.float, device=input.device
)
aggregation_backward_cuda(
n,
nsample,
c,
w_c,
input,
position,
weight,
idx,
grad_output,
grad_input,
grad_position,
grad_weight,
)
return grad_input, grad_position, grad_weight, None
aggregation = Aggregation.apply