| from collections.abc import Awaitable |
| from datetime import datetime |
| from typing import Callable |
|
|
| from fastapi import Request, Response |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.middleware.gzip import GZipMiddleware |
| from fastapi.middleware.trustedhost import TrustedHostMiddleware |
| from sentry_sdk.integrations.asgi import SentryAsgiMiddleware |
| from sentry_sdk.integrations.httpx import HttpxIntegration |
| from starlette.datastructures import MutableHeaders |
|
|
| from hibiapi.utils.config import Config |
| from hibiapi.utils.exceptions import BaseServerException, UncaughtException |
| from hibiapi.utils.log import LoguruHandler, logger |
| from hibiapi.utils.routing import request_headers, response_headers |
|
|
| from .application import app |
| from .handlers import exception_handler |
|
|
| RequestHandler = Callable[[Request], Awaitable[Response]] |
|
|
|
|
| if Config["server"]["gzip"].as_bool(): |
| app.add_middleware(GZipMiddleware) |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=Config["server"]["cors"]["origins"].get(list[str]), |
| allow_credentials=Config["server"]["cors"]["credentials"].as_bool(), |
| allow_methods=Config["server"]["cors"]["methods"].get(list[str]), |
| allow_headers=Config["server"]["cors"]["headers"].get(list[str]), |
| ) |
| app.add_middleware( |
| TrustedHostMiddleware, |
| allowed_hosts=Config["server"]["allowed"].get(list[str]), |
| ) |
| app.add_middleware(SentryAsgiMiddleware) |
|
|
| HttpxIntegration.setup_once() |
|
|
|
|
| @app.middleware("http") |
| async def request_logger(request: Request, call_next: RequestHandler) -> Response: |
| start_time = datetime.now() |
| host, port = request.client or (None, None) |
| response = await call_next(request) |
| process_time = (datetime.now() - start_time).total_seconds() * 1000 |
| response_headers.get().setdefault("X-Process-Time", f"{process_time:.3f}") |
| bg, fg = ( |
| ("green", "red") |
| if response.status_code < 400 |
| else ("yellow", "blue") |
| if response.status_code < 500 |
| else ("red", "green") |
| ) |
| status_code, method = response.status_code, request.method.upper() |
| user_agent = ( |
| LoguruHandler.escape_tag(request.headers["user-agent"]) |
| if "user-agent" in request.headers |
| else "<d>Unknown</d>" |
| ) |
| logger.info( |
| f"<m><b>{host}</b>:{port}</m>" |
| f" | <{bg.upper()}><b><{fg}>{method}</{fg}></b></{bg.upper()}>" |
| f" | <n><b>{str(request.url)!r}</b></n>" |
| f" | <c>{process_time:.3f}ms</c>" |
| f" | <e>{user_agent}</e>" |
| f" | <b><{bg}>{status_code}</{bg}></b>" |
| ) |
| return response |
|
|
|
|
| @app.middleware("http") |
| async def contextvar_setter(request: Request, call_next: RequestHandler): |
| request_headers.set(request.headers) |
| response_headers.set(MutableHeaders()) |
| response = await call_next(request) |
| response.headers.update({**response_headers.get()}) |
| return response |
|
|
|
|
| @app.middleware("http") |
| async def uncaught_exception_handler( |
| request: Request, call_next: RequestHandler |
| ) -> Response: |
| try: |
| response = await call_next(request) |
| except Exception as error: |
| response = await exception_handler( |
| request, |
| exc=( |
| error |
| if isinstance(error, BaseServerException) |
| else UncaughtException.with_exception(error) |
| ), |
| ) |
| return response |
|
|