dek924 commited on
Commit
08298d9
Β·
1 Parent(s): e646c57

feat: api call limit init

Browse files
Files changed (2) hide show
  1. app.py +25 -2
  2. rate_limiter.py +285 -0
app.py CHANGED
@@ -5,9 +5,14 @@ import gradio as gr
5
  from dotenv import load_dotenv, find_dotenv
6
  from patientsim import PatientAgent, DoctorAgent
7
  from patientsim.utils.common_utils import detect_ed_termination
 
8
 
9
  load_dotenv(find_dotenv(usecwd=True), override=True)
10
 
 
 
 
 
11
 
12
  # ---------------------------------------------------------------------------
13
  # Constants
@@ -510,10 +515,16 @@ def start_simulation(
510
  personality: str,
511
  recall: str,
512
  confusion: str,
 
513
  ):
514
  if not hadm_id:
515
  return _setup_error("Please select a patient first.")
516
 
 
 
 
 
 
517
  is_openai = "gpt" in model.lower()
518
 
519
  if is_openai:
@@ -611,12 +622,17 @@ def start_manual(profile_mode: str, agent, sim_config: dict):
611
  )
612
 
613
 
614
- def chat(message: str, history: list, agent):
615
  if agent is None:
616
  raise gr.Error("No simulation running. Please start a simulation first.")
617
  if not message.strip():
618
  return history, ""
619
 
 
 
 
 
 
620
  response = agent(user_prompt=message, using_multi_turn=True, verbose=False)
621
  history = history + [
622
  {"role": "user", "content": message},
@@ -663,13 +679,20 @@ def _auto_fallback_outputs():
663
  )
664
 
665
 
666
- def start_auto(agent, sim_config: dict):
667
  """Generator β€” yields chatbot updates turn-by-turn so the UI streams live."""
668
  if agent is None or sim_config is None:
669
  gr.Warning("Session expired. Please restart.")
670
  yield _auto_fallback_outputs()
671
  return
672
 
 
 
 
 
 
 
 
673
  agent.reset_history(verbose=False)
674
 
675
  # Show auto_section immediately; set per-patient avatar on first yield
 
5
  from dotenv import load_dotenv, find_dotenv
6
  from patientsim import PatientAgent, DoctorAgent
7
  from patientsim.utils.common_utils import detect_ed_termination
8
+ from rate_limiter import RateLimiter, get_client_key
9
 
10
  load_dotenv(find_dotenv(usecwd=True), override=True)
11
 
12
+ # ---------------------------------------------------------------------------
13
+ # Rate limiter (singleton β€” shared across all Gradio worker threads)
14
+ # ---------------------------------------------------------------------------
15
+ _rate_limiter = RateLimiter()
16
 
17
  # ---------------------------------------------------------------------------
18
  # Constants
 
515
  personality: str,
516
  recall: str,
517
  confusion: str,
518
+ request: gr.Request = None,
519
  ):
520
  if not hadm_id:
521
  return _setup_error("Please select a patient first.")
522
 
523
+ client_key = get_client_key(request)
524
+ allowed, limit_msg = _rate_limiter.check_simulation_start(client_key)
525
+ if not allowed:
526
+ return _setup_error(limit_msg)
527
+
528
  is_openai = "gpt" in model.lower()
529
 
530
  if is_openai:
 
622
  )
623
 
624
 
625
+ def chat(message: str, history: list, agent, request: gr.Request = None):
626
  if agent is None:
627
  raise gr.Error("No simulation running. Please start a simulation first.")
628
  if not message.strip():
629
  return history, ""
630
 
631
+ client_key = get_client_key(request)
632
+ allowed, limit_msg = _rate_limiter.check_chat_message(client_key)
633
+ if not allowed:
634
+ raise gr.Error(limit_msg)
635
+
636
  response = agent(user_prompt=message, using_multi_turn=True, verbose=False)
637
  history = history + [
638
  {"role": "user", "content": message},
 
679
  )
680
 
681
 
682
+ def start_auto(agent, sim_config: dict, request: gr.Request = None):
683
  """Generator β€” yields chatbot updates turn-by-turn so the UI streams live."""
