ModPilot / api /middleware.py
ThejasRao's picture
Deploy ModPilot Investigation Engine
7302343
Raw
History Blame Contribute Delete
4.03 kB
"""HMAC + correlation-id middleware.
Spec: docs/Specs.md §10 (HMAC), §15 (correlation_id in every log).
Dev mode: middleware is permissive — logs a warning and lets the request through
when `ENGINE_SHARED_SECRET` is unset. Production: rejects with UNAUTHORIZED.
"""
from __future__ import annotations
import hmac
import time
import uuid
from hashlib import sha256
from typing import TYPE_CHECKING
import structlog
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from api.config import get_settings
from api.errors import ErrorBody, ErrorEnvelope
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from fastapi import Request, Response
logger = structlog.get_logger(__name__)
# Endpoints that bypass HMAC entirely (health probes, no body).
_PUBLIC_PATHS: frozenset[str] = frozenset({"/health", "/docs", "/openapi.json", "/redoc"})
# Header names — matched on the Devvit-side client.
HEADER_SIGNATURE = "x-modpilot-signature"
HEADER_CORRELATION = "x-correlation-id"
HEADER_TIMESTAMP = "x-modpilot-timestamp"
# Tolerance for clock skew on signed requests.
_MAX_SKEW_SECONDS = 300
class CorrelationIdMiddleware(BaseHTTPMiddleware):
"""Bind `correlation_id` into structlog context for the request's lifetime."""
async def dispatch(
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
correlation_id = request.headers.get(HEADER_CORRELATION) or str(uuid.uuid4())
structlog.contextvars.clear_contextvars()
structlog.contextvars.bind_contextvars(
correlation_id=correlation_id,
path=request.url.path,
method=request.method,
)
try:
response = await call_next(request)
finally:
structlog.contextvars.clear_contextvars()
response.headers[HEADER_CORRELATION] = correlation_id
return response
class HmacMiddleware(BaseHTTPMiddleware):
"""Verify `X-Modpilot-Signature` against HMAC-SHA256 of the raw body.
Dev mode (no `ENGINE_SHARED_SECRET`): log + pass through. This keeps the
invariant "Engine refuses to start without a secret in prod" landing in F-0.7.
"""
async def dispatch( # noqa: PLR0911 — guard chain reads cleaner with explicit returns
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
if request.url.path in _PUBLIC_PATHS or request.method == "OPTIONS":
return await call_next(request)
settings = get_settings()
if not settings.hmac_enforced:
if not settings.engine_shared_secret:
logger.warning(
"hmac.permissive",
reason="ENGINE_SHARED_SECRET not configured",
env=settings.env,
)
return await call_next(request)
body = await request.body()
signature = request.headers.get(HEADER_SIGNATURE, "")
timestamp = request.headers.get(HEADER_TIMESTAMP, "")
if not signature or not timestamp:
return _unauthorized("missing signature or timestamp")
try:
skew = abs(time.time() - float(timestamp))
except ValueError:
return _unauthorized("malformed timestamp")
if skew > _MAX_SKEW_SECONDS:
return _unauthorized(f"timestamp skew {skew:.0f}s exceeds {_MAX_SKEW_SECONDS}s")
expected = hmac.new(
settings.engine_shared_secret.encode("utf-8"),
f"{timestamp}.".encode() + body,
sha256,
).hexdigest()
if not hmac.compare_digest(expected, signature):
return _unauthorized("signature mismatch")
return await call_next(request)
def _unauthorized(message: str) -> JSONResponse:
body = ErrorEnvelope(error=ErrorBody(code="UNAUTHORIZED", message=message, retryable=False))
return JSONResponse(status_code=401, content=body.model_dump())