aiBatteryLifeCycle / tests /routers /test_predict_router.py
NeerajCodz's picture
Fix metrics loading and restructure pytest suite
536ba3d
from __future__ import annotations
import pytest
from fastapi import HTTPException
from api.routers import predict as predict_router
from api.schemas import BatchPredictRequest, PredictRequest, RecommendationRequest
class StubRegistry:
def __init__(self):
self.default_model = "random_forest"
self.last_features = None
def predict(self, features, model_name=None):
self.last_features = dict(features)
if features.get("cycle_number") == -1:
raise RuntimeError("boom")
return {
"soh_pct": 91.2,
"rul_cycles": 123.4,
"degradation_state": "Healthy",
"confidence_lower": 89.0,
"confidence_upper": 93.0,
"model_used": model_name or "random_forest",
"model_version": "3.0.0",
}
def predict_batch(self, battery_id, cycles):
return [
{
"cycle_number": c.get("cycle_number", i + 1),
"soh_pct": 90.0,
"rul_cycles": 100.0,
"degradation_state": "Healthy",
"model_used": "random_forest",
"model_version": "3.0.0",
}
for i, c in enumerate(cycles)
]
def list_models(self):
return [{"name": "random_forest", "version": "3.0.0"}]
@pytest.fixture
def stubbed_predict_registry(monkeypatch):
reg = StubRegistry()
monkeypatch.setattr(predict_router, "registry", reg)
monkeypatch.setattr(predict_router, "registry_v1", reg)
return reg
def test_predict_applies_avg_temp_offset(stubbed_predict_registry, run_async):
req = PredictRequest(
battery_id="B0005",
cycle_number=12,
ambient_temperature=24.0,
avg_temp=24.0,
)
resp = run_async(predict_router.predict(req))
assert resp.soh_pct == 91.2
assert resp.model_used == "random_forest"
assert stubbed_predict_registry.last_features["avg_temp"] == 32.0
def test_predict_batch_returns_predictions(stubbed_predict_registry, run_async):
req = BatchPredictRequest(
battery_id="B0006",
cycles=[{"cycle_number": 1}, {"cycle_number": 2}],
)
resp = run_async(predict_router.predict_batch(req))
assert resp.battery_id == "B0006"
assert len(resp.predictions) == 2
def test_recommend_returns_top_k(stubbed_predict_registry, run_async):
req = RecommendationRequest(
battery_id="B0007",
current_cycle=100,
current_soh=85.0,
ambient_temperature=24.0,
top_k=3,
)
resp = run_async(predict_router.recommend(req))
assert len(resp.recommendations) == 3
assert resp.recommendations[0].rank == 1
def test_list_models(stubbed_predict_registry, run_async):
rows = run_async(predict_router.list_models())
assert rows[0]["name"] == "random_forest"
def test_list_models_v1(stubbed_predict_registry, run_async):
rows = run_async(predict_router.list_models_v1())
assert rows[0]["name"] == "random_forest"
def test_list_model_versions_grouping(monkeypatch, run_async):
class GroupRegistry(StubRegistry):
def list_models(self):
return [
{"name": "ridge", "version": "1.2.0"},
{"name": "bilstm", "version": "2.0.0"},
{"name": "best_ensemble", "version": "3.0.0"},
{"name": "mystery", "version": "?"},
]
reg = GroupRegistry()
monkeypatch.setattr(predict_router, "registry", reg)
out = run_async(predict_router.list_model_versions())
assert len(out["v1_classical"]) == 1
assert len(out["v2_deep"]) == 1
assert len(out["v2_ensemble"]) == 1
assert len(out["other"]) == 1
assert out["default_model"] == "random_forest"
def test_predict_v1(stubbed_predict_registry, run_async):
req = PredictRequest(battery_id="B0005", cycle_number=12)
resp = run_async(predict_router.predict_v1(req))
assert resp.model_version == "3.0.0"
def test_predict_wraps_exceptions(monkeypatch, run_async):
class ErrorRegistry(StubRegistry):
def predict(self, features, model_name=None):
raise RuntimeError("bad")
monkeypatch.setattr(predict_router, "registry", ErrorRegistry())
req = PredictRequest(battery_id="B0005", cycle_number=1)
with pytest.raises(HTTPException) as ex:
run_async(predict_router.predict(req))
assert ex.value.status_code == 500