Dharun235 commited on
Commit
b5f4f2e
·
0 Parent(s):

Add FastAPI inference API

Browse files
Files changed (6) hide show
  1. .gitattributes +35 -0
  2. Dockerfile +14 -0
  3. README.md +18 -0
  4. app.py +43 -0
  5. predictor.py +170 -0
  6. requirements.txt +7 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11.11-slim-bookworm
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+
7
+ WORKDIR /app
8
+
9
+ COPY --chown=user ./requirements.txt requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ COPY --chown=user . /app
13
+
14
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Battery Analytics
3
+ emoji: 🌍
4
+ colorFrom: yellow
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ ---
10
+
11
+ This Space serves the battery capacity prediction API.
12
+
13
+ The service exposes:
14
+
15
+ - `GET /health`
16
+ - `POST /predict`
17
+
18
+ The model weights and scalers are downloaded at runtime from the public model repo `Dharunkumar9/battery-capacity-predictor`, so the Space repo only needs code, not binary artifacts.
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List
2
+
3
+ from fastapi import FastAPI, HTTPException
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel, Field
6
+
7
+ from predictor import get_predictor
8
+
9
+
10
+ class PredictRequest(BaseModel):
11
+ battery_id: str = Field(default="B0005", description="Battery id or 0-based battery index")
12
+ window: List[List[float]] = Field(..., description="15 x 13 feature window")
13
+
14
+
15
+ app = FastAPI(title="Battery Capacity Predictor API", version="1.0.0")
16
+
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"],
20
+ allow_credentials=True,
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
24
+
25
+ predictor = get_predictor()
26
+
27
+
28
+ @app.get("/")
29
+ def root() -> dict[str, str]:
30
+ return {"status": "ok", "message": "Battery capacity prediction API is running"}
31
+
32
+
33
+ @app.get("/health")
34
+ def health() -> dict[str, str]:
35
+ return {"status": "healthy"}
36
+
37
+
38
+ @app.post("/predict")
39
+ def predict(request: PredictRequest) -> dict[str, Any]:
40
+ try:
41
+ return predictor.predict(request.window, battery_id=request.battery_id)
42
+ except ValueError as exc:
43
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
predictor.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from functools import lru_cache
4
+ from pathlib import Path
5
+ from typing import Any, Dict, Mapping, Union
6
+
7
+ import joblib
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ from huggingface_hub import snapshot_download
12
+
13
+
14
+ MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "Dharunkumar9/battery-capacity-predictor")
15
+ BATTERY_ORDER = ["B0005", "B0006", "B0007", "B0018"]
16
+
17
+
18
+ class PositionalEncoding(nn.Module):
19
+ def __init__(self, d_model: int, max_len: int = 500):
20
+ super().__init__()
21
+ pe = torch.zeros(max_len, d_model)
22
+ position = torch.arange(0, max_len).unsqueeze(1)
23
+ div_term = torch.exp(torch.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
24
+ pe[:, 0::2] = torch.sin(position * div_term)
25
+ pe[:, 1::2] = torch.cos(position * div_term)
26
+ self.register_buffer("pe", pe.unsqueeze(0))
27
+
28
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
29
+ return x + self.pe[:, : x.size(1), :]
30
+
31
+
32
+ class BatteryTransformer(nn.Module):
33
+ def __init__(
34
+ self,
35
+ num_features: int,
36
+ d_model: int = 128,
37
+ nhead: int = 4,
38
+ num_layers: int = 2,
39
+ dim_feedforward: int = 256,
40
+ dropout: float = 0.1,
41
+ last_frac: float = 0.4,
42
+ last_weight: float = 3.0,
43
+ ):
44
+ super().__init__()
45
+ self.input_proj = nn.Linear(num_features, d_model)
46
+ self.pos_encoder = PositionalEncoding(d_model)
47
+ encoder_layer = nn.TransformerEncoderLayer(
48
+ d_model=d_model,
49
+ nhead=nhead,
50
+ dim_feedforward=dim_feedforward,
51
+ dropout=dropout,
52
+ batch_first=True,
53
+ )
54
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
55
+ self.dropout = nn.Dropout(dropout)
56
+ self.regressor = nn.Linear(d_model, 1)
57
+ self.last_frac = last_frac
58
+ self.last_weight = last_weight
59
+
60
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
61
+ seq_len = x.size(1)
62
+ x = self.input_proj(x)
63
+ x = self.pos_encoder(x)
64
+ x = self.encoder(x)
65
+
66
+ weights = torch.ones(seq_len, device=x.device)
67
+ last_start = int(seq_len * (1 - self.last_frac))
68
+ weights[last_start:] = self.last_weight
69
+ weights = weights / weights.sum()
70
+
71
+ x = (x * weights.unsqueeze(1)).sum(dim=1)
72
+ x = self.dropout(x)
73
+ return self.regressor(x).squeeze(-1)
74
+
75
+
76
+ def _resolve_battery_id(battery_id: Union[str, int]) -> str:
77
+ if isinstance(battery_id, int):
78
+ if battery_id < 0 or battery_id >= len(BATTERY_ORDER):
79
+ raise ValueError(f"battery_id index must be between 0 and {len(BATTERY_ORDER) - 1}")
80
+ return BATTERY_ORDER[battery_id]
81
+
82
+ battery_id = str(battery_id).strip()
83
+ if battery_id not in BATTERY_ORDER:
84
+ raise ValueError(f"battery_id must be one of {BATTERY_ORDER} or a 0-based index")
85
+ return battery_id
86
+
87
+
88
+ def _normalize_window(window: Any, expected_rows: int, expected_cols: int) -> np.ndarray:
89
+ array = np.asarray(window, dtype=np.float32)
90
+ if array.ndim == 1:
91
+ if array.size != expected_rows * expected_cols:
92
+ raise ValueError(f"window must contain {expected_rows * expected_cols} values")
93
+ array = array.reshape(expected_rows, expected_cols)
94
+
95
+ if array.shape != (expected_rows, expected_cols):
96
+ raise ValueError(f"window must have shape ({expected_rows}, {expected_cols})")
97
+
98
+ return array
99
+
100
+
101
+ def _download_artifacts() -> Path:
102
+ return Path(
103
+ snapshot_download(
104
+ repo_id=MODEL_REPO_ID,
105
+ repo_type="model",
106
+ allow_patterns=["config.json", "pytorch_model.bin", "x_scalers.pkl", "y_scalers.pkl"],
107
+ )
108
+ )
109
+
110
+
111
+ class BatteryPredictor:
112
+ def __init__(self) -> None:
113
+ artifact_dir = _download_artifacts()
114
+ config = json.loads((artifact_dir / "config.json").read_text())
115
+ self.window_size = int(config["window_size"])
116
+ self.num_features = int(config["num_features"])
117
+
118
+ self.x_scalers = joblib.load(artifact_dir / "x_scalers.pkl")
119
+ self.y_scalers = joblib.load(artifact_dir / "y_scalers.pkl")
120
+
121
+ self.model = BatteryTransformer(
122
+ num_features=self.num_features,
123
+ d_model=int(config["d_model"]),
124
+ nhead=int(config["nhead"]),
125
+ num_layers=int(config["num_layers"]),
126
+ dim_feedforward=int(config["dim_feedforward"]),
127
+ dropout=float(config["dropout"]),
128
+ ).to("cpu")
129
+
130
+ state_dict = torch.load(artifact_dir / "pytorch_model.bin", map_location="cpu")
131
+ self.model.load_state_dict(state_dict)
132
+ self.model.eval()
133
+
134
+ def predict(self, window: Any, battery_id: Union[str, int] = "B0005") -> Dict[str, Any]:
135
+ battery_key = _resolve_battery_id(battery_id)
136
+ window_array = _normalize_window(window, self.window_size, self.num_features)
137
+
138
+ x_scaler = self.x_scalers[battery_key]
139
+ y_scaler = self.y_scalers[battery_key]
140
+
141
+ scaled_window = x_scaler.transform(window_array)
142
+ tensor = torch.tensor(scaled_window[None, :, :], dtype=torch.float32)
143
+
144
+ with torch.no_grad():
145
+ scaled_prediction = float(self.model(tensor).item())
146
+
147
+ predicted_capacity = float(y_scaler.inverse_transform([[scaled_prediction]])[0, 0])
148
+
149
+ return {
150
+ "battery_id": battery_key,
151
+ "window_size": self.window_size,
152
+ "num_features": self.num_features,
153
+ "predicted_capacity": predicted_capacity,
154
+ "scaled_prediction": scaled_prediction,
155
+ }
156
+
157
+
158
+ @lru_cache(maxsize=1)
159
+ def get_predictor() -> BatteryPredictor:
160
+ return BatteryPredictor()
161
+
162
+
163
+ def predict_from_request(payload: Mapping[str, Any]) -> Dict[str, Any]:
164
+ if not isinstance(payload, Mapping):
165
+ raise TypeError("payload must be a mapping with battery_id and window")
166
+
167
+ if "window" not in payload:
168
+ raise ValueError("payload must include a window field")
169
+
170
+ return get_predictor().predict(payload["window"], battery_id=payload.get("battery_id", "B0005"))
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ fastapi>=0.110.0
2
+ uvicorn[standard]>=0.29.0
3
+ torch>=2.2.0
4
+ numpy>=1.24.0
5
+ joblib>=1.3.0
6
+ scikit-learn==1.7.0
7
+ huggingface_hub>=0.23.0