Rossmann-Store-Sales / tests /test_api.py
ymlin105's picture
feat: add ci checks and model metadata versioning
0269b4b
import numpy as np
import pytest
from fastapi.testclient import TestClient
from src.serving import api
class DummyBooster:
def predict(self, matrix, pred_contribs=False):
column_count = matrix.num_col()
return np.array([[0.1] * (column_count + 1)])
class DummyModel:
def predict(self, X):
return np.log1p(np.full(len(X), 100.0))
def get_booster(self):
return DummyBooster()
def test_predict_endpoint_uses_defaults(monkeypatch, tmp_path):
monkeypatch.setenv("MODEL_PATH", str(tmp_path / "missing.json"))
client = TestClient(api.app)
monkeypatch.setattr(api, "model", DummyModel())
monkeypatch.setattr(api, "model_version", "test-version")
monkeypatch.setattr(
api,
"store_lookup",
{
1: {
"StoreType": "a",
"Assortment": "c",
"CompetitionDistance": 1200.0,
"Promo2": 0,
"Promo2SinceWeek": 0,
"Promo2SinceYear": 0,
}
},
)
response = client.post(
"/predict",
json={
"Store": 1,
"Date": "2015-07-31",
},
)
assert response.status_code == 200
payload = response.json()
assert payload["Store"] == 1
assert payload["Status"] == "success"
assert payload["PredictedSales"] == pytest.approx(100.0)
assert payload["ModelVersion"] == "test-version"
assert len(payload["Forecast"]) == 1
assert len(payload["Explanation"]) == 5
assert "score" in payload["Explanation"][0]
def test_predict_endpoint_returns_404_for_unknown_store(monkeypatch):
client = TestClient(api.app)
monkeypatch.setattr(api, "model", DummyModel())
monkeypatch.setattr(api, "store_lookup", {})
response = client.post(
"/predict",
json={
"Store": 9999,
"Date": "2015-07-31",
},
)
assert response.status_code == 404
def test_health_endpoint_is_stateless():
client = TestClient(api.app)
api.model_version = "test-version"
response = client.get("/health")
assert response.status_code == 200
assert response.json()["status"] == "healthy"
assert response.json()["model_version"] == "test-version"