Spaces:
Sleeping
Sleeping
Load model artifacts from archive at startup
Browse files- predictor.py +26 -8
- requirements.txt +1 -2
predictor.py
CHANGED
|
@@ -1,5 +1,8 @@
|
|
| 1 |
import json
|
| 2 |
import os
|
|
|
|
|
|
|
|
|
|
| 3 |
from functools import lru_cache
|
| 4 |
from pathlib import Path
|
| 5 |
from typing import Any, Dict, Mapping, Union
|
|
@@ -8,7 +11,6 @@ import joblib
|
|
| 8 |
import numpy as np
|
| 9 |
import torch
|
| 10 |
import torch.nn as nn
|
| 11 |
-
from huggingface_hub import snapshot_download
|
| 12 |
|
| 13 |
|
| 14 |
MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "Dharunkumar9/battery-capacity-predictor")
|
|
@@ -24,6 +26,7 @@ class PositionalEncoding(nn.Module):
|
|
| 24 |
pe[:, 0::2] = torch.sin(position * div_term)
|
| 25 |
pe[:, 1::2] = torch.cos(position * div_term)
|
| 26 |
self.register_buffer("pe", pe.unsqueeze(0))
|
|
|
|
| 27 |
|
| 28 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 29 |
return x + self.pe[:, : x.size(1), :]
|
|
@@ -99,13 +102,28 @@ def _normalize_window(window: Any, expected_rows: int, expected_cols: int) -> np
|
|
| 99 |
|
| 100 |
|
| 101 |
def _download_artifacts() -> Path:
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
|
| 111 |
class BatteryPredictor:
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
+
import urllib.request
|
| 4 |
+
import tempfile
|
| 5 |
+
import zipfile
|
| 6 |
from functools import lru_cache
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Any, Dict, Mapping, Union
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
import torch
|
| 13 |
import torch.nn as nn
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "Dharunkumar9/battery-capacity-predictor")
|
|
|
|
| 26 |
pe[:, 0::2] = torch.sin(position * div_term)
|
| 27 |
pe[:, 1::2] = torch.cos(position * div_term)
|
| 28 |
self.register_buffer("pe", pe.unsqueeze(0))
|
| 29 |
+
self.pe: torch.Tensor
|
| 30 |
|
| 31 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 32 |
return x + self.pe[:, : x.size(1), :]
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
def _download_artifacts() -> Path:
|
| 105 |
+
archive_name_candidates = ["artifacts_v1.zip", "artifacts-v1.zip"]
|
| 106 |
+
archive_path = None
|
| 107 |
+
|
| 108 |
+
for archive_name in archive_name_candidates:
|
| 109 |
+
try:
|
| 110 |
+
archive_url = f"https://huggingface.co/{MODEL_REPO_ID}/resolve/main/{archive_name}?download=true"
|
| 111 |
+
archive_file = Path(tempfile.mkdtemp(prefix="battery-archive-")) / archive_name
|
| 112 |
+
with urllib.request.urlopen(archive_url) as response, archive_file.open("wb") as output:
|
| 113 |
+
output.write(response.read())
|
| 114 |
+
archive_path = str(archive_file)
|
| 115 |
+
break
|
| 116 |
+
except Exception:
|
| 117 |
+
continue
|
| 118 |
+
|
| 119 |
+
if archive_path is None:
|
| 120 |
+
raise FileNotFoundError("Could not download the model artifact zip from the Hugging Face model repo")
|
| 121 |
+
|
| 122 |
+
extract_dir = Path(tempfile.mkdtemp(prefix="battery-model-"))
|
| 123 |
+
with zipfile.ZipFile(archive_path) as archive:
|
| 124 |
+
archive.extractall(extract_dir)
|
| 125 |
+
|
| 126 |
+
return extract_dir
|
| 127 |
|
| 128 |
|
| 129 |
class BatteryPredictor:
|
requirements.txt
CHANGED
|
@@ -3,5 +3,4 @@ uvicorn[standard]>=0.29.0
|
|
| 3 |
torch>=2.2.0
|
| 4 |
numpy>=1.24.0
|
| 5 |
joblib>=1.3.0
|
| 6 |
-
scikit-learn==1.7.0
|
| 7 |
-
huggingface_hub>=0.23.0
|
|
|
|
| 3 |
torch>=2.2.0
|
| 4 |
numpy>=1.24.0
|
| 5 |
joblib>=1.3.0
|
| 6 |
+
scikit-learn==1.7.0
|
|
|