Spaces:
Paused
Paused
| // Copyright (c) OpenMMLab. All rights reserved | |
| // Modified from | |
| // https://github.com/hszhao/semseg/blob/master/lib/psa/src | |
| void psamask_forward_impl(const int psa_type, const Tensor input, Tensor output, | |
| const int num_, const int h_feature, | |
| const int w_feature, const int h_mask, | |
| const int w_mask, const int half_h_mask, | |
| const int half_w_mask) { | |
| DISPATCH_DEVICE_IMPL(psamask_forward_impl, psa_type, input, output, num_, | |
| h_feature, w_feature, h_mask, w_mask, half_h_mask, | |
| half_w_mask); | |
| } | |
| void psamask_backward_impl(const int psa_type, const Tensor grad_output, | |
| Tensor grad_input, const int num_, | |
| const int h_feature, const int w_feature, | |
| const int h_mask, const int w_mask, | |
| const int half_h_mask, const int half_w_mask) { | |
| DISPATCH_DEVICE_IMPL(psamask_backward_impl, psa_type, grad_output, grad_input, | |
| num_, h_feature, w_feature, h_mask, w_mask, half_h_mask, | |
| half_w_mask); | |
| } | |
| void psamask_forward(const Tensor input, Tensor output, const int psa_type, | |
| const int num_, const int h_feature, const int w_feature, | |
| const int h_mask, const int w_mask, const int half_h_mask, | |
| const int half_w_mask) { | |
| psamask_forward_impl(psa_type, input, output, num_, h_feature, w_feature, | |
| h_mask, w_mask, half_h_mask, half_w_mask); | |
| } | |
| void psamask_backward(Tensor grad_output, const Tensor grad_input, | |
| const int psa_type, const int num_, const int h_feature, | |
| const int w_feature, const int h_mask, const int w_mask, | |
| const int half_h_mask, const int half_w_mask) { | |
| psamask_backward_impl(psa_type, grad_output, grad_input, num_, h_feature, | |
| w_feature, h_mask, w_mask, half_h_mask, half_w_mask); | |
| } | |