Spaces:
Sleeping
Sleeping
File size: 5,140 Bytes
b44d852 f7f1df4 b44d852 f7f1df4 b44d852 f7f1df4 b44d852 3b83e25 b44d852 3b83e25 b44d852 f7f1df4 b44d852 9a76208 f7f1df4 9a76208 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import pytest
from fastapi.testclient import TestClient
from app.main import ALLOW_OUT_OF_RANGE_COLUMNS, app
@pytest.fixture(scope="session")
def client():
with TestClient(app) as test_client:
yield test_client
def _build_payload(preprocessor):
data = {}
for col in preprocessor.required_input_columns:
if col in preprocessor.numeric_medians:
data[col] = preprocessor.numeric_medians[col]
elif col in preprocessor.categorical_columns:
data[col] = "Unknown"
else:
data[col] = 0
data["SK_ID_CURR"] = int(data.get("SK_ID_CURR", 100001))
return {"data": data}
def _pick_required_column(preprocessor, exclude=None):
exclude = set(exclude or [])
for col in preprocessor.required_input_columns:
if col not in exclude:
return col
raise AssertionError("No required column available for test.")
def _pick_numeric_range(preprocessor, exclude=None):
exclude = set(exclude or [])
for col, bounds in preprocessor.numeric_ranges.items():
if col in preprocessor.numeric_required_columns and col not in exclude:
return col, bounds
raise AssertionError("No numeric range available for test.")
def _pick_numeric_required(preprocessor):
for col in preprocessor.numeric_required_columns:
if col != "SK_ID_CURR":
return col
raise AssertionError("No numeric required column available for test.")
def test_health(client):
resp = client.get("/health")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}
def test_features(client):
resp = client.get("/features")
assert resp.status_code == 200
payload = resp.json()
assert "input_features" in payload
assert "required_input_features" in payload
assert "feature_selection_method" in payload
assert "SK_ID_CURR" in payload["input_features"]
assert len(payload["input_features"]) >= 2
def test_predict(client):
preprocessor = client.app.state.preprocessor
payload = _build_payload(preprocessor)
resp = client.post("/predict", json=payload)
assert resp.status_code == 200
data = resp.json()
assert "predictions" in data
assert len(data["predictions"]) == 1
result = data["predictions"][0]
assert "sk_id_curr" in result
assert "prediction" in result
assert "probability" in result
assert 0.0 <= result["probability"] <= 1.0
def test_predict_missing_required_field(client):
preprocessor = client.app.state.preprocessor
payload = _build_payload(preprocessor)
missing_col = _pick_required_column(preprocessor, exclude={"SK_ID_CURR"})
payload["data"].pop(missing_col, None)
resp = client.post("/predict", json=payload)
assert resp.status_code == 422
detail = resp.json().get("detail", {})
assert detail.get("message") == "Missing required input columns."
def test_predict_invalid_type(client):
preprocessor = client.app.state.preprocessor
payload = _build_payload(preprocessor)
invalid_col = _pick_numeric_required(preprocessor)
payload["data"][invalid_col] = "not_a_number"
resp = client.post("/predict", json=payload)
assert resp.status_code == 422
detail = resp.json().get("detail", {})
assert detail.get("message") == "Invalid numeric values provided."
def test_predict_out_of_range(client):
preprocessor = client.app.state.preprocessor
payload = _build_payload(preprocessor)
col, (min_val, max_val) = _pick_numeric_range(
preprocessor,
exclude=ALLOW_OUT_OF_RANGE_COLUMNS,
)
payload["data"][col] = max_val + 1
resp = client.post("/predict", json=payload)
assert resp.status_code == 422
detail = resp.json().get("detail", {})
assert detail.get("message") == "Input contains values outside expected ranges."
def test_predict_out_of_range_allowed_ext_source(client):
preprocessor = client.app.state.preprocessor
payload = _build_payload(preprocessor)
allowed = [
col
for col in ALLOW_OUT_OF_RANGE_COLUMNS
if col in preprocessor.numeric_ranges
]
if not allowed:
pytest.skip("No EXT_SOURCE ranges available for test.")
col = allowed[0]
_, max_val = preprocessor.numeric_ranges[col]
payload["data"][col] = max_val + 1
resp = client.post("/predict", json=payload)
assert resp.status_code == 200
def test_predict_normalizes_categoricals(client):
preprocessor = client.app.state.preprocessor
payload = _build_payload(preprocessor)
if "CODE_GENDER" in payload["data"]:
payload["data"]["CODE_GENDER"] = "female"
if "FLAG_OWN_CAR" in payload["data"]:
payload["data"]["FLAG_OWN_CAR"] = "true"
resp = client.post("/predict", json=payload)
assert resp.status_code == 200
def test_predict_days_employed_sentinel(client):
preprocessor = client.app.state.preprocessor
payload = _build_payload(preprocessor)
if "DAYS_EMPLOYED" in payload["data"]:
payload["data"]["DAYS_EMPLOYED"] = 365243
resp = client.post("/predict", json=payload)
assert resp.status_code == 200
|