File size: 3,984 Bytes
94dd2ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from datetime import datetime, timezone
import uuid

import pandas as pd
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.compiler import compiles

from src.main import app
from src.config.db import get_db
from src.models.ml import MLModel
from src.models.ml_inputs import MLInput
from src.models.ml_output import MLOutput


@compiles(JSONB, "sqlite")
def compile_jsonb_sqlite(type_, compiler, **kw):
    return "JSON"


def test_batch_predict_simple(tmp_path, monkeypatch):
    db_path = tmp_path / "testing.db"
    engine = create_engine(
        f"sqlite:///{db_path}",
        connect_args={"check_same_thread": False},
        future=True,
    )
    SQLSession = sessionmaker(
        bind=engine,
        autoflush=False,
        autocommit=False,
        future=True,
    )

    MLModel.__table__.create(bind=engine)
    MLInput.__table__.create(bind=engine)
    MLOutput.__table__.create(bind=engine)

    session = SQLSession()

    def get_db_override():
        try:
            yield session
        finally:
            pass

    app.dependency_overrides[get_db] = get_db_override

    client = TestClient(app, raise_server_exceptions=False)

    created = datetime(2025, 9, 15, 10, 11, 3, 950802, tzinfo=timezone.utc)
    model_row = MLModel(
        id=uuid.uuid4(),
        name="best_model",
        description="XGB v1",
        created_at=created,
        is_active=True,
    )
    session.add(model_row)
    session.commit()

    class FakeModel:
        classes_ = [0, 1]  

        def predict_proba(self, X: pd.DataFrame):
            return [[0.3, 0.7] for _ in range(len(X))]

    import src.controllers.predict_controller as pc

    def fake_load_model(name: str):
        assert name == "best_model"
        return FakeModel()

    def fake_compute_features(df: pd.DataFrame) -> pd.DataFrame:
        return df

    monkeypatch.setattr(pc, "load_model", fake_load_model)
    monkeypatch.setattr(pc, "compute_features", fake_compute_features)


    payload = {
        "model_name": "best_model",
        "inputs": [
            {
                "SK_ID_CURR": 100005,
                "NAME_CONTRACT_TYPE": "Cash loans",
                "CODE_GENDER": "M",
                "FLAG_OWN_CAR": "N",
                "FLAG_OWN_REALTY": "Y",
                "CNT_CHILDREN": 0,
                "AMT_INCOME_TOTAL": 99000.0,
                "AMT_CREDIT": 222768.0,
                "AMT_ANNUITY": 17370.0,
                "AMT_GOODS_PRICE": 180000.0,
                "NAME_TYPE_SUITE": "Unaccompanied",
                "NAME_INCOME_TYPE": "Working",
                "NAME_EDUCATION_TYPE": "Secondary / secondary special",
                "NAME_FAMILY_STATUS": "Married",
                "NAME_HOUSING_TYPE": "House / apartment",
                "REGION_POPULATION_RELATIVE": 0.035792000000000004,
                "DAYS_BIRTH": -18064,
                "DAYS_EMPLOYED": -4469,
                "DAYS_REGISTRATION": -9118,
                "DAYS_ID_PUBLISH": -1623,
                "OCCUPATION_TYPE": "Low-skill Laborers",
                "WEEKDAY_APPR_PROCESS_START": "FRIDAY",
                "HOUR_APPR_PROCESS_START": 9,
                "ORGANIZATION_TYPE": "Self-employed",
                "EXT_SOURCE_1": 0.5649902017969249,
                "EXT_SOURCE_2": 0.2916555320093651,
                "EXT_SOURCE_3": 0.4329616670974407,
                "AMT_REQ_CREDIT_BUREAU_YEAR": 3.0,
            }
        ],
    }

    resp = client.post("/predict/", json=payload)

    app.dependency_overrides.clear()
    session.close()

    assert resp.status_code == 200, resp.text
    body = resp.json()

    assert body["model_name"] == "best_model"
    assert "results" in body
    assert len(body["results"]) == 1

    item = body["results"][0]
    assert item["label"] in ("solvable", "non_solvable")
    assert 0.0 <= item["proba"] <= 1.0