| |
| |
|
|
| #include "non_maximal_suppression.h" |
|
|
| #include <cooperative_groups.h> |
| #include <cooperative_groups/reduce.h> |
|
|
| #include <thrust/binary_search.h> |
| #include <thrust/device_vector.h> |
| #include <thrust/execution_policy.h> |
|
|
| #include <c10/cuda/CUDACachingAllocator.h> |
|
|
| #include <trove/ptr.h> |
|
|
| #include "../cuda_intellisense.cuh" |
| #include "../geometry.h" |
| #include "../common.h" |
| #include "../scope_timer.h" |
| #include "strided_quad.h" |
|
|
| |
| |
| |
|
|
| namespace cg = cooperative_groups; |
| namespace ix = torch::indexing; |
|
|
| inline |
| void print_tensor_stats2(const std::string &msg, const torch::Tensor& tensor) { |
|
|
| auto fTensor = tensor.to(torch::kDouble).cpu(); |
|
|
| std::stringstream ss; |
| if (tensor.numel() > 1) { |
| ss << msg << " Size: " << tensor.sizes() << " Type: " << tensor.dtype() << " Device: " << tensor.device() << " Max: " << fTensor.max().item<double>() << " Min: " << fTensor.min().item<double>() << " Mean: " << fTensor.mean().item<double>() << " Std: " << fTensor.std().item<double>(); |
| } |
| else if (tensor.numel() == 1) { |
| ss << msg << " Size: " << tensor.sizes() << " Type: " << tensor.dtype() << " Device: " << tensor.device() << " Value: " << fTensor.item<double>() << std::endl; |
| } |
| else { |
| ss << msg << " Size: " << tensor.sizes() << " Type: " << tensor.dtype() << " Device: " << tensor.device() << std::endl; |
| } |
| std::cout << ss.str() << std::endl; |
| } |
|
|
| inline |
| void print_tensor_vec_stats2(std::string msg, const std::vector<torch::Tensor>& tensorVec) { |
| std::cout << msg << " Size: " << tensorVec.size() << std::endl; |
| std::stringstream ss; |
| msg = " - "; |
| for (int i = 0; i < tensorVec.size(); ++i) { |
| ss << msg << "[" << i << "]:"; |
| auto tensor = tensorVec[i]; |
| print_tensor_stats2(ss.str(), tensor); |
| ss.str(""); |
| } |
| } |
|
|
| std::ostream &operator<<(std::ostream &os, dim3 d) |
| { |
| return os << "(" << d.x << ", " << d.y << ", " << d.z << ")"; |
| } |
|
|
| #define ADD_OP2(vector2_t) __device__ \ |
| vector2_t operator+(const vector2_t &a, const vector2_t &b) { \ |
| return { a.x + b.x, a.y + b.y }; \ |
| } |
| ADD_OP2(float2); |
| ADD_OP2(double2); |
| #undef ADD_OP2 |
|
|
| #define ADD_OP4(vector4_t) __device__ \ |
| vector4_t operator+(const vector4_t &a, const vector4_t &b) { \ |
| return { a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w }; \ |
| } |
| ADD_OP4(float4); |
| ADD_OP4(double4); |
| #undef ADD_OP4 |
|
|
| template<typename T, size_t Size> |
| __device__ |
| std::array<T, Size> operator+(const std::array<T, Size> &a, const std::array<T, Size> &b) { |
| std::array<T, Size> ret; |
| #pragma unroll |
| for (size_t i = 0; i < Size; ++i) { |
| ret._Elems[i] = a._Elems[i] + b._Elems[i]; |
| } |
| return ret; |
| } |
|
|
| #if __CUDA_ARCH__ >= 800 |
| #define __reduce_add_full_warp(val) __reduce_add_sync(0xFFFFFFFF, val) |
| #define __reduce_max_full_warp(val) __reduce_max_sync(0xFFFFFFFF, val) |
| #define __reduce_min_full_warp(val) __reduce_min_sync(0xFFFFFFFF, val) |
| #else |
| #define __reduce_add_full_warp(val) cg::reduce(cg::tiled_partition<32>(cg::this_thread_block()), val, cg::plus<decltype(val)>()) |
| #define __reduce_max_full_warp(val) cg::reduce(cg::tiled_partition<32>(cg::this_thread_block()), val, cg::greater<decltype(val)>()) |
| #define __reduce_min_full_warp(val) cg::reduce(cg::tiled_partition<32>(cg::this_thread_block()), val, cg::less<decltype(val)>()) |
| #endif |
|
|
| template<typename T> |
| struct TToVec; |
| template<> |
| struct TToVec<float> { typedef float2 type2; typedef float4 type4; }; |
| template<> |
| struct TToVec<double> { typedef double2 type2; typedef double4 type4; }; |
|
|
| template<typename T, typename accessor_t> |
| __device__ |
| void write_embed_quad(accessor_t &acc, const MergeQuad_<T> &quad, int64_t storeOff) |
| { |
| constexpr auto EMBED_QUAD_SIZE = sizeof(EmbedQuad_<T>) / sizeof(T); |
| static_assert(EMBED_QUAD_SIZE == 10, "Unsupported embed quad size!"); |
|
|
| const T *mergeBuff = reinterpret_cast<const T*>(&quad); |
|
|
| const T confidence = quad.Confidence; |
| const auto i = threadIdx.x; |
|
|
| if (i >= 10) { |
| return; |
| } |
|
|
| T outVal; |
| |
| if (i < 8) { |
| outVal = mergeBuff[i] / confidence; |
| |
| } else if (i == 8) { |
| outVal = confidence / mergeBuff[9]; |
| |
| } else { |
| outVal = mergeBuff[9]; |
| } |
|
|
| acc[i][storeOff] = outVal; |
| } |
|
|
|
|
| template<typename group_t, typename ...Args> |
| __device__ |
| void ordered_print(group_t &group, const char *const fmt, const Args& ...args) |
| { |
| for (uint32_t i = 0; i < group.size(); ++i) { |
| if (group.thread_rank() == i) { |
| printf(fmt, args...); |
| } |
| group.sync(); |
| } |
| } |
|
|
| template<typename T> |
| __global__ |
| void device_row_collapse(torch::PackedTensorAccessor64<T, 5> allQuads, |
| torch::PackedTensorAccessor64<T, 3> allConfs, |
| T confThreshold, T iouThreshold, |
| torch::PackedTensorAccessor64<int32_t, 1> allOutCounts, |
| torch::PackedTensorAccessor64<T, 3> allOutEmbedQuads, |
| torch::PackedTensorAccessor64<int32_t, 2> allOutIds) |
| { |
| typedef InPlaceQuad_<T> Quadf; |
| static_assert(sizeof(Quadf) == sizeof(T) * 8, "Invalid QuadMem size!"); |
|
|
| constexpr uint32_t ALL_MASK = 0xFFFFFFFF; |
| constexpr uint32_t WARP_SIZE = 32; |
| constexpr T MIN_VALID_AREA = 8; |
|
|
| const uint32_t B = allQuads.size(0); |
| const uint32_t H = allQuads.size(1); |
|
|
| const uint32_t b = blockIdx.z; |
| const uint32_t r = blockIdx.y * blockDim.y + threadIdx.y; |
|
|
| if (r >= H) { |
| return; |
| } |
|
|
| #define threadRank threadIdx.x |
|
|
| auto rawQuads = reinterpret_cast<Quadf*>(allQuads[b][r].data()); |
| #if defined(NDEBUG) |
| trove::coalesced_ptr<Quadf> quads(rawQuads); |
| #else |
| auto quads = rawQuads; |
| #endif |
|
|
| auto confs = allConfs[b][r]; |
|
|
| T conf = confs[threadRank]; |
|
|
| bool quadValid = conf >= confThreshold; |
| uint32_t ballot = __ballot_sync(ALL_MASK, quadValid); |
|
|
| |
| if (ballot == 0) { |
| return; |
| } |
|
|
| const Quadf currQuad = quads[threadRank]; |
|
|
| const T qArea = currQuad.Area(); |
|
|
| quadValid = quadValid && qArea > MIN_VALID_AREA; |
| ballot = __ballot_sync(ALL_MASK, quadValid); |
| if (ballot == 0) { |
| return; |
| } |
| if (! quadValid) { |
| conf = 0; |
| } |
|
|
| MergeQuad_<T> qAccum{ZeroInitTag{}}; |
|
|
| Quadf prevQuad; |
| auto pCurrQuad = reinterpret_cast<const T*>(&currQuad); |
| auto pPrevQuad = reinterpret_cast<T*>(&prevQuad); |
| #pragma unroll |
| for (uint32_t i = 0; i < 8; ++i) { |
| pPrevQuad[i] = __shfl_up_sync(ALL_MASK, pCurrQuad[i], 1); |
| } |
| T prevConf = __shfl_up_sync(ALL_MASK, conf, 1); |
|
|
| if (threadRank == 0) { |
| prevConf = 0; |
| } |
|
|
| bool iouValid = false; |
| T iou = 0; |
| if (quadValid) { |
| qAccum.Append(currQuad, conf); |
|
|
| if (prevConf >= confThreshold) { |
| iou = prevQuad.IOU_UpperBound(currQuad); |
| if (iou >= iouThreshold) { |
| iouValid = true; |
| } |
| } |
| } |
|
|
| |
| |
| const bool isStartOfSpan = quadValid && !iouValid; |
|
|
| uint32_t label = isStartOfSpan; |
| |
| |
| #pragma unroll |
| for (uint32_t offset = 1; offset < 32; offset <<= 1) { |
| auto inc = __shfl_up_sync(ALL_MASK, label, offset); |
| if (threadRank >= offset) { |
| label += inc; |
| } |
| } |
|
|
| |
| const uint32_t numLabels = __shfl_sync(ALL_MASK, label, WARP_SIZE - 1); |
|
|
| |
| label = quadValid ? label : 0; |
|
|
| T* accumPtr = reinterpret_cast<T*>(&qAccum); |
| |
| |
| #pragma unroll |
| for (uint32_t offset = 1; offset < 32; offset <<= 1) { |
| const auto otherLabel = __shfl_down_sync(ALL_MASK, label, offset); |
|
|
| |
| |
| const T factor = otherLabel == label && offset + threadRank < WARP_SIZE ? 1.0f : 0.0f; |
|
|
| #pragma unroll |
| for (uint32_t i = 0; i < 10; ++i) { |
| accumPtr[i] += factor * __shfl_down_sync(ALL_MASK, accumPtr[i], offset); |
| } |
| } |
|
|
| |
| uint32_t storeOff = 0; |
| if (threadRank == 0) { |
| storeOff = atomicAdd(&allOutCounts[b], numLabels); |
| } |
| |
| storeOff = __shfl_sync(ALL_MASK, storeOff, 0); |
|
|
| auto outEmbedQuads = allOutEmbedQuads[b]; |
| |
| for (uint32_t procLabel = 1; procLabel <= numLabels; ++procLabel) { |
| |
| ballot = __ballot_sync(ALL_MASK, procLabel == label); |
| |
| |
| uint32_t startIdx = __ffs(ballot) - 1; |
|
|
| const T* inT = reinterpret_cast<T*>(&qAccum); |
| MergeQuad_<T> outQuad; |
| T* outT = reinterpret_cast<T*>(&outQuad); |
| #pragma unroll |
| for (uint32_t i = 0; i < 10; ++i) { |
| outT[i] = __shfl_sync(ALL_MASK, inT[i], startIdx); |
| } |
|
|
| write_embed_quad(outEmbedQuads, outQuad, storeOff + procLabel - 1); |
| if (threadRank == 0) { |
| allOutIds[b][storeOff + procLabel - 1] = r * 32 + startIdx; |
| } |
| } |
|
|
| if (threadRank == 0) { |
| |
| atomicAdd(&allOutCounts[B], numLabels); |
| } |
|
|
| #undef threadRank |
| } |
|
|
| template<typename T> |
| __global__ |
| void device_a2a_adjacency_sparse(const int32_t *ptrQuadCts, |
| T iouThreshold, |
| torch::PackedTensorAccessor64<T, 3> embedQuads, |
| torch::PackedTensorAccessor64<bool, 2> outIsStart, |
| torch::PackedTensorAccessor64<int32_t, 2> outAdjCounts, |
| torch::PackedTensorAccessor64<int32_t, 3> outSparseAdj) |
| { |
| const uint32_t b = blockIdx.y; |
|
|
| const int32_t quadCt = ptrQuadCts[b]; |
|
|
| if (quadCt == 0) { |
| return; |
| } |
|
|
| const int32_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x; |
| const int32_t row = jobIdx / quadCt; |
| const int32_t col = jobIdx % quadCt; |
|
|
| |
| if (row >= quadCt || col < row) { |
| return; |
| } |
|
|
| T* exData = embedQuads[b].data(); |
|
|
| const auto qRow = StridedEmbedQuad_<T>{ exData + row * embedQuads.stride(2), embedQuads.stride(1) }.Bounds(), |
| qCol = StridedEmbedQuad_<T>{ exData + col * embedQuads.stride(2), embedQuads.stride(1) }.Bounds(); |
|
|
| T pctRow, pctCol, iou; |
| thrust::tie(pctRow, pctCol, iou) = geometry_region_sizes(qRow, qCol); |
|
|
| auto warpGroup = cg::tiled_partition<32>(cg::this_thread_block()); |
|
|
| auto rowGroup = cg::labeled_partition(warpGroup, row); |
|
|
| const bool isValid = iou >= iouThreshold; |
|
|
| const uint32_t ballot = rowGroup.ballot(isValid); |
| const uint32_t numValid = __popc(ballot); |
|
|
| auto exAdjCounts = outAdjCounts[b].data(); |
|
|
| int32_t storeOff = 0; |
| if (numValid > 0 && rowGroup.thread_rank() == 0) { |
| storeOff = atomicAdd(exAdjCounts + row, numValid); |
| } |
| storeOff = rowGroup.shfl(storeOff, 0); |
|
|
| if (isValid) { |
| |
| |
| |
| uint32_t lowerMask = (1 << rowGroup.thread_rank()) - 1; |
|
|
| storeOff += __popc(ballot & lowerMask); |
|
|
| outSparseAdj[b][row][storeOff] = col; |
| if (row != col) { |
| |
| |
| outIsStart[b][col] = false; |
|
|
| |
| storeOff = atomicAdd(exAdjCounts + col, 1); |
| outSparseAdj[b][col][storeOff] = row; |
| } |
| } else if (pctRow > 0.8f || pctCol > 0.8f) { |
| T anchorHeight = qRow.Height(); |
| T otherHeight = qCol.Height(); |
|
|
| T ratio = anchorHeight > otherHeight ? |
| otherHeight / anchorHeight : |
| anchorHeight / otherHeight; |
| if (ratio > 0.9f) { |
| if (pctRow > 0.8f) { |
| |
| outIsStart[b][row] = false; |
| } |
| else { |
| outIsStart[b][col] = false; |
| } |
| } |
| } |
| } |
|
|
| template<uint32_t NumWarps, typename T, int32_t I_CELL_SIZE> |
| __global__ |
| void device_a2a_adjacency_build_grid(const int32_t *ptrQuadCts, |
| torch::PackedTensorAccessor64<T, 3> embedQuads, |
| torch::PackedTensorAccessor64<int32_t, 4> outGridCells, |
| torch::PackedTensorAccessor64<int32_t, 3> outQuadCells) |
| { |
| constexpr T MIN_T = std::numeric_limits<T>::min(); |
| constexpr T MAX_T = std::numeric_limits<T>::max(); |
| constexpr uint32_t WARP_SIZE = 32; |
| constexpr uint32_t BLOCK_SIZE = NumWarps * WARP_SIZE; |
| constexpr uint32_t FULL_WARP = 0xFFFFFFFF; |
| constexpr uint32_t FIRST_16_THREADS = 0x0FFFF; |
| constexpr T CELL_SIZE = I_CELL_SIZE; |
| constexpr T INV_CELL_SIZE = 1 / CELL_SIZE; |
|
|
| const uint32_t b = blockIdx.z; |
|
|
| const uint32_t quadCt = ptrQuadCts[b]; |
| const uint32_t quadIdx = blockIdx.y; |
|
|
| if (quadIdx >= quadCt) { |
| return; |
| } |
|
|
| const uint32_t threadRank = threadIdx.x; |
| const uint32_t localThreadRank = threadRank & 0x1F; |
|
|
| auto exQuads = embedQuads[b]; |
|
|
| const uint32_t numCells[2] = { outGridCells.size(2), outGridCells.size(1) }; |
|
|
| const uint32_t numRows = outGridCells.size(1); |
| const uint32_t numCols = outGridCells.size(2); |
|
|
| |
| |
| T sign = localThreadRank < 8 ? 1.0f : -1.0f; |
| T myVal = sign * (localThreadRank < 16 ? exQuads[localThreadRank & 0x7][quadIdx] : MIN_T); |
| #pragma unroll |
| for (uint32_t offset = 2; offset < 8; offset <<= 1) { |
| T nextVal = __shfl_down_sync(FIRST_16_THREADS, myVal, offset); |
| myVal = min(myVal, nextVal); |
| } |
| const uint32_t cellVal = max(0.0f, sign * INV_CELL_SIZE * myVal); |
|
|
| uint32_t minCell[2] = { __shfl_sync(FULL_WARP, cellVal, 0), __shfl_sync(FULL_WARP, cellVal, 1) }, |
| maxCell[2] = { __shfl_sync(FULL_WARP, cellVal, 8), __shfl_sync(FULL_WARP, cellVal, 9) }; |
|
|
| #pragma unroll |
| for (uint32_t i = 0; i < 2; ++i) { |
| maxCell[i] = min(numCells[i] - 1, maxCell[i]); |
| } |
|
|
| const uint32_t sizes[2] = { maxCell[0] - minCell[0] + 1, maxCell[1] - minCell[1] + 1 }; |
|
|
| const uint32_t totalCells = sizes[0] * sizes[1]; |
|
|
| auto exGridCells = outGridCells[b]; |
|
|
| for (uint32_t i = threadRank; i < totalCells; i += BLOCK_SIZE) { |
| uint32_t row = minCell[1] + i / sizes[0]; |
| uint32_t col = minCell[0] + i % sizes[0]; |
|
|
| int32_t *pCell = exGridCells[row][col].data(); |
|
|
| |
| int32_t storeOff = atomicAdd(pCell, 1) + 1; |
| pCell[storeOff] = quadIdx; |
| } |
|
|
| if (threadRank < 2) { |
| outQuadCells[b][quadIdx][threadRank] = minCell[threadRank]; |
| } else if (threadRank < 4) { |
| outQuadCells[b][quadIdx][threadRank] = maxCell[threadRank - 2]; |
| } |
| } |
|
|
| typedef uint8_t visit_mask_t; |
|
|
| template<uint32_t NumWarps, typename T> |
| __global__ |
| void device_a2a_adjacency_with_grid(const int32_t *ptrQuadCts, |
| T iouThreshold, |
| torch::PackedTensorAccessor64<T, 3> allEmbedQuads, |
| torch::PackedTensorAccessor64<int32_t, 4> allCells, |
| torch::PackedTensorAccessor64<int32_t, 3> allQuadExtents, |
| torch::PackedTensorAccessor64<bool, 2> outIsStart, |
| torch::PackedTensorAccessor64<int32_t, 2> outAdjCounts, |
| torch::PackedTensorAccessor64<int32_t, 3> outSparseAdj) |
| { |
| constexpr T MIN_T = std::numeric_limits<T>::min(); |
| constexpr T MAX_T = std::numeric_limits<T>::max(); |
| constexpr uint32_t WARP_SIZE = 32; |
| constexpr uint32_t BLOCK_SIZE = NumWarps * WARP_SIZE; |
|
|
| const uint32_t b = blockIdx.z; |
|
|
| const uint32_t quadCt = ptrQuadCts[b]; |
| const uint32_t quadIdx = blockIdx.y; |
|
|
| if (quadIdx >= quadCt) { |
| return; |
| } |
|
|
| const uint32_t threadRank = threadIdx.x; |
|
|
| auto exQuads = allEmbedQuads[b]; |
|
|
| __shared__ T s_quadVerts[8]; |
| __shared__ uint32_t s_quadExtent[4]; |
| extern __shared__ uint32_t s_alreadyVisited[]; |
|
|
| if (threadRank < 8) { |
| s_quadVerts[threadRank] = exQuads[threadRank][quadIdx]; |
| } else if (threadRank < 12) { |
| s_quadExtent[threadRank - 8] = reinterpret_cast<uint32_t*>(allQuadExtents[b][quadIdx].data())[threadRank - 8]; |
| } |
|
|
| uint32_t zeroTerm = (quadCt + 31u) >> 5u; |
| for (uint32_t col = threadRank; col < zeroTerm; col += BLOCK_SIZE) { |
| s_alreadyVisited[col] = 0; |
| } |
|
|
| __syncthreads(); |
|
|
| auto exCells = allCells[b]; |
| auto exAdjCounts = reinterpret_cast<uint32_t*>(outAdjCounts[b].data()); |
| auto exAdjValues = outSparseAdj[b][quadIdx].data(); |
|
|
| T *exData = allEmbedQuads[b].data(); |
|
|
| const auto bdsAnchor = Quad_<T>{ s_quadVerts }.Bounds(); |
|
|
| const uint32_t startCol = s_quadExtent[0], |
| endCol = s_quadExtent[2]; |
| for (uint32_t row = s_quadExtent[1], endRow = s_quadExtent[3]; row <= endRow; ++row) { |
| auto rowCells = exCells[row]; |
|
|
| for (uint32_t col = startCol; col <= endCol; ++col) { |
| auto colCells = reinterpret_cast<const uint32_t*>(rowCells[col].data()); |
|
|
| const uint32_t ct = colCells[0]; |
|
|
| for (uint32_t i = threadRank + 1; i <= ct; i += BLOCK_SIZE) { |
| const uint32_t otherIdx = colCells[i]; |
|
|
| const uint32_t maskIdx = otherIdx >> 5; |
| const uint32_t maskBit = 1 << (otherIdx & 0x1F); |
|
|
| const bool alreadyVisited = atomicOr(s_alreadyVisited + maskIdx, maskBit) & maskBit; |
|
|
| if (!alreadyVisited) { |
| const auto bdsOther = StridedEmbedQuad_<T>{ exData + otherIdx * allEmbedQuads.stride(2), allEmbedQuads.stride(1) }.Bounds(); |
|
|
| T pctAnchor, pctOther, iou; |
| thrust::tie(pctAnchor, pctOther, iou) = geometry_region_sizes(bdsAnchor, bdsOther); |
|
|
| if (iou >= iouThreshold) { |
| auto validGroup = cg::coalesced_threads(); |
|
|
| uint32_t storeOff = 0; |
| if (validGroup.thread_rank() == 0) { |
| storeOff = atomicAdd(exAdjCounts + quadIdx, validGroup.size()); |
| } |
| storeOff = validGroup.shfl(storeOff, 0) + validGroup.thread_rank(); |
|
|
| exAdjValues[storeOff] = otherIdx; |
|
|
| if (otherIdx > quadIdx) { |
| outIsStart[b][otherIdx] = false; |
| } |
| } else if (pctAnchor > 0.8f || pctOther > 0.8f) { |
| T anchorHeight = bdsAnchor.Height(); |
| T otherHeight = bdsOther.Height(); |
|
|
| T ratio = anchorHeight > otherHeight ? |
| otherHeight / anchorHeight : |
| anchorHeight / otherHeight; |
| if (ratio > 0.9f) { |
| if (pctAnchor > 0.8f) { |
| |
| outIsStart[b][quadIdx] = false; |
| } else { |
| outIsStart[b][otherIdx] = false; |
| } |
| } |
| } |
| } |
| } |
| } |
| } |
| } |
|
|
| __global__ |
| void device_flatten_graph_iterative(const int32_t *ptrQuadCts, |
| torch::PackedTensorAccessor64<bool, 2> allIsStart, |
| volatile uint32_t *allAdjCounts, |
| volatile uint32_t *allAdjValues |
| #ifdef NMS_VERIFY_CORRECTNESS |
| , int32_t *maxDepth |
| #endif |
| ) |
| { |
| constexpr uint32_t WARP_SIZE = 32; |
| constexpr uint32_t VISIT_STACK_SIZE = 9; |
| constexpr uint32_t TERM_VALUE = std::numeric_limits<uint32_t>::max(); |
|
|
| constexpr visit_mask_t VISITED_MASK = 0b001; |
| constexpr visit_mask_t ADDED_MASK = 0b010; |
| constexpr visit_mask_t QUEUED_MASK = 0b100; |
| constexpr visit_mask_t QUEUED_OR_VISITED_MASK = VISITED_MASK | QUEUED_MASK; |
|
|
| const uint32_t b = blockIdx.z; |
| const uint32_t anchorRow = blockIdx.y; |
|
|
| const uint32_t quadCt = ptrQuadCts[b]; |
|
|
| |
| |
| if (anchorRow >= quadCt) { |
| return; |
| } |
|
|
| auto isStart = allIsStart[b].data(); |
|
|
| const uint32_t threadRank = threadIdx.x; |
|
|
| extern __shared__ visit_mask_t s_visitedMask[]; |
|
|
| #ifndef NMS_VERIFY_CORRECTNESS |
| |
| |
| |
| |
| |
| const bool anchorIsStart = isStart[anchorRow]; |
| if (!anchorIsStart) { |
| return; |
| } |
| #endif |
|
|
| uint32_t *pIntVisitedMask = reinterpret_cast<uint32_t*>(s_visitedMask); |
| uint32_t zeroTerm = (quadCt + 3) >> 2; |
| for (uint32_t col = threadRank; col < zeroTerm; col += blockDim.x) { |
| pIntVisitedMask[col] = 0; |
| } |
|
|
| __syncthreads(); |
|
|
| const uint32_t maxExCount = allIsStart.size(1); |
| auto adjCounts = allAdjCounts + (b * maxExCount); |
| auto adjValues = allAdjValues + (b * maxExCount * maxExCount); |
|
|
| auto adjAnchorValues = adjValues + (anchorRow * maxExCount); |
| |
| |
| |
| for (uint32_t i = threadRank, ct = adjCounts[anchorRow]; i < ct; i += blockDim.x) { |
| const auto adjCol = adjAnchorValues[i]; |
| s_visitedMask[adjCol] = ADDED_MASK; |
| } |
|
|
| __syncthreads(); |
|
|
| if (threadRank == 0) { |
| s_visitedMask[anchorRow] |= QUEUED_MASK; |
| } |
|
|
| __syncthreads(); |
|
|
| |
| |
| if (threadRank >= WARP_SIZE) { |
| return; |
| } |
|
|
| uint32_t visitStack[VISIT_STACK_SIZE]; |
| visitStack[0] = TERM_VALUE; |
| visitStack[1] = anchorRow; |
| #ifndef NDEBUG |
| for (uint32_t i = 2; i < VISIT_STACK_SIZE; ++i) { |
| visitStack[i] = TERM_VALUE; |
| } |
| #endif |
| int32_t visitPtr = 1; |
|
|
| |
| for (uint32_t dfsIter = 0; true; ++dfsIter) { |
| #ifdef NMS_VERIFY_CORRECTNESS |
| assert(visitPtr >= 0 && visitPtr < VISIT_STACK_SIZE); |
| #endif |
| const uint32_t threadNextCol = visitStack[visitPtr]; |
| const uint32_t warpNextCol = __reduce_min_full_warp(threadNextCol); |
|
|
| |
| |
| if (threadNextCol == warpNextCol) { |
| #ifndef NDEBUG |
| |
| visitStack[visitPtr] = TERM_VALUE; |
| #endif |
| --visitPtr; |
| } |
|
|
| |
| |
| if (warpNextCol == TERM_VALUE) { |
| break; |
| } |
|
|
| const uint32_t procRow = warpNextCol; |
|
|
| __syncthreads(); |
|
|
| bool isAlreadyVisited = s_visitedMask[procRow] & VISITED_MASK; |
|
|
| if (isAlreadyVisited) { |
| continue; |
| } |
|
|
| const uint32_t procAdjCount = adjCounts[procRow]; |
| auto procAdjValues = adjValues + (procRow * maxExCount); |
|
|
| for (uint32_t i = threadRank; i < procAdjCount; i += WARP_SIZE) { |
| uint32_t adjCol = procAdjValues[i]; |
|
|
| auto group = cg::coalesced_threads(); |
| |
| |
| |
| |
| adjCol = group.shfl(adjCol, (group.thread_rank() + dfsIter) % group.size()); |
|
|
| |
| |
| |
| |
| |
| const auto oldMask = s_visitedMask[adjCol]; |
| auto newMask = oldMask; |
|
|
| bool alreadyAdded = oldMask & ADDED_MASK; |
|
|
| const uint32_t gThreadRank = group.thread_rank(); |
| uint32_t notAddedBallot = group.ballot(!alreadyAdded); |
| if (notAddedBallot) { |
| |
| |
| |
| |
| |
| const uint32_t globalStoreOff = adjCounts[anchorRow]; |
| |
| const uint32_t localStoreOff = __popc(notAddedBallot & ((1 << gThreadRank) - 1)); |
|
|
| if (!alreadyAdded) { |
| adjAnchorValues[globalStoreOff + localStoreOff] = adjCol; |
| if (adjCol > anchorRow) { |
| |
| isStart[adjCol] = false; |
| } |
| newMask |= ADDED_MASK; |
| } |
|
|
| |
| if (gThreadRank == 0) { |
| adjCounts[anchorRow] += __popc(notAddedBallot); |
| } |
| } |
|
|
| bool alreadyHandled = oldMask & QUEUED_OR_VISITED_MASK; |
|
|
| if (!alreadyHandled) { |
| #ifdef NMS_VERIFY_CORRECTNESS |
| newMask |= QUEUED_MASK; |
| ++visitPtr; |
| assert(visitPtr < VISIT_STACK_SIZE); |
| atomicMax(maxDepth, visitPtr); |
| visitStack[visitPtr] = adjCol; |
| #else |
| |
| if (visitPtr < VISIT_STACK_SIZE - 1) { |
| newMask |= QUEUED_MASK; |
| ++visitPtr; |
| visitStack[visitPtr] = adjCol; |
| } |
| #endif |
| } |
|
|
| if (newMask != oldMask) { |
| s_visitedMask[adjCol] = newMask; |
| } |
| } |
|
|
| |
| __syncthreads(); |
| } |
| } |
|
|
| void add_to_set(const torch::TensorAccessor<int32_t, 1>& adjCounts, |
| const torch::TensorAccessor<int32_t, 2>& adjValues, |
| int32_t row, |
| std::unordered_set<int32_t>& possible) |
| { |
| if (possible.count(row)) { |
| return; |
| } |
|
|
| possible.insert(row); |
|
|
| const int32_t adjCount = adjCounts[row]; |
| auto values = adjValues[row].data(); |
|
|
| for (int32_t i = 0; i < adjCount; ++i) { |
| const int32_t col = values[i]; |
| add_to_set(adjCounts, adjValues, col, possible); |
| } |
| } |
|
|
| void cpu_flatten_graph(const int32_t *ptrQuadCts, |
| torch::Tensor isStartTensorGPU, |
| torch::Tensor adjCountsTensorGPU, |
| torch::Tensor adjValuesTensorGPU) |
| { |
| auto isStartTensor = isStartTensorGPU.cpu(); |
| auto adjCountsTensor = adjCountsTensorGPU.cpu(); |
| auto adjValuesTensor = adjValuesTensorGPU.cpu(); |
|
|
| auto allIsStart = isStartTensor.accessor<bool, 2>(); |
| auto allAdjCounts = adjCountsTensor.accessor<int32_t, 2>(); |
| auto allAdjValues = adjValuesTensor.accessor<int32_t, 3>(); |
|
|
| for (int32_t b = 0; b < allAdjCounts.size(0); ++b) { |
| const int32_t quadCt = ptrQuadCts[b]; |
|
|
| for (int32_t row = 0; row < quadCt; ++row) { |
| std::unordered_set<int32_t> fullAdjSet; |
| add_to_set(allAdjCounts[b], allAdjValues[b], row, fullAdjSet); |
|
|
| int32_t &currCt = allAdjCounts[b][row]; |
| int32_t *currValues = allAdjValues[b][row].data(); |
| std::unordered_set<int32_t> existingSet{ currValues, currValues + currCt }; |
|
|
| for (int32_t adjCol : fullAdjSet) { |
| if (existingSet.count(adjCol)) { |
| continue; |
| } |
|
|
| currValues[currCt] = adjCol; |
| ++currCt; |
|
|
| if (adjCol > row) { |
| allIsStart[b][adjCol] = false; |
| } |
| } |
| } |
| } |
|
|
| isStartTensorGPU.copy_(isStartTensor); |
| adjCountsTensorGPU.copy_(adjCountsTensor); |
| adjValuesTensorGPU.copy_(adjValuesTensor); |
| } |
|
|
|
|
| __global__ |
| void device_a2a_adj_cleanup(const int32_t *counts, |
| torch::PackedTensorAccessor64<uint8_t, 3> inOutAdjacency) |
| { |
| const uint32_t b = blockIdx.y; |
| const uint32_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x; |
| const uint32_t numQuads = counts[b]; |
| const uint32_t row = jobIdx / numQuads; |
| const uint32_t col = jobIdx % numQuads; |
|
|
| if (row >= numQuads) { |
| return; |
| } |
|
|
| auto adjacency = inOutAdjacency[b]; |
|
|
| bool rowPivot = adjacency[row][row] > 0; |
| bool colPivot = adjacency[col][col] > 0; |
|
|
| if (!rowPivot || !colPivot) { |
| adjacency[row][col] = 0; |
| } |
| } |
|
|
| template<uint32_t NumWarps, typename T> |
| __global__ |
| void device_a2a_collapse(torch::PackedTensorAccessor64<int32_t, 1> quadCounts, |
| torch::PackedTensorAccessor64<T, 3> allEmbedQuads, |
| torch::PackedTensorAccessor64<bool, 2> allIsLeadRow, |
| const int64_t *regionCounts, |
| torch::PackedTensorAccessor64<int32_t, 2> allAdjCounts, |
| torch::PackedTensorAccessor64<int32_t, 3> allAdjValues, |
| |
| torch::PackedTensorAccessor64<T, 3> outQuads, |
| T *outConf) |
| { |
| constexpr uint32_t WARP_SIZE = 32; |
| constexpr uint32_t FULL_WARP = 0xFFFFFFFF; |
| constexpr uint32_t BLOCK_WIDTH = NumWarps * WARP_SIZE; |
| constexpr size_t MERGE_QUAD_SIZE = sizeof(MergeQuad_<T>) / sizeof(T); |
|
|
| static_assert(NumWarps < WARP_SIZE, "Only a single warp currently supported!"); |
|
|
| const uint32_t b = blockIdx.z; |
| const uint32_t row = blockIdx.y; |
|
|
| const int32_t quadCt = quadCounts[b]; |
|
|
| if (row >= quadCt) { |
| return; |
| } |
|
|
| |
| const auto isLeadRow = allIsLeadRow[b].data(); |
| if (!isLeadRow[row]) { |
| return; |
| } |
|
|
| const uint32_t threadRank = threadIdx.x; |
| const uint32_t localThreadRank = threadRank & 0x1F; |
| const uint32_t warpIdx = threadRank >> 5; |
|
|
| __shared__ T s_mergeQuad[MERGE_QUAD_SIZE]; |
|
|
| if constexpr (NumWarps > 1) { |
| if (threadRank < MERGE_QUAD_SIZE) { |
| s_mergeQuad[threadRank] = 0.0f; |
| } |
|
|
| __syncthreads(); |
| } |
|
|
| T *exData = allEmbedQuads[b].data(); |
|
|
| const int32_t adjCount = allAdjCounts[b][row]; |
| const int32_t *adjIdxs = allAdjValues[b][row].data(); |
|
|
| MergeQuad_<T> localMerge{ZeroInitTag{}}; |
|
|
| for (int32_t i = threadRank; i < adjCount; i += BLOCK_WIDTH) { |
| const int32_t currQuadIdx = adjIdxs[i]; |
| const StridedEmbedQuad_<T> qCurr{ exData + currQuadIdx * allEmbedQuads.stride(2), allEmbedQuads.stride(1) }; |
|
|
| localMerge.Append(qCurr); |
| } |
|
|
| T *mqV = reinterpret_cast<T*>(&localMerge); |
| #pragma unroll |
| for (uint32_t offset = 1; offset < WARP_SIZE; offset <<= 1) { |
| T mergeFactor = offset + localThreadRank < 32; |
| #pragma unroll |
| for (uint32_t i = 0; i < MERGE_QUAD_SIZE; ++i) { |
| mqV[i] += mergeFactor * __shfl_down_sync(FULL_WARP, mqV[i], offset); |
| } |
| } |
| #pragma unroll |
| for (uint32_t i = 0; i < MERGE_QUAD_SIZE; ++i) { |
| mqV[i] = __shfl_sync(FULL_WARP, mqV[i], 0); |
| } |
|
|
| |
| if (NumWarps > 1 && adjCount > WARP_SIZE) { |
| if (localThreadRank < MERGE_QUAD_SIZE) { |
| atomicAdd(s_mergeQuad + localThreadRank, mqV[localThreadRank]); |
| } |
|
|
| __syncthreads(); |
|
|
| mqV = s_mergeQuad; |
| } |
|
|
| |
| uint32_t writePosition = 0; |
| for (int32_t i = threadRank; i < b; i += BLOCK_WIDTH) { |
| writePosition += regionCounts[i]; |
| } |
|
|
| const uint8_t *pCurrIsLeadRow = reinterpret_cast<const uint8_t*>(isLeadRow); |
| for (int32_t i = threadRank; i < row; i += BLOCK_WIDTH) { |
| if (pCurrIsLeadRow[i]) { |
| ++writePosition; |
| } |
| } |
| |
| writePosition = __reduce_add_full_warp(writePosition); |
| |
| if constexpr (NumWarps > 1) { |
| __shared__ uint32_t s_threadWritePositions[NumWarps]; |
| if (localThreadRank == 0) { |
| s_threadWritePositions[warpIdx] = writePosition; |
| } |
| __syncthreads(); |
| writePosition = threadRank < NumWarps ? s_threadWritePositions[threadRank] : 0; |
| writePosition = __reduce_add_full_warp(writePosition); |
| } |
|
|
| if (threadRank >= 9) { |
| return; |
| } |
|
|
| const T sumConfidence = mqV[8]; |
| const T numQuads = mqV[9]; |
| const T divisor = threadRank < 8 ? sumConfidence : numQuads; |
|
|
| const T myVal = mqV[threadRank] / divisor; |
|
|
| auto writeVerts = outQuads[writePosition].data(); |
|
|
| if (threadRank < 8) { |
| writeVerts[threadRank] = myVal; |
| } else { |
| outConf[writePosition] = myVal; |
| } |
| } |
|
|
| struct CollapseRowsResult { |
| torch::Tensor ExCounts; |
| torch::Tensor StridedMergeQuads; |
| int32_t TotalNumQuads; |
| |
| torch::Tensor QuadIds; |
| int32_t ImageWidth; |
| int32_t ImageHeight; |
| }; |
|
|
| template<typename scalar_t> |
| CollapseRowsResult collapse_rows( |
| torch::Tensor quads, torch::Tensor probs, scalar_t probThreshold, scalar_t iouThreshold |
| ) |
| { |
| if (! quads.is_contiguous()) { |
| throw std::runtime_error("Expected `quads` to be contiguous!"); |
| } |
|
|
| if ((quads.size(2) % 32) != 0) { |
| throw std::runtime_error("Expected the width of the `quads` buffer to be a multiple of 32!"); |
| } |
|
|
| int32_t imageWidth = quads.size(2) * 4; |
| int32_t imageHeight = quads.size(1) * 4; |
|
|
| quads = quads.reshape({ quads.size(0), -1, 32, 4, 2 }); |
| probs = probs.reshape({ probs.size(0), -1, 32 }); |
|
|
| if (quads.size(0) != probs.size(0) || quads.size(1) != probs.size(1)) { |
| throw std::runtime_error("Dimension mismatch between `quads` and `probs`"); |
| } |
|
|
| |
| auto counts = torch::zeros({ quads.size(0) + 1 }, quads.options().dtype(torch::kInt32)); |
|
|
| int64_t embedSize = sizeof(EmbedQuad_<scalar_t>) / sizeof(scalar_t); |
| auto rowMergeTensor = torch::empty({ quads.size(0), embedSize, quads.size(1) * quads.size(2) }, quads.options()); |
|
|
| auto idsTensor = torch::full({ quads.size(0), quads.size(1) * quads.size(2) }, |
| std::numeric_limits<int32_t>::max(), |
| counts.options().dtype(torch::kInt32)); |
|
|
| dim3 blockSize(32, 3, 1); |
| dim3 gridSize(1, |
| div_up(quads.size(1), blockSize.y), |
| quads.size(0)); |
|
|
| device_row_collapse KERNEL_ARG2(gridSize, blockSize) ( |
| quads.packed_accessor64<scalar_t, 5>(), |
| probs.packed_accessor64<scalar_t, 3>(), |
| probThreshold, iouThreshold, |
| counts.packed_accessor64<int32_t, 1>(), |
| rowMergeTensor.packed_accessor64<scalar_t, 3>(), |
| idsTensor.packed_accessor64<int32_t, 2>() |
| ); |
|
|
| #ifdef NMS_VERIFY_CORRECTNESS |
| static std::unordered_set<int32_t> s_quadIds; |
| auto cpuIdsTensor = idsTensor.cpu(); |
| const int32_t *idsPtr = cpuIdsTensor.data_ptr<int32_t>(); |
| if (s_quadIds.empty()) { |
| s_quadIds.insert(idsPtr, idsPtr + idsTensor.numel()); |
| } else { |
| std::unordered_set<int32_t> otherIds{ idsPtr, idsPtr + idsTensor.numel() }; |
|
|
| if (s_quadIds != otherIds) { |
| throw std::runtime_error("Inconsistent Ids!"); |
| } |
| } |
| #endif |
|
|
| |
| int32_t totalQuads = counts[-1].item<int32_t>(); |
|
|
| counts = counts.slice( 0, 0, counts.size(0) - 1); |
|
|
| int64_t maxExCount; |
| if (counts.size(0) > 1) { |
| maxExCount = counts.max().item<int32_t>(); |
| } else { |
| maxExCount = totalQuads; |
| } |
|
|
| static bool s_sortOrder = false; |
|
|
| rowMergeTensor = rowMergeTensor.slice(2, 0, maxExCount); |
| idsTensor = idsTensor.slice(1, 0, maxExCount); |
| auto order = torch::argsort(idsTensor, 1, s_sortOrder); |
|
|
| auto embOrder = order.unsqueeze(1).expand_as(rowMergeTensor); |
|
|
| rowMergeTensor = torch::gather(rowMergeTensor, 2, embOrder); |
| idsTensor = torch::gather(idsTensor, 1, order); |
|
|
| return { counts, rowMergeTensor, totalQuads, idsTensor, imageWidth, imageHeight }; |
| } |
|
|
|
|
|
|
| void verify_row(const torch::TensorAccessor<int32_t, 1> &adjCounts, |
| const torch::TensorAccessor<int32_t, 2> &adjValues, |
| int32_t row) |
| { |
| |
| |
| |
| std::unordered_set<int32_t> possible; |
| add_to_set(adjCounts, adjValues, row, possible); |
|
|
| std::unordered_set<int32_t> thisRow{ row }; |
| const int32_t thisCount = adjCounts[row]; |
| auto thisValues = adjValues[row].data(); |
| thisRow.insert(thisValues, thisValues + thisCount); |
|
|
| if (thisRow != possible) { |
| throw std::runtime_error("The merge_up algorithm is not correct!"); |
| } |
| } |
|
|
| struct AdjacencyResult { |
| |
| |
| torch::Tensor IsLeadRow; |
| |
| |
| torch::Tensor AdjCounts; |
| |
| |
| torch::Tensor AdjValues; |
| int64_t MaxExCount; |
| }; |
|
|
| template<typename T> |
| void cpu_a2a_adjacency_sparse(const int32_t *ptrQuadCts, |
| const T iouThreshold, |
| torch::Tensor embedQuadsTensor, |
| torch::Tensor outIsStartTensorGPU, |
| torch::Tensor outAdjCountsTensorGPU, |
| torch::Tensor outSparseAdjTensorGPU) |
| { |
| embedQuadsTensor = embedQuadsTensor.cpu(); |
| auto outIsStartTensor = outIsStartTensorGPU.cpu(); |
| auto outAdjCountsTensor = outAdjCountsTensorGPU.cpu(); |
| auto outSparseAdjTensor = outSparseAdjTensorGPU.cpu(); |
|
|
| auto embedQuads = embedQuadsTensor.accessor<T, 3>(); |
| auto isStart = outIsStartTensor.accessor<bool, 2>(); |
| auto adjCounts = outAdjCountsTensor.accessor<int32_t, 2>(); |
| auto adjValues = outSparseAdjTensor.accessor<int32_t, 3>(); |
|
|
| for (int32_t b = 0; b < embedQuadsTensor.size(0); ++b) { |
| const int32_t quadCt = ptrQuadCts[b]; |
|
|
| T *exData = embedQuads[b].data(); |
|
|
| for (int32_t row = 0; row < quadCt; ++row) { |
| const auto qRow = StridedEmbedQuad_<T>{ exData + row, embedQuads.stride(1) }.Bounds(); |
|
|
| for (int32_t col = 0; col < quadCt; ++col) { |
| const auto qCol = StridedEmbedQuad_<T>{ exData + col, embedQuads.stride(1) }.Bounds(); |
|
|
| T pctRow, pctCol, iou; |
| thrust::tie(pctRow, pctCol, iou) = geometry_region_sizes(qRow, qCol); |
|
|
| if (iou >= iouThreshold) { |
| int32_t &storeIdx = adjCounts[b][row]; |
| adjValues[b][row][storeIdx] = col; |
| ++storeIdx; |
| if (row < col) { |
| isStart[b][col] = false; |
| } |
| } else if (pctRow > 0.8f || pctCol > 0.8f) { |
| T anchorHeight = qRow.Height(); |
| T otherHeight = qCol.Height(); |
|
|
| T ratio = anchorHeight > otherHeight ? |
| otherHeight / anchorHeight : |
| anchorHeight / otherHeight; |
| if (ratio > 0.9f) { |
| if (pctRow > 0.8f) { |
| |
| isStart[b][row] = false; |
| } |
| else { |
| isStart[b][col] = false; |
| } |
| } |
| } |
| } |
| } |
| } |
|
|
| outIsStartTensorGPU.copy_(outIsStartTensor); |
| outAdjCountsTensorGPU.copy_(outAdjCountsTensor); |
| outSparseAdjTensorGPU.copy_(outSparseAdjTensor); |
| } |
|
|
| template<typename T> |
| std::string to_flat_string(torch::Tensor tensor) { |
| tensor = tensor.flatten(); |
|
|
| auto acc = tensor.accessor<T, 1>(); |
|
|
| std::ostringstream oss; |
| oss << "["; |
| if (acc.size(0) > 0) { |
| oss << acc[0]; |
| for (int64_t i = 1; i < acc.size(0); ++i) { |
| oss << ", " << acc[i]; |
| } |
| } |
| oss << "]"; |
| return oss.str(); |
| } |
|
|
| template<typename scalar_t> |
| AdjacencyResult compute_all_to_all_adjacency( |
| const CollapseRowsResult &collapseResult, |
| scalar_t iouThreshold) |
| { |
| torch::Tensor counts = collapseResult.ExCounts; |
|
|
| int64_t maxExCount; |
| if (counts.size(0) > 1) { |
| maxExCount = counts.max().item<int32_t>(); |
| } else { |
| maxExCount = collapseResult.TotalNumQuads; |
| } |
|
|
| auto isStartTensor = torch::ones({ counts.size(0), maxExCount }, counts.options().dtype(torch::kBool)); |
| auto adjCountsTensor = torch::zeros({ counts.size(0), maxExCount }, counts.options().dtype(torch::kInt32)); |
| #ifndef NMS_VERIFY_CORRECTNESS |
| auto adjValuesTensor = torch::empty({ counts.size(0), maxExCount, maxExCount }, counts.options().dtype(torch::kInt32)); |
| #else |
| auto adjValuesTensor = torch::full({ counts.size(0), maxExCount, maxExCount }, |
| 5000, |
| counts.options().dtype(torch::kInt32)); |
| #endif |
|
|
| #ifdef NMS_VERIFY_CORRECTNESS |
| auto cpuAdjValuesTensor = adjValuesTensor.cpu(); |
| auto cpuAdjCountsTensor = adjCountsTensor.cpu(); |
| auto cpuIsStartTensor = isStartTensor.cpu(); |
| #endif |
|
|
| size_t smemSize; |
| dim3 gridSize, blockSize; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| uint32_t totalWork = maxExCount * maxExCount; |
|
|
| blockSize = dim3{96, 1}; |
| gridSize = dim3{div_up(totalWork, blockSize.x), |
| static_cast<uint32_t>(counts.size(0))}; |
|
|
| |
| device_a2a_adjacency_sparse<scalar_t> KERNEL_ARG2(gridSize, blockSize) ( |
| counts.data_ptr<int32_t>(), |
| iouThreshold, |
| collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(), |
| isStartTensor.packed_accessor64<bool, 2>(), |
| adjCountsTensor.packed_accessor64<int32_t, 2>(), |
| adjValuesTensor.packed_accessor64<int32_t, 3>() |
| ); |
|
|
|
|
| #ifdef NMS_VERIFY_CORRECTNESS |
| auto cpuCounts = counts.cpu(); |
|
|
| cpu_a2a_adjacency_sparse<scalar_t>(cpuCounts.data_ptr<int32_t>(), iouThreshold, |
| collapseResult.StridedMergeQuads, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor); |
|
|
| adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, 2)); |
|
|
| assert(torch::all(cpuAdjCountsTensor == adjCountsTensor.cpu()).item<bool>()); |
| assert(torch::all(cpuIsStartTensor == isStartTensor.cpu()).item<bool>()); |
| assert(torch::all(cpuAdjValuesTensor == adjValuesTensor.cpu()).item<bool>()); |
|
|
| std::cout << "\tA2A Is Start Count: " << isStartTensor.sum(torch::kInt32).item<int32_t>() |
| << ", Most Adjacent: " << adjCountsTensor.max().item<int32_t>() << std::endl; |
|
|
| auto maxDepthTensor = torch::tensor(0, adjCountsTensor.options()); |
| #endif |
|
|
| blockSize = dim3{ 128, 1, 1 }; |
| gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) }; |
| smemSize = div_up(maxExCount * sizeof(visit_mask_t), sizeof(uint32_t)) * sizeof(uint32_t); |
|
|
| device_flatten_graph_iterative KERNEL_ARG3(gridSize, blockSize, smemSize) ( |
| counts.data_ptr<int32_t>(), |
| isStartTensor.packed_accessor64<bool, 2>(), |
| reinterpret_cast<uint32_t*>(adjCountsTensor.data_ptr<int32_t>()), |
| reinterpret_cast<uint32_t*>(adjValuesTensor.data_ptr<int32_t>()) |
| #ifdef NMS_VERIFY_CORRECTNESS |
| , maxDepthTensor.data_ptr<int32_t>() |
| #endif |
| ); |
|
|
| #ifdef NMS_VERIFY_CORRECTNESS |
| cpu_flatten_graph(cpuCounts.data_ptr<int32_t>(), cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor); |
|
|
| cpuAdjValuesTensor = std::get<0>(torch::sort(cpuAdjValuesTensor, 2)); |
| adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, 2)); |
|
|
| torch::Tensor diffStartIdxs = (cpuIsStartTensor != isStartTensor.cpu()).nonzero_numpy()[0]; |
|
|
| assert(diffStartIdxs.numel() == 0); |
|
|
| torch::Tensor diffCountIdxs = (cpuAdjCountsTensor != adjCountsTensor.cpu()).nonzero_numpy()[0]; |
|
|
| assert(diffCountIdxs.numel() == 0); |
|
|
| auto diffValuesTensor = torch::any(cpuAdjValuesTensor != adjValuesTensor.cpu(), 2, false).flatten().nonzero().flatten(); |
|
|
| std::cout << "\t\tDiff Indices: " << to_flat_string<int64_t>(diffValuesTensor) << std::endl; |
|
|
| auto cpuDiffCountsTensor = cpuAdjCountsTensor.flatten().index({ diffValuesTensor }); |
| auto cpuDiffRowsTensor = cpuAdjValuesTensor.flatten(0, 1).index({ diffValuesTensor }); |
| auto gpuDiffRowsTensor = adjValuesTensor.cpu().flatten(0, 1).index({ diffValuesTensor }); |
|
|
| for (int64_t i = 0, ct = cpuDiffRowsTensor.size(0); i < ct; ++i) { |
| auto z = cpuDiffCountsTensor[i].item<int32_t>(); |
| auto diffRow = diffValuesTensor[i].item<int64_t>(); |
| std::cout << "\t\tRow " << diffRow << std::endl; |
| std::cout << "\t\t\tExpected: " << to_flat_string<int32_t>(cpuDiffRowsTensor[i].slice(0, 0, z + 1)) << std::endl; |
| std::cout << "\t\t\t GPU: " << to_flat_string<int32_t>(gpuDiffRowsTensor[i].slice(0, 0, z + 1)) << std::endl; |
| } |
|
|
| assert(diffValuesTensor.size(0) == 0); |
|
|
| std::cout << "\tA2A - Flatten - Is Start Count: " << isStartTensor.sum(torch::kInt32).item<int32_t>() |
| << ", Most Adjacent: " << adjCountsTensor.max().item<int32_t>() |
| << ", Max Depth: " << maxDepthTensor.item<int32_t>() << std::endl; |
|
|
| cpuIsStartTensor = isStartTensor.cpu(); |
| cpuAdjCountsTensor = adjCountsTensor.cpu(); |
| cpuAdjValuesTensor = adjValuesTensor.cpu(); |
| auto cpuCollapseIds = collapseResult.QuadIds.cpu(); |
|
|
| static std::vector<std::unordered_set<int32_t>> s_knownGroups; |
| static std::unordered_map<int32_t, std::unordered_set<int32_t>> s_groupLookup; |
|
|
| std::vector<std::unordered_set<int32_t>> idGroups; |
| decltype(s_groupLookup) groupLookup; |
| for (int64_t b = 0; b < counts.size(0); ++b) { |
| int64_t quadCt = cpuCounts[b].item<int32_t>(); |
| for (int64_t row = 0; row < quadCt; ++row) { |
| bool isLeadRow = cpuIsStartTensor[b][row].item<bool>(); |
| auto bCountsTensor = cpuAdjCountsTensor[b]; |
| auto bValuesTensor = cpuAdjValuesTensor[b]; |
| auto bCounts = bCountsTensor.accessor<int32_t, 1>(); |
| auto bValues = bValuesTensor.accessor<int32_t, 2>(); |
|
|
| auto bIdsTensor = cpuCollapseIds[b]; |
| auto bIds = bIdsTensor.accessor<int32_t, 1>(); |
|
|
| std::unordered_set<int32_t> sIds; |
| for (int32_t i = 0, ct = bCounts[row]; i < ct; ++i) { |
| int32_t col = bValues[row][i]; |
| int32_t id = bIds[col]; |
| sIds.insert(id); |
| } |
|
|
| if (sIds.empty()) { |
| throw std::runtime_error("The ids tensor is empty!"); |
| } |
|
|
| groupLookup[bIds[row]] = sIds; |
|
|
| if (isLeadRow) { |
| verify_row(bCounts, bValues, row); |
| idGroups.push_back(move(sIds)); |
| } |
| } |
| } |
|
|
| if (s_knownGroups.empty()) { |
| s_knownGroups = move(idGroups); |
| s_groupLookup = move(groupLookup); |
| } else { |
| |
| auto remOrigGroups = s_knownGroups; |
| auto remOrigGroupLookup = s_groupLookup; |
|
|
| std::vector<int32_t> quadIds; |
| for (auto &kv : remOrigGroupLookup) { |
| quadIds.push_back(kv.first); |
| } |
| for (int32_t qId : quadIds) { |
| assert(groupLookup.count(qId)); |
| } |
| assert(groupLookup.size() == remOrigGroupLookup.size()); |
|
|
| for (int32_t qId : quadIds) { |
| auto &oldGroup = remOrigGroupLookup[qId]; |
| auto &newGroup = groupLookup[qId]; |
|
|
| if (oldGroup == newGroup) { |
| remOrigGroupLookup.erase(qId); |
| groupLookup.erase(qId); |
| } else { |
| throw std::runtime_error("Group mismatch!"); |
| } |
| } |
|
|
| for (int i = idGroups.size() - 1; i >= 0; --i) { |
| for (int j = remOrigGroups.size() - 1; j >= 0; --j) { |
| auto &idGroup = idGroups[i]; |
| auto &knownGroup = remOrigGroups[j]; |
|
|
| if (idGroup == knownGroup) { |
| idGroups.erase(begin(idGroups) + i); |
| remOrigGroups.erase(begin(remOrigGroups) + j); |
| break; |
| } |
| } |
| } |
|
|
| if (!idGroups.empty() || !remOrigGroups.empty()) { |
| auto group_str = [] (auto &group) { |
| std::vector<int32_t> vGroup{ std::begin(group), std::end(group) }; |
| std::sort(std::begin(vGroup), std::end(vGroup)); |
|
|
| auto id_str = [] (int32_t id) { |
| std::ostringstream oss; |
| |
| oss << id; |
| return oss.str(); |
| }; |
|
|
| std::ostringstream oss; |
| oss << "[" << id_str(vGroup[0]); |
| for (size_t i = 1; i < vGroup.size(); ++i) { |
| oss << ", " << id_str(vGroup[i]); |
| } |
| oss << "]"; |
| return oss.str(); |
| }; |
|
|
| std::cout << "\tEncountered a difference in groups!" << std::endl |
| << "\t\tOrig groups:" << std::endl; |
| for (auto &group : remOrigGroups) { |
| std::cout << "\t\t\t" << group_str(group) << std::endl; |
| } |
| std::cout << "\t\tNew groups:" << std::endl; |
| for (auto &group : idGroups) { |
| std::cout << "\t\t\t" << group_str(group) << std::endl; |
| } |
| } |
| } |
| #endif |
|
|
| return { isStartTensor, adjCountsTensor, adjValuesTensor, maxExCount }; |
| } |
|
|
|
|
|
|
| template<typename scalar_t> |
| nms_result_t |
| all_to_all_collapse( |
| const CollapseRowsResult &collapseRowsRes, |
| const AdjacencyResult &adjResult) |
| { |
| auto counts = collapseRowsRes.ExCounts; |
| auto embedQuads = collapseRowsRes.StridedMergeQuads; |
|
|
| if (!embedQuads.is_contiguous()) { |
| throw std::runtime_error("Input embed quads were not contiguous!"); |
| } |
|
|
| torch::Tensor isLeadRow; |
| if (counts.size(0) == 1) { |
| isLeadRow = adjResult.IsLeadRow; |
| } else { |
| |
| |
| |
| isLeadRow = torch::logical_and(adjResult.IsLeadRow, adjResult.AdjCounts > 0); |
| } |
|
|
| auto regionCounts = isLeadRow.sum( 1, false, torch::kInt64); |
|
|
| const int64_t numOutQuads = counts.size(0) == 1 ? regionCounts.item<int64_t>() : regionCounts.sum().item<int64_t>(); |
|
|
| constexpr int32_t NUM_WARPS = 4; |
| dim3 blockSize(NUM_WARPS * 32, 1, 1); |
| dim3 gridSize(1, adjResult.MaxExCount, counts.size(0)); |
|
|
| torch::Tensor outQuads = torch::empty({ numOutQuads, 4, 2 }, embedQuads.options()); |
| torch::Tensor outConf = torch::empty({ numOutQuads }, embedQuads.options()); |
|
|
| device_a2a_collapse<NUM_WARPS, scalar_t> KERNEL_ARG2(gridSize, blockSize) ( |
| counts.packed_accessor64<int32_t, 1>(), |
| embedQuads.packed_accessor64<scalar_t, 3>(), |
| isLeadRow.packed_accessor64<bool, 2>(), |
| regionCounts.data_ptr<int64_t>(), |
| adjResult.AdjCounts.packed_accessor64<int32_t, 2>(), |
| adjResult.AdjValues.packed_accessor64<int32_t, 3>(), |
| outQuads.packed_accessor64<scalar_t, 3>(), |
| outConf.data_ptr<scalar_t>() |
| ); |
|
|
| return { outQuads, outConf, regionCounts }; |
| } |
|
|
| template<typename scalar_t> |
| nms_result_t cuda_quad_non_maximal_suppression_impl( |
| torch::Tensor quads, torch::Tensor probs, |
| scalar_t probThreshold, scalar_t iouThreshold, |
| int64_t maxRegions, bool verbose) |
| { |
| static const bool s_timerEnabled = true; |
| static const bool s_verboseLevel2 = true; |
|
|
| |
| if (quads.dim() == 4) { |
| |
| quads = quads.unsqueeze(0); |
| |
| probs = probs.unsqueeze(0); |
| } |
|
|
| |
|
|
| double msRowCollapse = -1, |
| msAdjacency = -1, |
| msA2ACollapse = -1, |
| msTotal = -1; |
|
|
| CollapseRowsResult collapseRows; |
| AdjacencyResult adjacency; |
| torch::Tensor retQuads, retConf, regionCounts; |
|
|
| { |
| CudaStoreTimer tTotal{msTotal, s_timerEnabled}; |
| { |
| CudaStoreTimer t{msRowCollapse, s_timerEnabled && verbose && s_verboseLevel2}; |
|
|
| |
| collapseRows = collapse_rows(quads, probs, probThreshold, iouThreshold); |
|
|
| if (collapseRows.TotalNumQuads == 0) { |
| return { |
| torch::empty({ 0, 4, 2 }, quads.options()), |
| torch::empty({ 0 }, probs.options()), |
| collapseRows.ExCounts.toType(torch::kInt64) |
| }; |
| } |
| } |
| { |
| CudaStoreTimer t{msAdjacency, s_timerEnabled && verbose && s_verboseLevel2}; |
| adjacency = compute_all_to_all_adjacency(collapseRows, iouThreshold); |
| } |
| { |
| CudaStoreTimer t{msA2ACollapse, s_timerEnabled && verbose && s_verboseLevel2}; |
| std::tie(retQuads, retConf, regionCounts) = all_to_all_collapse<scalar_t>(collapseRows, adjacency); |
| } |
| } |
|
|
| #ifndef NDEBUG |
| assert(regionCounts.sum().item<int64_t>() == retQuads.size(0)); |
| #endif |
|
|
| |
|
|
| if (s_timerEnabled && verbose) { |
| std::cout << "NMS Cuda " << retQuads.size(0) |
| << " - Row Collapse (" << quads.size(0) << ", " << quads.size(1) << ", " << quads.size(2) << ") - (" << collapseRows.TotalNumQuads << "): " << msRowCollapse << "ms" |
| << ", Adjacency (" << adjacency.AdjCounts.sum(torch::kInt32).item<int32_t>() << "): " << msAdjacency << "ms" |
| << ", A2A Collapse (" << retQuads.size(0) << "): " << msA2ACollapse << "ms" |
| << ", Total: " << msTotal << "ms" |
| << std::endl; |
| } |
|
|
| return { retQuads, retConf, regionCounts }; |
| } |
|
|
| nms_result_t cuda_quad_non_maximal_suppression( |
| torch::Tensor quads, torch::Tensor probs, |
| float probThreshold, float iouThreshold, |
| int64_t kernelHeight, int64_t kernelWidth, |
| int64_t maxRegions, bool verbose) |
| { |
| nms_result_t ret; |
|
|
| ret = cuda_quad_non_maximal_suppression_impl<float>( |
| quads.toType(torch::kFloat32), probs.toType(torch::kFloat32), |
| probThreshold, iouThreshold, |
| maxRegions, verbose |
| ); |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| return ret; |
| } |
|
|