Spaces:
Sleeping
Sleeping
File size: 3,984 Bytes
94dd2ae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | from datetime import datetime, timezone
import uuid
import pandas as pd
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.compiler import compiles
from src.main import app
from src.config.db import get_db
from src.models.ml import MLModel
from src.models.ml_inputs import MLInput
from src.models.ml_output import MLOutput
@compiles(JSONB, "sqlite")
def compile_jsonb_sqlite(type_, compiler, **kw):
return "JSON"
def test_batch_predict_simple(tmp_path, monkeypatch):
db_path = tmp_path / "testing.db"
engine = create_engine(
f"sqlite:///{db_path}",
connect_args={"check_same_thread": False},
future=True,
)
SQLSession = sessionmaker(
bind=engine,
autoflush=False,
autocommit=False,
future=True,
)
MLModel.__table__.create(bind=engine)
MLInput.__table__.create(bind=engine)
MLOutput.__table__.create(bind=engine)
session = SQLSession()
def get_db_override():
try:
yield session
finally:
pass
app.dependency_overrides[get_db] = get_db_override
client = TestClient(app, raise_server_exceptions=False)
created = datetime(2025, 9, 15, 10, 11, 3, 950802, tzinfo=timezone.utc)
model_row = MLModel(
id=uuid.uuid4(),
name="best_model",
description="XGB v1",
created_at=created,
is_active=True,
)
session.add(model_row)
session.commit()
class FakeModel:
classes_ = [0, 1]
def predict_proba(self, X: pd.DataFrame):
return [[0.3, 0.7] for _ in range(len(X))]
import src.controllers.predict_controller as pc
def fake_load_model(name: str):
assert name == "best_model"
return FakeModel()
def fake_compute_features(df: pd.DataFrame) -> pd.DataFrame:
return df
monkeypatch.setattr(pc, "load_model", fake_load_model)
monkeypatch.setattr(pc, "compute_features", fake_compute_features)
payload = {
"model_name": "best_model",
"inputs": [
{
"SK_ID_CURR": 100005,
"NAME_CONTRACT_TYPE": "Cash loans",
"CODE_GENDER": "M",
"FLAG_OWN_CAR": "N",
"FLAG_OWN_REALTY": "Y",
"CNT_CHILDREN": 0,
"AMT_INCOME_TOTAL": 99000.0,
"AMT_CREDIT": 222768.0,
"AMT_ANNUITY": 17370.0,
"AMT_GOODS_PRICE": 180000.0,
"NAME_TYPE_SUITE": "Unaccompanied",
"NAME_INCOME_TYPE": "Working",
"NAME_EDUCATION_TYPE": "Secondary / secondary special",
"NAME_FAMILY_STATUS": "Married",
"NAME_HOUSING_TYPE": "House / apartment",
"REGION_POPULATION_RELATIVE": 0.035792000000000004,
"DAYS_BIRTH": -18064,
"DAYS_EMPLOYED": -4469,
"DAYS_REGISTRATION": -9118,
"DAYS_ID_PUBLISH": -1623,
"OCCUPATION_TYPE": "Low-skill Laborers",
"WEEKDAY_APPR_PROCESS_START": "FRIDAY",
"HOUR_APPR_PROCESS_START": 9,
"ORGANIZATION_TYPE": "Self-employed",
"EXT_SOURCE_1": 0.5649902017969249,
"EXT_SOURCE_2": 0.2916555320093651,
"EXT_SOURCE_3": 0.4329616670974407,
"AMT_REQ_CREDIT_BUREAU_YEAR": 3.0,
}
],
}
resp = client.post("/predict/", json=payload)
app.dependency_overrides.clear()
session.close()
assert resp.status_code == 200, resp.text
body = resp.json()
assert body["model_name"] == "best_model"
assert "results" in body
assert len(body["results"]) == 1
item = body["results"][0]
assert item["label"] in ("solvable", "non_solvable")
assert 0.0 <= item["proba"] <= 1.0
|