Kernels
danieldk HF Staff commited on
Commit
ba38f49
·
1 Parent(s): 2c8e21f

Remove source

Browse files
README.md CHANGED
@@ -7,3 +7,5 @@ tags:
7
  ![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/deformable-detr)
8
 
9
  ## deformable-detr
 
 
 
7
  ![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/deformable-detr)
8
 
9
  ## deformable-detr
10
+
11
+ Kernel source: https://github.com/huggingface/kernels-community/tree/main/deformable-detr
build.toml DELETED
@@ -1,20 +0,0 @@
1
- [general]
2
- name = "deformable_detr"
3
- universal = false
4
-
5
- [torch]
6
- src = [
7
- "torch-ext/torch_binding.cpp",
8
- "torch-ext/torch_binding.h",
9
- ]
10
-
11
- [kernel.activation]
12
- backend = "cuda"
13
- depends = ["torch"]
14
- include = ["."]
15
- src = [
16
- "deformable_detr/ms_deform_attn_cuda.cu",
17
- "deformable_detr/ms_deform_im2col_cuda.cuh",
18
- "deformable_detr/ms_deform_attn_cuda.cuh",
19
- "deformable_detr/ms_deform_attn_cuda.h",
20
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deformable_detr/ms_deform_attn_cuda.cu DELETED
@@ -1,158 +0,0 @@
1
- /*!
2
- **************************************************************************************************
3
- * Deformable DETR
4
- * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
- * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
- **************************************************************************************************
7
- * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
- **************************************************************************************************
9
- */
10
-
11
- #include <vector>
12
- #include "deformable_detr/ms_deform_im2col_cuda.cuh"
13
-
14
- #include <ATen/ATen.h>
15
- #include <ATen/cuda/CUDAContext.h>
16
- #include <cuda.h>
17
- #include <cuda_runtime.h>
18
-
19
- #include <torch/all.h>
20
-
21
-
22
- at::Tensor ms_deform_attn_cuda_forward(
23
- const at::Tensor &value,
24
- const at::Tensor &spatial_shapes,
25
- const at::Tensor &level_start_index,
26
- const at::Tensor &sampling_loc,
27
- const at::Tensor &attn_weight,
28
- const int64_t im2col_step)
29
- {
30
- at::DeviceGuard guard(value.device());
31
-
32
- AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
33
- AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
34
- AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
35
- AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
36
- AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
37
-
38
- AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
39
- AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
40
- AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
41
- AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
42
- AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
43
-
44
- const int batch = value.size(0);
45
- const int spatial_size = value.size(1);
46
- const int num_heads = value.size(2);
47
- const int channels = value.size(3);
48
-
49
- const int num_levels = spatial_shapes.size(0);
50
-
51
- const int num_query = sampling_loc.size(1);
52
- const int num_point = sampling_loc.size(4);
53
-
54
- const int im2col_step_ = std::min(batch, static_cast<int>(im2col_step));
55
-
56
- AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
57
-
58
- auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
59
-
60
- const int batch_n = im2col_step_;
61
- auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
62
- auto per_value_size = spatial_size * num_heads * channels;
63
- auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
64
- auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
65
- for (int n = 0; n < batch/im2col_step_; ++n)
66
- {
67
- auto columns = output_n.select(0, n);
68
- AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
69
- ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
70
- value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
71
- spatial_shapes.data_ptr<int64_t>(),
72
- level_start_index.data_ptr<int64_t>(),
73
- sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
74
- attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
75
- batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
76
- columns.data_ptr<scalar_t>());
77
-
78
- }));
79
- }
80
-
81
- output = output.view({batch, num_query, num_heads*channels});
82
-
83
- return output;
84
- }
85
-
86
-
87
- std::vector<at::Tensor> ms_deform_attn_cuda_backward(
88
- const at::Tensor &value,
89
- const at::Tensor &spatial_shapes,
90
- const at::Tensor &level_start_index,
91
- const at::Tensor &sampling_loc,
92
- const at::Tensor &attn_weight,
93
- const at::Tensor &grad_output,
94
- const int64_t im2col_step)
95
- {
96
- at::DeviceGuard guard(value.device());
97
-
98
- AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
99
- AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
100
- AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
101
- AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
102
- AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
103
- AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
104
-
105
- AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
106
- AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
107
- AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
108
- AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
109
- AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
110
- AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
111
-
112
- const int batch = value.size(0);
113
- const int spatial_size = value.size(1);
114
- const int num_heads = value.size(2);
115
- const int channels = value.size(3);
116
-
117
- const int num_levels = spatial_shapes.size(0);
118
-
119
- const int num_query = sampling_loc.size(1);
120
- const int num_point = sampling_loc.size(4);
121
-
122
- const int im2col_step_ = std::min(batch, static_cast<int>(im2col_step));
123
-
124
- AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
125
-
126
- auto grad_value = at::zeros_like(value);
127
- auto grad_sampling_loc = at::zeros_like(sampling_loc);
128
- auto grad_attn_weight = at::zeros_like(attn_weight);
129
-
130
- const int batch_n = im2col_step_;
131
- auto per_value_size = spatial_size * num_heads * channels;
132
- auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
133
- auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
134
- auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
135
-
136
- for (int n = 0; n < batch/im2col_step_; ++n)
137
- {
138
- auto grad_output_g = grad_output_n.select(0, n);
139
- AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
140
- ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
141
- grad_output_g.data_ptr<scalar_t>(),
142
- value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
143
- spatial_shapes.data_ptr<int64_t>(),
144
- level_start_index.data_ptr<int64_t>(),
145
- sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
146
- attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
147
- batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
148
- grad_value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
149
- grad_sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
150
- grad_attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
151
-
152
- }));
153
- }
154
-
155
- return {
156
- grad_value, grad_sampling_loc, grad_attn_weight
157
- };
158
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deformable_detr/ms_deform_attn_cuda.cuh DELETED
@@ -1,1467 +0,0 @@
1
- /*!
2
- **************************************************************************************************
3
- * Deformable DETR
4
- * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
- * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
- **************************************************************************************************
7
- * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
- **************************************************************************************************
9
- */
10
-
11
- #include <vector>
12
-
13
- #include <cuda.h>
14
- #include <cuda_runtime.h>
15
-
16
- #include <cstdio>
17
- #include <algorithm>
18
- #include <cstring>
19
-
20
- #include <ATen/ATen.h>
21
- #include <ATen/cuda/CUDAContext.h>
22
-
23
- #include <THC/THCAtomics.cuh>
24
-
25
- #define CUDA_KERNEL_LOOP(i, n) \
26
- for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
27
- i < (n); \
28
- i += blockDim.x * gridDim.x)
29
-
30
-
31
- at::Tensor ms_deform_attn_cuda_forward(
32
- const at::Tensor &value,
33
- const at::Tensor &spatial_shapes,
34
- const at::Tensor &level_start_index,
35
- const at::Tensor &sampling_loc,
36
- const at::Tensor &attn_weight,
37
- const int im2col_step)
38
- {
39
- AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
40
- AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
41
- AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
42
- AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
43
- AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
44
-
45
- AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
46
- AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
47
- AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
48
- AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
49
- AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
50
-
51
- const int batch = value.size(0);
52
- const int spatial_size = value.size(1);
53
- const int num_heads = value.size(2);
54
- const int channels = value.size(3);
55
-
56
- const int num_levels = spatial_shapes.size(0);
57
-
58
- const int num_query = sampling_loc.size(1);
59
- const int num_point = sampling_loc.size(4);
60
-
61
- const int im2col_step_ = std::min(batch, im2col_step);
62
-
63
- AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
64
-
65
- auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
66
-
67
- const int batch_n = im2col_step_;
68
- auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
69
- auto per_value_size = spatial_size * num_heads * channels;
70
- auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
71
- auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
72
- for (int n = 0; n < batch/im2col_step_; ++n)
73
- {
74
- auto columns = output_n.select(0, n);
75
- AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
76
- ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
77
- value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
78
- spatial_shapes.data_ptr<int64_t>(),
79
- level_start_index.data_ptr<int64_t>(),
80
- sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
81
- attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
82
- batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
83
- columns.data_ptr<scalar_t>());
84
-
85
- }));
86
- }
87
-
88
- output = output.view({batch, num_query, num_heads*channels});
89
-
90
- return output;
91
- }
92
-
93
-
94
- std::vector<at::Tensor> ms_deform_attn_cuda_backward(
95
- const at::Tensor &value,
96
- const at::Tensor &spatial_shapes,
97
- const at::Tensor &level_start_index,
98
- const at::Tensor &sampling_loc,
99
- const at::Tensor &attn_weight,
100
- const at::Tensor &grad_output,
101
- const int im2col_step)
102
- {
103
-
104
- AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
105
- AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
106
- AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
107
- AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
108
- AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
109
- AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
110
-
111
- AT_ASSERTM(value.is_cuda(), "value must be a CUDA tensor");
112
- AT_ASSERTM(spatial_shapes.is_cuda(), "spatial_shapes must be a CUDA tensor");
113
- AT_ASSERTM(level_start_index.is_cuda(), "level_start_index must be a CUDA tensor");
114
- AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
115
- AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
116
- AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
117
-
118
- const int batch = value.size(0);
119
- const int spatial_size = value.size(1);
120
- const int num_heads = value.size(2);
121
- const int channels = value.size(3);
122
-
123
- const int num_levels = spatial_shapes.size(0);
124
-
125
- const int num_query = sampling_loc.size(1);
126
- const int num_point = sampling_loc.size(4);
127
-
128
- const int im2col_step_ = std::min(batch, im2col_step);
129
-
130
- AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
131
-
132
- auto grad_value = at::zeros_like(value);
133
- auto grad_sampling_loc = at::zeros_like(sampling_loc);
134
- auto grad_attn_weight = at::zeros_like(attn_weight);
135
-
136
- const int batch_n = im2col_step_;
137
- auto per_value_size = spatial_size * num_heads * channels;
138
- auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
139
- auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
140
- auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
141
-
142
- for (int n = 0; n < batch/im2col_step_; ++n)
143
- {
144
- auto grad_output_g = grad_output_n.select(0, n);
145
- AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
146
- ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
147
- grad_output_g.data_ptr<scalar_t>(),
148
- value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
149
- spatial_shapes.data_ptr<int64_t>(),
150
- level_start_index.data_ptr<int64_t>(),
151
- sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
152
- attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
153
- batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
154
- grad_value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
155
- grad_sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
156
- grad_attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
157
-
158
- }));
159
- }
160
-
161
- return {
162
- grad_value, grad_sampling_loc, grad_attn_weight
163
- };
164
- }
165
-
166
- const int CUDA_NUM_THREADS = 1024;
167
- inline int GET_BLOCKS(const int N, const int num_threads)
168
- {
169
- return (N + num_threads - 1) / num_threads;
170
- }
171
-
172
-
173
- template <typename scalar_t>
174
- __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
175
- const int &height, const int &width, const int &nheads, const int &channels,
176
- const scalar_t &h, const scalar_t &w, const int &m, const int &c)
177
- {
178
- const int h_low = floor(h);
179
- const int w_low = floor(w);
180
- const int h_high = h_low + 1;
181
- const int w_high = w_low + 1;
182
-
183
- const scalar_t lh = h - h_low;
184
- const scalar_t lw = w - w_low;
185
- const scalar_t hh = 1 - lh, hw = 1 - lw;
186
-
187
- const int w_stride = nheads * channels;
188
- const int h_stride = width * w_stride;
189
- const int h_low_ptr_offset = h_low * h_stride;
190
- const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
191
- const int w_low_ptr_offset = w_low * w_stride;
192
- const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
193
- const int base_ptr = m * channels + c;
194
-
195
- scalar_t v1 = 0;
196
- if (h_low >= 0 && w_low >= 0)
197
- {
198
- const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
199
- v1 = bottom_data[ptr1];
200
- }
201
- scalar_t v2 = 0;
202
- if (h_low >= 0 && w_high <= width - 1)
203
- {
204
- const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
205
- v2 = bottom_data[ptr2];
206
- }
207
- scalar_t v3 = 0;
208
- if (h_high <= height - 1 && w_low >= 0)
209
- {
210
- const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
211
- v3 = bottom_data[ptr3];
212
- }
213
- scalar_t v4 = 0;
214
- if (h_high <= height - 1 && w_high <= width - 1)
215
- {
216
- const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
217
- v4 = bottom_data[ptr4];
218
- }
219
-
220
- const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
221
-
222
- const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
223
- return val;
224
- }
225
-
226
-
227
- template <typename scalar_t>
228
- __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
229
- const int &height, const int &width, const int &nheads, const int &channels,
230
- const scalar_t &h, const scalar_t &w, const int &m, const int &c,
231
- const scalar_t &top_grad,
232
- const scalar_t &attn_weight,
233
- scalar_t* &grad_value,
234
- scalar_t* grad_sampling_loc,
235
- scalar_t* grad_attn_weight)
236
- {
237
- const int h_low = floor(h);
238
- const int w_low = floor(w);
239
- const int h_high = h_low + 1;
240
- const int w_high = w_low + 1;
241
-
242
- const scalar_t lh = h - h_low;
243
- const scalar_t lw = w - w_low;
244
- const scalar_t hh = 1 - lh, hw = 1 - lw;
245
-
246
- const int w_stride = nheads * channels;
247
- const int h_stride = width * w_stride;
248
- const int h_low_ptr_offset = h_low * h_stride;
249
- const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
250
- const int w_low_ptr_offset = w_low * w_stride;
251
- const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
252
- const int base_ptr = m * channels + c;
253
-
254
- const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
255
- const scalar_t top_grad_value = top_grad * attn_weight;
256
- scalar_t grad_h_weight = 0, grad_w_weight = 0;
257
-
258
- scalar_t v1 = 0;
259
- if (h_low >= 0 && w_low >= 0)
260
- {
261
- const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
262
- v1 = bottom_data[ptr1];
263
- grad_h_weight -= hw * v1;
264
- grad_w_weight -= hh * v1;
265
- atomicAdd(grad_value+ptr1, w1*top_grad_value);
266
- }
267
- scalar_t v2 = 0;
268
- if (h_low >= 0 && w_high <= width - 1)
269
- {
270
- const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
271
- v2 = bottom_data[ptr2];
272
- grad_h_weight -= lw * v2;
273
- grad_w_weight += hh * v2;
274
- atomicAdd(grad_value+ptr2, w2*top_grad_value);
275
- }
276
- scalar_t v3 = 0;
277
- if (h_high <= height - 1 && w_low >= 0)
278
- {
279
- const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
280
- v3 = bottom_data[ptr3];
281
- grad_h_weight += hw * v3;
282
- grad_w_weight -= lh * v3;
283
- atomicAdd(grad_value+ptr3, w3*top_grad_value);
284
- }
285
- scalar_t v4 = 0;
286
- if (h_high <= height - 1 && w_high <= width - 1)
287
- {
288
- const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
289
- v4 = bottom_data[ptr4];
290
- grad_h_weight += lw * v4;
291
- grad_w_weight += lh * v4;
292
- atomicAdd(grad_value+ptr4, w4*top_grad_value);
293
- }
294
-
295
- const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
296
- *grad_attn_weight = top_grad * val;
297
- *grad_sampling_loc = width * grad_w_weight * top_grad_value;
298
- *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
299
- }
300
-
301
-
302
- template <typename scalar_t>
303
- __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
304
- const int &height, const int &width, const int &nheads, const int &channels,
305
- const scalar_t &h, const scalar_t &w, const int &m, const int &c,
306
- const scalar_t &top_grad,
307
- const scalar_t &attn_weight,
308
- scalar_t* &grad_value,
309
- scalar_t* grad_sampling_loc,
310
- scalar_t* grad_attn_weight)
311
- {
312
- const int h_low = floor(h);
313
- const int w_low = floor(w);
314
- const int h_high = h_low + 1;
315
- const int w_high = w_low + 1;
316
-
317
- const scalar_t lh = h - h_low;
318
- const scalar_t lw = w - w_low;
319
- const scalar_t hh = 1 - lh, hw = 1 - lw;
320
-
321
- const int w_stride = nheads * channels;
322
- const int h_stride = width * w_stride;
323
- const int h_low_ptr_offset = h_low * h_stride;
324
- const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
325
- const int w_low_ptr_offset = w_low * w_stride;
326
- const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
327
- const int base_ptr = m * channels + c;
328
-
329
- const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
330
- const scalar_t top_grad_value = top_grad * attn_weight;
331
- scalar_t grad_h_weight = 0, grad_w_weight = 0;
332
-
333
- scalar_t v1 = 0;
334
- if (h_low >= 0 && w_low >= 0)
335
- {
336
- const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
337
- v1 = bottom_data[ptr1];
338
- grad_h_weight -= hw * v1;
339
- grad_w_weight -= hh * v1;
340
- atomicAdd(grad_value+ptr1, w1*top_grad_value);
341
- }
342
- scalar_t v2 = 0;
343
- if (h_low >= 0 && w_high <= width - 1)
344
- {
345
- const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
346
- v2 = bottom_data[ptr2];
347
- grad_h_weight -= lw * v2;
348
- grad_w_weight += hh * v2;
349
- atomicAdd(grad_value+ptr2, w2*top_grad_value);
350
- }
351
- scalar_t v3 = 0;
352
- if (h_high <= height - 1 && w_low >= 0)
353
- {
354
- const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
355
- v3 = bottom_data[ptr3];
356
- grad_h_weight += hw * v3;
357
- grad_w_weight -= lh * v3;
358
- atomicAdd(grad_value+ptr3, w3*top_grad_value);
359
- }
360
- scalar_t v4 = 0;
361
- if (h_high <= height - 1 && w_high <= width - 1)
362
- {
363
- const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
364
- v4 = bottom_data[ptr4];
365
- grad_h_weight += lw * v4;
366
- grad_w_weight += lh * v4;
367
- atomicAdd(grad_value+ptr4, w4*top_grad_value);
368
- }
369
-
370
- const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
371
- atomicAdd(grad_attn_weight, top_grad * val);
372
- atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
373
- atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
374
- }
375
-
376
-
377
- template <typename scalar_t>
378
- __global__ void ms_deformable_im2col_gpu_kernel(const int n,
379
- const scalar_t *data_value,
380
- const int64_t *data_spatial_shapes,
381
- const int64_t *data_level_start_index,
382
- const scalar_t *data_sampling_loc,
383
- const scalar_t *data_attn_weight,
384
- const int batch_size,
385
- const int spatial_size,
386
- const int num_heads,
387
- const int channels,
388
- const int num_levels,
389
- const int num_query,
390
- const int num_point,
391
- scalar_t *data_col)
392
- {
393
- CUDA_KERNEL_LOOP(index, n)
394
- {
395
- int _temp = index;
396
- const int c_col = _temp % channels;
397
- _temp /= channels;
398
- const int sampling_index = _temp;
399
- const int m_col = _temp % num_heads;
400
- _temp /= num_heads;
401
- [[maybe_unused]] const int q_col = _temp % num_query;
402
- _temp /= num_query;
403
- const int b_col = _temp;
404
-
405
- scalar_t *data_col_ptr = data_col + index;
406
- int data_weight_ptr = sampling_index * num_levels * num_point;
407
- int data_loc_w_ptr = data_weight_ptr << 1;
408
- const int qid_stride = num_heads * channels;
409
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
410
- scalar_t col = 0;
411
-
412
- for (int l_col=0; l_col < num_levels; ++l_col)
413
- {
414
- const int level_start_id = data_level_start_index[l_col];
415
- const int spatial_h_ptr = l_col << 1;
416
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
417
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
418
- const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
419
- for (int p_col=0; p_col < num_point; ++p_col)
420
- {
421
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
422
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
423
- const scalar_t weight = data_attn_weight[data_weight_ptr];
424
-
425
- const scalar_t h_im = loc_h * spatial_h - 0.5;
426
- const scalar_t w_im = loc_w * spatial_w - 0.5;
427
-
428
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
429
- {
430
- col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
431
- }
432
-
433
- data_weight_ptr += 1;
434
- data_loc_w_ptr += 2;
435
- }
436
- }
437
- *data_col_ptr = col;
438
- }
439
- }
440
-
441
- template <typename scalar_t, unsigned int blockSize>
442
- __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
443
- const scalar_t *grad_col,
444
- const scalar_t *data_value,
445
- const int64_t *data_spatial_shapes,
446
- const int64_t *data_level_start_index,
447
- const scalar_t *data_sampling_loc,
448
- const scalar_t *data_attn_weight,
449
- const int batch_size,
450
- const int spatial_size,
451
- const int num_heads,
452
- const int channels,
453
- const int num_levels,
454
- const int num_query,
455
- const int num_point,
456
- scalar_t *grad_value,
457
- scalar_t *grad_sampling_loc,
458
- scalar_t *grad_attn_weight)
459
- {
460
- CUDA_KERNEL_LOOP(index, n)
461
- {
462
- __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
463
- __shared__ scalar_t cache_grad_attn_weight[blockSize];
464
- unsigned int tid = threadIdx.x;
465
- int _temp = index;
466
- const int c_col = _temp % channels;
467
- _temp /= channels;
468
- const int sampling_index = _temp;
469
- const int m_col = _temp % num_heads;
470
- _temp /= num_heads;
471
- [[maybe_unused]] const int q_col = _temp % num_query;
472
- _temp /= num_query;
473
- const int b_col = _temp;
474
-
475
- const scalar_t top_grad = grad_col[index];
476
-
477
- int data_weight_ptr = sampling_index * num_levels * num_point;
478
- int data_loc_w_ptr = data_weight_ptr << 1;
479
- const int grad_sampling_ptr = data_weight_ptr;
480
- grad_sampling_loc += grad_sampling_ptr << 1;
481
- grad_attn_weight += grad_sampling_ptr;
482
- const int grad_weight_stride = 1;
483
- const int grad_loc_stride = 2;
484
- const int qid_stride = num_heads * channels;
485
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
486
-
487
- for (int l_col=0; l_col < num_levels; ++l_col)
488
- {
489
- const int level_start_id = data_level_start_index[l_col];
490
- const int spatial_h_ptr = l_col << 1;
491
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
492
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
493
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
494
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
495
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
496
-
497
- for (int p_col=0; p_col < num_point; ++p_col)
498
- {
499
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
500
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
501
- const scalar_t weight = data_attn_weight[data_weight_ptr];
502
-
503
- const scalar_t h_im = loc_h * spatial_h - 0.5;
504
- const scalar_t w_im = loc_w * spatial_w - 0.5;
505
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
506
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
507
- *(cache_grad_attn_weight+threadIdx.x)=0;
508
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
509
- {
510
- ms_deform_attn_col2im_bilinear(
511
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
512
- top_grad, weight, grad_value_ptr,
513
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
514
- }
515
-
516
- __syncthreads();
517
- if (tid == 0)
518
- {
519
- scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
520
- int sid=2;
521
- for (unsigned int tid = 1; tid < blockSize; ++tid)
522
- {
523
- _grad_w += cache_grad_sampling_loc[sid];
524
- _grad_h += cache_grad_sampling_loc[sid + 1];
525
- _grad_a += cache_grad_attn_weight[tid];
526
- sid += 2;
527
- }
528
-
529
-
530
- *grad_sampling_loc = _grad_w;
531
- *(grad_sampling_loc + 1) = _grad_h;
532
- *grad_attn_weight = _grad_a;
533
- }
534
- __syncthreads();
535
-
536
- data_weight_ptr += 1;
537
- data_loc_w_ptr += 2;
538
- grad_attn_weight += grad_weight_stride;
539
- grad_sampling_loc += grad_loc_stride;
540
- }
541
- }
542
- }
543
- }
544
-
545
-
546
- template <typename scalar_t, unsigned int blockSize>
547
- __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
548
- const scalar_t *grad_col,
549
- const scalar_t *data_value,
550
- const int64_t *data_spatial_shapes,
551
- const int64_t *data_level_start_index,
552
- const scalar_t *data_sampling_loc,
553
- const scalar_t *data_attn_weight,
554
- const int batch_size,
555
- const int spatial_size,
556
- const int num_heads,
557
- const int channels,
558
- const int num_levels,
559
- const int num_query,
560
- const int num_point,
561
- scalar_t *grad_value,
562
- scalar_t *grad_sampling_loc,
563
- scalar_t *grad_attn_weight)
564
- {
565
- CUDA_KERNEL_LOOP(index, n)
566
- {
567
- __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
568
- __shared__ scalar_t cache_grad_attn_weight[blockSize];
569
- unsigned int tid = threadIdx.x;
570
- int _temp = index;
571
- const int c_col = _temp % channels;
572
- _temp /= channels;
573
- const int sampling_index = _temp;
574
- const int m_col = _temp % num_heads;
575
- _temp /= num_heads;
576
- [[maybe_unused]] const int q_col = _temp % num_query;
577
- _temp /= num_query;
578
- const int b_col = _temp;
579
-
580
- const scalar_t top_grad = grad_col[index];
581
-
582
- int data_weight_ptr = sampling_index * num_levels * num_point;
583
- int data_loc_w_ptr = data_weight_ptr << 1;
584
- const int grad_sampling_ptr = data_weight_ptr;
585
- grad_sampling_loc += grad_sampling_ptr << 1;
586
- grad_attn_weight += grad_sampling_ptr;
587
- const int grad_weight_stride = 1;
588
- const int grad_loc_stride = 2;
589
- const int qid_stride = num_heads * channels;
590
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
591
-
592
- for (int l_col=0; l_col < num_levels; ++l_col)
593
- {
594
- const int level_start_id = data_level_start_index[l_col];
595
- const int spatial_h_ptr = l_col << 1;
596
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
597
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
598
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
599
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
600
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
601
-
602
- for (int p_col=0; p_col < num_point; ++p_col)
603
- {
604
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
605
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
606
- const scalar_t weight = data_attn_weight[data_weight_ptr];
607
-
608
- const scalar_t h_im = loc_h * spatial_h - 0.5;
609
- const scalar_t w_im = loc_w * spatial_w - 0.5;
610
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
611
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
612
- *(cache_grad_attn_weight+threadIdx.x)=0;
613
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
614
- {
615
- ms_deform_attn_col2im_bilinear(
616
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
617
- top_grad, weight, grad_value_ptr,
618
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
619
- }
620
-
621
- __syncthreads();
622
-
623
- for (unsigned int s=blockSize/2; s>0; s>>=1)
624
- {
625
- if (tid < s) {
626
- const unsigned int xid1 = tid << 1;
627
- const unsigned int xid2 = (tid + s) << 1;
628
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
629
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
630
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
631
- }
632
- __syncthreads();
633
- }
634
-
635
- if (tid == 0)
636
- {
637
- *grad_sampling_loc = cache_grad_sampling_loc[0];
638
- *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
639
- *grad_attn_weight = cache_grad_attn_weight[0];
640
- }
641
- __syncthreads();
642
-
643
- data_weight_ptr += 1;
644
- data_loc_w_ptr += 2;
645
- grad_attn_weight += grad_weight_stride;
646
- grad_sampling_loc += grad_loc_stride;
647
- }
648
- }
649
- }
650
- }
651
-
652
-
653
- template <typename scalar_t>
654
- __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
655
- const scalar_t *grad_col,
656
- const scalar_t *data_value,
657
- const int64_t *data_spatial_shapes,
658
- const int64_t *data_level_start_index,
659
- const scalar_t *data_sampling_loc,
660
- const scalar_t *data_attn_weight,
661
- const int batch_size,
662
- const int spatial_size,
663
- const int num_heads,
664
- const int channels,
665
- const int num_levels,
666
- const int num_query,
667
- const int num_point,
668
- scalar_t *grad_value,
669
- scalar_t *grad_sampling_loc,
670
- scalar_t *grad_attn_weight)
671
- {
672
- CUDA_KERNEL_LOOP(index, n)
673
- {
674
- extern __shared__ int _s[];
675
- scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
676
- scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
677
- unsigned int tid = threadIdx.x;
678
- int _temp = index;
679
- const int c_col = _temp % channels;
680
- _temp /= channels;
681
- const int sampling_index = _temp;
682
- const int m_col = _temp % num_heads;
683
- _temp /= num_heads;
684
- [[maybe_unused]] const int q_col = _temp % num_query;
685
- _temp /= num_query;
686
- const int b_col = _temp;
687
-
688
- const scalar_t top_grad = grad_col[index];
689
-
690
- int data_weight_ptr = sampling_index * num_levels * num_point;
691
- int data_loc_w_ptr = data_weight_ptr << 1;
692
- const int grad_sampling_ptr = data_weight_ptr;
693
- grad_sampling_loc += grad_sampling_ptr << 1;
694
- grad_attn_weight += grad_sampling_ptr;
695
- const int grad_weight_stride = 1;
696
- const int grad_loc_stride = 2;
697
- const int qid_stride = num_heads * channels;
698
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
699
-
700
- for (int l_col=0; l_col < num_levels; ++l_col)
701
- {
702
- const int level_start_id = data_level_start_index[l_col];
703
- const int spatial_h_ptr = l_col << 1;
704
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
705
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
706
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
707
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
708
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
709
-
710
- for (int p_col=0; p_col < num_point; ++p_col)
711
- {
712
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
713
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
714
- const scalar_t weight = data_attn_weight[data_weight_ptr];
715
-
716
- const scalar_t h_im = loc_h * spatial_h - 0.5;
717
- const scalar_t w_im = loc_w * spatial_w - 0.5;
718
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
719
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
720
- *(cache_grad_attn_weight+threadIdx.x)=0;
721
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
722
- {
723
- ms_deform_attn_col2im_bilinear(
724
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
725
- top_grad, weight, grad_value_ptr,
726
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
727
- }
728
-
729
- __syncthreads();
730
- if (tid == 0)
731
- {
732
- scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
733
- int sid=2;
734
- for (unsigned int tid = 1; tid < blockDim.x; ++tid)
735
- {
736
- _grad_w += cache_grad_sampling_loc[sid];
737
- _grad_h += cache_grad_sampling_loc[sid + 1];
738
- _grad_a += cache_grad_attn_weight[tid];
739
- sid += 2;
740
- }
741
-
742
-
743
- *grad_sampling_loc = _grad_w;
744
- *(grad_sampling_loc + 1) = _grad_h;
745
- *grad_attn_weight = _grad_a;
746
- }
747
- __syncthreads();
748
-
749
- data_weight_ptr += 1;
750
- data_loc_w_ptr += 2;
751
- grad_attn_weight += grad_weight_stride;
752
- grad_sampling_loc += grad_loc_stride;
753
- }
754
- }
755
- }
756
- }
757
-
758
- template <typename scalar_t>
759
- __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
760
- const scalar_t *grad_col,
761
- const scalar_t *data_value,
762
- const int64_t *data_spatial_shapes,
763
- const int64_t *data_level_start_index,
764
- const scalar_t *data_sampling_loc,
765
- const scalar_t *data_attn_weight,
766
- const int batch_size,
767
- const int spatial_size,
768
- const int num_heads,
769
- const int channels,
770
- const int num_levels,
771
- const int num_query,
772
- const int num_point,
773
- scalar_t *grad_value,
774
- scalar_t *grad_sampling_loc,
775
- scalar_t *grad_attn_weight)
776
- {
777
- CUDA_KERNEL_LOOP(index, n)
778
- {
779
- extern __shared__ int _s[];
780
- scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
781
- scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
782
- unsigned int tid = threadIdx.x;
783
- int _temp = index;
784
- const int c_col = _temp % channels;
785
- _temp /= channels;
786
- const int sampling_index = _temp;
787
- const int m_col = _temp % num_heads;
788
- _temp /= num_heads;
789
- [[maybe_unused]] const int q_col = _temp % num_query;
790
- _temp /= num_query;
791
- const int b_col = _temp;
792
-
793
- const scalar_t top_grad = grad_col[index];
794
-
795
- int data_weight_ptr = sampling_index * num_levels * num_point;
796
- int data_loc_w_ptr = data_weight_ptr << 1;
797
- const int grad_sampling_ptr = data_weight_ptr;
798
- grad_sampling_loc += grad_sampling_ptr << 1;
799
- grad_attn_weight += grad_sampling_ptr;
800
- const int grad_weight_stride = 1;
801
- const int grad_loc_stride = 2;
802
- const int qid_stride = num_heads * channels;
803
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
804
-
805
- for (int l_col=0; l_col < num_levels; ++l_col)
806
- {
807
- const int level_start_id = data_level_start_index[l_col];
808
- const int spatial_h_ptr = l_col << 1;
809
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
810
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
811
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
812
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
813
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
814
-
815
- for (int p_col=0; p_col < num_point; ++p_col)
816
- {
817
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
818
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
819
- const scalar_t weight = data_attn_weight[data_weight_ptr];
820
-
821
- const scalar_t h_im = loc_h * spatial_h - 0.5;
822
- const scalar_t w_im = loc_w * spatial_w - 0.5;
823
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
824
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
825
- *(cache_grad_attn_weight+threadIdx.x)=0;
826
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
827
- {
828
- ms_deform_attn_col2im_bilinear(
829
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
830
- top_grad, weight, grad_value_ptr,
831
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
832
- }
833
-
834
- __syncthreads();
835
-
836
- for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
837
- {
838
- if (tid < s) {
839
- const unsigned int xid1 = tid << 1;
840
- const unsigned int xid2 = (tid + s) << 1;
841
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
842
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
843
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
844
- if (tid + (s << 1) < spre)
845
- {
846
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
847
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
848
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
849
- }
850
- }
851
- __syncthreads();
852
- }
853
-
854
- if (tid == 0)
855
- {
856
- *grad_sampling_loc = cache_grad_sampling_loc[0];
857
- *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
858
- *grad_attn_weight = cache_grad_attn_weight[0];
859
- }
860
- __syncthreads();
861
-
862
- data_weight_ptr += 1;
863
- data_loc_w_ptr += 2;
864
- grad_attn_weight += grad_weight_stride;
865
- grad_sampling_loc += grad_loc_stride;
866
- }
867
- }
868
- }
869
- }
870
-
871
- template <typename scalar_t>
872
- __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
873
- const scalar_t *grad_col,
874
- const scalar_t *data_value,
875
- const int64_t *data_spatial_shapes,
876
- const int64_t *data_level_start_index,
877
- const scalar_t *data_sampling_loc,
878
- const scalar_t *data_attn_weight,
879
- const int batch_size,
880
- const int spatial_size,
881
- const int num_heads,
882
- const int channels,
883
- const int num_levels,
884
- const int num_query,
885
- const int num_point,
886
- scalar_t *grad_value,
887
- scalar_t *grad_sampling_loc,
888
- scalar_t *grad_attn_weight)
889
- {
890
- CUDA_KERNEL_LOOP(index, n)
891
- {
892
- extern __shared__ int _s[];
893
- scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
894
- scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
895
- unsigned int tid = threadIdx.x;
896
- int _temp = index;
897
- const int c_col = _temp % channels;
898
- _temp /= channels;
899
- const int sampling_index = _temp;
900
- const int m_col = _temp % num_heads;
901
- _temp /= num_heads;
902
- [[maybe_unused]] const int q_col = _temp % num_query;
903
- _temp /= num_query;
904
- const int b_col = _temp;
905
-
906
- const scalar_t top_grad = grad_col[index];
907
-
908
- int data_weight_ptr = sampling_index * num_levels * num_point;
909
- int data_loc_w_ptr = data_weight_ptr << 1;
910
- const int grad_sampling_ptr = data_weight_ptr;
911
- grad_sampling_loc += grad_sampling_ptr << 1;
912
- grad_attn_weight += grad_sampling_ptr;
913
- const int grad_weight_stride = 1;
914
- const int grad_loc_stride = 2;
915
- const int qid_stride = num_heads * channels;
916
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
917
-
918
- for (int l_col=0; l_col < num_levels; ++l_col)
919
- {
920
- const int level_start_id = data_level_start_index[l_col];
921
- const int spatial_h_ptr = l_col << 1;
922
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
923
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
924
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
925
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
926
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
927
-
928
- for (int p_col=0; p_col < num_point; ++p_col)
929
- {
930
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
931
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
932
- const scalar_t weight = data_attn_weight[data_weight_ptr];
933
-
934
- const scalar_t h_im = loc_h * spatial_h - 0.5;
935
- const scalar_t w_im = loc_w * spatial_w - 0.5;
936
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
937
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
938
- *(cache_grad_attn_weight+threadIdx.x)=0;
939
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
940
- {
941
- ms_deform_attn_col2im_bilinear(
942
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
943
- top_grad, weight, grad_value_ptr,
944
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
945
- }
946
-
947
- __syncthreads();
948
-
949
- for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
950
- {
951
- if (tid < s) {
952
- const unsigned int xid1 = tid << 1;
953
- const unsigned int xid2 = (tid + s) << 1;
954
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
955
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
956
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
957
- if (tid + (s << 1) < spre)
958
- {
959
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
960
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
961
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
962
- }
963
- }
964
- __syncthreads();
965
- }
966
-
967
- if (tid == 0)
968
- {
969
- atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
970
- atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
971
- atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
972
- }
973
- __syncthreads();
974
-
975
- data_weight_ptr += 1;
976
- data_loc_w_ptr += 2;
977
- grad_attn_weight += grad_weight_stride;
978
- grad_sampling_loc += grad_loc_stride;
979
- }
980
- }
981
- }
982
- }
983
-
984
-
985
- template <typename scalar_t>
986
- __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
987
- const scalar_t *grad_col,
988
- const scalar_t *data_value,
989
- const int64_t *data_spatial_shapes,
990
- const int64_t *data_level_start_index,
991
- const scalar_t *data_sampling_loc,
992
- const scalar_t *data_attn_weight,
993
- const int batch_size,
994
- const int spatial_size,
995
- const int num_heads,
996
- const int channels,
997
- const int num_levels,
998
- const int num_query,
999
- const int num_point,
1000
- scalar_t *grad_value,
1001
- scalar_t *grad_sampling_loc,
1002
- scalar_t *grad_attn_weight)
1003
- {
1004
- CUDA_KERNEL_LOOP(index, n)
1005
- {
1006
- int _temp = index;
1007
- const int c_col = _temp % channels;
1008
- _temp /= channels;
1009
- const int sampling_index = _temp;
1010
- const int m_col = _temp % num_heads;
1011
- _temp /= num_heads;
1012
- [[maybe_unused]] const int q_col = _temp % num_query;
1013
- _temp /= num_query;
1014
- const int b_col = _temp;
1015
-
1016
- const scalar_t top_grad = grad_col[index];
1017
-
1018
- int data_weight_ptr = sampling_index * num_levels * num_point;
1019
- int data_loc_w_ptr = data_weight_ptr << 1;
1020
- const int grad_sampling_ptr = data_weight_ptr;
1021
- grad_sampling_loc += grad_sampling_ptr << 1;
1022
- grad_attn_weight += grad_sampling_ptr;
1023
- const int grad_weight_stride = 1;
1024
- const int grad_loc_stride = 2;
1025
- const int qid_stride = num_heads * channels;
1026
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
1027
-
1028
- for (int l_col=0; l_col < num_levels; ++l_col)
1029
- {
1030
- const int level_start_id = data_level_start_index[l_col];
1031
- const int spatial_h_ptr = l_col << 1;
1032
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
1033
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
1034
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
1035
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
1036
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
1037
-
1038
- for (int p_col=0; p_col < num_point; ++p_col)
1039
- {
1040
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
1041
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
1042
- const scalar_t weight = data_attn_weight[data_weight_ptr];
1043
-
1044
- const scalar_t h_im = loc_h * spatial_h - 0.5;
1045
- const scalar_t w_im = loc_w * spatial_w - 0.5;
1046
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
1047
- {
1048
- ms_deform_attn_col2im_bilinear_gm(
1049
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
1050
- top_grad, weight, grad_value_ptr,
1051
- grad_sampling_loc, grad_attn_weight);
1052
- }
1053
- data_weight_ptr += 1;
1054
- data_loc_w_ptr += 2;
1055
- grad_attn_weight += grad_weight_stride;
1056
- grad_sampling_loc += grad_loc_stride;
1057
- }
1058
- }
1059
- }
1060
- }
1061
-
1062
-
1063
- template <typename scalar_t>
1064
- void ms_deformable_im2col_cuda(cudaStream_t stream,
1065
- const scalar_t* data_value,
1066
- const int64_t* data_spatial_shapes,
1067
- const int64_t* data_level_start_index,
1068
- const scalar_t* data_sampling_loc,
1069
- const scalar_t* data_attn_weight,
1070
- const int batch_size,
1071
- const int spatial_size,
1072
- const int num_heads,
1073
- const int channels,
1074
- const int num_levels,
1075
- const int num_query,
1076
- const int num_point,
1077
- scalar_t* data_col)
1078
- {
1079
- const int num_kernels = batch_size * num_query * num_heads * channels;
1080
- const int num_actual_kernels = batch_size * num_query * num_heads * channels;
1081
- const int num_threads = CUDA_NUM_THREADS;
1082
- ms_deformable_im2col_gpu_kernel<scalar_t>
1083
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1084
- 0, stream>>>(
1085
- num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
1086
- batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
1087
-
1088
- cudaError_t err = cudaGetLastError();
1089
- if (err != cudaSuccess)
1090
- {
1091
- printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
1092
- }
1093
-
1094
- }
1095
-
1096
- template <typename scalar_t>
1097
- void ms_deformable_col2im_cuda(cudaStream_t stream,
1098
- const scalar_t* grad_col,
1099
- const scalar_t* data_value,
1100
- const int64_t * data_spatial_shapes,
1101
- const int64_t * data_level_start_index,
1102
- const scalar_t * data_sampling_loc,
1103
- const scalar_t * data_attn_weight,
1104
- const int batch_size,
1105
- const int spatial_size,
1106
- const int num_heads,
1107
- const int channels,
1108
- const int num_levels,
1109
- const int num_query,
1110
- const int num_point,
1111
- scalar_t* grad_value,
1112
- scalar_t* grad_sampling_loc,
1113
- scalar_t* grad_attn_weight)
1114
- {
1115
- const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
1116
- const int num_kernels = batch_size * num_query * num_heads * channels;
1117
- const int num_actual_kernels = batch_size * num_query * num_heads * channels;
1118
- if (channels > 1024)
1119
- {
1120
- if ((channels & 1023) == 0)
1121
- {
1122
- ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
1123
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1124
- num_threads*3*sizeof(scalar_t), stream>>>(
1125
- num_kernels,
1126
- grad_col,
1127
- data_value,
1128
- data_spatial_shapes,
1129
- data_level_start_index,
1130
- data_sampling_loc,
1131
- data_attn_weight,
1132
- batch_size,
1133
- spatial_size,
1134
- num_heads,
1135
- channels,
1136
- num_levels,
1137
- num_query,
1138
- num_point,
1139
- grad_value,
1140
- grad_sampling_loc,
1141
- grad_attn_weight);
1142
- }
1143
- else
1144
- {
1145
- ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1146
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1147
- 0, stream>>>(
1148
- num_kernels,
1149
- grad_col,
1150
- data_value,
1151
- data_spatial_shapes,
1152
- data_level_start_index,
1153
- data_sampling_loc,
1154
- data_attn_weight,
1155
- batch_size,
1156
- spatial_size,
1157
- num_heads,
1158
- channels,
1159
- num_levels,
1160
- num_query,
1161
- num_point,
1162
- grad_value,
1163
- grad_sampling_loc,
1164
- grad_attn_weight);
1165
- }
1166
- }
1167
- else{
1168
- switch(channels)
1169
- {
1170
- case 1:
1171
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1172
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1173
- 0, stream>>>(
1174
- num_kernels,
1175
- grad_col,
1176
- data_value,
1177
- data_spatial_shapes,
1178
- data_level_start_index,
1179
- data_sampling_loc,
1180
- data_attn_weight,
1181
- batch_size,
1182
- spatial_size,
1183
- num_heads,
1184
- channels,
1185
- num_levels,
1186
- num_query,
1187
- num_point,
1188
- grad_value,
1189
- grad_sampling_loc,
1190
- grad_attn_weight);
1191
- break;
1192
- case 2:
1193
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1194
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1195
- 0, stream>>>(
1196
- num_kernels,
1197
- grad_col,
1198
- data_value,
1199
- data_spatial_shapes,
1200
- data_level_start_index,
1201
- data_sampling_loc,
1202
- data_attn_weight,
1203
- batch_size,
1204
- spatial_size,
1205
- num_heads,
1206
- channels,
1207
- num_levels,
1208
- num_query,
1209
- num_point,
1210
- grad_value,
1211
- grad_sampling_loc,
1212
- grad_attn_weight);
1213
- break;
1214
- case 4:
1215
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1216
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1217
- 0, stream>>>(
1218
- num_kernels,
1219
- grad_col,
1220
- data_value,
1221
- data_spatial_shapes,
1222
- data_level_start_index,
1223
- data_sampling_loc,
1224
- data_attn_weight,
1225
- batch_size,
1226
- spatial_size,
1227
- num_heads,
1228
- channels,
1229
- num_levels,
1230
- num_query,
1231
- num_point,
1232
- grad_value,
1233
- grad_sampling_loc,
1234
- grad_attn_weight);
1235
- break;
1236
- case 8:
1237
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1238
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1239
- 0, stream>>>(
1240
- num_kernels,
1241
- grad_col,
1242
- data_value,
1243
- data_spatial_shapes,
1244
- data_level_start_index,
1245
- data_sampling_loc,
1246
- data_attn_weight,
1247
- batch_size,
1248
- spatial_size,
1249
- num_heads,
1250
- channels,
1251
- num_levels,
1252
- num_query,
1253
- num_point,
1254
- grad_value,
1255
- grad_sampling_loc,
1256
- grad_attn_weight);
1257
- break;
1258
- case 16:
1259
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1260
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1261
- 0, stream>>>(
1262
- num_kernels,
1263
- grad_col,
1264
- data_value,
1265
- data_spatial_shapes,
1266
- data_level_start_index,
1267
- data_sampling_loc,
1268
- data_attn_weight,
1269
- batch_size,
1270
- spatial_size,
1271
- num_heads,
1272
- channels,
1273
- num_levels,
1274
- num_query,
1275
- num_point,
1276
- grad_value,
1277
- grad_sampling_loc,
1278
- grad_attn_weight);
1279
- break;
1280
- case 32:
1281
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1282
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1283
- 0, stream>>>(
1284
- num_kernels,
1285
- grad_col,
1286
- data_value,
1287
- data_spatial_shapes,
1288
- data_level_start_index,
1289
- data_sampling_loc,
1290
- data_attn_weight,
1291
- batch_size,
1292
- spatial_size,
1293
- num_heads,
1294
- channels,
1295
- num_levels,
1296
- num_query,
1297
- num_point,
1298
- grad_value,
1299
- grad_sampling_loc,
1300
- grad_attn_weight);
1301
- break;
1302
- case 64:
1303
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1304
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1305
- 0, stream>>>(
1306
- num_kernels,
1307
- grad_col,
1308
- data_value,
1309
- data_spatial_shapes,
1310
- data_level_start_index,
1311
- data_sampling_loc,
1312
- data_attn_weight,
1313
- batch_size,
1314
- spatial_size,
1315
- num_heads,
1316
- channels,
1317
- num_levels,
1318
- num_query,
1319
- num_point,
1320
- grad_value,
1321
- grad_sampling_loc,
1322
- grad_attn_weight);
1323
- break;
1324
- case 128:
1325
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1326
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1327
- 0, stream>>>(
1328
- num_kernels,
1329
- grad_col,
1330
- data_value,
1331
- data_spatial_shapes,
1332
- data_level_start_index,
1333
- data_sampling_loc,
1334
- data_attn_weight,
1335
- batch_size,
1336
- spatial_size,
1337
- num_heads,
1338
- channels,
1339
- num_levels,
1340
- num_query,
1341
- num_point,
1342
- grad_value,
1343
- grad_sampling_loc,
1344
- grad_attn_weight);
1345
- break;
1346
- case 256:
1347
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1348
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1349
- 0, stream>>>(
1350
- num_kernels,
1351
- grad_col,
1352
- data_value,
1353
- data_spatial_shapes,
1354
- data_level_start_index,
1355
- data_sampling_loc,
1356
- data_attn_weight,
1357
- batch_size,
1358
- spatial_size,
1359
- num_heads,
1360
- channels,
1361
- num_levels,
1362
- num_query,
1363
- num_point,
1364
- grad_value,
1365
- grad_sampling_loc,
1366
- grad_attn_weight);
1367
- break;
1368
- case 512:
1369
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1370
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1371
- 0, stream>>>(
1372
- num_kernels,
1373
- grad_col,
1374
- data_value,
1375
- data_spatial_shapes,
1376
- data_level_start_index,
1377
- data_sampling_loc,
1378
- data_attn_weight,
1379
- batch_size,
1380
- spatial_size,
1381
- num_heads,
1382
- channels,
1383
- num_levels,
1384
- num_query,
1385
- num_point,
1386
- grad_value,
1387
- grad_sampling_loc,
1388
- grad_attn_weight);
1389
- break;
1390
- case 1024:
1391
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1392
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1393
- 0, stream>>>(
1394
- num_kernels,
1395
- grad_col,
1396
- data_value,
1397
- data_spatial_shapes,
1398
- data_level_start_index,
1399
- data_sampling_loc,
1400
- data_attn_weight,
1401
- batch_size,
1402
- spatial_size,
1403
- num_heads,
1404
- channels,
1405
- num_levels,
1406
- num_query,
1407
- num_point,
1408
- grad_value,
1409
- grad_sampling_loc,
1410
- grad_attn_weight);
1411
- break;
1412
- default:
1413
- if (channels < 64)
1414
- {
1415
- ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1416
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1417
- num_threads*3*sizeof(scalar_t), stream>>>(
1418
- num_kernels,
1419
- grad_col,
1420
- data_value,
1421
- data_spatial_shapes,
1422
- data_level_start_index,
1423
- data_sampling_loc,
1424
- data_attn_weight,
1425
- batch_size,
1426
- spatial_size,
1427
- num_heads,
1428
- channels,
1429
- num_levels,
1430
- num_query,
1431
- num_point,
1432
- grad_value,
1433
- grad_sampling_loc,
1434
- grad_attn_weight);
1435
- }
1436
- else
1437
- {
1438
- ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1439
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1440
- num_threads*3*sizeof(scalar_t), stream>>>(
1441
- num_kernels,
1442
- grad_col,
1443
- data_value,
1444
- data_spatial_shapes,
1445
- data_level_start_index,
1446
- data_sampling_loc,
1447
- data_attn_weight,
1448
- batch_size,
1449
- spatial_size,
1450
- num_heads,
1451
- channels,
1452
- num_levels,
1453
- num_query,
1454
- num_point,
1455
- grad_value,
1456
- grad_sampling_loc,
1457
- grad_attn_weight);
1458
- }
1459
- }
1460
- }
1461
- cudaError_t err = cudaGetLastError();
1462
- if (err != cudaSuccess)
1463
- {
1464
- printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1465
- }
1466
-
1467
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deformable_detr/ms_deform_attn_cuda.h DELETED
@@ -1,46 +0,0 @@
1
- /*!
2
- **************************************************************************************************
3
- * Deformable DETR
4
- * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
- * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
- **************************************************************************************************
7
- * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
- **************************************************************************************************
9
- */
10
-
11
- #pragma once
12
- #include <torch/torch.h>
13
-
14
- at::Tensor ms_deform_attn_cuda_forward(
15
- const at::Tensor &value,
16
- const at::Tensor &spatial_shapes,
17
- const at::Tensor &level_start_index,
18
- const at::Tensor &sampling_loc,
19
- const at::Tensor &attn_weight,
20
- const int im2col_step);
21
-
22
- at::Tensor ms_deform_attn_cuda_forward_bf16(
23
- const at::Tensor &value,
24
- const at::Tensor &spatial_shapes,
25
- const at::Tensor &level_start_index,
26
- const at::Tensor &sampling_loc,
27
- const at::Tensor &attn_weight,
28
- const int im2col_step);
29
-
30
- std::vector<at::Tensor> ms_deform_attn_cuda_backward(
31
- const at::Tensor &value,
32
- const at::Tensor &spatial_shapes,
33
- const at::Tensor &level_start_index,
34
- const at::Tensor &sampling_loc,
35
- const at::Tensor &attn_weight,
36
- const at::Tensor &grad_output,
37
- const int im2col_step);
38
-
39
- std::vector<at::Tensor> ms_deform_attn_cuda_backward_bf16(
40
- const at::Tensor &value,
41
- const at::Tensor &spatial_shapes,
42
- const at::Tensor &level_start_index,
43
- const at::Tensor &sampling_loc,
44
- const at::Tensor &attn_weight,
45
- const at::Tensor &grad_output,
46
- const int im2col_step);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deformable_detr/ms_deform_im2col_cuda.cuh DELETED
@@ -1,1327 +0,0 @@
1
- /*!
2
- **************************************************************************
3
- * Deformable DETR
4
- * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
- * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
- **************************************************************************
7
- * Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
8
- * Copyright (c) 2018 Microsoft
9
- **************************************************************************
10
- */
11
-
12
- #include <cstdio>
13
- #include <algorithm>
14
- #include <cstring>
15
-
16
- #include <ATen/ATen.h>
17
- #include <ATen/cuda/CUDAContext.h>
18
-
19
- #include <THC/THCAtomics.cuh>
20
-
21
- #define CUDA_KERNEL_LOOP(i, n) \
22
- for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
23
- i < (n); \
24
- i += blockDim.x * gridDim.x)
25
-
26
- const int CUDA_NUM_THREADS = 1024;
27
- inline int GET_BLOCKS(const int N, const int num_threads)
28
- {
29
- return (N + num_threads - 1) / num_threads;
30
- }
31
-
32
-
33
- template <typename scalar_t>
34
- __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
35
- const int &height, const int &width, const int &nheads, const int &channels,
36
- const scalar_t &h, const scalar_t &w, const int &m, const int &c)
37
- {
38
- const int h_low = floor(h);
39
- const int w_low = floor(w);
40
- const int h_high = h_low + 1;
41
- const int w_high = w_low + 1;
42
-
43
- const scalar_t lh = h - h_low;
44
- const scalar_t lw = w - w_low;
45
- const scalar_t hh = 1 - lh, hw = 1 - lw;
46
-
47
- const int w_stride = nheads * channels;
48
- const int h_stride = width * w_stride;
49
- const int h_low_ptr_offset = h_low * h_stride;
50
- const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
51
- const int w_low_ptr_offset = w_low * w_stride;
52
- const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
53
- const int base_ptr = m * channels + c;
54
-
55
- scalar_t v1 = 0;
56
- if (h_low >= 0 && w_low >= 0)
57
- {
58
- const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
59
- v1 = bottom_data[ptr1];
60
- }
61
- scalar_t v2 = 0;
62
- if (h_low >= 0 && w_high <= width - 1)
63
- {
64
- const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
65
- v2 = bottom_data[ptr2];
66
- }
67
- scalar_t v3 = 0;
68
- if (h_high <= height - 1 && w_low >= 0)
69
- {
70
- const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
71
- v3 = bottom_data[ptr3];
72
- }
73
- scalar_t v4 = 0;
74
- if (h_high <= height - 1 && w_high <= width - 1)
75
- {
76
- const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
77
- v4 = bottom_data[ptr4];
78
- }
79
-
80
- const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
81
-
82
- const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
83
- return val;
84
- }
85
-
86
-
87
- template <typename scalar_t>
88
- __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
89
- const int &height, const int &width, const int &nheads, const int &channels,
90
- const scalar_t &h, const scalar_t &w, const int &m, const int &c,
91
- const scalar_t &top_grad,
92
- const scalar_t &attn_weight,
93
- scalar_t* &grad_value,
94
- scalar_t* grad_sampling_loc,
95
- scalar_t* grad_attn_weight)
96
- {
97
- const int h_low = floor(h);
98
- const int w_low = floor(w);
99
- const int h_high = h_low + 1;
100
- const int w_high = w_low + 1;
101
-
102
- const scalar_t lh = h - h_low;
103
- const scalar_t lw = w - w_low;
104
- const scalar_t hh = 1 - lh, hw = 1 - lw;
105
-
106
- const int w_stride = nheads * channels;
107
- const int h_stride = width * w_stride;
108
- const int h_low_ptr_offset = h_low * h_stride;
109
- const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
110
- const int w_low_ptr_offset = w_low * w_stride;
111
- const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
112
- const int base_ptr = m * channels + c;
113
-
114
- const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
115
- const scalar_t top_grad_value = top_grad * attn_weight;
116
- scalar_t grad_h_weight = 0, grad_w_weight = 0;
117
-
118
- scalar_t v1 = 0;
119
- if (h_low >= 0 && w_low >= 0)
120
- {
121
- const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
122
- v1 = bottom_data[ptr1];
123
- grad_h_weight -= hw * v1;
124
- grad_w_weight -= hh * v1;
125
- atomicAdd(grad_value+ptr1, w1*top_grad_value);
126
- }
127
- scalar_t v2 = 0;
128
- if (h_low >= 0 && w_high <= width - 1)
129
- {
130
- const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
131
- v2 = bottom_data[ptr2];
132
- grad_h_weight -= lw * v2;
133
- grad_w_weight += hh * v2;
134
- atomicAdd(grad_value+ptr2, w2*top_grad_value);
135
- }
136
- scalar_t v3 = 0;
137
- if (h_high <= height - 1 && w_low >= 0)
138
- {
139
- const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
140
- v3 = bottom_data[ptr3];
141
- grad_h_weight += hw * v3;
142
- grad_w_weight -= lh * v3;
143
- atomicAdd(grad_value+ptr3, w3*top_grad_value);
144
- }
145
- scalar_t v4 = 0;
146
- if (h_high <= height - 1 && w_high <= width - 1)
147
- {
148
- const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
149
- v4 = bottom_data[ptr4];
150
- grad_h_weight += lw * v4;
151
- grad_w_weight += lh * v4;
152
- atomicAdd(grad_value+ptr4, w4*top_grad_value);
153
- }
154
-
155
- const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
156
- *grad_attn_weight = top_grad * val;
157
- *grad_sampling_loc = width * grad_w_weight * top_grad_value;
158
- *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
159
- }
160
-
161
-
162
- template <typename scalar_t>
163
- __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
164
- const int &height, const int &width, const int &nheads, const int &channels,
165
- const scalar_t &h, const scalar_t &w, const int &m, const int &c,
166
- const scalar_t &top_grad,
167
- const scalar_t &attn_weight,
168
- scalar_t* &grad_value,
169
- scalar_t* grad_sampling_loc,
170
- scalar_t* grad_attn_weight)
171
- {
172
- const int h_low = floor(h);
173
- const int w_low = floor(w);
174
- const int h_high = h_low + 1;
175
- const int w_high = w_low + 1;
176
-
177
- const scalar_t lh = h - h_low;
178
- const scalar_t lw = w - w_low;
179
- const scalar_t hh = 1 - lh, hw = 1 - lw;
180
-
181
- const int w_stride = nheads * channels;
182
- const int h_stride = width * w_stride;
183
- const int h_low_ptr_offset = h_low * h_stride;
184
- const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
185
- const int w_low_ptr_offset = w_low * w_stride;
186
- const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
187
- const int base_ptr = m * channels + c;
188
-
189
- const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
190
- const scalar_t top_grad_value = top_grad * attn_weight;
191
- scalar_t grad_h_weight = 0, grad_w_weight = 0;
192
-
193
- scalar_t v1 = 0;
194
- if (h_low >= 0 && w_low >= 0)
195
- {
196
- const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
197
- v1 = bottom_data[ptr1];
198
- grad_h_weight -= hw * v1;
199
- grad_w_weight -= hh * v1;
200
- atomicAdd(grad_value+ptr1, w1*top_grad_value);
201
- }
202
- scalar_t v2 = 0;
203
- if (h_low >= 0 && w_high <= width - 1)
204
- {
205
- const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
206
- v2 = bottom_data[ptr2];
207
- grad_h_weight -= lw * v2;
208
- grad_w_weight += hh * v2;
209
- atomicAdd(grad_value+ptr2, w2*top_grad_value);
210
- }
211
- scalar_t v3 = 0;
212
- if (h_high <= height - 1 && w_low >= 0)
213
- {
214
- const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
215
- v3 = bottom_data[ptr3];
216
- grad_h_weight += hw * v3;
217
- grad_w_weight -= lh * v3;
218
- atomicAdd(grad_value+ptr3, w3*top_grad_value);
219
- }
220
- scalar_t v4 = 0;
221
- if (h_high <= height - 1 && w_high <= width - 1)
222
- {
223
- const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
224
- v4 = bottom_data[ptr4];
225
- grad_h_weight += lw * v4;
226
- grad_w_weight += lh * v4;
227
- atomicAdd(grad_value+ptr4, w4*top_grad_value);
228
- }
229
-
230
- const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
231
- atomicAdd(grad_attn_weight, top_grad * val);
232
- atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
233
- atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
234
- }
235
-
236
-
237
- template <typename scalar_t>
238
- __global__ void ms_deformable_im2col_gpu_kernel(const int n,
239
- const scalar_t *data_value,
240
- const int64_t *data_spatial_shapes,
241
- const int64_t *data_level_start_index,
242
- const scalar_t *data_sampling_loc,
243
- const scalar_t *data_attn_weight,
244
- const int batch_size,
245
- const int spatial_size,
246
- const int num_heads,
247
- const int channels,
248
- const int num_levels,
249
- const int num_query,
250
- const int num_point,
251
- scalar_t *data_col)
252
- {
253
- CUDA_KERNEL_LOOP(index, n)
254
- {
255
- int _temp = index;
256
- const int c_col = _temp % channels;
257
- _temp /= channels;
258
- const int sampling_index = _temp;
259
- const int m_col = _temp % num_heads;
260
- _temp /= num_heads;
261
- [[maybe_unused]] const int q_col = _temp % num_query;
262
- _temp /= num_query;
263
- const int b_col = _temp;
264
-
265
- scalar_t *data_col_ptr = data_col + index;
266
- int data_weight_ptr = sampling_index * num_levels * num_point;
267
- int data_loc_w_ptr = data_weight_ptr << 1;
268
- const int qid_stride = num_heads * channels;
269
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
270
- scalar_t col = 0;
271
-
272
- for (int l_col=0; l_col < num_levels; ++l_col)
273
- {
274
- const int level_start_id = data_level_start_index[l_col];
275
- const int spatial_h_ptr = l_col << 1;
276
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
277
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
278
- const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
279
- for (int p_col=0; p_col < num_point; ++p_col)
280
- {
281
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
282
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
283
- const scalar_t weight = data_attn_weight[data_weight_ptr];
284
-
285
- const scalar_t h_im = loc_h * spatial_h - 0.5;
286
- const scalar_t w_im = loc_w * spatial_w - 0.5;
287
-
288
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
289
- {
290
- col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
291
- }
292
-
293
- data_weight_ptr += 1;
294
- data_loc_w_ptr += 2;
295
- }
296
- }
297
- *data_col_ptr = col;
298
- }
299
- }
300
-
301
- template <typename scalar_t, unsigned int blockSize>
302
- __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
303
- const scalar_t *grad_col,
304
- const scalar_t *data_value,
305
- const int64_t *data_spatial_shapes,
306
- const int64_t *data_level_start_index,
307
- const scalar_t *data_sampling_loc,
308
- const scalar_t *data_attn_weight,
309
- const int batch_size,
310
- const int spatial_size,
311
- const int num_heads,
312
- const int channels,
313
- const int num_levels,
314
- const int num_query,
315
- const int num_point,
316
- scalar_t *grad_value,
317
- scalar_t *grad_sampling_loc,
318
- scalar_t *grad_attn_weight)
319
- {
320
- CUDA_KERNEL_LOOP(index, n)
321
- {
322
- __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
323
- __shared__ scalar_t cache_grad_attn_weight[blockSize];
324
- unsigned int tid = threadIdx.x;
325
- int _temp = index;
326
- const int c_col = _temp % channels;
327
- _temp /= channels;
328
- const int sampling_index = _temp;
329
- const int m_col = _temp % num_heads;
330
- _temp /= num_heads;
331
- [[maybe_unused]] const int q_col = _temp % num_query;
332
- _temp /= num_query;
333
- const int b_col = _temp;
334
-
335
- const scalar_t top_grad = grad_col[index];
336
-
337
- int data_weight_ptr = sampling_index * num_levels * num_point;
338
- int data_loc_w_ptr = data_weight_ptr << 1;
339
- const int grad_sampling_ptr = data_weight_ptr;
340
- grad_sampling_loc += grad_sampling_ptr << 1;
341
- grad_attn_weight += grad_sampling_ptr;
342
- const int grad_weight_stride = 1;
343
- const int grad_loc_stride = 2;
344
- const int qid_stride = num_heads * channels;
345
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
346
-
347
- for (int l_col=0; l_col < num_levels; ++l_col)
348
- {
349
- const int level_start_id = data_level_start_index[l_col];
350
- const int spatial_h_ptr = l_col << 1;
351
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
352
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
353
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
354
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
355
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
356
-
357
- for (int p_col=0; p_col < num_point; ++p_col)
358
- {
359
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
360
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
361
- const scalar_t weight = data_attn_weight[data_weight_ptr];
362
-
363
- const scalar_t h_im = loc_h * spatial_h - 0.5;
364
- const scalar_t w_im = loc_w * spatial_w - 0.5;
365
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
366
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
367
- *(cache_grad_attn_weight+threadIdx.x)=0;
368
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
369
- {
370
- ms_deform_attn_col2im_bilinear(
371
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
372
- top_grad, weight, grad_value_ptr,
373
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
374
- }
375
-
376
- __syncthreads();
377
- if (tid == 0)
378
- {
379
- scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
380
- int sid=2;
381
- for (unsigned int tid = 1; tid < blockSize; ++tid)
382
- {
383
- _grad_w += cache_grad_sampling_loc[sid];
384
- _grad_h += cache_grad_sampling_loc[sid + 1];
385
- _grad_a += cache_grad_attn_weight[tid];
386
- sid += 2;
387
- }
388
-
389
-
390
- *grad_sampling_loc = _grad_w;
391
- *(grad_sampling_loc + 1) = _grad_h;
392
- *grad_attn_weight = _grad_a;
393
- }
394
- __syncthreads();
395
-
396
- data_weight_ptr += 1;
397
- data_loc_w_ptr += 2;
398
- grad_attn_weight += grad_weight_stride;
399
- grad_sampling_loc += grad_loc_stride;
400
- }
401
- }
402
- }
403
- }
404
-
405
-
406
- template <typename scalar_t, unsigned int blockSize>
407
- __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
408
- const scalar_t *grad_col,
409
- const scalar_t *data_value,
410
- const int64_t *data_spatial_shapes,
411
- const int64_t *data_level_start_index,
412
- const scalar_t *data_sampling_loc,
413
- const scalar_t *data_attn_weight,
414
- const int batch_size,
415
- const int spatial_size,
416
- const int num_heads,
417
- const int channels,
418
- const int num_levels,
419
- const int num_query,
420
- const int num_point,
421
- scalar_t *grad_value,
422
- scalar_t *grad_sampling_loc,
423
- scalar_t *grad_attn_weight)
424
- {
425
- CUDA_KERNEL_LOOP(index, n)
426
- {
427
- __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
428
- __shared__ scalar_t cache_grad_attn_weight[blockSize];
429
- unsigned int tid = threadIdx.x;
430
- int _temp = index;
431
- const int c_col = _temp % channels;
432
- _temp /= channels;
433
- const int sampling_index = _temp;
434
- const int m_col = _temp % num_heads;
435
- _temp /= num_heads;
436
- [[maybe_unused]] const int q_col = _temp % num_query;
437
- _temp /= num_query;
438
- const int b_col = _temp;
439
-
440
- const scalar_t top_grad = grad_col[index];
441
-
442
- int data_weight_ptr = sampling_index * num_levels * num_point;
443
- int data_loc_w_ptr = data_weight_ptr << 1;
444
- const int grad_sampling_ptr = data_weight_ptr;
445
- grad_sampling_loc += grad_sampling_ptr << 1;
446
- grad_attn_weight += grad_sampling_ptr;
447
- const int grad_weight_stride = 1;
448
- const int grad_loc_stride = 2;
449
- const int qid_stride = num_heads * channels;
450
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
451
-
452
- for (int l_col=0; l_col < num_levels; ++l_col)
453
- {
454
- const int level_start_id = data_level_start_index[l_col];
455
- const int spatial_h_ptr = l_col << 1;
456
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
457
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
458
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
459
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
460
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
461
-
462
- for (int p_col=0; p_col < num_point; ++p_col)
463
- {
464
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
465
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
466
- const scalar_t weight = data_attn_weight[data_weight_ptr];
467
-
468
- const scalar_t h_im = loc_h * spatial_h - 0.5;
469
- const scalar_t w_im = loc_w * spatial_w - 0.5;
470
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
471
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
472
- *(cache_grad_attn_weight+threadIdx.x)=0;
473
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
474
- {
475
- ms_deform_attn_col2im_bilinear(
476
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
477
- top_grad, weight, grad_value_ptr,
478
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
479
- }
480
-
481
- __syncthreads();
482
-
483
- for (unsigned int s=blockSize/2; s>0; s>>=1)
484
- {
485
- if (tid < s) {
486
- const unsigned int xid1 = tid << 1;
487
- const unsigned int xid2 = (tid + s) << 1;
488
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
489
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
490
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
491
- }
492
- __syncthreads();
493
- }
494
-
495
- if (tid == 0)
496
- {
497
- *grad_sampling_loc = cache_grad_sampling_loc[0];
498
- *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
499
- *grad_attn_weight = cache_grad_attn_weight[0];
500
- }
501
- __syncthreads();
502
-
503
- data_weight_ptr += 1;
504
- data_loc_w_ptr += 2;
505
- grad_attn_weight += grad_weight_stride;
506
- grad_sampling_loc += grad_loc_stride;
507
- }
508
- }
509
- }
510
- }
511
-
512
-
513
- template <typename scalar_t>
514
- __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
515
- const scalar_t *grad_col,
516
- const scalar_t *data_value,
517
- const int64_t *data_spatial_shapes,
518
- const int64_t *data_level_start_index,
519
- const scalar_t *data_sampling_loc,
520
- const scalar_t *data_attn_weight,
521
- const int batch_size,
522
- const int spatial_size,
523
- const int num_heads,
524
- const int channels,
525
- const int num_levels,
526
- const int num_query,
527
- const int num_point,
528
- scalar_t *grad_value,
529
- scalar_t *grad_sampling_loc,
530
- scalar_t *grad_attn_weight)
531
- {
532
- CUDA_KERNEL_LOOP(index, n)
533
- {
534
- extern __shared__ int _s[];
535
- scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
536
- scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
537
- unsigned int tid = threadIdx.x;
538
- int _temp = index;
539
- const int c_col = _temp % channels;
540
- _temp /= channels;
541
- const int sampling_index = _temp;
542
- const int m_col = _temp % num_heads;
543
- _temp /= num_heads;
544
- [[maybe_unused]] const int q_col = _temp % num_query;
545
- _temp /= num_query;
546
- const int b_col = _temp;
547
-
548
- const scalar_t top_grad = grad_col[index];
549
-
550
- int data_weight_ptr = sampling_index * num_levels * num_point;
551
- int data_loc_w_ptr = data_weight_ptr << 1;
552
- const int grad_sampling_ptr = data_weight_ptr;
553
- grad_sampling_loc += grad_sampling_ptr << 1;
554
- grad_attn_weight += grad_sampling_ptr;
555
- const int grad_weight_stride = 1;
556
- const int grad_loc_stride = 2;
557
- const int qid_stride = num_heads * channels;
558
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
559
-
560
- for (int l_col=0; l_col < num_levels; ++l_col)
561
- {
562
- const int level_start_id = data_level_start_index[l_col];
563
- const int spatial_h_ptr = l_col << 1;
564
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
565
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
566
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
567
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
568
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
569
-
570
- for (int p_col=0; p_col < num_point; ++p_col)
571
- {
572
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
573
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
574
- const scalar_t weight = data_attn_weight[data_weight_ptr];
575
-
576
- const scalar_t h_im = loc_h * spatial_h - 0.5;
577
- const scalar_t w_im = loc_w * spatial_w - 0.5;
578
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
579
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
580
- *(cache_grad_attn_weight+threadIdx.x)=0;
581
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
582
- {
583
- ms_deform_attn_col2im_bilinear(
584
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
585
- top_grad, weight, grad_value_ptr,
586
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
587
- }
588
-
589
- __syncthreads();
590
- if (tid == 0)
591
- {
592
- scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
593
- int sid=2;
594
- for (unsigned int tid = 1; tid < blockDim.x; ++tid)
595
- {
596
- _grad_w += cache_grad_sampling_loc[sid];
597
- _grad_h += cache_grad_sampling_loc[sid + 1];
598
- _grad_a += cache_grad_attn_weight[tid];
599
- sid += 2;
600
- }
601
-
602
-
603
- *grad_sampling_loc = _grad_w;
604
- *(grad_sampling_loc + 1) = _grad_h;
605
- *grad_attn_weight = _grad_a;
606
- }
607
- __syncthreads();
608
-
609
- data_weight_ptr += 1;
610
- data_loc_w_ptr += 2;
611
- grad_attn_weight += grad_weight_stride;
612
- grad_sampling_loc += grad_loc_stride;
613
- }
614
- }
615
- }
616
- }
617
-
618
- template <typename scalar_t>
619
- __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
620
- const scalar_t *grad_col,
621
- const scalar_t *data_value,
622
- const int64_t *data_spatial_shapes,
623
- const int64_t *data_level_start_index,
624
- const scalar_t *data_sampling_loc,
625
- const scalar_t *data_attn_weight,
626
- const int batch_size,
627
- const int spatial_size,
628
- const int num_heads,
629
- const int channels,
630
- const int num_levels,
631
- const int num_query,
632
- const int num_point,
633
- scalar_t *grad_value,
634
- scalar_t *grad_sampling_loc,
635
- scalar_t *grad_attn_weight)
636
- {
637
- CUDA_KERNEL_LOOP(index, n)
638
- {
639
- extern __shared__ int _s[];
640
- scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
641
- scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
642
- unsigned int tid = threadIdx.x;
643
- int _temp = index;
644
- const int c_col = _temp % channels;
645
- _temp /= channels;
646
- const int sampling_index = _temp;
647
- const int m_col = _temp % num_heads;
648
- _temp /= num_heads;
649
- [[maybe_unused]] const int q_col = _temp % num_query;
650
- _temp /= num_query;
651
- const int b_col = _temp;
652
-
653
- const scalar_t top_grad = grad_col[index];
654
-
655
- int data_weight_ptr = sampling_index * num_levels * num_point;
656
- int data_loc_w_ptr = data_weight_ptr << 1;
657
- const int grad_sampling_ptr = data_weight_ptr;
658
- grad_sampling_loc += grad_sampling_ptr << 1;
659
- grad_attn_weight += grad_sampling_ptr;
660
- const int grad_weight_stride = 1;
661
- const int grad_loc_stride = 2;
662
- const int qid_stride = num_heads * channels;
663
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
664
-
665
- for (int l_col=0; l_col < num_levels; ++l_col)
666
- {
667
- const int level_start_id = data_level_start_index[l_col];
668
- const int spatial_h_ptr = l_col << 1;
669
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
670
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
671
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
672
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
673
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
674
-
675
- for (int p_col=0; p_col < num_point; ++p_col)
676
- {
677
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
678
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
679
- const scalar_t weight = data_attn_weight[data_weight_ptr];
680
-
681
- const scalar_t h_im = loc_h * spatial_h - 0.5;
682
- const scalar_t w_im = loc_w * spatial_w - 0.5;
683
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
684
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
685
- *(cache_grad_attn_weight+threadIdx.x)=0;
686
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
687
- {
688
- ms_deform_attn_col2im_bilinear(
689
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
690
- top_grad, weight, grad_value_ptr,
691
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
692
- }
693
-
694
- __syncthreads();
695
-
696
- for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
697
- {
698
- if (tid < s) {
699
- const unsigned int xid1 = tid << 1;
700
- const unsigned int xid2 = (tid + s) << 1;
701
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
702
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
703
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
704
- if (tid + (s << 1) < spre)
705
- {
706
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
707
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
708
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
709
- }
710
- }
711
- __syncthreads();
712
- }
713
-
714
- if (tid == 0)
715
- {
716
- *grad_sampling_loc = cache_grad_sampling_loc[0];
717
- *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
718
- *grad_attn_weight = cache_grad_attn_weight[0];
719
- }
720
- __syncthreads();
721
-
722
- data_weight_ptr += 1;
723
- data_loc_w_ptr += 2;
724
- grad_attn_weight += grad_weight_stride;
725
- grad_sampling_loc += grad_loc_stride;
726
- }
727
- }
728
- }
729
- }
730
-
731
- template <typename scalar_t>
732
- __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
733
- const scalar_t *grad_col,
734
- const scalar_t *data_value,
735
- const int64_t *data_spatial_shapes,
736
- const int64_t *data_level_start_index,
737
- const scalar_t *data_sampling_loc,
738
- const scalar_t *data_attn_weight,
739
- const int batch_size,
740
- const int spatial_size,
741
- const int num_heads,
742
- const int channels,
743
- const int num_levels,
744
- const int num_query,
745
- const int num_point,
746
- scalar_t *grad_value,
747
- scalar_t *grad_sampling_loc,
748
- scalar_t *grad_attn_weight)
749
- {
750
- CUDA_KERNEL_LOOP(index, n)
751
- {
752
- extern __shared__ int _s[];
753
- scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
754
- scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
755
- unsigned int tid = threadIdx.x;
756
- int _temp = index;
757
- const int c_col = _temp % channels;
758
- _temp /= channels;
759
- const int sampling_index = _temp;
760
- const int m_col = _temp % num_heads;
761
- _temp /= num_heads;
762
- [[maybe_unused]] const int q_col = _temp % num_query;
763
- _temp /= num_query;
764
- const int b_col = _temp;
765
-
766
- const scalar_t top_grad = grad_col[index];
767
-
768
- int data_weight_ptr = sampling_index * num_levels * num_point;
769
- int data_loc_w_ptr = data_weight_ptr << 1;
770
- const int grad_sampling_ptr = data_weight_ptr;
771
- grad_sampling_loc += grad_sampling_ptr << 1;
772
- grad_attn_weight += grad_sampling_ptr;
773
- const int grad_weight_stride = 1;
774
- const int grad_loc_stride = 2;
775
- const int qid_stride = num_heads * channels;
776
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
777
-
778
- for (int l_col=0; l_col < num_levels; ++l_col)
779
- {
780
- const int level_start_id = data_level_start_index[l_col];
781
- const int spatial_h_ptr = l_col << 1;
782
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
783
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
784
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
785
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
786
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
787
-
788
- for (int p_col=0; p_col < num_point; ++p_col)
789
- {
790
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
791
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
792
- const scalar_t weight = data_attn_weight[data_weight_ptr];
793
-
794
- const scalar_t h_im = loc_h * spatial_h - 0.5;
795
- const scalar_t w_im = loc_w * spatial_w - 0.5;
796
- *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
797
- *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
798
- *(cache_grad_attn_weight+threadIdx.x)=0;
799
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
800
- {
801
- ms_deform_attn_col2im_bilinear(
802
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
803
- top_grad, weight, grad_value_ptr,
804
- cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
805
- }
806
-
807
- __syncthreads();
808
-
809
- for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
810
- {
811
- if (tid < s) {
812
- const unsigned int xid1 = tid << 1;
813
- const unsigned int xid2 = (tid + s) << 1;
814
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
815
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
816
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
817
- if (tid + (s << 1) < spre)
818
- {
819
- cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
820
- cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
821
- cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
822
- }
823
- }
824
- __syncthreads();
825
- }
826
-
827
- if (tid == 0)
828
- {
829
- atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
830
- atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
831
- atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
832
- }
833
- __syncthreads();
834
-
835
- data_weight_ptr += 1;
836
- data_loc_w_ptr += 2;
837
- grad_attn_weight += grad_weight_stride;
838
- grad_sampling_loc += grad_loc_stride;
839
- }
840
- }
841
- }
842
- }
843
-
844
-
845
- template <typename scalar_t>
846
- __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
847
- const scalar_t *grad_col,
848
- const scalar_t *data_value,
849
- const int64_t *data_spatial_shapes,
850
- const int64_t *data_level_start_index,
851
- const scalar_t *data_sampling_loc,
852
- const scalar_t *data_attn_weight,
853
- const int batch_size,
854
- const int spatial_size,
855
- const int num_heads,
856
- const int channels,
857
- const int num_levels,
858
- const int num_query,
859
- const int num_point,
860
- scalar_t *grad_value,
861
- scalar_t *grad_sampling_loc,
862
- scalar_t *grad_attn_weight)
863
- {
864
- CUDA_KERNEL_LOOP(index, n)
865
- {
866
- int _temp = index;
867
- const int c_col = _temp % channels;
868
- _temp /= channels;
869
- const int sampling_index = _temp;
870
- const int m_col = _temp % num_heads;
871
- _temp /= num_heads;
872
- [[maybe_unused]] const int q_col = _temp % num_query;
873
- _temp /= num_query;
874
- const int b_col = _temp;
875
-
876
- const scalar_t top_grad = grad_col[index];
877
-
878
- int data_weight_ptr = sampling_index * num_levels * num_point;
879
- int data_loc_w_ptr = data_weight_ptr << 1;
880
- const int grad_sampling_ptr = data_weight_ptr;
881
- grad_sampling_loc += grad_sampling_ptr << 1;
882
- grad_attn_weight += grad_sampling_ptr;
883
- const int grad_weight_stride = 1;
884
- const int grad_loc_stride = 2;
885
- const int qid_stride = num_heads * channels;
886
- const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
887
-
888
- for (int l_col=0; l_col < num_levels; ++l_col)
889
- {
890
- const int level_start_id = data_level_start_index[l_col];
891
- const int spatial_h_ptr = l_col << 1;
892
- const int spatial_h = data_spatial_shapes[spatial_h_ptr];
893
- const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
894
- const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
895
- const scalar_t *data_value_ptr = data_value + value_ptr_offset;
896
- scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
897
-
898
- for (int p_col=0; p_col < num_point; ++p_col)
899
- {
900
- const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
901
- const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
902
- const scalar_t weight = data_attn_weight[data_weight_ptr];
903
-
904
- const scalar_t h_im = loc_h * spatial_h - 0.5;
905
- const scalar_t w_im = loc_w * spatial_w - 0.5;
906
- if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
907
- {
908
- ms_deform_attn_col2im_bilinear_gm(
909
- data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
910
- top_grad, weight, grad_value_ptr,
911
- grad_sampling_loc, grad_attn_weight);
912
- }
913
- data_weight_ptr += 1;
914
- data_loc_w_ptr += 2;
915
- grad_attn_weight += grad_weight_stride;
916
- grad_sampling_loc += grad_loc_stride;
917
- }
918
- }
919
- }
920
- }
921
-
922
-
923
- template <typename scalar_t>
924
- void ms_deformable_im2col_cuda(cudaStream_t stream,
925
- const scalar_t* data_value,
926
- const int64_t* data_spatial_shapes,
927
- const int64_t* data_level_start_index,
928
- const scalar_t* data_sampling_loc,
929
- const scalar_t* data_attn_weight,
930
- const int batch_size,
931
- const int spatial_size,
932
- const int num_heads,
933
- const int channels,
934
- const int num_levels,
935
- const int num_query,
936
- const int num_point,
937
- scalar_t* data_col)
938
- {
939
- const int num_kernels = batch_size * num_query * num_heads * channels;
940
- const int num_actual_kernels = batch_size * num_query * num_heads * channels;
941
- const int num_threads = CUDA_NUM_THREADS;
942
- ms_deformable_im2col_gpu_kernel<scalar_t>
943
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
944
- 0, stream>>>(
945
- num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
946
- batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
947
-
948
- cudaError_t err = cudaGetLastError();
949
- if (err != cudaSuccess)
950
- {
951
- printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
952
- }
953
-
954
- }
955
-
956
- template <typename scalar_t>
957
- void ms_deformable_col2im_cuda(cudaStream_t stream,
958
- const scalar_t* grad_col,
959
- const scalar_t* data_value,
960
- const int64_t * data_spatial_shapes,
961
- const int64_t * data_level_start_index,
962
- const scalar_t * data_sampling_loc,
963
- const scalar_t * data_attn_weight,
964
- const int batch_size,
965
- const int spatial_size,
966
- const int num_heads,
967
- const int channels,
968
- const int num_levels,
969
- const int num_query,
970
- const int num_point,
971
- scalar_t* grad_value,
972
- scalar_t* grad_sampling_loc,
973
- scalar_t* grad_attn_weight)
974
- {
975
- const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
976
- const int num_kernels = batch_size * num_query * num_heads * channels;
977
- const int num_actual_kernels = batch_size * num_query * num_heads * channels;
978
- if (channels > 1024)
979
- {
980
- if ((channels & 1023) == 0)
981
- {
982
- ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks<scalar_t>
983
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
984
- num_threads*3*sizeof(scalar_t), stream>>>(
985
- num_kernels,
986
- grad_col,
987
- data_value,
988
- data_spatial_shapes,
989
- data_level_start_index,
990
- data_sampling_loc,
991
- data_attn_weight,
992
- batch_size,
993
- spatial_size,
994
- num_heads,
995
- channels,
996
- num_levels,
997
- num_query,
998
- num_point,
999
- grad_value,
1000
- grad_sampling_loc,
1001
- grad_attn_weight);
1002
- }
1003
- else
1004
- {
1005
- ms_deformable_col2im_gpu_kernel_gm<scalar_t>
1006
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1007
- 0, stream>>>(
1008
- num_kernels,
1009
- grad_col,
1010
- data_value,
1011
- data_spatial_shapes,
1012
- data_level_start_index,
1013
- data_sampling_loc,
1014
- data_attn_weight,
1015
- batch_size,
1016
- spatial_size,
1017
- num_heads,
1018
- channels,
1019
- num_levels,
1020
- num_query,
1021
- num_point,
1022
- grad_value,
1023
- grad_sampling_loc,
1024
- grad_attn_weight);
1025
- }
1026
- }
1027
- else{
1028
- switch(channels)
1029
- {
1030
- case 1:
1031
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 1>
1032
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1033
- 0, stream>>>(
1034
- num_kernels,
1035
- grad_col,
1036
- data_value,
1037
- data_spatial_shapes,
1038
- data_level_start_index,
1039
- data_sampling_loc,
1040
- data_attn_weight,
1041
- batch_size,
1042
- spatial_size,
1043
- num_heads,
1044
- channels,
1045
- num_levels,
1046
- num_query,
1047
- num_point,
1048
- grad_value,
1049
- grad_sampling_loc,
1050
- grad_attn_weight);
1051
- break;
1052
- case 2:
1053
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 2>
1054
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1055
- 0, stream>>>(
1056
- num_kernels,
1057
- grad_col,
1058
- data_value,
1059
- data_spatial_shapes,
1060
- data_level_start_index,
1061
- data_sampling_loc,
1062
- data_attn_weight,
1063
- batch_size,
1064
- spatial_size,
1065
- num_heads,
1066
- channels,
1067
- num_levels,
1068
- num_query,
1069
- num_point,
1070
- grad_value,
1071
- grad_sampling_loc,
1072
- grad_attn_weight);
1073
- break;
1074
- case 4:
1075
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 4>
1076
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1077
- 0, stream>>>(
1078
- num_kernels,
1079
- grad_col,
1080
- data_value,
1081
- data_spatial_shapes,
1082
- data_level_start_index,
1083
- data_sampling_loc,
1084
- data_attn_weight,
1085
- batch_size,
1086
- spatial_size,
1087
- num_heads,
1088
- channels,
1089
- num_levels,
1090
- num_query,
1091
- num_point,
1092
- grad_value,
1093
- grad_sampling_loc,
1094
- grad_attn_weight);
1095
- break;
1096
- case 8:
1097
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 8>
1098
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1099
- 0, stream>>>(
1100
- num_kernels,
1101
- grad_col,
1102
- data_value,
1103
- data_spatial_shapes,
1104
- data_level_start_index,
1105
- data_sampling_loc,
1106
- data_attn_weight,
1107
- batch_size,
1108
- spatial_size,
1109
- num_heads,
1110
- channels,
1111
- num_levels,
1112
- num_query,
1113
- num_point,
1114
- grad_value,
1115
- grad_sampling_loc,
1116
- grad_attn_weight);
1117
- break;
1118
- case 16:
1119
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 16>
1120
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1121
- 0, stream>>>(
1122
- num_kernels,
1123
- grad_col,
1124
- data_value,
1125
- data_spatial_shapes,
1126
- data_level_start_index,
1127
- data_sampling_loc,
1128
- data_attn_weight,
1129
- batch_size,
1130
- spatial_size,
1131
- num_heads,
1132
- channels,
1133
- num_levels,
1134
- num_query,
1135
- num_point,
1136
- grad_value,
1137
- grad_sampling_loc,
1138
- grad_attn_weight);
1139
- break;
1140
- case 32:
1141
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1<scalar_t, 32>
1142
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1143
- 0, stream>>>(
1144
- num_kernels,
1145
- grad_col,
1146
- data_value,
1147
- data_spatial_shapes,
1148
- data_level_start_index,
1149
- data_sampling_loc,
1150
- data_attn_weight,
1151
- batch_size,
1152
- spatial_size,
1153
- num_heads,
1154
- channels,
1155
- num_levels,
1156
- num_query,
1157
- num_point,
1158
- grad_value,
1159
- grad_sampling_loc,
1160
- grad_attn_weight);
1161
- break;
1162
- case 64:
1163
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 64>
1164
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1165
- 0, stream>>>(
1166
- num_kernels,
1167
- grad_col,
1168
- data_value,
1169
- data_spatial_shapes,
1170
- data_level_start_index,
1171
- data_sampling_loc,
1172
- data_attn_weight,
1173
- batch_size,
1174
- spatial_size,
1175
- num_heads,
1176
- channels,
1177
- num_levels,
1178
- num_query,
1179
- num_point,
1180
- grad_value,
1181
- grad_sampling_loc,
1182
- grad_attn_weight);
1183
- break;
1184
- case 128:
1185
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 128>
1186
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1187
- 0, stream>>>(
1188
- num_kernels,
1189
- grad_col,
1190
- data_value,
1191
- data_spatial_shapes,
1192
- data_level_start_index,
1193
- data_sampling_loc,
1194
- data_attn_weight,
1195
- batch_size,
1196
- spatial_size,
1197
- num_heads,
1198
- channels,
1199
- num_levels,
1200
- num_query,
1201
- num_point,
1202
- grad_value,
1203
- grad_sampling_loc,
1204
- grad_attn_weight);
1205
- break;
1206
- case 256:
1207
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 256>
1208
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1209
- 0, stream>>>(
1210
- num_kernels,
1211
- grad_col,
1212
- data_value,
1213
- data_spatial_shapes,
1214
- data_level_start_index,
1215
- data_sampling_loc,
1216
- data_attn_weight,
1217
- batch_size,
1218
- spatial_size,
1219
- num_heads,
1220
- channels,
1221
- num_levels,
1222
- num_query,
1223
- num_point,
1224
- grad_value,
1225
- grad_sampling_loc,
1226
- grad_attn_weight);
1227
- break;
1228
- case 512:
1229
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 512>
1230
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1231
- 0, stream>>>(
1232
- num_kernels,
1233
- grad_col,
1234
- data_value,
1235
- data_spatial_shapes,
1236
- data_level_start_index,
1237
- data_sampling_loc,
1238
- data_attn_weight,
1239
- batch_size,
1240
- spatial_size,
1241
- num_heads,
1242
- channels,
1243
- num_levels,
1244
- num_query,
1245
- num_point,
1246
- grad_value,
1247
- grad_sampling_loc,
1248
- grad_attn_weight);
1249
- break;
1250
- case 1024:
1251
- ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2<scalar_t, 1024>
1252
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1253
- 0, stream>>>(
1254
- num_kernels,
1255
- grad_col,
1256
- data_value,
1257
- data_spatial_shapes,
1258
- data_level_start_index,
1259
- data_sampling_loc,
1260
- data_attn_weight,
1261
- batch_size,
1262
- spatial_size,
1263
- num_heads,
1264
- channels,
1265
- num_levels,
1266
- num_query,
1267
- num_point,
1268
- grad_value,
1269
- grad_sampling_loc,
1270
- grad_attn_weight);
1271
- break;
1272
- default:
1273
- if (channels < 64)
1274
- {
1275
- ms_deformable_col2im_gpu_kernel_shm_reduce_v1<scalar_t>
1276
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1277
- num_threads*3*sizeof(scalar_t), stream>>>(
1278
- num_kernels,
1279
- grad_col,
1280
- data_value,
1281
- data_spatial_shapes,
1282
- data_level_start_index,
1283
- data_sampling_loc,
1284
- data_attn_weight,
1285
- batch_size,
1286
- spatial_size,
1287
- num_heads,
1288
- channels,
1289
- num_levels,
1290
- num_query,
1291
- num_point,
1292
- grad_value,
1293
- grad_sampling_loc,
1294
- grad_attn_weight);
1295
- }
1296
- else
1297
- {
1298
- ms_deformable_col2im_gpu_kernel_shm_reduce_v2<scalar_t>
1299
- <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
1300
- num_threads*3*sizeof(scalar_t), stream>>>(
1301
- num_kernels,
1302
- grad_col,
1303
- data_value,
1304
- data_spatial_shapes,
1305
- data_level_start_index,
1306
- data_sampling_loc,
1307
- data_attn_weight,
1308
- batch_size,
1309
- spatial_size,
1310
- num_heads,
1311
- channels,
1312
- num_levels,
1313
- num_query,
1314
- num_point,
1315
- grad_value,
1316
- grad_sampling_loc,
1317
- grad_attn_weight);
1318
- }
1319
- }
1320
- }
1321
- cudaError_t err = cudaGetLastError();
1322
- if (err != cudaSuccess)
1323
- {
1324
- printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
1325
- }
1326
-
1327
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.lock DELETED
@@ -1,169 +0,0 @@
1
- {
2
- "nodes": {
3
- "flake-compat": {
4
- "locked": {
5
- "lastModified": 1747046372,
6
- "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
- "owner": "edolstra",
8
- "repo": "flake-compat",
9
- "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
- "type": "github"
11
- },
12
- "original": {
13
- "owner": "edolstra",
14
- "repo": "flake-compat",
15
- "type": "github"
16
- }
17
- },
18
- "flake-compat_2": {
19
- "locked": {
20
- "lastModified": 1733328505,
21
- "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
- "owner": "edolstra",
23
- "repo": "flake-compat",
24
- "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
- "type": "github"
26
- },
27
- "original": {
28
- "owner": "edolstra",
29
- "repo": "flake-compat",
30
- "type": "github"
31
- }
32
- },
33
- "flake-utils": {
34
- "inputs": {
35
- "systems": "systems"
36
- },
37
- "locked": {
38
- "lastModified": 1731533236,
39
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
- "owner": "numtide",
41
- "repo": "flake-utils",
42
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
- "type": "github"
44
- },
45
- "original": {
46
- "owner": "numtide",
47
- "repo": "flake-utils",
48
- "type": "github"
49
- }
50
- },
51
- "flake-utils_2": {
52
- "inputs": {
53
- "systems": "systems_2"
54
- },
55
- "locked": {
56
- "lastModified": 1731533236,
57
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
- "owner": "numtide",
59
- "repo": "flake-utils",
60
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
- "type": "github"
62
- },
63
- "original": {
64
- "owner": "numtide",
65
- "repo": "flake-utils",
66
- "type": "github"
67
- }
68
- },
69
- "hf-nix": {
70
- "inputs": {
71
- "flake-compat": "flake-compat_2",
72
- "flake-utils": "flake-utils_2",
73
- "nixpkgs": "nixpkgs"
74
- },
75
- "locked": {
76
- "lastModified": 1753354560,
77
- "narHash": "sha256-vmOfRmr0Qm/IbZTWB2sBn+UFrABSTTA/cTg+m27Yt/E=",
78
- "owner": "huggingface",
79
- "repo": "hf-nix",
80
- "rev": "7f2aceda2a2e72cd573bdb25e5c0667fd75f89d3",
81
- "type": "github"
82
- },
83
- "original": {
84
- "owner": "huggingface",
85
- "repo": "hf-nix",
86
- "type": "github"
87
- }
88
- },
89
- "kernel-builder": {
90
- "inputs": {
91
- "flake-compat": "flake-compat",
92
- "flake-utils": "flake-utils",
93
- "hf-nix": "hf-nix",
94
- "nixpkgs": [
95
- "kernel-builder",
96
- "hf-nix",
97
- "nixpkgs"
98
- ]
99
- },
100
- "locked": {
101
- "lastModified": 1753354632,
102
- "narHash": "sha256-31SX3Raiyx0qCuY9JSlx9ZZgxljeUxvW+JdujjxbofQ=",
103
- "owner": "huggingface",
104
- "repo": "kernel-builder",
105
- "rev": "524b628fd8e58525dbd28455bffb0628092c5265",
106
- "type": "github"
107
- },
108
- "original": {
109
- "owner": "huggingface",
110
- "ref": "torch-2.8",
111
- "repo": "kernel-builder",
112
- "type": "github"
113
- }
114
- },
115
- "nixpkgs": {
116
- "locked": {
117
- "lastModified": 1752785354,
118
- "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
119
- "owner": "nixos",
120
- "repo": "nixpkgs",
121
- "rev": "d38025438a6ee456758dc03188ca6873a415463b",
122
- "type": "github"
123
- },
124
- "original": {
125
- "owner": "nixos",
126
- "repo": "nixpkgs",
127
- "rev": "d38025438a6ee456758dc03188ca6873a415463b",
128
- "type": "github"
129
- }
130
- },
131
- "root": {
132
- "inputs": {
133
- "kernel-builder": "kernel-builder"
134
- }
135
- },
136
- "systems": {
137
- "locked": {
138
- "lastModified": 1681028828,
139
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
140
- "owner": "nix-systems",
141
- "repo": "default",
142
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
143
- "type": "github"
144
- },
145
- "original": {
146
- "owner": "nix-systems",
147
- "repo": "default",
148
- "type": "github"
149
- }
150
- },
151
- "systems_2": {
152
- "locked": {
153
- "lastModified": 1681028828,
154
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
155
- "owner": "nix-systems",
156
- "repo": "default",
157
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
158
- "type": "github"
159
- },
160
- "original": {
161
- "owner": "nix-systems",
162
- "repo": "default",
163
- "type": "github"
164
- }
165
- }
166
- },
167
- "root": "root",
168
- "version": 7
169
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.nix DELETED
@@ -1,17 +0,0 @@
1
- {
2
- description = "Flake for deformable_detr kernels";
3
-
4
- inputs = {
5
- kernel-builder.url = "github:huggingface/kernel-builder/torch-2.8";
6
- };
7
-
8
- outputs =
9
- {
10
- self,
11
- kernel-builder,
12
- }:
13
- kernel-builder.lib.genFlakeOutputs {
14
- path = ./.;
15
- rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
- };
17
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/deformable_detr/__init__.py DELETED
@@ -1,46 +0,0 @@
1
- from typing import List
2
- import torch
3
-
4
- from ._ops import ops
5
- from . import layers
6
-
7
-
8
- def ms_deform_attn_backward(
9
- value: torch.Tensor,
10
- spatial_shapes: torch.Tensor,
11
- level_start_index: torch.Tensor,
12
- sampling_loc: torch.Tensor,
13
- attn_weight: torch.Tensor,
14
- grad_output: torch.Tensor,
15
- im2col_step: int,
16
- ) -> List[torch.Tensor]:
17
- return ops.ms_deform_attn_backward(
18
- value,
19
- spatial_shapes,
20
- level_start_index,
21
- sampling_loc,
22
- attn_weight,
23
- grad_output,
24
- im2col_step,
25
- )
26
-
27
-
28
- def ms_deform_attn_forward(
29
- value: torch.Tensor,
30
- spatial_shapes: torch.Tensor,
31
- level_start_index: torch.Tensor,
32
- sampling_loc: torch.Tensor,
33
- attn_weight: torch.Tensor,
34
- im2col_step: int,
35
- ) -> torch.Tensor:
36
- return ops.ms_deform_attn_forward(
37
- value,
38
- spatial_shapes,
39
- level_start_index,
40
- sampling_loc,
41
- attn_weight,
42
- im2col_step,
43
- )
44
-
45
-
46
- __all__ = ["layers", "ms_deform_attn_forward", "ms_deform_attn_backward"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/deformable_detr/layers.py DELETED
@@ -1,84 +0,0 @@
1
- from typing import List, Union, Tuple
2
-
3
- from torch import Tensor
4
- from torch.autograd import Function
5
- from torch.autograd.function import once_differentiable
6
- import torch.nn as nn
7
-
8
- from ._ops import ops
9
-
10
-
11
- class MultiScaleDeformableAttentionFunction(Function):
12
- @staticmethod
13
- def forward(
14
- context,
15
- value: Tensor,
16
- value_spatial_shapes: Tensor,
17
- value_level_start_index: Tensor,
18
- sampling_locations: Tensor,
19
- attention_weights: Tensor,
20
- im2col_step: int,
21
- ):
22
- context.im2col_step = im2col_step
23
- output = ops.ms_deform_attn_forward(
24
- value,
25
- value_spatial_shapes,
26
- value_level_start_index,
27
- sampling_locations,
28
- attention_weights,
29
- context.im2col_step,
30
- )
31
- context.save_for_backward(
32
- value,
33
- value_spatial_shapes,
34
- value_level_start_index,
35
- sampling_locations,
36
- attention_weights,
37
- )
38
- return output
39
-
40
- @staticmethod
41
- @once_differentiable
42
- def backward(context, grad_output):
43
- (
44
- value,
45
- value_spatial_shapes,
46
- value_level_start_index,
47
- sampling_locations,
48
- attention_weights,
49
- ) = context.saved_tensors
50
- grad_value, grad_sampling_loc, grad_attn_weight = ops.ms_deform_attn_backward(
51
- value,
52
- value_spatial_shapes,
53
- value_level_start_index,
54
- sampling_locations,
55
- attention_weights,
56
- grad_output,
57
- context.im2col_step,
58
- )
59
-
60
- return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
61
-
62
-
63
- class MultiScaleDeformableAttention(nn.Module):
64
- def forward(
65
- self,
66
- value: Tensor,
67
- value_spatial_shapes: Tensor,
68
- value_spatial_shapes_list: List[Tuple],
69
- level_start_index: Tensor,
70
- sampling_locations: Tensor,
71
- attention_weights: Tensor,
72
- im2col_step: int,
73
- ):
74
- return MultiScaleDeformableAttentionFunction.apply(
75
- value,
76
- value_spatial_shapes,
77
- level_start_index,
78
- sampling_locations,
79
- attention_weights,
80
- im2col_step,
81
- )
82
-
83
-
84
- __all__ = ["MultiScaleDeformableAttention"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/ms_deform_attn_cpu.cpp DELETED
@@ -1,40 +0,0 @@
1
- /*!
2
- **************************************************************************************************
3
- * Deformable DETR
4
- * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
- * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
- **************************************************************************************************
7
- * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
- **************************************************************************************************
9
- */
10
-
11
- #include <vector>
12
-
13
- #include <ATen/ATen.h>
14
- #include <ATen/cuda/CUDAContext.h>
15
-
16
-
17
- at::Tensor
18
- ms_deform_attn_cpu_forward(
19
- const at::Tensor &value,
20
- const at::Tensor &spatial_shapes,
21
- const at::Tensor &level_start_index,
22
- const at::Tensor &sampling_loc,
23
- const at::Tensor &attn_weight,
24
- const int im2col_step)
25
- {
26
- AT_ERROR("Not implement on cpu");
27
- }
28
-
29
- std::vector<at::Tensor>
30
- ms_deform_attn_cpu_backward(
31
- const at::Tensor &value,
32
- const at::Tensor &spatial_shapes,
33
- const at::Tensor &level_start_index,
34
- const at::Tensor &sampling_loc,
35
- const at::Tensor &attn_weight,
36
- const at::Tensor &grad_output,
37
- const int im2col_step)
38
- {
39
- AT_ERROR("Not implement on cpu");
40
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/ms_deform_attn_cpu.h DELETED
@@ -1,32 +0,0 @@
1
- /*!
2
- **************************************************************************************************
3
- * Deformable DETR
4
- * Copyright (c) 2020 SenseTime. All Rights Reserved.
5
- * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6
- **************************************************************************************************
7
- * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8
- **************************************************************************************************
9
- */
10
-
11
- #pragma once
12
- #include <torch/extension.h>
13
-
14
- at::Tensor
15
- ms_deform_attn_cpu_forward(
16
- const at::Tensor &value,
17
- const at::Tensor &spatial_shapes,
18
- const at::Tensor &level_start_index,
19
- const at::Tensor &sampling_loc,
20
- const at::Tensor &attn_weight,
21
- const int im2col_step);
22
-
23
- std::vector<at::Tensor>
24
- ms_deform_attn_cpu_backward(
25
- const at::Tensor &value,
26
- const at::Tensor &spatial_shapes,
27
- const at::Tensor &level_start_index,
28
- const at::Tensor &sampling_loc,
29
- const at::Tensor &attn_weight,
30
- const at::Tensor &grad_output,
31
- const int im2col_step);
32
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/torch_binding.cpp DELETED
@@ -1,19 +0,0 @@
1
- #include <torch/library.h>
2
-
3
- #include "registration.h"
4
- #include "torch_binding.h"
5
-
6
- TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
- ops.def("ms_deform_attn_forward(Tensor value, Tensor spatial_shapes,"
8
- " Tensor level_start_index, Tensor sampling_loc,"
9
- " Tensor attn_weight, int im2col_step) -> Tensor");
10
- ops.impl("ms_deform_attn_forward", torch::kCUDA, &ms_deform_attn_cuda_forward);
11
-
12
- ops.def("ms_deform_attn_backward(Tensor value, Tensor spatial_shapes,"
13
- " Tensor level_start_index, Tensor sampling_loc,"
14
- " Tensor attn_weight, Tensor grad_output,"
15
- " int im2col_step) -> Tensor[]");
16
- ops.impl("ms_deform_attn_backward", torch::kCUDA, &ms_deform_attn_cuda_backward);
17
- }
18
-
19
- REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
torch-ext/torch_binding.h DELETED
@@ -1,16 +0,0 @@
1
- #pragma once
2
-
3
- #include <torch/torch.h>
4
-
5
- at::Tensor ms_deform_attn_cuda_forward(const at::Tensor &value,
6
- const at::Tensor &spatial_shapes,
7
- const at::Tensor &level_start_index,
8
- const at::Tensor &sampling_loc,
9
- const at::Tensor &attn_weight,
10
- const int64_t im2col_step);
11
-
12
- std::vector<at::Tensor> ms_deform_attn_cuda_backward(
13
- const at::Tensor &value, const at::Tensor &spatial_shapes,
14
- const at::Tensor &level_start_index, const at::Tensor &sampling_loc,
15
- const at::Tensor &attn_weight, const at::Tensor &grad_output,
16
- const int64_t im2col_step);