Ryan Chesler commited on
Commit
3575cd8
·
1 Parent(s): 11e8b7f

Simplify weight download to use hf_hub_download consistently- Remove weight_downloader.py module- Inline hf_hub_download calls in pipeline.py- Remove hf_token and force_download params from NemotronOCR- Simplify example.py

Browse files
example.py CHANGED
@@ -8,8 +8,7 @@ from nemotron_ocr.inference.pipeline import NemotronOCR
8
 
9
 
10
  def main(image_path, merge_level, no_visualize, model_dir):
11
- # model_dir can be None to use HuggingFace cache, or a path to local checkpoints
12
- ocr_pipeline = NemotronOCR(model_dir=model_dir if model_dir else None)
13
 
14
  predictions = ocr_pipeline(image_path, merge_level=merge_level, visualize=not no_visualize)
15
 
 
8
 
9
 
10
  def main(image_path, merge_level, no_visualize, model_dir):
11
+ ocr_pipeline = NemotronOCR(model_dir=model_dir)
 
12
 
13
  predictions = ocr_pipeline(image_path, merge_level=merge_level, visualize=not no_visualize)
14
 
nemotron-ocr/src/nemotron_ocr/inference/pipeline.py CHANGED
@@ -21,7 +21,7 @@ from nemotron_ocr.inference.post_processing.data.text_region import TextBlock
21
  from nemotron_ocr.inference.post_processing.quad_rectify import QuadRectify
22
  from nemotron_ocr.inference.post_processing.research_ops import parse_relational_results, reorder_boxes
23
  from nemotron_ocr.inference.pre_processing import interpolate_and_pad, pad_to_square
24
- from nemotron_ocr.inference.weight_downloader import ensure_weights_available
25
  from nemotron_ocr_cpp import quad_non_maximal_suppression, region_counts_to_indices, rrect_to_quads
26
  from PIL import Image, ImageDraw, ImageFont
27
  from torch import amp
@@ -39,54 +39,46 @@ MERGE_LEVELS = {"word", "sentence", "paragraph"}
39
  DEFAULT_MERGE_LEVEL = "paragraph"
40
 
41
 
 
 
 
 
 
42
  class NemotronOCR:
43
  """
44
  A high-level pipeline for performing OCR on images.
45
-
46
  Model weights are automatically downloaded from Hugging Face Hub
47
  (nvidia/nemotron-ocr-v1) if not found locally.
48
-
49
- Args:
50
- model_dir: Path to directory containing model checkpoints.
51
- If None, weights are downloaded to HuggingFace cache.
52
- If provided path exists and contains weights, uses them directly.
53
- If provided path doesn't have weights, downloads to HF cache.
54
- hf_token: Hugging Face authentication token (optional).
55
- force_download: If True, re-download weights even if they exist.
56
  """
57
 
58
- def __init__(
59
- self,
60
- model_dir: Optional[str] = None,
61
- hf_token: Optional[str] = None,
62
- force_download: bool = False,
63
- ):
64
- # Resolve model directory - download from HuggingFace if needed
65
  if model_dir is not None:
66
  local_path = Path(model_dir)
67
- # Check if the provided path has all required files
68
- required_files = ["detector.pth", "recognizer.pth", "relational.pth", "charset.txt"]
69
- if all((local_path / f).is_file() for f in required_files) and not force_download:
70
  self._model_dir = local_path
71
  else:
72
- # Download from HuggingFace
73
- self._model_dir = ensure_weights_available(
74
- model_dir=local_path,
75
- force_download=force_download,
76
- token=hf_token,
77
- )
78
  else:
79
- # No model_dir specified - download to HuggingFace cache
80
- self._model_dir = ensure_weights_available(
81
- model_dir=None,
82
- force_download=force_download,
83
- token=hf_token,
84
- )
85
 
86
  self._load_models()
87
  self._load_charset()
88
  self._initialize_processors()
89
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def _load_models(self):
91
  """Loads all necessary models into memory."""
92
  self.detector = FOTSDetector(coordinate_mode="RBOX", backbone="regnet_y_8gf", verbose=False)
 
21
  from nemotron_ocr.inference.post_processing.quad_rectify import QuadRectify
22
  from nemotron_ocr.inference.post_processing.research_ops import parse_relational_results, reorder_boxes
