remove_weights_from_python_wheel

#6
by jdye64 - opened
MANIFEST.in CHANGED
@@ -1,6 +1,7 @@
1
  include README.md
2
  include THIRD_PARTY_NOTICES.md
3
- recursive-include nemotron_table_structure_v1
 
4
 
5
 
6
 
 
1
  include README.md
2
  include THIRD_PARTY_NOTICES.md
3
+ recursive-include nemotron_table_structure_v1 *.py *.json
4
+ recursive-exclude nemotron_table_structure_v1 *.pth
5
 
6
 
7
 
README.md CHANGED
@@ -148,8 +148,12 @@ import numpy as np
148
  import matplotlib.pyplot as plt
149
  from PIL import Image
150
 
151
- from model import define_model
152
- from utils import plot_sample, postprocess_preds_table_structure, reformat_for_plotting
 
 
 
 
153
 
154
  # Load image
155
  path = "./example.png"
 
148
  import matplotlib.pyplot as plt
149
  from PIL import Image
150
 
151
+ from nemotron_table_structure_v1 import (
152
+ define_model,
153
+ plot_sample,
154
+ postprocess_preds_table_structure,
155
+ reformat_for_plotting,
156
+ )
157
 
158
  # Load image
159
  path = "./example.png"
nemotron_table_structure_v1/__init__.py CHANGED
@@ -9,7 +9,7 @@ A specialized object detection model for table structure extraction based on YOL
9
 
10
  __version__ = "1.0.0"
11
 
12
- from .model import define_model, YoloXWrapper
13
  from .utils import (
14
  plot_sample,
15
  postprocess_preds_table_structure,
@@ -19,6 +19,7 @@ from .utils import (
19
 
20
  __all__ = [
21
  "define_model",
 
22
  "YoloXWrapper",
23
  "plot_sample",
24
  "postprocess_preds_table_structure",
 
9
 
10
  __version__ = "1.0.0"
11
 
12
+ from .model import define_model, YoloXWrapper, get_weights_path
13
  from .utils import (
14
  plot_sample,
15
  postprocess_preds_table_structure,
 
19
 
20
  __all__ = [
21
  "define_model",
22
+ "get_weights_path",
23
  "YoloXWrapper",
24
  "plot_sample",
25
  "postprocess_preds_table_structure",
nemotron_table_structure_v1/model.py CHANGED
@@ -10,8 +10,37 @@ import numpy.typing as npt
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  from typing import Dict, List, Tuple, Union
 
13
  from .yolox.boxes import postprocess
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def define_model(config_name: str = "page_element_v3", verbose: bool = True) -> nn.Module:
17
  """
@@ -32,13 +61,8 @@ def define_model(config_name: str = "page_element_v3", verbose: bool = True) ->
32
  config = exp_module.Exp()
33
  model = config.get_model()
34
 
35
- # Load weights
36
- if verbose:
37
- print(" -> Loading weights from", config.ckpt)
38
-
39
- # Find package directory and load weights (nemotron_table_structure_v1)
40
- package_dir = os.path.dirname(os.path.abspath(__file__))
41
- weights_path = os.path.join(package_dir, "weights.pth")
42
  state_dict = torch.load(weights_path, map_location="cpu", weights_only=False)
43
  model.load_state_dict(state_dict["model"], strict=True)
44
 
 
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  from typing import Dict, List, Tuple, Union
13
+ from huggingface_hub import hf_hub_download
14
  from .yolox.boxes import postprocess
15
 
16
+ # HuggingFace repository for weights
17
+ HF_REPO_ID = "nvidia/nemotron-table-structure-v1"
18
+ WEIGHTS_FILENAME = "weights.pth"
19
+
20
+
21
+ def get_weights_path(verbose: bool = True) -> str:
22
+ """
23
+ Get the path to the model weights, downloading from HuggingFace if necessary.
24
+
25
+ The weights are cached in the HuggingFace cache directory after the first download.
26
+
27
+ Args:
28
+ verbose (bool): Whether to print download progress. Defaults to True.
29
+
30
+ Returns:
31
+ str: Path to the weights file.
32
+ """
33
+ if verbose:
34
+ print(f" -> Downloading/loading weights from HuggingFace: {HF_REPO_ID}")
35
+
36
+ weights_path = hf_hub_download(
37
+ repo_id=HF_REPO_ID,
38
+ filename=WEIGHTS_FILENAME,
39
+ repo_type="model",
40
+ )
41
+
42
+ return weights_path
43
+
44
 
45
  def define_model(config_name: str = "page_element_v3", verbose: bool = True) -> nn.Module:
46
  """
 
61
  config = exp_module.Exp()
62
  model = config.get_model()
63
 
64
+ # Load weights (downloaded from HuggingFace if not cached)
65
+ weights_path = get_weights_path(verbose=verbose)
 
 
 
 
 
66
  state_dict = torch.load(weights_path, map_location="cpu", weights_only=False)
67
  model.load_state_dict(state_dict["model"], strict=True)
68
 
pyproject.toml CHANGED
@@ -33,6 +33,7 @@ dependencies = [
33
  "matplotlib>=3.3.0",
34
  "pandas>=1.3.0",
35
  "Pillow>=8.0.0",
 
36
  ]
37
 
38
  [project.urls]
@@ -45,7 +46,7 @@ Documentation = "https://huggingface.co/nvidia/nemotron-table-structure-v1"
45
  packages = ["nemotron_table_structure_v1", "nemotron_table_structure_v1.yolox", "nemotron_table_structure_v1.post_processing"]
46
 
47
  [tool.setuptools.package-data]
48
- "*" = ["*.pth", "config.json"]
49
 
50
 
51
 
 
33
  "matplotlib>=3.3.0",
34
  "pandas>=1.3.0",
35
  "Pillow>=8.0.0",
36
+ "huggingface_hub>=0.20.0",
37
  ]
38
 
39
  [project.urls]
 
46
  packages = ["nemotron_table_structure_v1", "nemotron_table_structure_v1.yolox", "nemotron_table_structure_v1.post_processing"]
47
 
48
  [tool.setuptools.package-data]
49
+ "*" = ["config.json"]
50
 
51
 
52