ABSA / src /utils /rate_limit_middleware.py
parthnuwal7's picture
Adding Mongo+Redis concept
4b62d23
"""
Rate limiting middleware for FastAPI.
"""
import logging
from typing import Callable, Optional
from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from .redis_service import get_redis_service
from .mongodb_service import get_mongodb_service
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
Rate limiting middleware using Redis.
Policy: 2 requests per user per minute
Priority: user_id > device_id > IP
"""
def __init__(self, app, max_requests: int = 2, window_seconds: int = 60):
super().__init__(app)
self.max_requests = max_requests
self.window_seconds = window_seconds
self.redis_service = get_redis_service()
self.mongodb_service = get_mongodb_service()
def _get_identifier(self, request: Request) -> tuple[str, str]:
"""
Get rate limit identifier with priority: user_id > device_id > IP.
Returns:
Tuple of (identifier, identifier_type)
"""
# Try to get from headers/cookies
user_id = request.headers.get("X-User-Id") or request.cookies.get("user_id")
device_id = request.headers.get("X-Device-Id") or request.cookies.get("device_id")
# Priority: user_id > device_id > IP
if user_id:
return f"user:{user_id}", "user_id"
elif device_id:
return f"device:{device_id}", "device_id"
else:
# Fallback to IP
client_ip = request.client.host if request.client else "unknown"
return f"ip:{client_ip}", "ip"
async def dispatch(
self,
request: Request,
call_next: Callable
) -> Response:
"""Process request with rate limiting."""
# Skip rate limiting for admin endpoints
if request.url.path.startswith("/admin/"):
return await call_next(request)
# Skip for health checks
if request.url.path in ["/health", "/", "/docs", "/openapi.json"]:
return await call_next(request)
# Get identifier
identifier, identifier_type = self._get_identifier(request)
# Check rate limit
is_allowed, current_count = self.redis_service.check_rate_limit(
identifier,
self.max_requests,
self.window_seconds
)
if not is_allowed:
# Log rate limit hit to MongoDB
device_id = (
request.headers.get("X-Device-Id")
or request.cookies.get("device_id")
or identifier
)
user_id = request.headers.get("X-User-Id") or request.cookies.get("user_id")
self.mongodb_service.log_event(
event_type="RATE_LIMIT_HIT",
device_id=device_id,
user_id=user_id,
metadata={
"identifier": identifier,
"identifier_type": identifier_type,
"count": current_count,
"path": request.url.path
}
)
logger.warning(
f"Rate limit exceeded for {identifier} "
f"(count: {current_count}, max: {self.max_requests})"
)
raise HTTPException(
status_code=429,
detail={
"error": "Rate limit exceeded",
"message": f"Too many requests. Maximum {self.max_requests} requests per {self.window_seconds} seconds.",
"retry_after": self.window_seconds
}
)
# Add rate limit info to response headers
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(self.max_requests)
response.headers["X-RateLimit-Remaining"] = str(max(0, self.max_requests - current_count))
response.headers["X-RateLimit-Reset"] = str(self.window_seconds)
return response