https://huggingface.co/nvidia/nemotron-ocr-v1/tree/main
#8
by
Eklavya214
- opened
example.py
CHANGED
|
@@ -8,7 +8,7 @@ from nemotron_ocr.inference.pipeline import NemotronOCR
|
|
| 8 |
|
| 9 |
|
| 10 |
def main(image_path, merge_level, no_visualize, model_dir):
|
| 11 |
-
ocr_pipeline = NemotronOCR(
|
| 12 |
|
| 13 |
predictions = ocr_pipeline(image_path, merge_level=merge_level, visualize=not no_visualize)
|
| 14 |
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
def main(image_path, merge_level, no_visualize, model_dir):
|
| 11 |
+
ocr_pipeline = NemotronOCR()
|
| 12 |
|
| 13 |
predictions = ocr_pipeline(image_path, merge_level=merge_level, visualize=not no_visualize)
|
| 14 |
|
nemotron-ocr/cpp/non_maximal_suppression/cuda_non_maximal_suppression.cu
CHANGED
|
@@ -157,8 +157,11 @@ void device_row_collapse(torch::PackedTensorAccessor64<T, 5> allQuads,
|
|
| 157 |
torch::PackedTensorAccessor64<T, 3> allConfs,
|
| 158 |
T confThreshold, T iouThreshold,
|
| 159 |
torch::PackedTensorAccessor64<int32_t, 1> allOutCounts,
|
| 160 |
-
torch::PackedTensorAccessor64<T, 3> allOutEmbedQuads
|
| 161 |
-
|
|
|
|
|
|
|
|
|
|
| 162 |
{
|
| 163 |
typedef InPlaceQuad_<T> Quadf;
|
| 164 |
static_assert(sizeof(Quadf) == sizeof(T) * 8, "Invalid QuadMem size!");
|
|
@@ -303,9 +306,11 @@ void device_row_collapse(torch::PackedTensorAccessor64<T, 5> allQuads,
|
|
| 303 |
}
|
| 304 |
|
| 305 |
write_embed_quad(outEmbedQuads, outQuad, storeOff + procLabel - 1);
|
|
|
|
| 306 |
if (threadRank == 0) {
|
| 307 |
allOutIds[b][storeOff + procLabel - 1] = r * 32 + startIdx;
|
| 308 |
}
|
|
|
|
| 309 |
}
|
| 310 |
|
| 311 |
if (threadRank == 0) {
|
|
@@ -316,9 +321,9 @@ void device_row_collapse(torch::PackedTensorAccessor64<T, 5> allQuads,
|
|
| 316 |
#undef threadRank
|
| 317 |
}
|
| 318 |
|
| 319 |
-
template<typename T>
|
| 320 |
__global__
|
| 321 |
-
void device_a2a_adjacency_sparse(const
|
| 322 |
T iouThreshold,
|
| 323 |
torch::PackedTensorAccessor64<T, 3> embedQuads,
|
| 324 |
torch::PackedTensorAccessor64<bool, 2> outIsStart,
|
|
@@ -327,11 +332,7 @@ void device_a2a_adjacency_sparse(const int32_t *ptrQuadCts,
|
|
| 327 |
{
|
| 328 |
const uint32_t b = blockIdx.y;
|
| 329 |
|
| 330 |
-
const int32_t quadCt =
|
| 331 |
-
|
| 332 |
-
if (quadCt == 0) {
|
| 333 |
-
return;
|
| 334 |
-
}
|
| 335 |
|
| 336 |
const int32_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 337 |
const int32_t row = jobIdx / quadCt;
|
|
@@ -342,7 +343,7 @@ void device_a2a_adjacency_sparse(const int32_t *ptrQuadCts,
|
|
| 342 |
return;
|
| 343 |
}
|
| 344 |
|
| 345 |
-
T* exData = embedQuads[b].data();
|
| 346 |
|
| 347 |
const auto qRow = StridedEmbedQuad_<T>{ exData + row * embedQuads.stride(2), embedQuads.stride(1) }.Bounds(),
|
| 348 |
qCol = StridedEmbedQuad_<T>{ exData + col * embedQuads.stride(2), embedQuads.stride(1) }.Bounds();
|
|
@@ -404,9 +405,9 @@ void device_a2a_adjacency_sparse(const int32_t *ptrQuadCts,
|
|
| 404 |
}
|
| 405 |
}
|
| 406 |
|
| 407 |
-
template<uint32_t NumWarps, typename T, int32_t I_CELL_SIZE>
|
| 408 |
__global__
|
| 409 |
-
void device_a2a_adjacency_build_grid(const
|
| 410 |
torch::PackedTensorAccessor64<T, 3> embedQuads,
|
| 411 |
torch::PackedTensorAccessor64<int32_t, 4> outGridCells,
|
| 412 |
torch::PackedTensorAccessor64<int32_t, 3> outQuadCells)
|
|
@@ -422,10 +423,10 @@ void device_a2a_adjacency_build_grid(const int32_t *ptrQuadCts,
|
|
| 422 |
|
| 423 |
const uint32_t b = blockIdx.z;
|
| 424 |
|
| 425 |
-
const uint32_t quadCt =
|
| 426 |
const uint32_t quadIdx = blockIdx.y;
|
| 427 |
|
| 428 |
-
if (quadIdx >= quadCt) {
|
| 429 |
return;
|
| 430 |
}
|
| 431 |
|
|
@@ -484,9 +485,9 @@ void device_a2a_adjacency_build_grid(const int32_t *ptrQuadCts,
|
|
| 484 |
|
| 485 |
typedef uint8_t visit_mask_t;
|
| 486 |
|
| 487 |
-
template<uint32_t NumWarps, typename T>
|
| 488 |
__global__
|
| 489 |
-
void device_a2a_adjacency_with_grid(const
|
| 490 |
T iouThreshold,
|
| 491 |
torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
|
| 492 |
torch::PackedTensorAccessor64<int32_t, 4> allCells,
|
|
@@ -502,10 +503,10 @@ void device_a2a_adjacency_with_grid(const int32_t *ptrQuadCts,
|
|
| 502 |
|
| 503 |
const uint32_t b = blockIdx.z;
|
| 504 |
|
| 505 |
-
const uint32_t quadCt =
|
| 506 |
const uint32_t quadIdx = blockIdx.y;
|
| 507 |
|
| 508 |
-
if (quadIdx >= quadCt) {
|
| 509 |
return;
|
| 510 |
}
|
| 511 |
|
|
@@ -534,7 +535,7 @@ void device_a2a_adjacency_with_grid(const int32_t *ptrQuadCts,
|
|
| 534 |
auto exAdjCounts = reinterpret_cast<uint32_t*>(outAdjCounts[b].data());
|
| 535 |
auto exAdjValues = outSparseAdj[b][quadIdx].data();
|
| 536 |
|
| 537 |
-
T *exData = allEmbedQuads[b].data();
|
| 538 |
|
| 539 |
const auto bdsAnchor = Quad_<T>{ s_quadVerts }.Bounds();
|
| 540 |
|
|
@@ -598,8 +599,9 @@ void device_a2a_adjacency_with_grid(const int32_t *ptrQuadCts,
|
|
| 598 |
}
|
| 599 |
}
|
| 600 |
|
|
|
|
| 601 |
__global__
|
| 602 |
-
void device_flatten_graph_iterative(const
|
| 603 |
torch::PackedTensorAccessor64<bool, 2> allIsStart,
|
| 604 |
volatile uint32_t *allAdjCounts,
|
| 605 |
volatile uint32_t *allAdjValues
|
|
@@ -620,12 +622,14 @@ void device_flatten_graph_iterative(const int32_t *ptrQuadCts,
|
|
| 620 |
const uint32_t b = blockIdx.z;
|
| 621 |
const uint32_t anchorRow = blockIdx.y;
|
| 622 |
|
| 623 |
-
const uint32_t quadCt =
|
| 624 |
|
| 625 |
// Only need to check this if there are multiple examples, since in the case of a single example,
|
| 626 |
// the grid is precisely sized to that quadCt
|
| 627 |
-
if (
|
| 628 |
-
|
|
|
|
|
|
|
| 629 |
}
|
| 630 |
|
| 631 |
auto isStart = allIsStart[b].data();
|
|
@@ -686,13 +690,12 @@ void device_flatten_graph_iterative(const int32_t *ptrQuadCts,
|
|
| 686 |
visitStack[1] = anchorRow;
|
| 687 |
#ifndef NDEBUG
|
| 688 |
for (uint32_t i = 2; i < VISIT_STACK_SIZE; ++i) {
|
| 689 |
-
visitStack[i] =
|
| 690 |
}
|
| 691 |
#endif
|
| 692 |
int32_t visitPtr = 1;
|
| 693 |
|
| 694 |
-
|
| 695 |
-
for (uint32_t dfsIter = 0; true; ++dfsIter) {
|
| 696 |
#ifdef NMS_VERIFY_CORRECTNESS
|
| 697 |
assert(visitPtr >= 0 && visitPtr < VISIT_STACK_SIZE);
|
| 698 |
#endif
|
|
@@ -704,7 +707,7 @@ void device_flatten_graph_iterative(const int32_t *ptrQuadCts,
|
|
| 704 |
if (threadNextCol == warpNextCol) {
|
| 705 |
#ifndef NDEBUG
|
| 706 |
// This makes it easier to debug where the pointer is
|
| 707 |
-
visitStack[visitPtr] =
|
| 708 |
#endif
|
| 709 |
--visitPtr;
|
| 710 |
}
|
|
@@ -728,15 +731,12 @@ void device_flatten_graph_iterative(const int32_t *ptrQuadCts,
|
|
| 728 |
const uint32_t procAdjCount = adjCounts[procRow];
|
| 729 |
auto procAdjValues = adjValues + (procRow * maxExCount);
|
| 730 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 731 |
for (uint32_t i = threadRank; i < procAdjCount; i += WARP_SIZE) {
|
| 732 |
-
uint32_t adjCol = procAdjValues[i];
|
| 733 |
-
|
| 734 |
-
auto group = cg::coalesced_threads();
|
| 735 |
-
// Offsetting by the iteration number will help balance out the maximum depth of any stack in the warp.
|
| 736 |
-
// The reason behind this is due to how otherwise, warp-0 will always get a new element, warp-1 iff the adj graph
|
| 737 |
-
// has more than one element, warp-2 iff the adj graph has more than two elements, and so on. Basically,
|
| 738 |
-
// the warps have decreasing pressure. With the rotation mechanism, it helps to balance out stack usage.
|
| 739 |
-
adjCol = group.shfl(adjCol, (group.thread_rank() + dfsIter) % group.size());
|
| 740 |
|
| 741 |
// This will set the queued flag for this column, if it's not already set.
|
| 742 |
// It also returns the old state. In our case, we only want to add this value to the
|
|
@@ -748,6 +748,7 @@ void device_flatten_graph_iterative(const int32_t *ptrQuadCts,
|
|
| 748 |
|
| 749 |
bool alreadyAdded = oldMask & ADDED_MASK;
|
| 750 |
|
|
|
|
| 751 |
const uint32_t gThreadRank = group.thread_rank();
|
| 752 |
uint32_t notAddedBallot = group.ballot(!alreadyAdded);
|
| 753 |
if (notAddedBallot) {
|
|
@@ -824,7 +825,8 @@ void add_to_set(const torch::TensorAccessor<int32_t, 1>& adjCounts,
|
|
| 824 |
}
|
| 825 |
}
|
| 826 |
|
| 827 |
-
|
|
|
|
| 828 |
torch::Tensor isStartTensorGPU,
|
| 829 |
torch::Tensor adjCountsTensorGPU,
|
| 830 |
torch::Tensor adjValuesTensorGPU)
|
|
@@ -838,7 +840,7 @@ void cpu_flatten_graph(const int32_t *ptrQuadCts,
|
|
| 838 |
auto allAdjValues = adjValuesTensor.accessor<int32_t, 3>();
|
| 839 |
|
| 840 |
for (int32_t b = 0; b < allAdjCounts.size(0); ++b) {
|
| 841 |
-
const int32_t quadCt =
|
| 842 |
|
| 843 |
for (int32_t row = 0; row < quadCt; ++row) {
|
| 844 |
std::unordered_set<int32_t> fullAdjSet;
|
|
@@ -893,9 +895,9 @@ void device_a2a_adj_cleanup(const int32_t *counts,
|
|
| 893 |
}
|
| 894 |
}
|
| 895 |
|
| 896 |
-
template<uint32_t NumWarps, typename T>
|
| 897 |
__global__
|
| 898 |
-
void device_a2a_collapse(
|
| 899 |
torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
|
| 900 |
torch::PackedTensorAccessor64<bool, 2> allIsLeadRow,
|
| 901 |
const int64_t *regionCounts,
|
|
@@ -915,14 +917,16 @@ void device_a2a_collapse(torch::PackedTensorAccessor64<int32_t, 1> quadCounts,
|
|
| 915 |
const uint32_t b = blockIdx.z;
|
| 916 |
const uint32_t row = blockIdx.y;
|
| 917 |
|
| 918 |
-
const int32_t quadCt =
|
| 919 |
|
| 920 |
-
if (
|
| 921 |
-
|
|
|
|
|
|
|
| 922 |
}
|
| 923 |
|
| 924 |
// Only process the lead rows
|
| 925 |
-
const auto isLeadRow = allIsLeadRow[b].data();
|
| 926 |
if (!isLeadRow[row]) {
|
| 927 |
return;
|
| 928 |
}
|
|
@@ -941,7 +945,7 @@ void device_a2a_collapse(torch::PackedTensorAccessor64<int32_t, 1> quadCounts,
|
|
| 941 |
__syncthreads();
|
| 942 |
}
|
| 943 |
|
| 944 |
-
T *exData = allEmbedQuads[b].data();
|
| 945 |
|
| 946 |
const int32_t adjCount = allAdjCounts[b][row];
|
| 947 |
const int32_t *adjIdxs = allAdjValues[b][row].data();
|
|
@@ -982,12 +986,20 @@ void device_a2a_collapse(torch::PackedTensorAccessor64<int32_t, 1> quadCounts,
|
|
| 982 |
|
| 983 |
// Figure out the output position
|
| 984 |
uint32_t writePosition = 0;
|
| 985 |
-
|
| 986 |
-
|
|
|
|
|
|
|
| 987 |
}
|
| 988 |
|
|
|
|
| 989 |
const uint8_t *pCurrIsLeadRow = reinterpret_cast<const uint8_t*>(isLeadRow);
|
| 990 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 991 |
if (pCurrIsLeadRow[i]) {
|
| 992 |
++writePosition;
|
| 993 |
}
|
|
@@ -1063,9 +1075,13 @@ CollapseRowsResult collapse_rows(
|
|
| 1063 |
int64_t embedSize = sizeof(EmbedQuad_<scalar_t>) / sizeof(scalar_t);
|
| 1064 |
auto rowMergeTensor = torch::empty({ quads.size(0), embedSize, quads.size(1) * quads.size(2) }, quads.options());
|
| 1065 |
|
|
|
|
| 1066 |
auto idsTensor = torch::full({ quads.size(0), quads.size(1) * quads.size(2) },
|
| 1067 |
std::numeric_limits<int32_t>::max(),
|
| 1068 |
counts.options().dtype(torch::kInt32));
|
|
|
|
|
|
|
|
|
|
| 1069 |
|
| 1070 |
dim3 blockSize(32, 3, 1);
|
| 1071 |
dim3 gridSize(1,
|
|
@@ -1077,8 +1093,10 @@ CollapseRowsResult collapse_rows(
|
|
| 1077 |
probs.packed_accessor64<scalar_t, 3>(),
|
| 1078 |
probThreshold, iouThreshold,
|
| 1079 |
counts.packed_accessor64<int32_t, 1>(),
|
| 1080 |
-
rowMergeTensor.packed_accessor64<scalar_t, 3>()
|
| 1081 |
-
|
|
|
|
|
|
|
| 1082 |
);
|
| 1083 |
|
| 1084 |
#ifdef NMS_VERIFY_CORRECTNESS
|
|
@@ -1101,6 +1119,7 @@ CollapseRowsResult collapse_rows(
|
|
| 1101 |
|
| 1102 |
counts = counts.slice(/*dim=*/ 0, 0, counts.size(0) - 1);
|
| 1103 |
|
|
|
|
| 1104 |
int64_t maxExCount;
|
| 1105 |
if (counts.size(0) > 1) {
|
| 1106 |
maxExCount = counts.max().item<int32_t>();
|
|
@@ -1112,12 +1131,13 @@ CollapseRowsResult collapse_rows(
|
|
| 1112 |
|
| 1113 |
rowMergeTensor = rowMergeTensor.slice(2, 0, maxExCount);
|
| 1114 |
idsTensor = idsTensor.slice(1, 0, maxExCount);
|
| 1115 |
-
auto order = torch::argsort(idsTensor, /*dim=*/ 1, s_sortOrder);
|
| 1116 |
|
| 1117 |
auto embOrder = order.unsqueeze(1).expand_as(rowMergeTensor);
|
| 1118 |
|
| 1119 |
rowMergeTensor = torch::gather(rowMergeTensor, /*dim=*/ 2, embOrder);
|
| 1120 |
idsTensor = torch::gather(idsTensor, /*dim=*/ 1, order);
|
|
|
|
| 1121 |
|
| 1122 |
return { counts, rowMergeTensor, totalQuads, idsTensor, imageWidth, imageHeight };
|
| 1123 |
}
|
|
@@ -1157,8 +1177,8 @@ struct AdjacencyResult {
|
|
| 1157 |
int64_t MaxExCount;
|
| 1158 |
};
|
| 1159 |
|
| 1160 |
-
template<typename T>
|
| 1161 |
-
void cpu_a2a_adjacency_sparse(const
|
| 1162 |
const T iouThreshold,
|
| 1163 |
torch::Tensor embedQuadsTensor,
|
| 1164 |
torch::Tensor outIsStartTensorGPU,
|
|
@@ -1176,7 +1196,7 @@ void cpu_a2a_adjacency_sparse(const int32_t *ptrQuadCts,
|
|
| 1176 |
auto adjValues = outSparseAdjTensor.accessor<int32_t, 3>();
|
| 1177 |
|
| 1178 |
for (int32_t b = 0; b < embedQuadsTensor.size(0); ++b) {
|
| 1179 |
-
const int32_t quadCt =
|
| 1180 |
|
| 1181 |
T *exData = embedQuads[b].data();
|
| 1182 |
|
|
@@ -1264,6 +1284,13 @@ AdjacencyResult compute_all_to_all_adjacency(
|
|
| 1264 |
counts.options().dtype(torch::kInt32));
|
| 1265 |
#endif
|
| 1266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1267 |
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1268 |
auto cpuAdjValuesTensor = adjValuesTensor.cpu();
|
| 1269 |
auto cpuAdjCountsTensor = adjCountsTensor.cpu();
|
|
@@ -1291,15 +1318,23 @@ AdjacencyResult compute_all_to_all_adjacency(
|
|
| 1291 |
//blockSize = dim3{ GRID_NUM_WARPS * 32, 1, 1 };
|
| 1292 |
//gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
|
| 1293 |
|
| 1294 |
-
//
|
| 1295 |
-
//
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1296 |
// collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
|
| 1297 |
// gridCellsTensor.packed_accessor64<int32_t, 4>(),
|
| 1298 |
// quadCellExtentsTensor.packed_accessor64<int32_t, 3>()
|
| 1299 |
//);
|
| 1300 |
|
| 1301 |
-
//
|
| 1302 |
-
//
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1303 |
// iouThreshold,
|
| 1304 |
// collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
|
| 1305 |
// gridCellsTensor.packed_accessor64<int32_t, 4>(),
|
|
@@ -1316,9 +1351,11 @@ AdjacencyResult compute_all_to_all_adjacency(
|
|
| 1316 |
gridSize = dim3{div_up(totalWork, blockSize.x),
|
| 1317 |
static_cast<uint32_t>(counts.size(0))};
|
| 1318 |
|
|
|
|
|
|
|
| 1319 |
// This algorithm is O(n^2) with n being the current number of quads
|
| 1320 |
-
|
| 1321 |
-
|
| 1322 |
iouThreshold,
|
| 1323 |
collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
|
| 1324 |
isStartTensor.packed_accessor64<bool, 2>(),
|
|
@@ -1328,9 +1365,7 @@ AdjacencyResult compute_all_to_all_adjacency(
|
|
| 1328 |
|
| 1329 |
|
| 1330 |
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1331 |
-
|
| 1332 |
-
|
| 1333 |
-
cpu_a2a_adjacency_sparse<scalar_t>(cpuCounts.data_ptr<int32_t>(), iouThreshold,
|
| 1334 |
collapseResult.StridedMergeQuads, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
|
| 1335 |
|
| 1336 |
adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
|
|
@@ -1345,12 +1380,16 @@ AdjacencyResult compute_all_to_all_adjacency(
|
|
| 1345 |
auto maxDepthTensor = torch::tensor(0, adjCountsTensor.options());
|
| 1346 |
#endif
|
| 1347 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1348 |
blockSize = dim3{ 128, 1, 1 };
|
| 1349 |
gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
|
| 1350 |
smemSize = div_up(maxExCount * sizeof(visit_mask_t), sizeof(uint32_t)) * sizeof(uint32_t);
|
| 1351 |
|
| 1352 |
-
|
| 1353 |
-
|
| 1354 |
isStartTensor.packed_accessor64<bool, 2>(),
|
| 1355 |
reinterpret_cast<uint32_t*>(adjCountsTensor.data_ptr<int32_t>()),
|
| 1356 |
reinterpret_cast<uint32_t*>(adjValuesTensor.data_ptr<int32_t>())
|
|
@@ -1360,7 +1399,7 @@ AdjacencyResult compute_all_to_all_adjacency(
|
|
| 1360 |
);
|
| 1361 |
|
| 1362 |
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1363 |
-
cpu_flatten_graph
|
| 1364 |
|
| 1365 |
cpuAdjValuesTensor = std::get<0>(torch::sort(cpuAdjValuesTensor, /*dim=*/ 2));
|
| 1366 |
adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
|
|
@@ -1398,6 +1437,7 @@ AdjacencyResult compute_all_to_all_adjacency(
|
|
| 1398 |
cpuIsStartTensor = isStartTensor.cpu();
|
| 1399 |
cpuAdjCountsTensor = adjCountsTensor.cpu();
|
| 1400 |
cpuAdjValuesTensor = adjValuesTensor.cpu();
|
|
|
|
| 1401 |
auto cpuCollapseIds = collapseResult.QuadIds.cpu();
|
| 1402 |
|
| 1403 |
static std::vector<std::unordered_set<int32_t>> s_knownGroups;
|
|
@@ -1549,11 +1589,22 @@ nms_result_t
|
|
| 1549 |
dim3 blockSize(NUM_WARPS * 32, 1, 1);
|
| 1550 |
dim3 gridSize(1, adjResult.MaxExCount, counts.size(0));
|
| 1551 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1552 |
torch::Tensor outQuads = torch::empty({ numOutQuads, 4, 2 }, embedQuads.options());
|
| 1553 |
torch::Tensor outConf = torch::empty({ numOutQuads }, embedQuads.options());
|
| 1554 |
|
| 1555 |
-
|
| 1556 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1557 |
embedQuads.packed_accessor64<scalar_t, 3>(),
|
| 1558 |
isLeadRow.packed_accessor64<bool, 2>(),
|
| 1559 |
regionCounts.data_ptr<int64_t>(),
|
|
|
|
| 157 |
torch::PackedTensorAccessor64<T, 3> allConfs,
|
| 158 |
T confThreshold, T iouThreshold,
|
| 159 |
torch::PackedTensorAccessor64<int32_t, 1> allOutCounts,
|
| 160 |
+
torch::PackedTensorAccessor64<T, 3> allOutEmbedQuads
|
| 161 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 162 |
+
, torch::PackedTensorAccessor64<int32_t, 2> allOutIds
|
| 163 |
+
#endif
|
| 164 |
+
)
|
| 165 |
{
|
| 166 |
typedef InPlaceQuad_<T> Quadf;
|
| 167 |
static_assert(sizeof(Quadf) == sizeof(T) * 8, "Invalid QuadMem size!");
|
|
|
|
| 306 |
}
|
| 307 |
|
| 308 |
write_embed_quad(outEmbedQuads, outQuad, storeOff + procLabel - 1);
|
| 309 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 310 |
if (threadRank == 0) {
|
| 311 |
allOutIds[b][storeOff + procLabel - 1] = r * 32 + startIdx;
|
| 312 |
}
|
| 313 |
+
#endif
|
| 314 |
}
|
| 315 |
|
| 316 |
if (threadRank == 0) {
|
|
|
|
| 321 |
#undef threadRank
|
| 322 |
}
|
| 323 |
|
| 324 |
+
template<bool IsSingleExample, typename T>
|
| 325 |
__global__
|
| 326 |
+
void device_a2a_adjacency_sparse(const uint64_t punCounts,
|
| 327 |
T iouThreshold,
|
| 328 |
torch::PackedTensorAccessor64<T, 3> embedQuads,
|
| 329 |
torch::PackedTensorAccessor64<bool, 2> outIsStart,
|
|
|
|
| 332 |
{
|
| 333 |
const uint32_t b = blockIdx.y;
|
| 334 |
|
| 335 |
+
const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
const int32_t jobIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
| 338 |
const int32_t row = jobIdx / quadCt;
|
|
|
|
| 343 |
return;
|
| 344 |
}
|
| 345 |
|
| 346 |
+
T* exData = IsSingleExample ? embedQuads.data() : embedQuads[b].data();
|
| 347 |
|
| 348 |
const auto qRow = StridedEmbedQuad_<T>{ exData + row * embedQuads.stride(2), embedQuads.stride(1) }.Bounds(),
|
| 349 |
qCol = StridedEmbedQuad_<T>{ exData + col * embedQuads.stride(2), embedQuads.stride(1) }.Bounds();
|
|
|
|
| 405 |
}
|
| 406 |
}
|
| 407 |
|
| 408 |
+
template<uint32_t NumWarps, bool IsSingleExample, typename T, int32_t I_CELL_SIZE>
|
| 409 |
__global__
|
| 410 |
+
void device_a2a_adjacency_build_grid(const uint64_t punCounts,
|
| 411 |
torch::PackedTensorAccessor64<T, 3> embedQuads,
|
| 412 |
torch::PackedTensorAccessor64<int32_t, 4> outGridCells,
|
| 413 |
torch::PackedTensorAccessor64<int32_t, 3> outQuadCells)
|
|
|
|
| 423 |
|
| 424 |
const uint32_t b = blockIdx.z;
|
| 425 |
|
| 426 |
+
const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
|
| 427 |
const uint32_t quadIdx = blockIdx.y;
|
| 428 |
|
| 429 |
+
if (!IsSingleExample && quadIdx >= quadCt) {
|
| 430 |
return;
|
| 431 |
}
|
| 432 |
|
|
|
|
| 485 |
|
| 486 |
typedef uint8_t visit_mask_t;
|
| 487 |
|
| 488 |
+
template<uint32_t NumWarps, bool IsSingleExample, typename T>
|
| 489 |
__global__
|
| 490 |
+
void device_a2a_adjacency_with_grid(const uint64_t punCounts,
|
| 491 |
T iouThreshold,
|
| 492 |
torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
|
| 493 |
torch::PackedTensorAccessor64<int32_t, 4> allCells,
|
|
|
|
| 503 |
|
| 504 |
const uint32_t b = blockIdx.z;
|
| 505 |
|
| 506 |
+
const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
|
| 507 |
const uint32_t quadIdx = blockIdx.y;
|
| 508 |
|
| 509 |
+
if (!IsSingleExample && quadIdx >= quadCt) {
|
| 510 |
return;
|
| 511 |
}
|
| 512 |
|
|
|
|
| 535 |
auto exAdjCounts = reinterpret_cast<uint32_t*>(outAdjCounts[b].data());
|
| 536 |
auto exAdjValues = outSparseAdj[b][quadIdx].data();
|
| 537 |
|
| 538 |
+
T *exData = IsSingleExample ? allEmbedQuads.data() : allEmbedQuads[b].data();
|
| 539 |
|
| 540 |
const auto bdsAnchor = Quad_<T>{ s_quadVerts }.Bounds();
|
| 541 |
|
|
|
|
| 599 |
}
|
| 600 |
}
|
| 601 |
|
| 602 |
+
template<bool IsSingleExample>
|
| 603 |
__global__
|
| 604 |
+
void device_flatten_graph_iterative(const uint64_t punCounts,
|
| 605 |
torch::PackedTensorAccessor64<bool, 2> allIsStart,
|
| 606 |
volatile uint32_t *allAdjCounts,
|
| 607 |
volatile uint32_t *allAdjValues
|
|
|
|
| 622 |
const uint32_t b = blockIdx.z;
|
| 623 |
const uint32_t anchorRow = blockIdx.y;
|
| 624 |
|
| 625 |
+
const uint32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
|
| 626 |
|
| 627 |
// Only need to check this if there are multiple examples, since in the case of a single example,
|
| 628 |
// the grid is precisely sized to that quadCt
|
| 629 |
+
if constexpr (!IsSingleExample) {
|
| 630 |
+
if (anchorRow >= quadCt) {
|
| 631 |
+
return;
|
| 632 |
+
}
|
| 633 |
}
|
| 634 |
|
| 635 |
auto isStart = allIsStart[b].data();
|
|
|
|
| 690 |
visitStack[1] = anchorRow;
|
| 691 |
#ifndef NDEBUG
|
| 692 |
for (uint32_t i = 2; i < VISIT_STACK_SIZE; ++i) {
|
| 693 |
+
visitStack[i] = -2;
|
| 694 |
}
|
| 695 |
#endif
|
| 696 |
int32_t visitPtr = 1;
|
| 697 |
|
| 698 |
+
while (true) {
|
|
|
|
| 699 |
#ifdef NMS_VERIFY_CORRECTNESS
|
| 700 |
assert(visitPtr >= 0 && visitPtr < VISIT_STACK_SIZE);
|
| 701 |
#endif
|
|
|
|
| 707 |
if (threadNextCol == warpNextCol) {
|
| 708 |
#ifndef NDEBUG
|
| 709 |
// This makes it easier to debug where the pointer is
|
| 710 |
+
visitStack[visitPtr] = -2;
|
| 711 |
#endif
|
| 712 |
--visitPtr;
|
| 713 |
}
|
|
|
|
| 731 |
const uint32_t procAdjCount = adjCounts[procRow];
|
| 732 |
auto procAdjValues = adjValues + (procRow * maxExCount);
|
| 733 |
|
| 734 |
+
// Offsetting by the iteration number will help balance out the maximum depth of any stack in the warp.
|
| 735 |
+
// The reason behind this is due to how otherwise, warp-0 will always get a new element, warp-1 iff the adj graph
|
| 736 |
+
// has more than one element, warp-2 iff the adj graph has more than two elements, and so on. Basically,
|
| 737 |
+
// the warps have decreasing pressure. With the rotation mechanism, it helps to balance out stack usage.
|
| 738 |
for (uint32_t i = threadRank; i < procAdjCount; i += WARP_SIZE) {
|
| 739 |
+
const uint32_t adjCol = procAdjValues[i];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 740 |
|
| 741 |
// This will set the queued flag for this column, if it's not already set.
|
| 742 |
// It also returns the old state. In our case, we only want to add this value to the
|
|
|
|
| 748 |
|
| 749 |
bool alreadyAdded = oldMask & ADDED_MASK;
|
| 750 |
|
| 751 |
+
auto group = cg::coalesced_threads();
|
| 752 |
const uint32_t gThreadRank = group.thread_rank();
|
| 753 |
uint32_t notAddedBallot = group.ballot(!alreadyAdded);
|
| 754 |
if (notAddedBallot) {
|
|
|
|
| 825 |
}
|
| 826 |
}
|
| 827 |
|
| 828 |
+
template<bool IsSingleExample>
|
| 829 |
+
void cpu_flatten_graph(const uint64_t punCounts,
|
| 830 |
torch::Tensor isStartTensorGPU,
|
| 831 |
torch::Tensor adjCountsTensorGPU,
|
| 832 |
torch::Tensor adjValuesTensorGPU)
|
|
|
|
| 840 |
auto allAdjValues = adjValuesTensor.accessor<int32_t, 3>();
|
| 841 |
|
| 842 |
for (int32_t b = 0; b < allAdjCounts.size(0); ++b) {
|
| 843 |
+
const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
|
| 844 |
|
| 845 |
for (int32_t row = 0; row < quadCt; ++row) {
|
| 846 |
std::unordered_set<int32_t> fullAdjSet;
|
|
|
|
| 895 |
}
|
| 896 |
}
|
| 897 |
|
| 898 |
+
template<uint32_t NumWarps, typename T, bool IsSingleExample>
|
| 899 |
__global__
|
| 900 |
+
void device_a2a_collapse(const uint64_t punCounts,
|
| 901 |
torch::PackedTensorAccessor64<T, 3> allEmbedQuads,
|
| 902 |
torch::PackedTensorAccessor64<bool, 2> allIsLeadRow,
|
| 903 |
const int64_t *regionCounts,
|
|
|
|
| 917 |
const uint32_t b = blockIdx.z;
|
| 918 |
const uint32_t row = blockIdx.y;
|
| 919 |
|
| 920 |
+
const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
|
| 921 |
|
| 922 |
+
if constexpr (!IsSingleExample) {
|
| 923 |
+
if (row >= quadCt) {
|
| 924 |
+
return;
|
| 925 |
+
}
|
| 926 |
}
|
| 927 |
|
| 928 |
// Only process the lead rows
|
| 929 |
+
const auto isLeadRow = IsSingleExample ? allIsLeadRow.data() : allIsLeadRow[b].data();
|
| 930 |
if (!isLeadRow[row]) {
|
| 931 |
return;
|
| 932 |
}
|
|
|
|
| 945 |
__syncthreads();
|
| 946 |
}
|
| 947 |
|
| 948 |
+
T *exData = IsSingleExample ? allEmbedQuads.data() : allEmbedQuads[b].data();
|
| 949 |
|
| 950 |
const int32_t adjCount = allAdjCounts[b][row];
|
| 951 |
const int32_t *adjIdxs = allAdjValues[b][row].data();
|
|
|
|
| 986 |
|
| 987 |
// Figure out the output position
|
| 988 |
uint32_t writePosition = 0;
|
| 989 |
+
if constexpr (!IsSingleExample) {
|
| 990 |
+
for (int32_t i = threadRank; i < b; i += BLOCK_WIDTH) {
|
| 991 |
+
writePosition += regionCounts[i];
|
| 992 |
+
}
|
| 993 |
}
|
| 994 |
|
| 995 |
+
const int32_t numLongs = row >> 3; // Divide by 8
|
| 996 |
const uint8_t *pCurrIsLeadRow = reinterpret_cast<const uint8_t*>(isLeadRow);
|
| 997 |
+
const uint64_t *lpCurrIsLeadRow = reinterpret_cast<const uint64_t*>(pCurrIsLeadRow);
|
| 998 |
+
|
| 999 |
+
for (int32_t i = threadRank; i < numLongs; i += BLOCK_WIDTH) {
|
| 1000 |
+
writePosition += __popcll(lpCurrIsLeadRow[i]);
|
| 1001 |
+
}
|
| 1002 |
+
for (int32_t i = (numLongs * 8) + threadRank; i < row; i += BLOCK_WIDTH) {
|
| 1003 |
if (pCurrIsLeadRow[i]) {
|
| 1004 |
++writePosition;
|
| 1005 |
}
|
|
|
|
| 1075 |
int64_t embedSize = sizeof(EmbedQuad_<scalar_t>) / sizeof(scalar_t);
|
| 1076 |
auto rowMergeTensor = torch::empty({ quads.size(0), embedSize, quads.size(1) * quads.size(2) }, quads.options());
|
| 1077 |
|
| 1078 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1079 |
auto idsTensor = torch::full({ quads.size(0), quads.size(1) * quads.size(2) },
|
| 1080 |
std::numeric_limits<int32_t>::max(),
|
| 1081 |
counts.options().dtype(torch::kInt32));
|
| 1082 |
+
#else
|
| 1083 |
+
torch::Tensor idsTensor;
|
| 1084 |
+
#endif
|
| 1085 |
|
| 1086 |
dim3 blockSize(32, 3, 1);
|
| 1087 |
dim3 gridSize(1,
|
|
|
|
| 1093 |
probs.packed_accessor64<scalar_t, 3>(),
|
| 1094 |
probThreshold, iouThreshold,
|
| 1095 |
counts.packed_accessor64<int32_t, 1>(),
|
| 1096 |
+
rowMergeTensor.packed_accessor64<scalar_t, 3>()
|
| 1097 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1098 |
+
, idsTensor.packed_accessor64<int32_t, 2>()
|
| 1099 |
+
#endif
|
| 1100 |
);
|
| 1101 |
|
| 1102 |
#ifdef NMS_VERIFY_CORRECTNESS
|
|
|
|
| 1119 |
|
| 1120 |
counts = counts.slice(/*dim=*/ 0, 0, counts.size(0) - 1);
|
| 1121 |
|
| 1122 |
+
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1123 |
int64_t maxExCount;
|
| 1124 |
if (counts.size(0) > 1) {
|
| 1125 |
maxExCount = counts.max().item<int32_t>();
|
|
|
|
| 1131 |
|
| 1132 |
rowMergeTensor = rowMergeTensor.slice(2, 0, maxExCount);
|
| 1133 |
idsTensor = idsTensor.slice(1, 0, maxExCount);
|
| 1134 |
+
auto order = torch::argsort(idsTensor, /*dim=*/ 1, s_sortOrder); s_sortOrder = !s_sortOrder;
|
| 1135 |
|
| 1136 |
auto embOrder = order.unsqueeze(1).expand_as(rowMergeTensor);
|
| 1137 |
|
| 1138 |
rowMergeTensor = torch::gather(rowMergeTensor, /*dim=*/ 2, embOrder);
|
| 1139 |
idsTensor = torch::gather(idsTensor, /*dim=*/ 1, order);
|
| 1140 |
+
#endif
|
| 1141 |
|
| 1142 |
return { counts, rowMergeTensor, totalQuads, idsTensor, imageWidth, imageHeight };
|
| 1143 |
}
|
|
|
|
| 1177 |
int64_t MaxExCount;
|
| 1178 |
};
|
| 1179 |
|
| 1180 |
+
template<bool IsSingleExample, typename T>
|
| 1181 |
+
void cpu_a2a_adjacency_sparse(const uint64_t punCounts,
|
| 1182 |
const T iouThreshold,
|
| 1183 |
torch::Tensor embedQuadsTensor,
|
| 1184 |
torch::Tensor outIsStartTensorGPU,
|
|
|
|
| 1196 |
auto adjValues = outSparseAdjTensor.accessor<int32_t, 3>();
|
| 1197 |
|
| 1198 |
for (int32_t b = 0; b < embedQuadsTensor.size(0); ++b) {
|
| 1199 |
+
const int32_t quadCt = IsSingleExample ? punCounts : reinterpret_cast<const int32_t*>(punCounts)[b];
|
| 1200 |
|
| 1201 |
T *exData = embedQuads[b].data();
|
| 1202 |
|
|
|
|
| 1284 |
counts.options().dtype(torch::kInt32));
|
| 1285 |
#endif
|
| 1286 |
|
| 1287 |
+
// If the batch is only a single example, instead of hitting global memory for the count, we can
|
| 1288 |
+
// just encode the count into the pointer instead
|
| 1289 |
+
uint64_t ptrCounts = reinterpret_cast<uint64_t>(counts.data_ptr<int32_t>());
|
| 1290 |
+
if (counts.size(0) == 1) {
|
| 1291 |
+
ptrCounts = maxExCount;
|
| 1292 |
+
}
|
| 1293 |
+
|
| 1294 |
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1295 |
auto cpuAdjValuesTensor = adjValuesTensor.cpu();
|
| 1296 |
auto cpuAdjCountsTensor = adjCountsTensor.cpu();
|
|
|
|
| 1318 |
//blockSize = dim3{ GRID_NUM_WARPS * 32, 1, 1 };
|
| 1319 |
//gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
|
| 1320 |
|
| 1321 |
+
//auto buildGridFn = counts.size(0) == 1 ?
|
| 1322 |
+
// device_a2a_adjacency_build_grid<GRID_NUM_WARPS, true, scalar_t, CELL_SIZE> :
|
| 1323 |
+
// device_a2a_adjacency_build_grid<GRID_NUM_WARPS, false, scalar_t, CELL_SIZE>;
|
| 1324 |
+
|
| 1325 |
+
//buildGridFn KERNEL_ARG2(gridSize, blockSize) (
|
| 1326 |
+
// ptrCounts,
|
| 1327 |
// collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
|
| 1328 |
// gridCellsTensor.packed_accessor64<int32_t, 4>(),
|
| 1329 |
// quadCellExtentsTensor.packed_accessor64<int32_t, 3>()
|
| 1330 |
//);
|
| 1331 |
|
| 1332 |
+
//auto adjGridFn = counts.size(0) == 1 ?
|
| 1333 |
+
// device_a2a_adjacency_with_grid<GRID_NUM_WARPS, true, scalar_t> :
|
| 1334 |
+
// device_a2a_adjacency_with_grid<GRID_NUM_WARPS, false, scalar_t>;
|
| 1335 |
+
|
| 1336 |
+
//adjGridFn KERNEL_ARG3(gridSize, blockSize, smemSize) (
|
| 1337 |
+
// ptrCounts,
|
| 1338 |
// iouThreshold,
|
| 1339 |
// collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
|
| 1340 |
// gridCellsTensor.packed_accessor64<int32_t, 4>(),
|
|
|
|
| 1351 |
gridSize = dim3{div_up(totalWork, blockSize.x),
|
| 1352 |
static_cast<uint32_t>(counts.size(0))};
|
| 1353 |
|
| 1354 |
+
auto adjFn = counts.size(0) == 1 ? device_a2a_adjacency_sparse<true, scalar_t> : device_a2a_adjacency_sparse<false, scalar_t>;
|
| 1355 |
+
|
| 1356 |
// This algorithm is O(n^2) with n being the current number of quads
|
| 1357 |
+
adjFn KERNEL_ARG2(gridSize, blockSize) (
|
| 1358 |
+
ptrCounts,
|
| 1359 |
iouThreshold,
|
| 1360 |
collapseResult.StridedMergeQuads.packed_accessor64<scalar_t, 3>(),
|
| 1361 |
isStartTensor.packed_accessor64<bool, 2>(),
|
|
|
|
| 1365 |
|
| 1366 |
|
| 1367 |
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1368 |
+
cpu_a2a_adjacency_sparse<true>(ptrCounts, iouThreshold,
|
|
|
|
|
|
|
| 1369 |
collapseResult.StridedMergeQuads, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
|
| 1370 |
|
| 1371 |
adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
|
|
|
|
| 1380 |
auto maxDepthTensor = torch::tensor(0, adjCountsTensor.options());
|
| 1381 |
#endif
|
| 1382 |
|
| 1383 |
+
auto traverseFn = counts.size(0) == 1 ?
|
| 1384 |
+
device_flatten_graph_iterative<true> :
|
| 1385 |
+
device_flatten_graph_iterative<false>;
|
| 1386 |
+
|
| 1387 |
blockSize = dim3{ 128, 1, 1 };
|
| 1388 |
gridSize = dim3{ 1, static_cast<uint32_t>(maxExCount), static_cast<uint32_t>(counts.size(0)) };
|
| 1389 |
smemSize = div_up(maxExCount * sizeof(visit_mask_t), sizeof(uint32_t)) * sizeof(uint32_t);
|
| 1390 |
|
| 1391 |
+
traverseFn KERNEL_ARG3(gridSize, blockSize, smemSize) (
|
| 1392 |
+
ptrCounts,
|
| 1393 |
isStartTensor.packed_accessor64<bool, 2>(),
|
| 1394 |
reinterpret_cast<uint32_t*>(adjCountsTensor.data_ptr<int32_t>()),
|
| 1395 |
reinterpret_cast<uint32_t*>(adjValuesTensor.data_ptr<int32_t>())
|
|
|
|
| 1399 |
);
|
| 1400 |
|
| 1401 |
#ifdef NMS_VERIFY_CORRECTNESS
|
| 1402 |
+
cpu_flatten_graph<true>(ptrCounts, cpuIsStartTensor, cpuAdjCountsTensor, cpuAdjValuesTensor);
|
| 1403 |
|
| 1404 |
cpuAdjValuesTensor = std::get<0>(torch::sort(cpuAdjValuesTensor, /*dim=*/ 2));
|
| 1405 |
adjValuesTensor = std::get<0>(torch::sort(adjValuesTensor, /*dim=*/ 2));
|
|
|
|
| 1437 |
cpuIsStartTensor = isStartTensor.cpu();
|
| 1438 |
cpuAdjCountsTensor = adjCountsTensor.cpu();
|
| 1439 |
cpuAdjValuesTensor = adjValuesTensor.cpu();
|
| 1440 |
+
auto cpuCounts = counts.cpu();
|
| 1441 |
auto cpuCollapseIds = collapseResult.QuadIds.cpu();
|
| 1442 |
|
| 1443 |
static std::vector<std::unordered_set<int32_t>> s_knownGroups;
|
|
|
|
| 1589 |
dim3 blockSize(NUM_WARPS * 32, 1, 1);
|
| 1590 |
dim3 gridSize(1, adjResult.MaxExCount, counts.size(0));
|
| 1591 |
|
| 1592 |
+
// If the batch is only a single example, instead of hitting global memory for the count, we can
|
| 1593 |
+
// just encode the count into the pointer instead
|
| 1594 |
+
uint64_t ptrCounts = reinterpret_cast<uint64_t>(counts.data_ptr<int32_t>());
|
| 1595 |
+
if (counts.size(0) == 1) {
|
| 1596 |
+
ptrCounts = adjResult.MaxExCount;
|
| 1597 |
+
}
|
| 1598 |
+
|
| 1599 |
torch::Tensor outQuads = torch::empty({ numOutQuads, 4, 2 }, embedQuads.options());
|
| 1600 |
torch::Tensor outConf = torch::empty({ numOutQuads }, embedQuads.options());
|
| 1601 |
|
| 1602 |
+
auto collapseFn = counts.size(0) == 1 ?
|
| 1603 |
+
device_a2a_collapse<NUM_WARPS, scalar_t, true> :
|
| 1604 |
+
device_a2a_collapse<NUM_WARPS, scalar_t, false>;
|
| 1605 |
+
|
| 1606 |
+
collapseFn KERNEL_ARG2(gridSize, blockSize) (
|
| 1607 |
+
ptrCounts,
|
| 1608 |
embedQuads.packed_accessor64<scalar_t, 3>(),
|
| 1609 |
isLeadRow.packed_accessor64<bool, 2>(),
|
| 1610 |
regionCounts.data_ptr<int64_t>(),
|
nemotron-ocr/pyproject.toml
CHANGED
|
@@ -5,7 +5,6 @@ description = "Nemoton OCR"
|
|
| 5 |
authors = [{ name = "NVIDIA Nemotron" }]
|
| 6 |
requires-python = ">=3.12,<3.13"
|
| 7 |
dependencies = [
|
| 8 |
-
"huggingface_hub>=0.20.0",
|
| 9 |
"pandas>=2.3.3",
|
| 10 |
"pillow>=12.0.0",
|
| 11 |
"scikit-learn>=1.7.2",
|
|
|
|
| 5 |
authors = [{ name = "NVIDIA Nemotron" }]
|
| 6 |
requires-python = ">=3.12,<3.13"
|
| 7 |
dependencies = [
|
|
|
|
| 8 |
"pandas>=2.3.3",
|
| 9 |
"pillow>=12.0.0",
|
| 10 |
"scikit-learn>=1.7.2",
|
nemotron-ocr/src/nemotron_ocr/inference/pipeline.py
CHANGED
|
@@ -6,7 +6,6 @@ import io
|
|
| 6 |
import json
|
| 7 |
import os
|
| 8 |
from pathlib import Path
|
| 9 |
-
from typing import Optional
|
| 10 |
|
| 11 |
import numpy as np
|
| 12 |
import torch
|
|
@@ -21,7 +20,6 @@ from nemotron_ocr.inference.post_processing.data.text_region import TextBlock
|
|
| 21 |
from nemotron_ocr.inference.post_processing.quad_rectify import QuadRectify
|
| 22 |
from nemotron_ocr.inference.post_processing.research_ops import parse_relational_results, reorder_boxes
|
| 23 |
from nemotron_ocr.inference.pre_processing import interpolate_and_pad, pad_to_square
|
| 24 |
-
from huggingface_hub import hf_hub_download
|
| 25 |
from nemotron_ocr_cpp import quad_non_maximal_suppression, region_counts_to_indices, rrect_to_quads
|
| 26 |
from PIL import Image, ImageDraw, ImageFont
|
| 27 |
from torch import amp
|
|
@@ -39,57 +37,25 @@ MERGE_LEVELS = {"word", "sentence", "paragraph"}
|
|
| 39 |
DEFAULT_MERGE_LEVEL = "paragraph"
|
| 40 |
|
| 41 |
|
| 42 |
-
# HuggingFace repository for downloading model weights
|
| 43 |
-
HF_REPO_ID = "nvidia/nemotron-ocr-v1"
|
| 44 |
-
CHECKPOINT_FILES = ["detector.pth", "recognizer.pth", "relational.pth", "charset.txt"]
|
| 45 |
-
|
| 46 |
-
|
| 47 |
class NemotronOCR:
|
| 48 |
"""
|
| 49 |
A high-level pipeline for performing OCR on images.
|
| 50 |
-
|
| 51 |
-
Model weights are automatically downloaded from Hugging Face Hub
|
| 52 |
-
(nvidia/nemotron-ocr-v1) if not found locally.
|
| 53 |
"""
|
| 54 |
|
| 55 |
-
def __init__(self, model_dir
|
| 56 |
-
|
| 57 |
-
if model_dir is not None:
|
| 58 |
-
local_path = Path(model_dir)
|
| 59 |
-
if all((local_path / f).is_file() for f in CHECKPOINT_FILES):
|
| 60 |
-
self._model_dir = local_path
|
| 61 |
-
else:
|
| 62 |
-
self._model_dir = self._download_checkpoints()
|
| 63 |
-
else:
|
| 64 |
-
self._model_dir = self._download_checkpoints()
|
| 65 |
|
| 66 |
self._load_models()
|
| 67 |
self._load_charset()
|
| 68 |
self._initialize_processors()
|
| 69 |
|
| 70 |
-
@staticmethod
|
| 71 |
-
def _download_checkpoints() -> Path:
|
| 72 |
-
"""Download model checkpoints from HuggingFace Hub (cached locally after first download)."""
|
| 73 |
-
downloaded_path = None
|
| 74 |
-
for filename in CHECKPOINT_FILES:
|
| 75 |
-
downloaded_path = hf_hub_download(
|
| 76 |
-
repo_id=HF_REPO_ID,
|
| 77 |
-
filename=f"checkpoints/{filename}",
|
| 78 |
-
)
|
| 79 |
-
# All checkpoint files are in the same directory
|
| 80 |
-
return Path(downloaded_path).parent
|
| 81 |
-
|
| 82 |
def _load_models(self):
|
| 83 |
"""Loads all necessary models into memory."""
|
| 84 |
self.detector = FOTSDetector(coordinate_mode="RBOX", backbone="regnet_y_8gf", verbose=False)
|
| 85 |
-
self.detector.load_state_dict(
|
| 86 |
-
torch.load(self._model_dir / "detector.pth", weights_only=True), strict=True
|
| 87 |
-
)
|
| 88 |
|
| 89 |
self.recognizer = TransformerRecognizer(nic=self.detector.num_features[-1], num_tokens=858, max_width=32)
|
| 90 |
-
self.recognizer.load_state_dict(
|
| 91 |
-
torch.load(self._model_dir / "recognizer.pth", weights_only=True), strict=True
|
| 92 |
-
)
|
| 93 |
|
| 94 |
self.relational = GlobalRelationalModel(
|
| 95 |
num_input_channels=self.detector.num_features,
|
|
@@ -98,9 +64,7 @@ class NemotronOCR:
|
|
| 98 |
k=16,
|
| 99 |
num_layers=4,
|
| 100 |
)
|
| 101 |
-
self.relational.load_state_dict(
|
| 102 |
-
torch.load(self._model_dir / "relational.pth", weights_only=True), strict=True
|
| 103 |
-
)
|
| 104 |
|
| 105 |
for model in (self.detector, self.recognizer, self.relational):
|
| 106 |
model = model.cuda()
|
|
@@ -217,17 +181,29 @@ class NemotronOCR:
|
|
| 217 |
e2e_det_conf = torch.sigmoid(det_conf)
|
| 218 |
e2e_det_coords = rrect_to_quads(det_rboxes.float(), DETECTOR_DOWNSAMPLE)
|
| 219 |
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
if quads.shape[0] == 0:
|
| 233 |
rec_rectified_quads = torch.empty(0, 128, 8, 32, dtype=torch.float32, device=padded_image.device)
|
|
|
|
| 6 |
import json
|
| 7 |
import os
|
| 8 |
from pathlib import Path
|
|
|
|
| 9 |
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
|
|
|
| 20 |
from nemotron_ocr.inference.post_processing.quad_rectify import QuadRectify
|
| 21 |
from nemotron_ocr.inference.post_processing.research_ops import parse_relational_results, reorder_boxes
|
| 22 |
from nemotron_ocr.inference.pre_processing import interpolate_and_pad, pad_to_square
|
|
|
|
| 23 |
from nemotron_ocr_cpp import quad_non_maximal_suppression, region_counts_to_indices, rrect_to_quads
|
| 24 |
from PIL import Image, ImageDraw, ImageFont
|
| 25 |
from torch import amp
|
|
|
|
| 37 |
DEFAULT_MERGE_LEVEL = "paragraph"
|
| 38 |
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
class NemotronOCR:
|
| 41 |
"""
|
| 42 |
A high-level pipeline for performing OCR on images.
|
|
|
|
|
|
|
|
|
|
| 43 |
"""
|
| 44 |
|
| 45 |
+
def __init__(self, model_dir="./checkpoints"):
|
| 46 |
+
self._model_dir = Path(model_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
self._load_models()
|
| 49 |
self._load_charset()
|
| 50 |
self._initialize_processors()
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def _load_models(self):
|
| 53 |
"""Loads all necessary models into memory."""
|
| 54 |
self.detector = FOTSDetector(coordinate_mode="RBOX", backbone="regnet_y_8gf", verbose=False)
|
| 55 |
+
self.detector.load_state_dict(torch.load(self._model_dir / "detector.pth"), strict=True)
|
|
|
|
|
|
|
| 56 |
|
| 57 |
self.recognizer = TransformerRecognizer(nic=self.detector.num_features[-1], num_tokens=858, max_width=32)
|
| 58 |
+
self.recognizer.load_state_dict(torch.load(self._model_dir / "recognizer.pth"), strict=True)
|
|
|
|
|
|
|
| 59 |
|
| 60 |
self.relational = GlobalRelationalModel(
|
| 61 |
num_input_channels=self.detector.num_features,
|
|
|
|
| 64 |
k=16,
|
| 65 |
num_layers=4,
|
| 66 |
)
|
| 67 |
+
self.relational.load_state_dict(torch.load(self._model_dir / "relational.pth"), strict=True)
|
|
|
|
|
|
|
| 68 |
|
| 69 |
for model in (self.detector, self.recognizer, self.relational):
|
| 70 |
model = model.cuda()
|
|
|
|
| 181 |
e2e_det_conf = torch.sigmoid(det_conf)
|
| 182 |
e2e_det_coords = rrect_to_quads(det_rboxes.float(), DETECTOR_DOWNSAMPLE)
|
| 183 |
|
| 184 |
+
# FIXME: quad_non_maximal_suppression fails with batch size > 1
|
| 185 |
+
all_quads = []
|
| 186 |
+
all_confidence = []
|
| 187 |
+
all_region_counts = []
|
| 188 |
+
|
| 189 |
+
for idx in range(e2e_det_coords.shape[0]):
|
| 190 |
+
quads, confidence, region_counts = quad_non_maximal_suppression(
|
| 191 |
+
e2e_det_coords[idx].unsqueeze(0),
|
| 192 |
+
e2e_det_conf[idx].unsqueeze(0),
|
| 193 |
+
prob_threshold=NMS_PROB_THRESHOLD,
|
| 194 |
+
iou_threshold=NMS_IOU_THRESHOLD,
|
| 195 |
+
kernel_height=2,
|
| 196 |
+
kernel_width=3,
|
| 197 |
+
max_regions=NMS_MAX_REGIONS,
|
| 198 |
+
verbose=False,
|
| 199 |
+
)[:3]
|
| 200 |
+
all_quads.append(quads)
|
| 201 |
+
all_confidence.append(confidence)
|
| 202 |
+
all_region_counts.append(region_counts)
|
| 203 |
+
|
| 204 |
+
quads = torch.cat(all_quads, dim=0)
|
| 205 |
+
confidence = torch.cat(all_confidence, dim=0)
|
| 206 |
+
region_counts = torch.cat(all_region_counts, dim=0)
|
| 207 |
|
| 208 |
if quads.shape[0] == 0:
|
| 209 |
rec_rectified_quads = torch.empty(0, 128, 8, 32, dtype=torch.float32, device=padded_image.device)
|