Spaces:
Running
Running
| """Request middleware: ID generation, logging, error handling, metrics, rate limiting.""" | |
| from __future__ import annotations | |
| import time | |
| import uuid | |
| from collections import defaultdict, deque | |
| import structlog | |
| from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint | |
| from starlette.requests import Request | |
| from starlette.responses import JSONResponse, Response | |
| from agent_bench.core.provider import ProviderRateLimitError, ProviderTimeoutError | |
| logger = structlog.get_logger() | |
| class MetricsCollector: | |
| """In-process metrics. Resets on restart.""" | |
| def __init__(self, maxlen: int = 1000) -> None: | |
| self.latencies: deque[float] = deque(maxlen=maxlen) | |
| self.requests_total: int = 0 | |
| self.errors_total: int = 0 | |
| self.total_cost_usd: float = 0.0 | |
| def record(self, latency_ms: float, cost_usd: float = 0.0, error: bool = False) -> None: | |
| self.latencies.append(latency_ms) | |
| self.requests_total += 1 | |
| self.total_cost_usd += cost_usd | |
| if error: | |
| self.errors_total += 1 | |
| def percentile(self, p: float) -> float: | |
| if not self.latencies: | |
| return 0.0 | |
| sorted_latencies = sorted(self.latencies) | |
| idx = int(len(sorted_latencies) * p / 100) | |
| idx = min(idx, len(sorted_latencies) - 1) | |
| return sorted_latencies[idx] | |
| def avg_cost(self) -> float: | |
| if self.requests_total == 0: | |
| return 0.0 | |
| return self.total_cost_usd / self.requests_total | |
| class RateLimitMiddleware(BaseHTTPMiddleware): | |
| """In-memory sliding window rate limiter, per client IP.""" | |
| EXEMPT_PATHS = {"/health", "/metrics"} | |
| def __init__(self, app: object, requests_per_minute: int = 10) -> None: | |
| super().__init__(app) # type: ignore[arg-type] | |
| self.rpm = requests_per_minute | |
| self.windows: dict[str, list[float]] = defaultdict(list) | |
| async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: | |
| if request.url.path in self.EXEMPT_PATHS: | |
| return await call_next(request) | |
| client_ip = request.client.host if request.client else "unknown" | |
| now = time.time() | |
| window_start = now - 60 | |
| # Prune timestamps outside the window | |
| self.windows[client_ip] = [ | |
| t for t in self.windows[client_ip] if t > window_start | |
| ] | |
| if len(self.windows[client_ip]) >= self.rpm: | |
| retry_after = max(1, int(60 - (now - self.windows[client_ip][0]))) | |
| logger.warning("rate_limited", | |
| client_ip=client_ip, | |
| requests_in_window=len(self.windows[client_ip])) | |
| return JSONResponse( | |
| status_code=429, | |
| content={"error": "Rate limit exceeded", "retry_after": retry_after}, | |
| headers={"Retry-After": str(retry_after)}, | |
| ) | |
| self.windows[client_ip].append(now) | |
| return await call_next(request) | |
| class RequestMiddleware(BaseHTTPMiddleware): | |
| """Adds request ID, timing, structured logging, and error handling.""" | |
| async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: | |
| request_id = str(uuid.uuid4()) | |
| request.state.request_id = request_id | |
| start = time.perf_counter() | |
| try: | |
| response = await call_next(request) | |
| latency_ms = (time.perf_counter() - start) * 1000 | |
| response.headers["X-Request-ID"] = request_id | |
| logger.info( | |
| "request_completed", | |
| method=request.method, | |
| path=str(request.url.path), | |
| status=response.status_code, | |
| latency_ms=round(latency_ms, 2), | |
| request_id=request_id, | |
| ) | |
| return response | |
| except ProviderTimeoutError: | |
| latency_ms = (time.perf_counter() - start) * 1000 | |
| logger.error( | |
| "provider_timeout", | |
| method=request.method, | |
| path=str(request.url.path), | |
| latency_ms=round(latency_ms, 2), | |
| request_id=request_id, | |
| ) | |
| metrics = getattr(request.app.state, "metrics", None) | |
| if metrics is not None: | |
| metrics.record(latency_ms, error=True) | |
| return JSONResponse( | |
| status_code=504, | |
| content={"detail": "Provider timed out", "request_id": request_id}, | |
| headers={"X-Request-ID": request_id}, | |
| ) | |
| except ProviderRateLimitError: | |
| latency_ms = (time.perf_counter() - start) * 1000 | |
| logger.error( | |
| "provider_rate_limit", | |
| method=request.method, | |
| path=str(request.url.path), | |
| latency_ms=round(latency_ms, 2), | |
| request_id=request_id, | |
| ) | |
| metrics = getattr(request.app.state, "metrics", None) | |
| if metrics is not None: | |
| metrics.record(latency_ms, error=True) | |
| return JSONResponse( | |
| status_code=503, | |
| content={ | |
| "detail": "Provider rate limit or quota exceeded", | |
| "request_id": request_id, | |
| }, | |
| headers={"X-Request-ID": request_id}, | |
| ) | |
| except Exception: | |
| latency_ms = (time.perf_counter() - start) * 1000 | |
| logger.exception( | |
| "unhandled_error", | |
| method=request.method, | |
| path=str(request.url.path), | |
| latency_ms=round(latency_ms, 2), | |
| request_id=request_id, | |
| ) | |
| metrics = getattr(request.app.state, "metrics", None) | |
| if metrics is not None: | |
| metrics.record(latency_ms, error=True) | |
| return JSONResponse( | |
| status_code=500, | |
| content={"detail": "Internal server error", "request_id": request_id}, | |
| headers={"X-Request-ID": request_id}, | |
| ) | |