Charles Blackmon-Luca commited on
Commit
4ba3706
·
unverified ·
1 Parent(s): 90015d3

Fix misaligned address error in quad NMS cuda implementation

Browse files
nemotron-ocr/cpp/non_maximal_suppression/cuda_non_maximal_suppression.cu CHANGED
@@ -157,11 +157,8 @@ void device_row_collapse(torch::PackedTensorAccessor64<T, 5> allQuads,
157
  torch::PackedTensorAccessor64<T, 3> allConfs,
158
  T confThreshold, T iouThreshold,
159
  torch::PackedTensorAccessor64<int32_t, 1> allOutCounts,
160
- torch::PackedTensorAccessor64<T, 3> allOutEmbedQuads
161
- #ifdef NMS_VERIFY_CORRECTNESS
162
- , torch::PackedTensorAccessor64<int32_t, 2> allOutIds
163
- #endif
164
- )
165
  {
166
  typedef InPlaceQuad_<T> Quadf;
167
  static_assert(sizeof(Quadf) == sizeof(T) * 8, "Invalid QuadMem size!");
@@ -306,11 +303,9 @@ void device_row_collapse(torch::PackedTensorAccessor64<T, 5> allQuads,
306
  }
307
 
308
  write_embed_quad(outEmbedQuads, outQuad, storeOff + procLabel - 1);
309
- #ifdef NMS_VERIFY_CORRECTNESS
310
  if (threadRank == 0) {
311
  allOutIds[b][storeOff + procLabel - 1] = r * 32 + startIdx;
312
  }
313
- #endif
314
  }
315
 
