| #pragma once |
|
|
| #include <torch/extension.h> |
|
|
| at::Tensor ms_deform_attn_cuda_c2345_forward( |
| const at::Tensor& feat_c2, |
| const at::Tensor& feat_c3, |
| const at::Tensor& feat_c4, |
| const at::Tensor& feat_c5, |
| const at::Tensor& sampling_loc, |
| const at::Tensor& attn_weight |
| ); |
|
|
| std::vector<at::Tensor> ms_deform_attn_cuda_c2345_backward( |
| const at::Tensor& feat_c2, |
| const at::Tensor& feat_c3, |
| const at::Tensor& feat_c4, |
| const at::Tensor& feat_c5, |
| const at::Tensor& sampling_loc, |
| const at::Tensor& attn_weight, |
| const at::Tensor& grad_output |
| ); |
|
|
| at::Tensor ms_deform_attn_cuda_c23456_forward( |
| const at::Tensor& feat_c2, |
| const at::Tensor& feat_c3, |
| const at::Tensor& feat_c4, |
| const at::Tensor& feat_c5, |
| const at::Tensor& feat_c6, |
| const at::Tensor& sampling_loc, |
| const at::Tensor& attn_weight |
| ); |
|
|
| std::vector<at::Tensor> ms_deform_attn_cuda_c23456_backward( |
| const at::Tensor& grad_output, |
| const at::Tensor& feat_c2, |
| const at::Tensor& feat_c3, |
| const at::Tensor& feat_c4, |
| const at::Tensor& feat_c5, |
| const at::Tensor& feat_c6, |
| const at::Tensor& sampling_loc, |
| const at::Tensor& attn_weight |
| ); |