Aoun-Ai / app /core /rate_limiter.py
MuhammadMahmoud's picture
feat: Hardened routing layer with Redis-backed rate limiting, distributed circuit breakers, bulkheads, and full OTel/Prom observability with Grafana alerts
78e0e85
import time
import asyncio
from typing import Dict, List, Tuple, Optional
from app.core.redis_client import redis_client
class SlidingWindowRateLimiter:
"""Multi-Layer Rate Limiting supporting both distributed Redis or local Memory fallback."""
def __init__(self, max_requests: int, window_seconds: int):
self.max_requests = max_requests
self.window_seconds = window_seconds
self._lua_script = None
# Memory fallback state
self._requests: Dict[str, List[float]] = {}
@property
def _lua(self):
"""
Lazy-register the Lua script once per process. Keeps the hot-path to a single
EVALSHA call instead of 3–4 round trips (zrem/zcard/zadd/expire).
"""
if self._lua_script or not (redis_client.is_connected and redis_client.redis):
return self._lua_script
lua = r"""
-- KEYS[1] = rate limit key
-- ARGV[1] = now (epoch seconds, float allowed)
-- ARGV[2] = window_seconds
-- ARGV[3] = max_requests
local key = KEYS[1]
local now = tonumber(ARGV[1])
local window = tonumber(ARGV[2])
local limit = tonumber(ARGV[3])
-- trim old
redis.call('ZREMRANGEBYSCORE', key, '-inf', now - window)
local count = redis.call('ZCARD', key)
if count < limit then
-- use now as both score and member; append small random to avoid dup members
local member = tostring(now) .. '-' .. math.random()
redis.call('ZADD', key, now, member)
redis.call('EXPIRE', key, window)
return {1, 0}
else
local oldest = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')
local retry = window
if oldest and #oldest >= 2 then
local oldest_ts = tonumber(oldest[2])
retry = math.floor(window - (now - oldest_ts) + 0.999)
end
return {0, retry}
end
"""
try:
self._lua_script = redis_client.redis.register_script(lua)
except Exception:
self._lua_script = None
return self._lua_script
async def _eval_redis(self, redis_key: str, now: float) -> Optional[Tuple[bool, int]]:
"""
Execute the Lua limiter. Returns (allowed, retry_after) or None on failure.
"""
if not self._lua:
return None
try:
allowed, retry = await self._lua(keys=[redis_key], args=[now, self.window_seconds, self.max_requests])
return bool(allowed), int(retry)
except Exception:
return None
async def is_allowed(self, user_key: str) -> Tuple[bool, int]:
now = time.time()
cutoff = now - self.window_seconds
# If Redis is connected, use Distributed sorted sets logic
if redis_client.is_connected and redis_client.redis:
redis_key = f"rl:{self.window_seconds}:{user_key}"
lua_result = await self._eval_redis(redis_key, now)
if lua_result:
return lua_result
# If Redis fails mid-operation, degrade safely to memory
# ── LOCAL MEMORY FALLBACK ──
if user_key not in self._requests:
self._requests[user_key] = []
active = [t for t in self._requests[user_key] if t > cutoff]
if len(active) < self.max_requests:
active.append(now)
self._requests[user_key] = active
return True, 0
oldest_ts = active[0]
retry_in = int(self.window_seconds - (now - oldest_ts))
self._requests[user_key] = active
return False, max(1, retry_in)
# Session Limiter - limit to 15 messages per minute per session
chat_rate_limiter = SlidingWindowRateLimiter(max_requests=15, window_seconds=60)
# IP Limiter - limit to 30 messages per minute per IP Address to prevent DDoS
ip_rate_limiter = SlidingWindowRateLimiter(max_requests=30, window_seconds=60)
# Global Singleton - limits the entire system to prevent upstream provider exhaustion
global_rate_limiter = SlidingWindowRateLimiter(max_requests=1000, window_seconds=60)