Spaces:
Running
Running
| """ | |
| Knowledge Universe - API Middleware | |
| Rate limiting and metrics collection | |
| RICK'S FIX: Upgraded RateLimitMiddleware from in-memory defaultdict | |
| to Redis-backed atomic counters to survive deployments and scale across workers. | |
| """ | |
| import time | |
| import logging | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from starlette.requests import Request | |
| from starlette.responses import JSONResponse | |
| from prometheus_client import Counter, Histogram | |
| from config.settings import get_settings | |
| settings = get_settings() | |
| logger = logging.getLogger(__name__) | |
| # Prometheus metrics | |
| request_counter = Counter( | |
| 'ku_requests_total', | |
| 'Total requests', | |
| ['method', 'endpoint', 'status'] | |
| ) | |
| request_duration = Histogram( | |
| 'ku_request_duration_seconds', | |
| 'Request duration', | |
| ['method', 'endpoint'] | |
| ) | |
| class RateLimitMiddleware(BaseHTTPMiddleware): | |
| """ | |
| Redis-backed rate limiting to survive deployments and multi-worker scaling. | |
| """ | |
| def __init__(self, app): | |
| super().__init__(app) | |
| self.limit = settings.RATE_LIMIT_REQUESTS | |
| self.period = settings.RATE_LIMIT_PERIOD | |
| async def dispatch(self, request: Request, call_next): | |
| # Skip rate limiting for health checks and metrics | |
| if request.url.path in ['/health', '/ready', '/metrics']: | |
| return await call_next(request) | |
| client_id = self._get_client_id(request) | |
| try: | |
| # Safely get redis client from app state | |
| redis_manager = getattr(request.app.state, "redis", None) | |
| if redis_manager and redis_manager.client: | |
| redis_client = redis_manager.client | |
| # Create a time window bucket | |
| current_window = int(time.time() / self.period) | |
| limit_key = f"ku:ratelimit:{client_id}:{current_window}" | |
| # Atomic INCR and EXPIRE pipeline | |
| pipe = redis_client.pipeline() | |
| pipe.incr(limit_key) | |
| pipe.expire(limit_key, self.period) | |
| results = await pipe.execute() | |
| requests_this_window = results[0] | |
| if requests_this_window > self.limit: | |
| logger.warning(f"Rate limit exceeded for {client_id}") | |
| return JSONResponse( | |
| status_code=429, | |
| content={ | |
| 'error': 'Rate limit exceeded', | |
| 'limit': self.limit, | |
| 'period': self.period, | |
| 'message': 'Maximum requests reached. Please slow down.' | |
| }, | |
| headers={"Retry-After": str(self.period)} | |
| ) | |
| except Exception as e: | |
| # Fail open if Redis crashes so legitimate users aren't blocked | |
| logger.error(f"Rate limiter Redis failure: {e}") | |
| # Process request | |
| response = await call_next(request) | |
| return response | |
| def _get_client_id(self, request: Request) -> str: | |
| """Get client identifier from request""" | |
| # Check for API key | |
| api_key_header = getattr(settings, "API_KEY_HEADER", "X-API-Key") | |
| api_key = request.headers.get(api_key_header) | |
| if api_key: | |
| return f"key:{api_key[:8]}" # Only use prefix for cache key security | |
| # Fall back to IP address for public routes like /signup | |
| return f"ip:{request.client.host if request.client else 'unknown'}" | |
| class MetricsMiddleware(BaseHTTPMiddleware): | |
| """ | |
| Prometheus metrics collection | |
| """ | |
| async def dispatch(self, request: Request, call_next): | |
| # Skip metrics endpoint itself | |
| if request.url.path == '/metrics': | |
| return await call_next(request) | |
| # Record start time | |
| start_time = time.time() | |
| # Process request | |
| try: | |
| response = await call_next(request) | |
| status_code = response.status_code | |
| except Exception as e: | |
| logger.error(f"Request failed: {e}") | |
| status_code = 500 | |
| raise | |
| finally: | |
| # Record metrics | |
| duration = time.time() - start_time | |
| request_counter.labels( | |
| method=request.method, | |
| endpoint=request.url.path, | |
| status=status_code | |
| ).inc() | |
| request_duration.labels( | |
| method=request.method, | |
| endpoint=request.url.path | |
| ).observe(duration) | |
| return response |