File size: 6,817 Bytes
663494c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
// Modified from
// https://github.com/open-mmlab/OpenPCDet/blob/master/pcdet/ops/iou3d_nms/src/iou3d_nms.cpp
/*
3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others)
Written by Shaoshuai Shi
All Rights Reserved 2019-2020.
*/
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <torch/extension.h>
#include <torch/serialize/tensor.h>
#include <vector>
#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
#define CHECK_ERROR(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line,
bool abort = true) {
if (code != cudaSuccess) {
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
line);
if (abort) exit(code);
}
}
const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;
void boxesoverlapLauncher(const int num_a, const float *boxes_a,
const int num_b, const float *boxes_b,
float *ans_overlap);
void boxesioubevLauncher(const int num_a, const float *boxes_a, const int num_b,
const float *boxes_b, float *ans_iou);
void nmsLauncher(const float *boxes, unsigned long long *mask, int boxes_num,
float nms_overlap_thresh);
void nmsNormalLauncher(const float *boxes, unsigned long long *mask,
int boxes_num, float nms_overlap_thresh);
int boxes_overlap_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b,
at::Tensor ans_overlap) {
// params boxes_a: (N, 5) [x1, y1, x2, y2, ry]
// params boxes_b: (M, 5)
// params ans_overlap: (N, M)
CHECK_INPUT(boxes_a);
CHECK_INPUT(boxes_b);
CHECK_INPUT(ans_overlap);
int num_a = boxes_a.size(0);
int num_b = boxes_b.size(0);
const float *boxes_a_data = boxes_a.data_ptr<float>();
const float *boxes_b_data = boxes_b.data_ptr<float>();
float *ans_overlap_data = ans_overlap.data_ptr<float>();
boxesoverlapLauncher(num_a, boxes_a_data, num_b, boxes_b_data,
ans_overlap_data);
return 1;
}
int boxes_iou_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b,
at::Tensor ans_iou) {
// params boxes_a: (N, 5) [x1, y1, x2, y2, ry]
// params boxes_b: (M, 5)
// params ans_overlap: (N, M)
CHECK_INPUT(boxes_a);
CHECK_INPUT(boxes_b);
CHECK_INPUT(ans_iou);
int num_a = boxes_a.size(0);
int num_b = boxes_b.size(0);
const float *boxes_a_data = boxes_a.data_ptr<float>();
const float *boxes_b_data = boxes_b.data_ptr<float>();
float *ans_iou_data = ans_iou.data_ptr<float>();
boxesioubevLauncher(num_a, boxes_a_data, num_b, boxes_b_data, ans_iou_data);
return 1;
}
int nms_gpu(at::Tensor boxes, at::Tensor keep,
float nms_overlap_thresh, int device_id) {
// params boxes: (N, 5) [x1, y1, x2, y2, ry]
// params keep: (N)
CHECK_INPUT(boxes);
CHECK_CONTIGUOUS(keep);
cudaSetDevice(device_id);
int boxes_num = boxes.size(0);
const float *boxes_data = boxes.data_ptr<float>();
long *keep_data = keep.data_ptr<long>();
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
unsigned long long *mask_data = NULL;
CHECK_ERROR(cudaMalloc((void **)&mask_data,
boxes_num * col_blocks * sizeof(unsigned long long)));
nmsLauncher(boxes_data, mask_data, boxes_num, nms_overlap_thresh);
// unsigned long long mask_cpu[boxes_num * col_blocks];
// unsigned long long *mask_cpu = new unsigned long long [boxes_num *
// col_blocks];
std::vector<unsigned long long> mask_cpu(boxes_num * col_blocks);
// printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks);
CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data,
boxes_num * col_blocks * sizeof(unsigned long long),
cudaMemcpyDeviceToHost));
cudaFree(mask_data);
unsigned long long remv_cpu[col_blocks];
memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));
int num_to_keep = 0;
for (int i = 0; i < boxes_num; i++) {
int nblock = i / THREADS_PER_BLOCK_NMS;
int inblock = i % THREADS_PER_BLOCK_NMS;
if (!(remv_cpu[nblock] & (1ULL << inblock))) {
keep_data[num_to_keep++] = i;
unsigned long long *p = &mask_cpu[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv_cpu[j] |= p[j];
}
}
}
if (cudaSuccess != cudaGetLastError()) printf("Error!\n");
return num_to_keep;
}
int nms_normal_gpu(at::Tensor boxes, at::Tensor keep,
float nms_overlap_thresh, int device_id) {
// params boxes: (N, 5) [x1, y1, x2, y2, ry]
// params keep: (N)
CHECK_INPUT(boxes);
CHECK_CONTIGUOUS(keep);
cudaSetDevice(device_id);
int boxes_num = boxes.size(0);
const float *boxes_data = boxes.data_ptr<float>();
long *keep_data = keep.data_ptr<long>();
const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
unsigned long long *mask_data = NULL;
CHECK_ERROR(cudaMalloc((void **)&mask_data,
boxes_num * col_blocks * sizeof(unsigned long long)));
nmsNormalLauncher(boxes_data, mask_data, boxes_num, nms_overlap_thresh);
// unsigned long long mask_cpu[boxes_num * col_blocks];
// unsigned long long *mask_cpu = new unsigned long long [boxes_num *
// col_blocks];
std::vector<unsigned long long> mask_cpu(boxes_num * col_blocks);
// printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks);
CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data,
boxes_num * col_blocks * sizeof(unsigned long long),
cudaMemcpyDeviceToHost));
cudaFree(mask_data);
unsigned long long remv_cpu[col_blocks];
memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));
int num_to_keep = 0;
for (int i = 0; i < boxes_num; i++) {
int nblock = i / THREADS_PER_BLOCK_NMS;
int inblock = i % THREADS_PER_BLOCK_NMS;
if (!(remv_cpu[nblock] & (1ULL << inblock))) {
keep_data[num_to_keep++] = i;
unsigned long long *p = &mask_cpu[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv_cpu[j] |= p[j];
}
}
}
if (cudaSuccess != cudaGetLastError()) printf("Error!\n");
return num_to_keep;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("boxes_overlap_bev_gpu", &boxes_overlap_bev_gpu,
"oriented boxes overlap");
m.def("boxes_iou_bev_gpu", &boxes_iou_bev_gpu, "oriented boxes iou");
m.def("nms_gpu", &nms_gpu, "oriented nms gpu");
m.def("nms_normal_gpu", &nms_normal_gpu, "nms gpu");
}
|