File size: 5,618 Bytes
c4742ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""Centralized exception handlers.

Goals:

* Never leak ``str(exception)`` directly to clients (the previous code
  echoed full Python tracebacks via the FastAPI 500 detail).
* Always include a ``request_id`` in the response body so support can
  correlate with logs.
* Return shape is OpenAI-compatible so existing client SDKs work:
  ``{"error": {"message": ..., "type": ..., "code": ...}}``.
"""

from __future__ import annotations

import logging
from typing import Any

from fastapi import FastAPI, HTTPException, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import ORJSONResponse

from ..core.registry import UnknownModelError
from ..core.service import ImageFetchError
from ..security.concurrency import ConcurrencyError

logger = logging.getLogger(__name__)


def _request_id(request: Request) -> str | None:
    return getattr(request.state, "request_id", None) or request.headers.get("x-request-id")


def _error_body(
    *,
    message: str,
    code: str,
    type_: str,
    request_id: str | None,
    extra: dict[str, Any] | None = None,
) -> dict[str, Any]:
    body: dict[str, Any] = {"error": {"message": message, "type": type_, "code": code}}
    if request_id:
        body["error"]["request_id"] = request_id
    if extra:
        body["error"].update(extra)
    return body


def install_handlers(app: FastAPI) -> None:
    """Register all error handlers on ``app``."""

    @app.exception_handler(HTTPException)
    async def _http_handler(request: Request, exc: HTTPException) -> ORJSONResponse:
        rid = _request_id(request)
        # Preserve any caller-provided headers (e.g. Retry-After from the
        # rate limiter) but always set Content-Type via ORJSONResponse.
        return ORJSONResponse(
            status_code=exc.status_code,
            content=_error_body(
                message=str(exc.detail) if not isinstance(exc.detail, dict) else "request failed",
                code=str(exc.status_code),
                type_=_status_to_type(exc.status_code),
                request_id=rid,
                extra=exc.detail if isinstance(exc.detail, dict) else None,
            ),
            headers=exc.headers,
        )

    @app.exception_handler(RequestValidationError)
    async def _validation_handler(request: Request, exc: RequestValidationError) -> ORJSONResponse:
        rid = _request_id(request)
        return ORJSONResponse(
            status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
            content=_error_body(
                message="invalid request",
                code="invalid_request",
                type_="invalid_request_error",
                request_id=rid,
                extra={"details": exc.errors()},
            ),
        )

    @app.exception_handler(UnknownModelError)
    async def _unknown_model(request: Request, exc: UnknownModelError) -> ORJSONResponse:
        rid = _request_id(request)
        return ORJSONResponse(
            status_code=status.HTTP_400_BAD_REQUEST,
            content=_error_body(
                message=f"unknown model: {exc.name!r}",
                code="unknown_model",
                type_="invalid_request_error",
                request_id=rid,
            ),
        )

    @app.exception_handler(ImageFetchError)
    async def _image_error(request: Request, exc: ImageFetchError) -> ORJSONResponse:
        rid = _request_id(request)
        return ORJSONResponse(
            status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
            content=_error_body(
                message=str(exc),
                code="image_fetch_error",
                type_="invalid_request_error",
                request_id=rid,
            ),
        )

    @app.exception_handler(ConcurrencyError)
    async def _concurrency_error(request: Request, exc: ConcurrencyError) -> ORJSONResponse:
        rid = _request_id(request)
        return ORJSONResponse(
            status_code=status.HTTP_429_TOO_MANY_REQUESTS,
            content=_error_body(
                message=str(exc),
                code="too_many_concurrent_requests",
                type_="rate_limit_error",
                request_id=rid,
            ),
            headers={"Retry-After": "1"},
        )

    @app.exception_handler(ValueError)
    async def _value_error(request: Request, exc: ValueError) -> ORJSONResponse:
        rid = _request_id(request)
        return ORJSONResponse(
            status_code=status.HTTP_400_BAD_REQUEST,
            content=_error_body(
                message=str(exc),
                code="invalid_request",
                type_="invalid_request_error",
                request_id=rid,
            ),
        )

    @app.exception_handler(Exception)
    async def _unhandled(request: Request, exc: Exception) -> ORJSONResponse:
        rid = _request_id(request)
        # Log the full traceback server-side; keep the public response generic.
        logger.exception("unhandled error on %s %s [request_id=%s]",
                         request.method, request.url.path, rid)
        return ORJSONResponse(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            content=_error_body(
                message="internal server error",
                code="internal_error",
                type_="server_error",
                request_id=rid,
            ),
        )


def _status_to_type(code: int) -> str:
    if code == 429:
        return "rate_limit_error"
    if 400 <= code < 500:
        return "invalid_request_error"
    if 500 <= code < 600:
        return "server_error"
    return "api_error"