Spaces:
Runtime error
Runtime error
| // CUDA forward declarations | |
| std::vector<torch::Tensor> corr_cuda_forward( | |
| torch::Tensor fmap1, | |
| torch::Tensor fmap2, | |
| torch::Tensor coords, | |
| int radius); | |
| std::vector<torch::Tensor> corr_cuda_backward( | |
| torch::Tensor fmap1, | |
| torch::Tensor fmap2, | |
| torch::Tensor coords, | |
| torch::Tensor corr_grad, | |
| int radius); | |
| // C++ interface | |
| std::vector<torch::Tensor> corr_forward( | |
| torch::Tensor fmap1, | |
| torch::Tensor fmap2, | |
| torch::Tensor coords, | |
| int radius) { | |
| CHECK_INPUT(fmap1); | |
| CHECK_INPUT(fmap2); | |
| CHECK_INPUT(coords); | |
| return corr_cuda_forward(fmap1, fmap2, coords, radius); | |
| } | |
| std::vector<torch::Tensor> corr_backward( | |
| torch::Tensor fmap1, | |
| torch::Tensor fmap2, | |
| torch::Tensor coords, | |
| torch::Tensor corr_grad, | |
| int radius) { | |
| CHECK_INPUT(fmap1); | |
| CHECK_INPUT(fmap2); | |
| CHECK_INPUT(coords); | |
| CHECK_INPUT(corr_grad); | |
| return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); | |
| } | |
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
| m.def("forward", &corr_forward, "CORR forward"); | |
| m.def("backward", &corr_backward, "CORR backward"); | |
| } | |