Dharun235 commited on
Commit
8b2f8e8
·
1 Parent(s): b5f4f2e

Load model artifacts from archive at startup

Browse files
Files changed (2) hide show
  1. predictor.py +26 -8
  2. requirements.txt +1 -2
predictor.py CHANGED
@@ -1,5 +1,8 @@
1
  import json
2
  import os
 
 
 
3
  from functools import lru_cache
4
  from pathlib import Path
5
  from typing import Any, Dict, Mapping, Union
@@ -8,7 +11,6 @@ import joblib
8
  import numpy as np
9
  import torch
10
  import torch.nn as nn
11
- from huggingface_hub import snapshot_download
12
 
13
 
14
  MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "Dharunkumar9/battery-capacity-predictor")
@@ -24,6 +26,7 @@ class PositionalEncoding(nn.Module):
24
  pe[:, 0::2] = torch.sin(position * div_term)
25
  pe[:, 1::2] = torch.cos(position * div_term)
26
  self.register_buffer("pe", pe.unsqueeze(0))
 
27
 
28
  def forward(self, x: torch.Tensor) -> torch.Tensor:
29
  return x + self.pe[:, : x.size(1), :]
@@ -99,13 +102,28 @@ def _normalize_window(window: Any, expected_rows: int, expected_cols: int) -> np
99
 
100
 
101
  def _download_artifacts() -> Path:
102
- return Path(
103
- snapshot_download(
104
- repo_id=MODEL_REPO_ID,
105
- repo_type="model",
106
- allow_patterns=["config.json", "pytorch_model.bin", "x_scalers.pkl", "y_scalers.pkl"],
107
- )
108
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
 
111
  class BatteryPredictor:
 
1
  import json
2
  import os
3
+ import urllib.request
4
+ import tempfile
5
+ import zipfile
6
  from functools import lru_cache
7
  from pathlib import Path
8
  from typing import Any, Dict, Mapping, Union
 
11
  import numpy as np
12
  import torch
13
  import torch.nn as nn
 
14
 
15
 
16
  MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "Dharunkumar9/battery-capacity-predictor")
 
26
  pe[:, 0::2] = torch.sin(position * div_term)
27
  pe[:, 1::2] = torch.cos(position * div_term)
28
  self.register_buffer("pe", pe.unsqueeze(0))
29
+ self.pe: torch.Tensor
30
 
31
  def forward(self, x: torch.Tensor) -> torch.Tensor:
32
  return x + self.pe[:, : x.size(1), :]
 
102
 
103
 
104
  def _download_artifacts() -> Path:
105
+ archive_name_candidates = ["artifacts_v1.zip", "artifacts-v1.zip"]
106
+ archive_path = None
107
+
108
+ for archive_name in archive_name_candidates:
109
+ try:
110
+ archive_url = f"https://huggingface.co/{MODEL_REPO_ID}/resolve/main/{archive_name}?download=true"
111
+ archive_file = Path(tempfile.mkdtemp(prefix="battery-archive-")) / archive_name
112
+ with urllib.request.urlopen(archive_url) as response, archive_file.open("wb") as output:
113
+ output.write(response.read())
114
+ archive_path = str(archive_file)
115
+ break
116
+ except Exception:
117
+ continue
118
+
119
+ if archive_path is None:
120
+ raise FileNotFoundError("Could not download the model artifact zip from the Hugging Face model repo")
121
+
122
+ extract_dir = Path(tempfile.mkdtemp(prefix="battery-model-"))
123
+ with zipfile.ZipFile(archive_path) as archive:
124
+ archive.extractall(extract_dir)
125
+
126
+ return extract_dir
127
 
128
 
129
  class BatteryPredictor:
requirements.txt CHANGED
@@ -3,5 +3,4 @@ uvicorn[standard]>=0.29.0
3
  torch>=2.2.0
4
  numpy>=1.24.0
5
  joblib>=1.3.0
6
- scikit-learn==1.7.0
7
- huggingface_hub>=0.23.0
 
3
  torch>=2.2.0
4
  numpy>=1.24.0
5
  joblib>=1.3.0
6
+ scikit-learn==1.7.0