remove_weights_from_python_wheel

#9
by jdye64 - opened
MANIFEST.in CHANGED
@@ -1,3 +1,4 @@
1
  include README.md
2
  include THIRD_PARTY_NOTICES.md
3
- recursive-include nemotron_graphic_elements_v1
 
 
1
  include README.md
2
  include THIRD_PARTY_NOTICES.md
3
+ recursive-include nemotron_graphic_elements_v1 *.py *.json *.png
4
+ recursive-exclude nemotron_graphic_elements_v1 *.pth
nemotron_graphic_elements_v1/__init__.py CHANGED
@@ -6,6 +6,8 @@ Nemotron Graphic Elements v1
6
 
7
  A specialized object detection system designed to identify and extract key elements
8
  from charts and graphs. Based on YOLOX architecture.
 
 
9
  """
10
 
11
  __version__ = "1.0.0"
 
6
 
7
  A specialized object detection system designed to identify and extract key elements
8
  from charts and graphs. Based on YOLOX architecture.
9
+
10
+ Model weights are automatically downloaded from Hugging Face Hub on first use.
11
  """
12
 
13
  __version__ = "1.0.0"
nemotron_graphic_elements_v1/model.py CHANGED
@@ -10,8 +10,13 @@ import numpy.typing as npt
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  from typing import Dict, List, Tuple, Union
 
13
  from .yolox.boxes import postprocess
14
 
 
 
 
 
15
 
16
  def define_model(config_name: str = "graphic_element_v1", verbose: bool = True) -> nn.Module:
17
  """
@@ -30,11 +35,19 @@ def define_model(config_name: str = "graphic_element_v1", verbose: bool = True)
30
  config = Exp()
31
  model = config.get_model()
32
 
33
- # Load weights
 
 
 
 
 
 
 
 
34
  if verbose:
35
- print(" -> Loading weights from", config.ckpt)
36
 
37
- ckpt = torch.load(config.ckpt, map_location="cpu", weights_only=False)
38
  model.load_state_dict(ckpt["model"], strict=True)
39
 
40
  model = YoloXWrapper(model, config)
 
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  from typing import Dict, List, Tuple, Union
13
+ from huggingface_hub import hf_hub_download
14
  from .yolox.boxes import postprocess
15
 
16
+ # HuggingFace repository for downloading model weights
17
+ HF_REPO_ID = "nvidia/nemotron-graphic-elements-v1"
18
+ WEIGHTS_FILENAME = "nemotron_graphic_elements_v1/weights.pth"
19
+
20
 
21
  def define_model(config_name: str = "graphic_element_v1", verbose: bool = True) -> nn.Module:
22
  """
 
35
  config = Exp()
36
  model = config.get_model()
37
 
38
+ # Download weights from HuggingFace Hub (cached locally after first download)
39
+ if verbose:
40
+ print(f" -> Downloading/loading weights from HuggingFace: {HF_REPO_ID}")
41
+
42
+ weights_path = hf_hub_download(
43
+ repo_id=HF_REPO_ID,
44
+ filename=WEIGHTS_FILENAME,
45
+ )
46
+
47
  if verbose:
48
+ print(f" -> Weights cached at: {weights_path}")
49
 
50
+ ckpt = torch.load(weights_path, map_location="cpu", weights_only=False)
51
  model.load_state_dict(ckpt["model"], strict=True)
52
 
53
  model = YoloXWrapper(model, config)
pyproject.toml CHANGED
@@ -32,6 +32,7 @@ dependencies = [
32
  "matplotlib>=3.5.0",
33
  "pandas>=1.3.0",
34
  "Pillow>=9.0.0",
 
35
  ]
36
 
37
  [project.optional-dependencies]
@@ -50,5 +51,5 @@ Repository = "https://huggingface.co/nvidia/nemotron-graphic-elements-v1"
50
  packages = ["nemotron_graphic_elements_v1", "nemotron_graphic_elements_v1.yolox", "nemotron_graphic_elements_v1.post_processing"]
51
 
52
  [tool.setuptools.package-data]
53
- "nemotron_graphic_elements_v1" = ["*.pth", "*.json", "*.png"]
54
 
 
32
  "matplotlib>=3.5.0",
33
  "pandas>=1.3.0",
34
  "Pillow>=9.0.0",
35
+ "huggingface_hub>=0.20.0",
36
  ]
37
 
38
  [project.optional-dependencies]
 
51
  packages = ["nemotron_graphic_elements_v1", "nemotron_graphic_elements_v1.yolox", "nemotron_graphic_elements_v1.post_processing"]
52
 
53
  [tool.setuptools.package-data]
54
+ "nemotron_graphic_elements_v1" = ["*.json", "*.png"]
55