File size: 5,864 Bytes
4a2ab42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b14fa0a
 
 
4a2ab42
 
 
 
 
 
 
b14fa0a
4a2ab42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b14fa0a
4a2ab42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
Performance Monitoring and Metrics Collection
Tracks API response times, database queries, fraud detection latency
"""

import functools
import logging
import time
from collections.abc import Callable
from typing import Any

from fastapi import Request, Response
from prometheus_client import Counter, Gauge, Histogram, generate_latest
from starlette.middleware.base import BaseHTTPMiddleware

logger = logging.getLogger(__name__)

# Prometheus Metrics
http_requests_total = Counter("http_requests_total", "Total HTTP requests", ["method", "endpoint", "status"])

http_request_duration_seconds = Histogram(
    "http_request_duration_seconds", "HTTP request latency", ["method", "endpoint"]
)

fraud_detections_total = Counter("fraud_detections_total", "Total fraud detections", ["risk_level"])

ai_predictions_total = Counter("ai_predictions_total", "Total AI predictions made", ["model_type"])

pending_cases = Gauge("pending_cases", "Number of pending fraud cases")

from core.query_monitoring import query_duration  # noqa: E402

cache_hits_total = Counter("cache_hits_total", "Total cache hits", ["cache_type"])

cache_misses_total = Counter("cache_misses_total", "Total cache misses", ["cache_type"])

websocket_connections = Gauge("websocket_connections", "Active WebSocket connections")


class PerformanceMonitoringMiddleware(BaseHTTPMiddleware):
    """Middleware to monitor request performance"""

    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        # Start timing
        start_time = time.time()

        # Process request
        response = await call_next(request)

        # Calculate duration
        duration = time.time() - start_time

        # Extract endpoint and method
        endpoint = request.url.path
        method = request.method
        status = response.status_code

        # Record metrics
        http_requests_total.labels(method=method, endpoint=endpoint, status=status).inc()

        http_request_duration_seconds.labels(method=method, endpoint=endpoint).observe(duration)

        # Add performance headers
        response.headers["X-Response-Time"] = f"{duration:.4f}s"

        # Log slow requests
        if duration > 1.0:  # More than 1 second
            logger.warning(f"Slow request: {method} {endpoint} took {duration:.2f}s")

        return response


def track_performance(metric_name: str | None = None):
    """Decorator to track function performance"""

    def decorator(func: Callable) -> Callable:
        @functools.wraps(func)
        async def async_wrapper(*args, **kwargs) -> Any:
            start_time = time.time()
            try:
                result = await func(*args, **kwargs)
                return result
            finally:
                duration = time.time() - start_time
                name = metric_name or func.__name__
                logger.debug(f"{name} took {duration:.4f}s")

        @functools.wraps(func)
        def sync_wrapper(*args, **kwargs) -> Any:
            start_time = time.time()
            try:
                result = func(*args, **kwargs)
                return result
            finally:
                duration = time.time() - start_time
                name = metric_name or func.__name__
                logger.debug(f"{name} took {duration:.4f}s")

        # Return appropriate wrapper based on function type
        import asyncio

        if asyncio.iscoroutinefunction(func):
            return async_wrapper
        else:
            return sync_wrapper

    return decorator


def record_fraud_detection(risk_level: str):
    """Record fraud detection event"""
    fraud_detections_total.labels(risk_level=risk_level).inc()


def record_ai_prediction(model_type: str):
    """Record AI prediction event"""
    ai_predictions_total.labels(model_type=model_type).inc()


def update_pending_cases(count: int):
    """Update pending cases gauge"""
    pending_cases.set(count)


def record_db_query(query_type: str, duration: float):
    """Record database query performance"""
    query_duration.labels(operation="general", table=query_type).observe(duration)


def record_cache_hit(cache_type: str = "default"):
    """Record cache hit"""
    cache_hits_total.labels(cache_type=cache_type).inc()


def record_cache_miss(cache_type: str = "default"):
    """Record cache miss"""
    cache_misses_total.labels(cache_type=cache_type).inc()


def increment_ws_connections():
    """Increment WebSocket connection count"""
    websocket_connections.inc()


def decrement_ws_connections():
    """Decrement WebSocket connection count"""
    websocket_connections.dec()


class PerformanceTracker:
    """Context manager for tracking operation performance"""

    def __init__(self, operation_name: str):
        self.operation_name = operation_name
        self.start_time = None

    def __enter__(self):
        self.start_time = time.time()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        duration = time.time() - self.start_time
        logger.info(f"Performance: {self.operation_name} completed in {duration:.4f}s")

        # Record to Prometheus if it's a known metric
        if "fraud_detection" in self.operation_name.lower():
            # Record fraud detection timing
            pass
        elif "database" in self.operation_name.lower():
            record_db_query("general", duration)


# Helper function to get metrics
def get_metrics():
    """Get all Prometheus metrics"""
    return generate_latest()


# Example usage in endpoints:
"""
@app.get("/api/fraud/analyze")
@track_performance("fraud_analysis")
async def analyze_transaction(transaction: Transaction):
    with PerformanceTracker("fraud_detection_complete"):
        result = await fraud_service.analyze(transaction)
        record_fraud_detection(result.risk_level)
        return result
"""