AIDA / app /middleware /rate_limiter.py
destinyebuka's picture
mvp fix
b91dd33
# ============================================================
# app/middleware/rate_limiter.py - API Rate Limiting
# ============================================================
#
# Production-grade rate limiting using slowapi with Redis backend
# for distributed rate limiting across multiple server instances.
#
# Rate Limits:
# - AI endpoints (/ai/*): 10 requests/minute (expensive LLM calls)
# - Auth endpoints (/api/auth/*): 20 requests/minute (security)
# - Standard API: 100 requests/minute
# - WebSocket connections: 10/minute per IP
# ============================================================
from slowapi import Limiter
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from fastapi import Request
from fastapi.responses import JSONResponse
import logging
import os
logger = logging.getLogger(__name__)
# ============================================================
# Configuration
# ============================================================
# Use Redis for distributed rate limiting in production
# Falls back to in-memory for development
REDIS_URL = os.getenv("REDIS_URL", None)
def get_client_identifier(request: Request) -> str:
"""
Get a unique identifier for the client.
Priority:
1. Authenticated user ID (if available in request state)
2. X-Forwarded-For header (for proxied requests)
3. Client IP address
"""
# Try to get authenticated user ID
if hasattr(request.state, "user_id") and request.state.user_id:
return f"user:{request.state.user_id}"
# Check for forwarded IP (behind load balancer/proxy)
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
# X-Forwarded-For can contain multiple IPs, get the first one
return forwarded.split(",")[0].strip()
# Fall back to direct client IP
return get_remote_address(request)
# ============================================================
# Limiter Instance
# ============================================================
# Configure storage backend
if REDIS_URL:
storage_uri = REDIS_URL
logger.info(f"Rate limiter using Redis backend")
else:
storage_uri = "memory://"
logger.warning("Rate limiter using in-memory backend (not suitable for production)")
limiter = Limiter(
key_func=get_client_identifier,
default_limits=["100/minute"], # Default for all endpoints
storage_uri=storage_uri,
strategy="fixed-window", # Simple and efficient
headers_enabled=True, # Include X-RateLimit-* headers in response
)
# ============================================================
# Custom Rate Limit Exceeded Handler
# ============================================================
async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
"""
Custom handler for rate limit exceeded errors.
Returns a user-friendly JSON response with retry information.
"""
logger.warning(
f"Rate limit exceeded for {get_client_identifier(request)}: "
f"{request.method} {request.url.path}"
)
# Parse the retry-after from the exception
retry_after = getattr(exc, "retry_after", 60)
return JSONResponse(
status_code=429,
content={
"success": False,
"error_code": "RATE_LIMIT_EXCEEDED",
"message": "Too many requests. Please slow down.",
"retry_after_seconds": retry_after,
"detail": str(exc.detail) if hasattr(exc, "detail") else "Rate limit exceeded",
},
headers={
"Retry-After": str(retry_after),
"X-RateLimit-Limit": str(getattr(exc, "limit", "unknown")),
}
)
# ============================================================
# Rate Limit Decorators for Specific Endpoints
# ============================================================
# AI endpoints - expensive, limit strictly
AI_RATE_LIMIT = "10/minute"
# Auth endpoints - security-sensitive
AUTH_RATE_LIMIT = "20/minute"
# Search endpoints - moderately expensive
SEARCH_RATE_LIMIT = "30/minute"
# Standard API endpoints
STANDARD_RATE_LIMIT = "100/minute"
# WebSocket connections
WEBSOCKET_RATE_LIMIT = "10/minute"
# Heavy operations (file uploads, etc.)
HEAVY_RATE_LIMIT = "5/minute"
# ============================================================
# Helper function to apply rate limiting
# ============================================================
def get_limiter():
"""Get the configured limiter instance."""
return limiter
# ============================================================
# Exempt paths (no rate limiting)
# ============================================================
EXEMPT_PATHS = {
"/health",
"/",
"/docs",
"/openapi.json",
"/redoc",
}
def is_exempt(path: str) -> bool:
"""Check if a path is exempt from rate limiting."""
return path in EXEMPT_PATHS