File size: 13,094 Bytes
d19bd3e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 | /*!
* Modified from Deformable DETR
*/
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cuda_runtime.h>
#include <device_launch_parameters.h>
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
#define CUDA_NUM_THREADS 512
#define MAX_POINT 32
inline int GET_BLOCKS(const int N, const int num_threads) {
return (N + num_threads - 1) / num_threads;
}
__device__ float ms_deform_attn_im2col_bilinear(
const float*& bottom_data,
const int& height, const int& width, const int& channels,
const float& h, const float& w, const int& c) {
const int h_low = floor(h);
const int w_low = floor(w);
const int h_high = h_low + 1;
const int w_high = w_low + 1;
const float lh = h - h_low;
const float lw = w - w_low;
const float hh = 1 - lh, hw = 1 - lw;
const int w_stride = channels;
const int h_stride = width * w_stride;
const int h_low_ptr_offset = h_low * h_stride;
const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
const int w_low_ptr_offset = w_low * w_stride;
const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
float v1 = 0;
if (h_low >= 0 && w_low >= 0) {
const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + c;
v1 = bottom_data[ptr1];
}
float v2 = 0;
if (h_low >= 0 && w_high <= width - 1) {
const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + c;
v2 = bottom_data[ptr2];
}
float v3 = 0;
if (h_high <= height - 1 && w_low >= 0) {
const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + c;
v3 = bottom_data[ptr3];
}
float v4 = 0;
if (h_high <= height - 1 && w_high <= width - 1) {
const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + c;
v4 = bottom_data[ptr4];
}
const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
return val;
}
__global__ void ms_deformable_im2col_gpu_kernel_c2345(
const float* feat_c2,
const float* feat_c3,
const float* feat_c4,
const float* feat_c5,
const int h_c2, const int w_c2,
const int h_c3, const int w_c3,
const int h_c4, const int w_c4,
const int h_c5, const int w_c5,
const float* data_sampling_loc,
const float* data_attn_weight,
const int batch_size,
const int channels,
const int num_views,
const int num_query,
const int num_point,
float* data_col) {
float res[MAX_POINT];
CUDA_KERNEL_LOOP(index, batch_size * num_query * channels) { // n: bs x query x channels
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
_temp /= num_query;
const int b_col = _temp;
for (int p_col = 0; p_col < num_point; ++p_col) { res[p_col] = 0; }
for (int p_col = 0; p_col < num_point; ++p_col) {
// Sampling location in range [0, 1]
int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;
const float loc_w = data_sampling_loc[data_loc_ptr];
const float loc_h = data_sampling_loc[data_loc_ptr + 1];
const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));
// Attn weights
int data_weight_ptr = sampling_index * num_point * 4 + p_col * 4;
const float weight_c2 = data_attn_weight[data_weight_ptr];
const float weight_c3 = data_attn_weight[data_weight_ptr + 1];
const float weight_c4 = data_attn_weight[data_weight_ptr + 2];
const float weight_c5 = data_attn_weight[data_weight_ptr + 3];
//const float h_im = loc_h * spatial_h - 0.5; // align_corners = False
//const float w_im = loc_w * spatial_w - 0.5;
// C2 Feature
float h_im = loc_h * (h_c2 - 1); // align_corners = True
float w_im = loc_w * (w_c2 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2) {
const float* feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col) * weight_c2;
}
// C3 Feature
h_im = loc_h * (h_c3 - 1); // align_corners = True
w_im = loc_w * (w_c3 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3) {
const float* feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col) * weight_c3;
}
// C4 Feature
h_im = loc_h * (h_c4 - 1); // align_corners = True
w_im = loc_w * (w_c4 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4) {
const float* feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col) * weight_c4;
}
// C5 Feature
h_im = loc_h * (h_c5 - 1); // align_corners = True
w_im = loc_w * (w_c5 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5) {
const float* feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col) * weight_c5;
}
}
for (int p_col = 0; p_col < num_point; ++p_col) {
float* data_col_ptr = data_col + index * num_point + p_col;
*data_col_ptr = res[p_col];
}
}
}
__global__ void ms_deformable_im2col_gpu_kernel_c23456(
const float* feat_c2,
const float* feat_c3,
const float* feat_c4,
const float* feat_c5,
const float* feat_c6,
const int h_c2, const int w_c2,
const int h_c3, const int w_c3,
const int h_c4, const int w_c4,
const int h_c5, const int w_c5,
const int h_c6, const int w_c6,
const float* data_sampling_loc,
const float* data_attn_weight,
const int batch_size,
const int channels,
const int num_views,
const int num_query,
const int num_point,
float* data_col) {
float res[MAX_POINT];
CUDA_KERNEL_LOOP(index, batch_size * num_query * channels) { // n: bs x query x channels
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
const int sampling_index = _temp;
_temp /= num_query;
const int b_col = _temp;
for (int p_col = 0; p_col < num_point; ++p_col) { res[p_col] = 0; }
for (int p_col = 0; p_col < num_point; ++p_col) {
// Sampling location in range [0, 1]
int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;
const float loc_w = data_sampling_loc[data_loc_ptr];
const float loc_h = data_sampling_loc[data_loc_ptr + 1];
const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));
// Attn weights
int data_weight_ptr = sampling_index * num_point * 5 + p_col * 5;
const float weight_c2 = data_attn_weight[data_weight_ptr];
const float weight_c3 = data_attn_weight[data_weight_ptr + 1];
const float weight_c4 = data_attn_weight[data_weight_ptr + 2];
const float weight_c5 = data_attn_weight[data_weight_ptr + 3];
const float weight_c6 = data_attn_weight[data_weight_ptr + 4];
//const float h_im = loc_h * spatial_h - 0.5; // align_corners = False
//const float w_im = loc_w * spatial_w - 0.5;
// C2 Feature
float h_im = loc_h * (h_c2 - 1); // align_corners = True
float w_im = loc_w * (w_c2 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2) {
const float* feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col) * weight_c2;
}
// C3 Feature
h_im = loc_h * (h_c3 - 1); // align_corners = True
w_im = loc_w * (w_c3 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3) {
const float* feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col) * weight_c3;
}
// C4 Feature
h_im = loc_h * (h_c4 - 1); // align_corners = True
w_im = loc_w * (w_c4 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4) {
const float* feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col) * weight_c4;
}
// C5 Feature
h_im = loc_h * (h_c5 - 1); // align_corners = True
w_im = loc_w * (w_c5 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5) {
const float* feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col) * weight_c5;
}
// C6 Feature
h_im = loc_h * (h_c6 - 1); // align_corners = True
w_im = loc_w * (w_c6 - 1);
if (h_im > -1 && w_im > -1 && h_im < h_c6 && w_im < w_c6) {
const float* feat_c6_ptr = feat_c6 + b_col * num_views * h_c6 * w_c6 * channels + loc_v * h_c6 * w_c6 * channels;
res[p_col] += ms_deform_attn_im2col_bilinear(feat_c6_ptr, h_c6, w_c6, channels, h_im, w_im, c_col) * weight_c6;
}
}
for (int p_col = 0; p_col < num_point; ++p_col) {
float* data_col_ptr = data_col + index * num_point + p_col;
*data_col_ptr = res[p_col];
}
}
}
void ms_deformable_im2col_cuda_c2345(
const float* feat_c2,
const float* feat_c3,
const float* feat_c4,
const float* feat_c5,
const int h_c2, const int w_c2,
const int h_c3, const int w_c3,
const int h_c4, const int w_c4,
const int h_c5, const int w_c5,
const float* data_sampling_loc,
const float* data_attn_weight,
const int batch_size,
const int channels,
const int num_views,
const int num_query,
const int num_point,
float* data_col) {
const int num_kernels = batch_size * num_query * channels;
const int num_threads = CUDA_NUM_THREADS;
ms_deformable_im2col_gpu_kernel_c2345 <<<GET_BLOCKS(num_kernels, num_threads), num_threads>>> (
feat_c2, feat_c3, feat_c4, feat_c5, h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5,
data_sampling_loc, data_attn_weight, batch_size, channels, num_views, num_query, num_point, data_col
);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in ms_deformable_im2col_cuda_c2345: %s\n", cudaGetErrorString(err));
}
}
void ms_deformable_im2col_cuda_c23456(
const float* feat_c2,
const float* feat_c3,
const float* feat_c4,
const float* feat_c5,
const float* feat_c6,
const int h_c2, const int w_c2,
const int h_c3, const int w_c3,
const int h_c4, const int w_c4,
const int h_c5, const int w_c5,
const int h_c6, const int w_c6,
const float* data_sampling_loc,
const float* data_attn_weight,
const int batch_size,
const int channels,
const int num_views,
const int num_query,
const int num_point,
float* data_col) {
const int num_kernels = batch_size * num_query * channels;
const int num_threads = CUDA_NUM_THREADS;
ms_deformable_im2col_gpu_kernel_c23456 <<<GET_BLOCKS(num_kernels, num_threads), num_threads>>> (
feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6,
data_sampling_loc, data_attn_weight, batch_size, channels, num_views, num_query, num_point, data_col
);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("error in ms_deformable_im2col_cuda_c23456: %s\n", cudaGetErrorString(err));
}
}
|