agentbench / agent_bench /serving /middleware.py
Nomearod's picture
feat: add provider retry with backoff and API rate limiting
871820a
"""Request middleware: ID generation, logging, error handling, metrics, rate limiting."""
from __future__ import annotations
import time
import uuid
from collections import defaultdict, deque
import structlog
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from agent_bench.core.provider import ProviderRateLimitError, ProviderTimeoutError
logger = structlog.get_logger()
class MetricsCollector:
"""In-process metrics. Resets on restart."""
def __init__(self, maxlen: int = 1000) -> None:
self.latencies: deque[float] = deque(maxlen=maxlen)
self.requests_total: int = 0
self.errors_total: int = 0
self.total_cost_usd: float = 0.0
def record(self, latency_ms: float, cost_usd: float = 0.0, error: bool = False) -> None:
self.latencies.append(latency_ms)
self.requests_total += 1
self.total_cost_usd += cost_usd
if error:
self.errors_total += 1
def percentile(self, p: float) -> float:
if not self.latencies:
return 0.0
sorted_latencies = sorted(self.latencies)
idx = int(len(sorted_latencies) * p / 100)
idx = min(idx, len(sorted_latencies) - 1)
return sorted_latencies[idx]
@property
def avg_cost(self) -> float:
if self.requests_total == 0:
return 0.0
return self.total_cost_usd / self.requests_total
class RateLimitMiddleware(BaseHTTPMiddleware):
"""In-memory sliding window rate limiter, per client IP."""
EXEMPT_PATHS = {"/health", "/metrics"}
def __init__(self, app: object, requests_per_minute: int = 10) -> None:
super().__init__(app) # type: ignore[arg-type]
self.rpm = requests_per_minute
self.windows: dict[str, list[float]] = defaultdict(list)
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
if request.url.path in self.EXEMPT_PATHS:
return await call_next(request)
client_ip = request.client.host if request.client else "unknown"
now = time.time()
window_start = now - 60
# Prune timestamps outside the window
self.windows[client_ip] = [
t for t in self.windows[client_ip] if t > window_start
]
if len(self.windows[client_ip]) >= self.rpm:
retry_after = max(1, int(60 - (now - self.windows[client_ip][0])))
logger.warning("rate_limited",
client_ip=client_ip,
requests_in_window=len(self.windows[client_ip]))
return JSONResponse(
status_code=429,
content={"error": "Rate limit exceeded", "retry_after": retry_after},
headers={"Retry-After": str(retry_after)},
)
self.windows[client_ip].append(now)
return await call_next(request)
class RequestMiddleware(BaseHTTPMiddleware):
"""Adds request ID, timing, structured logging, and error handling."""
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
request_id = str(uuid.uuid4())
request.state.request_id = request_id
start = time.perf_counter()
try:
response = await call_next(request)
latency_ms = (time.perf_counter() - start) * 1000
response.headers["X-Request-ID"] = request_id
logger.info(
"request_completed",
method=request.method,
path=str(request.url.path),
status=response.status_code,
latency_ms=round(latency_ms, 2),
request_id=request_id,
)
return response
except ProviderTimeoutError:
latency_ms = (time.perf_counter() - start) * 1000
logger.error(
"provider_timeout",
method=request.method,
path=str(request.url.path),
latency_ms=round(latency_ms, 2),
request_id=request_id,
)
metrics = getattr(request.app.state, "metrics", None)
if metrics is not None:
metrics.record(latency_ms, error=True)
return JSONResponse(
status_code=504,
content={"detail": "Provider timed out", "request_id": request_id},
headers={"X-Request-ID": request_id},
)
except ProviderRateLimitError:
latency_ms = (time.perf_counter() - start) * 1000
logger.error(
"provider_rate_limit",
method=request.method,
path=str(request.url.path),
latency_ms=round(latency_ms, 2),
request_id=request_id,
)
metrics = getattr(request.app.state, "metrics", None)
if metrics is not None:
metrics.record(latency_ms, error=True)
return JSONResponse(
status_code=503,
content={
"detail": "Provider rate limit or quota exceeded",
"request_id": request_id,
},
headers={"X-Request-ID": request_id},
)
except Exception:
latency_ms = (time.perf_counter() - start) * 1000
logger.exception(
"unhandled_error",
method=request.method,
path=str(request.url.path),
latency_ms=round(latency_ms, 2),
request_id=request_id,
)
metrics = getattr(request.app.state, "metrics", None)
if metrics is not None:
metrics.record(latency_ms, error=True)
return JSONResponse(
status_code=500,
content={"detail": "Internal server error", "request_id": request_id},
headers={"X-Request-ID": request_id},
)