| import torch |
| from torch.autograd import Function |
|
|
| from pointops._C import ( |
| attention_relation_step_forward_cuda, |
| attention_relation_step_backward_cuda, |
| attention_fusion_step_forward_cuda, |
| attention_fusion_step_backward_cuda, |
| ) |
|
|
|
|
| class AttentionRelationStep(Function): |
| @staticmethod |
| def forward(ctx, query, key, weight, index_target, index_refer): |
| """ |
| input - query: (n, g, c), key: (n, g, c), weight: (c) 1_c for scatter attention, |
| index_target: (m), index_refer: (m) |
| output - relation: (M, g) |
| """ |
|
|
| assert ( |
| query.is_contiguous() |
| and key.is_contiguous() |
| and index_target.is_contiguous() |
| and index_refer.is_contiguous() |
| and weight.is_contiguous() |
| ) |
|
|
| assert index_target.shape[0] == index_refer.shape[0] |
|
|
| _, g, c = query.shape |
| m = index_target.shape[0] |
| output = torch.zeros((m, g), dtype=torch.float, device=query.device) |
| attention_relation_step_forward_cuda( |
| m, g, c, query, key, weight, index_target.int(), index_refer.int(), output |
| ) |
| ctx.save_for_backward(query, key, weight, index_target, index_refer) |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| query, key, weight, index_target, index_refer = ctx.saved_tensors |
| n, g, c = query.shape |
| m = index_target.shape[0] |
| grad_query = torch.zeros((n, g, c), dtype=torch.float, device=query.device) |
| grad_key = torch.zeros((n, g, c), dtype=torch.float, device=query.device) |
| grad_weight = torch.zeros(c, dtype=torch.float, device=query.device) |
| attention_relation_step_backward_cuda( |
| m, |
| g, |
| c, |
| query, |
| grad_query, |
| key, |
| grad_key, |
| weight, |
| grad_weight, |
| index_target.int(), |
| index_refer.int(), |
| grad_output, |
| ) |
| return grad_query, grad_key, None, None, None |
|
|
|
|
| class AttentionFusionStep(Function): |
| @staticmethod |
| def forward(ctx, weight, value, index_target, index_refer): |
| """ |
| input - weight: (m, g), value: (n, g, c) |
| index_target: (m), index_value: (m) |
| output - output: (n, g, c) |
| """ |
|
|
| assert ( |
| weight.is_contiguous() |
| and value.is_contiguous() |
| and index_target.is_contiguous() |
| and index_refer.is_contiguous() |
| and weight.is_contiguous() |
| ) |
|
|
| assert index_target.shape[0] == index_refer.shape[0] |
|
|
| n, g, c = value.shape |
| m = index_refer.shape[0] |
| output = torch.zeros((n, g, c), dtype=torch.float, device=weight.device) |
| attention_fusion_step_forward_cuda( |
| m, g, c, weight, value, index_target.int(), index_refer.int(), output |
| ) |
| ctx.save_for_backward(weight, value, index_target, index_refer) |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| """ |
| input: grad_output: (n, g, c) |
| output: grad_weight: (m, g), grad_value: (n, g, c), none, none |
| """ |
| weight, value, index_target, index_refer = ctx.saved_tensors |
| n, g, c = value.shape |
| m = index_target.shape[0] |
| grad_weight = torch.zeros((m, g), dtype=torch.float, device=weight.device) |
| grad_value = torch.zeros((n, g, c), dtype=torch.float, device=weight.device) |
| attention_fusion_step_backward_cuda( |
| m, |
| g, |
| c, |
| weight, |
| grad_weight, |
| value, |
| grad_value, |
| index_target.int(), |
| index_refer.int(), |
| grad_output, |
| ) |
| return grad_weight, grad_value, None, None |
|
|
|
|
| attention_relation_step = AttentionRelationStep.apply |
| attention_fusion_step = AttentionFusionStep.apply |
|
|