Admin-Desk2 / app /utils /rate_limiter.py
Fred808's picture
Upload 83 files
4d3ce85 verified
from fastapi import HTTPException, Request
from redis.asyncio import Redis
from ..core.config import settings
from ..utils.logger import logger
import time
from typing import Dict
import asyncio
class RateLimiter:
_instance = None
_memory_store: Dict[str, Dict[float, float]] = {}
def __new__(cls):
if cls._instance is None:
cls._instance = super(RateLimiter, cls).__new__(cls)
asyncio.create_task(cls._instance.initialize())
return cls._instance
async def initialize(self):
"""Initialize Redis connection with fallback to in-memory store"""
self.rate_limit = 100 # requests
self.time_window = 60 # seconds
self.is_connected = False
try:
self.redis = Redis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
decode_responses=True,
socket_connect_timeout=1
)
await self.redis.ping()
self.is_connected = True
logger.info("Rate limiter Redis connection initialized successfully")
except Exception as e:
self.is_connected = False
logger.warning(f"Redis connection failed for rate limiter, using in-memory fallback: {str(e)}")
async def check_rate_limit(self, request: Request):
"""Check rate limit for a client IP"""
client_ip = request.client.host
current = time.time()
window_start = current - self.time_window
# Use Redis if available
if self.is_connected:
try:
key = f"rate_limit:{client_ip}"
pipeline = self.redis.pipeline()
await pipeline.zremrangebyscore(key, 0, window_start)
await pipeline.zadd(key, {str(current): current})
await pipeline.zcard(key)
await pipeline.expire(key, self.time_window)
results = await pipeline.execute()
request_count = results[2] # zcard result
except Exception as e:
logger.error(f"Redis rate limit error: {str(e)}")
self.is_connected = False
request_count = await self._check_memory_rate_limit(client_ip, current, window_start)
else:
request_count = await self._check_memory_rate_limit(client_ip, current, window_start)
if request_count > self.rate_limit:
raise HTTPException(
status_code=429,
detail="Too many requests. Please try again later."
)
async def _check_memory_rate_limit(self, client_ip: str, current: float, window_start: float) -> int:
"""Check rate limit using in-memory store"""
if client_ip not in self._memory_store:
self._memory_store[client_ip] = {}
# Clean old entries
self._memory_store[client_ip] = {
ts: score for ts, score in self._memory_store[client_ip].items()
if score > window_start
}
# Add new request
self._memory_store[client_ip][str(current)] = current
# Clean up old IPs periodically
if len(self._memory_store) > 10000: # Prevent memory leak
await self._cleanup_memory_store()
return len(self._memory_store[client_ip])
async def _cleanup_memory_store(self):
"""Clean up old entries from memory store"""
current = time.time()
window_start = current - self.time_window
# Remove old IP entries
old_ips = [
ip for ip, timestamps in self._memory_store.items()
if all(score <= window_start for score in timestamps.values())
]
for ip in old_ips:
del self._memory_store[ip]
async def check_connection(self) -> bool:
"""Check if Redis connection is alive"""
try:
await self.redis.ping()
self.is_connected = True
return True
except:
self.is_connected = False
return False
rate_limiter = RateLimiter()