apoorvrajdev's picture
feat(api): build production-grade FastAPI inference backend
08f1adc
"""HTTP-layer logging glue.
The ML package already configures structlog (``captioning.utils.logging``).
The FastAPI process has two extra needs on top of that:
1. **Request correlation** — every log line emitted while handling a
request should carry the same ``request_id`` so logs can be grouped.
We bind it once via ``structlog.contextvars`` so any ``log.info(...)``
downstream automatically inherits it without threading the id through
function signatures.
2. **Access logs as structured events** — uvicorn's default access log is
a plain string. Re-emitting one structured ``request_finished`` event
per request keeps the log stream homogeneous and indexable.
"""
from __future__ import annotations
import time
import uuid
from collections.abc import Awaitable, Callable
import structlog
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
from captioning.utils.logging import configure_logging, get_logger
log = get_logger(__name__)
REQUEST_ID_HEADER = "x-request-id"
def configure_app_logging() -> None:
"""Initialise structlog for the FastAPI process.
Idempotent — delegates to the ML package's ``configure_logging`` so dev
gets pretty colourised output and ``APP_ENV=production`` flips to JSON.
"""
configure_logging()
class RequestContextMiddleware(BaseHTTPMiddleware):
"""Bind a request id to structlog and log start/finish events.
The id comes from the inbound ``x-request-id`` header when present
(so an upstream gateway can stitch traces), or a fresh ``uuid4`` hex
otherwise. Either way it's echoed back on the response.
"""
async def dispatch(
self,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
request_id = request.headers.get(REQUEST_ID_HEADER) or uuid.uuid4().hex
structlog.contextvars.clear_contextvars()
structlog.contextvars.bind_contextvars(
request_id=request_id,
method=request.method,
path=request.url.path,
)
start = time.perf_counter()
log.info("request_started")
try:
response = await call_next(request)
except Exception:
duration_ms = (time.perf_counter() - start) * 1000
log.exception("request_failed", duration_ms=round(duration_ms, 2))
raise
duration_ms = (time.perf_counter() - start) * 1000
log.info(
"request_finished",
status=response.status_code,
duration_ms=round(duration_ms, 2),
)
response.headers[REQUEST_ID_HEADER] = request_id
return response
def current_request_id() -> str:
"""Return the request id bound to the current contextvars, or ``""``."""
return str(structlog.contextvars.get_contextvars().get("request_id", ""))