Spaces:
Running
Running
| # ============================================================ | |
| # app/middleware/rate_limiter.py - API Rate Limiting | |
| # ============================================================ | |
| # | |
| # Production-grade rate limiting using slowapi with Redis backend | |
| # for distributed rate limiting across multiple server instances. | |
| # | |
| # Rate Limits: | |
| # - AI endpoints (/ai/*): 10 requests/minute (expensive LLM calls) | |
| # - Auth endpoints (/api/auth/*): 20 requests/minute (security) | |
| # - Standard API: 100 requests/minute | |
| # - WebSocket connections: 10/minute per IP | |
| # ============================================================ | |
| from slowapi import Limiter | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| from slowapi.middleware import SlowAPIMiddleware | |
| from fastapi import Request | |
| from fastapi.responses import JSONResponse | |
| import logging | |
| import os | |
| logger = logging.getLogger(__name__) | |
| # ============================================================ | |
| # Configuration | |
| # ============================================================ | |
| # Use Redis for distributed rate limiting in production | |
| # Falls back to in-memory for development | |
| REDIS_URL = os.getenv("REDIS_URL", None) | |
| def get_client_identifier(request: Request) -> str: | |
| """ | |
| Get a unique identifier for the client. | |
| Priority: | |
| 1. Authenticated user ID (if available in request state) | |
| 2. X-Forwarded-For header (for proxied requests) | |
| 3. Client IP address | |
| """ | |
| # Try to get authenticated user ID | |
| if hasattr(request.state, "user_id") and request.state.user_id: | |
| return f"user:{request.state.user_id}" | |
| # Check for forwarded IP (behind load balancer/proxy) | |
| forwarded = request.headers.get("X-Forwarded-For") | |
| if forwarded: | |
| # X-Forwarded-For can contain multiple IPs, get the first one | |
| return forwarded.split(",")[0].strip() | |
| # Fall back to direct client IP | |
| return get_remote_address(request) | |
| # ============================================================ | |
| # Limiter Instance | |
| # ============================================================ | |
| # Configure storage backend | |
| if REDIS_URL: | |
| storage_uri = REDIS_URL | |
| logger.info(f"Rate limiter using Redis backend") | |
| else: | |
| storage_uri = "memory://" | |
| logger.warning("Rate limiter using in-memory backend (not suitable for production)") | |
| limiter = Limiter( | |
| key_func=get_client_identifier, | |
| default_limits=["100/minute"], # Default for all endpoints | |
| storage_uri=storage_uri, | |
| strategy="fixed-window", # Simple and efficient | |
| headers_enabled=True, # Include X-RateLimit-* headers in response | |
| ) | |
| # ============================================================ | |
| # Custom Rate Limit Exceeded Handler | |
| # ============================================================ | |
| async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded): | |
| """ | |
| Custom handler for rate limit exceeded errors. | |
| Returns a user-friendly JSON response with retry information. | |
| """ | |
| logger.warning( | |
| f"Rate limit exceeded for {get_client_identifier(request)}: " | |
| f"{request.method} {request.url.path}" | |
| ) | |
| # Parse the retry-after from the exception | |
| retry_after = getattr(exc, "retry_after", 60) | |
| return JSONResponse( | |
| status_code=429, | |
| content={ | |
| "success": False, | |
| "error_code": "RATE_LIMIT_EXCEEDED", | |
| "message": "Too many requests. Please slow down.", | |
| "retry_after_seconds": retry_after, | |
| "detail": str(exc.detail) if hasattr(exc, "detail") else "Rate limit exceeded", | |
| }, | |
| headers={ | |
| "Retry-After": str(retry_after), | |
| "X-RateLimit-Limit": str(getattr(exc, "limit", "unknown")), | |
| } | |
| ) | |
| # ============================================================ | |
| # Rate Limit Decorators for Specific Endpoints | |
| # ============================================================ | |
| # AI endpoints - expensive, limit strictly | |
| AI_RATE_LIMIT = "10/minute" | |
| # Auth endpoints - security-sensitive | |
| AUTH_RATE_LIMIT = "20/minute" | |
| # Search endpoints - moderately expensive | |
| SEARCH_RATE_LIMIT = "30/minute" | |
| # Standard API endpoints | |
| STANDARD_RATE_LIMIT = "100/minute" | |
| # WebSocket connections | |
| WEBSOCKET_RATE_LIMIT = "10/minute" | |
| # Heavy operations (file uploads, etc.) | |
| HEAVY_RATE_LIMIT = "5/minute" | |
| # ============================================================ | |
| # Helper function to apply rate limiting | |
| # ============================================================ | |
| def get_limiter(): | |
| """Get the configured limiter instance.""" | |
| return limiter | |
| # ============================================================ | |
| # Exempt paths (no rate limiting) | |
| # ============================================================ | |
| EXEMPT_PATHS = { | |
| "/health", | |
| "/", | |
| "/docs", | |
| "/openapi.json", | |
| "/redoc", | |
| } | |
| def is_exempt(path: str) -> bool: | |
| """Check if a path is exempt from rate limiting.""" | |
| return path in EXEMPT_PATHS | |