File size: 4,700 Bytes
3acb982
 
 
2b9ab6a
 
 
3acb982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b9ab6a
3acb982
 
 
 
 
 
 
 
2b9ab6a
3acb982
 
 
 
 
2b9ab6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3acb982
 
 
 
 
 
 
 
2b9ab6a
 
 
3acb982
2b9ab6a
3acb982
2b9ab6a
 
3acb982
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Knowledge Universe - API Middleware
Rate limiting and metrics collection

RICK'S FIX: Upgraded RateLimitMiddleware from in-memory defaultdict 
to Redis-backed atomic counters to survive deployments and scale across workers.
"""
import time
import logging
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from prometheus_client import Counter, Histogram

from config.settings import get_settings

settings = get_settings()
logger = logging.getLogger(__name__)


# Prometheus metrics
request_counter = Counter(
    'ku_requests_total',
    'Total requests',
    ['method', 'endpoint', 'status']
)

request_duration = Histogram(
    'ku_request_duration_seconds',
    'Request duration',
    ['method', 'endpoint']
)


class RateLimitMiddleware(BaseHTTPMiddleware):
    """
    Redis-backed rate limiting to survive deployments and multi-worker scaling.
    """
    
    def __init__(self, app):
        super().__init__(app)
        self.limit = settings.RATE_LIMIT_REQUESTS
        self.period = settings.RATE_LIMIT_PERIOD
    
    async def dispatch(self, request: Request, call_next):
        # Skip rate limiting for health checks and metrics
        if request.url.path in ['/health', '/ready', '/metrics']:
            return await call_next(request)
        
        client_id = self._get_client_id(request)
        
        try:
            # Safely get redis client from app state
            redis_manager = getattr(request.app.state, "redis", None)
            if redis_manager and redis_manager.client:
                redis_client = redis_manager.client
                
                # Create a time window bucket
                current_window = int(time.time() / self.period)
                limit_key = f"ku:ratelimit:{client_id}:{current_window}"
                
                # Atomic INCR and EXPIRE pipeline
                pipe = redis_client.pipeline()
                pipe.incr(limit_key)
                pipe.expire(limit_key, self.period)
                results = await pipe.execute()
                
                requests_this_window = results[0]
                
                if requests_this_window > self.limit:
                    logger.warning(f"Rate limit exceeded for {client_id}")
                    return JSONResponse(
                        status_code=429,
                        content={
                            'error': 'Rate limit exceeded',
                            'limit': self.limit,
                            'period': self.period,
                            'message': 'Maximum requests reached. Please slow down.'
                        },
                        headers={"Retry-After": str(self.period)}
                    )
        except Exception as e:
            # Fail open if Redis crashes so legitimate users aren't blocked
            logger.error(f"Rate limiter Redis failure: {e}")
        
        # Process request
        response = await call_next(request)
        return response
    
    def _get_client_id(self, request: Request) -> str:
        """Get client identifier from request"""
        # Check for API key
        api_key_header = getattr(settings, "API_KEY_HEADER", "X-API-Key")
        api_key = request.headers.get(api_key_header)
        
        if api_key:
            return f"key:{api_key[:8]}"  # Only use prefix for cache key security
        
        # Fall back to IP address for public routes like /signup
        return f"ip:{request.client.host if request.client else 'unknown'}"


class MetricsMiddleware(BaseHTTPMiddleware):
    """
    Prometheus metrics collection
    """
    
    async def dispatch(self, request: Request, call_next):
        # Skip metrics endpoint itself
        if request.url.path == '/metrics':
            return await call_next(request)
        
        # Record start time
        start_time = time.time()
        
        # Process request
        try:
            response = await call_next(request)
            status_code = response.status_code
        except Exception as e:
            logger.error(f"Request failed: {e}")
            status_code = 500
            raise
        finally:
            # Record metrics
            duration = time.time() - start_time
            
            request_counter.labels(
                method=request.method,
                endpoint=request.url.path,
                status=status_code
            ).inc()
            
            request_duration.labels(
                method=request.method,
                endpoint=request.url.path
            ).observe(duration)
        
        return response