23
  from nemotron_ocr.inference.pre_processing import interpolate_and_pad, pad_to_square
24
+ from huggingface_hub import hf_hub_download
25
  from nemotron_ocr_cpp import quad_non_maximal_suppression, region_counts_to_indices, rrect_to_quads
26
  from PIL import Image, ImageDraw, ImageFont
27
  from torch import amp
 
39
  DEFAULT_MERGE_LEVEL = "paragraph"
40
 
41
 
42
+ # HuggingFace repository for downloading model weights
43
+ HF_REPO_ID = "nvidia/nemotron-ocr-v1"
44
+ CHECKPOINT_FILES = ["detector.pth", "recognizer.pth", "relational.pth", "charset.txt"]
45
+
46
+
47
  class NemotronOCR:
48
  """
49
  A high-level pipeline for performing OCR on images.
50
+
51
  Model weights are automatically downloaded from Hugging Face Hub
52
  (nvidia/nemotron-ocr-v1) if not found locally.
 
 
 
 
 
 
 
 
53
  """
54
 
55
+ def __init__(self, model_dir: Optional[str] = None):
56
+ # If model_dir is provided and contains all required files, use it directly
 
 
 
 
 
57
  if model_dir is not None:
58
  local_path = Path(model_dir)
59
+ if all((local_path / f).is_file() for f in CHECKPOINT_FILES):
 
 
60
  self._model_dir = local_path
61
  else:
62
+ self._model_dir = self._download_checkpoints()
 
 
 
 
 
63
  else:
64
+ self._model_dir = self._download_checkpoints()
 
 
 
 
 
65
 
66
  self._load_models()
67
  self._load_charset()
68
  self._initialize_processors()
69
 
70
+ @staticmethod
71
+ def _download_checkpoints() -> Path:
72
+ """Download model checkpoints from HuggingFace Hub (cached locally after first download)."""
73
+ downloaded_path = None
74
+ for filename in CHECKPOINT_FILES:
75
+ downloaded_path = hf_hub_download(
76
+ repo_id=HF_REPO_ID,
77
+ filename=f"checkpoints/{filename}",
78
+ )
79
+ # All checkpoint files are in the same directory
80
+ return Path(downloaded_path).parent
81
+
82
  def _load_models(self):
83
  """Loads all necessary models into memory."""
84
  self.detector = FOTSDetector(coordinate_mode="RBOX", backbone="regnet_y_8gf", verbose=False)
