| | """
|
| | Middleware for Production Security
|
| | Rate limiting, request logging, security headers, CORS
|
| | """
|
| |
|
| | import time
|
| | import uuid
|
| | from typing import Callable
|
| | from fastapi import Request, Response, status
|
| | from fastapi.responses import JSONResponse
|
| | from starlette.middleware.base import BaseHTTPMiddleware
|
| | from starlette.middleware.cors import CORSMiddleware
|
| |
|
| | from src.core.config import settings
|
| | from src.core.logging import logger, log_api_request, log_error
|
| | from src.core.exceptions import RateLimitExceededError
|
| | from src.core.cache import cache
|
| |
|
| |
|
| | class RequestIDMiddleware(BaseHTTPMiddleware):
|
| | """Add unique request ID to each request"""
|
| |
|
| | async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| | request_id = str(uuid.uuid4())
|
| | request.state.request_id = request_id
|
| |
|
| | response = await call_next(request)
|
| | response.headers["X-Request-ID"] = request_id
|
| |
|
| | return response
|
| |
|
| |
|
| | class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
| | """Log all API requests with performance metrics"""
|
| |
|
| | async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| | start_time = time.time()
|
| |
|
| |
|
| | request_id = getattr(request.state, "request_id", None)
|
| |
|
| |
|
| | response = await call_next(request)
|
| |
|
| |
|
| | duration_ms = (time.time() - start_time) * 1000
|
| |
|
| |
|
| | log_api_request(
|
| | method=request.method,
|
| | path=str(request.url.path),
|
| | status_code=response.status_code,
|
| | duration_ms=duration_ms,
|
| | user_id=getattr(request.state, "user_id", None),
|
| | request_id=request_id
|
| | )
|
| |
|
| |
|
| | response.headers["X-Response-Time"] = f"{duration_ms:.2f}ms"
|
| |
|
| | return response
|
| |
|
| |
|
| | class RateLimitMiddleware(BaseHTTPMiddleware):
|
| | """Rate limiting based on IP address or API key"""
|
| |
|
| | async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| | if not settings.RATE_LIMIT_ENABLED:
|
| | return await call_next(request)
|
| |
|
| |
|
| | if request.url.path == "/health":
|
| | return await call_next(request)
|
| |
|
| |
|
| | client_ip = request.client.host if request.client else "unknown"
|
| | user_id = getattr(request.state, "user_id", None)
|
| | identifier = f"user:{user_id}" if user_id else f"ip:{client_ip}"
|
| |
|
| |
|
| | count = await cache.increment_rate_limit(identifier, 60)
|
| |
|
| | if count > settings.RATE_LIMIT_PER_MINUTE:
|
| | logger.warning(f"Rate limit exceeded for {identifier}: {count} requests")
|
| | raise RateLimitExceededError(
|
| | limit=settings.RATE_LIMIT_PER_MINUTE,
|
| | window="minute"
|
| | )
|
| |
|
| |
|
| | response = await call_next(request)
|
| | response.headers["X-RateLimit-Limit"] = str(settings.RATE_LIMIT_PER_MINUTE)
|
| | response.headers["X-RateLimit-Remaining"] = str(max(0, settings.RATE_LIMIT_PER_MINUTE - count))
|
| |
|
| | return response
|
| |
|
| |
|
| | class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
| | """Add security headers to responses"""
|
| |
|
| | async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| | response = await call_next(request)
|
| |
|
| |
|
| | response.headers["X-Content-Type-Options"] = "nosniff"
|
| | response.headers["X-Frame-Options"] = "DENY"
|
| | response.headers["X-XSS-Protection"] = "1; mode=block"
|
| | response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
| | response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
| |
|
| |
|
| | if settings.is_production:
|
| | response.headers["Content-Security-Policy"] = (
|
| | "default-src 'self'; "
|
| | "script-src 'self' 'unsafe-inline'; "
|
| | "style-src 'self' 'unsafe-inline'; "
|
| | "img-src 'self' data: https:; "
|
| | "font-src 'self' data:; "
|
| | "connect-src 'self'"
|
| | )
|
| |
|
| | return response
|
| |
|
| |
|
| | class ErrorHandlerMiddleware(BaseHTTPMiddleware):
|
| | """Global error handler"""
|
| |
|
| | async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
| | try:
|
| | response = await call_next(request)
|
| | return response
|
| | except Exception as e:
|
| |
|
| | log_error(
|
| | error=e,
|
| | context={
|
| | "method": request.method,
|
| | "path": str(request.url.path),
|
| | "client": request.client.host if request.client else None
|
| | },
|
| | request_id=getattr(request.state, "request_id", None)
|
| | )
|
| |
|
| |
|
| | from src.core.exceptions import AppException
|
| |
|
| | if isinstance(e, AppException):
|
| | return JSONResponse(
|
| | status_code=e.status_code,
|
| | content={
|
| | "error": type(e).__name__,
|
| | "message": e.message,
|
| | "details": e.details,
|
| | "request_id": getattr(request.state, "request_id", None)
|
| | }
|
| | )
|
| | else:
|
| |
|
| | return JSONResponse(
|
| | status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
| | content={
|
| | "error": "InternalServerError",
|
| | "message": "An unexpected error occurred",
|
| | "request_id": getattr(request.state, "request_id", None)
|
| | }
|
| | )
|
| |
|
| |
|
| | def setup_cors(app):
|
| | """Configure CORS middleware"""
|
| | app.add_middleware(
|
| | CORSMiddleware,
|
| | allow_origins=settings.BACKEND_CORS_ORIGINS,
|
| | allow_credentials=True,
|
| | allow_methods=["*"],
|
| | allow_headers=["*"],
|
| | expose_headers=["X-Request-ID", "X-Response-Time", "X-RateLimit-Limit", "X-RateLimit-Remaining"]
|
| | )
|
| |
|
| |
|
| | def setup_middleware(app):
|
| | """Setup all middleware in correct order"""
|
| |
|
| |
|
| |
|
| |
|
| | app.add_middleware(ErrorHandlerMiddleware)
|
| |
|
| |
|
| | app.add_middleware(SecurityHeadersMiddleware)
|
| |
|
| |
|
| | app.add_middleware(RateLimitMiddleware)
|
| |
|
| |
|
| | app.add_middleware(RequestLoggingMiddleware)
|
| |
|
| |
|
| | app.add_middleware(RequestIDMiddleware)
|
| |
|
| |
|
| | setup_cors(app)
|
| |
|
| | logger.info("Middleware configured successfully")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | print("Middleware module loaded")
|
| | print(f"Rate limiting: {'Enabled' if settings.RATE_LIMIT_ENABLED else 'Disabled'}")
|
| | print(f"CORS origins: {settings.BACKEND_CORS_ORIGINS}")
|
| |
|