from fastapi import FastAPI, Request, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.trustedhost import TrustedHostMiddleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import Response import time from collections import defaultdict import asyncio from typing import Dict, List from app.core.config import settings class SecurityHeadersMiddleware(BaseHTTPMiddleware): """Add security headers to all responses""" async def dispatch(self, request: Request, call_next): response = await call_next(request) # Security headers response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Frame-Options"] = "DENY" response.headers["X-XSS-Protection"] = "1; mode=block" response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" response.headers["Content-Security-Policy"] = "default-src 'self'; img-src 'self' data: https:; style-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; font-src 'self' data: https:" return response class RateLimitMiddleware(BaseHTTPMiddleware): """Simple rate limiting middleware with memory cleanup""" def __init__(self, app, calls: int = None, period: int = None, cleanup_interval: int = 300): super().__init__(app) self.calls = calls or settings.RATE_LIMIT_CALLS self.period = period or settings.RATE_LIMIT_PERIOD self.cleanup_interval = cleanup_interval # Cleanup every 5 minutes by default self.clients: Dict[str, List[float]] = defaultdict(list) self.last_cleanup = time.time() def _cleanup_inactive_clients(self, now: float): """Remove clients that haven't made requests within the cleanup interval""" inactive_clients = [] for client_ip, request_times in self.clients.items(): if not request_times or (now - max(request_times)) > self.cleanup_interval: inactive_clients.append(client_ip) for client_ip in inactive_clients: del self.clients[client_ip] async def dispatch(self, request: Request, call_next): client_ip = request.client.host now = time.time() # Periodic cleanup of inactive clients if now - self.last_cleanup > self.cleanup_interval: self._cleanup_inactive_clients(now) self.last_cleanup = now # Clean old requests for current client self.clients[client_ip] = [ req_time for req_time in self.clients[client_ip] if now - req_time < self.period ] # Check rate limit if len(self.clients[client_ip]) >= self.calls: raise HTTPException( status_code=429, detail="Rate limit exceeded. Please try again later." ) # Add current request self.clients[client_ip].append(now) response = await call_next(request) return response def add_security_middleware(app: FastAPI): """Add all security middleware to the FastAPI app""" # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=settings.CORS_ORIGINS, allow_credentials=True, allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH"], allow_headers=["*"], ) # Trusted host middleware app.add_middleware( TrustedHostMiddleware, allowed_hosts=settings.ALLOWED_HOSTS ) # Security headers middleware app.add_middleware(SecurityHeadersMiddleware) # Rate limiting middleware app.add_middleware(RateLimitMiddleware)