""" Rate limiting middleware for FastAPI. """ import logging from typing import Callable, Optional from fastapi import Request, HTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import Response from .redis_service import get_redis_service from .mongodb_service import get_mongodb_service logger = logging.getLogger(__name__) class RateLimitMiddleware(BaseHTTPMiddleware): """ Rate limiting middleware using Redis. Policy: 2 requests per user per minute Priority: user_id > device_id > IP """ def __init__(self, app, max_requests: int = 2, window_seconds: int = 60): super().__init__(app) self.max_requests = max_requests self.window_seconds = window_seconds self.redis_service = get_redis_service() self.mongodb_service = get_mongodb_service() def _get_identifier(self, request: Request) -> tuple[str, str]: """ Get rate limit identifier with priority: user_id > device_id > IP. Returns: Tuple of (identifier, identifier_type) """ # Try to get from headers/cookies user_id = request.headers.get("X-User-Id") or request.cookies.get("user_id") device_id = request.headers.get("X-Device-Id") or request.cookies.get("device_id") # Priority: user_id > device_id > IP if user_id: return f"user:{user_id}", "user_id" elif device_id: return f"device:{device_id}", "device_id" else: # Fallback to IP client_ip = request.client.host if request.client else "unknown" return f"ip:{client_ip}", "ip" async def dispatch( self, request: Request, call_next: Callable ) -> Response: """Process request with rate limiting.""" # Skip rate limiting for admin endpoints if request.url.path.startswith("/admin/"): return await call_next(request) # Skip for health checks if request.url.path in ["/health", "/", "/docs", "/openapi.json"]: return await call_next(request) # Get identifier identifier, identifier_type = self._get_identifier(request) # Check rate limit is_allowed, current_count = self.redis_service.check_rate_limit( identifier, self.max_requests, self.window_seconds ) if not is_allowed: # Log rate limit hit to MongoDB device_id = ( request.headers.get("X-Device-Id") or request.cookies.get("device_id") or identifier ) user_id = request.headers.get("X-User-Id") or request.cookies.get("user_id") self.mongodb_service.log_event( event_type="RATE_LIMIT_HIT", device_id=device_id, user_id=user_id, metadata={ "identifier": identifier, "identifier_type": identifier_type, "count": current_count, "path": request.url.path } ) logger.warning( f"Rate limit exceeded for {identifier} " f"(count: {current_count}, max: {self.max_requests})" ) raise HTTPException( status_code=429, detail={ "error": "Rate limit exceeded", "message": f"Too many requests. Maximum {self.max_requests} requests per {self.window_seconds} seconds.", "retry_after": self.window_seconds } ) # Add rate limit info to response headers response = await call_next(request) response.headers["X-RateLimit-Limit"] = str(self.max_requests) response.headers["X-RateLimit-Remaining"] = str(max(0, self.max_requests - current_count)) response.headers["X-RateLimit-Reset"] = str(self.window_seconds) return response