| |
| #pragma once |
| #include <torch/types.h> |
|
|
| namespace detectron2 { |
|
|
| at::Tensor nms_rotated_cpu( |
| const at::Tensor& dets, |
| const at::Tensor& scores, |
| const double iou_threshold); |
|
|
| #if defined(WITH_CUDA) || defined(WITH_HIP) |
| at::Tensor nms_rotated_cuda( |
| const at::Tensor& dets, |
| const at::Tensor& scores, |
| const double iou_threshold); |
| #endif |
|
|
| |
| |
| |
| inline at::Tensor nms_rotated( |
| const at::Tensor& dets, |
| const at::Tensor& scores, |
| const double iou_threshold) { |
| assert(dets.device().is_cuda() == scores.device().is_cuda()); |
| if (dets.device().is_cuda()) { |
| #if defined(WITH_CUDA) || defined(WITH_HIP) |
| return nms_rotated_cuda( |
| dets.contiguous(), scores.contiguous(), iou_threshold); |
| #else |
| AT_ERROR("Detectron2 is not compiled with GPU support!"); |
| #endif |
| } |
|
|
| return nms_rotated_cpu(dets.contiguous(), scores.contiguous(), iou_threshold); |
| } |
|
|
| } |
|
|