"""Generic HTTP middleware: request-id propagation and body size limits.""" from __future__ import annotations import uuid from starlette.datastructures import MutableHeaders from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response from starlette.types import ASGIApp, Message, Receive, Scope, Send class RequestIDMiddleware(BaseHTTPMiddleware): """Read or generate ``X-Request-ID`` and stash it on ``request.state``.""" def __init__(self, app: ASGIApp, header_name: str = "X-Request-ID") -> None: super().__init__(app) self.header_name = header_name async def dispatch(self, request: Request, call_next): # type: ignore[override] rid = request.headers.get(self.header_name) or uuid.uuid4().hex request.state.request_id = rid response: Response = await call_next(request) response.headers[self.header_name] = rid return response class BodySizeLimitMiddleware: """Reject requests whose body exceeds ``max_bytes`` early. Inspects the ``Content-Length`` header (if present) and also enforces a streaming cap. Returns ``413`` with a JSON error body that matches the rest of the API. """ def __init__(self, app: ASGIApp, max_bytes: int) -> None: self._app = app self._max_bytes = max(1024, max_bytes) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": await self._app(scope, receive, send) return request_headers = MutableHeaders(scope=scope) cl = request_headers.get("content-length") if cl is not None: try: if int(cl) > self._max_bytes: await _send_413(send, self._max_bytes) return except ValueError: pass received = 0 too_large = False async def receive_capped() -> Message: nonlocal received, too_large message = await receive() if message.get("type") == "http.request": body: bytes = message.get("body", b"") or b"" received += len(body) if received > self._max_bytes: too_large = True return message sent_413 = False async def send_wrapper(message: Message) -> None: nonlocal sent_413 if too_large and not sent_413: sent_413 = True await _send_413(send, self._max_bytes) return if not too_large: await send(message) try: await self._app(scope, receive_capped, send_wrapper) except Exception: if too_large and not sent_413: await _send_413(send, self._max_bytes) return raise class RateLimitHeadersMiddleware(BaseHTTPMiddleware): """Copy ``request.state.rate_limit_headers`` (if set by a route) onto the outgoing response.""" async def dispatch(self, request: Request, call_next): # type: ignore[override] response: Response = await call_next(request) headers = getattr(request.state, "rate_limit_headers", None) if headers: for name, value in headers.items(): response.headers[name] = value return response async def _send_413(send: Send, max_bytes: int) -> None: body = ( b'{"error":{"message":"request body too large","type":"invalid_request_error",' b'"code":"payload_too_large","limit_bytes":' + str(max_bytes).encode() + b"}}" ) await send( { "type": "http.response.start", "status": 413, "headers": [ (b"content-type", b"application/json"), (b"content-length", str(len(body)).encode()), ], } ) await send({"type": "http.response.body", "body": body, "more_body": False})