|
|
""" |
|
|
Enterprise Rate Limiting for MCP Servers |
|
|
|
|
|
Features: |
|
|
- Token bucket algorithm for smooth rate limiting |
|
|
- Per-client rate limiting |
|
|
- Global rate limiting |
|
|
- Different limits for different endpoints |
|
|
- Distributed rate limiting with Redis (optional) |
|
|
""" |
|
|
import time |
|
|
import logging |
|
|
from typing import Dict, Optional |
|
|
from collections import defaultdict |
|
|
from dataclasses import dataclass, field |
|
|
from aiohttp import web |
|
|
import asyncio |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TokenBucket: |
|
|
"""Token bucket for rate limiting""" |
|
|
capacity: int |
|
|
refill_rate: float |
|
|
tokens: float = field(default=0) |
|
|
last_refill: float = field(default_factory=time.time) |
|
|
|
|
|
def __post_init__(self): |
|
|
self.tokens = self.capacity |
|
|
|
|
|
def _refill(self): |
|
|
"""Refill tokens based on time elapsed""" |
|
|
now = time.time() |
|
|
elapsed = now - self.last_refill |
|
|
|
|
|
|
|
|
self.tokens = min( |
|
|
self.capacity, |
|
|
self.tokens + (elapsed * self.refill_rate) |
|
|
) |
|
|
self.last_refill = now |
|
|
|
|
|
def consume(self, tokens: int = 1) -> bool: |
|
|
""" |
|
|
Try to consume tokens |
|
|
|
|
|
Returns: |
|
|
True if tokens were available, False otherwise |
|
|
""" |
|
|
self._refill() |
|
|
|
|
|
if self.tokens >= tokens: |
|
|
self.tokens -= tokens |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def get_wait_time(self, tokens: int = 1) -> float: |
|
|
""" |
|
|
Get time to wait until tokens are available |
|
|
|
|
|
Returns: |
|
|
Seconds to wait |
|
|
""" |
|
|
self._refill() |
|
|
|
|
|
if self.tokens >= tokens: |
|
|
return 0.0 |
|
|
|
|
|
tokens_needed = tokens - self.tokens |
|
|
return tokens_needed / self.refill_rate |
|
|
|
|
|
|
|
|
class RateLimiter: |
|
|
""" |
|
|
In-memory rate limiter with token bucket algorithm |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
self.client_buckets: Dict[str, TokenBucket] = {} |
|
|
|
|
|
|
|
|
self.global_bucket: Optional[TokenBucket] = None |
|
|
|
|
|
|
|
|
self.endpoint_limits: Dict[str, Dict] = { |
|
|
"/rpc": {"capacity": 100, "refill_rate": 10.0}, |
|
|
"default": {"capacity": 50, "refill_rate": 5.0} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self._cleanup_task = None |
|
|
logger.info("Rate limiter initialized") |
|
|
|
|
|
def _get_client_id(self, request: web.Request) -> str: |
|
|
""" |
|
|
Get client identifier for rate limiting |
|
|
|
|
|
Uses (in order): |
|
|
1. API key |
|
|
2. IP address |
|
|
""" |
|
|
|
|
|
if "api_key" in request and hasattr(request["api_key"], "key_id"): |
|
|
return f"key:{request['api_key'].key_id}" |
|
|
|
|
|
|
|
|
peername = request.transport.get_extra_info('peername') |
|
|
if peername: |
|
|
return f"ip:{peername[0]}" |
|
|
|
|
|
return "unknown" |
|
|
|
|
|
def _get_endpoint_limits(self, path: str) -> Dict: |
|
|
"""Get rate limits for endpoint""" |
|
|
return self.endpoint_limits.get(path, self.endpoint_limits["default"]) |
|
|
|
|
|
def _get_or_create_bucket(self, client_id: str, path: str) -> TokenBucket: |
|
|
"""Get or create token bucket for client""" |
|
|
bucket_key = f"{client_id}:{path}" |
|
|
|
|
|
if bucket_key not in self.client_buckets: |
|
|
limits = self._get_endpoint_limits(path) |
|
|
self.client_buckets[bucket_key] = TokenBucket( |
|
|
capacity=limits["capacity"], |
|
|
refill_rate=limits["refill_rate"] |
|
|
) |
|
|
|
|
|
return self.client_buckets[bucket_key] |
|
|
|
|
|
async def check_rate_limit( |
|
|
self, |
|
|
request: web.Request, |
|
|
tokens: int = 1 |
|
|
) -> tuple[bool, Optional[float]]: |
|
|
""" |
|
|
Check if request is within rate limit |
|
|
|
|
|
Returns: |
|
|
Tuple of (allowed, retry_after_seconds) |
|
|
""" |
|
|
client_id = self._get_client_id(request) |
|
|
path = request.path |
|
|
|
|
|
|
|
|
if self.global_bucket: |
|
|
if not self.global_bucket.consume(tokens): |
|
|
wait_time = self.global_bucket.get_wait_time(tokens) |
|
|
logger.warning(f"Global rate limit exceeded, retry after {wait_time:.2f}s") |
|
|
return False, wait_time |
|
|
|
|
|
|
|
|
bucket = self._get_or_create_bucket(client_id, path) |
|
|
|
|
|
if not bucket.consume(tokens): |
|
|
wait_time = bucket.get_wait_time(tokens) |
|
|
logger.warning(f"Rate limit exceeded for {client_id} on {path}, retry after {wait_time:.2f}s") |
|
|
return False, wait_time |
|
|
|
|
|
return True, None |
|
|
|
|
|
async def start_cleanup_task(self): |
|
|
"""Start background cleanup task""" |
|
|
if self._cleanup_task is None: |
|
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop()) |
|
|
logger.info("Rate limiter cleanup task started") |
|
|
|
|
|
async def _cleanup_loop(self): |
|
|
"""Periodically clean up old buckets""" |
|
|
while True: |
|
|
await asyncio.sleep(300) |
|
|
|
|
|
|
|
|
cutoff_time = time.time() - 600 |
|
|
removed = 0 |
|
|
|
|
|
for key in list(self.client_buckets.keys()): |
|
|
bucket = self.client_buckets[key] |
|
|
if bucket.last_refill < cutoff_time: |
|
|
del self.client_buckets[key] |
|
|
removed += 1 |
|
|
|
|
|
if removed > 0: |
|
|
logger.info(f"Cleaned up {removed} unused rate limit buckets") |
|
|
|
|
|
|
|
|
class RateLimitMiddleware: |
|
|
"""aiohttp middleware for rate limiting""" |
|
|
|
|
|
def __init__(self, rate_limiter: RateLimiter, exempt_paths: set[str] = None): |
|
|
self.rate_limiter = rate_limiter |
|
|
self.exempt_paths = exempt_paths or {"/health", "/metrics"} |
|
|
logger.info("Rate limit middleware initialized") |
|
|
|
|
|
@web.middleware |
|
|
async def middleware(self, request: web.Request, handler): |
|
|
"""Middleware handler""" |
|
|
|
|
|
|
|
|
if request.path in self.exempt_paths: |
|
|
return await handler(request) |
|
|
|
|
|
|
|
|
allowed, retry_after = await self.rate_limiter.check_rate_limit(request) |
|
|
|
|
|
if not allowed: |
|
|
return web.json_response( |
|
|
{ |
|
|
"error": "Rate limit exceeded", |
|
|
"message": f"Too many requests. Please retry after {retry_after:.2f} seconds.", |
|
|
"retry_after": retry_after |
|
|
}, |
|
|
status=429, |
|
|
headers={"Retry-After": str(int(retry_after) + 1)} |
|
|
) |
|
|
|
|
|
|
|
|
response = await handler(request) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
class RedisRateLimiter: |
|
|
""" |
|
|
Distributed rate limiter using Redis |
|
|
Suitable for multi-instance deployments |
|
|
""" |
|
|
|
|
|
def __init__(self, redis_client=None): |
|
|
""" |
|
|
Initialize with Redis client |
|
|
|
|
|
Args: |
|
|
redis_client: redis.asyncio.Redis client |
|
|
""" |
|
|
self.redis = redis_client |
|
|
logger.info("Redis rate limiter initialized" if redis_client else "Redis rate limiter (disabled)") |
|
|
|
|
|
async def check_rate_limit( |
|
|
self, |
|
|
key: str, |
|
|
limit: int, |
|
|
window_seconds: int |
|
|
) -> tuple[bool, Optional[int]]: |
|
|
""" |
|
|
Check rate limit using Redis |
|
|
|
|
|
Uses sliding window algorithm with Redis sorted sets |
|
|
|
|
|
Returns: |
|
|
Tuple of (allowed, retry_after_seconds) |
|
|
""" |
|
|
if not self.redis: |
|
|
|
|
|
return True, None |
|
|
|
|
|
now = time.time() |
|
|
window_start = now - window_seconds |
|
|
|
|
|
try: |
|
|
|
|
|
pipe = self.redis.pipeline() |
|
|
|
|
|
|
|
|
pipe.zremrangebyscore(key, 0, window_start) |
|
|
|
|
|
|
|
|
pipe.zcard(key) |
|
|
|
|
|
|
|
|
pipe.zadd(key, {str(now): now}) |
|
|
|
|
|
|
|
|
pipe.expire(key, window_seconds) |
|
|
|
|
|
results = await pipe.execute() |
|
|
|
|
|
count = results[1] |
|
|
|
|
|
if count < limit: |
|
|
return True, None |
|
|
else: |
|
|
|
|
|
oldest_entries = await self.redis.zrange(key, 0, 0, withscores=True) |
|
|
if oldest_entries: |
|
|
oldest_time = oldest_entries[0][1] |
|
|
retry_after = int(oldest_time + window_seconds - now) + 1 |
|
|
return False, retry_after |
|
|
|
|
|
return False, window_seconds |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Redis rate limit error: {e}") |
|
|
|
|
|
return True, None |
|
|
|
|
|
|
|
|
|
|
|
_rate_limiter: Optional[RateLimiter] = None |
|
|
|
|
|
|
|
|
def get_rate_limiter() -> RateLimiter: |
|
|
"""Get or create the global rate limiter""" |
|
|
global _rate_limiter |
|
|
if _rate_limiter is None: |
|
|
_rate_limiter = RateLimiter() |
|
|
return _rate_limiter |
|
|
|