Spaces:
Sleeping
Sleeping
| import pytest | |
| from fastapi.testclient import TestClient | |
| from src.api.main import app | |
| client = TestClient(app) | |
| # --------------------------------------------------------------------------- | |
| # Real rows taken directly from creditcard.csv for testing | |
| # --------------------------------------------------------------------------- | |
| FRAUD_TRANSACTION = { | |
| "Time": 406, "Amount": 0.0, | |
| "V1": -2.3122, "V2": 1.9519, "V3": -1.6097, "V4": 3.9979, | |
| "V5": -0.5222, "V6": -1.4265, "V7": -2.5374, "V8": 1.3914, | |
| "V9": -2.7700, "V10": -2.7722, "V11": 3.2020, "V12": -2.8992, | |
| "V13": -0.5950, "V14": -4.2895, "V15": 0.3898, "V16": -1.1407, | |
| "V17": -2.8300, "V18": -0.0168, "V19": 0.4165, "V20": 0.3269, | |
| "V21": 0.1474, "V22": -0.1703, "V23": 0.0359, "V24": -0.4118, | |
| "V25": 0.0714, "V26": 0.0719, "V27": 0.2127, "V28": 0.0952, | |
| } | |
| NORMAL_TRANSACTION = { | |
| "Time": 0, "Amount": 149.62, | |
| "V1": -1.3598, "V2": -0.0728, "V3": 2.5363, "V4": 1.3782, | |
| "V5": -0.3383, "V6": 0.4624, "V7": 0.2396, "V8": 0.0987, | |
| "V9": 0.3638, "V10": 0.0908, "V11": -0.5516, "V12": -0.6178, | |
| "V13": -0.9914, "V14": -0.3112, "V15": 1.4681, "V16": -0.4704, | |
| "V17": 0.2080, "V18": 0.0258, "V19": 0.4040, "V20": 0.2514, | |
| "V21": -0.0183, "V22": 0.2778, "V23": -0.1105, "V24": 0.0669, | |
| "V25": 0.1285, "V26": -0.1892, "V27": 0.1336, "V28": -0.0211, | |
| } | |
| ALL_ZEROS = { | |
| "Time": 0, "Amount": 0, | |
| **{f"V{i}": 0.0 for i in range(1, 29)}, | |
| } | |
| LARGE_AMOUNT = { | |
| "Time": 100000, "Amount": 99999.99, | |
| **{f"V{i}": 0.0 for i in range(1, 29)}, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Health & root | |
| # --------------------------------------------------------------------------- | |
| def test_health_returns_200(): | |
| response = client.get("/health") | |
| assert response.status_code == 200 | |
| def test_health_model_loaded(): | |
| data = client.get("/health").json() | |
| assert data["status"] == "ok" | |
| assert data["model_loaded"] is True | |
| def test_root_returns_200(): | |
| response = client.get("/") | |
| assert response.status_code == 200 | |
| def test_root_contains_endpoints(): | |
| data = client.get("/").json() | |
| assert "endpoints" in data | |
| # --------------------------------------------------------------------------- | |
| # /predict — correct responses | |
| # --------------------------------------------------------------------------- | |
| def test_predict_fraud_transaction(): | |
| """Known fraud row from dataset must return is_fraud = true.""" | |
| response = client.post("/predict", json=FRAUD_TRANSACTION) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["is_fraud"] is True | |
| assert data["fraud_probability"] > 0.5 | |
| def test_predict_normal_transaction(): | |
| """Known normal row from dataset must return is_fraud = false.""" | |
| response = client.post("/predict", json=NORMAL_TRANSACTION) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["is_fraud"] is False | |
| assert data["fraud_probability"] < 0.5 | |
| # --------------------------------------------------------------------------- | |
| # /predict — response schema | |
| # --------------------------------------------------------------------------- | |
| def test_predict_response_has_required_fields(): | |
| response = client.post("/predict", json=NORMAL_TRANSACTION) | |
| data = response.json() | |
| assert "is_fraud" in data | |
| assert "fraud_probability" in data | |
| assert "inference_ms" in data | |
| def test_predict_probability_in_range(): | |
| response = client.post("/predict", json=NORMAL_TRANSACTION) | |
| prob = response.json()["fraud_probability"] | |
| assert 0.0 <= prob <= 1.0 | |
| def test_predict_inference_ms_is_positive(): | |
| response = client.post("/predict", json=NORMAL_TRANSACTION) | |
| assert response.json()["inference_ms"] > 0 | |
| # --------------------------------------------------------------------------- | |
| # /predict — edge cases | |
| # --------------------------------------------------------------------------- | |
| def test_predict_all_zeros(): | |
| """All-zero input must not crash — returns a valid response.""" | |
| response = client.post("/predict", json=ALL_ZEROS) | |
| assert response.status_code == 200 | |
| assert "is_fraud" in response.json() | |
| def test_predict_large_amount(): | |
| """Very large transaction amount must not crash.""" | |
| response = client.post("/predict", json=LARGE_AMOUNT) | |
| assert response.status_code == 200 | |
| assert "is_fraud" in response.json() | |
| # --------------------------------------------------------------------------- | |
| # /predict — bad input | |
| # --------------------------------------------------------------------------- | |
| def test_predict_missing_field_returns_422(): | |
| """Sending incomplete data must return HTTP 422 Unprocessable Entity.""" | |
| incomplete = {"Time": 0, "Amount": 100} # missing all V features | |
| response = client.post("/predict", json=incomplete) | |
| assert response.status_code == 422 | |
| def test_predict_negative_amount_returns_422(): | |
| """Amount must be >= 0. Negative value must be rejected.""" | |
| bad = {**NORMAL_TRANSACTION, "Amount": -50} | |
| response = client.post("/predict", json=bad) | |
| assert response.status_code == 422 | |
| def test_predict_wrong_type_returns_422(): | |
| """String value where float expected must be rejected.""" | |
| bad = {**NORMAL_TRANSACTION, "V1": "not_a_number"} | |
| response = client.post("/predict", json=bad) | |
| assert response.status_code == 422 | |