File size: 2,118 Bytes
198ccb0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | """Monitoring middleware for FastAPI."""
import time
import logging
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Optional
from monitoring.prediction_logger import PredictionLogger
from monitoring.data_drift import DataDriftDetector
from monitoring.performance_monitor import PerformanceMonitor
logger = logging.getLogger(__name__)
class MonitoringMiddleware(BaseHTTPMiddleware):
"""
Middleware for monitoring API requests and predictions.
Logs predictions, detects data drift, and monitors performance.
"""
def __init__(
self,
app,
prediction_logger: PredictionLogger,
drift_detector: Optional[DataDriftDetector] = None,
performance_monitor: Optional[PerformanceMonitor] = None,
):
"""
Initialize monitoring middleware.
Args:
app: FastAPI application
prediction_logger: Prediction logger instance
drift_detector: Optional drift detector
performance_monitor: Optional performance monitor
"""
super().__init__(app)
self.prediction_logger = prediction_logger
self.drift_detector = drift_detector
self.performance_monitor = performance_monitor
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""Process request with monitoring."""
start_time = time.time()
# Process request
response = await call_next(request)
# Calculate latency
latency_ms = (time.time() - start_time) * 1000
# Monitor classification endpoints (log in background)
if request.url.path in ["/classify", "/classify/batch"]:
# Store monitoring data in response state for background processing
# The actual logging will be done by the endpoint or a background task
response.headers["X-Process-Time"] = str(latency_ms)
response.headers["X-Monitored"] = "true"
return response
|