File size: 5,408 Bytes
a2bc2a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import pytest
from fastapi.testclient import TestClient
from src.api.main import app

client = TestClient(app)

# ---------------------------------------------------------------------------
# Real rows taken directly from creditcard.csv for testing
# ---------------------------------------------------------------------------
FRAUD_TRANSACTION = {
    "Time": 406, "Amount": 0.0,
    "V1": -2.3122, "V2": 1.9519, "V3": -1.6097, "V4": 3.9979,
    "V5": -0.5222, "V6": -1.4265, "V7": -2.5374, "V8": 1.3914,
    "V9": -2.7700, "V10": -2.7722, "V11": 3.2020, "V12": -2.8992,
    "V13": -0.5950, "V14": -4.2895, "V15": 0.3898, "V16": -1.1407,
    "V17": -2.8300, "V18": -0.0168, "V19": 0.4165, "V20": 0.3269,
    "V21": 0.1474, "V22": -0.1703, "V23": 0.0359, "V24": -0.4118,
    "V25": 0.0714, "V26": 0.0719, "V27": 0.2127, "V28": 0.0952,
}

NORMAL_TRANSACTION = {
    "Time": 0, "Amount": 149.62,
    "V1": -1.3598, "V2": -0.0728, "V3": 2.5363, "V4": 1.3782,
    "V5": -0.3383, "V6": 0.4624, "V7": 0.2396, "V8": 0.0987,
    "V9": 0.3638, "V10": 0.0908, "V11": -0.5516, "V12": -0.6178,
    "V13": -0.9914, "V14": -0.3112, "V15": 1.4681, "V16": -0.4704,
    "V17": 0.2080, "V18": 0.0258, "V19": 0.4040, "V20": 0.2514,
    "V21": -0.0183, "V22": 0.2778, "V23": -0.1105, "V24": 0.0669,
    "V25": 0.1285, "V26": -0.1892, "V27": 0.1336, "V28": -0.0211,
}

ALL_ZEROS = {
    "Time": 0, "Amount": 0,
    **{f"V{i}": 0.0 for i in range(1, 29)},
}

LARGE_AMOUNT = {
    "Time": 100000, "Amount": 99999.99,
    **{f"V{i}": 0.0 for i in range(1, 29)},
}


# ---------------------------------------------------------------------------
# Health & root
# ---------------------------------------------------------------------------
def test_health_returns_200():
    response = client.get("/health")
    assert response.status_code == 200


def test_health_model_loaded():
    data = client.get("/health").json()
    assert data["status"] == "ok"
    assert data["model_loaded"] is True


def test_root_returns_200():
    response = client.get("/")
    assert response.status_code == 200


def test_root_contains_endpoints():
    data = client.get("/").json()
    assert "endpoints" in data


# ---------------------------------------------------------------------------
# /predict — correct responses
# ---------------------------------------------------------------------------
def test_predict_fraud_transaction():
    """Known fraud row from dataset must return is_fraud = true."""
    response = client.post("/predict", json=FRAUD_TRANSACTION)
    assert response.status_code == 200
    data = response.json()
    assert data["is_fraud"] is True
    assert data["fraud_probability"] > 0.5


def test_predict_normal_transaction():
    """Known normal row from dataset must return is_fraud = false."""
    response = client.post("/predict", json=NORMAL_TRANSACTION)
    assert response.status_code == 200
    data = response.json()
    assert data["is_fraud"] is False
    assert data["fraud_probability"] < 0.5


# ---------------------------------------------------------------------------
# /predict — response schema
# ---------------------------------------------------------------------------
def test_predict_response_has_required_fields():
    response = client.post("/predict", json=NORMAL_TRANSACTION)
    data = response.json()
    assert "is_fraud"          in data
    assert "fraud_probability" in data
    assert "inference_ms"      in data


def test_predict_probability_in_range():
    response = client.post("/predict", json=NORMAL_TRANSACTION)
    prob = response.json()["fraud_probability"]
    assert 0.0 <= prob <= 1.0


def test_predict_inference_ms_is_positive():
    response = client.post("/predict", json=NORMAL_TRANSACTION)
    assert response.json()["inference_ms"] > 0


# ---------------------------------------------------------------------------
# /predict — edge cases
# ---------------------------------------------------------------------------
def test_predict_all_zeros():
    """All-zero input must not crash — returns a valid response."""
    response = client.post("/predict", json=ALL_ZEROS)
    assert response.status_code == 200
    assert "is_fraud" in response.json()


def test_predict_large_amount():
    """Very large transaction amount must not crash."""
    response = client.post("/predict", json=LARGE_AMOUNT)
    assert response.status_code == 200
    assert "is_fraud" in response.json()


# ---------------------------------------------------------------------------
# /predict — bad input
# ---------------------------------------------------------------------------
def test_predict_missing_field_returns_422():
    """Sending incomplete data must return HTTP 422 Unprocessable Entity."""
    incomplete = {"Time": 0, "Amount": 100}   # missing all V features
    response = client.post("/predict", json=incomplete)
    assert response.status_code == 422


def test_predict_negative_amount_returns_422():
    """Amount must be >= 0. Negative value must be rejected."""
    bad = {**NORMAL_TRANSACTION, "Amount": -50}
    response = client.post("/predict", json=bad)
    assert response.status_code == 422


def test_predict_wrong_type_returns_422():
    """String value where float expected must be rejected."""
    bad = {**NORMAL_TRANSACTION, "V1": "not_a_number"}
    response = client.post("/predict", json=bad)
    assert response.status_code == 422