Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.middleware.trustedhost import TrustedHostMiddleware | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from starlette.responses import Response | |
| import time | |
| from collections import defaultdict | |
| import asyncio | |
| from typing import Dict, List | |
| from app.core.config import settings | |
| class SecurityHeadersMiddleware(BaseHTTPMiddleware): | |
| """Add security headers to all responses""" | |
| async def dispatch(self, request: Request, call_next): | |
| response = await call_next(request) | |
| # Security headers | |
| response.headers["X-Content-Type-Options"] = "nosniff" | |
| response.headers["X-Frame-Options"] = "DENY" | |
| response.headers["X-XSS-Protection"] = "1; mode=block" | |
| response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" | |
| response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" | |
| response.headers["Content-Security-Policy"] = "default-src 'self'; img-src 'self' data: https:; style-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; font-src 'self' data: https:" | |
| return response | |
| class RateLimitMiddleware(BaseHTTPMiddleware): | |
| """Simple rate limiting middleware with memory cleanup""" | |
| def __init__(self, app, calls: int = None, period: int = None, cleanup_interval: int = 300): | |
| super().__init__(app) | |
| self.calls = calls or settings.RATE_LIMIT_CALLS | |
| self.period = period or settings.RATE_LIMIT_PERIOD | |
| self.cleanup_interval = cleanup_interval # Cleanup every 5 minutes by default | |
| self.clients: Dict[str, List[float]] = defaultdict(list) | |
| self.last_cleanup = time.time() | |
| def _cleanup_inactive_clients(self, now: float): | |
| """Remove clients that haven't made requests within the cleanup interval""" | |
| inactive_clients = [] | |
| for client_ip, request_times in self.clients.items(): | |
| if not request_times or (now - max(request_times)) > self.cleanup_interval: | |
| inactive_clients.append(client_ip) | |
| for client_ip in inactive_clients: | |
| del self.clients[client_ip] | |
| async def dispatch(self, request: Request, call_next): | |
| client_ip = request.client.host | |
| now = time.time() | |
| # Periodic cleanup of inactive clients | |
| if now - self.last_cleanup > self.cleanup_interval: | |
| self._cleanup_inactive_clients(now) | |
| self.last_cleanup = now | |
| # Clean old requests for current client | |
| self.clients[client_ip] = [ | |
| req_time for req_time in self.clients[client_ip] | |
| if now - req_time < self.period | |
| ] | |
| # Check rate limit | |
| if len(self.clients[client_ip]) >= self.calls: | |
| raise HTTPException( | |
| status_code=429, | |
| detail="Rate limit exceeded. Please try again later." | |
| ) | |
| # Add current request | |
| self.clients[client_ip].append(now) | |
| response = await call_next(request) | |
| return response | |
| def add_security_middleware(app: FastAPI): | |
| """Add all security middleware to the FastAPI app""" | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=settings.CORS_ORIGINS, | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH"], | |
| allow_headers=["*"], | |
| ) | |
| # Trusted host middleware | |
| app.add_middleware( | |
| TrustedHostMiddleware, | |
| allowed_hosts=settings.ALLOWED_HOSTS | |
| ) | |
| # Security headers middleware | |
| app.add_middleware(SecurityHeadersMiddleware) | |
| # Rate limiting middleware | |
| app.add_middleware(RateLimitMiddleware) |