Spaces:
Running
Running
| """ | |
| Rate limiting and request throttling utilities. | |
| Thread-safe for single-instance deployments via threading.Lock. | |
| For multi-instance deployments, replace _store with Redis. | |
| """ | |
| import time | |
| import threading | |
| import logging | |
| from collections import defaultdict | |
| from typing import Dict, Tuple | |
| import os | |
| logger = logging.getLogger(__name__) | |
| # Rate limiting configuration | |
| RATE_LIMIT_REQUESTS_PER_MINUTE = int(os.getenv("RATE_LIMIT_REQUESTS_PER_MINUTE", "60")) | |
| RATE_LIMIT_UPLOADS_PER_MINUTE = int(os.getenv("RATE_LIMIT_UPLOADS_PER_MINUTE", "10")) | |
| RATE_LIMIT_ANALYSIS_PER_MINUTE = int(os.getenv("RATE_LIMIT_ANALYSIS_PER_MINUTE", "20")) | |
| RATE_LIMIT_ANALYSIS_CONCURRENT = int(os.getenv("RATE_LIMIT_ANALYSIS_CONCURRENT", "5")) | |
| # Thread-safe in-memory rate limit state | |
| # WARNING: per-process only — does not share state across multiple Uvicorn workers. | |
| _lock = threading.Lock() | |
| _rate_limit_store: Dict[str, list] = defaultdict(list) | |
| _concurrent_jobs: Dict[str, int] = defaultdict(int) | |
| def check_rate_limit(client_id: str, limit: int, window_seconds: int = 60) -> Tuple[bool, str]: | |
| """ | |
| Check if a client has exceeded its rate limit (thread-safe). | |
| Slide-window counter — automatically expires old timestamps. | |
| Returns (is_allowed, message). | |
| """ | |
| now = time.time() | |
| window_start = now - window_seconds | |
| with _lock: | |
| # Evict expired timestamps then check + record atomically | |
| _rate_limit_store[client_id] = [ | |
| ts for ts in _rate_limit_store[client_id] | |
| if ts > window_start | |
| ] | |
| request_count = len(_rate_limit_store[client_id]) | |
| if request_count >= limit: | |
| return False, f"Rate limit exceeded: {limit} requests per {window_seconds}s" | |
| _rate_limit_store[client_id].append(now) | |
| return True, "" | |
| def can_start_analysis(client_id: str) -> Tuple[bool, str]: | |
| """ | |
| Check if a client can start a new analysis job (thread-safe concurrent limit). | |
| The check and the subsequent increment must be performed as an atomic pair | |
| at the call site to avoid a race; use increment_concurrent_job immediately after. | |
| Returns (is_allowed, message). | |
| """ | |
| with _lock: | |
| if _concurrent_jobs[client_id] >= RATE_LIMIT_ANALYSIS_CONCURRENT: | |
| return False, f"Too many concurrent analysis jobs. Max: {RATE_LIMIT_ANALYSIS_CONCURRENT}" | |
| return True, "" | |
| def increment_concurrent_job(client_id: str) -> None: | |
| """Record that a client started a new analysis job (thread-safe).""" | |
| with _lock: | |
| _concurrent_jobs[client_id] += 1 | |
| def decrement_concurrent_job(client_id: str) -> None: | |
| """Record that a client's analysis job finished (thread-safe, never goes negative).""" | |
| with _lock: | |
| if _concurrent_jobs[client_id] > 0: | |
| _concurrent_jobs[client_id] -= 1 | |
| def extract_client_id(request) -> str: | |
| """ | |
| Extract client identifier securely from a FastAPI Request. | |
| Protects against X-Forwarded-For spoofing by requiring explicit proxy trust. | |
| """ | |
| trust_proxy = os.getenv("TRUST_FORWARDED_IP", "false").lower() == "true" | |
| if trust_proxy: | |
| forwarded = request.headers.get("X-Forwarded-For") | |
| if forwarded: | |
| return forwarded.split(",")[0].strip() | |
| return request.client.host if request.client else "unknown" | |