Spaces:
Sleeping
Sleeping
File size: 4,183 Bytes
4b62d23 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | """
Rate limiting middleware for FastAPI.
"""
import logging
from typing import Callable, Optional
from fastapi import Request, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from .redis_service import get_redis_service
from .mongodb_service import get_mongodb_service
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
Rate limiting middleware using Redis.
Policy: 2 requests per user per minute
Priority: user_id > device_id > IP
"""
def __init__(self, app, max_requests: int = 2, window_seconds: int = 60):
super().__init__(app)
self.max_requests = max_requests
self.window_seconds = window_seconds
self.redis_service = get_redis_service()
self.mongodb_service = get_mongodb_service()
def _get_identifier(self, request: Request) -> tuple[str, str]:
"""
Get rate limit identifier with priority: user_id > device_id > IP.
Returns:
Tuple of (identifier, identifier_type)
"""
# Try to get from headers/cookies
user_id = request.headers.get("X-User-Id") or request.cookies.get("user_id")
device_id = request.headers.get("X-Device-Id") or request.cookies.get("device_id")
# Priority: user_id > device_id > IP
if user_id:
return f"user:{user_id}", "user_id"
elif device_id:
return f"device:{device_id}", "device_id"
else:
# Fallback to IP
client_ip = request.client.host if request.client else "unknown"
return f"ip:{client_ip}", "ip"
async def dispatch(
self,
request: Request,
call_next: Callable
) -> Response:
"""Process request with rate limiting."""
# Skip rate limiting for admin endpoints
if request.url.path.startswith("/admin/"):
return await call_next(request)
# Skip for health checks
if request.url.path in ["/health", "/", "/docs", "/openapi.json"]:
return await call_next(request)
# Get identifier
identifier, identifier_type = self._get_identifier(request)
# Check rate limit
is_allowed, current_count = self.redis_service.check_rate_limit(
identifier,
self.max_requests,
self.window_seconds
)
if not is_allowed:
# Log rate limit hit to MongoDB
device_id = (
request.headers.get("X-Device-Id")
or request.cookies.get("device_id")
or identifier
)
user_id = request.headers.get("X-User-Id") or request.cookies.get("user_id")
self.mongodb_service.log_event(
event_type="RATE_LIMIT_HIT",
device_id=device_id,
user_id=user_id,
metadata={
"identifier": identifier,
"identifier_type": identifier_type,
"count": current_count,
"path": request.url.path
}
)
logger.warning(
f"Rate limit exceeded for {identifier} "
f"(count: {current_count}, max: {self.max_requests})"
)
raise HTTPException(
status_code=429,
detail={
"error": "Rate limit exceeded",
"message": f"Too many requests. Maximum {self.max_requests} requests per {self.window_seconds} seconds.",
"retry_after": self.window_seconds
}
)
# Add rate limit info to response headers
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(self.max_requests)
response.headers["X-RateLimit-Remaining"] = str(max(0, self.max_requests - current_count))
response.headers["X-RateLimit-Reset"] = str(self.window_seconds)
return response
|