File size: 2,118 Bytes
198ccb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""Monitoring middleware for FastAPI."""

import time
import logging
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware

from typing import Optional
from monitoring.prediction_logger import PredictionLogger
from monitoring.data_drift import DataDriftDetector
from monitoring.performance_monitor import PerformanceMonitor

logger = logging.getLogger(__name__)


class MonitoringMiddleware(BaseHTTPMiddleware):
    """
    Middleware for monitoring API requests and predictions.
    
    Logs predictions, detects data drift, and monitors performance.
    """

    def __init__(
        self,
        app,
        prediction_logger: PredictionLogger,
        drift_detector: Optional[DataDriftDetector] = None,
        performance_monitor: Optional[PerformanceMonitor] = None,
    ):
        """
        Initialize monitoring middleware.
        
        Args:
            app: FastAPI application
            prediction_logger: Prediction logger instance
            drift_detector: Optional drift detector
            performance_monitor: Optional performance monitor
        """
        super().__init__(app)
        self.prediction_logger = prediction_logger
        self.drift_detector = drift_detector
        self.performance_monitor = performance_monitor

    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        """Process request with monitoring."""
        start_time = time.time()
        
        # Process request
        response = await call_next(request)
        
        # Calculate latency
        latency_ms = (time.time() - start_time) * 1000
        
        # Monitor classification endpoints (log in background)
        if request.url.path in ["/classify", "/classify/batch"]:
            # Store monitoring data in response state for background processing
            # The actual logging will be done by the endpoint or a background task
            response.headers["X-Process-Time"] = str(latency_ms)
            response.headers["X-Monitored"] = "true"
        
        return response