File size: 4,732 Bytes
8a08300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Integration tests for FastAPI inference service.
"""

import json
from pathlib import Path

import pytest
from fastapi.testclient import TestClient


# Note: These tests require a trained model to be present
# Run training first: python src/models/train.py --data_path ...


@pytest.fixture
def api_client():
    """Create test client for API."""
    # Import here to avoid startup issues before model is trained
    from src.api.main import app

    return TestClient(app)


@pytest.fixture
def sample_request_data():
    """Sample prediction request."""
    return {
        "user_id": "test_user_123",
        "trans_date_trans_time": "2020-06-15 14:30:00",
        "amt": 150.00,
        "lat": 40.7128,
        "long": -74.0060,
        "merch_lat": 40.7200,
        "merch_long": -74.0100,
        "job": "Engineer, biomedical",
        "category": "grocery_pos",
        "gender": "M",
        "dob": "1985-03-20",
    }


class TestHealthEndpoint:
    """Tests for health check endpoint."""

    def test_health_endpoint_exists(self, api_client):
        """Test that health endpoint is accessible."""
        response = api_client.get("/health")
        assert response.status_code == 200

    def test_health_response_structure(self, api_client):
        """Test health response has correct structure."""
        response = api_client.get("/health")
        data = response.json()

        assert "status" in data
        assert "model_loaded" in data
        assert "redis_connected" in data
        assert "version" in data


class TestPredictEndpoint:
    """Tests for prediction endpoint."""

    @pytest.mark.skip(reason="Requires trained model - run after training")
    def test_predict_endpoint_returns_200(self, api_client, sample_request_data):
        """Test that predict endpoint returns 200 OK."""
        response = api_client.post("/v1/predict", json=sample_request_data)
        assert response.status_code == 200

    @pytest.mark.skip(reason="Requires trained model - run after training")
    def test_predict_response_structure(self, api_client, sample_request_data):
        """Test prediction response has correct structure."""
        response = api_client.post("/v1/predict", json=sample_request_data)
        data = response.json()

        # Required fields
        assert "decision" in data
        assert "probability" in data
        assert "risk_score" in data
        assert "latency_ms" in data
        assert "shadow_mode" in data

        # Value constraints
        assert data["decision"] in ["BLOCK", "APPROVE"]
        assert 0 <= data["probability"] <= 1
        assert 0 <= data["risk_score"] <= 100
        assert data["latency_ms"] > 0

    @pytest.mark.skip(reason="Requires trained model - run after training")
    def test_latency_within_target(self, api_client, sample_request_data):
        """Test that latency is within 50ms target."""
        response = api_client.post("/v1/predict", json=sample_request_data)
        data = response.json()

        # Should be well under 50ms for single prediction
        assert data["latency_ms"] < 50.0

    def test_predict_invalid_request(self, api_client):
        """Test that invalid request returns 422 validation error."""
        invalid_data = {"amt": "not_a_number"}  # Invalid type
        response = api_client.post("/v1/predict", json=invalid_data)
        assert response.status_code == 422  # Unprocessable Entity


class TestRootEndpoint:
    """Tests for root endpoint."""

    def test_root_endpoint(self, api_client):
        """Test root endpoint returns API info."""
        response = api_client.get("/")
        assert response.status_code == 200

        data = response.json()
        assert "service" in data
        assert "version" in data
        assert "endpoints" in data


class TestShadowMode:
    """Tests for shadow mode functionality."""

    @pytest.mark.skip(reason="Requires trained model and shadow mode config")
    def test_shadow_mode_always_approves(self, api_client, sample_request_data):
        """Test that shadow mode always returns APPROVE."""
        # This test assumes shadow_mode=True in config
        response = api_client.post("/v1/predict", json=sample_request_data)
        data = response.json()

        if data["shadow_mode"]:
            assert data["decision"] == "APPROVE"

    @pytest.mark.skip(reason="Requires log file inspection")
    def test_shadow_mode_logs_predictions(self, api_client, sample_request_data, tmp_path):
        """Test that shadow mode logs predictions to file."""
        # Would need to inspect logs/shadow_predictions.jsonl
        # to verify logging occurred
        pass


if __name__ == "__main__":
    pytest.main([__file__, "-v"])