https://huggingface.co/nvidia/nemotron-ocr-v1/tree/main

#8
by Eklavya214 - opened
example.py CHANGED
@@ -8,7 +8,7 @@ from nemotron_ocr.inference.pipeline import NemotronOCR
8
 
9
 
10
  def main(image_path, merge_level, no_visualize, model_dir):
11
- ocr_pipeline = NemotronOCR(model_dir=model_dir)
12
 
13
  predictions = ocr_pipeline(image_path, merge_level=merge_level, visualize=not no_visualize)
14
 
 
8
 
9
 
10
  def main(image_path, merge_level, no_visualize, model_dir):
11
+ ocr_pipeline = NemotronOCR()
12
 
13
  predictions = ocr_pipeline(image_path, merge_level=merge_level, visualize=not no_visualize)
14
 
nemotron-ocr/cpp/non_maximal_suppression/cuda_non_maximal_suppression.cu CHANGED
@@ -157,8 +157,11 @@ void device_row_collapse(torch::PackedTensorAccessor64<T, 5> allQuads,
157
  torch::PackedTensorAccessor64<T, 3> allConfs,
158
  T confThreshold, T iouThreshold,
159
  torch::PackedTensorAccessor64<int32_t, 1> allOutCounts,
160
- torch::PackedTensorAccessor64<T, 3> allOutEmbedQuads,
161
- torch::PackedTensorAccessor64<int32_t, 2> allOutIds)
 
 
 
162
  {
163
  typedef InPlaceQuad_<T> Quadf;
164
  static_assert(sizeof(Quadf) == sizeof(T) * 8, "Invalid QuadMem size!");
@@ -303,9 +306,11 @@ void device_row_collapse(torch::PackedTensorAccessor64<T, 5> allQuads,
303
  }
304
 
305
  write_embed_quad(outEmbedQuads, outQuad, storeOff + procLabel - 1);
 
306
  if (threadRank == 0) {
307
  allOutIds[b][storeOff + procLabel - 1] = r * 32 + startIdx;
308
  }
 
309
  }
310
 