684
  if agent is None or sim_config is None:
685
  gr.Warning("Session expired. Please restart.")
686
  yield _auto_fallback_outputs()
687
  return
688
 
689
+ client_key = get_client_key(request)
690
+ allowed, limit_msg = _rate_limiter.check_auto_run(client_key)
691
+ if not allowed:
692
+ gr.Warning(limit_msg)
693
+ yield _auto_fallback_outputs()
694
+ return
695
+
696
  agent.reset_history(verbose=False)
697
 
698
  # Show auto_section immediately; set per-patient avatar on first yield
rate_limiter.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ IP-based hard-cap rate limiter for PatientSim Gradio demo.
3
+
4
+ Each counter is a simple cumulative total β€” no time window, no reset.
5
+ Once a limit is reached the client is permanently blocked for that action
6
+ until the process is restarted.
7
+
8
+ Limits are configurable via environment variables:
9
+
10
+ RATE_LIMIT_SIM_STARTS β€” max simulation setups total per IP (default: 5)
11
+ RATE_LIMIT_CHAT_MSGS β€” max chat messages total per IP (default: 50)
12
+ RATE_LIMIT_AUTO_RUNS β€” max auto simulation runs total per IP (default: 5)
13
+ RATE_LIMIT_TOTAL_API_CALLS β€” max total LLM calls across all modes (default: 200)
14
+
15
+ Client identification priority (for HuggingFace Spaces):
16
+ 1. HF OAuth username (if the Space has OAuth enabled)
17
+ 2. X-Forwarded-For header (first IP in the proxy chain)
18
+ 3. X-Real-IP header
19
+ 4. Direct client host
20
+ """
21
+
22
+ from __future__ import annotations
23
+
24
+ import os
25
+ import threading
26
+ from collections import defaultdict
27
+ from typing import Dict, Tuple
28
+
29
+ import gradio as gr
30
+
31
+
32
+ # ---------------------------------------------------------------------------
33
+ # Configuration β€” overridable via environment variables
34
+ # ---------------------------------------------------------------------------
35
+ SIM_STARTS_LIMIT: int = int(os.environ.get("RATE_LIMIT_SIM_STARTS", "5"))
36
+ CHAT_MSGS_LIMIT: int = int(os.environ.get("RATE_LIMIT_CHAT_MSGS", "50"))
37
+ AUTO_RUNS_LIMIT: int = int(os.environ.get("RATE_LIMIT_AUTO_RUNS", "5"))
38
+ TOTAL_API_CALLS_LIMIT: int = int(os.environ.get("RATE_LIMIT_TOTAL_API_CALLS", "200"))
39
+
40
+ # Each auto simulation consumes at most (2 agents Γ— MAX_AUTO_INFERENCES) API calls.
41
+ # We reserve this many slots upfront in the total_calls counter when an auto run starts.
42
+ _AUTO_RUN_CALL_RESERVATION: int = 20
43
+
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # Client identifier extraction
47
+ # ---------------------------------------------------------------------------
48
+
49
+ def get_client_key(request: gr.Request | None) -> str:
50
+ """
51
+ Return a stable string that identifies the caller.
52
+
53
+ The key is prefixed with ``"user:"`` for authenticated HF users and
54
+ ``"ip:"`` for anonymous IP-based identification. Falls back to
55
+ ``"unknown"`` when no identifier can be extracted.
56
+
57
+ Parameters
58
+ ----------
59
+ request:
60
+ The :class:`gradio.Request` object injected by Gradio into event
61
+ handler functions.
62
+
63
+ Returns
64
+ -------
65
+ str
66
+ A non-empty identifier string.
67
+ """
68
+ if request is None:
69
+ return "unknown"
70
+
71
+ # 1. HuggingFace OAuth username (available when HF OAuth is enabled on the Space)
72
+ username = getattr(request, "username", None)
73
+ if username:
74
+ return f"user:{username}"
75
+
76
+ # Normalise headers to lowercase keys for consistent lookup
77
+ raw_headers: dict = {}
78
+ if hasattr(request, "headers") and request.headers:
79
+ try:
80
+ raw_headers = {k.lower(): v for k, v in dict(request.headers).items()}
81
+ except Exception:
82
+ pass
83
+
84
+ # 2. X-Forwarded-For β€” proxy / CDN chain; leftmost entry is the original client
85
+ xff = raw_headers.get("x-forwarded-for", "")
86
+ if xff:
87
+ client_ip = xff.split(",")[0].strip()
88
+ if client_ip:
89
+ return f"ip:{client_ip}"
90
+
91
+ # 3. X-Real-IP β€” set by some reverse proxies (nginx, etc.)
92
+ x_real_ip = raw_headers.get("x-real-ip", "")
93
+ if x_real_ip:
94
+ return f"ip:{x_real_ip.strip()}"
95
+
96
+ # 4. Direct connection host (only reliable when not behind a proxy)
97
+ client = getattr(request, "client", None)
98
+ if client and getattr(client, "host", None):
99
+ return f"ip:{client.host}"
100
+
101
+ return "unknown"
102
+
103
+
104
+ # ---------------------------------------------------------------------------
105
+ # Rate limiter
106
+ # ---------------------------------------------------------------------------
107
+
108
+ class RateLimiter:
109
+ """
110
+ Thread-safe hard-cap rate limiter keyed by client identifier.
111
+
112
+ Counters are cumulative totals with no time window β€” once a limit is
113
+ reached the client is permanently blocked for that action.
114
+
115
+ Tracks four independent counters per key:
116
+
117
+ * **sim_starts** β€” calls to ``start_simulation()``
118
+ * **chat_msgs** β€” individual chat messages (1 LLM call each)
119
+ * **auto_runs** β€” auto simulation runs (each reserved as
120
+ ``_AUTO_RUN_CALL_RESERVATION`` LLM calls in ``total_calls``)
121
+ * **total_calls** β€” aggregate LLM API calls across all modes
122
+
123
+ Example
124
+ -------
125
+ >>> limiter = RateLimiter()
126
+ >>> allowed, msg = limiter.check_simulation_start("ip:1.2.3.4")
127
+ >>> if not allowed:
128
+ ... raise gr.Error(msg)
129
+ """
130
+
131
+ def __init__(self) -> None:
132
+ self._lock = threading.Lock()
133
+ self._sim_starts: Dict[str, int] = defaultdict(int)
134
+ self._chat_msgs: Dict[str, int] = defaultdict(int)
135
+ self._auto_runs: Dict[str, int] = defaultdict(int)
136
+ self._total_calls: Dict[str, int] = defaultdict(int)
137
+
138
+ # ------------------------------------------------------------------
139
+ # Private helpers
140
+ # ------------------------------------------------------------------
141
+
142
+ def _increment(
143
+ self,
144
+ store: Dict[str, int],
145
+ key: str,
146
+ limit: int,
147
+ *,
148
+ n: int = 1,
149
+ ) -> Tuple[bool, int]:
150
+ """
151
+ Increment counter by *n* and check whether the new total exceeds *limit*.
152
+
153
+ Returns
154
+ -------
155
+ (allowed, new_count)
156
+ """
157
+ with self._lock:
158
+ store[key] += n
159
+ count = store[key]
160
+ return count <= limit, count
161
+
162
+ def _decrement(self, store: Dict[str, int], key: str, n: int = 1) -> None:
163
+ """Roll back a previous increment (used when a subsequent check fails)."""
164
+ with self._lock:
165
+ store[key] = max(0, store[key] - n)
166
+
167
+ # ------------------------------------------------------------------
168
+ # Public check methods
169
+ # ------------------------------------------------------------------
170
+
171
+ def check_simulation_start(self, key: str) -> Tuple[bool, str]:
172
+ """
173
+ Check whether a new simulation setup is allowed.
174
+
175
+ Called once when the user clicks **Start Simulation**.
176
+
177
+ Parameters
178
+ ----------
179
+ key:
180
+ Client identifier returned by :func:`get_client_key`.
181
+
182
+ Returns
183
+ -------
184
+ (True, "") β€” allowed
185
+ (False, human-readable message) β€” denied
186
+ """
187
+ allowed, count = self._increment(self._sim_starts, key, SIM_STARTS_LIMIT)
188
+ if not allowed:
189
+ return False, (
190
+ f"Simulation setup limit reached "
191
+ f"(maximum {SIM_STARTS_LIMIT} simulations per session)."
192
+ )
193
+ return True, ""
194
+
195
+ def check_chat_message(self, key: str) -> Tuple[bool, str]:
196
+ """
197
+ Check whether sending a chat message is allowed (= 1 LLM API call).
198
+
199
+ Increments both ``chat_msgs`` and ``total_calls``.
200
+
201
+ Parameters
202
+ ----------
203
+ key:
204
+ Client identifier returned by :func:`get_client_key`.
205
+
206
+ Returns
207
+ -------
208
+ (True, "") β€” allowed
209
+ (False, human-readable message) β€” denied
210
+ """
211
+ allowed_msg, _ = self._increment(self._chat_msgs, key, CHAT_MSGS_LIMIT)
212
+ if not allowed_msg:
213
+ return False, (
214
+ f"Chat message limit reached "
215
+ f"(maximum {CHAT_MSGS_LIMIT} messages per session)."
216
+ )
217
+
218
+ allowed_total, _ = self._increment(self._total_calls, key, TOTAL_API_CALLS_LIMIT)
219
+ if not allowed_total:
220
+ self._decrement(self._chat_msgs, key)
221
+ return False, (
222
+ f"Total API call limit reached "
223
+ f"(maximum {TOTAL_API_CALLS_LIMIT} API calls per session)."
224
+ )
225
+ return True, ""
226
+
227
+ def check_auto_run(self, key: str) -> Tuple[bool, str]:
228
+ """
229
+ Check whether starting an auto simulation is allowed.
230
+
231
+ Reserves ``_AUTO_RUN_CALL_RESERVATION`` slots in the ``total_calls``
232
+ counter upfront because each auto run may issue up to that many LLM
233
+ calls before it finishes.
234
+
235
+ Parameters
236
+ ----------
237
+ key:
238
+ Client identifier returned by :func:`get_client_key`.
239
+
240
+ Returns
241
+ -------
242
+ (True, "") β€” allowed
243
+ (False, human-readable message) β€” denied
244
+ """
245
+ allowed_run, _ = self._increment(self._auto_runs, key, AUTO_RUNS_LIMIT)
246
+ if not allowed_run:
247
+ return False, (
248
+ f"Auto simulation limit reached "
249
+ f"(maximum {AUTO_RUNS_LIMIT} auto runs per session)."
250
+ )
251
+
252
+ allowed_total, _ = self._increment(
253
+ self._total_calls, key, TOTAL_API_CALLS_LIMIT,
254
+ n=_AUTO_RUN_CALL_RESERVATION,
255
+ )
256
+ if not allowed_total:
257
+ self._decrement(self._auto_runs, key)
258
+ return False, (
259
+ f"Total API call limit reached "
260
+ f"(maximum {TOTAL_API_CALLS_LIMIT} API calls per session)."
261
+ )
262
+ return True, ""
263
+
264
+ # ------------------------------------------------------------------
265
+ # Diagnostic
266
+ # ------------------------------------------------------------------
267
+
268
+ def status(self, key: str) -> dict:
269
+ """
270
+ Return current counter snapshots for *key*.
271
+
272
+ Useful for debugging or exposing quota information in the UI.
273
+
274
+ Returns
275
+ -------
276
+ dict with keys ``sim_starts``, ``chat_messages``, ``auto_runs``,
277
+ ``total_api_calls``; each value is a dict with ``used`` and ``limit``.
278
+ """
279
+ with self._lock:
280
+ return {
281
+ "sim_starts": {"used": self._sim_starts[key], "limit": SIM_STARTS_LIMIT},
282
+ "chat_messages": {"used": self._chat_msgs[key], "limit": CHAT_MSGS_LIMIT},
283
+ "auto_runs": {"used": self._auto_runs[key], "limit": AUTO_RUNS_LIMIT},
284
+ "total_api_calls":{"used": self._total_calls[key], "limit": TOTAL_API_CALLS_LIMIT},
285
+ }