NeerajCodz's picture
feat: full project β€” ML simulation, dashboard UI, models on HF Hub
f381be8
"""
api.routers.predict
===================
Prediction & recommendation endpoints.
"""
from __future__ import annotations
from fastapi import APIRouter, HTTPException
from api.model_registry import registry, registry_v1, classify_degradation, soh_to_color
from api.schemas import (
PredictRequest, PredictResponse,
BatchPredictRequest, BatchPredictResponse,
RecommendationRequest, RecommendationResponse, SingleRecommendation,
)
router = APIRouter(prefix="/api", tags=["prediction"])
# v1-prefixed router (legacy, preserved for backward compatibility)
v1_router = APIRouter(prefix="/api/v1", tags=["v1-prediction"])
# ── Single prediction ────────────────────────────────────────────────────────
@router.post("/predict", response_model=PredictResponse)
async def predict(req: PredictRequest):
"""Predict SOH for a single cycle."""
features = req.model_dump(exclude={"battery_id"})
features["voltage_range"] = features["peak_voltage"] - features["min_voltage"]
# If avg_temp equals ambient_temperature exactly, apply the NASA data offset
# (cell temperature is always 8-10Β°C above ambient under load).
if abs(features["avg_temp"] - features["ambient_temperature"]) < 0.5:
features["avg_temp"] = features["ambient_temperature"] + 8.0
try:
result = registry.predict(features)
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
return PredictResponse(
battery_id=req.battery_id,
cycle_number=req.cycle_number,
soh_pct=result["soh_pct"],
rul_cycles=result["rul_cycles"],
degradation_state=result["degradation_state"],
confidence_lower=result["confidence_lower"],
confidence_upper=result["confidence_upper"],
model_used=result["model_used"],
)
# ── Batch prediction ─────────────────────────────────────────────────────────
@router.post("/predict/batch", response_model=BatchPredictResponse)
async def predict_batch(req: BatchPredictRequest):
"""Predict SOH for multiple cycles of one battery."""
results = registry.predict_batch(req.battery_id, req.cycles)
predictions = [
PredictResponse(
battery_id=req.battery_id,
cycle_number=r["cycle_number"],
soh_pct=r["soh_pct"],
rul_cycles=r["rul_cycles"],
degradation_state=r["degradation_state"],
confidence_lower=r.get("confidence_lower"),
confidence_upper=r.get("confidence_upper"),
model_used=r["model_used"], model_version=r.get("model_version"), )
for r in results
]
return BatchPredictResponse(battery_id=req.battery_id, predictions=predictions)
# ── Recommendations ──────────────────────────────────────────────────────────
@router.post("/recommend", response_model=RecommendationResponse)
async def recommend(req: RecommendationRequest):
"""Get operational recommendations for a battery based on physics-informed degradation model."""
import itertools
# **FIXED**: Use physics-based degradation rates instead of unreliable RUL prediction
# Empirical degradation rates from NASA PCoE data analysis
DEGRADATION_RATES = {
# Format: (temp_range, current_level): % SOH loss per cycle
"cold_light": 0.08, # 4Β°C, <=1A
"cold_moderate": 0.12, # 4Β°C, 1-2A
"cold_heavy": 0.18, # 4Β°C, >2A
"room_light": 0.15, # 24Β°C, <=1A
"room_moderate": 0.22, # 24Β°C, 1-2A
"room_heavy": 0.28, # 24Β°C, >2A
"warm_light": 0.35, # 43Β°C, <=1A
"warm_moderate": 0.48, # 43Β°C, 1-2A
"warm_heavy": 0.65, # 43Β°C, >2A
}
def get_degradation_rate(temp, current):
"""Return degradation rate (% SOH/cycle) given operating conditions."""
if temp <= 4:
if current <= 1.0:
return DEGRADATION_RATES["cold_light"]
elif current <= 2.0:
return DEGRADATION_RATES["cold_moderate"]
else:
return DEGRADATION_RATES["cold_heavy"]
elif temp <= 24:
if current <= 1.0:
return DEGRADATION_RATES["room_light"]
elif current <= 2.0:
return DEGRADATION_RATES["room_moderate"]
else:
return DEGRADATION_RATES["room_heavy"]
else:
if current <= 1.0:
return DEGRADATION_RATES["warm_light"]
elif current <= 2.0:
return DEGRADATION_RATES["warm_moderate"]
else:
return DEGRADATION_RATES["warm_heavy"]
def cycles_to_eol(current_soh, degradation_rate_pct_per_cycle, eol_threshold=70):
"""Calculate cycles until end-of-life."""
if degradation_rate_pct_per_cycle <= 0:
return 10000 # Unrealistic but prevents division by zero
soh_margin = current_soh - eol_threshold
if soh_margin <= 0:
return 0
return int(soh_margin / degradation_rate_pct_per_cycle)
# Generate recommendations for different operating conditions
temps = [4.0, 24.0, 43.0]
currents = [0.5, 1.0, 2.0, 4.0]
candidates = []
for t, c in itertools.product(temps, currents):
degradation = get_degradation_rate(t, c)
rul = cycles_to_eol(req.current_soh, degradation)
candidates.append((rul, t, c, degradation))
# Sort by RUL (cycles until EOL) in descending order
candidates.sort(reverse=True, key=lambda x: x[0])
top = candidates[:req.top_k]
# Calculate baseline (current operating conditions)
baseline_degradation = get_degradation_rate(req.ambient_temperature, 2.0)
baseline_rul = cycles_to_eol(req.current_soh, baseline_degradation)
recs = []
for rank, (rul, t, c, deg) in enumerate(top, 1):
improvement = rul - baseline_rul
improvement_pct = (improvement / baseline_rul * 100) if baseline_rul > 0 else 0.0
# Determine operational regime
if t <= 4:
temp_desc = "cold storage"
elif t <= 24:
temp_desc = "room temperature"
else:
temp_desc = "heated environment"
if c <= 1.0:
current_desc = "low current (trickle charge/light use)"
elif c <= 2.0:
current_desc = "moderate current (normal use)"
else:
current_desc = "high current (fast charging/heavy load)"
recs.append(SingleRecommendation(
rank=rank,
ambient_temperature=t,
discharge_current=c,
cutoff_voltage=2.5, # Standard cutoff
predicted_rul=int(rul),
rul_improvement=int(improvement),
rul_improvement_pct=round(improvement_pct, 1),
explanation=f"Rank #{rank}: Operate in {temp_desc} at {current_desc} β†’ ~{int(rul)} cycles until EOL (+{int(improvement)} cycles vs. baseline)",
))
return RecommendationResponse(
battery_id=req.battery_id,
current_soh=req.current_soh,
recommendations=recs,
)
# ── Model listing ────────────────────────────────────────────────────────────
@router.get("/models")
async def list_models():
"""List all registered models with metrics, version, and load status."""
return registry.list_models()
@router.get("/models/versions")
async def list_model_versions():
"""Return models grouped by semantic version family.
Groups:
* v1 β€” Classical ML models
* v2 β€” Deep sequence models (LSTM, Transformer)
* v2 patch β€” Ensemble / meta-models (v2.6)
"""
all_models = registry.list_models()
groups: dict[str, list] = {"v1": [], "v2": [], "v2_ensemble": [], "other": []}
for m in all_models:
ver = m.get("version", "")
if ver.startswith("1"):
groups["v1"].append(m)
elif ver.startswith("3") or "ensemble" in m.get("name", "").lower():
groups["v2_ensemble"].append(m)
elif ver.startswith("2"):
groups["v2"].append(m)
else:
groups["other"].append(m)
return {
"v1_classical": groups["v1"],
"v2_deep": groups["v2"],
"v2_ensemble": groups["v2_ensemble"],
"other": groups["other"],
"default_model": registry.default_model,
}
# ── v1-prefixed endpoints (legacy) ──────────────────────────────────────────
@v1_router.post("/predict", response_model=PredictResponse)
async def predict_v1(req: PredictRequest):
"""Predict SOH using v1 models (legacy, uses group-battery split models)."""
features = req.model_dump(exclude={"battery_id"})
features["voltage_range"] = features["peak_voltage"] - features["min_voltage"]
if abs(features["avg_temp"] - features["ambient_temperature"]) < 0.5:
features["avg_temp"] = features["ambient_temperature"] + 8.0
try:
result = registry_v1.predict(features)
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
return PredictResponse(
battery_id=req.battery_id,
cycle_number=req.cycle_number,
soh_pct=result["soh_pct"],
rul_cycles=result["rul_cycles"],
degradation_state=result["degradation_state"],
confidence_lower=result["confidence_lower"],
confidence_upper=result["confidence_upper"],
model_used=result["model_used"],
model_version=result.get("model_version", "1.0.0"),
)
@v1_router.get("/models")
async def list_models_v1():
"""List all v1 registered models."""
return registry_v1.list_models()