Spaces:
Running
Running
| """ | |
| Rate Limiting Middleware dla GrantForge AI. | |
| Chroni endpointy generatora i audytora przed nadu偶yciami. | |
| U偶ywa prostego in-memory store (dla multi-worker贸w u偶yj Redis). | |
| """ | |
| import time | |
| import logging | |
| from collections import defaultdict | |
| from fastapi import Request | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from starlette.responses import JSONResponse | |
| logger = logging.getLogger(__name__) | |
| # Konfiguracja limit贸w per endpoint-gruppe | |
| RATE_LIMITS = { | |
| "/api/generator/stream": {"requests": 5, "window_seconds": 300}, # 5 req / 5 min | |
| "/api/projects/": { | |
| "audit": {"requests": 10, "window_seconds": 3600}, # 10 audyt贸w / godz. | |
| "autofix": {"requests": 200, "window_seconds": 3600}, # 200 autofix贸w / godz. | |
| }, | |
| "default": {"requests": 120, "window_seconds": 60}, # 120 req / min dla reszty | |
| } | |
| # In-memory store: {user_id: {endpoint_key: [(timestamp), ...]}} | |
| _request_log: dict = defaultdict(lambda: defaultdict(list)) | |
| def _get_user_id(request: Request) -> str: | |
| """Wyci膮ga user_id z tokena JWT lub u偶ywa IP jako fallback.""" | |
| import jwt | |
| auth = request.headers.get("Authorization", "") | |
| if auth.startswith("Bearer "): | |
| token = auth.split(" ", 1)[1] | |
| try: | |
| if token == "dev_test_token": | |
| return "dev_user" | |
| decoded = jwt.decode(token, options={"verify_signature": False}) | |
| return decoded.get("sub", request.client.host) | |
| except Exception: | |
| pass | |
| # Fallback: token w query string (generator SSE) | |
| token = request.query_params.get("token", "") | |
| if token: | |
| try: | |
| decoded = jwt.decode(token, options={"verify_signature": False}) | |
| return decoded.get("sub", request.client.host) | |
| except Exception: | |
| pass | |
| return getattr(request.client, "host", "unknown") | |
| def _check_rate_limit(user_id: str, endpoint_key: str, limit: dict) -> tuple[bool, int]: | |
| """ | |
| Sprawdza czy u偶ytkownik nie przekroczy艂 limitu. | |
| Zwraca (allowed, retry_after_seconds). | |
| """ | |
| now = time.time() | |
| window = limit["window_seconds"] | |
| max_requests = limit["requests"] | |
| # Wyczy艣膰 stare wpisy | |
| timestamps = _request_log[user_id][endpoint_key] | |
| _request_log[user_id][endpoint_key] = [t for t in timestamps if now - t < window] | |
| current_count = len(_request_log[user_id][endpoint_key]) | |
| if current_count >= max_requests: | |
| oldest = _request_log[user_id][endpoint_key][0] | |
| retry_after = int(window - (now - oldest)) + 1 | |
| return False, retry_after | |
| _request_log[user_id][endpoint_key].append(now) | |
| return True, 0 | |
| class RateLimitMiddleware(BaseHTTPMiddleware): | |
| """ | |
| Middleware aplikuj膮cy rate limiting do wybranych endpoint贸w. | |
| Styl: sliding window per user. | |
| """ | |
| # Endpointy do kt贸rych stosujemy 艣cis艂e limity | |
| STRICT_PATHS = { | |
| "/api/generator/stream", | |
| } | |
| # Wzorce URL z kluczem (艣cie偶ka zawiera te fragmenty) | |
| PATTERN_LIMITS = { | |
| "/audit": {"requests": 10, "window_seconds": 3600}, | |
| "/autofix": {"requests": 200, "window_seconds": 3600}, | |
| } | |
| async def dispatch(self, request: Request, call_next): | |
| path = request.url.path | |
| # Pomijamy health check i statyczne zasoby | |
| if path in ("/health", "/api/health", "/", "/docs", "/openapi.json"): | |
| return await call_next(request) | |
| user_id = _get_user_id(request) | |
| # 1. 艢cis艂e limity dla generatora | |
| if path in self.STRICT_PATHS: | |
| limit = RATE_LIMITS["/api/generator/stream"] | |
| allowed, retry_after = _check_rate_limit(user_id, path, limit) | |
| if not allowed: | |
| logger.warning( | |
| f"Rate limit: {user_id} @ {path} (retry in {retry_after}s)" | |
| ) | |
| return JSONResponse( | |
| status_code=429, | |
| content={ | |
| "detail": f"Przekroczono limit zapyta艅. Spr贸buj ponownie za {retry_after} sekund.", | |
| "retry_after": retry_after, | |
| }, | |
| headers={"Retry-After": str(retry_after)}, | |
| ) | |
| # 2. Limity dla wzorc贸w audit/autofix | |
| for pattern, limit in self.PATTERN_LIMITS.items(): | |
| if pattern in path: | |
| endpoint_key = f"{path}:{request.method}" | |
| allowed, retry_after = _check_rate_limit(user_id, endpoint_key, limit) | |
| if not allowed: | |
| logger.warning( | |
| f"Rate limit: {user_id} @ {path} (retry in {retry_after}s)" | |
| ) | |
| return JSONResponse( | |
| status_code=429, | |
| content={ | |
| "detail": f"Przekroczono limit operacji AI. Spr贸buj ponownie za {retry_after} sekund.", | |
| "retry_after": retry_after, | |
| }, | |
| headers={"Retry-After": str(retry_after)}, | |
| ) | |
| break | |
| response = await call_next(request) | |
| # Dodaj nag艂贸wki informacyjne o limitach (opcjonalnie) | |
| if path in self.STRICT_PATHS: | |
| limit = RATE_LIMITS["/api/generator/stream"] | |
| timestamps = _request_log[user_id][path] | |
| remaining = max(0, limit["requests"] - len(timestamps)) | |
| response.headers["X-RateLimit-Limit"] = str(limit["requests"]) | |
| response.headers["X-RateLimit-Remaining"] = str(remaining) | |
| response.headers["X-RateLimit-Window"] = str(limit["window_seconds"]) | |
| return response | |