File size: 5,140 Bytes
b44d852
 
 
f7f1df4
b44d852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7f1df4
 
b44d852
f7f1df4
b44d852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b83e25
b44d852
3b83e25
b44d852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7f1df4
 
 
 
b44d852
 
 
 
 
9a76208
 
f7f1df4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a76208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import pytest
from fastapi.testclient import TestClient

from app.main import ALLOW_OUT_OF_RANGE_COLUMNS, app


@pytest.fixture(scope="session")
def client():
    with TestClient(app) as test_client:
        yield test_client


def _build_payload(preprocessor):
    data = {}
    for col in preprocessor.required_input_columns:
        if col in preprocessor.numeric_medians:
            data[col] = preprocessor.numeric_medians[col]
        elif col in preprocessor.categorical_columns:
            data[col] = "Unknown"
        else:
            data[col] = 0
    data["SK_ID_CURR"] = int(data.get("SK_ID_CURR", 100001))
    return {"data": data}


def _pick_required_column(preprocessor, exclude=None):
    exclude = set(exclude or [])
    for col in preprocessor.required_input_columns:
        if col not in exclude:
            return col
    raise AssertionError("No required column available for test.")


def _pick_numeric_range(preprocessor, exclude=None):
    exclude = set(exclude or [])
    for col, bounds in preprocessor.numeric_ranges.items():
        if col in preprocessor.numeric_required_columns and col not in exclude:
            return col, bounds
    raise AssertionError("No numeric range available for test.")


def _pick_numeric_required(preprocessor):
    for col in preprocessor.numeric_required_columns:
        if col != "SK_ID_CURR":
            return col
    raise AssertionError("No numeric required column available for test.")


def test_health(client):
    resp = client.get("/health")
    assert resp.status_code == 200
    assert resp.json() == {"status": "ok"}


def test_features(client):
    resp = client.get("/features")
    assert resp.status_code == 200
    payload = resp.json()
    assert "input_features" in payload
    assert "required_input_features" in payload
    assert "feature_selection_method" in payload
    assert "SK_ID_CURR" in payload["input_features"]
    assert len(payload["input_features"]) >= 2


def test_predict(client):
    preprocessor = client.app.state.preprocessor
    payload = _build_payload(preprocessor)
    resp = client.post("/predict", json=payload)
    assert resp.status_code == 200
    data = resp.json()
    assert "predictions" in data
    assert len(data["predictions"]) == 1
    result = data["predictions"][0]
    assert "sk_id_curr" in result
    assert "prediction" in result
    assert "probability" in result
    assert 0.0 <= result["probability"] <= 1.0


def test_predict_missing_required_field(client):
    preprocessor = client.app.state.preprocessor
    payload = _build_payload(preprocessor)
    missing_col = _pick_required_column(preprocessor, exclude={"SK_ID_CURR"})
    payload["data"].pop(missing_col, None)
    resp = client.post("/predict", json=payload)
    assert resp.status_code == 422
    detail = resp.json().get("detail", {})
    assert detail.get("message") == "Missing required input columns."


def test_predict_invalid_type(client):
    preprocessor = client.app.state.preprocessor
    payload = _build_payload(preprocessor)
    invalid_col = _pick_numeric_required(preprocessor)
    payload["data"][invalid_col] = "not_a_number"
    resp = client.post("/predict", json=payload)
    assert resp.status_code == 422
    detail = resp.json().get("detail", {})
    assert detail.get("message") == "Invalid numeric values provided."


def test_predict_out_of_range(client):
    preprocessor = client.app.state.preprocessor
    payload = _build_payload(preprocessor)
    col, (min_val, max_val) = _pick_numeric_range(
        preprocessor,
        exclude=ALLOW_OUT_OF_RANGE_COLUMNS,
    )
    payload["data"][col] = max_val + 1
    resp = client.post("/predict", json=payload)
    assert resp.status_code == 422
    detail = resp.json().get("detail", {})
    assert detail.get("message") == "Input contains values outside expected ranges."


def test_predict_out_of_range_allowed_ext_source(client):
    preprocessor = client.app.state.preprocessor
    payload = _build_payload(preprocessor)
    allowed = [
        col
        for col in ALLOW_OUT_OF_RANGE_COLUMNS
        if col in preprocessor.numeric_ranges
    ]
    if not allowed:
        pytest.skip("No EXT_SOURCE ranges available for test.")
    col = allowed[0]
    _, max_val = preprocessor.numeric_ranges[col]
    payload["data"][col] = max_val + 1
    resp = client.post("/predict", json=payload)
    assert resp.status_code == 200


def test_predict_normalizes_categoricals(client):
    preprocessor = client.app.state.preprocessor
    payload = _build_payload(preprocessor)
    if "CODE_GENDER" in payload["data"]:
        payload["data"]["CODE_GENDER"] = "female"
    if "FLAG_OWN_CAR" in payload["data"]:
        payload["data"]["FLAG_OWN_CAR"] = "true"
    resp = client.post("/predict", json=payload)
    assert resp.status_code == 200


def test_predict_days_employed_sentinel(client):
    preprocessor = client.app.state.preprocessor
    payload = _build_payload(preprocessor)
    if "DAYS_EMPLOYED" in payload["data"]:
        payload["data"]["DAYS_EMPLOYED"] = 365243
    resp = client.post("/predict", json=payload)
    assert resp.status_code == 200