remove_weights_from_python_wheel

#6
by jdye64 - opened
example.py CHANGED
@@ -8,7 +8,7 @@ from nemotron_ocr.inference.pipeline import NemotronOCR
8
 
9
 
10
  def main(image_path, merge_level, no_visualize, model_dir):
11
- ocr_pipeline = NemotronOCR()
12
 
13
  predictions = ocr_pipeline(image_path, merge_level=merge_level, visualize=not no_visualize)
14
 
 
8
 
9
 
10
  def main(image_path, merge_level, no_visualize, model_dir):
11
+ ocr_pipeline = NemotronOCR(model_dir=model_dir)
12
 
13
  predictions = ocr_pipeline(image_path, merge_level=merge_level, visualize=not no_visualize)
14
 
nemotron-ocr/pyproject.toml CHANGED
@@ -5,6 +5,7 @@ description = "Nemoton OCR"
5
  authors = [{ name = "NVIDIA Nemotron" }]
6
  requires-python = ">=3.12,<3.13"
7
  dependencies = [
 
8
  "pandas>=2.3.3",
9
  "pillow>=12.0.0",
10
  "scikit-learn>=1.7.2",
 
5
  authors = [{ name = "NVIDIA Nemotron" }]
6
  requires-python = ">=3.12,<3.13"
7
  dependencies = [
8
+ "huggingface_hub>=0.20.0",
9
  "pandas>=2.3.3",
10
  "pillow>=12.0.0",
11
  "scikit-learn>=1.7.2",
nemotron-ocr/src/nemotron_ocr/inference/pipeline.py CHANGED
@@ -6,6 +6,7 @@ import io
6
  import json
7
  import os
8
  from pathlib import Path
 
9
 
10
  import numpy as np
11
  import torch
@@ -20,6 +21,7 @@ from nemotron_ocr.inference.post_processing.data.text_region import TextBlock
20
  from nemotron_ocr.inference.post_processing.quad_rectify import QuadRectify
21
  from nemotron_ocr.inference.post_processing.research_ops import parse_relational_results, reorder_boxes
22
  from nemotron_ocr.inference.pre_processing import interpolate_and_pad, pad_to_square
 
23
  from nemotron_ocr_cpp import quad_non_maximal_suppression, region_counts_to_indices, rrect_to_quads
24
  from PIL import Image, ImageDraw, ImageFont
25
  from torch import amp
@@ -37,25 +39,57 @@ MERGE_LEVELS = {"word", "sentence", "paragraph"}
37
  DEFAULT_MERGE_LEVEL = "paragraph"
38
 
39
 
 
 
 
 
 
40
  class NemotronOCR:
41
  """
42
  A high-level pipeline for performing OCR on images.
 
 
 
43
  """
44
 
45
- def __init__(self, model_dir="./checkpoints"):
46
- self._model_dir = Path(model_dir)
 
 
 
 
 
 
 
 
47
 
48
  self._load_models()
49
  self._load_charset()
50
  self._initialize_processors()
51
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def _load_models(self):
53
  """Loads all necessary models into memory."""
54
  self.detector = FOTSDetector(coordinate_mode="RBOX", backbone="regnet_y_8gf", verbose=False)
55
- self.detector.load_state_dict(torch.load(self._model_dir / "detector.pth"), strict=True)
 
 
56
 
57
  self.recognizer = TransformerRecognizer(nic=self.detector.num_features[-1], num_tokens=858, max_width=32)
58
- self.recognizer.load_state_dict(torch.load(self._model_dir / "recognizer.pth"), strict=True)
 
 
59
 
60
  self.relational = GlobalRelationalModel(
61
  num_input_channels=self.detector.num_features,
@@ -64,7 +98,9 @@ class NemotronOCR:
64
  k=16,
65
  num_layers=4,
66
  )
67
- self.relational.load_state_dict(torch.load(self._model_dir / "relational.pth"), strict=True)
 
 
68
 
69
  for model in (self.detector, self.recognizer, self.relational):
70
  model = model.cuda()
 
6
  import json
7
  import os
8
  from pathlib import Path
9
+ from typing import Optional
10
 
11
  import numpy as np
12
  import torch
 
21
  from nemotron_ocr.inference.post_processing.quad_rectify import QuadRectify
22
  from nemotron_ocr.inference.post_processing.research_ops import parse_relational_results, reorder_boxes
23
  from nemotron_ocr.inference.pre_processing import interpolate_and_pad, pad_to_square
24
+ from huggingface_hub import hf_hub_download
25
  from nemotron_ocr_cpp import quad_non_maximal_suppression, region_counts_to_indices, rrect_to_quads
26
  from PIL import Image, ImageDraw, ImageFont
27
  from torch import amp
 
39
  DEFAULT_MERGE_LEVEL = "paragraph"
40
 
41
 
42
+ # HuggingFace repository for downloading model weights
43
+ HF_REPO_ID = "nvidia/nemotron-ocr-v1"
44
+ CHECKPOINT_FILES = ["detector.pth", "recognizer.pth", "relational.pth", "charset.txt"]
45
+
46
+
47
  class NemotronOCR:
48
  """
49
  A high-level pipeline for performing OCR on images.
50
+
51
+ Model weights are automatically downloaded from Hugging Face Hub
52
+ (nvidia/nemotron-ocr-v1) if not found locally.
53
  """
54
 
55
+ def __init__(self, model_dir: Optional[str] = None):
56
+ # If model_dir is provided and contains all required files, use it directly
57
+ if model_dir is not None:
58
+ local_path = Path(model_dir)
59
+ if all((local_path / f).is_file() for f in CHECKPOINT_FILES):
60
+ self._model_dir = local_path
61
+ else:
62
+ self._model_dir = self._download_checkpoints()
63
+ else:
64
+ self._model_dir = self._download_checkpoints()
65
 
66
  self._load_models()
67
  self._load_charset()
68
  self._initialize_processors()
69
 
70
+ @staticmethod
71
+ def _download_checkpoints() -> Path:
72
+ """Download model checkpoints from HuggingFace Hub (cached locally after first download)."""
73
+ downloaded_path = None
74
+ for filename in CHECKPOINT_FILES:
75
+ downloaded_path = hf_hub_download(
76
+ repo_id=HF_REPO_ID,
77
+ filename=f"checkpoints/{filename}",
78
+ )
79
+ # All checkpoint files are in the same directory
80
+ return Path(downloaded_path).parent
81
+
82
  def _load_models(self):
83
  """Loads all necessary models into memory."""
84
  self.detector = FOTSDetector(coordinate_mode="RBOX", backbone="regnet_y_8gf", verbose=False)
85
+ self.detector.load_state_dict(
86
+ torch.load(self._model_dir / "detector.pth", weights_only=True), strict=True
87
+ )
88
 
89
  self.recognizer = TransformerRecognizer(nic=self.detector.num_features[-1], num_tokens=858, max_width=32)
90
+ self.recognizer.load_state_dict(
91
+ torch.load(self._model_dir / "recognizer.pth", weights_only=True), strict=True
92
+ )
93
 
94
  self.relational = GlobalRelationalModel(
95
  num_input_channels=self.detector.num_features,
 
98
  k=16,
99
  num_layers=4,
100
  )
101
+ self.relational.load_state_dict(
102
+ torch.load(self._model_dir / "relational.pth", weights_only=True), strict=True
103
+ )
104
 
105
  for model in (self.detector, self.recognizer, self.relational):
106
  model = model.cuda()