Ryan Chesler commited on
Commit
3f0560d
·
1 Parent(s): d75d375

Simplify weight download to use hf_hub_download consistently- Remove get_weights_path helper, inline hf_hub_download in define_model- Fix WEIGHTS_FILENAME to use subdirectory path- Fix copy-paste bug in define_model default config_name- Remove get_weights_path from __init__.py exports

Browse files
nemotron_table_structure_v1/__init__.py CHANGED
@@ -9,7 +9,7 @@ A specialized object detection model for table structure extraction based on YOL
9
 
10
  __version__ = "1.0.0"
11
 
12
- from .model import define_model, YoloXWrapper, get_weights_path
13
  from .utils import (
14
  plot_sample,
15
  postprocess_preds_table_structure,
@@ -19,7 +19,6 @@ from .utils import (
19
 
20
  __all__ = [
21
  "define_model",
22
- "get_weights_path",
23
  "YoloXWrapper",
24
  "plot_sample",
25
  "postprocess_preds_table_structure",
 
9
 
10
  __version__ = "1.0.0"
11
 
12
+ from .model import define_model, YoloXWrapper
13
  from .utils import (
14
  plot_sample,
15
  postprocess_preds_table_structure,
 
19
 
20
  __all__ = [
21
  "define_model",
 
22
  "YoloXWrapper",
23
  "plot_sample",
24
  "postprocess_preds_table_structure",
nemotron_table_structure_v1/model.py CHANGED
@@ -13,56 +13,42 @@ from typing import Dict, List, Tuple, Union
13
  from huggingface_hub import hf_hub_download
14
  from .yolox.boxes import postprocess
15
 
16
- # HuggingFace repository for weights
17
  HF_REPO_ID = "nvidia/nemotron-table-structure-v1"
18
- WEIGHTS_FILENAME = "weights.pth"
19
 
20
 
21
- def get_weights_path(verbose: bool = True) -> str:
22
- """
23
- Get the path to the model weights, downloading from HuggingFace if necessary.
24
-
25
- The weights are cached in the HuggingFace cache directory after the first download.
26
-
27
- Args:
28
- verbose (bool): Whether to print download progress. Defaults to True.
29
-
30
- Returns:
31
- str: Path to the weights file.
32
- """
33
- if verbose:
34
- print(f" -> Downloading/loading weights from HuggingFace: {HF_REPO_ID}")
35
-
36
- weights_path = hf_hub_download(
37
- repo_id=HF_REPO_ID,
38
- filename=WEIGHTS_FILENAME,
39
- repo_type="model",
40
- )
41
-
42
- return weights_path
43
-
44
-
45
- def define_model(config_name: str = "page_element_v3", verbose: bool = True) -> nn.Module:
46
  """
47
  Defines and initializes the model based on the configuration.
48
 
49
  Args:
50
- config_name (str): Configuration name. Defaults to "page_element_v3".
51
  verbose (bool): Whether to print verbose output. Defaults to True.
52
 
53
  Returns:
54
  torch.nn.Module: The initialized YOLOX model.
55
  """
56
  # Load model from exp_file
57
- # page_element_v3.py is in the same directory as model.py
58
  sys.path.append(os.path.dirname(__file__))
59
  exp_module = importlib.import_module("table_structure_v1")
60
 
61
  config = exp_module.Exp()
62
  model = config.get_model()
63
 
64
- # Load weights (downloaded from HuggingFace if not cached)
65
- weights_path = get_weights_path(verbose=verbose)
 
 
 
 
 
 
 
 
 
 
66
  state_dict = torch.load(weights_path, map_location="cpu", weights_only=False)
67
  model.load_state_dict(state_dict["model"], strict=True)
68
 
 
13
  from huggingface_hub import hf_hub_download
14
  from .yolox.boxes import postprocess
15
 
16
+ # HuggingFace repository for downloading model weights
17
  HF_REPO_ID = "nvidia/nemotron-table-structure-v1"
18
+ WEIGHTS_FILENAME = "nemotron_table_structure_v1/weights.pth"
19
 
20
 
21
+ def define_model(config_name: str = "table_structure_v1", verbose: bool = True) -> nn.Module:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  """
23
  Defines and initializes the model based on the configuration.
24
 
25
  Args:
26
+ config_name (str): Configuration name. Defaults to "table_structure_v1".
27
  verbose (bool): Whether to print verbose output. Defaults to True.
28
 
29
  Returns:
30
  torch.nn.Module: The initialized YOLOX model.
31
  """
32
  # Load model from exp_file
33
+ # table_structure_v1.py is in the same directory as model.py
34
  sys.path.append(os.path.dirname(__file__))
35
  exp_module = importlib.import_module("table_structure_v1")
36
 
37
  config = exp_module.Exp()
38
  model = config.get_model()
39
 
40
+ # Download weights from HuggingFace Hub (cached locally after first download)
41
+ if verbose:
42
+ print(f" -> Downloading/loading weights from HuggingFace: {HF_REPO_ID}")
43
+
44
+ weights_path = hf_hub_download(
45
+ repo_id=HF_REPO_ID,
46
+ filename=WEIGHTS_FILENAME,
47
+ )
48
+
49
+ if verbose:
50
+ print(f" -> Weights cached at: {weights_path}")
51
+
52
  state_dict = torch.load(weights_path, map_location="cpu", weights_only=False)
53
  model.load_state_dict(state_dict["model"], strict=True)
54