316
  if (threadRank == 0) {
@@ -321,9 +316,9 @@ void device_row_collapse(torch::PackedTensorAccessor64<T, 5> allQuads,
321
  #undef threadRank
322
  }
323
 
324
- template<bool IsSingleExample, typename T>
325
  __global__
326
- void device_a2a_adjacency_sparse(const uint64_t punCounts,
327
  T iouThreshold,
328
  torch::PackedTensorAccessor64<T, 3> embedQuads,
329
  torch::PackedTensorAccessor64<bool, 2> outIsStart,
@@ -332,7 +327,11 @@ void device_a2a_adjacency_sparse(const uint64_t punCounts,
332
  {
333
  const uint32_t b = blockIdx.y;
334
 
335
- const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
 
 
 
 
336
 
337
  const int32_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;
338
  const int32_t row = jobIdx / quadCt;
@@ -343,7 +342,7 @@ void device_a2a_adjacency_sparse(const uint64_t punCounts,
343
  return;
344
  }
345
 
346
- T* exData = IsSingleExample ? embedQuads.data() : embedQuads[b].data();
347
 
348
  const auto qRow = StridedEmbedQuad_<T>{ exData + row * embedQuads.stride(2), embedQuads.stride(1) }.Bounds(),
349
  qCol = StridedEmbedQuad_<T>{ exData + col * embedQuads.stride(2), embedQuads.stride(1) }.Bounds();
@@ -405,9 +404,9 @@ void device_a2a_adjacency_sparse(const uint64_t punCounts,
405
  }
406
  }
407
 
408
- template<uint32_t NumWarps, bool IsSingleExample, typename T, int32_t I_CELL_SIZE>
409
  __global__
410
- void device_a2a_adjacency_build_grid(const uint64_t punCounts,
411
  torch::PackedTensorAccessor64<T, 3> embedQuads,
412
  torch::PackedTensorAccessor64<int32_t, 4> outGridCells,
413
  torch::PackedTensorAccessor64<int32_t, 3> outQuadCells)
@@ -423,10 +422,10 @@ void device_a2a_adjacency_build_grid(const uint64_t punCounts,
423
 
424
  const uint32_t b = blockIdx.z;
425
 
426
- const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
427
  const uint32_t quadIdx = blockIdx.y;
428
 
429
- if (!IsSingleExample && quadIdx >= quadCt) {
430
  return;
431
  }
432
 
@@ -485,9 +484,9 @@ void device_a2a_adjacency_build_grid(const uint64_t punCounts,
485
 
486
  typedef uint8_t visit_mask_t;
487
 
488
- template<uint32_t NumWarps, bool IsSingleExample, typename T>
489
  __global__
490
- void device_a2a_adjacency_with_grid(const uint64_t punCounts,
491
  T iouThreshold,
492
  torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
493
  torch::PackedTensorAccessor64<int32_t, 4> allCells,
@@ -503,10 +502,10 @@ void device_a2a_adjacency_with_grid(const uint64_t punCounts,
503
 
504
  const uint32_t b = blockIdx.z;
505
 
506
- const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
507
  const uint32_t quadIdx = blockIdx.y;
508
 
509
- if (!IsSingleExample && quadIdx >= quadCt) {
510
  return;
511
  }
512
 
@@ -535,7 +534,7 @@ void device_a2a_adjacency_with_grid(const uint64_t punCounts,
535
  auto exAdjCounts = reinterpret_cast<uint32_t*>(outAdjCounts[b].data());
536
  auto exAdjValues = outSparseAdj[b][quadIdx].data();
537
 
538
- T *exData = IsSingleExample ? allEmbedQuads.data() : allEmbedQuads[b].data();
539
 
540
  const auto bdsAnchor = Quad_<T>{ s_quadVerts }.Bounds();
541
 
@@ -599,9 +598,8 @@ void device_a2a_adjacency_with_grid(const uint64_t punCounts,
599
  }
600
  }
601
 
602
- template<bool IsSingleExample>
603
  __global__
604
- void device_flatten_graph_iterative(const uint64_t punCounts,
605
  torch::PackedTensorAccessor64<bool, 2> allIsStart,
606
  volatile uint32_t *allAdjCounts,
607
  volatile uint32_t *allAdjValues
@@ -622,14 +620,12 @@ void device_flatten_graph_iterative(const uint64_t punCounts,
622
  const uint32_t b = blockIdx.z;
623
  const uint32_t anchorRow = blockIdx.y;
624
 
625
- const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
626
 
627
  // Only need to check this if there are multiple examples, since in the case of a single example,
628
  // the grid is precisely sized to that quadCt
629
- if constexpr (!IsSingleExample) {
630
- if (anchorRow >= quadCt) {
631
- return;
632
- }
633
  }
634
 
635
  auto isStart = allIsStart[b].data();
@@ -690,12 +686,13 @@ void device_flatten_graph_iterative(const uint64_t punCounts,
690
  visitStack[1] = anchorRow;
691
  #ifndef NDEBUG
692
  for (uint32_t i = 2; i < VISIT_STACK_SIZE; ++i) {
693
- visitStack[i] = -2;
694
  }
695
  #endif
696
  int32_t visitPtr = 1;
697
 
698
- while (true) {
 
699
  #ifdef NMS_VERIFY_CORRECTNESS
700
  assert(visitPtr >= 0 && visitPtr < VISIT_STACK_SIZE);
701
  #endif
@@ -707,7 +704,7 @@ void device_flatten_graph_iterative(const uint64_t punCounts,
707
  if (threadNextCol == warpNextCol) {
708
  #ifndef NDEBUG
709
  // This makes it easier to debug where the pointer is
710
- visitStack[visitPtr] = -2;
711
  #endif
712
  --visitPtr;
713
  }
@@ -731,12 +728,15 @@ void device_flatten_graph_iterative(const uint64_t punCounts,
731
  const uint32_t procAdjCount = adjCounts[procRow];
732
  auto procAdjValues = adjValues + (procRow * maxExCount);
733
 
734
- // Offsetting by the iteration number will help balance out the maximum depth of any stack in the warp.
735
- // The reason behind this is due to how otherwise, warp-0 will always get a new element, warp-1 iff the adj graph
736
- // has more than one element, warp-2 iff the adj graph has more than two elements, and so on. Basically,
737
- // the warps have decreasing pressure. With the rotation mechanism, it helps to balance out stack usage.
738
  for (uint32_t i = threadRank; i < procAdjCount; i += WARP_SIZE) {
739
- const uint32_t adjCol = procAdjValues[i];
 
 
 
 
 
 
 
740
 
741
  // This will set the queued flag for this column, if it's not already set.
742
  // It also returns the old state. In our case, we only want to add this value to the
@@ -748,7 +748,6 @@ void device_flatten_graph_iterative(const uint64_t punCounts,
748
 
749
  bool alreadyAdded = oldMask & ADDED_MASK;
750
 
751
- auto group = cg::coalesced_threads();
752
  const uint32_t gThreadRank = group.thread_rank();
753
  uint32_t notAddedBallot = group.ballot(!alreadyAdded);
754
  if (notAddedBallot) {
@@ -825,8 +824,7 @@ void add_to_set(const torch::TensorAccessor<int32_t, 1>& adjCounts,
825
  }
826
  }
827
 
828
- template<bool IsSingleExample>
829
- void cpu_flatten_graph(const uint64_t punCounts,
830
  torch::Tensor isStartTensorGPU,
831
  torch::Tensor adjCountsTensorGPU,
832
  torch::Tensor adjValuesTensorGPU)
@@ -840,7 +838,7 @@ void cpu_flatten_graph(const uint64_t punCounts,
840
  auto allAdjValues = adjValuesTensor.accessor<int32_t, 3>();
841
 
842
  for (int32_t b = 0; b < allAdjCounts.size(0); ++b) {
843
- const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
844
 
845
  for (int32_t row = 0; row < quadCt; ++row) {
846
  std::unordered_set<int32_t> fullAdjSet;
@@ -895,9 +893,9 @@ void device_a2a_adj_cleanup(const int32_t *counts,
895
  }
896
  }
897
 
898
- template<uint32_t NumWarps, typename T, bool IsSingleExample>
899
  __global__
900
- void device_a2a_collapse(const uint64_t punCounts,
901
  torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
902
  torch::PackedTensorAccessor64<bool, 2> allIsLeadRow,
903
  const int64_t *regionCounts,
@@ -917,16 +915,14 @@ void device_a2a_collapse(const uint64_t punCounts,
917
  const uint32_t b = blockIdx.z;
918
  const uint32_t row = blockIdx.y;
919
 
920
- const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
921
 
922
- if constexpr (!IsSingleExample) {
923
- if (row >= quadCt) {
924
- return;
925
- }
926
  }
927
 
928
  // Only process the lead rows
929
- const auto isLeadRow = IsSingleExample ? allIsLeadRow.data() : allIsLeadRow[b].data();
930
  if (!isLeadRow[row]) {
931
  return;
932
  }
@@ -945,7 +941,7 @@ void device_a2a_collapse(const uint64_t punCounts,
945
  __syncthreads();
946
  }
947
 
948
- T *exData = IsSingleExample ? allEmbedQuads.data() : allEmbedQuads[b].data();
949
 
950
  const int32_t adjCount = allAdjCounts[b][row];
951
  const int32_t *adjIdxs = allAdjValues[b][row].data();
@@ -986,20 +982,12 @@ void device_a2a_collapse(const uint64_t punCounts,
986
 
987
  // Figure out the output position
988
  uint32_t writePosition = 0;
989
- if constexpr (!IsSingleExample) {
990
- for (int32_t i = threadRank; i < b; i += BLOCK_WIDTH) {
991
- writePosition += regionCounts[i];
992
- }
993
  }
994
 
995
- const int32_t numLongs = row >> 3; // Divide by 8
996
  const uint8_t *pCurrIsLeadRow = reinterpret_cast<const uint8_t*>(isLeadRow);
997
- const uint64_t *lpCurrIsLeadRow = reinterpret_cast<const uint64_t*>(pCurrIsLeadRow);
998
-
999
- for (int32_t i = threadRank; i < numLongs; i += BLOCK_WIDTH) {
1000
- writePosition += __popcll(lpCurrIsLeadRow[i]);
1001
- }
1002
- for (int32_t i = (numLongs * 8) + threadRank; i < row; i += BLOCK_WIDTH) {
1003
  if (pCurrIsLeadRow[i]) {
1004
  ++writePosition;
1005
  }
@@ -1075,13 +1063,9 @@ CollapseRowsResult collapse_rows(
1075
  int64_t embedSize = sizeof(EmbedQuad_<scalar_t>) / sizeof(scalar_t);
1076
  auto rowMergeTensor = torch::empty({ quads.size(0), embedSize, quads.size(1) * quads.size(2) }, quads.options());
1077
 
1078
- #ifdef NMS_VERIFY_CORRECTNESS
1079
  auto idsTensor = torch::full({ quads.size(0), quads.size(1) * quads.size(2) },
1080
  std::numeric_limits<int32_t>::max(),
1081
  counts.options().dtype(torch::kInt32));
1082
- #else
1083
- torch::Tensor idsTensor;
1084
- #endif
1085
 
1086
  dim3 blockSize(32, 3, 1);
1087
  dim3 gridSize(1,
@@ -1093,10 +1077,8 @@ CollapseRowsResult collapse_rows(
1093
  probs.packed_accessor64<scalar_t, 3>(),
1094
  probThreshold, iouThreshold,
1095
  counts.packed_accessor64<int32_t, 1>(),
1096
- rowMergeTensor.packed_accessor64<scalar_t, 3>()
1097
- #ifdef NMS_VERIFY_CORRECTNESS
1098
- , idsTensor.packed_accessor64<int32_t, 2>()
1099
- #endif
1100
  );
1101
 
1102
  #ifdef NMS_VERIFY_CORRECTNESS
@@ -1119,7 +1101,6 @@ CollapseRowsResult collapse_rows(
1119
 
1120
  counts = counts.slice(/*dim=*/ 0, 0, counts.size(0) - 1);
1121
 
1122
- #ifdef NMS_VERIFY_CORRECTNESS
1123
  int64_t maxExCount;
1124
  if (counts.size(0) > 1) {
1125
  maxExCount = counts.max().item<int32_t>();
@@ -1131,13 +1112,12 @@ CollapseRowsResult collapse_rows(
1131
 
1132
  rowMergeTensor = rowMergeTensor.slice(2, 0, maxExCount);
1133
  idsTensor = idsTensor.slice(1, 0, maxExCount);
1134
- auto order = torch::argsort(idsTensor, /*dim=*/ 1, s_sortOrder); s_sortOrder = !s_sortOrder;
1135
 
1136
  auto embOrder = order.unsqueeze(1).expand_as(rowMergeTensor);
1137
 
1138
  rowMergeTensor = torch::gather(rowMergeTensor, /*dim=*/ 2, embOrder);
1139
  idsTensor = torch::gather(idsTensor, /*dim=*/ 1, order);
1140
- #endif
1141
 
1142
  return { counts, rowMergeTensor, totalQuads, idsTensor, imageWidth, imageHeight };
1143
  }
@@ -1177,8 +1157,8 @@ struct AdjacencyResult {
1177
  int64_t MaxExCount;
1178
  };
1179
 
1180
- template<bool IsSingleExample, typename T>
1181
- void cpu_a2a_adjacency_sparse(const uint64_t punCounts,
1182
  const T iouThreshold,
1183
  torch::Tensor embedQuadsTensor,
1184
  torch::Tensor outIsStartTensorGPU,
@@ -1196,7 +1176,7 @@ void cpu_a2a_adjacency_sparse(const uint64_t punCounts,
1196
  auto adjValues = outSparseAdjTensor.accessor<int32_t, 3>();
1197
 
1198
  for (int32_t b = 0; b < embedQuadsTensor.size(0); ++b) {
1199
- const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
1200
 
1201
  T *exData = embedQuads[b].data();
1202
 
@@ -1284,13 +1264,6 @@ AdjacencyResult compute_all_to_all_adjacency(
1284
  counts.options().dtype(torch::kInt32));
1285
  #endif
1286
 
1287
- // If the batch is only a single example, instead of hitting global memory for the count, we can
1288
- // just encode the count into the pointer instead
1289
- uint64_t ptrCounts = reinterpret_cast<uint64_t>(counts.data_ptr<int32_t>());
1290
- if (counts.size(0) == 1) {
1291
- ptrCounts = maxExCount;
1292
- }
1293
-
1294
  #ifdef NMS_VERIFY_CORRECTNESS
1295
  auto cpuAdjValuesTensor = adjValuesTensor.cpu();
1296
  auto cpuAdjCountsTensor = adjCountsTensor.cpu();
@@ -1318,23 +1291,15 @@ AdjacencyResult compute_all_to_all_adjacency(
1318
  //blockSize = dim3{ GRID_NUM_WARPS * 32, 1, 1 };
1319
  //gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
1320
 
1321
- //auto buildGridFn = counts.size(0) == 1 ?
1322
- // device_a2a_adjacency_build_grid<GRID_NUM_WARPS, true, scalar_t, CELL_SIZE> :
1323
- // device_a2a_adjacency_build_grid<GRID_NUM_WARPS, false, scalar_t, CELL_SIZE>;
1324
-
1325
- //buildGridFn KERNEL_ARG2(gridSize, blockSize) (
1326
- // ptrCounts,
1327
  // collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
1328
  // gridCellsTensor.packed_accessor64<int32_t, 4>(),
1329
  // quadCellExtentsTensor.packed_accessor64<int32_t, 3>()
1330
  //);
1331
 
1332
- //auto adjGridFn = counts.size(0) == 1 ?
1333
- // device_a2a_adjacency_with_grid<GRID_NUM_WARPS, true, scalar_t> :
1334
- // device_a2a_adjacency_with_grid<GRID_NUM_WARPS, false, scalar_t>;
1335
-
1336
- //adjGridFn KERNEL_ARG3(gridSize, blockSize, smemSize) (
1337
- // ptrCounts,
1338
  // iouThreshold,
1339
  // collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
1340
  // gridCellsTensor.packed_accessor64<int32_t, 4>(),
@@ -1351,11 +1316,9 @@ AdjacencyResult compute_all_to_all_adjacency(
1351
  gridSize = dim3{div_up(totalWork, blockSize.x),
1352
  static_cast<uint32_t>(counts.size(0))};
1353
 
1354
- auto adjFn = counts.size(0) == 1 ? device_a2a_adjacency_sparse<true, scalar_t> : device_a2a_adjacency_sparse<false, scalar_t>;
1355
-
1356
  // This algorithm is O(n^2) with n being the current number of quads
1357
- adjFn KERNEL_ARG2(gridSize, blockSize) (
1358
- ptrCounts,
1359
  iouThreshold,
1360
  collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
1361
  isStartTensor.packed_accessor64<bool, 2>(),
@@ -1365,7 +1328,9 @@ AdjacencyResult compute_all_to_all_adjacency(
1365
 
1366
 
1367
  #ifdef NMS_VERIFY_CORRECTNESS
1368
- cpu_a2a_adjacency_sparse<true>(ptrCounts, iouThreshold,
 
 
1369
  collapseResult.StridedMergeQuads, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
1370
 
1371
  adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
@@ -1380,16 +1345,12 @@ AdjacencyResult compute_all_to_all_adjacency(
1380
  auto maxDepthTensor = torch::tensor(0, adjCountsTensor.options());
1381
  #endif
1382
 
1383
- auto traverseFn = counts.size(0) == 1 ?
1384
- device_flatten_graph_iterative<true> :
1385
- device_flatten_graph_iterative<false>;
1386
-
1387
  blockSize = dim3{ 128, 1, 1 };
1388
  gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
1389
  smemSize = div_up(maxExCount * sizeof(visit_mask_t), sizeof(uint32_t)) * sizeof(uint32_t);
1390
 
1391
- traverseFn KERNEL_ARG3(gridSize, blockSize, smemSize) (
1392
- ptrCounts,
1393
  isStartTensor.packed_accessor64<bool, 2>(),
1394
  reinterpret_cast<uint32_t*>(adjCountsTensor.data_ptr<int32_t>()),
1395
  reinterpret_cast<uint32_t*>(adjValuesTensor.data_ptr<int32_t>())
@@ -1399,7 +1360,7 @@ AdjacencyResult compute_all_to_all_adjacency(
1399
  );
1400
 
1401
  #ifdef NMS_VERIFY_CORRECTNESS
1402
- cpu_flatten_graph<true>(ptrCounts, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
1403
 
1404
  cpuAdjValuesTensor = std::get<0>(torch::sort(cpuAdjValuesTensor, /*dim=*/ 2));
1405
  adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
@@ -1437,7 +1398,6 @@ AdjacencyResult compute_all_to_all_adjacency(
1437
  cpuIsStartTensor = isStartTensor.cpu();
1438
  cpuAdjCountsTensor = adjCountsTensor.cpu();
1439
  cpuAdjValuesTensor = adjValuesTensor.cpu();
1440
- auto cpuCounts = counts.cpu();
1441
  auto cpuCollapseIds = collapseResult.QuadIds.cpu();
1442
 
1443
  static std::vector<std::unordered_set<int32_t>> s_knownGroups;
@@ -1589,22 +1549,11 @@ nms_result_t
1589
  dim3 blockSize(NUM_WARPS * 32, 1, 1);
1590
  dim3 gridSize(1, adjResult.MaxExCount, counts.size(0));
1591
 
1592
- // If the batch is only a single example, instead of hitting global memory for the count, we can
1593
- // just encode the count into the pointer instead
1594
- uint64_t ptrCounts = reinterpret_cast<uint64_t>(counts.data_ptr<int32_t>());
1595
- if (counts.size(0) == 1) {
1596
- ptrCounts = adjResult.MaxExCount;
1597
- }
1598
-
1599
  torch::Tensor outQuads = torch::empty({ numOutQuads, 4, 2 }, embedQuads.options());
1600
  torch::Tensor outConf = torch::empty({ numOutQuads }, embedQuads.options());
1601
 
1602
- auto collapseFn = counts.size(0) == 1 ?
1603
- device_a2a_collapse<NUM_WARPS, scalar_t, true> :
1604
- device_a2a_collapse<NUM_WARPS, scalar_t, false>;
1605
-
1606
- collapseFn KERNEL_ARG2(gridSize, blockSize) (
1607
- ptrCounts,
1608
  embedQuads.packed_accessor64<scalar_t, 3>(),
1609
  isLeadRow.packed_accessor64<bool, 2>(),
1610
  regionCounts.data_ptr<int64_t>(),
 
157
  torch::PackedTensorAccessor64<T, 3> allConfs,
158
  T confThreshold, T iouThreshold,
159
  torch::PackedTensorAccessor64<int32_t, 1> allOutCounts,
160
+ torch::PackedTensorAccessor64<T, 3> allOutEmbedQuads,
161
+ torch::PackedTensorAccessor64<int32_t, 2> allOutIds)
 
 
 
162
  {
163
  typedef InPlaceQuad_<T> Quadf;
164
  static_assert(sizeof(Quadf) == sizeof(T) * 8, "Invalid QuadMem size!");
 
303
  }
304
 
305
  write_embed_quad(outEmbedQuads, outQuad, storeOff + procLabel - 1);
 
306
  if (threadRank == 0) {
307
  allOutIds[b][storeOff + procLabel - 1] = r * 32 + startIdx;
308
  }
 
309
  }
310
 
311
  if (threadRank == 0) {
 
316
  #undef threadRank
317
  }
318
 
319
+ template<typename T>
320
  __global__
321
+ void device_a2a_adjacency_sparse(const int32_t *ptrQuadCts,
322
  T iouThreshold,
323
  torch::PackedTensorAccessor64<T, 3> embedQuads,
324
  torch::PackedTensorAccessor64<bool, 2> outIsStart,
 
327
  {
328
  const uint32_t b = blockIdx.y;
329
 
330
+ const int32_t quadCt = ptrQuadCts[b];
331
+
332
+ if (quadCt == 0) {
333
+ return;
334
+ }
335
 
336
  const int32_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;
337
  const int32_t row = jobIdx / quadCt;
 
342
  return;
343
  }
344
 
345
+ T* exData = embedQuads[b].data();
346
 
347
  const auto qRow = StridedEmbedQuad_<T>{ exData + row * embedQuads.stride(2), embedQuads.stride(1) }.Bounds(),
348
  qCol = StridedEmbedQuad_<T>{ exData + col * embedQuads.stride(2), embedQuads.stride(1) }.Bounds();
 
404
  }
405
  }
406
 
407
+ template<uint32_t NumWarps, typename T, int32_t I_CELL_SIZE>
408
  __global__
409
+ void device_a2a_adjacency_build_grid(const int32_t *ptrQuadCts,
410
  torch::PackedTensorAccessor64<T, 3> embedQuads,
411
  torch::PackedTensorAccessor64<int32_t, 4> outGridCells,
412
  torch::PackedTensorAccessor64<int32_t, 3> outQuadCells)
 
422
 
423
  const uint32_t b = blockIdx.z;
424
 
425
+ const uint32_t quadCt = ptrQuadCts[b];
426
  const uint32_t quadIdx = blockIdx.y;
427
 
428
+ if (quadIdx >= quadCt) {
429
  return;
430
  }
431
 
 
484
 
485
  typedef uint8_t visit_mask_t;
486
 
487
+ template<uint32_t NumWarps, typename T>
488
  __global__
489
+ void device_a2a_adjacency_with_grid(const int32_t *ptrQuadCts,
490
  T iouThreshold,
491
  torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
492
  torch::PackedTensorAccessor64<int32_t, 4> allCells,
 
502
 
503
  const uint32_t b = blockIdx.z;
504
 
505
+ const uint32_t quadCt = ptrQuadCts[b];
506
  const uint32_t quadIdx = blockIdx.y;
507
 
508
+ if (quadIdx >= quadCt) {
509
  return;
510
  }
511
 
 
534
  auto exAdjCounts = reinterpret_cast<uint32_t*>(outAdjCounts[b].data());
535
  auto exAdjValues = outSparseAdj[b][quadIdx].data();
536
 
537
+ T *exData = allEmbedQuads[b].data();
538
 
539
  const auto bdsAnchor = Quad_<T>{ s_quadVerts }.Bounds();
540
 
 
598
  }
599
  }
600
 
 
601
  __global__
602
+ void device_flatten_graph_iterative(const int32_t *ptrQuadCts,
603
  torch::PackedTensorAccessor64<bool, 2> allIsStart,
604
  volatile uint32_t *allAdjCounts,
605
  volatile uint32_t *allAdjValues
 
620
  const uint32_t b = blockIdx.z;
621
  const uint32_t anchorRow = blockIdx.y;
622
 
623
+ const uint32_t quadCt = ptrQuadCts[b];
624
 
625
  // Only need to check this if there are multiple examples, since in the case of a single example,
626
  // the grid is precisely sized to that quadCt
627
+ if (anchorRow >= quadCt) {
628
+ return;
 
 
629
  }
630
 
631
  auto isStart = allIsStart[b].data();
 
686
  visitStack[1] = anchorRow;
687
  #ifndef NDEBUG
688
  for (uint32_t i = 2; i < VISIT_STACK_SIZE; ++i) {
689
+ visitStack[i] = TERM_VALUE;
690
  }
691
  #endif
692
  int32_t visitPtr = 1;
693
 
694
+ // NOTE: This loop is actually terminated by the `if (warpNextCol == TERM_VALUE)` check below
695
+ for (uint32_t dfsIter = 0; true; ++dfsIter) {
696
  #ifdef NMS_VERIFY_CORRECTNESS
697
  assert(visitPtr >= 0 && visitPtr < VISIT_STACK_SIZE);
698
  #endif
 
704
  if (threadNextCol == warpNextCol) {
705
  #ifndef NDEBUG
706
  // This makes it easier to debug where the pointer is
707
+ visitStack[visitPtr] = TERM_VALUE;
708
  #endif
709
  --visitPtr;
710
  }
 
728
  const uint32_t procAdjCount = adjCounts[procRow];
729
  auto procAdjValues = adjValues + (procRow * maxExCount);
730
 
 
 
 
 
731
  for (uint32_t i = threadRank; i < procAdjCount; i += WARP_SIZE) {
732
+ uint32_t adjCol = procAdjValues[i];
733
+
734
+ auto group = cg::coalesced_threads();
735
+ // Offsetting by the iteration number will help balance out the maximum depth of any stack in the warp.
736
+ // The reason behind this is due to how otherwise, warp-0 will always get a new element, warp-1 iff the adj graph
737
+ // has more than one element, warp-2 iff the adj graph has more than two elements, and so on. Basically,
738
+ // the warps have decreasing pressure. With the rotation mechanism, it helps to balance out stack usage.
739
+ adjCol = group.shfl(adjCol, (group.thread_rank() + dfsIter) % group.size());
740
 
741
  // This will set the queued flag for this column, if it's not already set.
742
  // It also returns the old state. In our case, we only want to add this value to the
 
748
 
749
  bool alreadyAdded = oldMask & ADDED_MASK;
750
 
 
751
  const uint32_t gThreadRank = group.thread_rank();
752
  uint32_t notAddedBallot = group.ballot(!alreadyAdded);
753
  if (notAddedBallot) {
 
824
  }
825
  }
826
 
827
+ void cpu_flatten_graph(const int32_t *ptrQuadCts,
 
828
  torch::Tensor isStartTensorGPU,
829
  torch::Tensor adjCountsTensorGPU,
830
  torch::Tensor adjValuesTensorGPU)
 
838
  auto allAdjValues = adjValuesTensor.accessor<int32_t, 3>();
839
 
840
  for (int32_t b = 0; b < allAdjCounts.size(0); ++b) {
841
+ const int32_t quadCt = ptrQuadCts[b];
842
 
843
  for (int32_t row = 0; row < quadCt; ++row) {
844
  std::unordered_set<int32_t> fullAdjSet;
 
893
  }
894
  }
895
 
896
+ template<uint32_t NumWarps, typename T>
897
  __global__
898
+ void device_a2a_collapse(torch::PackedTensorAccessor64<int32_t, 1> quadCounts,
899
  torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
900
  torch::PackedTensorAccessor64<bool, 2> allIsLeadRow,
901
  const int64_t *regionCounts,
 
915
  const uint32_t b = blockIdx.z;
916
  const uint32_t row = blockIdx.y;
917
 
918
+ const int32_t quadCt = quadCounts[b];
919
 
920
+ if (row >= quadCt) {
921
+ return;
 
 
922
  }
923
 
924
  // Only process the lead rows
925
+ const auto isLeadRow = allIsLeadRow[b].data();
926
  if (!isLeadRow[row]) {
927
  return;
928
  }
 
941
  __syncthreads();
942
  }
943
 
944
+ T *exData = allEmbedQuads[b].data();
945
 
946
  const int32_t adjCount = allAdjCounts[b][row];
947
  const int32_t *adjIdxs = allAdjValues[b][row].data();
 
982
 
983
  // Figure out the output position
984
  uint32_t writePosition = 0;
985
+ for (int32_t i = threadRank; i < b; i += BLOCK_WIDTH) {
986
+ writePosition += regionCounts[i];
 
 
987
  }
988
 
 
989
  const uint8_t *pCurrIsLeadRow = reinterpret_cast<const uint8_t*>(isLeadRow);
990
+ for (int32_t i = threadRank; i < row; i += BLOCK_WIDTH) {
 
 
 
 
 
991
  if (pCurrIsLeadRow[i]) {
992
  ++writePosition;
993
  }
 
1063
  int64_t embedSize = sizeof(EmbedQuad_<scalar_t>) / sizeof(scalar_t);
1064
  auto rowMergeTensor = torch::empty({ quads.size(0), embedSize, quads.size(1) * quads.size(2) }, quads.options());
1065
 
 
1066
  auto idsTensor = torch::full({ quads.size(0), quads.size(1) * quads.size(2) },
1067
  std::numeric_limits<int32_t>::max(),
1068
  counts.options().dtype(torch::kInt32));
 
 
 
1069
 
1070
  dim3 blockSize(32, 3, 1);
1071
  dim3 gridSize(1,
 
1077
  probs.packed_accessor64<scalar_t, 3>(),
1078
  probThreshold, iouThreshold,
1079
  counts.packed_accessor64<int32_t, 1>(),
1080
+ rowMergeTensor.packed_accessor64<scalar_t, 3>(),
1081
+ idsTensor.packed_accessor64<int32_t, 2>()
 
 
1082
  );
1083
 
1084
  #ifdef NMS_VERIFY_CORRECTNESS
 
1101
 
1102
  counts = counts.slice(/*dim=*/ 0, 0, counts.size(0) - 1);
1103
 
 
1104
  int64_t maxExCount;
1105
  if (counts.size(0) > 1) {
1106
  maxExCount = counts.max().item<int32_t>();
 
1112
 
1113
  rowMergeTensor = rowMergeTensor.slice(2, 0, maxExCount);
1114
  idsTensor = idsTensor.slice(1, 0, maxExCount);
1115
+ auto order = torch::argsort(idsTensor, /*dim=*/ 1, s_sortOrder);
1116
 
1117
  auto embOrder = order.unsqueeze(1).expand_as(rowMergeTensor);
1118
 
1119
  rowMergeTensor = torch::gather(rowMergeTensor, /*dim=*/ 2, embOrder);
1120
  idsTensor = torch::gather(idsTensor, /*dim=*/ 1, order);
 
1121
 
1122
  return { counts, rowMergeTensor, totalQuads, idsTensor, imageWidth, imageHeight };
1123
  }
 
1157
  int64_t MaxExCount;
1158
  };
1159
 
1160
+ template<typename T>
1161
+ void cpu_a2a_adjacency_sparse(const int32_t *ptrQuadCts,
1162
  const T iouThreshold,
1163
  torch::Tensor embedQuadsTensor,
1164
  torch::Tensor outIsStartTensorGPU,
 
1176
  auto adjValues = outSparseAdjTensor.accessor<int32_t, 3>();
1177
 
1178
  for (int32_t b = 0; b < embedQuadsTensor.size(0); ++b) {
1179
+ const int32_t quadCt = ptrQuadCts[b];
1180
 
1181
  T *exData = embedQuads[b].data();
1182
 
 
1264
  counts.options().dtype(torch::kInt32));
1265
  #endif
1266
 
 
 
 
 
 
 
 
1267
  #ifdef NMS_VERIFY_CORRECTNESS
1268
  auto cpuAdjValuesTensor = adjValuesTensor.cpu();
1269
  auto cpuAdjCountsTensor = adjCountsTensor.cpu();
 
1291
  //blockSize = dim3{ GRID_NUM_WARPS * 32, 1, 1 };
1292
  //gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
1293
 
1294
+ //device_a2a_adjacency_build_grid<GRID_NUM_WARPS, scalar_t, CELL_SIZE> KERNEL_ARG2(gridSize, blockSize) (
1295
+ // counts.data_ptr<int32_t>(),
 
 
 
 
1296
  // collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
1297
  // gridCellsTensor.packed_accessor64<int32_t, 4>(),
1298
  // quadCellExtentsTensor.packed_accessor64<int32_t, 3>()
1299
  //);
1300
 
1301
+ //device_a2a_adjacency_with_grid<GRID_NUM_WARPS, scalar_t> KERNEL_ARG3(gridSize, blockSize, smemSize) (
1302
+ // counts.data_ptr<int32_t>(),
 
 
 
 
1303
  // iouThreshold,
1304
  // collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
1305
  // gridCellsTensor.packed_accessor64<int32_t, 4>(),
 
1316
  gridSize = dim3{div_up(totalWork, blockSize.x),
1317
  static_cast<uint32_t>(counts.size(0))};
1318
 
 
 
1319
  // This algorithm is O(n^2) with n being the current number of quads
1320
+ device_a2a_adjacency_sparse<scalar_t> KERNEL_ARG2(gridSize, blockSize) (
1321
+ counts.data_ptr<int32_t>(),
1322
  iouThreshold,
1323
  collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
1324
  isStartTensor.packed_accessor64<bool, 2>(),
 
1328
 
1329
 
1330
  #ifdef NMS_VERIFY_CORRECTNESS
1331
+ auto cpuCounts = counts.cpu();
1332
+
1333
+ cpu_a2a_adjacency_sparse<scalar_t>(cpuCounts.data_ptr<int32_t>(), iouThreshold,
1334
  collapseResult.StridedMergeQuads, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
1335
 
1336
  adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
 
1345
  auto maxDepthTensor = torch::tensor(0, adjCountsTensor.options());
1346
  #endif
1347
 
 
 
 
 
1348
  blockSize = dim3{ 128, 1, 1 };
1349
  gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
1350
  smemSize = div_up(maxExCount * sizeof(visit_mask_t), sizeof(uint32_t)) * sizeof(uint32_t);
1351
 
1352
+ device_flatten_graph_iterative KERNEL_ARG3(gridSize, blockSize, smemSize) (
1353
+ counts.data_ptr<int32_t>(),
1354
  isStartTensor.packed_accessor64<bool, 2>(),
1355
  reinterpret_cast<uint32_t*>(adjCountsTensor.data_ptr<int32_t>()),
1356
  reinterpret_cast<uint32_t*>(adjValuesTensor.data_ptr<int32_t>())
 
1360
  );
1361
 
1362
  #ifdef NMS_VERIFY_CORRECTNESS
1363
+ cpu_flatten_graph(cpuCounts.data_ptr<int32_t>(), cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
1364
 
1365
  cpuAdjValuesTensor = std::get<0>(torch::sort(cpuAdjValuesTensor, /*dim=*/ 2));
1366
  adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
 
1398
  cpuIsStartTensor = isStartTensor.cpu();
1399
  cpuAdjCountsTensor = adjCountsTensor.cpu();
1400
  cpuAdjValuesTensor = adjValuesTensor.cpu();
 
1401
  auto cpuCollapseIds = collapseResult.QuadIds.cpu();
1402
 
1403
  static std::vector<std::unordered_set<int32_t>> s_knownGroups;
 
1549
  dim3 blockSize(NUM_WARPS * 32, 1, 1);
1550
  dim3 gridSize(1, adjResult.MaxExCount, counts.size(0));
1551
 
 
 
 
 
 
 
 
1552
  torch::Tensor outQuads = torch::empty({ numOutQuads, 4, 2 }, embedQuads.options());
1553
  torch::Tensor outConf = torch::empty({ numOutQuads }, embedQuads.options());
1554
 
1555
+ device_a2a_collapse<NUM_WARPS, scalar_t> KERNEL_ARG2(gridSize, blockSize) (
1556
+ counts.packed_accessor64<int32_t, 1>(),
 
 
 
 
1557
  embedQuads.packed_accessor64<scalar_t, 3>(),
1558
  isLeadRow.packed_accessor64<bool, 2>(),
1559
  regionCounts.data_ptr<int64_t>(),
nemotron-ocr/src/nemotron_ocr/inference/pipeline.py CHANGED
@@ -181,29 +181,17 @@ class NemotronOCR:
181
  e2e_det_conf = torch.sigmoid(det_conf)
182
  e2e_det_coords = rrect_to_quads(det_rboxes.float(), DETECTOR_DOWNSAMPLE)
183
 
184
- # FIXME: quad_non_maximal_suppression fails with batch size > 1
185
- all_quads = []
186
- all_confidence = []
187
- all_region_counts = []
188
-
189
- for idx in range(e2e_det_coords.shape[0]):
190
- quads, confidence, region_counts = quad_non_maximal_suppression(
191
- e2e_det_coords[idx].unsqueeze(0),
192
- e2e_det_conf[idx].unsqueeze(0),
193
- prob_threshold=NMS_PROB_THRESHOLD,
194
- iou_threshold=NMS_IOU_THRESHOLD,
195
- kernel_height=2,
196
- kernel_width=3,
197
- max_regions=NMS_MAX_REGIONS,
198
- verbose=False,
199
- )[:3]
200
- all_quads.append(quads)
201
- all_confidence.append(confidence)
202
- all_region_counts.append(region_counts)
203
-
204
- quads = torch.cat(all_quads, dim=0)
205
- confidence = torch.cat(all_confidence, dim=0)
206
- region_counts = torch.cat(all_region_counts, dim=0)
207
 
208
  if quads.shape[0] == 0:
209
  rec_rectified_quads = torch.empty(0, 128, 8, 32, dtype=torch.float32, device=padded_image.device)
 
181
  e2e_det_conf = torch.sigmoid(det_conf)
182
  e2e_det_coords = rrect_to_quads(det_rboxes.float(), DETECTOR_DOWNSAMPLE)
183
 
184
+ quads, confidence, region_counts = quad_non_maximal_suppression(
185
+ e2e_det_coords[idx].unsqueeze(0),
186
+ e2e_det_conf[idx].unsqueeze(0),
187
+ prob_threshold=NMS_PROB_THRESHOLD,
188
+ iou_threshold=NMS_IOU_THRESHOLD,
189
+ kernel_height=2,
190
+ kernel_width=3,
191
+ max_regions=NMS_MAX_REGIONS,
192
+ verbose=False,
193
+ )[:3]
194
+
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  if quads.shape[0] == 0:
197
  rec_rectified_quads = torch.empty(0, 128, 8, 32, dtype=torch.float32, device=padded_image.device)