Knowledge-Universe / src /api /middleware.py
vlsiddarth's picture
feat: Enterprise deployment - Decay velocity, cron persistence, and surgical crawler fixes
2b9ab6a
"""
Knowledge Universe - API Middleware
Rate limiting and metrics collection
RICK'S FIX: Upgraded RateLimitMiddleware from in-memory defaultdict
to Redis-backed atomic counters to survive deployments and scale across workers.
"""
import time
import logging
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from prometheus_client import Counter, Histogram
from config.settings import get_settings
settings = get_settings()
logger = logging.getLogger(__name__)
# Prometheus metrics
request_counter = Counter(
'ku_requests_total',
'Total requests',
['method', 'endpoint', 'status']
)
request_duration = Histogram(
'ku_request_duration_seconds',
'Request duration',
['method', 'endpoint']
)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
Redis-backed rate limiting to survive deployments and multi-worker scaling.
"""
def __init__(self, app):
super().__init__(app)
self.limit = settings.RATE_LIMIT_REQUESTS
self.period = settings.RATE_LIMIT_PERIOD
async def dispatch(self, request: Request, call_next):
# Skip rate limiting for health checks and metrics
if request.url.path in ['/health', '/ready', '/metrics']:
return await call_next(request)
client_id = self._get_client_id(request)
try:
# Safely get redis client from app state
redis_manager = getattr(request.app.state, "redis", None)
if redis_manager and redis_manager.client:
redis_client = redis_manager.client
# Create a time window bucket
current_window = int(time.time() / self.period)
limit_key = f"ku:ratelimit:{client_id}:{current_window}"
# Atomic INCR and EXPIRE pipeline
pipe = redis_client.pipeline()
pipe.incr(limit_key)
pipe.expire(limit_key, self.period)
results = await pipe.execute()
requests_this_window = results[0]
if requests_this_window > self.limit:
logger.warning(f"Rate limit exceeded for {client_id}")
return JSONResponse(
status_code=429,
content={
'error': 'Rate limit exceeded',
'limit': self.limit,
'period': self.period,
'message': 'Maximum requests reached. Please slow down.'
},
headers={"Retry-After": str(self.period)}
)
except Exception as e:
# Fail open if Redis crashes so legitimate users aren't blocked
logger.error(f"Rate limiter Redis failure: {e}")
# Process request
response = await call_next(request)
return response
def _get_client_id(self, request: Request) -> str:
"""Get client identifier from request"""
# Check for API key
api_key_header = getattr(settings, "API_KEY_HEADER", "X-API-Key")
api_key = request.headers.get(api_key_header)
if api_key:
return f"key:{api_key[:8]}" # Only use prefix for cache key security
# Fall back to IP address for public routes like /signup
return f"ip:{request.client.host if request.client else 'unknown'}"
class MetricsMiddleware(BaseHTTPMiddleware):
"""
Prometheus metrics collection
"""
async def dispatch(self, request: Request, call_next):
# Skip metrics endpoint itself
if request.url.path == '/metrics':
return await call_next(request)
# Record start time
start_time = time.time()
# Process request
try:
response = await call_next(request)
status_code = response.status_code
except Exception as e:
logger.error(f"Request failed: {e}")
status_code = 500
raise
finally:
# Record metrics
duration = time.time() - start_time
request_counter.labels(
method=request.method,
endpoint=request.url.path,
status=status_code
).inc()
request_duration.labels(
method=request.method,
endpoint=request.url.path
).observe(duration)
return response