File size: 4,027 Bytes
7302343
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""HMAC + correlation-id middleware.

Spec: docs/Specs.md §10 (HMAC), §15 (correlation_id in every log).
Dev mode: middleware is permissive — logs a warning and lets the request through
when `ENGINE_SHARED_SECRET` is unset. Production: rejects with UNAUTHORIZED.
"""

from __future__ import annotations

import hmac
import time
import uuid
from hashlib import sha256
from typing import TYPE_CHECKING

import structlog
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse

from api.config import get_settings
from api.errors import ErrorBody, ErrorEnvelope

if TYPE_CHECKING:
    from collections.abc import Awaitable, Callable

    from fastapi import Request, Response

logger = structlog.get_logger(__name__)

# Endpoints that bypass HMAC entirely (health probes, no body).
_PUBLIC_PATHS: frozenset[str] = frozenset({"/health", "/docs", "/openapi.json", "/redoc"})

# Header names — matched on the Devvit-side client.
HEADER_SIGNATURE = "x-modpilot-signature"
HEADER_CORRELATION = "x-correlation-id"
HEADER_TIMESTAMP = "x-modpilot-timestamp"

# Tolerance for clock skew on signed requests.
_MAX_SKEW_SECONDS = 300


class CorrelationIdMiddleware(BaseHTTPMiddleware):
    """Bind `correlation_id` into structlog context for the request's lifetime."""

    async def dispatch(
        self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
    ) -> Response:
        correlation_id = request.headers.get(HEADER_CORRELATION) or str(uuid.uuid4())
        structlog.contextvars.clear_contextvars()
        structlog.contextvars.bind_contextvars(
            correlation_id=correlation_id,
            path=request.url.path,
            method=request.method,
        )
        try:
            response = await call_next(request)
        finally:
            structlog.contextvars.clear_contextvars()
        response.headers[HEADER_CORRELATION] = correlation_id
        return response


class HmacMiddleware(BaseHTTPMiddleware):
    """Verify `X-Modpilot-Signature` against HMAC-SHA256 of the raw body.

    Dev mode (no `ENGINE_SHARED_SECRET`): log + pass through. This keeps the
    invariant "Engine refuses to start without a secret in prod" landing in F-0.7.
    """

    async def dispatch(  # noqa: PLR0911 — guard chain reads cleaner with explicit returns
        self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
    ) -> Response:
        if request.url.path in _PUBLIC_PATHS or request.method == "OPTIONS":
            return await call_next(request)

        settings = get_settings()

        if not settings.hmac_enforced:
            if not settings.engine_shared_secret:
                logger.warning(
                    "hmac.permissive",
                    reason="ENGINE_SHARED_SECRET not configured",
                    env=settings.env,
                )
            return await call_next(request)

        body = await request.body()
        signature = request.headers.get(HEADER_SIGNATURE, "")
        timestamp = request.headers.get(HEADER_TIMESTAMP, "")

        if not signature or not timestamp:
            return _unauthorized("missing signature or timestamp")

        try:
            skew = abs(time.time() - float(timestamp))
        except ValueError:
            return _unauthorized("malformed timestamp")
        if skew > _MAX_SKEW_SECONDS:
            return _unauthorized(f"timestamp skew {skew:.0f}s exceeds {_MAX_SKEW_SECONDS}s")

        expected = hmac.new(
            settings.engine_shared_secret.encode("utf-8"),
            f"{timestamp}.".encode() + body,
            sha256,
        ).hexdigest()

        if not hmac.compare_digest(expected, signature):
            return _unauthorized("signature mismatch")

        return await call_next(request)


def _unauthorized(message: str) -> JSONResponse:
    body = ErrorEnvelope(error=ErrorBody(code="UNAUTHORIZED", message=message, retryable=False))
    return JSONResponse(status_code=401, content=body.model_dump())