Spaces:
Running
Running
File size: 5,652 Bytes
3b7f713 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | """
Rate Limiting Middleware dla GrantForge AI.
Chroni endpointy generatora i audytora przed nadu偶yciami.
U偶ywa prostego in-memory store (dla multi-worker贸w u偶yj Redis).
"""
import time
import logging
from collections import defaultdict
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
logger = logging.getLogger(__name__)
# Konfiguracja limit贸w per endpoint-gruppe
RATE_LIMITS = {
"/api/generator/stream": {"requests": 5, "window_seconds": 300}, # 5 req / 5 min
"/api/projects/": {
"audit": {"requests": 10, "window_seconds": 3600}, # 10 audyt贸w / godz.
"autofix": {"requests": 200, "window_seconds": 3600}, # 200 autofix贸w / godz.
},
"default": {"requests": 120, "window_seconds": 60}, # 120 req / min dla reszty
}
# In-memory store: {user_id: {endpoint_key: [(timestamp), ...]}}
_request_log: dict = defaultdict(lambda: defaultdict(list))
def _get_user_id(request: Request) -> str:
"""Wyci膮ga user_id z tokena JWT lub u偶ywa IP jako fallback."""
import jwt
auth = request.headers.get("Authorization", "")
if auth.startswith("Bearer "):
token = auth.split(" ", 1)[1]
try:
if token == "dev_test_token":
return "dev_user"
decoded = jwt.decode(token, options={"verify_signature": False})
return decoded.get("sub", request.client.host)
except Exception:
pass
# Fallback: token w query string (generator SSE)
token = request.query_params.get("token", "")
if token:
try:
decoded = jwt.decode(token, options={"verify_signature": False})
return decoded.get("sub", request.client.host)
except Exception:
pass
return getattr(request.client, "host", "unknown")
def _check_rate_limit(user_id: str, endpoint_key: str, limit: dict) -> tuple[bool, int]:
"""
Sprawdza czy u偶ytkownik nie przekroczy艂 limitu.
Zwraca (allowed, retry_after_seconds).
"""
now = time.time()
window = limit["window_seconds"]
max_requests = limit["requests"]
# Wyczy艣膰 stare wpisy
timestamps = _request_log[user_id][endpoint_key]
_request_log[user_id][endpoint_key] = [t for t in timestamps if now - t < window]
current_count = len(_request_log[user_id][endpoint_key])
if current_count >= max_requests:
oldest = _request_log[user_id][endpoint_key][0]
retry_after = int(window - (now - oldest)) + 1
return False, retry_after
_request_log[user_id][endpoint_key].append(now)
return True, 0
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
Middleware aplikuj膮cy rate limiting do wybranych endpoint贸w.
Styl: sliding window per user.
"""
# Endpointy do kt贸rych stosujemy 艣cis艂e limity
STRICT_PATHS = {
"/api/generator/stream",
}
# Wzorce URL z kluczem (艣cie偶ka zawiera te fragmenty)
PATTERN_LIMITS = {
"/audit": {"requests": 10, "window_seconds": 3600},
"/autofix": {"requests": 200, "window_seconds": 3600},
}
async def dispatch(self, request: Request, call_next):
path = request.url.path
# Pomijamy health check i statyczne zasoby
if path in ("/health", "/api/health", "/", "/docs", "/openapi.json"):
return await call_next(request)
user_id = _get_user_id(request)
# 1. 艢cis艂e limity dla generatora
if path in self.STRICT_PATHS:
limit = RATE_LIMITS["/api/generator/stream"]
allowed, retry_after = _check_rate_limit(user_id, path, limit)
if not allowed:
logger.warning(
f"Rate limit: {user_id} @ {path} (retry in {retry_after}s)"
)
return JSONResponse(
status_code=429,
content={
"detail": f"Przekroczono limit zapyta艅. Spr贸buj ponownie za {retry_after} sekund.",
"retry_after": retry_after,
},
headers={"Retry-After": str(retry_after)},
)
# 2. Limity dla wzorc贸w audit/autofix
for pattern, limit in self.PATTERN_LIMITS.items():
if pattern in path:
endpoint_key = f"{path}:{request.method}"
allowed, retry_after = _check_rate_limit(user_id, endpoint_key, limit)
if not allowed:
logger.warning(
f"Rate limit: {user_id} @ {path} (retry in {retry_after}s)"
)
return JSONResponse(
status_code=429,
content={
"detail": f"Przekroczono limit operacji AI. Spr贸buj ponownie za {retry_after} sekund.",
"retry_after": retry_after,
},
headers={"Retry-After": str(retry_after)},
)
break
response = await call_next(request)
# Dodaj nag艂贸wki informacyjne o limitach (opcjonalnie)
if path in self.STRICT_PATHS:
limit = RATE_LIMITS["/api/generator/stream"]
timestamps = _request_log[user_id][path]
remaining = max(0, limit["requests"] - len(timestamps))
response.headers["X-RateLimit-Limit"] = str(limit["requests"])
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Window"] = str(limit["window_seconds"])
return response
|