311
  if (threadRank == 0) {
@@ -316,9 +321,9 @@ void device_row_collapse(torch::PackedTensorAccessor64<T, 5> allQuads,
316
  #undef threadRank
317
  }
318
 
319
- template<typename T>
320
  __global__
321
- void device_a2a_adjacency_sparse(const int32_t *ptrQuadCts,
322
  T iouThreshold,
323
  torch::PackedTensorAccessor64<T, 3> embedQuads,
324
  torch::PackedTensorAccessor64<bool, 2> outIsStart,
@@ -327,11 +332,7 @@ void device_a2a_adjacency_sparse(const int32_t *ptrQuadCts,
327
  {
328
  const uint32_t b = blockIdx.y;
329
 
330
- const int32_t quadCt = ptrQuadCts[b];
331
-
332
- if (quadCt == 0) {
333
- return;
334
- }
335
 
336
  const int32_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;
337
  const int32_t row = jobIdx / quadCt;
@@ -342,7 +343,7 @@ void device_a2a_adjacency_sparse(const int32_t *ptrQuadCts,
342
  return;
343
  }
344
 
345
- T* exData = embedQuads[b].data();
346
 
347
  const auto qRow = StridedEmbedQuad_<T>{ exData + row * embedQuads.stride(2), embedQuads.stride(1) }.Bounds(),
348
  qCol = StridedEmbedQuad_<T>{ exData + col * embedQuads.stride(2), embedQuads.stride(1) }.Bounds();
@@ -404,9 +405,9 @@ void device_a2a_adjacency_sparse(const int32_t *ptrQuadCts,
404
  }
405
  }
406
 
407
- template<uint32_t NumWarps, typename T, int32_t I_CELL_SIZE>
408
  __global__
409
- void device_a2a_adjacency_build_grid(const int32_t *ptrQuadCts,
410
  torch::PackedTensorAccessor64<T, 3> embedQuads,
411
  torch::PackedTensorAccessor64<int32_t, 4> outGridCells,
412
  torch::PackedTensorAccessor64<int32_t, 3> outQuadCells)
@@ -422,10 +423,10 @@ void device_a2a_adjacency_build_grid(const int32_t *ptrQuadCts,
422
 
423
  const uint32_t b = blockIdx.z;
424
 
425
- const uint32_t quadCt = ptrQuadCts[b];
426
  const uint32_t quadIdx = blockIdx.y;
427
 
428
- if (quadIdx >= quadCt) {
429
  return;
430
  }
431
 
@@ -484,9 +485,9 @@ void device_a2a_adjacency_build_grid(const int32_t *ptrQuadCts,
484
 
485
  typedef uint8_t visit_mask_t;
486
 
487
- template<uint32_t NumWarps, typename T>
488
  __global__
489
- void device_a2a_adjacency_with_grid(const int32_t *ptrQuadCts,
490
  T iouThreshold,
491
  torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
492
  torch::PackedTensorAccessor64<int32_t, 4> allCells,
@@ -502,10 +503,10 @@ void device_a2a_adjacency_with_grid(const int32_t *ptrQuadCts,
502
 
503
  const uint32_t b = blockIdx.z;
504
 
505
- const uint32_t quadCt = ptrQuadCts[b];
506
  const uint32_t quadIdx = blockIdx.y;
507
 
508
- if (quadIdx >= quadCt) {
509
  return;
510
  }
511
 
@@ -534,7 +535,7 @@ void device_a2a_adjacency_with_grid(const int32_t *ptrQuadCts,
534
  auto exAdjCounts = reinterpret_cast<uint32_t*>(outAdjCounts[b].data());
535
  auto exAdjValues = outSparseAdj[b][quadIdx].data();
536
 
537
- T *exData = allEmbedQuads[b].data();
538
 
539
  const auto bdsAnchor = Quad_<T>{ s_quadVerts }.Bounds();
540
 
@@ -598,8 +599,9 @@ void device_a2a_adjacency_with_grid(const int32_t *ptrQuadCts,
598
  }
599
  }
600
 
 
601
  __global__
602
- void device_flatten_graph_iterative(const int32_t *ptrQuadCts,
603
  torch::PackedTensorAccessor64<bool, 2> allIsStart,
604
  volatile uint32_t *allAdjCounts,
605
  volatile uint32_t *allAdjValues
@@ -620,12 +622,14 @@ void device_flatten_graph_iterative(const int32_t *ptrQuadCts,
620
  const uint32_t b = blockIdx.z;
621
  const uint32_t anchorRow = blockIdx.y;
622
 
623
- const uint32_t quadCt = ptrQuadCts[b];
624
 
625
  // Only need to check this if there are multiple examples, since in the case of a single example,
626
  // the grid is precisely sized to that quadCt
627
- if (anchorRow >= quadCt) {
628
- return;
 
 
629
  }
630
 
631
  auto isStart = allIsStart[b].data();
@@ -686,13 +690,12 @@ void device_flatten_graph_iterative(const int32_t *ptrQuadCts,
686
  visitStack[1] = anchorRow;
687
  #ifndef NDEBUG
688
  for (uint32_t i = 2; i < VISIT_STACK_SIZE; ++i) {
689
- visitStack[i] = TERM_VALUE;
690
  }
691
  #endif
692
  int32_t visitPtr = 1;
693
 
694
- // NOTE: This loop is actually terminated by the `if (warpNextCol == TERM_VALUE)` check below
695
- for (uint32_t dfsIter = 0; true; ++dfsIter) {
696
  #ifdef NMS_VERIFY_CORRECTNESS
697
  assert(visitPtr >= 0 && visitPtr < VISIT_STACK_SIZE);
698
  #endif
@@ -704,7 +707,7 @@ void device_flatten_graph_iterative(const int32_t *ptrQuadCts,
704
  if (threadNextCol == warpNextCol) {
705
  #ifndef NDEBUG
706
  // This makes it easier to debug where the pointer is
707
- visitStack[visitPtr] = TERM_VALUE;
708
  #endif
709
  --visitPtr;
710
  }
@@ -728,15 +731,12 @@ void device_flatten_graph_iterative(const int32_t *ptrQuadCts,
728
  const uint32_t procAdjCount = adjCounts[procRow];
729
  auto procAdjValues = adjValues + (procRow * maxExCount);
730
 
 
 
 
 
731
  for (uint32_t i = threadRank; i < procAdjCount; i += WARP_SIZE) {
732
- uint32_t adjCol = procAdjValues[i];
733
-
734
- auto group = cg::coalesced_threads();
735
- // Offsetting by the iteration number will help balance out the maximum depth of any stack in the warp.
736
- // The reason behind this is due to how otherwise, warp-0 will always get a new element, warp-1 iff the adj graph
737
- // has more than one element, warp-2 iff the adj graph has more than two elements, and so on. Basically,
738
- // the warps have decreasing pressure. With the rotation mechanism, it helps to balance out stack usage.
739
- adjCol = group.shfl(adjCol, (group.thread_rank() + dfsIter) % group.size());
740
 
741
  // This will set the queued flag for this column, if it's not already set.
742
  // It also returns the old state. In our case, we only want to add this value to the
@@ -748,6 +748,7 @@ void device_flatten_graph_iterative(const int32_t *ptrQuadCts,
748
 
749
  bool alreadyAdded = oldMask & ADDED_MASK;
750
 
 
751
  const uint32_t gThreadRank = group.thread_rank();
752
  uint32_t notAddedBallot = group.ballot(!alreadyAdded);
753
  if (notAddedBallot) {
@@ -824,7 +825,8 @@ void add_to_set(const torch::TensorAccessor<int32_t, 1>& adjCounts,
824
  }
825
  }
826
 
827
- void cpu_flatten_graph(const int32_t *ptrQuadCts,
 
828
  torch::Tensor isStartTensorGPU,
829
  torch::Tensor adjCountsTensorGPU,
830
  torch::Tensor adjValuesTensorGPU)
@@ -838,7 +840,7 @@ void cpu_flatten_graph(const int32_t *ptrQuadCts,
838
  auto allAdjValues = adjValuesTensor.accessor<int32_t, 3>();
839
 
840
  for (int32_t b = 0; b < allAdjCounts.size(0); ++b) {
841
- const int32_t quadCt = ptrQuadCts[b];
842
 
843
  for (int32_t row = 0; row < quadCt; ++row) {
844
  std::unordered_set<int32_t> fullAdjSet;
@@ -893,9 +895,9 @@ void device_a2a_adj_cleanup(const int32_t *counts,
893
  }
894
  }
895
 
896
- template<uint32_t NumWarps, typename T>
897
  __global__
898
- void device_a2a_collapse(torch::PackedTensorAccessor64<int32_t, 1> quadCounts,
899
  torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
900
  torch::PackedTensorAccessor64<bool, 2> allIsLeadRow,
901
  const int64_t *regionCounts,
@@ -915,14 +917,16 @@ void device_a2a_collapse(torch::PackedTensorAccessor64<int32_t, 1> quadCounts,
915
  const uint32_t b = blockIdx.z;
916
  const uint32_t row = blockIdx.y;
917
 
918
- const int32_t quadCt = quadCounts[b];
919
 
920
- if (row >= quadCt) {
921
- return;
 
 
922
  }
923
 
924
  // Only process the lead rows
925
- const auto isLeadRow = allIsLeadRow[b].data();
926
  if (!isLeadRow[row]) {
927
  return;
928
  }
@@ -941,7 +945,7 @@ void device_a2a_collapse(torch::PackedTensorAccessor64<int32_t, 1> quadCounts,
941
  __syncthreads();
942
  }
943
 
944
- T *exData = allEmbedQuads[b].data();
945
 
946
  const int32_t adjCount = allAdjCounts[b][row];
947
  const int32_t *adjIdxs = allAdjValues[b][row].data();
@@ -982,12 +986,20 @@ void device_a2a_collapse(torch::PackedTensorAccessor64<int32_t, 1> quadCounts,
982
 
983
  // Figure out the output position
984
  uint32_t writePosition = 0;
985
- for (int32_t i = threadRank; i < b; i += BLOCK_WIDTH) {
986
- writePosition += regionCounts[i];
 
 
987
  }
988
 
 
989
  const uint8_t *pCurrIsLeadRow = reinterpret_cast<const uint8_t*>(isLeadRow);
990
- for (int32_t i = threadRank; i < row; i += BLOCK_WIDTH) {
 
 
 
 
 
991
  if (pCurrIsLeadRow[i]) {
992
  ++writePosition;
993
  }
@@ -1063,9 +1075,13 @@ CollapseRowsResult collapse_rows(
1063
  int64_t embedSize = sizeof(EmbedQuad_<scalar_t>) / sizeof(scalar_t);
1064
  auto rowMergeTensor = torch::empty({ quads.size(0), embedSize, quads.size(1) * quads.size(2) }, quads.options());
1065
 
 
1066
  auto idsTensor = torch::full({ quads.size(0), quads.size(1) * quads.size(2) },
1067
  std::numeric_limits<int32_t>::max(),
1068
  counts.options().dtype(torch::kInt32));
 
 
 
1069
 
1070
  dim3 blockSize(32, 3, 1);
1071
  dim3 gridSize(1,
@@ -1077,8 +1093,10 @@ CollapseRowsResult collapse_rows(
1077
  probs.packed_accessor64<scalar_t, 3>(),
1078
  probThreshold, iouThreshold,
1079
  counts.packed_accessor64<int32_t, 1>(),
1080
- rowMergeTensor.packed_accessor64<scalar_t, 3>(),
1081
- idsTensor.packed_accessor64<int32_t, 2>()
 
 
1082
  );
1083
 
1084
  #ifdef NMS_VERIFY_CORRECTNESS
@@ -1101,6 +1119,7 @@ CollapseRowsResult collapse_rows(
1101
 
1102
  counts = counts.slice(/*dim=*/ 0, 0, counts.size(0) - 1);
1103
 
 
1104
  int64_t maxExCount;
1105
  if (counts.size(0) > 1) {
1106
  maxExCount = counts.max().item<int32_t>();
@@ -1112,12 +1131,13 @@ CollapseRowsResult collapse_rows(
1112
 
1113
  rowMergeTensor = rowMergeTensor.slice(2, 0, maxExCount);
1114
  idsTensor = idsTensor.slice(1, 0, maxExCount);
1115
- auto order = torch::argsort(idsTensor, /*dim=*/ 1, s_sortOrder);
1116
 
1117
  auto embOrder = order.unsqueeze(1).expand_as(rowMergeTensor);
1118
 
1119
  rowMergeTensor = torch::gather(rowMergeTensor, /*dim=*/ 2, embOrder);
1120
  idsTensor = torch::gather(idsTensor, /*dim=*/ 1, order);
 
1121
 
1122
  return { counts, rowMergeTensor, totalQuads, idsTensor, imageWidth, imageHeight };
1123
  }
@@ -1157,8 +1177,8 @@ struct AdjacencyResult {
1157
  int64_t MaxExCount;
1158
  };
1159
 
1160
- template<typename T>
1161
- void cpu_a2a_adjacency_sparse(const int32_t *ptrQuadCts,
1162
  const T iouThreshold,
1163
  torch::Tensor embedQuadsTensor,
1164
  torch::Tensor outIsStartTensorGPU,
@@ -1176,7 +1196,7 @@ void cpu_a2a_adjacency_sparse(const int32_t *ptrQuadCts,
1176
  auto adjValues = outSparseAdjTensor.accessor<int32_t, 3>();
1177
 
1178
  for (int32_t b = 0; b < embedQuadsTensor.size(0); ++b) {
1179
- const int32_t quadCt = ptrQuadCts[b];
1180
 
1181
  T *exData = embedQuads[b].data();
1182
 
@@ -1264,6 +1284,13 @@ AdjacencyResult compute_all_to_all_adjacency(
1264
  counts.options().dtype(torch::kInt32));
1265
  #endif
1266
 
 
 
 
 
 
 
 
1267
  #ifdef NMS_VERIFY_CORRECTNESS
1268
  auto cpuAdjValuesTensor = adjValuesTensor.cpu();
1269
  auto cpuAdjCountsTensor = adjCountsTensor.cpu();
@@ -1291,15 +1318,23 @@ AdjacencyResult compute_all_to_all_adjacency(
1291
  //blockSize = dim3{ GRID_NUM_WARPS * 32, 1, 1 };
1292
  //gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
1293
 
1294
- //device_a2a_adjacency_build_grid<GRID_NUM_WARPS, scalar_t, CELL_SIZE> KERNEL_ARG2(gridSize, blockSize) (
1295
- // counts.data_ptr<int32_t>(),
 
 
 
 
1296
  // collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
1297
  // gridCellsTensor.packed_accessor64<int32_t, 4>(),
1298
  // quadCellExtentsTensor.packed_accessor64<int32_t, 3>()
1299
  //);
1300
 
1301
- //device_a2a_adjacency_with_grid<GRID_NUM_WARPS, scalar_t> KERNEL_ARG3(gridSize, blockSize, smemSize) (
1302
- // counts.data_ptr<int32_t>(),
 
 
 
 
1303
  // iouThreshold,
1304
  // collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
1305
  // gridCellsTensor.packed_accessor64<int32_t, 4>(),
@@ -1316,9 +1351,11 @@ AdjacencyResult compute_all_to_all_adjacency(
1316
  gridSize = dim3{div_up(totalWork, blockSize.x),
1317
  static_cast<uint32_t>(counts.size(0))};
1318
 
 
 
1319
  // This algorithm is O(n^2) with n being the current number of quads
1320
- device_a2a_adjacency_sparse<scalar_t> KERNEL_ARG2(gridSize, blockSize) (
1321
- counts.data_ptr<int32_t>(),
1322
  iouThreshold,
1323
  collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
1324
  isStartTensor.packed_accessor64<bool, 2>(),
@@ -1328,9 +1365,7 @@ AdjacencyResult compute_all_to_all_adjacency(
1328
 
1329
 
1330
  #ifdef NMS_VERIFY_CORRECTNESS
1331
- auto cpuCounts = counts.cpu();
1332
-
1333
- cpu_a2a_adjacency_sparse<scalar_t>(cpuCounts.data_ptr<int32_t>(), iouThreshold,
1334
  collapseResult.StridedMergeQuads, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
1335
 
1336
  adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
@@ -1345,12 +1380,16 @@ AdjacencyResult compute_all_to_all_adjacency(
1345
  auto maxDepthTensor = torch::tensor(0, adjCountsTensor.options());
1346
  #endif
1347
 
 
 
 
 
1348
  blockSize = dim3{ 128, 1, 1 };
1349
  gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
1350
  smemSize = div_up(maxExCount * sizeof(visit_mask_t), sizeof(uint32_t)) * sizeof(uint32_t);
1351
 
1352
- device_flatten_graph_iterative KERNEL_ARG3(gridSize, blockSize, smemSize) (
1353
- counts.data_ptr<int32_t>(),
1354
  isStartTensor.packed_accessor64<bool, 2>(),
1355
  reinterpret_cast<uint32_t*>(adjCountsTensor.data_ptr<int32_t>()),
1356
  reinterpret_cast<uint32_t*>(adjValuesTensor.data_ptr<int32_t>())
@@ -1360,7 +1399,7 @@ AdjacencyResult compute_all_to_all_adjacency(
1360
  );
1361
 
1362
  #ifdef NMS_VERIFY_CORRECTNESS
1363
- cpu_flatten_graph(cpuCounts.data_ptr<int32_t>(), cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
1364
 
1365
  cpuAdjValuesTensor = std::get<0>(torch::sort(cpuAdjValuesTensor, /*dim=*/ 2));
1366
  adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
@@ -1398,6 +1437,7 @@ AdjacencyResult compute_all_to_all_adjacency(
1398
  cpuIsStartTensor = isStartTensor.cpu();
1399
  cpuAdjCountsTensor = adjCountsTensor.cpu();
1400
  cpuAdjValuesTensor = adjValuesTensor.cpu();
 
1401
  auto cpuCollapseIds = collapseResult.QuadIds.cpu();
1402
 
1403
  static std::vector<std::unordered_set<int32_t>> s_knownGroups;
@@ -1549,11 +1589,22 @@ nms_result_t
1549
  dim3 blockSize(NUM_WARPS * 32, 1, 1);
1550
  dim3 gridSize(1, adjResult.MaxExCount, counts.size(0));
1551
 
 
 
 
 
 
 
 
1552
  torch::Tensor outQuads = torch::empty({ numOutQuads, 4, 2 }, embedQuads.options());
1553
  torch::Tensor outConf = torch::empty({ numOutQuads }, embedQuads.options());
1554
 
1555
- device_a2a_collapse<NUM_WARPS, scalar_t> KERNEL_ARG2(gridSize, blockSize) (
1556
- counts.packed_accessor64<int32_t, 1>(),
 
 
 
 
1557
  embedQuads.packed_accessor64<scalar_t, 3>(),
1558
  isLeadRow.packed_accessor64<bool, 2>(),
1559
  regionCounts.data_ptr<int64_t>(),
 
157
  torch::PackedTensorAccessor64<T, 3> allConfs,
158
  T confThreshold, T iouThreshold,
159
  torch::PackedTensorAccessor64<int32_t, 1> allOutCounts,
160
+ torch::PackedTensorAccessor64<T, 3> allOutEmbedQuads
161
+ #ifdef NMS_VERIFY_CORRECTNESS
162
+ , torch::PackedTensorAccessor64<int32_t, 2> allOutIds
163
+ #endif
164
+ )
165
  {
166
  typedef InPlaceQuad_<T> Quadf;
167
  static_assert(sizeof(Quadf) == sizeof(T) * 8, "Invalid QuadMem size!");
 
306
  }
307
 
308
  write_embed_quad(outEmbedQuads, outQuad, storeOff + procLabel - 1);
309
+ #ifdef NMS_VERIFY_CORRECTNESS
310
  if (threadRank == 0) {
311
  allOutIds[b][storeOff + procLabel - 1] = r * 32 + startIdx;
312
  }
313
+ #endif
314
  }
315
 
316
  if (threadRank == 0) {
 
321
  #undef threadRank
322
  }
323
 
324
+ template<bool IsSingleExample, typename T>
325
  __global__
326
+ void device_a2a_adjacency_sparse(const uint64_t punCounts,
327
  T iouThreshold,
328
  torch::PackedTensorAccessor64<T, 3> embedQuads,
329
  torch::PackedTensorAccessor64<bool, 2> outIsStart,
 
332
  {
333
  const uint32_t b = blockIdx.y;
334
 
335
+ const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
 
 
 
 
336
 
337
  const int32_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;
338
  const int32_t row = jobIdx / quadCt;
 
343
  return;
344
  }
345
 
346
+ T* exData = IsSingleExample ? embedQuads.data() : embedQuads[b].data();
347
 
348
  const auto qRow = StridedEmbedQuad_<T>{ exData + row * embedQuads.stride(2), embedQuads.stride(1) }.Bounds(),
349
  qCol = StridedEmbedQuad_<T>{ exData + col * embedQuads.stride(2), embedQuads.stride(1) }.Bounds();
 
405
  }
406
  }
407
 
408
+ template<uint32_t NumWarps, bool IsSingleExample, typename T, int32_t I_CELL_SIZE>
409
  __global__
410
+ void device_a2a_adjacency_build_grid(const uint64_t punCounts,
411
  torch::PackedTensorAccessor64<T, 3> embedQuads,
412
  torch::PackedTensorAccessor64<int32_t, 4> outGridCells,
413
  torch::PackedTensorAccessor64<int32_t, 3> outQuadCells)
 
423
 
424
  const uint32_t b = blockIdx.z;
425
 
426
+ const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
427
  const uint32_t quadIdx = blockIdx.y;
428
 
429
+ if (!IsSingleExample && quadIdx >= quadCt) {
430
  return;
431
  }
432
 
 
485
 
486
  typedef uint8_t visit_mask_t;
487
 
488
+ template<uint32_t NumWarps, bool IsSingleExample, typename T>
489
  __global__
490
+ void device_a2a_adjacency_with_grid(const uint64_t punCounts,
491
  T iouThreshold,
492
  torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
493
  torch::PackedTensorAccessor64<int32_t, 4> allCells,
 
503
 
504
  const uint32_t b = blockIdx.z;
505
 
506
+ const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
507
  const uint32_t quadIdx = blockIdx.y;
508
 
509
+ if (!IsSingleExample && quadIdx >= quadCt) {
510
  return;
511
  }
512
 
 
535
  auto exAdjCounts = reinterpret_cast<uint32_t*>(outAdjCounts[b].data());
536
  auto exAdjValues = outSparseAdj[b][quadIdx].data();
537
 
538
+ T *exData = IsSingleExample ? allEmbedQuads.data() : allEmbedQuads[b].data();
539
 
540
  const auto bdsAnchor = Quad_<T>{ s_quadVerts }.Bounds();
541
 
 
599
  }
600
  }
601
 
602
+ template<bool IsSingleExample>
603
  __global__
604
+ void device_flatten_graph_iterative(const uint64_t punCounts,
605
  torch::PackedTensorAccessor64<bool, 2> allIsStart,
606
  volatile uint32_t *allAdjCounts,
607
  volatile uint32_t *allAdjValues
 
622
  const uint32_t b = blockIdx.z;
623
  const uint32_t anchorRow = blockIdx.y;
624
 
625
+ const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
626
 
627
  // Only need to check this if there are multiple examples, since in the case of a single example,
628
  // the grid is precisely sized to that quadCt
629
+ if constexpr (!IsSingleExample) {
630
+ if (anchorRow >= quadCt) {
631
+ return;
632
+ }
633
  }
634
 
635
  auto isStart = allIsStart[b].data();
 
690
  visitStack[1] = anchorRow;
691
  #ifndef NDEBUG
692
  for (uint32_t i = 2; i < VISIT_STACK_SIZE; ++i) {
693
+ visitStack[i] = -2;
694
  }
695
  #endif
696
  int32_t visitPtr = 1;
697
 
698
+ while (true) {
 
699
  #ifdef NMS_VERIFY_CORRECTNESS
700
  assert(visitPtr >= 0 && visitPtr < VISIT_STACK_SIZE);
701
  #endif
 
707
  if (threadNextCol == warpNextCol) {
708
  #ifndef NDEBUG
709
  // This makes it easier to debug where the pointer is
710
+ visitStack[visitPtr] = -2;
711
  #endif
712
  --visitPtr;
713
  }
 
731
  const uint32_t procAdjCount = adjCounts[procRow];
732
  auto procAdjValues = adjValues + (procRow * maxExCount);
733
 
734
+ // Offsetting by the iteration number will help balance out the maximum depth of any stack in the warp.
735
+ // The reason behind this is due to how otherwise, warp-0 will always get a new element, warp-1 iff the adj graph
736
+ // has more than one element, warp-2 iff the adj graph has more than two elements, and so on. Basically,
737
+ // the warps have decreasing pressure. With the rotation mechanism, it helps to balance out stack usage.
738
  for (uint32_t i = threadRank; i < procAdjCount; i += WARP_SIZE) {
739
+ const uint32_t adjCol = procAdjValues[i];
 
 
 
 
 
 
 
740
 
741
  // This will set the queued flag for this column, if it's not already set.
742
  // It also returns the old state. In our case, we only want to add this value to the
 
748
 
749
  bool alreadyAdded = oldMask & ADDED_MASK;
750
 
751
+ auto group = cg::coalesced_threads();
752
  const uint32_t gThreadRank = group.thread_rank();
753
  uint32_t notAddedBallot = group.ballot(!alreadyAdded);
754
  if (notAddedBallot) {
 
825
  }
826
  }
827
 
828
+ template<bool IsSingleExample>
829
+ void cpu_flatten_graph(const uint64_t punCounts,
830
  torch::Tensor isStartTensorGPU,
831
  torch::Tensor adjCountsTensorGPU,
832
  torch::Tensor adjValuesTensorGPU)
 
840
  auto allAdjValues = adjValuesTensor.accessor<int32_t, 3>();
841
 
842
  for (int32_t b = 0; b < allAdjCounts.size(0); ++b) {
843
+ const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
844
 
845
  for (int32_t row = 0; row < quadCt; ++row) {
846
  std::unordered_set<int32_t> fullAdjSet;
 
895
  }
896
  }
897
 
898
+ template<uint32_t NumWarps, typename T, bool IsSingleExample>
899
  __global__
900
+ void device_a2a_collapse(const uint64_t punCounts,
901
  torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
902
  torch::PackedTensorAccessor64<bool, 2> allIsLeadRow,
903
  const int64_t *regionCounts,
 
917
  const uint32_t b = blockIdx.z;
918
  const uint32_t row = blockIdx.y;
919
 
920
+ const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
921
 
922
+ if constexpr (!IsSingleExample) {
923
+ if (row >= quadCt) {
924
+ return;
925
+ }
926
  }
927
 
928
  // Only process the lead rows
929
+ const auto isLeadRow = IsSingleExample ? allIsLeadRow.data() : allIsLeadRow[b].data();
930
  if (!isLeadRow[row]) {
931
  return;
932
  }
 
945
  __syncthreads();
946
  }
947
 
948
+ T *exData = IsSingleExample ? allEmbedQuads.data() : allEmbedQuads[b].data();
949
 
950
  const int32_t adjCount = allAdjCounts[b][row];
951
  const int32_t *adjIdxs = allAdjValues[b][row].data();
 
986
 
987
  // Figure out the output position
988
  uint32_t writePosition = 0;
989
+ if constexpr (!IsSingleExample) {
990
+ for (int32_t i = threadRank; i < b; i += BLOCK_WIDTH) {
991
+ writePosition += regionCounts[i];
992
+ }
993
  }
994
 
995
+ const int32_t numLongs = row >> 3; // Divide by 8
996
  const uint8_t *pCurrIsLeadRow = reinterpret_cast<const uint8_t*>(isLeadRow);
997
+ const uint64_t *lpCurrIsLeadRow = reinterpret_cast<const uint64_t*>(pCurrIsLeadRow);
998
+
999
+ for (int32_t i = threadRank; i < numLongs; i += BLOCK_WIDTH) {
1000
+ writePosition += __popcll(lpCurrIsLeadRow[i]);
1001
+ }
1002
+ for (int32_t i = (numLongs * 8) + threadRank; i < row; i += BLOCK_WIDTH) {
1003
  if (pCurrIsLeadRow[i]) {
1004
  ++writePosition;
1005
  }
 
1075
  int64_t embedSize = sizeof(EmbedQuad_<scalar_t>) / sizeof(scalar_t);
1076
  auto rowMergeTensor = torch::empty({ quads.size(0), embedSize, quads.size(1) * quads.size(2) }, quads.options());
1077
 
1078
+ #ifdef NMS_VERIFY_CORRECTNESS
1079
  auto idsTensor = torch::full({ quads.size(0), quads.size(1) * quads.size(2) },
1080
  std::numeric_limits<int32_t>::max(),
1081
  counts.options().dtype(torch::kInt32));
1082
+ #else
1083
+ torch::Tensor idsTensor;
1084
+ #endif
1085
 
1086
  dim3 blockSize(32, 3, 1);
1087
  dim3 gridSize(1,
 
1093
  probs.packed_accessor64<scalar_t, 3>(),
1094
  probThreshold, iouThreshold,
1095
  counts.packed_accessor64<int32_t, 1>(),
1096
+ rowMergeTensor.packed_accessor64<scalar_t, 3>()
1097
+ #ifdef NMS_VERIFY_CORRECTNESS
1098
+ , idsTensor.packed_accessor64<int32_t, 2>()
1099
+ #endif
1100
  );
1101
 
1102
  #ifdef NMS_VERIFY_CORRECTNESS
 
1119
 
1120
  counts = counts.slice(/*dim=*/ 0, 0, counts.size(0) - 1);
1121
 
1122
+ #ifdef NMS_VERIFY_CORRECTNESS
1123
  int64_t maxExCount;
1124
  if (counts.size(0) > 1) {
1125
  maxExCount = counts.max().item<int32_t>();
 
1131
 
1132
  rowMergeTensor = rowMergeTensor.slice(2, 0, maxExCount);
1133
  idsTensor = idsTensor.slice(1, 0, maxExCount);
1134
+ auto order = torch::argsort(idsTensor, /*dim=*/ 1, s_sortOrder); s_sortOrder = !s_sortOrder;
1135
 
1136
  auto embOrder = order.unsqueeze(1).expand_as(rowMergeTensor);
1137
 
1138
  rowMergeTensor = torch::gather(rowMergeTensor, /*dim=*/ 2, embOrder);
1139
  idsTensor = torch::gather(idsTensor, /*dim=*/ 1, order);
1140
+ #endif
1141
 
1142
  return { counts, rowMergeTensor, totalQuads, idsTensor, imageWidth, imageHeight };
1143
  }
 
1177
  int64_t MaxExCount;
1178
  };
1179
 
1180
+ template<bool IsSingleExample, typename T>
1181
+ void cpu_a2a_adjacency_sparse(const uint64_t punCounts,
1182
  const T iouThreshold,
1183
  torch::Tensor embedQuadsTensor,
1184
  torch::Tensor outIsStartTensorGPU,
 
1196
  auto adjValues = outSparseAdjTensor.accessor<int32_t, 3>();
1197
 
1198
  for (int32_t b = 0; b < embedQuadsTensor.size(0); ++b) {
1199
+ const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
1200
 
1201
  T *exData = embedQuads[b].data();
1202
 
 
1284
  counts.options().dtype(torch::kInt32));
1285
  #endif
1286
 
1287
+ // If the batch is only a single example, instead of hitting global memory for the count, we can
1288
+ // just encode the count into the pointer instead
1289
+ uint64_t ptrCounts = reinterpret_cast<uint64_t>(counts.data_ptr<int32_t>());
1290
+ if (counts.size(0) == 1) {
1291
+ ptrCounts = maxExCount;
1292
+ }
1293
+
1294
  #ifdef NMS_VERIFY_CORRECTNESS
1295
  auto cpuAdjValuesTensor = adjValuesTensor.cpu();
1296
  auto cpuAdjCountsTensor = adjCountsTensor.cpu();
 
1318
  //blockSize = dim3{ GRID_NUM_WARPS * 32, 1, 1 };
1319
  //gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
1320
 
1321
+ //auto buildGridFn = counts.size(0) == 1 ?
1322
+ // device_a2a_adjacency_build_grid<GRID_NUM_WARPS, true, scalar_t, CELL_SIZE> :
1323
+ // device_a2a_adjacency_build_grid<GRID_NUM_WARPS, false, scalar_t, CELL_SIZE>;
1324
+
1325
+ //buildGridFn KERNEL_ARG2(gridSize, blockSize) (
1326
+ // ptrCounts,
1327
  // collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
1328
  // gridCellsTensor.packed_accessor64<int32_t, 4>(),
1329
  // quadCellExtentsTensor.packed_accessor64<int32_t, 3>()
1330
  //);
1331
 
1332
+ //auto adjGridFn = counts.size(0) == 1 ?
1333
+ // device_a2a_adjacency_with_grid<GRID_NUM_WARPS, true, scalar_t> :
1334
+ // device_a2a_adjacency_with_grid<GRID_NUM_WARPS, false, scalar_t>;
1335
+
1336
+ //adjGridFn KERNEL_ARG3(gridSize, blockSize, smemSize) (
1337
+ // ptrCounts,
1338
  // iouThreshold,
1339
  // collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
1340
  // gridCellsTensor.packed_accessor64<int32_t, 4>(),
 
1351
  gridSize = dim3{div_up(totalWork, blockSize.x),
1352
  static_cast<uint32_t>(counts.size(0))};
1353
 
1354
+ auto adjFn = counts.size(0) == 1 ? device_a2a_adjacency_sparse<true, scalar_t> : device_a2a_adjacency_sparse<false, scalar_t>;
1355
+
1356
  // This algorithm is O(n^2) with n being the current number of quads
1357
+ adjFn KERNEL_ARG2(gridSize, blockSize) (
1358
+ ptrCounts,
1359
  iouThreshold,
1360
  collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
1361
  isStartTensor.packed_accessor64<bool, 2>(),
 
1365
 
1366
 
1367
  #ifdef NMS_VERIFY_CORRECTNESS
1368
+ cpu_a2a_adjacency_sparse<true>(ptrCounts, iouThreshold,
 
 
1369
  collapseResult.StridedMergeQuads, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
1370
 
1371
  adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
 
1380
  auto maxDepthTensor = torch::tensor(0, adjCountsTensor.options());
1381
  #endif
1382
 
1383
+ auto traverseFn = counts.size(0) == 1 ?
1384
+ device_flatten_graph_iterative<true> :
1385
+ device_flatten_graph_iterative<false>;
1386
+
1387
  blockSize = dim3{ 128, 1, 1 };
1388
  gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
1389
  smemSize = div_up(maxExCount * sizeof(visit_mask_t), sizeof(uint32_t)) * sizeof(uint32_t);
1390
 
1391
+ traverseFn KERNEL_ARG3(gridSize, blockSize, smemSize) (
1392
+ ptrCounts,
1393
  isStartTensor.packed_accessor64<bool, 2>(),
1394
  reinterpret_cast<uint32_t*>(adjCountsTensor.data_ptr<int32_t>()),
1395
  reinterpret_cast<uint32_t*>(adjValuesTensor.data_ptr<int32_t>())
 
1399
  );
1400
 
1401
  #ifdef NMS_VERIFY_CORRECTNESS
1402
+ cpu_flatten_graph<true>(ptrCounts, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
1403
 
1404
  cpuAdjValuesTensor = std::get<0>(torch::sort(cpuAdjValuesTensor, /*dim=*/ 2));
1405
  adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
 
1437
  cpuIsStartTensor = isStartTensor.cpu();
1438
  cpuAdjCountsTensor = adjCountsTensor.cpu();
1439
  cpuAdjValuesTensor = adjValuesTensor.cpu();
1440
+ auto cpuCounts = counts.cpu();
1441
  auto cpuCollapseIds = collapseResult.QuadIds.cpu();
1442
 
1443
  static std::vector<std::unordered_set<int32_t>> s_knownGroups;
 
1589
  dim3 blockSize(NUM_WARPS * 32, 1, 1);
1590
  dim3 gridSize(1, adjResult.MaxExCount, counts.size(0));
1591
 
1592
+ // If the batch is only a single example, instead of hitting global memory for the count, we can
1593
+ // just encode the count into the pointer instead
1594
+ uint64_t ptrCounts = reinterpret_cast<uint64_t>(counts.data_ptr<int32_t>());
1595
+ if (counts.size(0) == 1) {
1596
+ ptrCounts = adjResult.MaxExCount;
1597
+ }
1598
+
1599
  torch::Tensor outQuads = torch::empty({ numOutQuads, 4, 2 }, embedQuads.options());
1600
  torch::Tensor outConf = torch::empty({ numOutQuads }, embedQuads.options());
1601
 
1602
+ auto collapseFn = counts.size(0) == 1 ?
1603
+ device_a2a_collapse<NUM_WARPS, scalar_t, true> :
1604
+ device_a2a_collapse<NUM_WARPS, scalar_t, false>;
1605
+
1606
+ collapseFn KERNEL_ARG2(gridSize, blockSize) (
1607
+ ptrCounts,
1608
  embedQuads.packed_accessor64<scalar_t, 3>(),
1609
  isLeadRow.packed_accessor64<bool, 2>(),
1610
  regionCounts.data_ptr<int64_t>(),
nemotron-ocr/pyproject.toml CHANGED
@@ -5,7 +5,6 @@ description = "Nemoton OCR"
5
  authors = [{ name = "NVIDIA Nemotron" }]
6
  requires-python = ">=3.12,<3.13"
7
  dependencies = [
8
- "huggingface_hub>=0.20.0",
9
  "pandas>=2.3.3",
10
  "pillow>=12.0.0",
11
  "scikit-learn>=1.7.2",
 
5
  authors = [{ name = "NVIDIA Nemotron" }]
6
  requires-python = ">=3.12,<3.13"
7
  dependencies = [
 
8
  "pandas>=2.3.3",
9
  "pillow>=12.0.0",
10
  "scikit-learn>=1.7.2",
nemotron-ocr/src/nemotron_ocr/inference/pipeline.py CHANGED
@@ -6,7 +6,6 @@ import io
6
  import json
7
  import os
8
  from pathlib import Path
9
- from typing import Optional
10
 
11
  import numpy as np
12
  import torch
@@ -21,7 +20,6 @@ from nemotron_ocr.inference.post_processing.data.text_region import TextBlock
21
  from nemotron_ocr.inference.post_processing.quad_rectify import QuadRectify
22
  from nemotron_ocr.inference.post_processing.research_ops import parse_relational_results, reorder_boxes
23
  from nemotron_ocr.inference.pre_processing import interpolate_and_pad, pad_to_square
24
- from huggingface_hub import hf_hub_download
25
  from nemotron_ocr_cpp import quad_non_maximal_suppression, region_counts_to_indices, rrect_to_quads
26
  from PIL import Image, ImageDraw, ImageFont
27
  from torch import amp
@@ -39,57 +37,25 @@ MERGE_LEVELS = {"word", "sentence", "paragraph"}
39
  DEFAULT_MERGE_LEVEL = "paragraph"
40
 
41
 
42
- # HuggingFace repository for downloading model weights
43
- HF_REPO_ID = "nvidia/nemotron-ocr-v1"
44
- CHECKPOINT_FILES = ["detector.pth", "recognizer.pth", "relational.pth", "charset.txt"]
45
-
46
-
47
  class NemotronOCR:
48
  """
49
  A high-level pipeline for performing OCR on images.
50
-
51
- Model weights are automatically downloaded from Hugging Face Hub
52
- (nvidia/nemotron-ocr-v1) if not found locally.
53
  """
54
 
55
- def __init__(self, model_dir: Optional[str] = None):
56
- # If model_dir is provided and contains all required files, use it directly
57
- if model_dir is not None:
58
- local_path = Path(model_dir)
59
- if all((local_path / f).is_file() for f in CHECKPOINT_FILES):
60
- self._model_dir = local_path
61
- else:
62
- self._model_dir = self._download_checkpoints()
63
- else:
64
- self._model_dir = self._download_checkpoints()
65
 
66
  self._load_models()
67
  self._load_charset()
68
  self._initialize_processors()
69
 
70
- @staticmethod
71
- def _download_checkpoints() -> Path:
72
- """Download model checkpoints from HuggingFace Hub (cached locally after first download)."""
73
- downloaded_path = None
74
- for filename in CHECKPOINT_FILES:
75
- downloaded_path = hf_hub_download(
76
- repo_id=HF_REPO_ID,
77
- filename=f"checkpoints/{filename}",
78
- )
79
- # All checkpoint files are in the same directory
80
- return Path(downloaded_path).parent
81
-
82
  def _load_models(self):
83
  """Loads all necessary models into memory."""
84
  self.detector = FOTSDetector(coordinate_mode="RBOX", backbone="regnet_y_8gf", verbose=False)
85
- self.detector.load_state_dict(
86
- torch.load(self._model_dir / "detector.pth", weights_only=True), strict=True
87
- )
88
 
89
  self.recognizer = TransformerRecognizer(nic=self.detector.num_features[-1], num_tokens=858, max_width=32)
90
- self.recognizer.load_state_dict(
91
- torch.load(self._model_dir / "recognizer.pth", weights_only=True), strict=True
92
- )
93
 
94
  self.relational = GlobalRelationalModel(
95
  num_input_channels=self.detector.num_features,
@@ -98,9 +64,7 @@ class NemotronOCR:
98
  k=16,
99
  num_layers=4,
100
  )
101
- self.relational.load_state_dict(
102
- torch.load(self._model_dir / "relational.pth", weights_only=True), strict=True
103
- )
104
 
105
  for model in (self.detector, self.recognizer, self.relational):
106
  model = model.cuda()
@@ -217,17 +181,29 @@ class NemotronOCR:
217
  e2e_det_conf = torch.sigmoid(det_conf)
218
  e2e_det_coords = rrect_to_quads(det_rboxes.float(), DETECTOR_DOWNSAMPLE)
219
 
220
- quads, confidence, region_counts = quad_non_maximal_suppression(
221
- e2e_det_coords,
222
- e2e_det_conf,
223
- prob_threshold=NMS_PROB_THRESHOLD,
224
- iou_threshold=NMS_IOU_THRESHOLD,
225
- kernel_height=2,
226
- kernel_width=3,
227
- max_regions=NMS_MAX_REGIONS,
228
- verbose=False,
229
- )[:3]
230
-
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  if quads.shape[0] == 0:
233
  rec_rectified_quads = torch.empty(0, 128, 8, 32, dtype=torch.float32, device=padded_image.device)
 
6
  import json
7
  import os
8
  from pathlib import Path
 
9
 
10
  import numpy as np
11
  import torch
 
20
  from nemotron_ocr.inference.post_processing.quad_rectify import QuadRectify
21
  from nemotron_ocr.inference.post_processing.research_ops import parse_relational_results, reorder_boxes
22
  from nemotron_ocr.inference.pre_processing import interpolate_and_pad, pad_to_square
 
23
  from nemotron_ocr_cpp import quad_non_maximal_suppression, region_counts_to_indices, rrect_to_quads
24
  from PIL import Image, ImageDraw, ImageFont
25
  from torch import amp
 
37
  DEFAULT_MERGE_LEVEL = "paragraph"
38
 
39
 
 
 
 
 
 
40
  class NemotronOCR:
41
  """
42
  A high-level pipeline for performing OCR on images.
 
 
 
43
  """
44
 
45
+ def __init__(self, model_dir="./checkpoints"):
46
+ self._model_dir = Path(model_dir)
 
 
 
 
 
 
 
 
47
 
48
  self._load_models()
49
  self._load_charset()
50
  self._initialize_processors()
51
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def _load_models(self):
53
  """Loads all necessary models into memory."""
54
  self.detector = FOTSDetector(coordinate_mode="RBOX", backbone="regnet_y_8gf", verbose=False)
55
+ self.detector.load_state_dict(torch.load(self._model_dir / "detector.pth"), strict=True)
 
 
56
 
57
  self.recognizer = TransformerRecognizer(nic=self.detector.num_features[-1], num_tokens=858, max_width=32)
58
+ self.recognizer.load_state_dict(torch.load(self._model_dir / "recognizer.pth"), strict=True)
 
 
59
 
60
  self.relational = GlobalRelationalModel(
61
  num_input_channels=self.detector.num_features,
 
64
  k=16,
65
  num_layers=4,
66
  )
67
+ self.relational.load_state_dict(torch.load(self._model_dir / "relational.pth"), strict=True)
 
 
68
 
69
  for model in (self.detector, self.recognizer, self.relational):
70
  model = model.cuda()
 
181
  e2e_det_conf = torch.sigmoid(det_conf)
182
  e2e_det_coords = rrect_to_quads(det_rboxes.float(), DETECTOR_DOWNSAMPLE)
183
 
184
+ # FIXME: quad_non_maximal_suppression fails with batch size > 1
185
+ all_quads = []
186
+ all_confidence = []
187
+ all_region_counts = []
188
+
189
+ for idx in range(e2e_det_coords.shape[0]):
190
+ quads, confidence, region_counts = quad_non_maximal_suppression(
191
+ e2e_det_coords[idx].unsqueeze(0),
192
+ e2e_det_conf[idx].unsqueeze(0),
193
+ prob_threshold=NMS_PROB_THRESHOLD,
194
+ iou_threshold=NMS_IOU_THRESHOLD,
195
+ kernel_height=2,
196
+ kernel_width=3,
197
+ max_regions=NMS_MAX_REGIONS,
198
+ verbose=False,
199
+ )[:3]
200
+ all_quads.append(quads)
201
+ all_confidence.append(confidence)
202
+ all_region_counts.append(region_counts)
203
+
204
+ quads = torch.cat(all_quads, dim=0)
205
+ confidence = torch.cat(all_confidence, dim=0)
206
+ region_counts = torch.cat(all_region_counts, dim=0)
207
 
208
  if quads.shape[0] == 0:
209
  rec_rectified_quads = torch.empty(0, 128, 8, 32, dtype=torch.float32, device=padded_image.device)