nemotron-ocr/src/nemotron_ocr/inference/weight_downloader.py DELETED
@@ -1,168 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- """
5
- Utility for downloading model weights from Hugging Face Hub.
6
-
7
- This module provides functionality to automatically download the Nemotron OCR
8
- model weights from the Hugging Face repository if they are not present locally.
9
- """
10
-
11
- from pathlib import Path
12
- from typing import Optional
13
-
14
- from huggingface_hub import hf_hub_download, snapshot_download
15
-
16
- # Hugging Face repository for Nemotron OCR weights
17
- HF_REPO_ID = "nvidia/nemotron-ocr-v1"
18
-
19
- # List of required checkpoint files
20
- CHECKPOINT_FILES = [
21
- "checkpoints/detector.pth",
22
- "checkpoints/recognizer.pth",
23
- "checkpoints/relational.pth",
24
- "checkpoints/charset.txt",
25
- ]
26
-
27
-
28
- def get_default_cache_dir() -> Path:
29
- """
30
- Get the default cache directory for storing downloaded weights.
31
-
32
- Uses the standard HuggingFace cache location.
33
-
34
- Returns:
35
- Path to the cache directory.
36
- """
37
- from huggingface_hub import constants
38
- return Path(constants.HF_HUB_CACHE)
39
-
40
-
41
- def ensure_weights_available(
42
- model_dir: Optional[Path] = None,
43
- repo_id: str = HF_REPO_ID,
44
- force_download: bool = False,
45
- token: Optional[str] = None,
46
- ) -> Path:
47
- """
48
- Ensure model weights are available, downloading them if necessary.
49
-
50
- This function checks if the required checkpoint files exist in the specified
51
- model directory. If any files are missing, it downloads them from the
52
- Hugging Face Hub.
53
-
54
- Args:
55
- model_dir: Path to the directory containing model weights.
56
- If None, uses the HuggingFace cache directory.
57
- repo_id: Hugging Face repository ID.
58
- force_download: If True, re-download even if files exist.
59
- token: Hugging Face authentication token (optional, for private repos).
60
-
61
- Returns:
62
- Path to the directory containing the model checkpoints.
63
-
64
- Raises:
65
- RuntimeError: If download fails.
66
- """
67
- # If model_dir is provided and all files exist, use it directly
68
- if model_dir is not None and not force_download:
69
- model_path = Path(model_dir)
70
- if _all_checkpoints_present(model_path):
71
- return model_path
72
-
73
- # Download to HuggingFace cache if no local path provided or files missing
74
- try:
75
- # Download only the checkpoints folder from the repo
76
- cache_dir = snapshot_download(
77
- repo_id=repo_id,
78
- allow_patterns=["checkpoints/*"],
79
- force_download=force_download,
80
- token=token,
81
- )
82
- checkpoint_dir = Path(cache_dir) / "checkpoints"
83
-
84
- if not _all_checkpoints_present_flat(checkpoint_dir):
85
- raise RuntimeError(
86
- f"Downloaded weights are incomplete. Expected files in {checkpoint_dir}"
87
- )
88
-
89
- return checkpoint_dir
90
-
91
- except Exception as e:
92
- raise RuntimeError(
93
- f"Failed to download model weights from {repo_id}. "
94
- f"Please ensure you have internet access and the repository exists. "
95
- f"Error: {e}"
96
- ) from e
97
-
98
-
99
- def _all_checkpoints_present(base_path: Path) -> bool:
100
- """Check if all required checkpoint files are present in the given directory."""
101
- required_files = ["detector.pth", "recognizer.pth", "relational.pth", "charset.txt"]
102
- return all((base_path / f).is_file() for f in required_files)
103
-
104
-
105
- def _all_checkpoints_present_flat(checkpoint_dir: Path) -> bool:
106
- """Check if all required checkpoint files are present in a flat directory."""
107
- required_files = ["detector.pth", "recognizer.pth", "relational.pth", "charset.txt"]
108
- return all((checkpoint_dir / f).is_file() for f in required_files)
109
-
110
-
111
- def download_weights(
112
- output_dir: Optional[Path] = None,
113
- repo_id: str = HF_REPO_ID,
114
- force_download: bool = False,
115
- token: Optional[str] = None,
116
- ) -> Path:
117
- """
118
- Explicitly download model weights to a specified directory.
119
-
120
- This is a convenience function for users who want to pre-download
121
- weights to a specific location.
122
-
123
- Args:
124
- output_dir: Directory to save the weights. If None, uses HuggingFace cache.
125
- repo_id: Hugging Face repository ID.
126
- force_download: If True, re-download even if files exist.
127
- token: Hugging Face authentication token (optional).
128
-
129
- Returns:
130
- Path to the directory containing the downloaded checkpoints.
131
-
132
- Example:
133
- >>> from nemotron_ocr.inference.weight_downloader import download_weights
134
- >>> checkpoint_dir = download_weights(output_dir=Path("./my_checkpoints"))
135
- >>> # Use checkpoint_dir with NemotronOCR
136
- >>> from nemotron_ocr.inference.pipeline import NemotronOCR
137
- >>> ocr = NemotronOCR(model_dir=checkpoint_dir)
138
- """
139
- if output_dir is not None:
140
- output_path = Path(output_dir)
141
- output_path.mkdir(parents=True, exist_ok=True)
142
-
143
- # Download individual files to the output directory
144
- required_files = ["detector.pth", "recognizer.pth", "relational.pth", "charset.txt"]
145
- for filename in required_files:
146
- hf_hub_download(
147
- repo_id=repo_id,
148
- filename=f"checkpoints/{filename}",
149
- local_dir=output_path.parent,
150
- force_download=force_download,
151
- token=token,
152
- )
153
-
154
- # The files are downloaded to output_path.parent/checkpoints/
155
- checkpoint_dir = output_path.parent / "checkpoints"
156
- if output_path != checkpoint_dir:
157
- # If user specified a different path, we downloaded to parent/checkpoints
158
- # Return the actual location
159
- return checkpoint_dir
160
- return output_path
161
- else:
162
- return ensure_weights_available(
163
- model_dir=None,
164
- repo_id=repo_id,
165
- force_download=force_download,
166
- token=token,
167
- )
168
-