import json import os import urllib.request import tempfile import zipfile from functools import lru_cache from pathlib import Path from typing import Any, Dict, Mapping, Union import joblib import numpy as np import torch import torch.nn as nn MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "Dharunkumar9/battery-capacity-predictor") BATTERY_ORDER = ["B0005", "B0006", "B0007", "B0018"] class PositionalEncoding(nn.Module): def __init__(self, d_model: int, max_len: int = 500): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * -(np.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe.unsqueeze(0)) self.pe: torch.Tensor def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.pe[:, : x.size(1), :] class BatteryTransformer(nn.Module): def __init__( self, num_features: int, d_model: int = 128, nhead: int = 4, num_layers: int = 2, dim_feedforward: int = 256, dropout: float = 0.1, last_frac: float = 0.4, last_weight: float = 3.0, ): super().__init__() self.input_proj = nn.Linear(num_features, d_model) self.pos_encoder = PositionalEncoding(d_model) encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True, ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.dropout = nn.Dropout(dropout) self.regressor = nn.Linear(d_model, 1) self.last_frac = last_frac self.last_weight = last_weight def forward(self, x: torch.Tensor) -> torch.Tensor: seq_len = x.size(1) x = self.input_proj(x) x = self.pos_encoder(x) x = self.encoder(x) weights = torch.ones(seq_len, device=x.device) last_start = int(seq_len * (1 - self.last_frac)) weights[last_start:] = self.last_weight weights = weights / weights.sum() x = (x * weights.unsqueeze(1)).sum(dim=1) x = self.dropout(x) return self.regressor(x).squeeze(-1) def _resolve_battery_id(battery_id: Union[str, int]) -> str: if isinstance(battery_id, int): if battery_id < 0 or battery_id >= len(BATTERY_ORDER): raise ValueError(f"battery_id index must be between 0 and {len(BATTERY_ORDER) - 1}") return BATTERY_ORDER[battery_id] battery_id = str(battery_id).strip() if battery_id not in BATTERY_ORDER: raise ValueError(f"battery_id must be one of {BATTERY_ORDER} or a 0-based index") return battery_id def _normalize_window(window: Any, expected_rows: int, expected_cols: int) -> np.ndarray: array = np.asarray(window, dtype=np.float32) if array.ndim == 1: if array.size != expected_rows * expected_cols: raise ValueError(f"window must contain {expected_rows * expected_cols} values") array = array.reshape(expected_rows, expected_cols) if array.shape != (expected_rows, expected_cols): raise ValueError(f"window must have shape ({expected_rows}, {expected_cols})") return array def _download_artifacts() -> Path: archive_name_candidates = ["artifacts_v1.zip", "artifacts-v1.zip"] archive_path = None for archive_name in archive_name_candidates: try: archive_url = f"https://huggingface.co/{MODEL_REPO_ID}/resolve/main/{archive_name}?download=true" archive_file = Path(tempfile.mkdtemp(prefix="battery-archive-")) / archive_name with urllib.request.urlopen(archive_url) as response, archive_file.open("wb") as output: output.write(response.read()) archive_path = str(archive_file) break except Exception: continue if archive_path is None: raise FileNotFoundError("Could not download the model artifact zip from the Hugging Face model repo") extract_dir = Path(tempfile.mkdtemp(prefix="battery-model-")) with zipfile.ZipFile(archive_path) as archive: archive.extractall(extract_dir) return extract_dir class BatteryPredictor: def __init__(self) -> None: artifact_dir = _download_artifacts() config = json.loads((artifact_dir / "config.json").read_text()) self.window_size = int(config["window_size"]) self.num_features = int(config["num_features"]) self.x_scalers = joblib.load(artifact_dir / "x_scalers.pkl") self.y_scalers = joblib.load(artifact_dir / "y_scalers.pkl") self.model = BatteryTransformer( num_features=self.num_features, d_model=int(config["d_model"]), nhead=int(config["nhead"]), num_layers=int(config["num_layers"]), dim_feedforward=int(config["dim_feedforward"]), dropout=float(config["dropout"]), ).to("cpu") state_dict = torch.load(artifact_dir / "pytorch_model.bin", map_location="cpu") self.model.load_state_dict(state_dict) self.model.eval() def predict(self, window: Any, battery_id: Union[str, int] = "B0005") -> Dict[str, Any]: battery_key = _resolve_battery_id(battery_id) window_array = _normalize_window(window, self.window_size, self.num_features) x_scaler = self.x_scalers[battery_key] y_scaler = self.y_scalers[battery_key] scaled_window = x_scaler.transform(window_array) tensor = torch.tensor(scaled_window[None, :, :], dtype=torch.float32) with torch.no_grad(): scaled_prediction = float(self.model(tensor).item()) predicted_capacity = float(y_scaler.inverse_transform([[scaled_prediction]])[0, 0]) return { "battery_id": battery_key, "window_size": self.window_size, "num_features": self.num_features, "predicted_capacity": predicted_capacity, "scaled_prediction": scaled_prediction, } @lru_cache(maxsize=1) def get_predictor() -> BatteryPredictor: return BatteryPredictor() def predict_from_request(payload: Mapping[str, Any]) -> Dict[str, Any]: if not isinstance(payload, Mapping): raise TypeError("payload must be a mapping with battery_id and window") if "window" not in payload: raise ValueError("payload must include a window field") return get_predictor().predict(payload["window"], battery_id=payload.get("battery_id", "B0005"))