File size: 3,006 Bytes
1aa566a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
FastAPI endpoint tests using httpx + TestClient.
"""
from __future__ import annotations

import pytest
from fastapi.testclient import TestClient


@pytest.fixture(scope="module")
def client():
    from src.api.main import create_app
    app = create_app()
    with TestClient(app) as c:
        yield c


class TestHealth:
    def test_health_endpoint(self, client: TestClient) -> None:
        r = client.get("/health")
        assert r.status_code == 200
        data = r.json()
        assert data["status"] == "ok"
        assert "model_loaded" in data
        assert "uptime_seconds" in data

    def test_root_redirect(self, client: TestClient) -> None:
        r = client.get("/")
        assert r.status_code == 200
        data = r.json()
        assert "service" in data


class TestPrediction:
    _valid_payload = {
        "passenger_count": 2,
        "trip_distance": 3.5,
        "pickup_hour": 8,
        "pickup_dow": 1,
        "pickup_month": 3,
        "pickup_is_weekend": 0,
        "rate_code_id": 1,
        "payment_type": 1,
        "pu_location_zone": 10,
        "do_location_zone": 25,
        "vendor_id": 1,
    }

    def test_predict_valid_input(self, client: TestClient) -> None:
        r = client.post("/predict", json=self._valid_payload)
        # If model not trained yet, 503 is expected; otherwise 200
        assert r.status_code in (200, 503)
        if r.status_code == 200:
            data = r.json()
            assert "request_id" in data
            assert "predicted_duration_min" in data
            assert data["predicted_duration_min"] > 0

    def test_predict_invalid_input(self, client: TestClient) -> None:
        bad = dict(self._valid_payload, passenger_count=99)  # out of range
        r = client.post("/predict", json=bad)
        assert r.status_code == 422  # validation error

    def test_predict_missing_field(self, client: TestClient) -> None:
        bad = {k: v for k, v in self._valid_payload.items() if k != "trip_distance"}
        r = client.post("/predict", json=bad)
        assert r.status_code == 422

    def test_feedback_unknown_id(self, client: TestClient) -> None:
        r = client.post(
            "/predict/feedback",
            json={"request_id": "nonexistent-id", "actual_duration_min": 12.5},
        )
        assert r.status_code == 200
        data = r.json()
        assert data["matched"] is False


class TestMonitoring:
    def test_metrics_endpoint(self, client: TestClient) -> None:
        r = client.get("/monitor/metrics")
        assert r.status_code == 200
        data = r.json()
        assert "n_pending" in data

    def test_history_endpoint(self, client: TestClient) -> None:
        r = client.get("/monitor/history?log_type=drift&limit=10")
        assert r.status_code == 200
        assert isinstance(r.json(), list)

    def test_history_invalid_type(self, client: TestClient) -> None:
        r = client.get("/monitor/history?log_type=INVALID")
        assert r.status_code == 400