battery-analytics / predictor.py
Dharun235's picture
Load model artifacts from archive at startup
8b2f8e8
import json
import os
import urllib.request
import tempfile
import zipfile
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, Mapping, Union
import joblib
import numpy as np
import torch
import torch.nn as nn
MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "Dharunkumar9/battery-capacity-predictor")
BATTERY_ORDER = ["B0005", "B0006", "B0007", "B0018"]
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 500):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe.unsqueeze(0))
self.pe: torch.Tensor
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.pe[:, : x.size(1), :]
class BatteryTransformer(nn.Module):
def __init__(
self,
num_features: int,
d_model: int = 128,
nhead: int = 4,
num_layers: int = 2,
dim_feedforward: int = 256,
dropout: float = 0.1,
last_frac: float = 0.4,
last_weight: float = 3.0,
):
super().__init__()
self.input_proj = nn.Linear(num_features, d_model)
self.pos_encoder = PositionalEncoding(d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.dropout = nn.Dropout(dropout)
self.regressor = nn.Linear(d_model, 1)
self.last_frac = last_frac
self.last_weight = last_weight
def forward(self, x: torch.Tensor) -> torch.Tensor:
seq_len = x.size(1)
x = self.input_proj(x)
x = self.pos_encoder(x)
x = self.encoder(x)
weights = torch.ones(seq_len, device=x.device)
last_start = int(seq_len * (1 - self.last_frac))
weights[last_start:] = self.last_weight
weights = weights / weights.sum()
x = (x * weights.unsqueeze(1)).sum(dim=1)
x = self.dropout(x)
return self.regressor(x).squeeze(-1)
def _resolve_battery_id(battery_id: Union[str, int]) -> str:
if isinstance(battery_id, int):
if battery_id < 0 or battery_id >= len(BATTERY_ORDER):
raise ValueError(f"battery_id index must be between 0 and {len(BATTERY_ORDER) - 1}")
return BATTERY_ORDER[battery_id]
battery_id = str(battery_id).strip()
if battery_id not in BATTERY_ORDER:
raise ValueError(f"battery_id must be one of {BATTERY_ORDER} or a 0-based index")
return battery_id
def _normalize_window(window: Any, expected_rows: int, expected_cols: int) -> np.ndarray:
array = np.asarray(window, dtype=np.float32)
if array.ndim == 1:
if array.size != expected_rows * expected_cols:
raise ValueError(f"window must contain {expected_rows * expected_cols} values")
array = array.reshape(expected_rows, expected_cols)
if array.shape != (expected_rows, expected_cols):
raise ValueError(f"window must have shape ({expected_rows}, {expected_cols})")
return array
def _download_artifacts() -> Path:
archive_name_candidates = ["artifacts_v1.zip", "artifacts-v1.zip"]
archive_path = None
for archive_name in archive_name_candidates:
try:
archive_url = f"https://huggingface.co/{MODEL_REPO_ID}/resolve/main/{archive_name}?download=true"
archive_file = Path(tempfile.mkdtemp(prefix="battery-archive-")) / archive_name
with urllib.request.urlopen(archive_url) as response, archive_file.open("wb") as output:
output.write(response.read())
archive_path = str(archive_file)
break
except Exception:
continue
if archive_path is None:
raise FileNotFoundError("Could not download the model artifact zip from the Hugging Face model repo")
extract_dir = Path(tempfile.mkdtemp(prefix="battery-model-"))
with zipfile.ZipFile(archive_path) as archive:
archive.extractall(extract_dir)
return extract_dir
class BatteryPredictor:
def __init__(self) -> None:
artifact_dir = _download_artifacts()
config = json.loads((artifact_dir / "config.json").read_text())
self.window_size = int(config["window_size"])
self.num_features = int(config["num_features"])
self.x_scalers = joblib.load(artifact_dir / "x_scalers.pkl")
self.y_scalers = joblib.load(artifact_dir / "y_scalers.pkl")
self.model = BatteryTransformer(
num_features=self.num_features,
d_model=int(config["d_model"]),
nhead=int(config["nhead"]),
num_layers=int(config["num_layers"]),
dim_feedforward=int(config["dim_feedforward"]),
dropout=float(config["dropout"]),
).to("cpu")
state_dict = torch.load(artifact_dir / "pytorch_model.bin", map_location="cpu")
self.model.load_state_dict(state_dict)
self.model.eval()
def predict(self, window: Any, battery_id: Union[str, int] = "B0005") -> Dict[str, Any]:
battery_key = _resolve_battery_id(battery_id)
window_array = _normalize_window(window, self.window_size, self.num_features)
x_scaler = self.x_scalers[battery_key]
y_scaler = self.y_scalers[battery_key]
scaled_window = x_scaler.transform(window_array)
tensor = torch.tensor(scaled_window[None, :, :], dtype=torch.float32)
with torch.no_grad():
scaled_prediction = float(self.model(tensor).item())
predicted_capacity = float(y_scaler.inverse_transform([[scaled_prediction]])[0, 0])
return {
"battery_id": battery_key,
"window_size": self.window_size,
"num_features": self.num_features,
"predicted_capacity": predicted_capacity,
"scaled_prediction": scaled_prediction,
}
@lru_cache(maxsize=1)
def get_predictor() -> BatteryPredictor:
return BatteryPredictor()
def predict_from_request(payload: Mapping[str, Any]) -> Dict[str, Any]:
if not isinstance(payload, Mapping):
raise TypeError("payload must be a mapping with battery_id and window")
if "window" not in payload:
raise ValueError("payload must include a window field")
return get_predictor().predict(payload["window"], battery_id=payload.get("battery_id", "B0005"))