from datetime import datetime, timezone import uuid import pandas as pd from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.compiler import compiles from src.main import app from src.config.db import get_db from src.models.ml import MLModel from src.models.ml_inputs import MLInput from src.models.ml_output import MLOutput @compiles(JSONB, "sqlite") def compile_jsonb_sqlite(type_, compiler, **kw): return "JSON" def test_batch_predict_simple(tmp_path, monkeypatch): db_path = tmp_path / "testing.db" engine = create_engine( f"sqlite:///{db_path}", connect_args={"check_same_thread": False}, future=True, ) SQLSession = sessionmaker( bind=engine, autoflush=False, autocommit=False, future=True, ) MLModel.__table__.create(bind=engine) MLInput.__table__.create(bind=engine) MLOutput.__table__.create(bind=engine) session = SQLSession() def get_db_override(): try: yield session finally: pass app.dependency_overrides[get_db] = get_db_override client = TestClient(app, raise_server_exceptions=False) created = datetime(2025, 9, 15, 10, 11, 3, 950802, tzinfo=timezone.utc) model_row = MLModel( id=uuid.uuid4(), name="best_model", description="XGB v1", created_at=created, is_active=True, ) session.add(model_row) session.commit() class FakeModel: classes_ = [0, 1] def predict_proba(self, X: pd.DataFrame): return [[0.3, 0.7] for _ in range(len(X))] import src.controllers.predict_controller as pc def fake_load_model(name: str): assert name == "best_model" return FakeModel() def fake_compute_features(df: pd.DataFrame) -> pd.DataFrame: return df monkeypatch.setattr(pc, "load_model", fake_load_model) monkeypatch.setattr(pc, "compute_features", fake_compute_features) payload = { "model_name": "best_model", "inputs": [ { "SK_ID_CURR": 100005, "NAME_CONTRACT_TYPE": "Cash loans", "CODE_GENDER": "M", "FLAG_OWN_CAR": "N", "FLAG_OWN_REALTY": "Y", "CNT_CHILDREN": 0, "AMT_INCOME_TOTAL": 99000.0, "AMT_CREDIT": 222768.0, "AMT_ANNUITY": 17370.0, "AMT_GOODS_PRICE": 180000.0, "NAME_TYPE_SUITE": "Unaccompanied", "NAME_INCOME_TYPE": "Working", "NAME_EDUCATION_TYPE": "Secondary / secondary special", "NAME_FAMILY_STATUS": "Married", "NAME_HOUSING_TYPE": "House / apartment", "REGION_POPULATION_RELATIVE": 0.035792000000000004, "DAYS_BIRTH": -18064, "DAYS_EMPLOYED": -4469, "DAYS_REGISTRATION": -9118, "DAYS_ID_PUBLISH": -1623, "OCCUPATION_TYPE": "Low-skill Laborers", "WEEKDAY_APPR_PROCESS_START": "FRIDAY", "HOUR_APPR_PROCESS_START": 9, "ORGANIZATION_TYPE": "Self-employed", "EXT_SOURCE_1": 0.5649902017969249, "EXT_SOURCE_2": 0.2916555320093651, "EXT_SOURCE_3": 0.4329616670974407, "AMT_REQ_CREDIT_BUREAU_YEAR": 3.0, } ], } resp = client.post("/predict/", json=payload) app.dependency_overrides.clear() session.close() assert resp.status_code == 200, resp.text body = resp.json() assert body["model_name"] == "best_model" assert "results" in body assert len(body["results"]) == 1 item = body["results"][0] assert item["label"] in ("solvable", "non_solvable") assert 0.0 <= item["proba"] <= 1.0