""" Middleware Chain - Orchestration of multiple middleware layers. Provides utilities for managing and coordinating multiple middleware components in the request/response flow. Usage: # In app.py from services.base_service import MiddlewareChain # Add middleware in reverse order (last added = first executed) app.add_middleware(CreditMiddleware) app.add_middleware(AuthMiddleware) # Or use the chain helper chain = MiddlewareChain() chain.add(AuthMiddleware) chain.add(CreditMiddleware) chain.apply_to_app(app) """ import logging from typing import List, Type, Callable from starlette.middleware.base import BaseHTTPMiddleware from fastapi import FastAPI, Request, Response logger = logging.getLogger(__name__) class RequestContext: """ Shared context for passing data between middleware layers. Attached to request.state for access across middleware and routers. """ def __init__(self): """Initialize empty context.""" # Auth layer self.user = None self.is_authenticated = False # Credit layer self.credits_reserved = 0 self.credit_cost = 0 # General self.start_time = None self.service_flags = {} def set_user(self, user) -> None: """Set authenticated user.""" self.user = user self.is_authenticated = True def set_credits(self, reserved: int, cost: int) -> None: """Set credit information.""" self.credits_reserved = reserved self.credit_cost = cost def set_flag(self, key: str, value: any) -> None: """Set a service-specific flag.""" self.service_flags[key] = value def get_flag(self, key: str, default=None) -> any: """Get a service-specific flag.""" return self.service_flags.get(key, default) class MiddlewareChain: """ Helper for managing middleware registration order. FastAPI/Starlette middleware executes in REVERSE order of registration, so the LAST middleware added is the FIRST to execute. This class helps manage the order explicitly. """ def __init__(self): """Initialize empty middleware chain.""" self._middleware: List[Type[BaseHTTPMiddleware]] = [] def add(self, middleware_class: Type[BaseHTTPMiddleware], **kwargs) -> 'MiddlewareChain': """ Add middleware to the chain. Middleware is added to the END of the list, but will be registered in REVERSE order (so first added = first executed). Args: middleware_class: Middleware class to add **kwargs: Arguments to pass to middleware constructor Returns: Self for chaining """ self._middleware.append((middleware_class, kwargs)) logger.debug(f"Added middleware to chain: {middleware_class.__name__}") return self def apply_to_app(self, app: FastAPI) -> None: """ Apply all middleware to the FastAPI app in correct order. Middleware is registered in REVERSE order so that the first middleware added to the chain is the first to execute. Args: app: FastAPI application instance """ # Reverse the list so first added = first executed for middleware_class, kwargs in reversed(self._middleware): app.add_middleware(middleware_class, **kwargs) logger.info(f"Registered middleware: {middleware_class.__name__}") def get_middleware_list(self) -> List[Type[BaseHTTPMiddleware]]: """ Get the list of middleware in execution order. Returns: List of middleware classes in the order they will execute """ return [m[0] for m in self._middleware] def __len__(self) -> int: """Get number of middleware in chain.""" return len(self._middleware) def __repr__(self) -> str: """String representation for debugging.""" middleware_names = [m[0].__name__ for m in self._middleware] return f"MiddlewareChain({middleware_names})" async def initialize_request_context(request: Request) -> None: """ Initialize request context for middleware to use. This should be called early in the middleware chain to ensure request.state.ctx is available. Usage: class MyMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): await initialize_request_context(request) # Now request.state.ctx is available ... """ if not hasattr(request.state, "ctx"): request.state.ctx = RequestContext() def get_request_context(request: Request) -> RequestContext: """ Get request context from request.state. Creates context if it doesn't exist. Args: request: FastAPI request object Returns: RequestContext instance """ if not hasattr(request.state, "ctx"): request.state.ctx = RequestContext() return request.state.ctx class BaseServiceMiddleware(BaseHTTPMiddleware): """ Base class for service middleware. Provides common functionality for all service middleware: - Request context initialization - Error handling - Logging """ SERVICE_NAME = "base" async def dispatch(self, request: Request, call_next: Callable) -> Response: """ Process request through middleware. Override this in subclasses to implement service-specific logic. """ # Initialize context await initialize_request_context(request) # Call next middleware/route response = await call_next(request) return response def log_request(self, request: Request, message: str) -> None: """Log request with service context.""" logger.info(f"[{self.SERVICE_NAME}] {request.method} {request.url.path} - {message}") def log_error(self, request: Request, error: str) -> None: """Log error with service context.""" logger.error(f"[{self.SERVICE_NAME}] {request.method} {request.url.path} - ERROR: {error}") __all__ = [ 'MiddlewareChain', 'RequestContext', 'BaseServiceMiddleware', 'initialize_request_context', 'get_request_context', ]