apigateway / services /base_service /middleware_chain.py
jebin2's picture
refactor
bcc8074
"""
Middleware Chain - Orchestration of multiple middleware layers.
Provides utilities for managing and coordinating multiple middleware
components in the request/response flow.
Usage:
# In app.py
from services.base_service import MiddlewareChain
# Add middleware in reverse order (last added = first executed)
app.add_middleware(CreditMiddleware)
app.add_middleware(AuthMiddleware)
# Or use the chain helper
chain = MiddlewareChain()
chain.add(AuthMiddleware)
chain.add(CreditMiddleware)
chain.apply_to_app(app)
"""
import logging
from typing import List, Type, Callable
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import FastAPI, Request, Response
logger = logging.getLogger(__name__)
class RequestContext:
"""
Shared context for passing data between middleware layers.
Attached to request.state for access across middleware and routers.
"""
def __init__(self):
"""Initialize empty context."""
# Auth layer
self.user = None
self.is_authenticated = False
# Credit layer
self.credits_reserved = 0
self.credit_cost = 0
# General
self.start_time = None
self.service_flags = {}
def set_user(self, user) -> None:
"""Set authenticated user."""
self.user = user
self.is_authenticated = True
def set_credits(self, reserved: int, cost: int) -> None:
"""Set credit information."""
self.credits_reserved = reserved
self.credit_cost = cost
def set_flag(self, key: str, value: any) -> None:
"""Set a service-specific flag."""
self.service_flags[key] = value
def get_flag(self, key: str, default=None) -> any:
"""Get a service-specific flag."""
return self.service_flags.get(key, default)
class MiddlewareChain:
"""
Helper for managing middleware registration order.
FastAPI/Starlette middleware executes in REVERSE order of registration,
so the LAST middleware added is the FIRST to execute.
This class helps manage the order explicitly.
"""
def __init__(self):
"""Initialize empty middleware chain."""
self._middleware: List[Type[BaseHTTPMiddleware]] = []
def add(self, middleware_class: Type[BaseHTTPMiddleware], **kwargs) -> 'MiddlewareChain':
"""
Add middleware to the chain.
Middleware is added to the END of the list, but will be registered
in REVERSE order (so first added = first executed).
Args:
middleware_class: Middleware class to add
**kwargs: Arguments to pass to middleware constructor
Returns:
Self for chaining
"""
self._middleware.append((middleware_class, kwargs))
logger.debug(f"Added middleware to chain: {middleware_class.__name__}")
return self
def apply_to_app(self, app: FastAPI) -> None:
"""
Apply all middleware to the FastAPI app in correct order.
Middleware is registered in REVERSE order so that the first
middleware added to the chain is the first to execute.
Args:
app: FastAPI application instance
"""
# Reverse the list so first added = first executed
for middleware_class, kwargs in reversed(self._middleware):
app.add_middleware(middleware_class, **kwargs)
logger.info(f"Registered middleware: {middleware_class.__name__}")
def get_middleware_list(self) -> List[Type[BaseHTTPMiddleware]]:
"""
Get the list of middleware in execution order.
Returns:
List of middleware classes in the order they will execute
"""
return [m[0] for m in self._middleware]
def __len__(self) -> int:
"""Get number of middleware in chain."""
return len(self._middleware)
def __repr__(self) -> str:
"""String representation for debugging."""
middleware_names = [m[0].__name__ for m in self._middleware]
return f"MiddlewareChain({middleware_names})"
async def initialize_request_context(request: Request) -> None:
"""
Initialize request context for middleware to use.
This should be called early in the middleware chain to ensure
request.state.ctx is available.
Usage:
class MyMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
await initialize_request_context(request)
# Now request.state.ctx is available
...
"""
if not hasattr(request.state, "ctx"):
request.state.ctx = RequestContext()
def get_request_context(request: Request) -> RequestContext:
"""
Get request context from request.state.
Creates context if it doesn't exist.
Args:
request: FastAPI request object
Returns:
RequestContext instance
"""
if not hasattr(request.state, "ctx"):
request.state.ctx = RequestContext()
return request.state.ctx
class BaseServiceMiddleware(BaseHTTPMiddleware):
"""
Base class for service middleware.
Provides common functionality for all service middleware:
- Request context initialization
- Error handling
- Logging
"""
SERVICE_NAME = "base"
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""
Process request through middleware.
Override this in subclasses to implement service-specific logic.
"""
# Initialize context
await initialize_request_context(request)
# Call next middleware/route
response = await call_next(request)
return response
def log_request(self, request: Request, message: str) -> None:
"""Log request with service context."""
logger.info(f"[{self.SERVICE_NAME}] {request.method} {request.url.path} - {message}")
def log_error(self, request: Request, error: str) -> None:
"""Log error with service context."""
logger.error(f"[{self.SERVICE_NAME}] {request.method} {request.url.path} - ERROR: {error}")
__all__ = [
'MiddlewareChain',
'RequestContext',
'BaseServiceMiddleware',
'initialize_request_context',
'get_request_context',
]