Spaces:
Running
Running
| """ | |
| api.routers.predict | |
| =================== | |
| Prediction & recommendation endpoints. | |
| """ | |
| from __future__ import annotations | |
| from fastapi import APIRouter, HTTPException | |
| from api.model_registry import registry, registry_v1, classify_degradation, soh_to_color | |
| from api.schemas import ( | |
| PredictRequest, PredictResponse, | |
| BatchPredictRequest, BatchPredictResponse, | |
| RecommendationRequest, RecommendationResponse, SingleRecommendation, | |
| ) | |
| router = APIRouter(prefix="/api", tags=["prediction"]) | |
| # v1-prefixed router (legacy, preserved for backward compatibility) | |
| v1_router = APIRouter(prefix="/api/v1", tags=["v1-prediction"]) | |
| # ββ Single prediction ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def predict(req: PredictRequest): | |
| """Predict SOH for a single cycle.""" | |
| features = req.model_dump(exclude={"battery_id"}) | |
| features["voltage_range"] = features["peak_voltage"] - features["min_voltage"] | |
| # If avg_temp equals ambient_temperature exactly, apply the NASA data offset | |
| # (cell temperature is always 8-10Β°C above ambient under load). | |
| if abs(features["avg_temp"] - features["ambient_temperature"]) < 0.5: | |
| features["avg_temp"] = features["ambient_temperature"] + 8.0 | |
| try: | |
| result = registry.predict(features) | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| return PredictResponse( | |
| battery_id=req.battery_id, | |
| cycle_number=req.cycle_number, | |
| soh_pct=result["soh_pct"], | |
| rul_cycles=result["rul_cycles"], | |
| degradation_state=result["degradation_state"], | |
| confidence_lower=result["confidence_lower"], | |
| confidence_upper=result["confidence_upper"], | |
| model_used=result["model_used"], | |
| ) | |
| # ββ Batch prediction βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def predict_batch(req: BatchPredictRequest): | |
| """Predict SOH for multiple cycles of one battery.""" | |
| results = registry.predict_batch(req.battery_id, req.cycles) | |
| predictions = [ | |
| PredictResponse( | |
| battery_id=req.battery_id, | |
| cycle_number=r["cycle_number"], | |
| soh_pct=r["soh_pct"], | |
| rul_cycles=r["rul_cycles"], | |
| degradation_state=r["degradation_state"], | |
| confidence_lower=r.get("confidence_lower"), | |
| confidence_upper=r.get("confidence_upper"), | |
| model_used=r["model_used"], model_version=r.get("model_version"), ) | |
| for r in results | |
| ] | |
| return BatchPredictResponse(battery_id=req.battery_id, predictions=predictions) | |
| # ββ Recommendations ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def recommend(req: RecommendationRequest): | |
| """Get operational recommendations for a battery based on physics-informed degradation model.""" | |
| import itertools | |
| # **FIXED**: Use physics-based degradation rates instead of unreliable RUL prediction | |
| # Empirical degradation rates from NASA PCoE data analysis | |
| DEGRADATION_RATES = { | |
| # Format: (temp_range, current_level): % SOH loss per cycle | |
| "cold_light": 0.08, # 4Β°C, <=1A | |
| "cold_moderate": 0.12, # 4Β°C, 1-2A | |
| "cold_heavy": 0.18, # 4Β°C, >2A | |
| "room_light": 0.15, # 24Β°C, <=1A | |
| "room_moderate": 0.22, # 24Β°C, 1-2A | |
| "room_heavy": 0.28, # 24Β°C, >2A | |
| "warm_light": 0.35, # 43Β°C, <=1A | |
| "warm_moderate": 0.48, # 43Β°C, 1-2A | |
| "warm_heavy": 0.65, # 43Β°C, >2A | |
| } | |
| def get_degradation_rate(temp, current): | |
| """Return degradation rate (% SOH/cycle) given operating conditions.""" | |
| if temp <= 4: | |
| if current <= 1.0: | |
| return DEGRADATION_RATES["cold_light"] | |
| elif current <= 2.0: | |
| return DEGRADATION_RATES["cold_moderate"] | |
| else: | |
| return DEGRADATION_RATES["cold_heavy"] | |
| elif temp <= 24: | |
| if current <= 1.0: | |
| return DEGRADATION_RATES["room_light"] | |
| elif current <= 2.0: | |
| return DEGRADATION_RATES["room_moderate"] | |
| else: | |
| return DEGRADATION_RATES["room_heavy"] | |
| else: | |
| if current <= 1.0: | |
| return DEGRADATION_RATES["warm_light"] | |
| elif current <= 2.0: | |
| return DEGRADATION_RATES["warm_moderate"] | |
| else: | |
| return DEGRADATION_RATES["warm_heavy"] | |
| def cycles_to_eol(current_soh, degradation_rate_pct_per_cycle, eol_threshold=70): | |
| """Calculate cycles until end-of-life.""" | |
| if degradation_rate_pct_per_cycle <= 0: | |
| return 10000 # Unrealistic but prevents division by zero | |
| soh_margin = current_soh - eol_threshold | |
| if soh_margin <= 0: | |
| return 0 | |
| return int(soh_margin / degradation_rate_pct_per_cycle) | |
| # Generate recommendations for different operating conditions | |
| temps = [4.0, 24.0, 43.0] | |
| currents = [0.5, 1.0, 2.0, 4.0] | |
| candidates = [] | |
| for t, c in itertools.product(temps, currents): | |
| degradation = get_degradation_rate(t, c) | |
| rul = cycles_to_eol(req.current_soh, degradation) | |
| candidates.append((rul, t, c, degradation)) | |
| # Sort by RUL (cycles until EOL) in descending order | |
| candidates.sort(reverse=True, key=lambda x: x[0]) | |
| top = candidates[:req.top_k] | |
| # Calculate baseline (current operating conditions) | |
| baseline_degradation = get_degradation_rate(req.ambient_temperature, 2.0) | |
| baseline_rul = cycles_to_eol(req.current_soh, baseline_degradation) | |
| recs = [] | |
| for rank, (rul, t, c, deg) in enumerate(top, 1): | |
| improvement = rul - baseline_rul | |
| improvement_pct = (improvement / baseline_rul * 100) if baseline_rul > 0 else 0.0 | |
| # Determine operational regime | |
| if t <= 4: | |
| temp_desc = "cold storage" | |
| elif t <= 24: | |
| temp_desc = "room temperature" | |
| else: | |
| temp_desc = "heated environment" | |
| if c <= 1.0: | |
| current_desc = "low current (trickle charge/light use)" | |
| elif c <= 2.0: | |
| current_desc = "moderate current (normal use)" | |
| else: | |
| current_desc = "high current (fast charging/heavy load)" | |
| recs.append(SingleRecommendation( | |
| rank=rank, | |
| ambient_temperature=t, | |
| discharge_current=c, | |
| cutoff_voltage=2.5, # Standard cutoff | |
| predicted_rul=int(rul), | |
| rul_improvement=int(improvement), | |
| rul_improvement_pct=round(improvement_pct, 1), | |
| explanation=f"Rank #{rank}: Operate in {temp_desc} at {current_desc} β ~{int(rul)} cycles until EOL (+{int(improvement)} cycles vs. baseline)", | |
| )) | |
| return RecommendationResponse( | |
| battery_id=req.battery_id, | |
| current_soh=req.current_soh, | |
| recommendations=recs, | |
| ) | |
| # ββ Model listing ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def list_models(): | |
| """List all registered models with metrics, version, and load status.""" | |
| return registry.list_models() | |
| async def list_model_versions(): | |
| """Return models grouped by semantic version family. | |
| Groups: | |
| * v1 β Classical ML models | |
| * v2 β Deep sequence models (LSTM, Transformer) | |
| * v2 patch β Ensemble / meta-models (v2.6) | |
| """ | |
| all_models = registry.list_models() | |
| groups: dict[str, list] = {"v1": [], "v2": [], "v2_ensemble": [], "other": []} | |
| for m in all_models: | |
| ver = m.get("version", "") | |
| if ver.startswith("1"): | |
| groups["v1"].append(m) | |
| elif ver.startswith("3") or "ensemble" in m.get("name", "").lower(): | |
| groups["v2_ensemble"].append(m) | |
| elif ver.startswith("2"): | |
| groups["v2"].append(m) | |
| else: | |
| groups["other"].append(m) | |
| return { | |
| "v1_classical": groups["v1"], | |
| "v2_deep": groups["v2"], | |
| "v2_ensemble": groups["v2_ensemble"], | |
| "other": groups["other"], | |
| "default_model": registry.default_model, | |
| } | |
| # ββ v1-prefixed endpoints (legacy) ββββββββββββββββββββββββββββββββββββββββββ | |
| async def predict_v1(req: PredictRequest): | |
| """Predict SOH using v1 models (legacy, uses group-battery split models).""" | |
| features = req.model_dump(exclude={"battery_id"}) | |
| features["voltage_range"] = features["peak_voltage"] - features["min_voltage"] | |
| if abs(features["avg_temp"] - features["ambient_temperature"]) < 0.5: | |
| features["avg_temp"] = features["ambient_temperature"] + 8.0 | |
| try: | |
| result = registry_v1.predict(features) | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| return PredictResponse( | |
| battery_id=req.battery_id, | |
| cycle_number=req.cycle_number, | |
| soh_pct=result["soh_pct"], | |
| rul_cycles=result["rul_cycles"], | |
| degradation_state=result["degradation_state"], | |
| confidence_lower=result["confidence_lower"], | |
| confidence_upper=result["confidence_upper"], | |
| model_used=result["model_used"], | |
| model_version=result.get("model_version", "1.0.0"), | |
| ) | |
| async def list_models_v1(): | |
| """List all v1 registered models.""" | |
| return registry_v1.list_models() | |