Spaces:
Running
Running
| from collections.abc import Awaitable | |
| from datetime import datetime | |
| from typing import Callable | |
| from fastapi import Request, Response | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.middleware.gzip import GZipMiddleware | |
| from fastapi.middleware.trustedhost import TrustedHostMiddleware | |
| from sentry_sdk.integrations.asgi import SentryAsgiMiddleware | |
| from sentry_sdk.integrations.httpx import HttpxIntegration | |
| from starlette.datastructures import MutableHeaders | |
| from hibiapi.utils.config import Config | |
| from hibiapi.utils.exceptions import BaseServerException, UncaughtException | |
| from hibiapi.utils.log import LoguruHandler, logger | |
| from hibiapi.utils.routing import request_headers, response_headers | |
| from .application import app | |
| from .handlers import exception_handler | |
| RequestHandler = Callable[[Request], Awaitable[Response]] | |
| if Config["server"]["gzip"].as_bool(): | |
| app.add_middleware(GZipMiddleware) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=Config["server"]["cors"]["origins"].get(list[str]), | |
| allow_credentials=Config["server"]["cors"]["credentials"].as_bool(), | |
| allow_methods=Config["server"]["cors"]["methods"].get(list[str]), | |
| allow_headers=Config["server"]["cors"]["headers"].get(list[str]), | |
| ) | |
| app.add_middleware( | |
| TrustedHostMiddleware, | |
| allowed_hosts=Config["server"]["allowed"].get(list[str]), | |
| ) | |
| app.add_middleware(SentryAsgiMiddleware) | |
| HttpxIntegration.setup_once() | |
| async def request_logger(request: Request, call_next: RequestHandler) -> Response: | |
| start_time = datetime.now() | |
| host, port = request.client or (None, None) | |
| response = await call_next(request) | |
| process_time = (datetime.now() - start_time).total_seconds() * 1000 | |
| response_headers.get().setdefault("X-Process-Time", f"{process_time:.3f}") | |
| bg, fg = ( | |
| ("green", "red") | |
| if response.status_code < 400 | |
| else ("yellow", "blue") | |
| if response.status_code < 500 | |
| else ("red", "green") | |
| ) | |
| status_code, method = response.status_code, request.method.upper() | |
| user_agent = ( | |
| LoguruHandler.escape_tag(request.headers["user-agent"]) | |
| if "user-agent" in request.headers | |
| else "<d>Unknown</d>" | |
| ) | |
| logger.info( | |
| f"<m><b>{host}</b>:{port}</m>" | |
| f" | <{bg.upper()}><b><{fg}>{method}</{fg}></b></{bg.upper()}>" | |
| f" | <n><b>{str(request.url)!r}</b></n>" | |
| f" | <c>{process_time:.3f}ms</c>" | |
| f" | <e>{user_agent}</e>" | |
| f" | <b><{bg}>{status_code}</{bg}></b>" | |
| ) | |
| return response | |
| async def contextvar_setter(request: Request, call_next: RequestHandler): | |
| request_headers.set(request.headers) | |
| response_headers.set(MutableHeaders()) | |
| response = await call_next(request) | |
| response.headers.update({**response_headers.get()}) | |
| return response | |
| async def uncaught_exception_handler( | |
| request: Request, call_next: RequestHandler | |
| ) -> Response: | |
| try: | |
| response = await call_next(request) | |
| except Exception as error: | |
| response = await exception_handler( | |
| request, | |
| exc=( | |
| error | |
| if isinstance(error, BaseServerException) | |
| else UncaughtException.with_exception(error) | |
| ), | |
| ) | |
| return response | |