File size: 3,985 Bytes
4937cba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from __future__ import annotations

from pathlib import Path

from fastapi import HTTPException
from fastapi.testclient import TestClient

from api.app import app, get_inference_service


class DummyService:
    threshold = 0.74
    model_path = Path("models/model.pkl")
    preprocessor_path = Path("models/preprocessor.pkl")

    def predict_records(self, records):
        outputs = []
        for record in records:
            amount = float(record["Amount"])
            prob = 0.9 if amount > 200 else 0.1
            outputs.append(
                {
                    "is_fraud": prob >= self.threshold,
                    "fraud_probability": prob,
                    "risk_level": "high" if prob >= 0.7 else "low",
                    "threshold": self.threshold,
                }
            )
        return outputs


def _transaction(amount: float = 10.0) -> dict[str, float]:
    payload = {"Time": 0.0, "Amount": amount}
    for i in range(1, 29):
        payload[f"V{i}"] = 0.0
    return payload


def test_health_endpoint() -> None:
    app.dependency_overrides[get_inference_service] = lambda: DummyService()
    client = TestClient(app)

    response = client.get("/health")

    assert response.status_code == 200
    body = response.json()
    assert body["status"] == "ok"
    assert body["model_loaded"] is True
    app.dependency_overrides.clear()


def test_predict_endpoint_valid_payload() -> None:
    app.dependency_overrides[get_inference_service] = lambda: DummyService()
    client = TestClient(app)

    response = client.post("/predict", json=_transaction(amount=350.0))

    assert response.status_code == 200
    body = response.json()
    assert body["is_fraud"] is True
    assert body["risk_level"] == "high"
    assert response.headers.get("X-Request-ID")
    app.dependency_overrides.clear()


def test_predict_endpoint_invalid_payload() -> None:
    app.dependency_overrides[get_inference_service] = lambda: DummyService()
    client = TestClient(app)

    payload = _transaction()
    payload.pop("V28")
    response = client.post("/predict", json=payload)

    assert response.status_code == 422
    app.dependency_overrides.clear()


def test_batch_prediction_endpoint() -> None:
    app.dependency_overrides[get_inference_service] = lambda: DummyService()
    client = TestClient(app)

    response = client.post(
        "/predict/batch",
        json={"transactions": [_transaction(20.0), _transaction(300.0)]},
    )

    assert response.status_code == 200
    body = response.json()
    assert len(body["predictions"]) == 2
    assert body["predictions"][0]["is_fraud"] is False
    assert body["predictions"][1]["is_fraud"] is True
    app.dependency_overrides.clear()


def test_metrics_endpoint_tracks_predictions_and_requests() -> None:
    app.dependency_overrides[get_inference_service] = lambda: DummyService()
    client = TestClient(app)

    before = client.get("/metrics")
    assert before.status_code == 200
    before_body = before.json()

    predict_response = client.post("/predict", json=_transaction(amount=350.0))
    assert predict_response.status_code == 200

    after = client.get("/metrics")
    assert after.status_code == 200
    after_body = after.json()

    assert after_body["total_requests"] >= before_body["total_requests"] + 2
    assert after_body["total_predictions"] >= before_body["total_predictions"] + 1
    assert 0.0 <= after_body["error_rate"] <= 1.0
    assert 0.0 <= after_body["fraud_prediction_rate"] <= 1.0
    app.dependency_overrides.clear()


def test_health_returns_503_when_service_unavailable() -> None:
    def _raise():
        raise HTTPException(status_code=503, detail="Model artifact not found")

    app.dependency_overrides[get_inference_service] = _raise
    client = TestClient(app)

    response = client.get("/health")

    assert response.status_code == 503
    assert "Model artifact not found" in response.json()["detail"]
    app.dependency_overrides.clear()