Spaces:
Running
Running
File size: 4,700 Bytes
3acb982 2b9ab6a 3acb982 2b9ab6a 3acb982 2b9ab6a 3acb982 2b9ab6a 3acb982 2b9ab6a 3acb982 2b9ab6a 3acb982 2b9ab6a 3acb982 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | """
Knowledge Universe - API Middleware
Rate limiting and metrics collection
RICK'S FIX: Upgraded RateLimitMiddleware from in-memory defaultdict
to Redis-backed atomic counters to survive deployments and scale across workers.
"""
import time
import logging
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from prometheus_client import Counter, Histogram
from config.settings import get_settings
settings = get_settings()
logger = logging.getLogger(__name__)
# Prometheus metrics
request_counter = Counter(
'ku_requests_total',
'Total requests',
['method', 'endpoint', 'status']
)
request_duration = Histogram(
'ku_request_duration_seconds',
'Request duration',
['method', 'endpoint']
)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
Redis-backed rate limiting to survive deployments and multi-worker scaling.
"""
def __init__(self, app):
super().__init__(app)
self.limit = settings.RATE_LIMIT_REQUESTS
self.period = settings.RATE_LIMIT_PERIOD
async def dispatch(self, request: Request, call_next):
# Skip rate limiting for health checks and metrics
if request.url.path in ['/health', '/ready', '/metrics']:
return await call_next(request)
client_id = self._get_client_id(request)
try:
# Safely get redis client from app state
redis_manager = getattr(request.app.state, "redis", None)
if redis_manager and redis_manager.client:
redis_client = redis_manager.client
# Create a time window bucket
current_window = int(time.time() / self.period)
limit_key = f"ku:ratelimit:{client_id}:{current_window}"
# Atomic INCR and EXPIRE pipeline
pipe = redis_client.pipeline()
pipe.incr(limit_key)
pipe.expire(limit_key, self.period)
results = await pipe.execute()
requests_this_window = results[0]
if requests_this_window > self.limit:
logger.warning(f"Rate limit exceeded for {client_id}")
return JSONResponse(
status_code=429,
content={
'error': 'Rate limit exceeded',
'limit': self.limit,
'period': self.period,
'message': 'Maximum requests reached. Please slow down.'
},
headers={"Retry-After": str(self.period)}
)
except Exception as e:
# Fail open if Redis crashes so legitimate users aren't blocked
logger.error(f"Rate limiter Redis failure: {e}")
# Process request
response = await call_next(request)
return response
def _get_client_id(self, request: Request) -> str:
"""Get client identifier from request"""
# Check for API key
api_key_header = getattr(settings, "API_KEY_HEADER", "X-API-Key")
api_key = request.headers.get(api_key_header)
if api_key:
return f"key:{api_key[:8]}" # Only use prefix for cache key security
# Fall back to IP address for public routes like /signup
return f"ip:{request.client.host if request.client else 'unknown'}"
class MetricsMiddleware(BaseHTTPMiddleware):
"""
Prometheus metrics collection
"""
async def dispatch(self, request: Request, call_next):
# Skip metrics endpoint itself
if request.url.path == '/metrics':
return await call_next(request)
# Record start time
start_time = time.time()
# Process request
try:
response = await call_next(request)
status_code = response.status_code
except Exception as e:
logger.error(f"Request failed: {e}")
status_code = 500
raise
finally:
# Record metrics
duration = time.time() - start_time
request_counter.labels(
method=request.method,
endpoint=request.url.path,
status=status_code
).inc()
request_duration.labels(
method=request.method,
endpoint=request.url.path
).observe(duration)
return response |