Spaces:
Sleeping
Sleeping
| """ | |
| Middleware Chain - Orchestration of multiple middleware layers. | |
| Provides utilities for managing and coordinating multiple middleware | |
| components in the request/response flow. | |
| Usage: | |
| # In app.py | |
| from services.base_service import MiddlewareChain | |
| # Add middleware in reverse order (last added = first executed) | |
| app.add_middleware(CreditMiddleware) | |
| app.add_middleware(AuthMiddleware) | |
| # Or use the chain helper | |
| chain = MiddlewareChain() | |
| chain.add(AuthMiddleware) | |
| chain.add(CreditMiddleware) | |
| chain.apply_to_app(app) | |
| """ | |
| import logging | |
| from typing import List, Type, Callable | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from fastapi import FastAPI, Request, Response | |
| logger = logging.getLogger(__name__) | |
| class RequestContext: | |
| """ | |
| Shared context for passing data between middleware layers. | |
| Attached to request.state for access across middleware and routers. | |
| """ | |
| def __init__(self): | |
| """Initialize empty context.""" | |
| # Auth layer | |
| self.user = None | |
| self.is_authenticated = False | |
| # Credit layer | |
| self.credits_reserved = 0 | |
| self.credit_cost = 0 | |
| # General | |
| self.start_time = None | |
| self.service_flags = {} | |
| def set_user(self, user) -> None: | |
| """Set authenticated user.""" | |
| self.user = user | |
| self.is_authenticated = True | |
| def set_credits(self, reserved: int, cost: int) -> None: | |
| """Set credit information.""" | |
| self.credits_reserved = reserved | |
| self.credit_cost = cost | |
| def set_flag(self, key: str, value: any) -> None: | |
| """Set a service-specific flag.""" | |
| self.service_flags[key] = value | |
| def get_flag(self, key: str, default=None) -> any: | |
| """Get a service-specific flag.""" | |
| return self.service_flags.get(key, default) | |
| class MiddlewareChain: | |
| """ | |
| Helper for managing middleware registration order. | |
| FastAPI/Starlette middleware executes in REVERSE order of registration, | |
| so the LAST middleware added is the FIRST to execute. | |
| This class helps manage the order explicitly. | |
| """ | |
| def __init__(self): | |
| """Initialize empty middleware chain.""" | |
| self._middleware: List[Type[BaseHTTPMiddleware]] = [] | |
| def add(self, middleware_class: Type[BaseHTTPMiddleware], **kwargs) -> 'MiddlewareChain': | |
| """ | |
| Add middleware to the chain. | |
| Middleware is added to the END of the list, but will be registered | |
| in REVERSE order (so first added = first executed). | |
| Args: | |
| middleware_class: Middleware class to add | |
| **kwargs: Arguments to pass to middleware constructor | |
| Returns: | |
| Self for chaining | |
| """ | |
| self._middleware.append((middleware_class, kwargs)) | |
| logger.debug(f"Added middleware to chain: {middleware_class.__name__}") | |
| return self | |
| def apply_to_app(self, app: FastAPI) -> None: | |
| """ | |
| Apply all middleware to the FastAPI app in correct order. | |
| Middleware is registered in REVERSE order so that the first | |
| middleware added to the chain is the first to execute. | |
| Args: | |
| app: FastAPI application instance | |
| """ | |
| # Reverse the list so first added = first executed | |
| for middleware_class, kwargs in reversed(self._middleware): | |
| app.add_middleware(middleware_class, **kwargs) | |
| logger.info(f"Registered middleware: {middleware_class.__name__}") | |
| def get_middleware_list(self) -> List[Type[BaseHTTPMiddleware]]: | |
| """ | |
| Get the list of middleware in execution order. | |
| Returns: | |
| List of middleware classes in the order they will execute | |
| """ | |
| return [m[0] for m in self._middleware] | |
| def __len__(self) -> int: | |
| """Get number of middleware in chain.""" | |
| return len(self._middleware) | |
| def __repr__(self) -> str: | |
| """String representation for debugging.""" | |
| middleware_names = [m[0].__name__ for m in self._middleware] | |
| return f"MiddlewareChain({middleware_names})" | |
| async def initialize_request_context(request: Request) -> None: | |
| """ | |
| Initialize request context for middleware to use. | |
| This should be called early in the middleware chain to ensure | |
| request.state.ctx is available. | |
| Usage: | |
| class MyMiddleware(BaseHTTPMiddleware): | |
| async def dispatch(self, request: Request, call_next): | |
| await initialize_request_context(request) | |
| # Now request.state.ctx is available | |
| ... | |
| """ | |
| if not hasattr(request.state, "ctx"): | |
| request.state.ctx = RequestContext() | |
| def get_request_context(request: Request) -> RequestContext: | |
| """ | |
| Get request context from request.state. | |
| Creates context if it doesn't exist. | |
| Args: | |
| request: FastAPI request object | |
| Returns: | |
| RequestContext instance | |
| """ | |
| if not hasattr(request.state, "ctx"): | |
| request.state.ctx = RequestContext() | |
| return request.state.ctx | |
| class BaseServiceMiddleware(BaseHTTPMiddleware): | |
| """ | |
| Base class for service middleware. | |
| Provides common functionality for all service middleware: | |
| - Request context initialization | |
| - Error handling | |
| - Logging | |
| """ | |
| SERVICE_NAME = "base" | |
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
| """ | |
| Process request through middleware. | |
| Override this in subclasses to implement service-specific logic. | |
| """ | |
| # Initialize context | |
| await initialize_request_context(request) | |
| # Call next middleware/route | |
| response = await call_next(request) | |
| return response | |
| def log_request(self, request: Request, message: str) -> None: | |
| """Log request with service context.""" | |
| logger.info(f"[{self.SERVICE_NAME}] {request.method} {request.url.path} - {message}") | |
| def log_error(self, request: Request, error: str) -> None: | |
| """Log error with service context.""" | |
| logger.error(f"[{self.SERVICE_NAME}] {request.method} {request.url.path} - ERROR: {error}") | |
| __all__ = [ | |
| 'MiddlewareChain', | |
| 'RequestContext', | |
| 'BaseServiceMiddleware', | |
| 'initialize_request_context', | |
| 'get_request_context', | |
| ] | |