File size: 4,213 Bytes
b70ff07
4d3ce85
b70ff07
 
 
 
 
 
 
 
 
 
 
 
 
4d3ce85
b70ff07
 
4d3ce85
b70ff07
 
 
4d3ce85
b70ff07
 
 
 
 
 
4d3ce85
b70ff07
4d3ce85
b70ff07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d3ce85
 
 
 
 
 
b70ff07
 
 
4d3ce85
b70ff07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d3ce85
b70ff07
 
4d3ce85
b70ff07
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from fastapi import HTTPException, Request
from redis.asyncio import Redis
from ..core.config import settings
from ..utils.logger import logger
import time
from typing import Dict
import asyncio

class RateLimiter:
    _instance = None
    _memory_store: Dict[str, Dict[float, float]] = {}

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(RateLimiter, cls).__new__(cls)
            asyncio.create_task(cls._instance.initialize())
        return cls._instance

    async def initialize(self):
        """Initialize Redis connection with fallback to in-memory store"""
        self.rate_limit = 100  # requests
        self.time_window = 60  # seconds
        self.is_connected = False
        
        try:
            self.redis = Redis(
                host=settings.REDIS_HOST,
                port=settings.REDIS_PORT,
                decode_responses=True,
                socket_connect_timeout=1
            )
            await self.redis.ping()
            self.is_connected = True
            logger.info("Rate limiter Redis connection initialized successfully")
        except Exception as e:
            self.is_connected = False
            logger.warning(f"Redis connection failed for rate limiter, using in-memory fallback: {str(e)}")

    async def check_rate_limit(self, request: Request):
        """Check rate limit for a client IP"""
        client_ip = request.client.host
        current = time.time()
        window_start = current - self.time_window

        # Use Redis if available
        if self.is_connected:
            try:
                key = f"rate_limit:{client_ip}"
                pipeline = self.redis.pipeline()
                await pipeline.zremrangebyscore(key, 0, window_start)
                await pipeline.zadd(key, {str(current): current})
                await pipeline.zcard(key)
                await pipeline.expire(key, self.time_window)
                results = await pipeline.execute()
                request_count = results[2]  # zcard result
            except Exception as e:
                logger.error(f"Redis rate limit error: {str(e)}")
                self.is_connected = False
                request_count = await self._check_memory_rate_limit(client_ip, current, window_start)
        else:
            request_count = await self._check_memory_rate_limit(client_ip, current, window_start)

        if request_count > self.rate_limit:
            raise HTTPException(
                status_code=429,
                detail="Too many requests. Please try again later."
            )

    async def _check_memory_rate_limit(self, client_ip: str, current: float, window_start: float) -> int:
        """Check rate limit using in-memory store"""
        if client_ip not in self._memory_store:
            self._memory_store[client_ip] = {}

        # Clean old entries
        self._memory_store[client_ip] = {
            ts: score for ts, score in self._memory_store[client_ip].items()
            if score > window_start
        }

        # Add new request
        self._memory_store[client_ip][str(current)] = current

        # Clean up old IPs periodically
        if len(self._memory_store) > 10000:  # Prevent memory leak
            await self._cleanup_memory_store()

        return len(self._memory_store[client_ip])

    async def _cleanup_memory_store(self):
        """Clean up old entries from memory store"""
        current = time.time()
        window_start = current - self.time_window
        
        # Remove old IP entries
        old_ips = [
            ip for ip, timestamps in self._memory_store.items()
            if all(score <= window_start for score in timestamps.values())
        ]
        for ip in old_ips:
            del self._memory_store[ip]

    async def check_connection(self) -> bool:
        """Check if Redis connection is alive"""
        try:
            await self.redis.ping()
            self.is_connected = True
            return True
        except:
            self.is_connected = False
            return False

rate_limiter = RateLimiter()