""" Middleware for Production Security Rate limiting, request logging, security headers, CORS """ import time import uuid from typing import Callable from fastapi import Request, Response, status from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.cors import CORSMiddleware from src.core.config import settings from src.core.logging import logger, log_api_request, log_error from src.core.exceptions import RateLimitExceededError from src.core.cache import cache class RequestIDMiddleware(BaseHTTPMiddleware): """Add unique request ID to each request""" async def dispatch(self, request: Request, call_next: Callable) -> Response: request_id = str(uuid.uuid4()) request.state.request_id = request_id response = await call_next(request) response.headers["X-Request-ID"] = request_id return response class RequestLoggingMiddleware(BaseHTTPMiddleware): """Log all API requests with performance metrics""" async def dispatch(self, request: Request, call_next: Callable) -> Response: start_time = time.time() # Get request ID request_id = getattr(request.state, "request_id", None) # Process request response = await call_next(request) # Calculate duration duration_ms = (time.time() - start_time) * 1000 # Log request log_api_request( method=request.method, path=str(request.url.path), status_code=response.status_code, duration_ms=duration_ms, user_id=getattr(request.state, "user_id", None), request_id=request_id ) # Add performance header response.headers["X-Response-Time"] = f"{duration_ms:.2f}ms" return response class RateLimitMiddleware(BaseHTTPMiddleware): """Rate limiting based on IP address or API key""" async def dispatch(self, request: Request, call_next: Callable) -> Response: if not settings.RATE_LIMIT_ENABLED: return await call_next(request) # Skip rate limiting for health check if request.url.path == "/health": return await call_next(request) # Get identifier (IP address or user ID) client_ip = request.client.host if request.client else "unknown" user_id = getattr(request.state, "user_id", None) identifier = f"user:{user_id}" if user_id else f"ip:{client_ip}" # Check rate limit (per minute) count = await cache.increment_rate_limit(identifier, 60) if count > settings.RATE_LIMIT_PER_MINUTE: logger.warning(f"Rate limit exceeded for {identifier}: {count} requests") raise RateLimitExceededError( limit=settings.RATE_LIMIT_PER_MINUTE, window="minute" ) # Add rate limit headers response = await call_next(request) response.headers["X-RateLimit-Limit"] = str(settings.RATE_LIMIT_PER_MINUTE) response.headers["X-RateLimit-Remaining"] = str(max(0, settings.RATE_LIMIT_PER_MINUTE - count)) return response class SecurityHeadersMiddleware(BaseHTTPMiddleware): """Add security headers to responses""" async def dispatch(self, request: Request, call_next: Callable) -> Response: response = await call_next(request) # Security headers response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Frame-Options"] = "DENY" response.headers["X-XSS-Protection"] = "1; mode=block" response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" # Content Security Policy if settings.is_production: response.headers["Content-Security-Policy"] = ( "default-src 'self'; " "script-src 'self' 'unsafe-inline'; " "style-src 'self' 'unsafe-inline'; " "img-src 'self' data: https:; " "font-src 'self' data:; " "connect-src 'self'" ) return response class ErrorHandlerMiddleware(BaseHTTPMiddleware): """Global error handler""" async def dispatch(self, request: Request, call_next: Callable) -> Response: try: response = await call_next(request) return response except Exception as e: # Log error log_error( error=e, context={ "method": request.method, "path": str(request.url.path), "client": request.client.host if request.client else None }, request_id=getattr(request.state, "request_id", None) ) # Return error response from src.core.exceptions import AppException if isinstance(e, AppException): return JSONResponse( status_code=e.status_code, content={ "error": type(e).__name__, "message": e.message, "details": e.details, "request_id": getattr(request.state, "request_id", None) } ) else: # Generic error response return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={ "error": "InternalServerError", "message": "An unexpected error occurred", "request_id": getattr(request.state, "request_id", None) } ) def setup_cors(app): """Configure CORS middleware""" app.add_middleware( CORSMiddleware, allow_origins=settings.BACKEND_CORS_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], expose_headers=["X-Request-ID", "X-Response-Time", "X-RateLimit-Limit", "X-RateLimit-Remaining"] ) def setup_middleware(app): """Setup all middleware in correct order""" # Order matters! Apply in reverse order of execution # Error handling (outermost) app.add_middleware(ErrorHandlerMiddleware) # Security headers app.add_middleware(SecurityHeadersMiddleware) # Rate limiting app.add_middleware(RateLimitMiddleware) # Request logging app.add_middleware(RequestLoggingMiddleware) # Request ID (innermost) app.add_middleware(RequestIDMiddleware) # CORS setup_cors(app) logger.info("Middleware configured successfully") if __name__ == "__main__": print("Middleware module loaded") print(f"Rate limiting: {'Enabled' if settings.RATE_LIMIT_ENABLED else 'Disabled'}") print(f"CORS origins: {settings.BACKEND_CORS_ORIGINS}")