grantforge-api / backend /core /rate_limiter.py
GrantForge Bot
Deploy to Hugging Face
afd56bc
"""
Rate Limiting Middleware dla GrantForge AI.
Chroni endpointy generatora i audytora przed nadu偶yciami.
U偶ywa prostego in-memory store (dla multi-worker贸w u偶yj Redis).
"""
import time
import logging
from collections import defaultdict
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
logger = logging.getLogger(__name__)
# Konfiguracja limit贸w per endpoint-gruppe
RATE_LIMITS = {
"/api/generator/stream": {"requests": 5, "window_seconds": 300}, # 5 req / 5 min
"/api/projects/": {
"audit": {"requests": 10, "window_seconds": 3600}, # 10 audyt贸w / godz.
"autofix": {"requests": 200, "window_seconds": 3600}, # 200 autofix贸w / godz.
},
"default": {"requests": 120, "window_seconds": 60}, # 120 req / min dla reszty
}
# In-memory store: {user_id: {endpoint_key: [(timestamp), ...]}}
_request_log: dict = defaultdict(lambda: defaultdict(list))
def _get_user_id(request: Request) -> str:
"""Wyci膮ga user_id z tokena JWT lub u偶ywa IP jako fallback."""
import jwt
auth = request.headers.get("Authorization", "")
if auth.startswith("Bearer "):
token = auth.split(" ", 1)[1]
try:
if token == "dev_test_token":
return "dev_user"
decoded = jwt.decode(token, options={"verify_signature": False})
return decoded.get("sub", request.client.host)
except Exception:
pass
# Fallback: token w query string (generator SSE)
token = request.query_params.get("token", "")
if token:
try:
decoded = jwt.decode(token, options={"verify_signature": False})
return decoded.get("sub", request.client.host)
except Exception:
pass
return getattr(request.client, "host", "unknown")
def _check_rate_limit(user_id: str, endpoint_key: str, limit: dict) -> tuple[bool, int]:
"""
Sprawdza czy u偶ytkownik nie przekroczy艂 limitu.
Zwraca (allowed, retry_after_seconds).
"""
now = time.time()
window = limit["window_seconds"]
max_requests = limit["requests"]
# Wyczy艣膰 stare wpisy
timestamps = _request_log[user_id][endpoint_key]
_request_log[user_id][endpoint_key] = [t for t in timestamps if now - t < window]
current_count = len(_request_log[user_id][endpoint_key])
if current_count >= max_requests:
oldest = _request_log[user_id][endpoint_key][0]
retry_after = int(window - (now - oldest)) + 1
return False, retry_after
_request_log[user_id][endpoint_key].append(now)
return True, 0
class RateLimitMiddleware(BaseHTTPMiddleware):
"""
Middleware aplikuj膮cy rate limiting do wybranych endpoint贸w.
Styl: sliding window per user.
"""
# Endpointy do kt贸rych stosujemy 艣cis艂e limity
STRICT_PATHS = {
"/api/generator/stream",
}
# Wzorce URL z kluczem (艣cie偶ka zawiera te fragmenty)
PATTERN_LIMITS = {
"/audit": {"requests": 10, "window_seconds": 3600},
"/autofix": {"requests": 200, "window_seconds": 3600},
}
async def dispatch(self, request: Request, call_next):
path = request.url.path
# Pomijamy health check i statyczne zasoby
if path in ("/health", "/api/health", "/", "/docs", "/openapi.json"):
return await call_next(request)
user_id = _get_user_id(request)
# 1. 艢cis艂e limity dla generatora
if path in self.STRICT_PATHS:
limit = RATE_LIMITS["/api/generator/stream"]
allowed, retry_after = _check_rate_limit(user_id, path, limit)
if not allowed:
logger.warning(
f"Rate limit: {user_id} @ {path} (retry in {retry_after}s)"
)
return JSONResponse(
status_code=429,
content={
"detail": f"Przekroczono limit zapyta艅. Spr贸buj ponownie za {retry_after} sekund.",
"retry_after": retry_after,
},
headers={"Retry-After": str(retry_after)},
)
# 2. Limity dla wzorc贸w audit/autofix
for pattern, limit in self.PATTERN_LIMITS.items():
if pattern in path:
endpoint_key = f"{path}:{request.method}"
allowed, retry_after = _check_rate_limit(user_id, endpoint_key, limit)
if not allowed:
logger.warning(
f"Rate limit: {user_id} @ {path} (retry in {retry_after}s)"
)
return JSONResponse(
status_code=429,
content={
"detail": f"Przekroczono limit operacji AI. Spr贸buj ponownie za {retry_after} sekund.",
"retry_after": retry_after,
},
headers={"Retry-After": str(retry_after)},
)
break
response = await call_next(request)
# Dodaj nag艂贸wki informacyjne o limitach (opcjonalnie)
if path in self.STRICT_PATHS:
limit = RATE_LIMITS["/api/generator/stream"]
timestamps = _request_log[user_id][path]
remaining = max(0, limit["requests"] - len(timestamps))
response.headers["X-RateLimit-Limit"] = str(limit["requests"])
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Window"] = str(limit["window_seconds"])
return response