Spaces:
Running
Running
| """Centralized exception handlers. | |
| Goals: | |
| * Never leak ``str(exception)`` directly to clients (the previous code | |
| echoed full Python tracebacks via the FastAPI 500 detail). | |
| * Always include a ``request_id`` in the response body so support can | |
| correlate with logs. | |
| * Return shape is OpenAI-compatible so existing client SDKs work: | |
| ``{"error": {"message": ..., "type": ..., "code": ...}}``. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from typing import Any | |
| from fastapi import FastAPI, HTTPException, Request, status | |
| from fastapi.exceptions import RequestValidationError | |
| from fastapi.responses import ORJSONResponse | |
| from ..core.registry import UnknownModelError | |
| from ..core.service import ImageFetchError | |
| from ..security.concurrency import ConcurrencyError | |
| logger = logging.getLogger(__name__) | |
| def _request_id(request: Request) -> str | None: | |
| return getattr(request.state, "request_id", None) or request.headers.get("x-request-id") | |
| def _error_body( | |
| *, | |
| message: str, | |
| code: str, | |
| type_: str, | |
| request_id: str | None, | |
| extra: dict[str, Any] | None = None, | |
| ) -> dict[str, Any]: | |
| body: dict[str, Any] = {"error": {"message": message, "type": type_, "code": code}} | |
| if request_id: | |
| body["error"]["request_id"] = request_id | |
| if extra: | |
| body["error"].update(extra) | |
| return body | |
| def install_handlers(app: FastAPI) -> None: | |
| """Register all error handlers on ``app``.""" | |
| async def _http_handler(request: Request, exc: HTTPException) -> ORJSONResponse: | |
| rid = _request_id(request) | |
| # Preserve any caller-provided headers (e.g. Retry-After from the | |
| # rate limiter) but always set Content-Type via ORJSONResponse. | |
| return ORJSONResponse( | |
| status_code=exc.status_code, | |
| content=_error_body( | |
| message=str(exc.detail) if not isinstance(exc.detail, dict) else "request failed", | |
| code=str(exc.status_code), | |
| type_=_status_to_type(exc.status_code), | |
| request_id=rid, | |
| extra=exc.detail if isinstance(exc.detail, dict) else None, | |
| ), | |
| headers=exc.headers, | |
| ) | |
| async def _validation_handler(request: Request, exc: RequestValidationError) -> ORJSONResponse: | |
| rid = _request_id(request) | |
| return ORJSONResponse( | |
| status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, | |
| content=_error_body( | |
| message="invalid request", | |
| code="invalid_request", | |
| type_="invalid_request_error", | |
| request_id=rid, | |
| extra={"details": exc.errors()}, | |
| ), | |
| ) | |
| async def _unknown_model(request: Request, exc: UnknownModelError) -> ORJSONResponse: | |
| rid = _request_id(request) | |
| return ORJSONResponse( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| content=_error_body( | |
| message=f"unknown model: {exc.name!r}", | |
| code="unknown_model", | |
| type_="invalid_request_error", | |
| request_id=rid, | |
| ), | |
| ) | |
| async def _image_error(request: Request, exc: ImageFetchError) -> ORJSONResponse: | |
| rid = _request_id(request) | |
| return ORJSONResponse( | |
| status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, | |
| content=_error_body( | |
| message=str(exc), | |
| code="image_fetch_error", | |
| type_="invalid_request_error", | |
| request_id=rid, | |
| ), | |
| ) | |
| async def _concurrency_error(request: Request, exc: ConcurrencyError) -> ORJSONResponse: | |
| rid = _request_id(request) | |
| return ORJSONResponse( | |
| status_code=status.HTTP_429_TOO_MANY_REQUESTS, | |
| content=_error_body( | |
| message=str(exc), | |
| code="too_many_concurrent_requests", | |
| type_="rate_limit_error", | |
| request_id=rid, | |
| ), | |
| headers={"Retry-After": "1"}, | |
| ) | |
| async def _value_error(request: Request, exc: ValueError) -> ORJSONResponse: | |
| rid = _request_id(request) | |
| return ORJSONResponse( | |
| status_code=status.HTTP_400_BAD_REQUEST, | |
| content=_error_body( | |
| message=str(exc), | |
| code="invalid_request", | |
| type_="invalid_request_error", | |
| request_id=rid, | |
| ), | |
| ) | |
| async def _unhandled(request: Request, exc: Exception) -> ORJSONResponse: | |
| rid = _request_id(request) | |
| # Log the full traceback server-side; keep the public response generic. | |
| logger.exception("unhandled error on %s %s [request_id=%s]", | |
| request.method, request.url.path, rid) | |
| return ORJSONResponse( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| content=_error_body( | |
| message="internal server error", | |
| code="internal_error", | |
| type_="server_error", | |
| request_id=rid, | |
| ), | |
| ) | |
| def _status_to_type(code: int) -> str: | |
| if code == 429: | |
| return "rate_limit_error" | |
| if 400 <= code < 500: | |
| return "invalid_request_error" | |
| if 500 <= code < 600: | |
| return "server_error" | |
| return "api_error" | |