Zalaid's picture
Step 8: add 28 automated tests covering preprocessing, models, and API
a2bc2a9
import pytest
from fastapi.testclient import TestClient
from src.api.main import app
client = TestClient(app)
# ---------------------------------------------------------------------------
# Real rows taken directly from creditcard.csv for testing
# ---------------------------------------------------------------------------
FRAUD_TRANSACTION = {
"Time": 406, "Amount": 0.0,
"V1": -2.3122, "V2": 1.9519, "V3": -1.6097, "V4": 3.9979,
"V5": -0.5222, "V6": -1.4265, "V7": -2.5374, "V8": 1.3914,
"V9": -2.7700, "V10": -2.7722, "V11": 3.2020, "V12": -2.8992,
"V13": -0.5950, "V14": -4.2895, "V15": 0.3898, "V16": -1.1407,
"V17": -2.8300, "V18": -0.0168, "V19": 0.4165, "V20": 0.3269,
"V21": 0.1474, "V22": -0.1703, "V23": 0.0359, "V24": -0.4118,
"V25": 0.0714, "V26": 0.0719, "V27": 0.2127, "V28": 0.0952,
}
NORMAL_TRANSACTION = {
"Time": 0, "Amount": 149.62,
"V1": -1.3598, "V2": -0.0728, "V3": 2.5363, "V4": 1.3782,
"V5": -0.3383, "V6": 0.4624, "V7": 0.2396, "V8": 0.0987,
"V9": 0.3638, "V10": 0.0908, "V11": -0.5516, "V12": -0.6178,
"V13": -0.9914, "V14": -0.3112, "V15": 1.4681, "V16": -0.4704,
"V17": 0.2080, "V18": 0.0258, "V19": 0.4040, "V20": 0.2514,
"V21": -0.0183, "V22": 0.2778, "V23": -0.1105, "V24": 0.0669,
"V25": 0.1285, "V26": -0.1892, "V27": 0.1336, "V28": -0.0211,
}
ALL_ZEROS = {
"Time": 0, "Amount": 0,
**{f"V{i}": 0.0 for i in range(1, 29)},
}
LARGE_AMOUNT = {
"Time": 100000, "Amount": 99999.99,
**{f"V{i}": 0.0 for i in range(1, 29)},
}
# ---------------------------------------------------------------------------
# Health & root
# ---------------------------------------------------------------------------
def test_health_returns_200():
response = client.get("/health")
assert response.status_code == 200
def test_health_model_loaded():
data = client.get("/health").json()
assert data["status"] == "ok"
assert data["model_loaded"] is True
def test_root_returns_200():
response = client.get("/")
assert response.status_code == 200
def test_root_contains_endpoints():
data = client.get("/").json()
assert "endpoints" in data
# ---------------------------------------------------------------------------
# /predict — correct responses
# ---------------------------------------------------------------------------
def test_predict_fraud_transaction():
"""Known fraud row from dataset must return is_fraud = true."""
response = client.post("/predict", json=FRAUD_TRANSACTION)
assert response.status_code == 200
data = response.json()
assert data["is_fraud"] is True
assert data["fraud_probability"] > 0.5
def test_predict_normal_transaction():
"""Known normal row from dataset must return is_fraud = false."""
response = client.post("/predict", json=NORMAL_TRANSACTION)
assert response.status_code == 200
data = response.json()
assert data["is_fraud"] is False
assert data["fraud_probability"] < 0.5
# ---------------------------------------------------------------------------
# /predict — response schema
# ---------------------------------------------------------------------------
def test_predict_response_has_required_fields():
response = client.post("/predict", json=NORMAL_TRANSACTION)
data = response.json()
assert "is_fraud" in data
assert "fraud_probability" in data
assert "inference_ms" in data
def test_predict_probability_in_range():
response = client.post("/predict", json=NORMAL_TRANSACTION)
prob = response.json()["fraud_probability"]
assert 0.0 <= prob <= 1.0
def test_predict_inference_ms_is_positive():
response = client.post("/predict", json=NORMAL_TRANSACTION)
assert response.json()["inference_ms"] > 0
# ---------------------------------------------------------------------------
# /predict — edge cases
# ---------------------------------------------------------------------------
def test_predict_all_zeros():
"""All-zero input must not crash — returns a valid response."""
response = client.post("/predict", json=ALL_ZEROS)
assert response.status_code == 200
assert "is_fraud" in response.json()
def test_predict_large_amount():
"""Very large transaction amount must not crash."""
response = client.post("/predict", json=LARGE_AMOUNT)
assert response.status_code == 200
assert "is_fraud" in response.json()
# ---------------------------------------------------------------------------
# /predict — bad input
# ---------------------------------------------------------------------------
def test_predict_missing_field_returns_422():
"""Sending incomplete data must return HTTP 422 Unprocessable Entity."""
incomplete = {"Time": 0, "Amount": 100} # missing all V features
response = client.post("/predict", json=incomplete)
assert response.status_code == 422
def test_predict_negative_amount_returns_422():
"""Amount must be >= 0. Negative value must be rejected."""
bad = {**NORMAL_TRANSACTION, "Amount": -50}
response = client.post("/predict", json=bad)
assert response.status_code == 422
def test_predict_wrong_type_returns_422():
"""String value where float expected must be rejected."""
bad = {**NORMAL_TRANSACTION, "V1": "not_a_number"}
response = client.post("/predict", json=bad)
assert response.status_code == 422