Spaces:
Running
Running
File size: 5,618 Bytes
c4742ee | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | """Centralized exception handlers.
Goals:
* Never leak ``str(exception)`` directly to clients (the previous code
echoed full Python tracebacks via the FastAPI 500 detail).
* Always include a ``request_id`` in the response body so support can
correlate with logs.
* Return shape is OpenAI-compatible so existing client SDKs work:
``{"error": {"message": ..., "type": ..., "code": ...}}``.
"""
from __future__ import annotations
import logging
from typing import Any
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.responses import ORJSONResponse
from ..core.registry import UnknownModelError
from ..core.service import ImageFetchError
from ..security.concurrency import ConcurrencyError
logger = logging.getLogger(__name__)
def _request_id(request: Request) -> str | None:
return getattr(request.state, "request_id", None) or request.headers.get("x-request-id")
def _error_body(
*,
message: str,
code: str,
type_: str,
request_id: str | None,
extra: dict[str, Any] | None = None,
) -> dict[str, Any]:
body: dict[str, Any] = {"error": {"message": message, "type": type_, "code": code}}
if request_id:
body["error"]["request_id"] = request_id
if extra:
body["error"].update(extra)
return body
def install_handlers(app: FastAPI) -> None:
"""Register all error handlers on ``app``."""
@app.exception_handler(HTTPException)
async def _http_handler(request: Request, exc: HTTPException) -> ORJSONResponse:
rid = _request_id(request)
# Preserve any caller-provided headers (e.g. Retry-After from the
# rate limiter) but always set Content-Type via ORJSONResponse.
return ORJSONResponse(
status_code=exc.status_code,
content=_error_body(
message=str(exc.detail) if not isinstance(exc.detail, dict) else "request failed",
code=str(exc.status_code),
type_=_status_to_type(exc.status_code),
request_id=rid,
extra=exc.detail if isinstance(exc.detail, dict) else None,
),
headers=exc.headers,
)
@app.exception_handler(RequestValidationError)
async def _validation_handler(request: Request, exc: RequestValidationError) -> ORJSONResponse:
rid = _request_id(request)
return ORJSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=_error_body(
message="invalid request",
code="invalid_request",
type_="invalid_request_error",
request_id=rid,
extra={"details": exc.errors()},
),
)
@app.exception_handler(UnknownModelError)
async def _unknown_model(request: Request, exc: UnknownModelError) -> ORJSONResponse:
rid = _request_id(request)
return ORJSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=_error_body(
message=f"unknown model: {exc.name!r}",
code="unknown_model",
type_="invalid_request_error",
request_id=rid,
),
)
@app.exception_handler(ImageFetchError)
async def _image_error(request: Request, exc: ImageFetchError) -> ORJSONResponse:
rid = _request_id(request)
return ORJSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=_error_body(
message=str(exc),
code="image_fetch_error",
type_="invalid_request_error",
request_id=rid,
),
)
@app.exception_handler(ConcurrencyError)
async def _concurrency_error(request: Request, exc: ConcurrencyError) -> ORJSONResponse:
rid = _request_id(request)
return ORJSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content=_error_body(
message=str(exc),
code="too_many_concurrent_requests",
type_="rate_limit_error",
request_id=rid,
),
headers={"Retry-After": "1"},
)
@app.exception_handler(ValueError)
async def _value_error(request: Request, exc: ValueError) -> ORJSONResponse:
rid = _request_id(request)
return ORJSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=_error_body(
message=str(exc),
code="invalid_request",
type_="invalid_request_error",
request_id=rid,
),
)
@app.exception_handler(Exception)
async def _unhandled(request: Request, exc: Exception) -> ORJSONResponse:
rid = _request_id(request)
# Log the full traceback server-side; keep the public response generic.
logger.exception("unhandled error on %s %s [request_id=%s]",
request.method, request.url.path, rid)
return ORJSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=_error_body(
message="internal server error",
code="internal_error",
type_="server_error",
request_id=rid,
),
)
def _status_to_type(code: int) -> str:
if code == 429:
return "rate_limit_error"
if 400 <= code < 500:
return "invalid_request_error"
if 500 <= code < 600:
return "server_error"
return "api_error"
|