Spaces:
Sleeping
Sleeping
File size: 4,732 Bytes
8a08300 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
"""
Integration tests for FastAPI inference service.
"""
import json
from pathlib import Path
import pytest
from fastapi.testclient import TestClient
# Note: These tests require a trained model to be present
# Run training first: python src/models/train.py --data_path ...
@pytest.fixture
def api_client():
"""Create test client for API."""
# Import here to avoid startup issues before model is trained
from src.api.main import app
return TestClient(app)
@pytest.fixture
def sample_request_data():
"""Sample prediction request."""
return {
"user_id": "test_user_123",
"trans_date_trans_time": "2020-06-15 14:30:00",
"amt": 150.00,
"lat": 40.7128,
"long": -74.0060,
"merch_lat": 40.7200,
"merch_long": -74.0100,
"job": "Engineer, biomedical",
"category": "grocery_pos",
"gender": "M",
"dob": "1985-03-20",
}
class TestHealthEndpoint:
"""Tests for health check endpoint."""
def test_health_endpoint_exists(self, api_client):
"""Test that health endpoint is accessible."""
response = api_client.get("/health")
assert response.status_code == 200
def test_health_response_structure(self, api_client):
"""Test health response has correct structure."""
response = api_client.get("/health")
data = response.json()
assert "status" in data
assert "model_loaded" in data
assert "redis_connected" in data
assert "version" in data
class TestPredictEndpoint:
"""Tests for prediction endpoint."""
@pytest.mark.skip(reason="Requires trained model - run after training")
def test_predict_endpoint_returns_200(self, api_client, sample_request_data):
"""Test that predict endpoint returns 200 OK."""
response = api_client.post("/v1/predict", json=sample_request_data)
assert response.status_code == 200
@pytest.mark.skip(reason="Requires trained model - run after training")
def test_predict_response_structure(self, api_client, sample_request_data):
"""Test prediction response has correct structure."""
response = api_client.post("/v1/predict", json=sample_request_data)
data = response.json()
# Required fields
assert "decision" in data
assert "probability" in data
assert "risk_score" in data
assert "latency_ms" in data
assert "shadow_mode" in data
# Value constraints
assert data["decision"] in ["BLOCK", "APPROVE"]
assert 0 <= data["probability"] <= 1
assert 0 <= data["risk_score"] <= 100
assert data["latency_ms"] > 0
@pytest.mark.skip(reason="Requires trained model - run after training")
def test_latency_within_target(self, api_client, sample_request_data):
"""Test that latency is within 50ms target."""
response = api_client.post("/v1/predict", json=sample_request_data)
data = response.json()
# Should be well under 50ms for single prediction
assert data["latency_ms"] < 50.0
def test_predict_invalid_request(self, api_client):
"""Test that invalid request returns 422 validation error."""
invalid_data = {"amt": "not_a_number"} # Invalid type
response = api_client.post("/v1/predict", json=invalid_data)
assert response.status_code == 422 # Unprocessable Entity
class TestRootEndpoint:
"""Tests for root endpoint."""
def test_root_endpoint(self, api_client):
"""Test root endpoint returns API info."""
response = api_client.get("/")
assert response.status_code == 200
data = response.json()
assert "service" in data
assert "version" in data
assert "endpoints" in data
class TestShadowMode:
"""Tests for shadow mode functionality."""
@pytest.mark.skip(reason="Requires trained model and shadow mode config")
def test_shadow_mode_always_approves(self, api_client, sample_request_data):
"""Test that shadow mode always returns APPROVE."""
# This test assumes shadow_mode=True in config
response = api_client.post("/v1/predict", json=sample_request_data)
data = response.json()
if data["shadow_mode"]:
assert data["decision"] == "APPROVE"
@pytest.mark.skip(reason="Requires log file inspection")
def test_shadow_mode_logs_predictions(self, api_client, sample_request_data, tmp_path):
"""Test that shadow mode logs predictions to file."""
# Would need to inspect logs/shadow_predictions.jsonl
# to verify logging occurred
pass
if __name__ == "__main__":
pytest.main([__file__, "-v"])
|