Spaces:
Sleeping
Sleeping
| """ | |
| Rate limiting middleware for FastAPI. | |
| """ | |
| import logging | |
| from typing import Callable, Optional | |
| from fastapi import Request, HTTPException | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from starlette.responses import Response | |
| from .redis_service import get_redis_service | |
| from .mongodb_service import get_mongodb_service | |
| logger = logging.getLogger(__name__) | |
| class RateLimitMiddleware(BaseHTTPMiddleware): | |
| """ | |
| Rate limiting middleware using Redis. | |
| Policy: 2 requests per user per minute | |
| Priority: user_id > device_id > IP | |
| """ | |
| def __init__(self, app, max_requests: int = 2, window_seconds: int = 60): | |
| super().__init__(app) | |
| self.max_requests = max_requests | |
| self.window_seconds = window_seconds | |
| self.redis_service = get_redis_service() | |
| self.mongodb_service = get_mongodb_service() | |
| def _get_identifier(self, request: Request) -> tuple[str, str]: | |
| """ | |
| Get rate limit identifier with priority: user_id > device_id > IP. | |
| Returns: | |
| Tuple of (identifier, identifier_type) | |
| """ | |
| # Try to get from headers/cookies | |
| user_id = request.headers.get("X-User-Id") or request.cookies.get("user_id") | |
| device_id = request.headers.get("X-Device-Id") or request.cookies.get("device_id") | |
| # Priority: user_id > device_id > IP | |
| if user_id: | |
| return f"user:{user_id}", "user_id" | |
| elif device_id: | |
| return f"device:{device_id}", "device_id" | |
| else: | |
| # Fallback to IP | |
| client_ip = request.client.host if request.client else "unknown" | |
| return f"ip:{client_ip}", "ip" | |
| async def dispatch( | |
| self, | |
| request: Request, | |
| call_next: Callable | |
| ) -> Response: | |
| """Process request with rate limiting.""" | |
| # Skip rate limiting for admin endpoints | |
| if request.url.path.startswith("/admin/"): | |
| return await call_next(request) | |
| # Skip for health checks | |
| if request.url.path in ["/health", "/", "/docs", "/openapi.json"]: | |
| return await call_next(request) | |
| # Get identifier | |
| identifier, identifier_type = self._get_identifier(request) | |
| # Check rate limit | |
| is_allowed, current_count = self.redis_service.check_rate_limit( | |
| identifier, | |
| self.max_requests, | |
| self.window_seconds | |
| ) | |
| if not is_allowed: | |
| # Log rate limit hit to MongoDB | |
| device_id = ( | |
| request.headers.get("X-Device-Id") | |
| or request.cookies.get("device_id") | |
| or identifier | |
| ) | |
| user_id = request.headers.get("X-User-Id") or request.cookies.get("user_id") | |
| self.mongodb_service.log_event( | |
| event_type="RATE_LIMIT_HIT", | |
| device_id=device_id, | |
| user_id=user_id, | |
| metadata={ | |
| "identifier": identifier, | |
| "identifier_type": identifier_type, | |
| "count": current_count, | |
| "path": request.url.path | |
| } | |
| ) | |
| logger.warning( | |
| f"Rate limit exceeded for {identifier} " | |
| f"(count: {current_count}, max: {self.max_requests})" | |
| ) | |
| raise HTTPException( | |
| status_code=429, | |
| detail={ | |
| "error": "Rate limit exceeded", | |
| "message": f"Too many requests. Maximum {self.max_requests} requests per {self.window_seconds} seconds.", | |
| "retry_after": self.window_seconds | |
| } | |
| ) | |
| # Add rate limit info to response headers | |
| response = await call_next(request) | |
| response.headers["X-RateLimit-Limit"] = str(self.max_requests) | |
| response.headers["X-RateLimit-Remaining"] = str(max(0, self.max_requests - current_count)) | |
| response.headers["X-RateLimit-Reset"] = str(self.window_seconds) | |
| return response | |