ShreyasGosavi's picture
Upload 37 files
53bec59 verified
"""
Middleware for Production Security
Rate limiting, request logging, security headers, CORS
"""
import time
import uuid
from typing import Callable
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from src.core.config import settings
from src.core.logging import logger, log_api_request, log_error
from src.core.exceptions import RateLimitExceededError
from src.core.cache import cache
class RequestIDMiddleware(BaseHTTPMiddleware):
"""Add unique request ID to each request"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
request_id = str(uuid.uuid4())
request.state.request_id = request_id
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""Log all API requests with performance metrics"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
start_time = time.time()
# Get request ID
request_id = getattr(request.state, "request_id", None)
# Process request
response = await call_next(request)
# Calculate duration
duration_ms = (time.time() - start_time) * 1000
# Log request
log_api_request(
method=request.method,
path=str(request.url.path),
status_code=response.status_code,
duration_ms=duration_ms,
user_id=getattr(request.state, "user_id", None),
request_id=request_id
)
# Add performance header
response.headers["X-Response-Time"] = f"{duration_ms:.2f}ms"
return response
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate limiting based on IP address or API key"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
if not settings.RATE_LIMIT_ENABLED:
return await call_next(request)
# Skip rate limiting for health check
if request.url.path == "/health":
return await call_next(request)
# Get identifier (IP address or user ID)
client_ip = request.client.host if request.client else "unknown"
user_id = getattr(request.state, "user_id", None)
identifier = f"user:{user_id}" if user_id else f"ip:{client_ip}"
# Check rate limit (per minute)
count = await cache.increment_rate_limit(identifier, 60)
if count > settings.RATE_LIMIT_PER_MINUTE:
logger.warning(f"Rate limit exceeded for {identifier}: {count} requests")
raise RateLimitExceededError(
limit=settings.RATE_LIMIT_PER_MINUTE,
window="minute"
)
# Add rate limit headers
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(settings.RATE_LIMIT_PER_MINUTE)
response.headers["X-RateLimit-Remaining"] = str(max(0, settings.RATE_LIMIT_PER_MINUTE - count))
return response
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security headers to responses"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
response = await call_next(request)
# Security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Content Security Policy
if settings.is_production:
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: https:; "
"font-src 'self' data:; "
"connect-src 'self'"
)
return response
class ErrorHandlerMiddleware(BaseHTTPMiddleware):
"""Global error handler"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
try:
response = await call_next(request)
return response
except Exception as e:
# Log error
log_error(
error=e,
context={
"method": request.method,
"path": str(request.url.path),
"client": request.client.host if request.client else None
},
request_id=getattr(request.state, "request_id", None)
)
# Return error response
from src.core.exceptions import AppException
if isinstance(e, AppException):
return JSONResponse(
status_code=e.status_code,
content={
"error": type(e).__name__,
"message": e.message,
"details": e.details,
"request_id": getattr(request.state, "request_id", None)
}
)
else:
# Generic error response
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"error": "InternalServerError",
"message": "An unexpected error occurred",
"request_id": getattr(request.state, "request_id", None)
}
)
def setup_cors(app):
"""Configure CORS middleware"""
app.add_middleware(
CORSMiddleware,
allow_origins=settings.BACKEND_CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["X-Request-ID", "X-Response-Time", "X-RateLimit-Limit", "X-RateLimit-Remaining"]
)
def setup_middleware(app):
"""Setup all middleware in correct order"""
# Order matters! Apply in reverse order of execution
# Error handling (outermost)
app.add_middleware(ErrorHandlerMiddleware)
# Security headers
app.add_middleware(SecurityHeadersMiddleware)
# Rate limiting
app.add_middleware(RateLimitMiddleware)
# Request logging
app.add_middleware(RequestLoggingMiddleware)
# Request ID (innermost)
app.add_middleware(RequestIDMiddleware)
# CORS
setup_cors(app)
logger.info("Middleware configured successfully")
if __name__ == "__main__":
print("Middleware module loaded")
print(f"Rate limiting: {'Enabled' if settings.RATE_LIMIT_ENABLED else 'Disabled'}")
print(f"CORS origins: {settings.BACKEND_CORS_ORIGINS}")