File size: 1,676 Bytes
d19bd3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#pragma once

#include <torch/extension.h>

at::Tensor ms_deform_attn_cuda_c2345_forward(
    const at::Tensor& feat_c2,  // [B, N, H, W, C]
    const at::Tensor& feat_c3,  // [B, N, H, W, C]
    const at::Tensor& feat_c4,  // [B, N, H, W, C]
    const at::Tensor& feat_c5,  // [B, N, H, W, C]
    const at::Tensor& sampling_loc,  // [B, Q, P, 3]
    const at::Tensor& attn_weight  // [B, Q, P, 4]
);

std::vector<at::Tensor> ms_deform_attn_cuda_c2345_backward(
    const at::Tensor& feat_c2,  // [B, N, H, W, C]
    const at::Tensor& feat_c3,  // [B, N, H, W, C]
    const at::Tensor& feat_c4,  // [B, N, H, W, C]
    const at::Tensor& feat_c5,  // [B, N, H, W, C]
    const at::Tensor& sampling_loc,  // [B, Q, P, 3]
    const at::Tensor& attn_weight,  // [B, Q, P, 4]
    const at::Tensor& grad_output
);

at::Tensor ms_deform_attn_cuda_c23456_forward(
    const at::Tensor& feat_c2,  // [B, N, H, W, C]
    const at::Tensor& feat_c3,  // [B, N, H, W, C]
    const at::Tensor& feat_c4,  // [B, N, H, W, C]
    const at::Tensor& feat_c5,  // [B, N, H, W, C]
    const at::Tensor& feat_c6,  // [B, N, H, W, C]
    const at::Tensor& sampling_loc,  // [B, Q, P, 3]
    const at::Tensor& attn_weight  // [B, Q, P, 4]
);

std::vector<at::Tensor> ms_deform_attn_cuda_c23456_backward(
    const at::Tensor& grad_output,
    const at::Tensor& feat_c2,  // [B, N, H, W, C]
    const at::Tensor& feat_c3,  // [B, N, H, W, C]
    const at::Tensor& feat_c4,  // [B, N, H, W, C]
    const at::Tensor& feat_c5,  // [B, N, H, W, C]
    const at::Tensor& feat_c6,  // [B, N, H, W, C]
    const at::Tensor& sampling_loc,  // [B, Q, P, 3]
    const at::Tensor& attn_weight  // [B, Q, P, 4]
);