from fastapi import HTTPException, Request from redis.asyncio import Redis from ..core.config import settings from ..utils.logger import logger import time from typing import Dict import asyncio class RateLimiter: _instance = None _memory_store: Dict[str, Dict[float, float]] = {} def __new__(cls): if cls._instance is None: cls._instance = super(RateLimiter, cls).__new__(cls) asyncio.create_task(cls._instance.initialize()) return cls._instance async def initialize(self): """Initialize Redis connection with fallback to in-memory store""" self.rate_limit = 100 # requests self.time_window = 60 # seconds self.is_connected = False try: self.redis = Redis( host=settings.REDIS_HOST, port=settings.REDIS_PORT, decode_responses=True, socket_connect_timeout=1 ) await self.redis.ping() self.is_connected = True logger.info("Rate limiter Redis connection initialized successfully") except Exception as e: self.is_connected = False logger.warning(f"Redis connection failed for rate limiter, using in-memory fallback: {str(e)}") async def check_rate_limit(self, request: Request): """Check rate limit for a client IP""" client_ip = request.client.host current = time.time() window_start = current - self.time_window # Use Redis if available if self.is_connected: try: key = f"rate_limit:{client_ip}" pipeline = self.redis.pipeline() await pipeline.zremrangebyscore(key, 0, window_start) await pipeline.zadd(key, {str(current): current}) await pipeline.zcard(key) await pipeline.expire(key, self.time_window) results = await pipeline.execute() request_count = results[2] # zcard result except Exception as e: logger.error(f"Redis rate limit error: {str(e)}") self.is_connected = False request_count = await self._check_memory_rate_limit(client_ip, current, window_start) else: request_count = await self._check_memory_rate_limit(client_ip, current, window_start) if request_count > self.rate_limit: raise HTTPException( status_code=429, detail="Too many requests. Please try again later." ) async def _check_memory_rate_limit(self, client_ip: str, current: float, window_start: float) -> int: """Check rate limit using in-memory store""" if client_ip not in self._memory_store: self._memory_store[client_ip] = {} # Clean old entries self._memory_store[client_ip] = { ts: score for ts, score in self._memory_store[client_ip].items() if score > window_start } # Add new request self._memory_store[client_ip][str(current)] = current # Clean up old IPs periodically if len(self._memory_store) > 10000: # Prevent memory leak await self._cleanup_memory_store() return len(self._memory_store[client_ip]) async def _cleanup_memory_store(self): """Clean up old entries from memory store""" current = time.time() window_start = current - self.time_window # Remove old IP entries old_ips = [ ip for ip, timestamps in self._memory_store.items() if all(score <= window_start for score in timestamps.values()) ] for ip in old_ips: del self._memory_store[ip] async def check_connection(self) -> bool: """Check if Redis connection is alive""" try: await self.redis.ping() self.is_connected = True return True except: self.is_connected = False return False rate_limiter = RateLimiter()