File size: 6,517 Bytes
bcc8074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
"""
Middleware Chain - Orchestration of multiple middleware layers.

Provides utilities for managing and coordinating multiple middleware
components in the request/response flow.

Usage:
    # In app.py
    from services.base_service import MiddlewareChain
    
    # Add middleware in reverse order (last added = first executed)
    app.add_middleware(CreditMiddleware)
    app.add_middleware(AuthMiddleware)
    
    # Or use the chain helper
    chain = MiddlewareChain()
    chain.add(AuthMiddleware)
    chain.add(CreditMiddleware)
    chain.apply_to_app(app)
"""

import logging
from typing import List, Type, Callable
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi import FastAPI, Request, Response

logger = logging.getLogger(__name__)


class RequestContext:
    """
    Shared context for passing data between middleware layers.
    
    Attached to request.state for access across middleware and routers.
    """
    
    def __init__(self):
        """Initialize empty context."""
        # Auth layer
        self.user = None
        self.is_authenticated = False
        
        # Credit layer
        self.credits_reserved = 0
        self.credit_cost = 0
        
        # General
        self.start_time = None
        self.service_flags = {}
    
    def set_user(self, user) -> None:
        """Set authenticated user."""
        self.user = user
        self.is_authenticated = True
    
    def set_credits(self, reserved: int, cost: int) -> None:
        """Set credit information."""
        self.credits_reserved = reserved
        self.credit_cost = cost
    
    def set_flag(self, key: str, value: any) -> None:
        """Set a service-specific flag."""
        self.service_flags[key] = value
    
    def get_flag(self, key: str, default=None) -> any:
        """Get a service-specific flag."""
        return self.service_flags.get(key, default)


class MiddlewareChain:
    """
    Helper for managing middleware registration order.
    
    FastAPI/Starlette middleware executes in REVERSE order of registration,
    so the LAST middleware added is the FIRST to execute.
    
    This class helps manage the order explicitly.
    """
    
    def __init__(self):
        """Initialize empty middleware chain."""
        self._middleware: List[Type[BaseHTTPMiddleware]] = []
    
    def add(self, middleware_class: Type[BaseHTTPMiddleware], **kwargs) -> 'MiddlewareChain':
        """
        Add middleware to the chain.
        
        Middleware is added to the END of the list, but will be registered
        in REVERSE order (so first added = first executed).
        
        Args:
            middleware_class: Middleware class to add
            **kwargs: Arguments to pass to middleware constructor
        
        Returns:
            Self for chaining
        """
        self._middleware.append((middleware_class, kwargs))
        logger.debug(f"Added middleware to chain: {middleware_class.__name__}")
        return self
    
    def apply_to_app(self, app: FastAPI) -> None:
        """
        Apply all middleware to the FastAPI app in correct order.
        
        Middleware is registered in REVERSE order so that the first
        middleware added to the chain is the first to execute.
        
        Args:
            app: FastAPI application instance
        """
        # Reverse the list so first added = first executed
        for middleware_class, kwargs in reversed(self._middleware):
            app.add_middleware(middleware_class, **kwargs)
            logger.info(f"Registered middleware: {middleware_class.__name__}")
    
    def get_middleware_list(self) -> List[Type[BaseHTTPMiddleware]]:
        """
        Get the list of middleware in execution order.
        
        Returns:
            List of middleware classes in the order they will execute
        """
        return [m[0] for m in self._middleware]
    
    def __len__(self) -> int:
        """Get number of middleware in chain."""
        return len(self._middleware)
    
    def __repr__(self) -> str:
        """String representation for debugging."""
        middleware_names = [m[0].__name__ for m in self._middleware]
        return f"MiddlewareChain({middleware_names})"


async def initialize_request_context(request: Request) -> None:
    """
    Initialize request context for middleware to use.
    
    This should be called early in the middleware chain to ensure
    request.state.ctx is available.
    
    Usage:
        class MyMiddleware(BaseHTTPMiddleware):
            async def dispatch(self, request: Request, call_next):
                await initialize_request_context(request)
                # Now request.state.ctx is available
                ...
    """
    if not hasattr(request.state, "ctx"):
        request.state.ctx = RequestContext()


def get_request_context(request: Request) -> RequestContext:
    """
    Get request context from request.state.
    
    Creates context if it doesn't exist.
    
    Args:
        request: FastAPI request object
    
    Returns:
        RequestContext instance
    """
    if not hasattr(request.state, "ctx"):
        request.state.ctx = RequestContext()
    return request.state.ctx


class BaseServiceMiddleware(BaseHTTPMiddleware):
    """
    Base class for service middleware.
    
    Provides common functionality for all service middleware:
    - Request context initialization
    - Error handling
    - Logging
    """
    
    SERVICE_NAME = "base"
    
    async def dispatch(self, request: Request, call_next: Callable) -> Response:
        """
        Process request through middleware.
        
        Override this in subclasses to implement service-specific logic.
        """
        # Initialize context
        await initialize_request_context(request)
        
        # Call next middleware/route
        response = await call_next(request)
        
        return response
    
    def log_request(self, request: Request, message: str) -> None:
        """Log request with service context."""
        logger.info(f"[{self.SERVICE_NAME}] {request.method} {request.url.path} - {message}")
    
    def log_error(self, request: Request, error: str) -> None:
        """Log error with service context."""
        logger.error(f"[{self.SERVICE_NAME}] {request.method} {request.url.path} - ERROR: {error}")


__all__ = [
    'MiddlewareChain',
    'RequestContext',
    'BaseServiceMiddleware',
    'initialize_request_context',
    'get_request_context',
]