Spaces:
Sleeping
Sleeping
File size: 6,016 Bytes
b5f4f2e 8b2f8e8 b5f4f2e 8b2f8e8 b5f4f2e 8b2f8e8 b5f4f2e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | import json
import os
import urllib.request
import tempfile
import zipfile
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Mapping, Union
import joblib
import numpy as np
import torch
import torch.nn as nn
MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "Dharunkumar9/battery-capacity-predictor")
BATTERY_ORDER = ["B0005", "B0006", "B0007", "B0018"]
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 500):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe.unsqueeze(0))
self.pe: torch.Tensor
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.pe[:, : x.size(1), :]
class BatteryTransformer(nn.Module):
def __init__(
self,
num_features: int,
d_model: int = 128,
nhead: int = 4,
num_layers: int = 2,
dim_feedforward: int = 256,
dropout: float = 0.1,
last_frac: float = 0.4,
last_weight: float = 3.0,
):
super().__init__()
self.input_proj = nn.Linear(num_features, d_model)
self.pos_encoder = PositionalEncoding(d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.dropout = nn.Dropout(dropout)
self.regressor = nn.Linear(d_model, 1)
self.last_frac = last_frac
self.last_weight = last_weight
def forward(self, x: torch.Tensor) -> torch.Tensor:
seq_len = x.size(1)
x = self.input_proj(x)
x = self.pos_encoder(x)
x = self.encoder(x)
weights = torch.ones(seq_len, device=x.device)
last_start = int(seq_len * (1 - self.last_frac))
weights[last_start:] = self.last_weight
weights = weights / weights.sum()
x = (x * weights.unsqueeze(1)).sum(dim=1)
x = self.dropout(x)
return self.regressor(x).squeeze(-1)
def _resolve_battery_id(battery_id: Union[str, int]) -> str:
if isinstance(battery_id, int):
if battery_id < 0 or battery_id >= len(BATTERY_ORDER):
raise ValueError(f"battery_id index must be between 0 and {len(BATTERY_ORDER) - 1}")
return BATTERY_ORDER[battery_id]
battery_id = str(battery_id).strip()
if battery_id not in BATTERY_ORDER:
raise ValueError(f"battery_id must be one of {BATTERY_ORDER} or a 0-based index")
return battery_id
def _normalize_window(window: Any, expected_rows: int, expected_cols: int) -> np.ndarray:
array = np.asarray(window, dtype=np.float32)
if array.ndim == 1:
if array.size != expected_rows * expected_cols:
raise ValueError(f"window must contain {expected_rows * expected_cols} values")
array = array.reshape(expected_rows, expected_cols)
if array.shape != (expected_rows, expected_cols):
raise ValueError(f"window must have shape ({expected_rows}, {expected_cols})")
return array
def _download_artifacts() -> Path:
archive_name_candidates = ["artifacts_v1.zip", "artifacts-v1.zip"]
archive_path = None
for archive_name in archive_name_candidates:
try:
archive_url = f"https://huggingface.co/{MODEL_REPO_ID}/resolve/main/{archive_name}?download=true"
archive_file = Path(tempfile.mkdtemp(prefix="battery-archive-")) / archive_name
with urllib.request.urlopen(archive_url) as response, archive_file.open("wb") as output:
output.write(response.read())
archive_path = str(archive_file)
break
except Exception:
continue
if archive_path is None:
raise FileNotFoundError("Could not download the model artifact zip from the Hugging Face model repo")
extract_dir = Path(tempfile.mkdtemp(prefix="battery-model-"))
with zipfile.ZipFile(archive_path) as archive:
archive.extractall(extract_dir)
return extract_dir
class BatteryPredictor:
def __init__(self) -> None:
artifact_dir = _download_artifacts()
config = json.loads((artifact_dir / "config.json").read_text())
self.window_size = int(config["window_size"])
self.num_features = int(config["num_features"])
self.x_scalers = joblib.load(artifact_dir / "x_scalers.pkl")
self.y_scalers = joblib.load(artifact_dir / "y_scalers.pkl")
self.model = BatteryTransformer(
num_features=self.num_features,
d_model=int(config["d_model"]),
nhead=int(config["nhead"]),
num_layers=int(config["num_layers"]),
dim_feedforward=int(config["dim_feedforward"]),
dropout=float(config["dropout"]),
).to("cpu")
state_dict = torch.load(artifact_dir / "pytorch_model.bin", map_location="cpu")
self.model.load_state_dict(state_dict)
self.model.eval()
def predict(self, window: Any, battery_id: Union[str, int] = "B0005") -> Dict[str, Any]:
battery_key = _resolve_battery_id(battery_id)
window_array = _normalize_window(window, self.window_size, self.num_features)
x_scaler = self.x_scalers[battery_key]
y_scaler = self.y_scalers[battery_key]
scaled_window = x_scaler.transform(window_array)
tensor = torch.tensor(scaled_window[None, :, :], dtype=torch.float32)
with torch.no_grad():
scaled_prediction = float(self.model(tensor).item())
predicted_capacity = float(y_scaler.inverse_transform([[scaled_prediction]])[0, 0])
return {
"battery_id": battery_key,
"window_size": self.window_size,
"num_features": self.num_features,
"predicted_capacity": predicted_capacity,
"scaled_prediction": scaled_prediction,
}
@lru_cache(maxsize=1)
def get_predictor() -> BatteryPredictor:
return BatteryPredictor()
def predict_from_request(payload: Mapping[str, Any]) -> Dict[str, Any]:
if not isinstance(payload, Mapping):
raise TypeError("payload must be a mapping with battery_id and window")
if "window" not in payload:
raise ValueError("payload must include a window field")
return get_predictor().predict(payload["window"], battery_id=payload.get("battery_id", "B0005"))
|