remove_weights_from_python_wheel
#6
by
jdye64
- opened
example.py
CHANGED
|
@@ -8,7 +8,7 @@ from nemotron_ocr.inference.pipeline import NemotronOCR
|
|
| 8 |
|
| 9 |
|
| 10 |
def main(image_path, merge_level, no_visualize, model_dir):
|
| 11 |
-
ocr_pipeline = NemotronOCR()
|
| 12 |
|
| 13 |
predictions = ocr_pipeline(image_path, merge_level=merge_level, visualize=not no_visualize)
|
| 14 |
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
def main(image_path, merge_level, no_visualize, model_dir):
|
| 11 |
+
ocr_pipeline = NemotronOCR(model_dir=model_dir)
|
| 12 |
|
| 13 |
predictions = ocr_pipeline(image_path, merge_level=merge_level, visualize=not no_visualize)
|
| 14 |
|
nemotron-ocr/pyproject.toml
CHANGED
|
@@ -5,6 +5,7 @@ description = "Nemoton OCR"
|
|
| 5 |
authors = [{ name = "NVIDIA Nemotron" }]
|
| 6 |
requires-python = ">=3.12,<3.13"
|
| 7 |
dependencies = [
|
|
|
|
| 8 |
"pandas>=2.3.3",
|
| 9 |
"pillow>=12.0.0",
|
| 10 |
"scikit-learn>=1.7.2",
|
|
|
|
| 5 |
authors = [{ name = "NVIDIA Nemotron" }]
|
| 6 |
requires-python = ">=3.12,<3.13"
|
| 7 |
dependencies = [
|
| 8 |
+
"huggingface_hub>=0.20.0",
|
| 9 |
"pandas>=2.3.3",
|
| 10 |
"pillow>=12.0.0",
|
| 11 |
"scikit-learn>=1.7.2",
|
nemotron-ocr/src/nemotron_ocr/inference/pipeline.py
CHANGED
|
@@ -6,6 +6,7 @@ import io
|
|
| 6 |
import json
|
| 7 |
import os
|
| 8 |
from pathlib import Path
|
|
|
|
| 9 |
|
| 10 |
import numpy as np
|
| 11 |
import torch
|
|
@@ -20,6 +21,7 @@ from nemotron_ocr.inference.post_processing.data.text_region import TextBlock
|
|
| 20 |
from nemotron_ocr.inference.post_processing.quad_rectify import QuadRectify
|
| 21 |
from nemotron_ocr.inference.post_processing.research_ops import parse_relational_results, reorder_boxes
|
| 22 |
from nemotron_ocr.inference.pre_processing import interpolate_and_pad, pad_to_square
|
|
|
|
| 23 |
from nemotron_ocr_cpp import quad_non_maximal_suppression, region_counts_to_indices, rrect_to_quads
|
| 24 |
from PIL import Image, ImageDraw, ImageFont
|
| 25 |
from torch import amp
|
|
@@ -37,25 +39,57 @@ MERGE_LEVELS = {"word", "sentence", "paragraph"}
|
|
| 37 |
DEFAULT_MERGE_LEVEL = "paragraph"
|
| 38 |
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
class NemotronOCR:
|
| 41 |
"""
|
| 42 |
A high-level pipeline for performing OCR on images.
|
|
|
|
|
|
|
|
|
|
| 43 |
"""
|
| 44 |
|
| 45 |
-
def __init__(self, model_dir=
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
self._load_models()
|
| 49 |
self._load_charset()
|
| 50 |
self._initialize_processors()
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def _load_models(self):
|
| 53 |
"""Loads all necessary models into memory."""
|
| 54 |
self.detector = FOTSDetector(coordinate_mode="RBOX", backbone="regnet_y_8gf", verbose=False)
|
| 55 |
-
self.detector.load_state_dict(
|
|
|
|
|
|
|
| 56 |
|
| 57 |
self.recognizer = TransformerRecognizer(nic=self.detector.num_features[-1], num_tokens=858, max_width=32)
|
| 58 |
-
self.recognizer.load_state_dict(
|
|
|
|
|
|
|
| 59 |
|
| 60 |
self.relational = GlobalRelationalModel(
|
| 61 |
num_input_channels=self.detector.num_features,
|
|
@@ -64,7 +98,9 @@ class NemotronOCR:
|
|
| 64 |
k=16,
|
| 65 |
num_layers=4,
|
| 66 |
)
|
| 67 |
-
self.relational.load_state_dict(
|
|
|
|
|
|
|
| 68 |
|
| 69 |
for model in (self.detector, self.recognizer, self.relational):
|
| 70 |
model = model.cuda()
|
|
|
|
| 6 |
import json
|
| 7 |
import os
|
| 8 |
from pathlib import Path
|
| 9 |
+
from typing import Optional
|
| 10 |
|
| 11 |
import numpy as np
|
| 12 |
import torch
|
|
|
|
| 21 |
from nemotron_ocr.inference.post_processing.quad_rectify import QuadRectify
|
| 22 |
from nemotron_ocr.inference.post_processing.research_ops import parse_relational_results, reorder_boxes
|
| 23 |
from nemotron_ocr.inference.pre_processing import interpolate_and_pad, pad_to_square
|
| 24 |
+
from huggingface_hub import hf_hub_download
|
| 25 |
from nemotron_ocr_cpp import quad_non_maximal_suppression, region_counts_to_indices, rrect_to_quads
|
| 26 |
from PIL import Image, ImageDraw, ImageFont
|
| 27 |
from torch import amp
|
|
|
|
| 39 |
DEFAULT_MERGE_LEVEL = "paragraph"
|
| 40 |
|
| 41 |
|
| 42 |
+
# HuggingFace repository for downloading model weights
|
| 43 |
+
HF_REPO_ID = "nvidia/nemotron-ocr-v1"
|
| 44 |
+
CHECKPOINT_FILES = ["detector.pth", "recognizer.pth", "relational.pth", "charset.txt"]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
class NemotronOCR:
|
| 48 |
"""
|
| 49 |
A high-level pipeline for performing OCR on images.
|
| 50 |
+
|
| 51 |
+
Model weights are automatically downloaded from Hugging Face Hub
|
| 52 |
+
(nvidia/nemotron-ocr-v1) if not found locally.
|
| 53 |
"""
|
| 54 |
|
| 55 |
+
def __init__(self, model_dir: Optional[str] = None):
|
| 56 |
+
# If model_dir is provided and contains all required files, use it directly
|
| 57 |
+
if model_dir is not None:
|
| 58 |
+
local_path = Path(model_dir)
|
| 59 |
+
if all((local_path / f).is_file() for f in CHECKPOINT_FILES):
|
| 60 |
+
self._model_dir = local_path
|
| 61 |
+
else:
|
| 62 |
+
self._model_dir = self._download_checkpoints()
|
| 63 |
+
else:
|
| 64 |
+
self._model_dir = self._download_checkpoints()
|
| 65 |
|
| 66 |
self._load_models()
|
| 67 |
self._load_charset()
|
| 68 |
self._initialize_processors()
|
| 69 |
|
| 70 |
+
@staticmethod
|
| 71 |
+
def _download_checkpoints() -> Path:
|
| 72 |
+
"""Download model checkpoints from HuggingFace Hub (cached locally after first download)."""
|
| 73 |
+
downloaded_path = None
|
| 74 |
+
for filename in CHECKPOINT_FILES:
|
| 75 |
+
downloaded_path = hf_hub_download(
|
| 76 |
+
repo_id=HF_REPO_ID,
|
| 77 |
+
filename=f"checkpoints/{filename}",
|
| 78 |
+
)
|
| 79 |
+
# All checkpoint files are in the same directory
|
| 80 |
+
return Path(downloaded_path).parent
|
| 81 |
+
|
| 82 |
def _load_models(self):
|
| 83 |
"""Loads all necessary models into memory."""
|
| 84 |
self.detector = FOTSDetector(coordinate_mode="RBOX", backbone="regnet_y_8gf", verbose=False)
|
| 85 |
+
self.detector.load_state_dict(
|
| 86 |
+
torch.load(self._model_dir / "detector.pth", weights_only=True), strict=True
|
| 87 |
+
)
|
| 88 |
|
| 89 |
self.recognizer = TransformerRecognizer(nic=self.detector.num_features[-1], num_tokens=858, max_width=32)
|
| 90 |
+
self.recognizer.load_state_dict(
|
| 91 |
+
torch.load(self._model_dir / "recognizer.pth", weights_only=True), strict=True
|
| 92 |
+
)
|
| 93 |
|
| 94 |
self.relational = GlobalRelationalModel(
|
| 95 |
num_input_channels=self.detector.num_features,
|
|
|
|
| 98 |
k=16,
|
| 99 |
num_layers=4,
|
| 100 |
)
|
| 101 |
+
self.relational.load_state_dict(
|
| 102 |
+
torch.load(self._model_dir / "relational.pth", weights_only=True), strict=True
|
| 103 |
+
)
|
| 104 |
|
| 105 |
for model in (self.detector, self.recognizer, self.relational):
|
| 106 |
model = model.cuda()
|