""" Rate Limiting Middleware dla GrantForge AI. Chroni endpointy generatora i audytora przed nadużyciami. Używa prostego in-memory store (dla multi-workerów użyj Redis). """ import time import logging from collections import defaultdict from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse logger = logging.getLogger(__name__) # Konfiguracja limitów per endpoint-gruppe RATE_LIMITS = { "/api/generator/stream": {"requests": 5, "window_seconds": 300}, # 5 req / 5 min "/api/projects/": { "audit": {"requests": 10, "window_seconds": 3600}, # 10 audytów / godz. "autofix": {"requests": 200, "window_seconds": 3600}, # 200 autofixów / godz. }, "default": {"requests": 120, "window_seconds": 60}, # 120 req / min dla reszty } # In-memory store: {user_id: {endpoint_key: [(timestamp), ...]}} _request_log: dict = defaultdict(lambda: defaultdict(list)) def _get_user_id(request: Request) -> str: """Wyciąga user_id z tokena JWT lub używa IP jako fallback.""" import jwt auth = request.headers.get("Authorization", "") if auth.startswith("Bearer "): token = auth.split(" ", 1)[1] try: if token == "dev_test_token": return "dev_user" decoded = jwt.decode(token, options={"verify_signature": False}) return decoded.get("sub", request.client.host) except Exception: pass # Fallback: token w query string (generator SSE) token = request.query_params.get("token", "") if token: try: decoded = jwt.decode(token, options={"verify_signature": False}) return decoded.get("sub", request.client.host) except Exception: pass return getattr(request.client, "host", "unknown") def _check_rate_limit(user_id: str, endpoint_key: str, limit: dict) -> tuple[bool, int]: """ Sprawdza czy użytkownik nie przekroczył limitu. Zwraca (allowed, retry_after_seconds). """ now = time.time() window = limit["window_seconds"] max_requests = limit["requests"] # Wyczyść stare wpisy timestamps = _request_log[user_id][endpoint_key] _request_log[user_id][endpoint_key] = [t for t in timestamps if now - t < window] current_count = len(_request_log[user_id][endpoint_key]) if current_count >= max_requests: oldest = _request_log[user_id][endpoint_key][0] retry_after = int(window - (now - oldest)) + 1 return False, retry_after _request_log[user_id][endpoint_key].append(now) return True, 0 class RateLimitMiddleware(BaseHTTPMiddleware): """ Middleware aplikujący rate limiting do wybranych endpointów. Styl: sliding window per user. """ # Endpointy do których stosujemy ścisłe limity STRICT_PATHS = { "/api/generator/stream", } # Wzorce URL z kluczem (ścieżka zawiera te fragmenty) PATTERN_LIMITS = { "/audit": {"requests": 10, "window_seconds": 3600}, "/autofix": {"requests": 200, "window_seconds": 3600}, } async def dispatch(self, request: Request, call_next): path = request.url.path # Pomijamy health check i statyczne zasoby if path in ("/health", "/api/health", "/", "/docs", "/openapi.json"): return await call_next(request) user_id = _get_user_id(request) # 1. Ścisłe limity dla generatora if path in self.STRICT_PATHS: limit = RATE_LIMITS["/api/generator/stream"] allowed, retry_after = _check_rate_limit(user_id, path, limit) if not allowed: logger.warning( f"Rate limit: {user_id} @ {path} (retry in {retry_after}s)" ) return JSONResponse( status_code=429, content={ "detail": f"Przekroczono limit zapytań. Spróbuj ponownie za {retry_after} sekund.", "retry_after": retry_after, }, headers={"Retry-After": str(retry_after)}, ) # 2. Limity dla wzorców audit/autofix for pattern, limit in self.PATTERN_LIMITS.items(): if pattern in path: endpoint_key = f"{path}:{request.method}" allowed, retry_after = _check_rate_limit(user_id, endpoint_key, limit) if not allowed: logger.warning( f"Rate limit: {user_id} @ {path} (retry in {retry_after}s)" ) return JSONResponse( status_code=429, content={ "detail": f"Przekroczono limit operacji AI. Spróbuj ponownie za {retry_after} sekund.", "retry_after": retry_after, }, headers={"Retry-After": str(retry_after)}, ) break response = await call_next(request) # Dodaj nagłówki informacyjne o limitach (opcjonalnie) if path in self.STRICT_PATHS: limit = RATE_LIMITS["/api/generator/stream"] timestamps = _request_log[user_id][path] remaining = max(0, limit["requests"] - len(timestamps)) response.headers["X-RateLimit-Limit"] = str(limit["requests"]) response.headers["X-RateLimit-Remaining"] = str(remaining) response.headers["X-RateLimit-Window"] = str(limit["window_seconds"]) return response