Spaces:
Paused
Paused
File size: 4,213 Bytes
b70ff07 4d3ce85 b70ff07 4d3ce85 b70ff07 4d3ce85 b70ff07 4d3ce85 b70ff07 4d3ce85 b70ff07 4d3ce85 b70ff07 4d3ce85 b70ff07 4d3ce85 b70ff07 4d3ce85 b70ff07 4d3ce85 b70ff07 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | from fastapi import HTTPException, Request
from redis.asyncio import Redis
from ..core.config import settings
from ..utils.logger import logger
import time
from typing import Dict
import asyncio
class RateLimiter:
_instance = None
_memory_store: Dict[str, Dict[float, float]] = {}
def __new__(cls):
if cls._instance is None:
cls._instance = super(RateLimiter, cls).__new__(cls)
asyncio.create_task(cls._instance.initialize())
return cls._instance
async def initialize(self):
"""Initialize Redis connection with fallback to in-memory store"""
self.rate_limit = 100 # requests
self.time_window = 60 # seconds
self.is_connected = False
try:
self.redis = Redis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
decode_responses=True,
socket_connect_timeout=1
)
await self.redis.ping()
self.is_connected = True
logger.info("Rate limiter Redis connection initialized successfully")
except Exception as e:
self.is_connected = False
logger.warning(f"Redis connection failed for rate limiter, using in-memory fallback: {str(e)}")
async def check_rate_limit(self, request: Request):
"""Check rate limit for a client IP"""
client_ip = request.client.host
current = time.time()
window_start = current - self.time_window
# Use Redis if available
if self.is_connected:
try:
key = f"rate_limit:{client_ip}"
pipeline = self.redis.pipeline()
await pipeline.zremrangebyscore(key, 0, window_start)
await pipeline.zadd(key, {str(current): current})
await pipeline.zcard(key)
await pipeline.expire(key, self.time_window)
results = await pipeline.execute()
request_count = results[2] # zcard result
except Exception as e:
logger.error(f"Redis rate limit error: {str(e)}")
self.is_connected = False
request_count = await self._check_memory_rate_limit(client_ip, current, window_start)
else:
request_count = await self._check_memory_rate_limit(client_ip, current, window_start)
if request_count > self.rate_limit:
raise HTTPException(
status_code=429,
detail="Too many requests. Please try again later."
)
async def _check_memory_rate_limit(self, client_ip: str, current: float, window_start: float) -> int:
"""Check rate limit using in-memory store"""
if client_ip not in self._memory_store:
self._memory_store[client_ip] = {}
# Clean old entries
self._memory_store[client_ip] = {
ts: score for ts, score in self._memory_store[client_ip].items()
if score > window_start
}
# Add new request
self._memory_store[client_ip][str(current)] = current
# Clean up old IPs periodically
if len(self._memory_store) > 10000: # Prevent memory leak
await self._cleanup_memory_store()
return len(self._memory_store[client_ip])
async def _cleanup_memory_store(self):
"""Clean up old entries from memory store"""
current = time.time()
window_start = current - self.time_window
# Remove old IP entries
old_ips = [
ip for ip, timestamps in self._memory_store.items()
if all(score <= window_start for score in timestamps.values())
]
for ip in old_ips:
del self._memory_store[ip]
async def check_connection(self) -> bool:
"""Check if Redis connection is alive"""
try:
await self.redis.ping()
self.is_connected = True
return True
except:
self.is_connected = False
return False
rate_limiter = RateLimiter() |