dek924 commited on
Commit
5f9fc5a
Β·
1 Parent(s): e42bc71

feat: add user api input

Browse files
Files changed (1) hide show
  1. app.py +52 -24
app.py CHANGED
@@ -540,6 +540,7 @@ def start_simulation(
540
  personality: str,
541
  recall: str,
542
  confusion: str,
 
543
  request: gr.Request = None,
544
  ):
545
  if not hadm_id:
@@ -548,14 +549,20 @@ def start_simulation(
548
  if model not in BACKEND_MODELS:
549
  return _setup_error("Invalid model selection.")
550
 
551
- client_key = get_client_key(request)
552
- allowed, limit_msg = _rate_limiter.check_simulation_start(client_key)
553
- if not allowed:
554
- return _setup_error(limit_msg)
 
 
 
 
555
 
556
  is_openai = "gpt" in model.lower()
557
 
558
- if is_openai:
 
 
559
  api_key = os.environ.get("OPENAI_API_KEY", "")
560
  else:
561
  api_key = os.environ.get("GENAI_API_KEY", "") or os.environ.get("GOOGLE_API_KEY", "")
@@ -597,6 +604,7 @@ def start_simulation(
597
  "patient": patient,
598
  "model": model,
599
  "recap_html": recap,
 
600
  }
601
 
602
  return (
@@ -660,7 +668,7 @@ _INJECTION_PATTERNS = re.compile(
660
  )
661
 
662
 
663
- def chat(message: str, history: list, agent, request: gr.Request = None):
664
  if agent is None:
665
  raise gr.Error("No simulation running. Please start a simulation first.")
666
  if not message.strip():
@@ -676,10 +684,12 @@ def chat(message: str, history: list, agent, request: gr.Request = None):
676
  _logger.warning("Prompt injection attempt detected from key=%s", get_client_key(request))
677
  raise gr.Error("Invalid input detected. Please enter a valid clinical question.")
678
 
679
- client_key = get_client_key(request)
680
- allowed, limit_msg = _rate_limiter.check_chat_message(client_key)
681
- if not allowed:
682
- raise gr.Error(limit_msg)
 
 
683
 
684
  response = agent(user_prompt=message, using_multi_turn=True, verbose=False)
685
  history = history + [
@@ -736,11 +746,13 @@ def start_auto(agent, sim_config: dict, request: gr.Request = None):
736
  yield _auto_fallback_outputs()
737
  return
738
 
739
- allowed, limit_msg = _rate_limiter.check_auto_run(client_key)
740
- if not allowed:
741
- gr.Warning(limit_msg)
742
- yield _auto_fallback_outputs()
743
- return
 
 
744
 
745
  try:
746
  agent.reset_history(verbose=False)
@@ -758,7 +770,9 @@ def start_auto(agent, sim_config: dict, request: gr.Request = None):
758
 
759
  model = sim_config["model"]
760
  is_openai = "gpt" in model.lower()
761
- if is_openai:
 
 
762
  api_key = os.environ.get("OPENAI_API_KEY", "")
763
  else:
764
  api_key = os.environ.get("GENAI_API_KEY", "") or os.environ.get("GOOGLE_API_KEY", "")
@@ -914,12 +928,26 @@ with gr.Blocks(title="PatientSim", theme=gr.themes.Soft(), css=CUSTOM_CSS) as de
914
 
915
  # ── Connection card ──────────────────────────────────────────────────
916
  with gr.Group(elem_classes=["form-card"]):
917
- # gr.HTML("<span class='card-title'>Model Selection</span>")
918
- model_dd = gr.Dropdown(
919
- choices=BACKEND_MODELS,
920
- value=BACKEND_MODELS[0],
921
- label="Model Selection",
 
922
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
923
 
924
  # ── Patient Case card (one gr.Button per patient) ────────────────────
925
  with gr.Group(elem_classes=["form-card"]):
@@ -1165,7 +1193,7 @@ with gr.Blocks(title="PatientSim", theme=gr.themes.Soft(), css=CUSTOM_CSS) as de
1165
  # Start simulation β†’ mode selection
1166
  start_btn.click(
1167
  fn=start_simulation,
1168
- inputs=[selected_patient_state, model_dd, cefr_radio, personality_dd, recall_radio, confusion_radio],
1169
  outputs=[patient_agent_state, sim_config_state, setup_section, mode_section, recap_display, setup_error_display],
1170
  )
1171
 
@@ -1203,12 +1231,12 @@ with gr.Blocks(title="PatientSim", theme=gr.themes.Soft(), css=CUSTOM_CSS) as de
1203
  )
1204
  chat_event_send = send_btn.click(
1205
  fn=chat,
1206
- inputs=[msg_box, chatbot, patient_agent_state],
1207
  outputs=[chatbot, msg_box],
1208
  )
1209
  chat_event_submit = msg_box.submit(
1210
  fn=chat,
1211
- inputs=[msg_box, chatbot, patient_agent_state],
1212
  outputs=[chatbot, msg_box],
1213
  )
1214
  back_from_chat_to_mode_btn.click(
 
540
  personality: str,
541
  recall: str,
542
  confusion: str,
543
+ user_api_key: str = "",
544
  request: gr.Request = None,
545
  ):
546
  if not hadm_id:
 
549
  if model not in BACKEND_MODELS:
550
  return _setup_error("Invalid model selection.")
551
 
552
+ using_own_key = bool(user_api_key.strip())
553
+
554
+ # Only apply rate limiting when using the shared demo key
555
+ if not using_own_key:
556
+ client_key = get_client_key(request)
557
+ allowed, limit_msg = _rate_limiter.check_simulation_start(client_key)
558
+ if not allowed:
559
+ return _setup_error(limit_msg)
560
 
561
  is_openai = "gpt" in model.lower()
562
 
563
+ if using_own_key:
564
+ api_key = user_api_key.strip()
565
+ elif is_openai:
566
  api_key = os.environ.get("OPENAI_API_KEY", "")
567
  else:
568
  api_key = os.environ.get("GENAI_API_KEY", "") or os.environ.get("GOOGLE_API_KEY", "")
 
604
  "patient": patient,
605
  "model": model,
606
  "recap_html": recap,
607
+ "user_api_key": user_api_key.strip(), # empty string = using shared demo key
608
  }
609
 
610
  return (
 
668
  )
669
 
670
 
671
+ def chat(message: str, history: list, agent, sim_config: dict, request: gr.Request = None):
672
  if agent is None:
673
  raise gr.Error("No simulation running. Please start a simulation first.")
674
  if not message.strip():
 
684
  _logger.warning("Prompt injection attempt detected from key=%s", get_client_key(request))
685
  raise gr.Error("Invalid input detected. Please enter a valid clinical question.")
686
 
687
+ using_own_key = bool(sim_config and sim_config.get("user_api_key"))
688
+ if not using_own_key:
689
+ client_key = get_client_key(request)
690
+ allowed, limit_msg = _rate_limiter.check_chat_message(client_key)
691
+ if not allowed:
692
+ raise gr.Error(limit_msg)
693
 
694
  response = agent(user_prompt=message, using_multi_turn=True, verbose=False)
695
  history = history + [
 
746
  yield _auto_fallback_outputs()
747
  return
748
 
749
+ using_own_key = bool(sim_config.get("user_api_key"))
750
+ if not using_own_key:
751
+ allowed, limit_msg = _rate_limiter.check_auto_run(client_key)
752
+ if not allowed:
753
+ gr.Warning(limit_msg)
754
+ yield _auto_fallback_outputs()
755
+ return
756
 
757
  try:
758
  agent.reset_history(verbose=False)
 
770
 
771
  model = sim_config["model"]
772
  is_openai = "gpt" in model.lower()
773
+ if using_own_key:
774
+ api_key = sim_config["user_api_key"]
775
+ elif is_openai:
776
  api_key = os.environ.get("OPENAI_API_KEY", "")
777
  else:
778
  api_key = os.environ.get("GENAI_API_KEY", "") or os.environ.get("GOOGLE_API_KEY", "")
 
928
 
929
  # ── Connection card ──────────────────────────────────────────────────
930
  with gr.Group(elem_classes=["form-card"]):
931
+ gr.Markdown(
932
+ "**πŸ”‘ Model & API Key**\n\n"
933
+ "This demo runs on a shared API key with a limited number of free calls. "
934
+ "If the free quota has been exhausted, please enter your own API key below "
935
+ "(OpenAI or Google Gemini) to continue without restrictions. "
936
+ "Your key is used only for this session and is never stored on our servers."
937
  )
938
+ with gr.Row(equal_height=True):
939
+ model_dd = gr.Dropdown(
940
+ choices=BACKEND_MODELS,
941
+ value=BACKEND_MODELS[0],
942
+ label="Model",
943
+ scale=1,
944
+ )
945
+ api_key_input = gr.Textbox(
946
+ label="API Key (optional)",
947
+ placeholder="Leave blank to use the shared demo key Β· sk-... or paste your Gemini key",
948
+ type="password",
949
+ scale=2,
950
+ )
951
 
952
  # ── Patient Case card (one gr.Button per patient) ────────────────────
953
  with gr.Group(elem_classes=["form-card"]):
 
1193
  # Start simulation β†’ mode selection
1194
  start_btn.click(
1195
  fn=start_simulation,
1196
+ inputs=[selected_patient_state, model_dd, cefr_radio, personality_dd, recall_radio, confusion_radio, api_key_input],
1197
  outputs=[patient_agent_state, sim_config_state, setup_section, mode_section, recap_display, setup_error_display],
1198
  )
1199
 
 
1231
  )
1232
  chat_event_send = send_btn.click(
1233
  fn=chat,
1234
+ inputs=[msg_box, chatbot, patient_agent_state, sim_config_state],
1235
  outputs=[chatbot, msg_box],
1236
  )
1237
  chat_event_submit = msg_box.submit(
1238
  fn=chat,
1239
+ inputs=[msg_box, chatbot, patient_agent_state, sim_config_state],
1240
  outputs=[chatbot, msg_box],
1241
  )
1242
  back_from_chat_to_mode_btn.click(