Spaces:
Sleeping
Sleeping
File size: 5,408 Bytes
a2bc2a9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | import pytest
from fastapi.testclient import TestClient
from src.api.main import app
client = TestClient(app)
# ---------------------------------------------------------------------------
# Real rows taken directly from creditcard.csv for testing
# ---------------------------------------------------------------------------
FRAUD_TRANSACTION = {
"Time": 406, "Amount": 0.0,
"V1": -2.3122, "V2": 1.9519, "V3": -1.6097, "V4": 3.9979,
"V5": -0.5222, "V6": -1.4265, "V7": -2.5374, "V8": 1.3914,
"V9": -2.7700, "V10": -2.7722, "V11": 3.2020, "V12": -2.8992,
"V13": -0.5950, "V14": -4.2895, "V15": 0.3898, "V16": -1.1407,
"V17": -2.8300, "V18": -0.0168, "V19": 0.4165, "V20": 0.3269,
"V21": 0.1474, "V22": -0.1703, "V23": 0.0359, "V24": -0.4118,
"V25": 0.0714, "V26": 0.0719, "V27": 0.2127, "V28": 0.0952,
}
NORMAL_TRANSACTION = {
"Time": 0, "Amount": 149.62,
"V1": -1.3598, "V2": -0.0728, "V3": 2.5363, "V4": 1.3782,
"V5": -0.3383, "V6": 0.4624, "V7": 0.2396, "V8": 0.0987,
"V9": 0.3638, "V10": 0.0908, "V11": -0.5516, "V12": -0.6178,
"V13": -0.9914, "V14": -0.3112, "V15": 1.4681, "V16": -0.4704,
"V17": 0.2080, "V18": 0.0258, "V19": 0.4040, "V20": 0.2514,
"V21": -0.0183, "V22": 0.2778, "V23": -0.1105, "V24": 0.0669,
"V25": 0.1285, "V26": -0.1892, "V27": 0.1336, "V28": -0.0211,
}
ALL_ZEROS = {
"Time": 0, "Amount": 0,
**{f"V{i}": 0.0 for i in range(1, 29)},
}
LARGE_AMOUNT = {
"Time": 100000, "Amount": 99999.99,
**{f"V{i}": 0.0 for i in range(1, 29)},
}
# ---------------------------------------------------------------------------
# Health & root
# ---------------------------------------------------------------------------
def test_health_returns_200():
response = client.get("/health")
assert response.status_code == 200
def test_health_model_loaded():
data = client.get("/health").json()
assert data["status"] == "ok"
assert data["model_loaded"] is True
def test_root_returns_200():
response = client.get("/")
assert response.status_code == 200
def test_root_contains_endpoints():
data = client.get("/").json()
assert "endpoints" in data
# ---------------------------------------------------------------------------
# /predict — correct responses
# ---------------------------------------------------------------------------
def test_predict_fraud_transaction():
"""Known fraud row from dataset must return is_fraud = true."""
response = client.post("/predict", json=FRAUD_TRANSACTION)
assert response.status_code == 200
data = response.json()
assert data["is_fraud"] is True
assert data["fraud_probability"] > 0.5
def test_predict_normal_transaction():
"""Known normal row from dataset must return is_fraud = false."""
response = client.post("/predict", json=NORMAL_TRANSACTION)
assert response.status_code == 200
data = response.json()
assert data["is_fraud"] is False
assert data["fraud_probability"] < 0.5
# ---------------------------------------------------------------------------
# /predict — response schema
# ---------------------------------------------------------------------------
def test_predict_response_has_required_fields():
response = client.post("/predict", json=NORMAL_TRANSACTION)
data = response.json()
assert "is_fraud" in data
assert "fraud_probability" in data
assert "inference_ms" in data
def test_predict_probability_in_range():
response = client.post("/predict", json=NORMAL_TRANSACTION)
prob = response.json()["fraud_probability"]
assert 0.0 <= prob <= 1.0
def test_predict_inference_ms_is_positive():
response = client.post("/predict", json=NORMAL_TRANSACTION)
assert response.json()["inference_ms"] > 0
# ---------------------------------------------------------------------------
# /predict — edge cases
# ---------------------------------------------------------------------------
def test_predict_all_zeros():
"""All-zero input must not crash — returns a valid response."""
response = client.post("/predict", json=ALL_ZEROS)
assert response.status_code == 200
assert "is_fraud" in response.json()
def test_predict_large_amount():
"""Very large transaction amount must not crash."""
response = client.post("/predict", json=LARGE_AMOUNT)
assert response.status_code == 200
assert "is_fraud" in response.json()
# ---------------------------------------------------------------------------
# /predict — bad input
# ---------------------------------------------------------------------------
def test_predict_missing_field_returns_422():
"""Sending incomplete data must return HTTP 422 Unprocessable Entity."""
incomplete = {"Time": 0, "Amount": 100} # missing all V features
response = client.post("/predict", json=incomplete)
assert response.status_code == 422
def test_predict_negative_amount_returns_422():
"""Amount must be >= 0. Negative value must be rejected."""
bad = {**NORMAL_TRANSACTION, "Amount": -50}
response = client.post("/predict", json=bad)
assert response.status_code == 422
def test_predict_wrong_type_returns_422():
"""String value where float expected must be rejected."""
bad = {**NORMAL_TRANSACTION, "V1": "not_a_number"}
response = client.post("/predict", json=bad)
assert response.status_code == 422
|