Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import pytest | |
| from fastapi import HTTPException | |
| from api.routers import predict as predict_router | |
| from api.schemas import BatchPredictRequest, PredictRequest, RecommendationRequest | |
| class StubRegistry: | |
| def __init__(self): | |
| self.default_model = "random_forest" | |
| self.last_features = None | |
| def predict(self, features, model_name=None): | |
| self.last_features = dict(features) | |
| if features.get("cycle_number") == -1: | |
| raise RuntimeError("boom") | |
| return { | |
| "soh_pct": 91.2, | |
| "rul_cycles": 123.4, | |
| "degradation_state": "Healthy", | |
| "confidence_lower": 89.0, | |
| "confidence_upper": 93.0, | |
| "model_used": model_name or "random_forest", | |
| "model_version": "3.0.0", | |
| } | |
| def predict_batch(self, battery_id, cycles): | |
| return [ | |
| { | |
| "cycle_number": c.get("cycle_number", i + 1), | |
| "soh_pct": 90.0, | |
| "rul_cycles": 100.0, | |
| "degradation_state": "Healthy", | |
| "model_used": "random_forest", | |
| "model_version": "3.0.0", | |
| } | |
| for i, c in enumerate(cycles) | |
| ] | |
| def list_models(self): | |
| return [{"name": "random_forest", "version": "3.0.0"}] | |
| def stubbed_predict_registry(monkeypatch): | |
| reg = StubRegistry() | |
| monkeypatch.setattr(predict_router, "registry", reg) | |
| monkeypatch.setattr(predict_router, "registry_v1", reg) | |
| return reg | |
| def test_predict_applies_avg_temp_offset(stubbed_predict_registry, run_async): | |
| req = PredictRequest( | |
| battery_id="B0005", | |
| cycle_number=12, | |
| ambient_temperature=24.0, | |
| avg_temp=24.0, | |
| ) | |
| resp = run_async(predict_router.predict(req)) | |
| assert resp.soh_pct == 91.2 | |
| assert resp.model_used == "random_forest" | |
| assert stubbed_predict_registry.last_features["avg_temp"] == 32.0 | |
| def test_predict_batch_returns_predictions(stubbed_predict_registry, run_async): | |
| req = BatchPredictRequest( | |
| battery_id="B0006", | |
| cycles=[{"cycle_number": 1}, {"cycle_number": 2}], | |
| ) | |
| resp = run_async(predict_router.predict_batch(req)) | |
| assert resp.battery_id == "B0006" | |
| assert len(resp.predictions) == 2 | |
| def test_recommend_returns_top_k(stubbed_predict_registry, run_async): | |
| req = RecommendationRequest( | |
| battery_id="B0007", | |
| current_cycle=100, | |
| current_soh=85.0, | |
| ambient_temperature=24.0, | |
| top_k=3, | |
| ) | |
| resp = run_async(predict_router.recommend(req)) | |
| assert len(resp.recommendations) == 3 | |
| assert resp.recommendations[0].rank == 1 | |
| def test_list_models(stubbed_predict_registry, run_async): | |
| rows = run_async(predict_router.list_models()) | |
| assert rows[0]["name"] == "random_forest" | |
| def test_list_models_v1(stubbed_predict_registry, run_async): | |
| rows = run_async(predict_router.list_models_v1()) | |
| assert rows[0]["name"] == "random_forest" | |
| def test_list_model_versions_grouping(monkeypatch, run_async): | |
| class GroupRegistry(StubRegistry): | |
| def list_models(self): | |
| return [ | |
| {"name": "ridge", "version": "1.2.0"}, | |
| {"name": "bilstm", "version": "2.0.0"}, | |
| {"name": "best_ensemble", "version": "3.0.0"}, | |
| {"name": "mystery", "version": "?"}, | |
| ] | |
| reg = GroupRegistry() | |
| monkeypatch.setattr(predict_router, "registry", reg) | |
| out = run_async(predict_router.list_model_versions()) | |
| assert len(out["v1_classical"]) == 1 | |
| assert len(out["v2_deep"]) == 1 | |
| assert len(out["v2_ensemble"]) == 1 | |
| assert len(out["other"]) == 1 | |
| assert out["default_model"] == "random_forest" | |
| def test_predict_v1(stubbed_predict_registry, run_async): | |
| req = PredictRequest(battery_id="B0005", cycle_number=12) | |
| resp = run_async(predict_router.predict_v1(req)) | |
| assert resp.model_version == "3.0.0" | |
| def test_predict_wraps_exceptions(monkeypatch, run_async): | |
| class ErrorRegistry(StubRegistry): | |
| def predict(self, features, model_name=None): | |
| raise RuntimeError("bad") | |
| monkeypatch.setattr(predict_router, "registry", ErrorRegistry()) | |
| req = PredictRequest(battery_id="B0005", cycle_number=1) | |
| with pytest.raises(HTTPException) as ex: | |
| run_async(predict_router.predict(req)) | |
| assert ex.value.status_code == 500 | |