""" Knowledge Universe - API Middleware Rate limiting and metrics collection RICK'S FIX: Upgraded RateLimitMiddleware from in-memory defaultdict to Redis-backed atomic counters to survive deployments and scale across workers. """ import time import logging from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse from prometheus_client import Counter, Histogram from config.settings import get_settings settings = get_settings() logger = logging.getLogger(__name__) # Prometheus metrics request_counter = Counter( 'ku_requests_total', 'Total requests', ['method', 'endpoint', 'status'] ) request_duration = Histogram( 'ku_request_duration_seconds', 'Request duration', ['method', 'endpoint'] ) class RateLimitMiddleware(BaseHTTPMiddleware): """ Redis-backed rate limiting to survive deployments and multi-worker scaling. """ def __init__(self, app): super().__init__(app) self.limit = settings.RATE_LIMIT_REQUESTS self.period = settings.RATE_LIMIT_PERIOD async def dispatch(self, request: Request, call_next): # Skip rate limiting for health checks and metrics if request.url.path in ['/health', '/ready', '/metrics']: return await call_next(request) client_id = self._get_client_id(request) try: # Safely get redis client from app state redis_manager = getattr(request.app.state, "redis", None) if redis_manager and redis_manager.client: redis_client = redis_manager.client # Create a time window bucket current_window = int(time.time() / self.period) limit_key = f"ku:ratelimit:{client_id}:{current_window}" # Atomic INCR and EXPIRE pipeline pipe = redis_client.pipeline() pipe.incr(limit_key) pipe.expire(limit_key, self.period) results = await pipe.execute() requests_this_window = results[0] if requests_this_window > self.limit: logger.warning(f"Rate limit exceeded for {client_id}") return JSONResponse( status_code=429, content={ 'error': 'Rate limit exceeded', 'limit': self.limit, 'period': self.period, 'message': 'Maximum requests reached. Please slow down.' }, headers={"Retry-After": str(self.period)} ) except Exception as e: # Fail open if Redis crashes so legitimate users aren't blocked logger.error(f"Rate limiter Redis failure: {e}") # Process request response = await call_next(request) return response def _get_client_id(self, request: Request) -> str: """Get client identifier from request""" # Check for API key api_key_header = getattr(settings, "API_KEY_HEADER", "X-API-Key") api_key = request.headers.get(api_key_header) if api_key: return f"key:{api_key[:8]}" # Only use prefix for cache key security # Fall back to IP address for public routes like /signup return f"ip:{request.client.host if request.client else 'unknown'}" class MetricsMiddleware(BaseHTTPMiddleware): """ Prometheus metrics collection """ async def dispatch(self, request: Request, call_next): # Skip metrics endpoint itself if request.url.path == '/metrics': return await call_next(request) # Record start time start_time = time.time() # Process request try: response = await call_next(request) status_code = response.status_code except Exception as e: logger.error(f"Request failed: {e}") status_code = 500 raise finally: # Record metrics duration = time.time() - start_time request_counter.labels( method=request.method, endpoint=request.url.path, status=status_code ).inc() request_duration.labels( method=request.method, endpoint=request.url.path ).observe(duration) return response