|
|
""" |
|
|
Middleware for correlation ID and request tracking |
|
|
""" |
|
|
import contextvars |
|
|
import time |
|
|
import uuid |
|
|
from typing import Callable |
|
|
|
|
|
from fastapi import Request, Response |
|
|
from starlette.middleware.base import BaseHTTPMiddleware |
|
|
from starlette.types import ASGIApp |
|
|
|
|
|
from .config import settings |
|
|
from .logging import get_logger |
|
|
|
|
|
correlation_id_ctx = contextvars.ContextVar("correlation_id", default=None) |
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
class CorrelationIdMiddleware(BaseHTTPMiddleware): |
|
|
"""Add correlation ID to requests and responses""" |
|
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response: |
|
|
|
|
|
correlation_id = request.headers.get( |
|
|
settings.correlation_id_header, |
|
|
str(uuid.uuid4()) |
|
|
) |
|
|
|
|
|
|
|
|
correlation_id_ctx.set(correlation_id) |
|
|
|
|
|
|
|
|
response = await call_next(request) |
|
|
|
|
|
|
|
|
response.headers[settings.correlation_id_header] = correlation_id |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
class RequestLoggingMiddleware(BaseHTTPMiddleware): |
|
|
"""Log all requests with timing""" |
|
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response: |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
logger.info( |
|
|
"request_started", |
|
|
method=request.method, |
|
|
path=request.url.path, |
|
|
query_params=str(request.query_params), |
|
|
) |
|
|
|
|
|
|
|
|
response = await call_next(request) |
|
|
|
|
|
|
|
|
duration = time.time() - start_time |
|
|
|
|
|
|
|
|
logger.info( |
|
|
"request_completed", |
|
|
method=request.method, |
|
|
path=request.url.path, |
|
|
status_code=response.status_code, |
|
|
duration_ms=round(duration * 1000, 2), |
|
|
) |
|
|
|
|
|
return response |
|
|
|