Spaces:
Running
Running
| """ | |
| IP-based hard-cap rate limiter for PatientSim Gradio demo. | |
| Each counter is a simple cumulative total β no time window, no reset. | |
| Once a limit is reached the client is permanently blocked for that action | |
| until the process is restarted (or the SQLite DB is cleared). | |
| Limits are configurable via environment variables: | |
| RATE_LIMIT_CHAT_MSGS β max chat messages total per IP (default: 50) | |
| RATE_LIMIT_AUTO_RUNS β max auto simulation runs total per IP (default: 5) | |
| RATE_LIMIT_TOTAL_API_CALLS β max total LLM calls across all modes (default: 200) | |
| RATE_LIMIT_GLOBAL_TOTAL β hard cap on total LLM calls globally (default: 10000) | |
| Client identification priority (for HuggingFace Spaces): | |
| 1. HF OAuth username (if the Space has OAuth enabled) | |
| 2. X-Forwarded-For header (rightmost IP β added by the trusted proxy) | |
| 3. X-Real-IP header | |
| 4. Direct client host | |
| Callers that cannot be identified return None from get_client_key() and are | |
| rejected by all check_* methods. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import stat | |
| import sqlite3 | |
| import threading | |
| from collections import defaultdict | |
| from typing import Dict, Optional, Tuple | |
| import gradio as gr | |
| # --------------------------------------------------------------------------- | |
| # Configuration β overridable via environment variables | |
| # --------------------------------------------------------------------------- | |
| CHAT_MSGS_LIMIT: int = int(os.environ.get("RATE_LIMIT_CHAT_MSGS", "50")) | |
| AUTO_RUNS_LIMIT: int = int(os.environ.get("RATE_LIMIT_AUTO_RUNS", "5")) | |
| TOTAL_API_CALLS_LIMIT: int = int(os.environ.get("RATE_LIMIT_TOTAL_API_CALLS", "200")) | |
| GLOBAL_TOTAL_CALLS_LIMIT: int = int(os.environ.get("RATE_LIMIT_GLOBAL_TOTAL", "10000")) | |
| # Each auto simulation consumes at most (2 agents Γ MAX_AUTO_INFERENCES) API calls. | |
| # We reserve this many slots upfront in the total_calls counter when an auto run starts. | |
| _AUTO_RUN_CALL_RESERVATION: int = 20 | |
| # Maximum concurrent auto simulation runs allowed per client key. | |
| _MAX_CONCURRENT_AUTO: int = 1 | |
| # --------------------------------------------------------------------------- | |
| # Client identifier extraction | |
| # --------------------------------------------------------------------------- | |
| def get_client_key(request: gr.Request | None) -> Optional[str]: | |
| """ | |
| Return a stable string that identifies the caller, or ``None`` if no | |
| identifier can be extracted (caller will be denied by all check methods). | |
| The key is prefixed with ``"user:"`` for authenticated HF users and | |
| ``"ip:"`` for anonymous IP-based identification. | |
| Parameters | |
| ---------- | |
| request: | |
| The :class:`gradio.Request` object injected by Gradio into event | |
| handler functions. | |
| Returns | |
| ------- | |
| str or None | |
| A non-empty identifier string, or None when identification fails. | |
| """ | |
| if request is None: | |
| return None | |
| # 1. HuggingFace OAuth username (available when HF OAuth is enabled on the Space) | |
| username = getattr(request, "username", None) | |
| if username: | |
| return f"user:{username}" | |
| # Normalise headers to lowercase keys for consistent lookup | |
| raw_headers: dict = {} | |
| if hasattr(request, "headers") and request.headers: | |
| try: | |
| raw_headers = {k.lower(): v for k, v in dict(request.headers).items()} | |
| except Exception: | |
| pass | |
| # 2. Cloudflare/HF real IP header β not spoofable by clients | |
| cf_ip = raw_headers.get("cf-connecting-ip", "").strip() | |
| if cf_ip: | |
| return f"ip:{cf_ip}" | |
| # 3. X-Forwarded-For β index from the right by the number of trusted proxies | |
| # to avoid client-controlled header spoofing. | |
| _TRUSTED_PROXIES: int = int(os.environ.get("TRUSTED_PROXY_COUNT", "1")) | |
| xff = raw_headers.get("x-forwarded-for", "") | |
| if xff: | |
| ips = [ip.strip() for ip in xff.split(",") if ip.strip()] | |
| if len(ips) >= _TRUSTED_PROXIES: | |
| client_ip = ips[-_TRUSTED_PROXIES] | |
| if client_ip: | |
| return f"ip:{client_ip}" | |
| # 4. X-Real-IP β set by some reverse proxies (nginx, etc.) | |
| x_real_ip = raw_headers.get("x-real-ip", "") | |
| if x_real_ip: | |
| return f"ip:{x_real_ip.strip()}" | |
| # 5. Direct connection host (only reliable when not behind a proxy) | |
| client = getattr(request, "client", None) | |
| if client and getattr(client, "host", None): | |
| return f"ip:{client.host}" | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Rate limiter | |
| # --------------------------------------------------------------------------- | |
| class RateLimiter: | |
| """ | |
| Thread-safe hard-cap rate limiter keyed by client identifier. | |
| Counters are cumulative totals with no time window β once a limit is | |
| reached the client is permanently blocked for that action. | |
| Per-key counters are persisted to SQLite so they survive process restarts | |
| (OOM kills, HF Space sleeps, etc.). Falls back to in-memory storage if | |
| SQLite cannot be initialised. | |
| Tracks four independent counters per key: | |
| * **chat_msgs** β individual chat messages (1 LLM call each) | |
| * **auto_runs** β auto simulation runs (each reserved as | |
| ``_AUTO_RUN_CALL_RESERVATION`` LLM calls in ``total_calls``) | |
| * **total_calls** β aggregate LLM API calls across all modes | |
| Plus two global/in-memory counters: | |
| * **_global_calls** β total LLM calls across all clients (hard global cap) | |
| * **_active_auto_runs** β concurrent auto runs per key (burst prevention) | |
| Example | |
| ------- | |
| >>> limiter = RateLimiter() | |
| >>> allowed, msg = limiter.check_chat_message("ip:1.2.3.4") | |
| >>> if not allowed: | |
| ... raise gr.Error(msg) | |
| """ | |
| _UNIDENTIFIED_MSG = "Unable to identify your session. Please reload the page." | |
| def __init__(self) -> None: | |
| self._lock = threading.Lock() | |
| # SQLite-backed persistent counters; fall back to in-memory on failure | |
| self._db: Optional[sqlite3.Connection] = None | |
| self._mem: Dict[str, Dict[str, int]] = { | |
| "chat_msgs": defaultdict(int), | |
| "auto_runs": defaultdict(int), | |
| "total_calls": defaultdict(int), | |
| } | |
| self._init_db() | |
| # _global_calls is persisted to SQLite under the special key "__global__" | |
| # so the hard cap survives process restarts (HF Space sleep/wake cycles). | |
| # _active_auto_runs is intentionally in-memory β concurrent run slots | |
| # should reset on restart. | |
| self._active_auto_runs: Dict[str, int] = defaultdict(int) | |
| def _global_calls(self) -> int: | |
| return self._get("total_calls", "__global__") | |
| def _global_calls(self, value: int) -> None: | |
| self._set("total_calls", "__global__", value) | |
| # ------------------------------------------------------------------ | |
| # SQLite helpers (must be called within self._lock) | |
| # ------------------------------------------------------------------ | |
| def _init_db(self) -> None: | |
| """Attempt to open a persistent SQLite DB for counter storage.""" | |
| for candidate in ["/data/rate_limits.db", "/tmp/rate_limits.db"]: | |
| try: | |
| db = sqlite3.connect(candidate, check_same_thread=False) | |
| db.execute( | |
| "CREATE TABLE IF NOT EXISTS counters " | |
| "(key TEXT, counter_type TEXT, count INTEGER DEFAULT 0, " | |
| "PRIMARY KEY (key, counter_type))" | |
| ) | |
| db.commit() | |
| # Restrict file permissions to owner only (rw-------) | |
| try: | |
| os.chmod(candidate, stat.S_IRUSR | stat.S_IWUSR) | |
| except OSError: | |
| pass | |
| self._db = db | |
| return | |
| except Exception: | |
| continue | |
| def _get(self, counter_type: str, key: str) -> int: | |
| """Read a counter value. Must be called within self._lock.""" | |
| if self._db is not None: | |
| try: | |
| row = self._db.execute( | |
| "SELECT count FROM counters WHERE key=? AND counter_type=?", | |
| (key, counter_type), | |
| ).fetchone() | |
| return row[0] if row else 0 | |
| except Exception: | |
| pass | |
| return self._mem[counter_type][key] | |
| def _set(self, counter_type: str, key: str, count: int) -> None: | |
| """Write a counter value. Must be called within self._lock.""" | |
| if self._db is not None: | |
| try: | |
| self._db.execute( | |
| "INSERT OR REPLACE INTO counters (key, counter_type, count) " | |
| "VALUES (?, ?, ?)", | |
| (key, counter_type, count), | |
| ) | |
| self._db.commit() | |
| return | |
| except Exception: | |
| pass | |
| self._mem[counter_type][key] = count | |
| # ------------------------------------------------------------------ | |
| # Public check methods | |
| # ------------------------------------------------------------------ | |
| def check_chat_message(self, key: Optional[str]) -> Tuple[bool, str]: | |
| """ | |
| Check whether sending a chat message is allowed (= 1 LLM API call). | |
| Atomically increments both ``chat_msgs`` and ``total_calls`` within a | |
| single lock to prevent TOCTOU race conditions. | |
| """ | |
| if not key: | |
| return False, self._UNIDENTIFIED_MSG | |
| with self._lock: | |
| chat_count = self._get("chat_msgs", key) + 1 | |
| total_count = self._get("total_calls", key) + 1 | |
| new_global = self._global_calls + 1 | |
| if chat_count > CHAT_MSGS_LIMIT: | |
| return False, ( | |
| f"Chat message limit reached " | |
| f"(maximum {CHAT_MSGS_LIMIT} messages per session)." | |
| ) | |
| if total_count > TOTAL_API_CALLS_LIMIT: | |
| return False, ( | |
| f"Total API call limit reached " | |
| f"(maximum {TOTAL_API_CALLS_LIMIT} API calls per session)." | |
| ) | |
| if new_global > GLOBAL_TOTAL_CALLS_LIMIT: | |
| return False, "Service capacity reached. Please try again later." | |
| # All checks passed β commit atomically | |
| self._set("chat_msgs", key, chat_count) | |
| self._set("total_calls", key, total_count) | |
| self._global_calls = new_global | |
| return True, "" | |
| def check_auto_run(self, key: Optional[str]) -> Tuple[bool, str]: | |
| """ | |
| Check whether starting an auto simulation is allowed. | |
| Reserves ``_AUTO_RUN_CALL_RESERVATION`` slots in the ``total_calls`` | |
| counter upfront because each auto run may issue up to that many LLM | |
| calls before it finishes. Also enforces a per-key concurrent run limit. | |
| """ | |
| if not key: | |
| return False, self._UNIDENTIFIED_MSG | |
| with self._lock: | |
| if self._active_auto_runs[key] >= _MAX_CONCURRENT_AUTO: | |
| return False, "An auto simulation is already running. Please wait." | |
| run_count = self._get("auto_runs", key) + 1 | |
| total_count = self._get("total_calls", key) + _AUTO_RUN_CALL_RESERVATION | |
| new_global = self._global_calls + _AUTO_RUN_CALL_RESERVATION | |
| if run_count > AUTO_RUNS_LIMIT: | |
| return False, ( | |
| f"Auto simulation limit reached " | |
| f"(maximum {AUTO_RUNS_LIMIT} auto runs per session)." | |
| ) | |
| if total_count > TOTAL_API_CALLS_LIMIT: | |
| return False, ( | |
| f"Total API call limit reached " | |
| f"(maximum {TOTAL_API_CALLS_LIMIT} API calls per session)." | |
| ) | |
| if new_global > GLOBAL_TOTAL_CALLS_LIMIT: | |
| return False, "Service capacity reached. Please try again later." | |
| # All checks passed β commit atomically | |
| self._set("auto_runs", key, run_count) | |
| self._set("total_calls", key, total_count) | |
| self._global_calls = new_global | |
| self._active_auto_runs[key] += 1 | |
| return True, "" | |
| def check_global_capacity(self) -> Tuple[bool, str]: | |
| """ | |
| Lightweight global-capacity check for users supplying their own API keys. | |
| Per-IP quotas (sim_starts, chat_msgs, auto_runs, total_calls) are | |
| intentionally skipped β own-key users are not billed against the shared | |
| pool. However, the hard global cap still applies to prevent the server | |
| from being overwhelmed regardless of who is calling. | |
| Unlike the per-IP check methods, this method **does** increment | |
| ``_global_calls`` so that the counter accurately reflects all LLM | |
| calls, not just those made through the shared key. | |
| """ | |
| with self._lock: | |
| new_global = self._global_calls + 1 | |
| if new_global > GLOBAL_TOTAL_CALLS_LIMIT: | |
| return False, "Service capacity reached. Please try again later." | |
| self._global_calls = new_global | |
| return True, "" | |
| def check_own_key_auto_run(self, key: Optional[str]) -> Tuple[bool, str]: | |
| """ | |
| Concurrent-run and global-capacity check for own-key auto simulations. | |
| Per-IP auto-run quota and total-call quota are intentionally skipped. | |
| The concurrent run cap (``_MAX_CONCURRENT_AUTO``) **is** enforced to | |
| prevent a single client from spawning many parallel simulations and | |
| exhausting server threads. The global hard cap is also applied and the | |
| global counter is updated. | |
| Must be paired with a ``release_auto_slot()`` call in a ``finally`` | |
| block, just like ``check_auto_run()``. | |
| """ | |
| if not key: | |
| return False, self._UNIDENTIFIED_MSG | |
| with self._lock: | |
| if self._active_auto_runs[key] >= _MAX_CONCURRENT_AUTO: | |
| return False, "An auto simulation is already running. Please wait." | |
| new_global = self._global_calls + _AUTO_RUN_CALL_RESERVATION | |
| if new_global > GLOBAL_TOTAL_CALLS_LIMIT: | |
| return False, "Service capacity reached. Please try again later." | |
| # All checks passed β commit atomically | |
| self._global_calls = new_global | |
| self._active_auto_runs[key] += 1 | |
| return True, "" | |
| def release_auto_slot(self, key: Optional[str]) -> None: | |
| """ | |
| Release one concurrent auto run slot for *key*. | |
| Must be called when an auto simulation finishes (or fails) so that | |
| the same client can start another run later. | |
| """ | |
| if not key: | |
| return | |
| with self._lock: | |
| self._active_auto_runs[key] = max(0, self._active_auto_runs[key] - 1) | |
| # ------------------------------------------------------------------ | |
| # Diagnostic | |
| # ------------------------------------------------------------------ | |
| def status(self, key: str) -> dict: | |
| """ | |
| Return current counter snapshots for *key*. | |
| Useful for debugging or exposing quota information in the UI. | |
| Returns | |
| ------- | |
| dict with keys ``chat_messages``, ``auto_runs``, | |
| ``total_api_calls``; each value is a dict with ``used`` and ``limit``. | |
| """ | |
| with self._lock: | |
| return { | |
| "chat_messages": {"used": self._get("chat_msgs", key), "limit": CHAT_MSGS_LIMIT}, | |
| "auto_runs": {"used": self._get("auto_runs", key), "limit": AUTO_RUNS_LIMIT}, | |
| "total_api_calls": {"used": self._get("total_calls", key), "limit": TOTAL_API_CALLS_LIMIT}, | |
| } | |