Spaces:
Paused
Paused
| from fastapi import HTTPException, Request | |
| from redis.asyncio import Redis | |
| from ..core.config import settings | |
| from ..utils.logger import logger | |
| import time | |
| from typing import Dict | |
| import asyncio | |
| class RateLimiter: | |
| _instance = None | |
| _memory_store: Dict[str, Dict[float, float]] = {} | |
| def __new__(cls): | |
| if cls._instance is None: | |
| cls._instance = super(RateLimiter, cls).__new__(cls) | |
| asyncio.create_task(cls._instance.initialize()) | |
| return cls._instance | |
| async def initialize(self): | |
| """Initialize Redis connection with fallback to in-memory store""" | |
| self.rate_limit = 100 # requests | |
| self.time_window = 60 # seconds | |
| self.is_connected = False | |
| try: | |
| self.redis = Redis( | |
| host=settings.REDIS_HOST, | |
| port=settings.REDIS_PORT, | |
| decode_responses=True, | |
| socket_connect_timeout=1 | |
| ) | |
| await self.redis.ping() | |
| self.is_connected = True | |
| logger.info("Rate limiter Redis connection initialized successfully") | |
| except Exception as e: | |
| self.is_connected = False | |
| logger.warning(f"Redis connection failed for rate limiter, using in-memory fallback: {str(e)}") | |
| async def check_rate_limit(self, request: Request): | |
| """Check rate limit for a client IP""" | |
| client_ip = request.client.host | |
| current = time.time() | |
| window_start = current - self.time_window | |
| # Use Redis if available | |
| if self.is_connected: | |
| try: | |
| key = f"rate_limit:{client_ip}" | |
| pipeline = self.redis.pipeline() | |
| await pipeline.zremrangebyscore(key, 0, window_start) | |
| await pipeline.zadd(key, {str(current): current}) | |
| await pipeline.zcard(key) | |
| await pipeline.expire(key, self.time_window) | |
| results = await pipeline.execute() | |
| request_count = results[2] # zcard result | |
| except Exception as e: | |
| logger.error(f"Redis rate limit error: {str(e)}") | |
| self.is_connected = False | |
| request_count = await self._check_memory_rate_limit(client_ip, current, window_start) | |
| else: | |
| request_count = await self._check_memory_rate_limit(client_ip, current, window_start) | |
| if request_count > self.rate_limit: | |
| raise HTTPException( | |
| status_code=429, | |
| detail="Too many requests. Please try again later." | |
| ) | |
| async def _check_memory_rate_limit(self, client_ip: str, current: float, window_start: float) -> int: | |
| """Check rate limit using in-memory store""" | |
| if client_ip not in self._memory_store: | |
| self._memory_store[client_ip] = {} | |
| # Clean old entries | |
| self._memory_store[client_ip] = { | |
| ts: score for ts, score in self._memory_store[client_ip].items() | |
| if score > window_start | |
| } | |
| # Add new request | |
| self._memory_store[client_ip][str(current)] = current | |
| # Clean up old IPs periodically | |
| if len(self._memory_store) > 10000: # Prevent memory leak | |
| await self._cleanup_memory_store() | |
| return len(self._memory_store[client_ip]) | |
| async def _cleanup_memory_store(self): | |
| """Clean up old entries from memory store""" | |
| current = time.time() | |
| window_start = current - self.time_window | |
| # Remove old IP entries | |
| old_ips = [ | |
| ip for ip, timestamps in self._memory_store.items() | |
| if all(score <= window_start for score in timestamps.values()) | |
| ] | |
| for ip in old_ips: | |
| del self._memory_store[ip] | |
| async def check_connection(self) -> bool: | |
| """Check if Redis connection is alive""" | |
| try: | |
| await self.redis.ping() | |
| self.is_connected = True | |
| return True | |
| except: | |
| self.is_connected = False | |
| return False | |
| rate_limiter = RateLimiter() |