File size: 9,555 Bytes
8bab08d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""
Enterprise Rate Limiting for MCP Servers

Features:
- Token bucket algorithm for smooth rate limiting
- Per-client rate limiting
- Global rate limiting
- Different limits for different endpoints
- Distributed rate limiting with Redis (optional)
"""
import time
import logging
from typing import Dict, Optional
from collections import defaultdict
from dataclasses import dataclass, field
from aiohttp import web
import asyncio

logger = logging.getLogger(__name__)


@dataclass
class TokenBucket:
    """Token bucket for rate limiting"""
    capacity: int  # Maximum tokens
    refill_rate: float  # Tokens per second
    tokens: float = field(default=0)
    last_refill: float = field(default_factory=time.time)

    def __post_init__(self):
        self.tokens = self.capacity

    def _refill(self):
        """Refill tokens based on time elapsed"""
        now = time.time()
        elapsed = now - self.last_refill

        # Add tokens based on refill rate
        self.tokens = min(
            self.capacity,
            self.tokens + (elapsed * self.refill_rate)
        )
        self.last_refill = now

    def consume(self, tokens: int = 1) -> bool:
        """
        Try to consume tokens

        Returns:
            True if tokens were available, False otherwise
        """
        self._refill()

        if self.tokens >= tokens:
            self.tokens -= tokens
            return True

        return False

    def get_wait_time(self, tokens: int = 1) -> float:
        """
        Get time to wait until tokens are available

        Returns:
            Seconds to wait
        """
        self._refill()

        if self.tokens >= tokens:
            return 0.0

        tokens_needed = tokens - self.tokens
        return tokens_needed / self.refill_rate


class RateLimiter:
    """
    In-memory rate limiter with token bucket algorithm
    """

    def __init__(self):
        # Client-specific buckets
        self.client_buckets: Dict[str, TokenBucket] = {}

        # Global bucket for all requests
        self.global_bucket: Optional[TokenBucket] = None

        # Endpoint-specific limits
        self.endpoint_limits: Dict[str, Dict] = {
            "/rpc": {"capacity": 100, "refill_rate": 10.0},  # 100 requests, 10/sec refill
            "default": {"capacity": 50, "refill_rate": 5.0}  # Default for other endpoints
        }

        # Global rate limit (disabled by default)
        # self.global_bucket = TokenBucket(capacity=1000, refill_rate=100.0)

        # Cleanup task
        self._cleanup_task = None
        logger.info("Rate limiter initialized")

    def _get_client_id(self, request: web.Request) -> str:
        """
        Get client identifier for rate limiting

        Uses (in order):
        1. API key
        2. IP address
        """
        # Try API key first
        if "api_key" in request and hasattr(request["api_key"], "key_id"):
            return f"key:{request['api_key'].key_id}"

        # Fall back to IP address
        peername = request.transport.get_extra_info('peername')
        if peername:
            return f"ip:{peername[0]}"

        return "unknown"

    def _get_endpoint_limits(self, path: str) -> Dict:
        """Get rate limits for endpoint"""
        return self.endpoint_limits.get(path, self.endpoint_limits["default"])

    def _get_or_create_bucket(self, client_id: str, path: str) -> TokenBucket:
        """Get or create token bucket for client"""
        bucket_key = f"{client_id}:{path}"

        if bucket_key not in self.client_buckets:
            limits = self._get_endpoint_limits(path)
            self.client_buckets[bucket_key] = TokenBucket(
                capacity=limits["capacity"],
                refill_rate=limits["refill_rate"]
            )

        return self.client_buckets[bucket_key]

    async def check_rate_limit(
        self,
        request: web.Request,
        tokens: int = 1
    ) -> tuple[bool, Optional[float]]:
        """
        Check if request is within rate limit

        Returns:
            Tuple of (allowed, retry_after_seconds)
        """
        client_id = self._get_client_id(request)
        path = request.path

        # Check global rate limit first (if enabled)
        if self.global_bucket:
            if not self.global_bucket.consume(tokens):
                wait_time = self.global_bucket.get_wait_time(tokens)
                logger.warning(f"Global rate limit exceeded, retry after {wait_time:.2f}s")
                return False, wait_time

        # Check client-specific rate limit
        bucket = self._get_or_create_bucket(client_id, path)

        if not bucket.consume(tokens):
            wait_time = bucket.get_wait_time(tokens)
            logger.warning(f"Rate limit exceeded for {client_id} on {path}, retry after {wait_time:.2f}s")
            return False, wait_time

        return True, None

    async def start_cleanup_task(self):
        """Start background cleanup task"""
        if self._cleanup_task is None:
            self._cleanup_task = asyncio.create_task(self._cleanup_loop())
            logger.info("Rate limiter cleanup task started")

    async def _cleanup_loop(self):
        """Periodically clean up old buckets"""
        while True:
            await asyncio.sleep(300)  # Every 5 minutes

            # Remove buckets that haven't been used recently
            cutoff_time = time.time() - 600  # 10 minutes
            removed = 0

            for key in list(self.client_buckets.keys()):
                bucket = self.client_buckets[key]
                if bucket.last_refill < cutoff_time:
                    del self.client_buckets[key]
                    removed += 1

            if removed > 0:
                logger.info(f"Cleaned up {removed} unused rate limit buckets")


class RateLimitMiddleware:
    """aiohttp middleware for rate limiting"""

    def __init__(self, rate_limiter: RateLimiter, exempt_paths: set[str] = None):
        self.rate_limiter = rate_limiter
        self.exempt_paths = exempt_paths or {"/health", "/metrics"}
        logger.info("Rate limit middleware initialized")

    @web.middleware
    async def middleware(self, request: web.Request, handler):
        """Middleware handler"""

        # Skip rate limiting for exempt paths
        if request.path in self.exempt_paths:
            return await handler(request)

        # Check rate limit
        allowed, retry_after = await self.rate_limiter.check_rate_limit(request)

        if not allowed:
            return web.json_response(
                {
                    "error": "Rate limit exceeded",
                    "message": f"Too many requests. Please retry after {retry_after:.2f} seconds.",
                    "retry_after": retry_after
                },
                status=429,
                headers={"Retry-After": str(int(retry_after) + 1)}
            )

        # Add rate limit headers
        response = await handler(request)

        # TODO: Add X-RateLimit-* headers
        # response.headers["X-RateLimit-Limit"] = "100"
        # response.headers["X-RateLimit-Remaining"] = "95"

        return response


class RedisRateLimiter:
    """
    Distributed rate limiter using Redis
    Suitable for multi-instance deployments
    """

    def __init__(self, redis_client=None):
        """
        Initialize with Redis client

        Args:
            redis_client: redis.asyncio.Redis client
        """
        self.redis = redis_client
        logger.info("Redis rate limiter initialized" if redis_client else "Redis rate limiter (disabled)")

    async def check_rate_limit(
        self,
        key: str,
        limit: int,
        window_seconds: int
    ) -> tuple[bool, Optional[int]]:
        """
        Check rate limit using Redis

        Uses sliding window algorithm with Redis sorted sets

        Returns:
            Tuple of (allowed, retry_after_seconds)
        """
        if not self.redis:
            # If Redis is not available, allow all requests
            return True, None

        now = time.time()
        window_start = now - window_seconds

        try:
            # Redis pipeline for atomic operations
            pipe = self.redis.pipeline()

            # Remove old entries
            pipe.zremrangebyscore(key, 0, window_start)

            # Count current requests
            pipe.zcard(key)

            # Add current request
            pipe.zadd(key, {str(now): now})

            # Set expiry
            pipe.expire(key, window_seconds)

            results = await pipe.execute()

            count = results[1]  # Result from ZCARD

            if count < limit:
                return True, None
            else:
                # Calculate retry time
                oldest_entries = await self.redis.zrange(key, 0, 0, withscores=True)
                if oldest_entries:
                    oldest_time = oldest_entries[0][1]
                    retry_after = int(oldest_time + window_seconds - now) + 1
                    return False, retry_after

                return False, window_seconds

        except Exception as e:
            logger.error(f"Redis rate limit error: {e}")
            # On error, allow request (fail open)
            return True, None


# Global rate limiter instance
_rate_limiter: Optional[RateLimiter] = None


def get_rate_limiter() -> RateLimiter:
    """Get or create the global rate limiter"""
    global _rate_limiter
    if _rate_limiter is None:
        _rate_limiter = RateLimiter()
    return _rate_limiter