Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import urllib.request | |
| import tempfile | |
| import zipfile | |
| from functools import lru_cache | |
| from pathlib import Path | |
| from typing import Any, Dict, Mapping, Union | |
| import joblib | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "Dharunkumar9/battery-capacity-predictor") | |
| BATTERY_ORDER = ["B0005", "B0006", "B0007", "B0018"] | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model: int, max_len: int = 500): | |
| super().__init__() | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2) * -(np.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| self.register_buffer("pe", pe.unsqueeze(0)) | |
| self.pe: torch.Tensor | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x + self.pe[:, : x.size(1), :] | |
| class BatteryTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| num_features: int, | |
| d_model: int = 128, | |
| nhead: int = 4, | |
| num_layers: int = 2, | |
| dim_feedforward: int = 256, | |
| dropout: float = 0.1, | |
| last_frac: float = 0.4, | |
| last_weight: float = 3.0, | |
| ): | |
| super().__init__() | |
| self.input_proj = nn.Linear(num_features, d_model) | |
| self.pos_encoder = PositionalEncoding(d_model) | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=d_model, | |
| nhead=nhead, | |
| dim_feedforward=dim_feedforward, | |
| dropout=dropout, | |
| batch_first=True, | |
| ) | |
| self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
| self.dropout = nn.Dropout(dropout) | |
| self.regressor = nn.Linear(d_model, 1) | |
| self.last_frac = last_frac | |
| self.last_weight = last_weight | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| seq_len = x.size(1) | |
| x = self.input_proj(x) | |
| x = self.pos_encoder(x) | |
| x = self.encoder(x) | |
| weights = torch.ones(seq_len, device=x.device) | |
| last_start = int(seq_len * (1 - self.last_frac)) | |
| weights[last_start:] = self.last_weight | |
| weights = weights / weights.sum() | |
| x = (x * weights.unsqueeze(1)).sum(dim=1) | |
| x = self.dropout(x) | |
| return self.regressor(x).squeeze(-1) | |
| def _resolve_battery_id(battery_id: Union[str, int]) -> str: | |
| if isinstance(battery_id, int): | |
| if battery_id < 0 or battery_id >= len(BATTERY_ORDER): | |
| raise ValueError(f"battery_id index must be between 0 and {len(BATTERY_ORDER) - 1}") | |
| return BATTERY_ORDER[battery_id] | |
| battery_id = str(battery_id).strip() | |
| if battery_id not in BATTERY_ORDER: | |
| raise ValueError(f"battery_id must be one of {BATTERY_ORDER} or a 0-based index") | |
| return battery_id | |
| def _normalize_window(window: Any, expected_rows: int, expected_cols: int) -> np.ndarray: | |
| array = np.asarray(window, dtype=np.float32) | |
| if array.ndim == 1: | |
| if array.size != expected_rows * expected_cols: | |
| raise ValueError(f"window must contain {expected_rows * expected_cols} values") | |
| array = array.reshape(expected_rows, expected_cols) | |
| if array.shape != (expected_rows, expected_cols): | |
| raise ValueError(f"window must have shape ({expected_rows}, {expected_cols})") | |
| return array | |
| def _download_artifacts() -> Path: | |
| archive_name_candidates = ["artifacts_v1.zip", "artifacts-v1.zip"] | |
| archive_path = None | |
| for archive_name in archive_name_candidates: | |
| try: | |
| archive_url = f"https://huggingface.co/{MODEL_REPO_ID}/resolve/main/{archive_name}?download=true" | |
| archive_file = Path(tempfile.mkdtemp(prefix="battery-archive-")) / archive_name | |
| with urllib.request.urlopen(archive_url) as response, archive_file.open("wb") as output: | |
| output.write(response.read()) | |
| archive_path = str(archive_file) | |
| break | |
| except Exception: | |
| continue | |
| if archive_path is None: | |
| raise FileNotFoundError("Could not download the model artifact zip from the Hugging Face model repo") | |
| extract_dir = Path(tempfile.mkdtemp(prefix="battery-model-")) | |
| with zipfile.ZipFile(archive_path) as archive: | |
| archive.extractall(extract_dir) | |
| return extract_dir | |
| class BatteryPredictor: | |
| def __init__(self) -> None: | |
| artifact_dir = _download_artifacts() | |
| config = json.loads((artifact_dir / "config.json").read_text()) | |
| self.window_size = int(config["window_size"]) | |
| self.num_features = int(config["num_features"]) | |
| self.x_scalers = joblib.load(artifact_dir / "x_scalers.pkl") | |
| self.y_scalers = joblib.load(artifact_dir / "y_scalers.pkl") | |
| self.model = BatteryTransformer( | |
| num_features=self.num_features, | |
| d_model=int(config["d_model"]), | |
| nhead=int(config["nhead"]), | |
| num_layers=int(config["num_layers"]), | |
| dim_feedforward=int(config["dim_feedforward"]), | |
| dropout=float(config["dropout"]), | |
| ).to("cpu") | |
| state_dict = torch.load(artifact_dir / "pytorch_model.bin", map_location="cpu") | |
| self.model.load_state_dict(state_dict) | |
| self.model.eval() | |
| def predict(self, window: Any, battery_id: Union[str, int] = "B0005") -> Dict[str, Any]: | |
| battery_key = _resolve_battery_id(battery_id) | |
| window_array = _normalize_window(window, self.window_size, self.num_features) | |
| x_scaler = self.x_scalers[battery_key] | |
| y_scaler = self.y_scalers[battery_key] | |
| scaled_window = x_scaler.transform(window_array) | |
| tensor = torch.tensor(scaled_window[None, :, :], dtype=torch.float32) | |
| with torch.no_grad(): | |
| scaled_prediction = float(self.model(tensor).item()) | |
| predicted_capacity = float(y_scaler.inverse_transform([[scaled_prediction]])[0, 0]) | |
| return { | |
| "battery_id": battery_key, | |
| "window_size": self.window_size, | |
| "num_features": self.num_features, | |
| "predicted_capacity": predicted_capacity, | |
| "scaled_prediction": scaled_prediction, | |
| } | |
| def get_predictor() -> BatteryPredictor: | |
| return BatteryPredictor() | |
| def predict_from_request(payload: Mapping[str, Any]) -> Dict[str, Any]: | |
| if not isinstance(payload, Mapping): | |
| raise TypeError("payload must be a mapping with battery_id and window") | |
| if "window" not in payload: | |
| raise ValueError("payload must include a window field") | |
| return get_predictor().predict(payload["window"], battery_id=payload.get("battery_id", "B0005")) | |