File size: 6,097 Bytes
871820a
2d8989f
 
 
 
 
871820a
2d8989f
 
 
 
 
 
2a4cb78
2d8989f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
871820a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d8989f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a4cb78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d8989f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""Request middleware: ID generation, logging, error handling, metrics, rate limiting."""

from __future__ import annotations

import time
import uuid
from collections import defaultdict, deque

import structlog
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import JSONResponse, Response

from agent_bench.core.provider import ProviderRateLimitError, ProviderTimeoutError

logger = structlog.get_logger()


class MetricsCollector:
    """In-process metrics. Resets on restart."""

    def __init__(self, maxlen: int = 1000) -> None:
        self.latencies: deque[float] = deque(maxlen=maxlen)
        self.requests_total: int = 0
        self.errors_total: int = 0
        self.total_cost_usd: float = 0.0

    def record(self, latency_ms: float, cost_usd: float = 0.0, error: bool = False) -> None:
        self.latencies.append(latency_ms)
        self.requests_total += 1
        self.total_cost_usd += cost_usd
        if error:
            self.errors_total += 1

    def percentile(self, p: float) -> float:
        if not self.latencies:
            return 0.0
        sorted_latencies = sorted(self.latencies)
        idx = int(len(sorted_latencies) * p / 100)
        idx = min(idx, len(sorted_latencies) - 1)
        return sorted_latencies[idx]

    @property
    def avg_cost(self) -> float:
        if self.requests_total == 0:
            return 0.0
        return self.total_cost_usd / self.requests_total


class RateLimitMiddleware(BaseHTTPMiddleware):
    """In-memory sliding window rate limiter, per client IP."""

    EXEMPT_PATHS = {"/health", "/metrics"}

    def __init__(self, app: object, requests_per_minute: int = 10) -> None:
        super().__init__(app)  # type: ignore[arg-type]
        self.rpm = requests_per_minute
        self.windows: dict[str, list[float]] = defaultdict(list)

    async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
        if request.url.path in self.EXEMPT_PATHS:
            return await call_next(request)

        client_ip = request.client.host if request.client else "unknown"
        now = time.time()
        window_start = now - 60

        # Prune timestamps outside the window
        self.windows[client_ip] = [
            t for t in self.windows[client_ip] if t > window_start
        ]

        if len(self.windows[client_ip]) >= self.rpm:
            retry_after = max(1, int(60 - (now - self.windows[client_ip][0])))
            logger.warning("rate_limited",
                           client_ip=client_ip,
                           requests_in_window=len(self.windows[client_ip]))
            return JSONResponse(
                status_code=429,
                content={"error": "Rate limit exceeded", "retry_after": retry_after},
                headers={"Retry-After": str(retry_after)},
            )

        self.windows[client_ip].append(now)
        return await call_next(request)


class RequestMiddleware(BaseHTTPMiddleware):
    """Adds request ID, timing, structured logging, and error handling."""

    async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
        request_id = str(uuid.uuid4())
        request.state.request_id = request_id
        start = time.perf_counter()

        try:
            response = await call_next(request)
            latency_ms = (time.perf_counter() - start) * 1000

            response.headers["X-Request-ID"] = request_id

            logger.info(
                "request_completed",
                method=request.method,
                path=str(request.url.path),
                status=response.status_code,
                latency_ms=round(latency_ms, 2),
                request_id=request_id,
            )

            return response

        except ProviderTimeoutError:
            latency_ms = (time.perf_counter() - start) * 1000
            logger.error(
                "provider_timeout",
                method=request.method,
                path=str(request.url.path),
                latency_ms=round(latency_ms, 2),
                request_id=request_id,
            )
            metrics = getattr(request.app.state, "metrics", None)
            if metrics is not None:
                metrics.record(latency_ms, error=True)
            return JSONResponse(
                status_code=504,
                content={"detail": "Provider timed out", "request_id": request_id},
                headers={"X-Request-ID": request_id},
            )

        except ProviderRateLimitError:
            latency_ms = (time.perf_counter() - start) * 1000
            logger.error(
                "provider_rate_limit",
                method=request.method,
                path=str(request.url.path),
                latency_ms=round(latency_ms, 2),
                request_id=request_id,
            )
            metrics = getattr(request.app.state, "metrics", None)
            if metrics is not None:
                metrics.record(latency_ms, error=True)
            return JSONResponse(
                status_code=503,
                content={
                    "detail": "Provider rate limit or quota exceeded",
                    "request_id": request_id,
                },
                headers={"X-Request-ID": request_id},
            )

        except Exception:
            latency_ms = (time.perf_counter() - start) * 1000
            logger.exception(
                "unhandled_error",
                method=request.method,
                path=str(request.url.path),
                latency_ms=round(latency_ms, 2),
                request_id=request_id,
            )
            metrics = getattr(request.app.state, "metrics", None)
            if metrics is not None:
                metrics.record(latency_ms, error=True)
            return JSONResponse(
                status_code=500,
                content={"detail": "Internal server error", "request_id": request_id},
                headers={"X-Request-ID": request_id},
            )