Charles Blackmon-Luca commited on
Fix misaligned address error in quad NMS cuda implementation
Browse files
nemotron-ocr/cpp/non_maximal_suppression/cuda_non_maximal_suppression.cu
CHANGED
|
@@ -157,11 +157,8 @@ void device_row_collapse(torch::PackedTensorAccessor64<T, 5> allQuads,
|
|
| 157 |
torch::PackedTensorAccessor64<T, 3> allConfs,
|
| 158 |
T confThreshold, T iouThreshold,
|
| 159 |
torch::PackedTensorAccessor64<int32_t, 1> allOutCounts,
|
| 160 |
-
torch::PackedTensorAccessor64<T, 3> allOutEmbedQuads
|
| 161 |
-
|
| 162 |
-
, torch::PackedTensorAccessor64<int32_t, 2> allOutIds
|
| 163 |
-
#endif
|
| 164 |
-
)
|
| 165 |
{
|
| 166 |
typedef InPlaceQuad_<T> Quadf;
|
| 167 |
static_assert(sizeof(Quadf) == sizeof(T) * 8, "Invalid QuadMem size!");
|
|
@@ -306,11 +303,9 @@ void device_row_collapse(torch::PackedTensorAccessor64<T, 5> allQuads,
|
|
| 306 |
}
|
| 307 |
|
| 308 |
write_embed_quad(outEmbedQuads, outQuad, storeOff + procLabel - 1);
|
| 309 |
-
#ifdef NMS_VERIFY_CORRECTNESS
|
| 310 |
if (threadRank == 0) {
|
| 311 |
allOutIds[b][storeOff + procLabel - 1] = r * 32 + startIdx;
|
| 312 |
}
|
| 313 |
-
#endif
|
| 314 |
}
|
| 315 |
|
| 316 |
if (threadRank == 0) {
|
|
@@ -321,9 +316,9 @@ void device_row_collapse(torch::PackedTensorAccessor64<T, 5> allQuads,
|
|
| 321 |
#undef threadRank
|
| 322 |
}
|
| 323 |
|
| 324 |
-
template<
|
| 325 |
__global__
|
| 326 |
-
void device_a2a_adjacency_sparse(const
|
| 327 |
T iouThreshold,
|
| 328 |
torch::PackedTensorAccessor64<T, 3> embedQuads,
|
| 329 |
torch::PackedTensorAccessor64<bool, 2> outIsStart,
|
|
@@ -332,7 +327,11 @@ void device_a2a_adjacency_sparse(const uint64_t punCounts,
|
|
| 332 |
{
|
| 333 |
const uint32_t b = blockIdx.y;
|
| 334 |
|
| 335 |
-
const int32_t quadCt =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
const int32_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 338 |
const int32_t row = jobIdx / quadCt;
|
|
@@ -343,7 +342,7 @@ void device_a2a_adjacency_sparse(const uint64_t punCounts,
|
|
| 343 |
return;
|
| 344 |
}
|
| 345 |
|
| 346 |
-
T* exData =
|
| 347 |
|
| 348 |
const auto qRow = StridedEmbedQuad_<T>{ exData + row * embedQuads.stride(2), embedQuads.stride(1) }.Bounds(),
|
| 349 |
qCol = StridedEmbedQuad_<T>{ exData + col * embedQuads.stride(2), embedQuads.stride(1) }.Bounds();
|
|
@@ -405,9 +404,9 @@ void device_a2a_adjacency_sparse(const uint64_t punCounts,
|
|
| 405 |
}
|
| 406 |
}
|
| 407 |
|
| 408 |
-
template<uint32_t NumWarps,
|
| 409 |
__global__
|
| 410 |
-
void device_a2a_adjacency_build_grid(const
|
| 411 |
torch::PackedTensorAccessor64<T, 3> embedQuads,
|
| 412 |
torch::PackedTensorAccessor64<int32_t, 4> outGridCells,
|
| 413 |
torch::PackedTensorAccessor64<int32_t, 3> outQuadCells)
|
|
@@ -423,10 +422,10 @@ void device_a2a_adjacency_build_grid(const uint64_t punCounts,
|
|
| 423 |
|
| 424 |
const uint32_t b = blockIdx.z;
|
| 425 |
|
| 426 |
-
const uint32_t quadCt =
|
| 427 |
const uint32_t quadIdx = blockIdx.y;
|
| 428 |
|
| 429 |
-
if (
|
| 430 |
return;
|
| 431 |
}
|
| 432 |
|
|
@@ -485,9 +484,9 @@ void device_a2a_adjacency_build_grid(const uint64_t punCounts,
|
|
| 485 |
|
| 486 |
typedef uint8_t visit_mask_t;
|
| 487 |
|
| 488 |
-
template<uint32_t NumWarps,
|
| 489 |
__global__
|
| 490 |
-
void device_a2a_adjacency_with_grid(const
|
| 491 |
T iouThreshold,
|
| 492 |
torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
|
| 493 |
torch::PackedTensorAccessor64<int32_t, 4> allCells,
|
|
@@ -503,10 +502,10 @@ void device_a2a_adjacency_with_grid(const uint64_t punCounts,
|
|
| 503 |
|
| 504 |
const uint32_t b = blockIdx.z;
|
| 505 |
|
| 506 |
-
const uint32_t quadCt =
|
| 507 |
const uint32_t quadIdx = blockIdx.y;
|
| 508 |
|
| 509 |
-
if (
|
| 510 |
return;
|
| 511 |
}
|
| 512 |
|
|
@@ -535,7 +534,7 @@ void device_a2a_adjacency_with_grid(const uint64_t punCounts,
|
|
| 535 |
auto exAdjCounts = reinterpret_cast<uint32_t*>(outAdjCounts[b].data());
|
| 536 |
auto exAdjValues = outSparseAdj[b][quadIdx].data();
|
| 537 |
|
| 538 |
-
T *exData =
|
| 539 |
|
| 540 |
const auto bdsAnchor = Quad_<T>{ s_quadVerts }.Bounds();
|
| 541 |
|
|
@@ -599,9 +598,8 @@ void device_a2a_adjacency_with_grid(const uint64_t punCounts,
|
|
| 599 |
}
|
| 600 |
}
|
| 601 |
|
| 602 |
-
template<bool IsSingleExample>
|
| 603 |
__global__
|
| 604 |
-
void device_flatten_graph_iterative(const
|
| 605 |
torch::PackedTensorAccessor64<bool, 2> allIsStart,
|
| 606 |
volatile uint32_t *allAdjCounts,
|
| 607 |
volatile uint32_t *allAdjValues
|
|
@@ -622,14 +620,12 @@ void device_flatten_graph_iterative(const uint64_t punCounts,
|
|
| 622 |
const uint32_t b = blockIdx.z;
|
| 623 |
const uint32_t anchorRow = blockIdx.y;
|
| 624 |
|
| 625 |
-
const uint32_t quadCt =
|
| 626 |
|
| 627 |
// Only need to check this if there are multiple examples, since in the case of a single example,
|
| 628 |
// the grid is precisely sized to that quadCt
|
| 629 |
-
if
|
| 630 |
-
|
| 631 |
-
return;
|
| 632 |
-
}
|
| 633 |
}
|
| 634 |
|
| 635 |
auto isStart = allIsStart[b].data();
|
|
@@ -690,12 +686,13 @@ void device_flatten_graph_iterative(const uint64_t punCounts,
|
|
| 690 |
visitStack[1] = anchorRow;
|
| 691 |
#ifndef NDEBUG
|
| 692 |
for (uint32_t i = 2; i < VISIT_STACK_SIZE; ++i) {
|
| 693 |
-
visitStack[i] =
|
| 694 |
}
|
| 695 |
#endif
|
| 696 |
int32_t visitPtr = 1;
|
| 697 |
|
| 698 |
-
|
|
|
|
| 699 |
#ifdef NMS_VERIFY_CORRECTNESS
|
| 700 |
assert(visitPtr >= 0 && visitPtr < VISIT_STACK_SIZE);
|
| 701 |
#endif
|
|
@@ -707,7 +704,7 @@ void device_flatten_graph_iterative(const uint64_t punCounts,
|
|
| 707 |
if (threadNextCol == warpNextCol) {
|
| 708 |
#ifndef NDEBUG
|
| 709 |
// This makes it easier to debug where the pointer is
|
| 710 |
-
visitStack[visitPtr] =
|
| 711 |
#endif
|
| 712 |
--visitPtr;
|
| 713 |
}
|
|
@@ -731,12 +728,15 @@ void device_flatten_graph_iterative(const uint64_t punCounts,
|
|
| 731 |
const uint32_t procAdjCount = adjCounts[procRow];
|
| 732 |
auto procAdjValues = adjValues + (procRow * maxExCount);
|
| 733 |
|
| 734 |
-
// Offsetting by the iteration number will help balance out the maximum depth of any stack in the warp.
|
| 735 |
-
// The reason behind this is due to how otherwise, warp-0 will always get a new element, warp-1 iff the adj graph
|
| 736 |
-
// has more than one element, warp-2 iff the adj graph has more than two elements, and so on. Basically,
|
| 737 |
-
// the warps have decreasing pressure. With the rotation mechanism, it helps to balance out stack usage.
|
| 738 |
for (uint32_t i = threadRank; i < procAdjCount; i += WARP_SIZE) {
|
| 739 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 740 |
|
| 741 |
// This will set the queued flag for this column, if it's not already set.
|
| 742 |
// It also returns the old state. In our case, we only want to add this value to the
|
|
@@ -748,7 +748,6 @@ void device_flatten_graph_iterative(const uint64_t punCounts,
|
|
| 748 |
|
| 749 |
bool alreadyAdded = oldMask & ADDED_MASK;
|
| 750 |
|
| 751 |
-
auto group = cg::coalesced_threads();
|
| 752 |
const uint32_t gThreadRank = group.thread_rank();
|
| 753 |
uint32_t notAddedBallot = group.ballot(!alreadyAdded);
|
| 754 |
if (notAddedBallot) {
|
|
@@ -825,8 +824,7 @@ void add_to_set(const torch::TensorAccessor<int32_t, 1>& adjCounts,
|
|
| 825 |
}
|
| 826 |
}
|
| 827 |
|
| 828 |
-
|
| 829 |
-
void cpu_flatten_graph(const uint64_t punCounts,
|
| 830 |
torch::Tensor isStartTensorGPU,
|
| 831 |
torch::Tensor adjCountsTensorGPU,
|
| 832 |
torch::Tensor adjValuesTensorGPU)
|
|
@@ -840,7 +838,7 @@ void cpu_flatten_graph(const uint64_t punCounts,
|
|
| 840 |
auto allAdjValues = adjValuesTensor.accessor<int32_t, 3>();
|
| 841 |
|
| 842 |
for (int32_t b = 0; b < allAdjCounts.size(0); ++b) {
|
| 843 |
-
const int32_t quadCt =
|
| 844 |
|
| 845 |
for (int32_t row = 0; row < quadCt; ++row) {
|
| 846 |
std::unordered_set<int32_t> fullAdjSet;
|
|
@@ -895,9 +893,9 @@ void device_a2a_adj_cleanup(const int32_t *counts,
|
|
| 895 |
}
|
| 896 |
}
|
| 897 |
|
| 898 |
-
template<uint32_t NumWarps, typename T
|
| 899 |
__global__
|
| 900 |
-
void device_a2a_collapse(
|
| 901 |
torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
|
| 902 |
torch::PackedTensorAccessor64<bool, 2> allIsLeadRow,
|
| 903 |
const int64_t *regionCounts,
|
|
@@ -917,16 +915,14 @@ void device_a2a_collapse(const uint64_t punCounts,
|
|
| 917 |
const uint32_t b = blockIdx.z;
|
| 918 |
const uint32_t row = blockIdx.y;
|
| 919 |
|
| 920 |
-
const int32_t quadCt =
|
| 921 |
|
| 922 |
-
if
|
| 923 |
-
|
| 924 |
-
return;
|
| 925 |
-
}
|
| 926 |
}
|
| 927 |
|
| 928 |
// Only process the lead rows
|
| 929 |
-
const auto isLeadRow =
|
| 930 |
if (!isLeadRow[row]) {
|
| 931 |
return;
|
| 932 |
}
|
|
@@ -945,7 +941,7 @@ void device_a2a_collapse(const uint64_t punCounts,
|
|
| 945 |
__syncthreads();
|
| 946 |
}
|
| 947 |
|
| 948 |
-
T *exData =
|
| 949 |
|
| 950 |
const int32_t adjCount = allAdjCounts[b][row];
|
| 951 |
const int32_t *adjIdxs = allAdjValues[b][row].data();
|
|
@@ -986,20 +982,12 @@ void device_a2a_collapse(const uint64_t punCounts,
|
|
| 986 |
|
| 987 |
// Figure out the output position
|
| 988 |
uint32_t writePosition = 0;
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
writePosition += regionCounts[i];
|
| 992 |
-
}
|
| 993 |
}
|
| 994 |
|
| 995 |
-
const int32_t numLongs = row >> 3; // Divide by 8
|
| 996 |
const uint8_t *pCurrIsLeadRow = reinterpret_cast<const uint8_t*>(isLeadRow);
|
| 997 |
-
|
| 998 |
-
|
| 999 |
-
for (int32_t i = threadRank; i < numLongs; i += BLOCK_WIDTH) {
|
| 1000 |
-
writePosition += __popcll(lpCurrIsLeadRow[i]);
|
| 1001 |
-
}
|
| 1002 |
-
for (int32_t i = (numLongs * 8) + threadRank; i < row; i += BLOCK_WIDTH) {
|
| 1003 |
if (pCurrIsLeadRow[i]) {
|
| 1004 |
++writePosition;
|
| 1005 |
}
|
|
@@ -1075,13 +1063,9 @@ CollapseRowsResult collapse_rows(
|
|
| 1075 |
int64_t embedSize = sizeof(EmbedQuad_<scalar_t>) / sizeof(scalar_t);
|
| 1076 |
auto rowMergeTensor = torch::empty({ quads.size(0), embedSize, quads.size(1) * quads.size(2) }, quads.options());
|
| 1077 |
|
| 1078 |
-
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1079 |
auto idsTensor = torch::full({ quads.size(0), quads.size(1) * quads.size(2) },
|
| 1080 |
std::numeric_limits<int32_t>::max(),
|
| 1081 |
counts.options().dtype(torch::kInt32));
|
| 1082 |
-
#else
|
| 1083 |
-
torch::Tensor idsTensor;
|
| 1084 |
-
#endif
|
| 1085 |
|
| 1086 |
dim3 blockSize(32, 3, 1);
|
| 1087 |
dim3 gridSize(1,
|
|
@@ -1093,10 +1077,8 @@ CollapseRowsResult collapse_rows(
|
|
| 1093 |
probs.packed_accessor64<scalar_t, 3>(),
|
| 1094 |
probThreshold, iouThreshold,
|
| 1095 |
counts.packed_accessor64<int32_t, 1>(),
|
| 1096 |
-
rowMergeTensor.packed_accessor64<scalar_t, 3>()
|
| 1097 |
-
|
| 1098 |
-
, idsTensor.packed_accessor64<int32_t, 2>()
|
| 1099 |
-
#endif
|
| 1100 |
);
|
| 1101 |
|
| 1102 |
#ifdef NMS_VERIFY_CORRECTNESS
|
|
@@ -1119,7 +1101,6 @@ CollapseRowsResult collapse_rows(
|
|
| 1119 |
|
| 1120 |
counts = counts.slice(/*dim=*/ 0, 0, counts.size(0) - 1);
|
| 1121 |
|
| 1122 |
-
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1123 |
int64_t maxExCount;
|
| 1124 |
if (counts.size(0) > 1) {
|
| 1125 |
maxExCount = counts.max().item<int32_t>();
|
|
@@ -1131,13 +1112,12 @@ CollapseRowsResult collapse_rows(
|
|
| 1131 |
|
| 1132 |
rowMergeTensor = rowMergeTensor.slice(2, 0, maxExCount);
|
| 1133 |
idsTensor = idsTensor.slice(1, 0, maxExCount);
|
| 1134 |
-
auto order = torch::argsort(idsTensor, /*dim=*/ 1, s_sortOrder);
|
| 1135 |
|
| 1136 |
auto embOrder = order.unsqueeze(1).expand_as(rowMergeTensor);
|
| 1137 |
|
| 1138 |
rowMergeTensor = torch::gather(rowMergeTensor, /*dim=*/ 2, embOrder);
|
| 1139 |
idsTensor = torch::gather(idsTensor, /*dim=*/ 1, order);
|
| 1140 |
-
#endif
|
| 1141 |
|
| 1142 |
return { counts, rowMergeTensor, totalQuads, idsTensor, imageWidth, imageHeight };
|
| 1143 |
}
|
|
@@ -1177,8 +1157,8 @@ struct AdjacencyResult {
|
|
| 1177 |
int64_t MaxExCount;
|
| 1178 |
};
|
| 1179 |
|
| 1180 |
-
template<
|
| 1181 |
-
void cpu_a2a_adjacency_sparse(const
|
| 1182 |
const T iouThreshold,
|
| 1183 |
torch::Tensor embedQuadsTensor,
|
| 1184 |
torch::Tensor outIsStartTensorGPU,
|
|
@@ -1196,7 +1176,7 @@ void cpu_a2a_adjacency_sparse(const uint64_t punCounts,
|
|
| 1196 |
auto adjValues = outSparseAdjTensor.accessor<int32_t, 3>();
|
| 1197 |
|
| 1198 |
for (int32_t b = 0; b < embedQuadsTensor.size(0); ++b) {
|
| 1199 |
-
const int32_t quadCt =
|
| 1200 |
|
| 1201 |
T *exData = embedQuads[b].data();
|
| 1202 |
|
|
@@ -1284,13 +1264,6 @@ AdjacencyResult compute_all_to_all_adjacency(
|
|
| 1284 |
counts.options().dtype(torch::kInt32));
|
| 1285 |
#endif
|
| 1286 |
|
| 1287 |
-
// If the batch is only a single example, instead of hitting global memory for the count, we can
|
| 1288 |
-
// just encode the count into the pointer instead
|
| 1289 |
-
uint64_t ptrCounts = reinterpret_cast<uint64_t>(counts.data_ptr<int32_t>());
|
| 1290 |
-
if (counts.size(0) == 1) {
|
| 1291 |
-
ptrCounts = maxExCount;
|
| 1292 |
-
}
|
| 1293 |
-
|
| 1294 |
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1295 |
auto cpuAdjValuesTensor = adjValuesTensor.cpu();
|
| 1296 |
auto cpuAdjCountsTensor = adjCountsTensor.cpu();
|
|
@@ -1318,23 +1291,15 @@ AdjacencyResult compute_all_to_all_adjacency(
|
|
| 1318 |
//blockSize = dim3{ GRID_NUM_WARPS * 32, 1, 1 };
|
| 1319 |
//gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
|
| 1320 |
|
| 1321 |
-
//
|
| 1322 |
-
//
|
| 1323 |
-
// device_a2a_adjacency_build_grid<GRID_NUM_WARPS, false, scalar_t, CELL_SIZE>;
|
| 1324 |
-
|
| 1325 |
-
//buildGridFn KERNEL_ARG2(gridSize, blockSize) (
|
| 1326 |
-
// ptrCounts,
|
| 1327 |
// collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
|
| 1328 |
// gridCellsTensor.packed_accessor64<int32_t, 4>(),
|
| 1329 |
// quadCellExtentsTensor.packed_accessor64<int32_t, 3>()
|
| 1330 |
//);
|
| 1331 |
|
| 1332 |
-
//
|
| 1333 |
-
//
|
| 1334 |
-
// device_a2a_adjacency_with_grid<GRID_NUM_WARPS, false, scalar_t>;
|
| 1335 |
-
|
| 1336 |
-
//adjGridFn KERNEL_ARG3(gridSize, blockSize, smemSize) (
|
| 1337 |
-
// ptrCounts,
|
| 1338 |
// iouThreshold,
|
| 1339 |
// collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
|
| 1340 |
// gridCellsTensor.packed_accessor64<int32_t, 4>(),
|
|
@@ -1351,11 +1316,9 @@ AdjacencyResult compute_all_to_all_adjacency(
|
|
| 1351 |
gridSize = dim3{div_up(totalWork, blockSize.x),
|
| 1352 |
static_cast<uint32_t>(counts.size(0))};
|
| 1353 |
|
| 1354 |
-
auto adjFn = counts.size(0) == 1 ? device_a2a_adjacency_sparse<true, scalar_t> : device_a2a_adjacency_sparse<false, scalar_t>;
|
| 1355 |
-
|
| 1356 |
// This algorithm is O(n^2) with n being the current number of quads
|
| 1357 |
-
|
| 1358 |
-
|
| 1359 |
iouThreshold,
|
| 1360 |
collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
|
| 1361 |
isStartTensor.packed_accessor64<bool, 2>(),
|
|
@@ -1365,7 +1328,9 @@ AdjacencyResult compute_all_to_all_adjacency(
|
|
| 1365 |
|
| 1366 |
|
| 1367 |
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1368 |
-
|
|
|
|
|
|
|
| 1369 |
collapseResult.StridedMergeQuads, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
|
| 1370 |
|
| 1371 |
adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
|
|
@@ -1380,16 +1345,12 @@ AdjacencyResult compute_all_to_all_adjacency(
|
|
| 1380 |
auto maxDepthTensor = torch::tensor(0, adjCountsTensor.options());
|
| 1381 |
#endif
|
| 1382 |
|
| 1383 |
-
auto traverseFn = counts.size(0) == 1 ?
|
| 1384 |
-
device_flatten_graph_iterative<true> :
|
| 1385 |
-
device_flatten_graph_iterative<false>;
|
| 1386 |
-
|
| 1387 |
blockSize = dim3{ 128, 1, 1 };
|
| 1388 |
gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
|
| 1389 |
smemSize = div_up(maxExCount * sizeof(visit_mask_t), sizeof(uint32_t)) * sizeof(uint32_t);
|
| 1390 |
|
| 1391 |
-
|
| 1392 |
-
|
| 1393 |
isStartTensor.packed_accessor64<bool, 2>(),
|
| 1394 |
reinterpret_cast<uint32_t*>(adjCountsTensor.data_ptr<int32_t>()),
|
| 1395 |
reinterpret_cast<uint32_t*>(adjValuesTensor.data_ptr<int32_t>())
|
|
@@ -1399,7 +1360,7 @@ AdjacencyResult compute_all_to_all_adjacency(
|
|
| 1399 |
);
|
| 1400 |
|
| 1401 |
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1402 |
-
cpu_flatten_graph<
|
| 1403 |
|
| 1404 |
cpuAdjValuesTensor = std::get<0>(torch::sort(cpuAdjValuesTensor, /*dim=*/ 2));
|
| 1405 |
adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
|
|
@@ -1437,7 +1398,6 @@ AdjacencyResult compute_all_to_all_adjacency(
|
|
| 1437 |
cpuIsStartTensor = isStartTensor.cpu();
|
| 1438 |
cpuAdjCountsTensor = adjCountsTensor.cpu();
|
| 1439 |
cpuAdjValuesTensor = adjValuesTensor.cpu();
|
| 1440 |
-
auto cpuCounts = counts.cpu();
|
| 1441 |
auto cpuCollapseIds = collapseResult.QuadIds.cpu();
|
| 1442 |
|
| 1443 |
static std::vector<std::unordered_set<int32_t>> s_knownGroups;
|
|
@@ -1589,22 +1549,11 @@ nms_result_t
|
|
| 1589 |
dim3 blockSize(NUM_WARPS * 32, 1, 1);
|
| 1590 |
dim3 gridSize(1, adjResult.MaxExCount, counts.size(0));
|
| 1591 |
|
| 1592 |
-
// If the batch is only a single example, instead of hitting global memory for the count, we can
|
| 1593 |
-
// just encode the count into the pointer instead
|
| 1594 |
-
uint64_t ptrCounts = reinterpret_cast<uint64_t>(counts.data_ptr<int32_t>());
|
| 1595 |
-
if (counts.size(0) == 1) {
|
| 1596 |
-
ptrCounts = adjResult.MaxExCount;
|
| 1597 |
-
}
|
| 1598 |
-
|
| 1599 |
torch::Tensor outQuads = torch::empty({ numOutQuads, 4, 2 }, embedQuads.options());
|
| 1600 |
torch::Tensor outConf = torch::empty({ numOutQuads }, embedQuads.options());
|
| 1601 |
|
| 1602 |
-
|
| 1603 |
-
|
| 1604 |
-
device_a2a_collapse<NUM_WARPS, scalar_t, false>;
|
| 1605 |
-
|
| 1606 |
-
collapseFn KERNEL_ARG2(gridSize, blockSize) (
|
| 1607 |
-
ptrCounts,
|
| 1608 |
embedQuads.packed_accessor64<scalar_t, 3>(),
|
| 1609 |
isLeadRow.packed_accessor64<bool, 2>(),
|
| 1610 |
regionCounts.data_ptr<int64_t>(),
|
|
|
|
| 157 |
torch::PackedTensorAccessor64<T, 3> allConfs,
|
| 158 |
T confThreshold, T iouThreshold,
|
| 159 |
torch::PackedTensorAccessor64<int32_t, 1> allOutCounts,
|
| 160 |
+
torch::PackedTensorAccessor64<T, 3> allOutEmbedQuads,
|
| 161 |
+
torch::PackedTensorAccessor64<int32_t, 2> allOutIds)
|
|
|
|
|
|
|
|
|
|
| 162 |
{
|
| 163 |
typedef InPlaceQuad_<T> Quadf;
|
| 164 |
static_assert(sizeof(Quadf) == sizeof(T) * 8, "Invalid QuadMem size!");
|
|
|
|
| 303 |
}
|
| 304 |
|
| 305 |
write_embed_quad(outEmbedQuads, outQuad, storeOff + procLabel - 1);
|
|
|
|
| 306 |
if (threadRank == 0) {
|
| 307 |
allOutIds[b][storeOff + procLabel - 1] = r * 32 + startIdx;
|
| 308 |
}
|
|
|
|
| 309 |
}
|
| 310 |
|
| 311 |
if (threadRank == 0) {
|
|
|
|
| 316 |
#undef threadRank
|
| 317 |
}
|
| 318 |
|
| 319 |
+
template<typename T>
|
| 320 |
__global__
|
| 321 |
+
void device_a2a_adjacency_sparse(const int32_t *ptrQuadCts,
|
| 322 |
T iouThreshold,
|
| 323 |
torch::PackedTensorAccessor64<T, 3> embedQuads,
|
| 324 |
torch::PackedTensorAccessor64<bool, 2> outIsStart,
|
|
|
|
| 327 |
{
|
| 328 |
const uint32_t b = blockIdx.y;
|
| 329 |
|
| 330 |
+
const int32_t quadCt = ptrQuadCts[b];
|
| 331 |
+
|
| 332 |
+
if (quadCt == 0) {
|
| 333 |
+
return;
|
| 334 |
+
}
|
| 335 |
|
| 336 |
const int32_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 337 |
const int32_t row = jobIdx / quadCt;
|
|
|
|
| 342 |
return;
|
| 343 |
}
|
| 344 |
|
| 345 |
+
T* exData = embedQuads[b].data();
|
| 346 |
|
| 347 |
const auto qRow = StridedEmbedQuad_<T>{ exData + row * embedQuads.stride(2), embedQuads.stride(1) }.Bounds(),
|
| 348 |
qCol = StridedEmbedQuad_<T>{ exData + col * embedQuads.stride(2), embedQuads.stride(1) }.Bounds();
|
|
|
|
| 404 |
}
|
| 405 |
}
|
| 406 |
|
| 407 |
+
template<uint32_t NumWarps, typename T, int32_t I_CELL_SIZE>
|
| 408 |
__global__
|
| 409 |
+
void device_a2a_adjacency_build_grid(const int32_t *ptrQuadCts,
|
| 410 |
torch::PackedTensorAccessor64<T, 3> embedQuads,
|
| 411 |
torch::PackedTensorAccessor64<int32_t, 4> outGridCells,
|
| 412 |
torch::PackedTensorAccessor64<int32_t, 3> outQuadCells)
|
|
|
|
| 422 |
|
| 423 |
const uint32_t b = blockIdx.z;
|
| 424 |
|
| 425 |
+
const uint32_t quadCt = ptrQuadCts[b];
|
| 426 |
const uint32_t quadIdx = blockIdx.y;
|
| 427 |
|
| 428 |
+
if (quadIdx >= quadCt) {
|
| 429 |
return;
|
| 430 |
}
|
| 431 |
|
|
|
|
| 484 |
|
| 485 |
typedef uint8_t visit_mask_t;
|
| 486 |
|
| 487 |
+
template<uint32_t NumWarps, typename T>
|
| 488 |
__global__
|
| 489 |
+
void device_a2a_adjacency_with_grid(const int32_t *ptrQuadCts,
|
| 490 |
T iouThreshold,
|
| 491 |
torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
|
| 492 |
torch::PackedTensorAccessor64<int32_t, 4> allCells,
|
|
|
|
| 502 |
|
| 503 |
const uint32_t b = blockIdx.z;
|
| 504 |
|
| 505 |
+
const uint32_t quadCt = ptrQuadCts[b];
|
| 506 |
const uint32_t quadIdx = blockIdx.y;
|
| 507 |
|
| 508 |
+
if (quadIdx >= quadCt) {
|
| 509 |
return;
|
| 510 |
}
|
| 511 |
|
|
|
|
| 534 |
auto exAdjCounts = reinterpret_cast<uint32_t*>(outAdjCounts[b].data());
|
| 535 |
auto exAdjValues = outSparseAdj[b][quadIdx].data();
|
| 536 |
|
| 537 |
+
T *exData = allEmbedQuads[b].data();
|
| 538 |
|
| 539 |
const auto bdsAnchor = Quad_<T>{ s_quadVerts }.Bounds();
|
| 540 |
|
|
|
|
| 598 |
}
|
| 599 |
}
|
| 600 |
|
|
|
|
| 601 |
__global__
|
| 602 |
+
void device_flatten_graph_iterative(const int32_t *ptrQuadCts,
|
| 603 |
torch::PackedTensorAccessor64<bool, 2> allIsStart,
|
| 604 |
volatile uint32_t *allAdjCounts,
|
| 605 |
volatile uint32_t *allAdjValues
|
|
|
|
| 620 |
const uint32_t b = blockIdx.z;
|
| 621 |
const uint32_t anchorRow = blockIdx.y;
|
| 622 |
|
| 623 |
+
const uint32_t quadCt = ptrQuadCts[b];
|
| 624 |
|
| 625 |
// Only need to check this if there are multiple examples, since in the case of a single example,
|
| 626 |
// the grid is precisely sized to that quadCt
|
| 627 |
+
if (anchorRow >= quadCt) {
|
| 628 |
+
return;
|
|
|
|
|
|
|
| 629 |
}
|
| 630 |
|
| 631 |
auto isStart = allIsStart[b].data();
|
|
|
|
| 686 |
visitStack[1] = anchorRow;
|
| 687 |
#ifndef NDEBUG
|
| 688 |
for (uint32_t i = 2; i < VISIT_STACK_SIZE; ++i) {
|
| 689 |
+
visitStack[i] = TERM_VALUE;
|
| 690 |
}
|
| 691 |
#endif
|
| 692 |
int32_t visitPtr = 1;
|
| 693 |
|
| 694 |
+
// NOTE: This loop is actually terminated by the `if (warpNextCol == TERM_VALUE)` check below
|
| 695 |
+
for (uint32_t dfsIter = 0; true; ++dfsIter) {
|
| 696 |
#ifdef NMS_VERIFY_CORRECTNESS
|
| 697 |
assert(visitPtr >= 0 && visitPtr < VISIT_STACK_SIZE);
|
| 698 |
#endif
|
|
|
|
| 704 |
if (threadNextCol == warpNextCol) {
|
| 705 |
#ifndef NDEBUG
|
| 706 |
// This makes it easier to debug where the pointer is
|
| 707 |
+
visitStack[visitPtr] = TERM_VALUE;
|
| 708 |
#endif
|
| 709 |
--visitPtr;
|
| 710 |
}
|
|
|
|
| 728 |
const uint32_t procAdjCount = adjCounts[procRow];
|
| 729 |
auto procAdjValues = adjValues + (procRow * maxExCount);
|
| 730 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 731 |
for (uint32_t i = threadRank; i < procAdjCount; i += WARP_SIZE) {
|
| 732 |
+
uint32_t adjCol = procAdjValues[i];
|
| 733 |
+
|
| 734 |
+
auto group = cg::coalesced_threads();
|
| 735 |
+
// Offsetting by the iteration number will help balance out the maximum depth of any stack in the warp.
|
| 736 |
+
// The reason behind this is due to how otherwise, warp-0 will always get a new element, warp-1 iff the adj graph
|
| 737 |
+
// has more than one element, warp-2 iff the adj graph has more than two elements, and so on. Basically,
|
| 738 |
+
// the warps have decreasing pressure. With the rotation mechanism, it helps to balance out stack usage.
|
| 739 |
+
adjCol = group.shfl(adjCol, (group.thread_rank() + dfsIter) % group.size());
|
| 740 |
|
| 741 |
// This will set the queued flag for this column, if it's not already set.
|
| 742 |
// It also returns the old state. In our case, we only want to add this value to the
|
|
|
|
| 748 |
|
| 749 |
bool alreadyAdded = oldMask & ADDED_MASK;
|
| 750 |
|
|
|
|
| 751 |
const uint32_t gThreadRank = group.thread_rank();
|
| 752 |
uint32_t notAddedBallot = group.ballot(!alreadyAdded);
|
| 753 |
if (notAddedBallot) {
|
|
|
|
| 824 |
}
|
| 825 |
}
|
| 826 |
|
| 827 |
+
void cpu_flatten_graph(const int32_t *ptrQuadCts,
|
|
|
|
| 828 |
torch::Tensor isStartTensorGPU,
|
| 829 |
torch::Tensor adjCountsTensorGPU,
|
| 830 |
torch::Tensor adjValuesTensorGPU)
|
|
|
|
| 838 |
auto allAdjValues = adjValuesTensor.accessor<int32_t, 3>();
|
| 839 |
|
| 840 |
for (int32_t b = 0; b < allAdjCounts.size(0); ++b) {
|
| 841 |
+
const int32_t quadCt = ptrQuadCts[b];
|
| 842 |
|
| 843 |
for (int32_t row = 0; row < quadCt; ++row) {
|
| 844 |
std::unordered_set<int32_t> fullAdjSet;
|
|
|
|
| 893 |
}
|
| 894 |
}
|
| 895 |
|
| 896 |
+
template<uint32_t NumWarps, typename T>
|
| 897 |
__global__
|
| 898 |
+
void device_a2a_collapse(torch::PackedTensorAccessor64<int32_t, 1> quadCounts,
|
| 899 |
torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
|
| 900 |
torch::PackedTensorAccessor64<bool, 2> allIsLeadRow,
|
| 901 |
const int64_t *regionCounts,
|
|
|
|
| 915 |
const uint32_t b = blockIdx.z;
|
| 916 |
const uint32_t row = blockIdx.y;
|
| 917 |
|
| 918 |
+
const int32_t quadCt = quadCounts[b];
|
| 919 |
|
| 920 |
+
if (row >= quadCt) {
|
| 921 |
+
return;
|
|
|
|
|
|
|
| 922 |
}
|
| 923 |
|
| 924 |
// Only process the lead rows
|
| 925 |
+
const auto isLeadRow = allIsLeadRow[b].data();
|
| 926 |
if (!isLeadRow[row]) {
|
| 927 |
return;
|
| 928 |
}
|
|
|
|
| 941 |
__syncthreads();
|
| 942 |
}
|
| 943 |
|
| 944 |
+
T *exData = allEmbedQuads[b].data();
|
| 945 |
|
| 946 |
const int32_t adjCount = allAdjCounts[b][row];
|
| 947 |
const int32_t *adjIdxs = allAdjValues[b][row].data();
|
|
|
|
| 982 |
|
| 983 |
// Figure out the output position
|
| 984 |
uint32_t writePosition = 0;
|
| 985 |
+
for (int32_t i = threadRank; i < b; i += BLOCK_WIDTH) {
|
| 986 |
+
writePosition += regionCounts[i];
|
|
|
|
|
|
|
| 987 |
}
|
| 988 |
|
|
|
|
| 989 |
const uint8_t *pCurrIsLeadRow = reinterpret_cast<const uint8_t*>(isLeadRow);
|
| 990 |
+
for (int32_t i = threadRank; i < row; i += BLOCK_WIDTH) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 991 |
if (pCurrIsLeadRow[i]) {
|
| 992 |
++writePosition;
|
| 993 |
}
|
|
|
|
| 1063 |
int64_t embedSize = sizeof(EmbedQuad_<scalar_t>) / sizeof(scalar_t);
|
| 1064 |
auto rowMergeTensor = torch::empty({ quads.size(0), embedSize, quads.size(1) * quads.size(2) }, quads.options());
|
| 1065 |
|
|
|
|
| 1066 |
auto idsTensor = torch::full({ quads.size(0), quads.size(1) * quads.size(2) },
|
| 1067 |
std::numeric_limits<int32_t>::max(),
|
| 1068 |
counts.options().dtype(torch::kInt32));
|
|
|
|
|
|
|
|
|
|
| 1069 |
|
| 1070 |
dim3 blockSize(32, 3, 1);
|
| 1071 |
dim3 gridSize(1,
|
|
|
|
| 1077 |
probs.packed_accessor64<scalar_t, 3>(),
|
| 1078 |
probThreshold, iouThreshold,
|
| 1079 |
counts.packed_accessor64<int32_t, 1>(),
|
| 1080 |
+
rowMergeTensor.packed_accessor64<scalar_t, 3>(),
|
| 1081 |
+
idsTensor.packed_accessor64<int32_t, 2>()
|
|
|
|
|
|
|
| 1082 |
);
|
| 1083 |
|
| 1084 |
#ifdef NMS_VERIFY_CORRECTNESS
|
|
|
|
| 1101 |
|
| 1102 |
counts = counts.slice(/*dim=*/ 0, 0, counts.size(0) - 1);
|
| 1103 |
|
|
|
|
| 1104 |
int64_t maxExCount;
|
| 1105 |
if (counts.size(0) > 1) {
|
| 1106 |
maxExCount = counts.max().item<int32_t>();
|
|
|
|
| 1112 |
|
| 1113 |
rowMergeTensor = rowMergeTensor.slice(2, 0, maxExCount);
|
| 1114 |
idsTensor = idsTensor.slice(1, 0, maxExCount);
|
| 1115 |
+
auto order = torch::argsort(idsTensor, /*dim=*/ 1, s_sortOrder);
|
| 1116 |
|
| 1117 |
auto embOrder = order.unsqueeze(1).expand_as(rowMergeTensor);
|
| 1118 |
|
| 1119 |
rowMergeTensor = torch::gather(rowMergeTensor, /*dim=*/ 2, embOrder);
|
| 1120 |
idsTensor = torch::gather(idsTensor, /*dim=*/ 1, order);
|
|
|
|
| 1121 |
|
| 1122 |
return { counts, rowMergeTensor, totalQuads, idsTensor, imageWidth, imageHeight };
|
| 1123 |
}
|
|
|
|
| 1157 |
int64_t MaxExCount;
|
| 1158 |
};
|
| 1159 |
|
| 1160 |
+
template<typename T>
|
| 1161 |
+
void cpu_a2a_adjacency_sparse(const int32_t *ptrQuadCts,
|
| 1162 |
const T iouThreshold,
|
| 1163 |
torch::Tensor embedQuadsTensor,
|
| 1164 |
torch::Tensor outIsStartTensorGPU,
|
|
|
|
| 1176 |
auto adjValues = outSparseAdjTensor.accessor<int32_t, 3>();
|
| 1177 |
|
| 1178 |
for (int32_t b = 0; b < embedQuadsTensor.size(0); ++b) {
|
| 1179 |
+
const int32_t quadCt = ptrQuadCts[b];
|
| 1180 |
|
| 1181 |
T *exData = embedQuads[b].data();
|
| 1182 |
|
|
|
|
| 1264 |
counts.options().dtype(torch::kInt32));
|
| 1265 |
#endif
|
| 1266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1267 |
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1268 |
auto cpuAdjValuesTensor = adjValuesTensor.cpu();
|
| 1269 |
auto cpuAdjCountsTensor = adjCountsTensor.cpu();
|
|
|
|
| 1291 |
//blockSize = dim3{ GRID_NUM_WARPS * 32, 1, 1 };
|
| 1292 |
//gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
|
| 1293 |
|
| 1294 |
+
//device_a2a_adjacency_build_grid<GRID_NUM_WARPS, scalar_t, CELL_SIZE> KERNEL_ARG2(gridSize, blockSize) (
|
| 1295 |
+
// counts.data_ptr<int32_t>(),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1296 |
// collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
|
| 1297 |
// gridCellsTensor.packed_accessor64<int32_t, 4>(),
|
| 1298 |
// quadCellExtentsTensor.packed_accessor64<int32_t, 3>()
|
| 1299 |
//);
|
| 1300 |
|
| 1301 |
+
//device_a2a_adjacency_with_grid<GRID_NUM_WARPS, scalar_t> KERNEL_ARG3(gridSize, blockSize, smemSize) (
|
| 1302 |
+
// counts.data_ptr<int32_t>(),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1303 |
// iouThreshold,
|
| 1304 |
// collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
|
| 1305 |
// gridCellsTensor.packed_accessor64<int32_t, 4>(),
|
|
|
|
| 1316 |
gridSize = dim3{div_up(totalWork, blockSize.x),
|
| 1317 |
static_cast<uint32_t>(counts.size(0))};
|
| 1318 |
|
|
|
|
|
|
|
| 1319 |
// This algorithm is O(n^2) with n being the current number of quads
|
| 1320 |
+
device_a2a_adjacency_sparse<scalar_t> KERNEL_ARG2(gridSize, blockSize) (
|
| 1321 |
+
counts.data_ptr<int32_t>(),
|
| 1322 |
iouThreshold,
|
| 1323 |
collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
|
| 1324 |
isStartTensor.packed_accessor64<bool, 2>(),
|
|
|
|
| 1328 |
|
| 1329 |
|
| 1330 |
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1331 |
+
auto cpuCounts = counts.cpu();
|
| 1332 |
+
|
| 1333 |
+
cpu_a2a_adjacency_sparse<scalar_t>(cpuCounts.data_ptr<int32_t>(), iouThreshold,
|
| 1334 |
collapseResult.StridedMergeQuads, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
|
| 1335 |
|
| 1336 |
adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
|
|
|
|
| 1345 |
auto maxDepthTensor = torch::tensor(0, adjCountsTensor.options());
|
| 1346 |
#endif
|
| 1347 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1348 |
blockSize = dim3{ 128, 1, 1 };
|
| 1349 |
gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
|
| 1350 |
smemSize = div_up(maxExCount * sizeof(visit_mask_t), sizeof(uint32_t)) * sizeof(uint32_t);
|
| 1351 |
|
| 1352 |
+
device_flatten_graph_iterative KERNEL_ARG3(gridSize, blockSize, smemSize) (
|
| 1353 |
+
counts.data_ptr<int32_t>(),
|
| 1354 |
isStartTensor.packed_accessor64<bool, 2>(),
|
| 1355 |
reinterpret_cast<uint32_t*>(adjCountsTensor.data_ptr<int32_t>()),
|
| 1356 |
reinterpret_cast<uint32_t*>(adjValuesTensor.data_ptr<int32_t>())
|
|
|
|
| 1360 |
);
|
| 1361 |
|
| 1362 |
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1363 |
+
cpu_flatten_graph(cpuCounts.data_ptr<int32_t>(), cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
|
| 1364 |
|
| 1365 |
cpuAdjValuesTensor = std::get<0>(torch::sort(cpuAdjValuesTensor, /*dim=*/ 2));
|
| 1366 |
adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
|
|
|
|
| 1398 |
cpuIsStartTensor = isStartTensor.cpu();
|
| 1399 |
cpuAdjCountsTensor = adjCountsTensor.cpu();
|
| 1400 |
cpuAdjValuesTensor = adjValuesTensor.cpu();
|
|
|
|
| 1401 |
auto cpuCollapseIds = collapseResult.QuadIds.cpu();
|
| 1402 |
|
| 1403 |
static std::vector<std::unordered_set<int32_t>> s_knownGroups;
|
|
|
|
| 1549 |
dim3 blockSize(NUM_WARPS * 32, 1, 1);
|
| 1550 |
dim3 gridSize(1, adjResult.MaxExCount, counts.size(0));
|
| 1551 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1552 |
torch::Tensor outQuads = torch::empty({ numOutQuads, 4, 2 }, embedQuads.options());
|
| 1553 |
torch::Tensor outConf = torch::empty({ numOutQuads }, embedQuads.options());
|
| 1554 |
|
| 1555 |
+
device_a2a_collapse<NUM_WARPS, scalar_t> KERNEL_ARG2(gridSize, blockSize) (
|
| 1556 |
+
counts.packed_accessor64<int32_t, 1>(),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1557 |
embedQuads.packed_accessor64<scalar_t, 3>(),
|
| 1558 |
isLeadRow.packed_accessor64<bool, 2>(),
|
| 1559 |
regionCounts.data_ptr<int64_t>(),
|
nemotron-ocr/src/nemotron_ocr/inference/pipeline.py
CHANGED
|
@@ -181,29 +181,17 @@ class NemotronOCR:
|
|
| 181 |
e2e_det_conf = torch.sigmoid(det_conf)
|
| 182 |
e2e_det_coords = rrect_to_quads(det_rboxes.float(), DETECTOR_DOWNSAMPLE)
|
| 183 |
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
kernel_height=2,
|
| 196 |
-
kernel_width=3,
|
| 197 |
-
max_regions=NMS_MAX_REGIONS,
|
| 198 |
-
verbose=False,
|
| 199 |
-
)[:3]
|
| 200 |
-
all_quads.append(quads)
|
| 201 |
-
all_confidence.append(confidence)
|
| 202 |
-
all_region_counts.append(region_counts)
|
| 203 |
-
|
| 204 |
-
quads = torch.cat(all_quads, dim=0)
|
| 205 |
-
confidence = torch.cat(all_confidence, dim=0)
|
| 206 |
-
region_counts = torch.cat(all_region_counts, dim=0)
|
| 207 |
|
| 208 |
if quads.shape[0] == 0:
|
| 209 |
rec_rectified_quads = torch.empty(0, 128, 8, 32, dtype=torch.float32, device=padded_image.device)
|
|
|
|
| 181 |
e2e_det_conf = torch.sigmoid(det_conf)
|
| 182 |
e2e_det_coords = rrect_to_quads(det_rboxes.float(), DETECTOR_DOWNSAMPLE)
|
| 183 |
|
| 184 |
+
quads, confidence, region_counts = quad_non_maximal_suppression(
|
| 185 |
+
e2e_det_coords[idx].unsqueeze(0),
|
| 186 |
+
e2e_det_conf[idx].unsqueeze(0),
|
| 187 |
+
prob_threshold=NMS_PROB_THRESHOLD,
|
| 188 |
+
iou_threshold=NMS_IOU_THRESHOLD,
|
| 189 |
+
kernel_height=2,
|
| 190 |
+
kernel_width=3,
|
| 191 |
+
max_regions=NMS_MAX_REGIONS,
|
| 192 |
+
verbose=False,
|
| 193 |
+
)[:3]
|
| 194 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
if quads.shape[0] == 0:
|
| 197 |
rec_rectified_quads = torch.empty(0, 128, 8, 32, dtype=torch.float32, device=padded_image.device)
|