import os import re import html import json import copy import logging import gradio as gr from dotenv import load_dotenv, find_dotenv from patientsim import PatientAgent, DoctorAgent from patientsim.utils.common_utils import detect_ed_termination from rate_limiter import RateLimiter, get_client_key load_dotenv(find_dotenv(usecwd=True), override=False) logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s [%(name)s] %(message)s", handlers=[logging.StreamHandler()], ) _logger = logging.getLogger("patientsim.app") # --------------------------------------------------------------------------- # Rate limiter (singleton — shared across all Gradio worker threads) # --------------------------------------------------------------------------- _rate_limiter = RateLimiter() # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- CEFR_CHOICES = [ ("A — Beginner", "A"), ("B — Intermediate", "B"), ("C — Advanced", "C"), ] PERSONALITY_CHOICES = [ ("Neutral", "plain"), ("Talkative", "verbose"), ("Distrustful", "distrust"), ("Pleasing", "pleasing"), ("Impatient", "impatient"), ("Overanxious", "overanxious"), ] RECALL_CHOICES = [ ("Low", "low"), ("High", "high"), ] CONFUSION_CHOICES = [ ("Normal", "normal"), ("High", "high"), ] BACKEND_MODELS = [ "gemini-3.1-flash-lite-preview", "gemini-3-flash-preview", "gemini-2.5-flash", "gpt-5.4-nano", "gpt-5.4-mini", "gpt-5.4", ] MAX_AUTO_INFERENCES = 10 MAX_MESSAGE_CHARS = 2000 # --------------------------------------------------------------------------- # Patient data # --------------------------------------------------------------------------- _DATA_PATH = os.path.join(os.path.abspath(os.path.dirname(__file__)), "demo", "data", "demo_data.json") try: with open(_DATA_PATH) as _f: PATIENT_DATA: list[dict] = json.load(_f) except (FileNotFoundError, json.JSONDecodeError) as _e: raise RuntimeError(f"Failed to load patient data from {_DATA_PATH}: {_e}") from _e PATIENT_DICT: dict[str, dict] = {p["hadm_id"]: p for p in PATIENT_DATA} _DOCTOR_AVATAR = "https://cdn-icons-png.flaticon.com/512/3774/3774299.png" # Per-patient avatar URLs matched to demographics (gender + age group) _PATIENT_AVATAR_URLS = { "patient_MI": "https://cdn-icons-png.flaticon.com/512/5488/5488324.png", # 64yo Male — elderly man "patient_PNA": "https://cdn-icons-png.flaticon.com/512/4140/4140051.png", # 47yo Female — middle-aged woman "patient_UTI": "https://cdn-icons-png.flaticon.com/512/4140/4140047.png", # 29yo Female — young woman } _SORTED_PATIENTS: list[dict] = sorted(PATIENT_DATA, key=lambda x: x["hadm_id"]) def _sorted_patients() -> list[dict]: return _SORTED_PATIENTS def _build_single_card_html(p: dict, selected: bool, avatar_url: str) -> str: """Build HTML for a single patient card (no select button — handled by gr.Button).""" border_color = "#3b82f6" if selected else "#e5e7eb" bg_color = "#eff6ff" if selected else "#ffffff" shadow = "0 0 0 2px #3b82f6" if selected else "0 1px 3px rgba(0,0,0,0.06)" age = html.escape(str(p.get("age", "?"))) gender = html.escape(str(p.get("gender", "Unknown"))) diagnosis = html.escape(str(p.get("diagnosis", "Unknown"))) chief = html.escape(str(p.get("chiefcomplaint", "—"))) transport = html.escape(str(p.get("arrival_transport", "—"))) safe_avatar = html.escape(avatar_url) return ( f"
" f"
" f"Patient" f"
" f"
" f"
Age: {age} · Gender: {gender}
" f"
Chief Complaint: {chief}
" f"
Transport: {transport}
" f"
Dx: {diagnosis}
" f"
" f"
" ) # --------------------------------------------------------------------------- # HTML helpers # --------------------------------------------------------------------------- def build_recap_html(hadm_id: str, model: str, cefr: str, personality: str, recall: str, confusion: str) -> str: patient = PATIENT_DICT.get(hadm_id, {}) patient_label = ( f"Age {patient.get('age')} · {patient.get('gender')} · {patient.get('diagnosis', 'Unknown')}" if patient else "—" ) personality_label = next((l for l, v in PERSONALITY_CHOICES if v == personality), personality) cefr_label = next((l for l, v in CEFR_CHOICES if v == cefr), cefr) recall_label = next((l for l, v in RECALL_CHOICES if v == recall), recall) confusion_label = next((l for l, v in CONFUSION_CHOICES if v == confusion), confusion) def _card(label, value): safe_label = html.escape(str(label)) safe_value = html.escape(str(value)) return ( "
" f"
{safe_label}
" f"
{safe_value}
" "
" ) items = [ _card("Patient", patient_label), _card("Personality", personality_label), _card("Model", model), _card("Language Proficiency", cefr_label), _card("Medical History Recall", recall_label), _card("Cognitive Confusion", confusion_label), ] grid_items = "".join(items) return ( "
" "
📋 Simulation Configuration
" f"
{grid_items}
" "
" ) # --------------------------------------------------------------------------- # Profile HTML helpers (card-based layout matching recap style) # --------------------------------------------------------------------------- def _profile_item(label: str, val) -> str: """Build a single key-value item inside a profile card.""" safe_val = html.escape(str(val)) if val not in (None, "") else "N/A" safe_label = html.escape(str(label)) return ( f"
" f"{safe_label}" f"{safe_val}" f"
" ) def _profile_card(icon: str, title: str, items_html: str, accent: str = "#3b82f6") -> str: """Build a styled card section for the profile panel.""" return ( f"
" f"
" f"{icon}" f"{html.escape(title)}" f"
" f"{items_html}" f"
" ) def build_profile_html(p: dict) -> str: hadm_id = p.get("hadm_id", "") avatar_url = _PATIENT_AVATAR_URLS.get(hadm_id, list(_PATIENT_AVATAR_URLS.values())[0]) safe_avatar = html.escape(avatar_url) # Header with avatar header = ( f"
" f"Patient" f"
Patient Profile
" f"
" ) basic = _profile_card("👤", "Demographics", _profile_item("Age", p.get("age")) + _profile_item("Gender", p.get("gender")) + _profile_item("Race", p.get("race")) + _profile_item("Transport", p.get("arrival_transport")), accent="#3b82f6", ) social = _profile_card("🏠", "Social History", _profile_item("Tobacco", p.get("tobacco")) + _profile_item("Alcohol", p.get("alcohol")) + _profile_item("Illicit Drug", p.get("illicit_drug")) + _profile_item("Exercise", p.get("exercise")) + _profile_item("Marital Status", p.get("marital_status")) + _profile_item("Children", p.get("children")) + _profile_item("Living Situation", p.get("living_situation")) + _profile_item("Occupation", p.get("occupation")) + _profile_item("Insurance", p.get("insurance")), accent="#10b981", ) history = _profile_card("📋", "Previous Medical History", _profile_item("Allergies", p.get("allergies")) + _profile_item("Family History", p.get("family_medical_history")) + _profile_item("Medical Devices", p.get("medical_device")) + _profile_item("Prior History", p.get("medical_history")), accent="#f59e0b", ) visit = _profile_card("🩺", "Current Visit", _profile_item("Present Illness (+)", p.get("present_illness_positive")) + _profile_item("Present Illness (−)", p.get("present_illness_negative")) + _profile_item("Chief Complaint", p.get("chiefcomplaint")) + _profile_item("Pain (0–10)", p.get("pain")) + _profile_item("Medications", p.get("medication")) + _profile_item("Disposition", p.get("disposition")) + _profile_item("Diagnosis", p.get("diagnosis")), accent="#ef4444", ) return ( f"
" f"{header}{basic}{social}{history}{visit}" f"
" ) def build_blind_profile_html(p: dict) -> str: """Show only basic demographic info for practice mode without full case details.""" hadm_id = p.get("hadm_id", "") avatar_url = _PATIENT_AVATAR_URLS.get(hadm_id, list(_PATIENT_AVATAR_URLS.values())[0]) safe_avatar = html.escape(avatar_url) header = ( f"
" f"Patient" f"
Patient Info
" f"
" f"Basic demographics only — gather the rest through consultation.
" f"
" ) basic = _profile_card("👤", "Demographics", _profile_item("Age", p.get("age")) + _profile_item("Gender", p.get("gender")) + _profile_item("Race", p.get("race")) + _profile_item("Transport", p.get("arrival_transport")), accent="#3b82f6", ) hint = ( "
" "🔍 Additional information is hidden.
" "Interview the patient to uncover their history." "
" ) return ( f"
" f"{header}{basic}{hint}" f"
" ) # --------------------------------------------------------------------------- # Custom CSS # --------------------------------------------------------------------------- CUSTOM_CSS = """ /* ── Global gothic (sans-serif) font ───────────────────────────── */ @import url('https://fonts.googleapis.com/css2?family=Noto+Sans+KR:wght@300;400;500;600;700&family=Noto+Sans:wght@300;400;500;600;700&display=swap'); *, *::before, *::after { font-family: 'Noto Sans KR', 'Noto Sans', 'Malgun Gothic', 'Apple SD Gothic Neo', 'Segoe UI', sans-serif !important; } .prose h1, .prose h2, .prose h3, .prose h4, .markdown-text h1, .markdown-text h2, .markdown-text h3, .markdown-text h4 { font-family: 'Noto Sans KR', 'Noto Sans', 'Malgun Gothic', 'Apple SD Gothic Neo', 'Segoe UI', sans-serif !important; } /* ── Page background ────────────────────────────────────────────── */ body, .gradio-container, .main, footer { background-color: #f3f4f6 !important; } /* ── All form labels — no blue/accent background ────────────────── */ label span, label > span, .block label span, .label-wrap span, .gradio-container label span, .gradio-dropdown label span, .wrap > label > span, [class*="label"] > span { background: transparent !important; background-color: transparent !important; } /* ── White shadow cards for form sections ───────────────────────── */ .form-card { background: #ffffff !important; border-radius: 14px !important; box-shadow: 0 1px 3px rgba(0,0,0,0.08), 0 0 0 1px rgba(0,0,0,0.05) !important; border: none !important; padding: 16px 20px 20px !important; } .form-card > .gap, .form-card > div { background: transparent !important; border: none !important; gap: 12px !important; padding: 8px !important; } /* ── Card section title — vertically centered with background ──── */ .card-title { font-size: 18px; font-weight: 600; color: #374151; padding: 6px 0; margin-bottom: 4px; border-bottom: 1px solid #e5e7eb; display: flex; align-items: center; justify-content: center; min-height: 32px; line-height: 1; } /* ── Tooltip descriptions — plain text inside the white cell ─────── */ .option-desc { font-size: 15px; color: #4b5563; margin-top: 8px; line-height: 1.4; } /* ── Transparent HTML tip containers ────────────────────────────── */ .tip-html, .tip-html > div, .tip-html > .prose { background: transparent !important; border: none !important; padding: 0 !important; margin: 0 !important; } /* ── CEFR radio & Persona grid — thin border on buttons ─────────── */ .compact-radio .wrap { gap: 6px !important; background: transparent !important; } .compact-radio label { padding: 5px 14px !important; border-radius: 8px !important; font-size: 13px !important; min-width: unset !important; background: transparent !important; border: 1px solid #d1d5db !important; } .compact-radio label:has(input:checked) { border-color: #3b82f6 !important; background: #eff6ff !important; } .persona-grid { display: grid !important; grid-template-columns: 1fr 1fr !important; gap: 16px !important; } /* ── Persona cell — white box containing buttons + description ───── */ .persona-cell { background: #ffffff !important; border: 1px solid #e5e7eb !important; border-radius: 10px !important; padding: 14px 16px !important; display: flex !important; flex-direction: column !important; } /* Force options to use black bold font without background */ .persona-cell label span, .persona-cell span { color: black !important; font-weight: bold !important; background: transparent !important; } /* ── Personality radio — 2-column grid for 6 choices ────────────── */ .personality-radio .wrap { display: grid !important; grid-template-columns: 1fr 1fr !important; gap: 6px !important; } .persona-cell .compact-radio { margin-bottom: 4px !important; } /* ── Mode cards ─────────────────────────────────────────────────── */ .mode-card { background: #ffffff !important; border: 1px solid #e5e7eb !important; border-radius: 14px !important; box-shadow: 0 1px 3px rgba(0,0,0,0.06) !important; padding: 20px !important; } /* ── Start simulation button shadow ─────────────────────────────── */ #start-btn > button { box-shadow: 0 4px 14px rgba(59, 130, 246, 0.30) !important; font-size: 15px !important; } /* ── Patient card HTML wrapper — remove Gradio's default padding ── */ .patient-card-html > div, .patient-card-html .prose { padding: 10 !important; margin: 0 !important; border: none !important; background: #ffffff !important; } /* ── Patient profile panel — strip wrapper chrome ───────────────── */ .profile-display, .profile-display > div, .profile-display .prose { background: transparent !important; border: none !important; padding: 0 !important; margin: 0 !important; } /* White background + padding for each patient column cell */ .form-card .gradio-column { background: #ffffff !important; border-radius: 10px !important; border: 1px solid #e5e7eb !important; padding: 12px !important; gap: 8px !important; } /* Spacing between columns and from the card border */ .form-card .gradio-row { gap: 12px !important; padding: 12px 12px !important; background: transparent !important; } /* Fallback using data-testid if gradio-column class doesn't apply */ .form-card [data-testid="column"] { background: #ffffff !important; border-radius: 10px !important; border: 1px solid #e5e7eb !important; padding: 12px !important; } .form-card [data-testid="row"] { gap: 24px !important; padding: 12px 8px !important; } .patient-row { gap: 16px !important; margin-bottom: 16px !important; padding: 0 24px !important; } /* Gradio 6 wraps row flex content in an inner div — target it directly */ .patient-row > div { gap: flex !important; display: flex !important; flex-wrap: wrap !important; align-items: stretch !important; } .patient-card-column { flex: 1 1 calc(33.333% - 12px) !important; flex-direction: column !important; min-width: 100px !important; max-width: calc(33.333% - 12px) !important; display: flex !important; height: 100% !important; gap: 0 !important; } .patient-card-html { flex: 1 !important; height: 100% !important; } """ # Tooltip maps for descriptions shown below each radio group PERSONALITY_TIPS = { "plain": "No strong emotions or noticeable behavior.", "verbose": "Speaks a lot, gives highly detailed responses.", "distrust": "Questions the doctor's expertise and care.", "pleasing": "Overly positive, tends to minimize problems.", "impatient": "Easily irritated and lacks patience.", "overanxious": "Expresses concern beyond what is typical.", } RECALL_TIPS = { "low": "Often forgets even major medical events.", "high": "Usually recalls medical events accurately.", } CONFUSION_TIPS = { "normal": "Clear mental status.", "high": "Highly dazed and extremely confused.", } CEFR_TIPS = { "A": "Can make simple sentences.", "B": "Can have daily conversations.", "C": "Can freely use advanced medical terms.", } def _make_tip_html(tips: dict, selected: str) -> str: desc = tips.get(selected, "") return f"

{desc}

" if desc else "" # --------------------------------------------------------------------------- # Callbacks — Setup # --------------------------------------------------------------------------- def go_to_setup(): return gr.update(visible=False), gr.update(visible=True) def _sanitize_error(msg: str) -> str: """Remove API key patterns and internal paths from error messages.""" msg = re.sub(r'sk-[A-Za-z0-9_-]{10,}', '[REDACTED]', msg) msg = re.sub(r'AIza[A-Za-z0-9_-]{20,}', '[REDACTED]', msg) msg = re.sub(r'AQ\.[A-Za-z0-9_-]{10,}', '[REDACTED]', msg) msg = re.sub(r'/(?:home|data|tmp|var|usr|etc)/[^\s\'"]+', '[PATH_REDACTED]', msg) return msg def _setup_error(msg: str): """Return output tuple that keeps setup visible and shows an inline error.""" safe_msg = html.escape(_sanitize_error(msg)) error_html = ( f"
⚠️ {safe_msg}
" ) return ( gr.update(), # patient_agent_state — no change gr.update(), # sim_config_state — no change gr.update(), # setup_section — no change (already visible) gr.update(), # mode_section — no change (already hidden) gr.update(), # recap_display — no change gr.update(value=error_html, visible=True), # show setup_error_display ) def start_simulation( hadm_id: str, model: str, cefr: str, personality: str, recall: str, confusion: str, user_api_key: str = "", request: gr.Request = None, ): if not hadm_id: return _setup_error("Please select a patient first.") if model not in BACKEND_MODELS: return _setup_error("Invalid model selection.") using_own_key = bool(user_api_key.strip()) is_openai = "gpt" in model.lower() if using_own_key: api_key = user_api_key.strip() elif is_openai: api_key = os.environ.get("OPENAI_API_KEY", "") else: api_key = os.environ.get("GENAI_API_KEY", "") or os.environ.get("GOOGLE_API_KEY", "") if not api_key: key_name = "OPENAI_API_KEY" if is_openai else "GENAI_API_KEY / GOOGLE_API_KEY" return _setup_error( f"API key not configured ({key_name}). " "Run with your own key, or contact the demo administrator." ) patient = PATIENT_DICT.get(hadm_id) if patient is None: return _setup_error("Invalid patient selection. Please select a valid patient from the list.") patient = copy.deepcopy(patient) try: agent = PatientAgent( model=model, visit_type="emergency_department", personality=personality, recall_level=recall, confusion_level=confusion, lang_proficiency_level=cefr, api_key=api_key, temperature=0.7, num_word_sample=10, random_seed=42, log_verbose=False, use_vertex=True, # Force using Vertex AI for all models to ensure consistent behavior and better error handling **patient, ) except Exception as e: _logger.error("Failed to initialize patient agent: %s", _sanitize_error(str(e))) return _setup_error(f"Failed to initialize patient agent: {_sanitize_error(str(e))}") recap = build_recap_html(hadm_id, model, cefr, personality, recall, confusion) sim_config = { "patient": patient, "model": model, "recap_html": recap, "user_api_key": user_api_key.strip(), # empty string = using shared demo key } return ( agent, sim_config, gr.update(visible=False), # hide setup_section gr.update(visible=True), # show mode_section gr.update(value=recap), # recap_display gr.update(visible=False), # hide setup_error_display ) def back_to_setup(agent): if agent is not None: agent.reset_history(verbose=False) return ( None, # clear patient_agent_state None, # clear sim_config_state gr.update(visible=True), # show setup_section gr.update(visible=False), # hide mode_section gr.update(visible=False), # hide chat_section gr.update(visible=False), # hide auto_section gr.update(value=""), # clear recap_display gr.update(value="", visible=False), # clear setup_error_display ) # --------------------------------------------------------------------------- # Callbacks — Manual practice mode # --------------------------------------------------------------------------- def start_manual(profile_mode: str, agent, sim_config: dict): if agent is None or sim_config is None: raise gr.Warning("Session expired. Please restart.") agent.reset_history(verbose=False) patient = sim_config["patient"] profile_html = ( build_profile_html(patient) if profile_mode == "full" else build_blind_profile_html(patient) ) hadm_id = sim_config["patient"].get("hadm_id", "") patient_avatar = _PATIENT_AVATAR_URLS.get(hadm_id, list(_PATIENT_AVATAR_URLS.values())[0]) return ( profile_html, sim_config.get("recap_html", ""), # show recap in left panel gr.update(value=[], avatar_images=(_DOCTOR_AVATAR, patient_avatar)), # chatbot with matched avatar [], # clear chat log gr.update(visible=False), # hide mode_section gr.update(visible=False), # hide auto_section gr.update(visible=True), # show chat_section ) _INJECTION_PATTERNS = re.compile( r'ignore\s+(?:\w+\s+){0,3}(instructions?|prompts?|rules?)' r'|(?:forget|disregard)\s+(?:\w+\s+){0,3}instructions?' r'|you\s+are\s+(?:now|actually)\s+(?:a\s+|an\s+)?(?:DAN|jailbreak)' r'|reveal\s+(?:\w+\s+){0,2}(system\s+)?prompt' r'|act\s+as\s+(?:if\s+)?you\s+(?:have\s+no|are\s+without)\s+(?:restrictions?|guidelines?)', re.IGNORECASE, ) def chat(message: str, history: list, agent, sim_config: dict, request: gr.Request = None): if agent is None: raise gr.Error("No simulation running. Please start a simulation first.") if not message.strip(): return history, "" if len(message) > MAX_MESSAGE_CHARS: raise gr.Error( f"Message too long ({len(message)} characters). " f"Please keep messages under {MAX_MESSAGE_CHARS} characters." ) if _INJECTION_PATTERNS.search(message): _logger.warning("Prompt injection attempt detected from key=%s", get_client_key(request)) raise gr.Error("Invalid input detected. Please enter a valid clinical question.") using_own_key = bool(sim_config and sim_config.get("user_api_key")) client_key = get_client_key(request) if not using_own_key: allowed, limit_msg = _rate_limiter.check_chat_message(client_key) if not allowed: raise gr.Error(limit_msg) else: # Own-key users bypass per-IP quotas but still respect global capacity. allowed, limit_msg = _rate_limiter.check_global_capacity() if not allowed: raise gr.Error(limit_msg) response = agent(user_prompt=message, using_multi_turn=True, verbose=False) history = history + [ {"role": "user", "content": message}, {"role": "assistant", "content": response}, ] return history, history, "" def reset_chat(agent): if agent is not None: agent.reset_history(verbose=False) return gr.update(value=[]), [] def back_to_setup_from_chat(agent): if agent is not None: agent.reset_history(verbose=False) return ( [], # clear chat log "", # clear profile html "", # clear chat recap None, # clear patient_agent_state None, # clear sim_config_state gr.update(visible=True), # show setup_section gr.update(visible=False), # hide mode_section gr.update(visible=False), # hide auto_section gr.update(visible=False), # hide chat_section gr.update(value=""), # clear recap_display gr.update(value="", visible=False), # clear setup_error_display ) # --------------------------------------------------------------------------- # Callbacks — Auto simulation mode # --------------------------------------------------------------------------- def _auto_fallback_outputs(): """Return fresh gr.update() instances for error/exit paths in start_auto.""" return ( gr.update(), # auto_chatbot gr.update(visible=True), # mode_section gr.update(visible=False), # auto_section gr.update(visible=False), # chat_section gr.update(), # auto_recap ) def start_auto(agent, sim_config: dict, request: gr.Request = None): """Generator — yields chatbot updates turn-by-turn so the UI streams live.""" client_key = get_client_key(request) if agent is None or sim_config is None: gr.Warning("Session expired. Please restart.") yield _auto_fallback_outputs() return using_own_key = bool(sim_config.get("user_api_key")) if not using_own_key: allowed, limit_msg = _rate_limiter.check_auto_run(client_key) if not allowed: gr.Warning(limit_msg) yield _auto_fallback_outputs() return else: # Own-key users bypass per-IP quotas but still enforce the concurrent # run cap and the hard global capacity limit. allowed, limit_msg = _rate_limiter.check_own_key_auto_run(client_key) if not allowed: gr.Warning(limit_msg) yield _auto_fallback_outputs() return try: agent.reset_history(verbose=False) # Show auto_section immediately; set per-patient avatar on first yield _hadm_id = sim_config["patient"].get("hadm_id", "") _patient_avatar = _PATIENT_AVATAR_URLS.get(_hadm_id, list(_PATIENT_AVATAR_URLS.values())[0]) yield ( gr.update(value=[], avatar_images=(_DOCTOR_AVATAR, _patient_avatar)), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), sim_config.get("recap_html", ""), ) model = sim_config["model"] is_openai = "gpt" in model.lower() if using_own_key: api_key = sim_config["user_api_key"] elif is_openai: api_key = os.environ.get("OPENAI_API_KEY", "") else: api_key = os.environ.get("GENAI_API_KEY", "") or os.environ.get("GOOGLE_API_KEY", "") try: doctor = DoctorAgent( model=model, api_key=api_key, temperature=0.2, random_seed=42, max_inferences=MAX_AUTO_INFERENCES, use_vertex=True, **{k: sim_config["patient"][k] for k in ("age", "gender", "arrival_transport") if k in sim_config["patient"]}, ) except Exception as e: _logger.error("Failed to initialize doctor agent: %s", _sanitize_error(str(e))) gr.Error(f"Failed to initialize doctor agent: {_sanitize_error(str(e))}") yield _auto_fallback_outputs() return # Switch to auto_section immediately with empty chatbot; set recap once chat_history = [] recap_html = sim_config.get("recap_html", "") yield chat_history, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), recap_html def _append(role: str, content: str): chat_history.append({"role": role, "content": content}) def _yield(): return chat_history, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update() try: # Doctor greets first doctor_greet = doctor.doctor_greet _append("user", doctor_greet) yield _yield() for inference_idx in range(MAX_AUTO_INFERENCES): is_last = inference_idx == MAX_AUTO_INFERENCES - 1 # Patient responds to the last doctor message patient_response = agent( user_prompt=chat_history[-1]["content"], using_multi_turn=True, verbose=False, ) _append("assistant", patient_response) yield _yield() # Doctor responds doctor_input = chat_history[-1]["content"] if is_last: doctor_input += "\nThis is the final turn. Now, you must provide your top 5 differential diagnoses." doctor_response = doctor( user_prompt=doctor_input, using_multi_turn=True, verbose=False, ) _append("user", doctor_response) yield _yield() if detect_ed_termination(doctor_response): break except Exception as e: gr.Error(f"Simulation error: {_sanitize_error(str(e))}") yield _yield() finally: _rate_limiter.release_auto_slot(client_key) def back_to_mode_from_auto(agent): if agent is not None: agent.reset_history(verbose=False) return ( [], # clear auto chatbot log "", # clear auto recap gr.update(visible=True), # show mode_section gr.update(visible=False), # hide auto_section ) def back_to_mode_from_chat(agent): if agent is not None: agent.reset_history(verbose=False) return ( [], # clear chat log "", # clear profile html "", # clear chat recap gr.update(visible=True), # show mode_section gr.update(visible=False), # hide chat_section ) def back_to_setup_from_auto(agent): if agent is not None: agent.reset_history(verbose=False) return ( [], # clear auto chatbot log "", # clear auto recap None, # clear patient_agent_state None, # clear sim_config_state gr.update(visible=True), # show setup_section gr.update(visible=False), # hide mode_section gr.update(visible=False), # hide auto_section gr.update(visible=False), # hide chat_section gr.update(value=""), # clear recap_display gr.update(value="", visible=False), # clear setup_error_display ) # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- with gr.Blocks(title="PatientSim", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: patient_agent_state = gr.State(None) sim_config_state = gr.State(None) selected_patient_state = gr.State(None) chat_history_state = gr.State([]) # ── Intro section ──────────────────────────────────────────────────────── with gr.Column(visible=True) as intro_section: gr.Markdown( "# 🏥 PatientSim — ED Consultation Demo\n\n" "**PatientSim** is a research framework for simulating realistic emergency department " "doctor–patient interactions, presented at " "[NeurIPS 2025 (Datasets & Benchmarks)](https://openreview.net/forum?id=1THAjdP4QJ).\n\n" "Large language models act as patients with **controllable personas** — you choose their " "personality, language proficiency, medical history recall, and cognitive state — producing " "diverse and realistic consultation scenarios for training and evaluation.\n\n" "---\n\n" "### What you can do here\n\n" "| Mode | Description |\n" "|------|-------------|\n" "| 🤖 **Auto Simulation** | Watch an AI doctor conduct a full consultation with the simulated patient and arrive at a differential diagnosis — no input required. |\n" "| 🩺 **Practice Mode** | You play the doctor. Consult the simulated patient, gather medical history, and work toward a diagnosis yourself. |\n\n" "---\n\n" "### How to get started\n" "1. Click **Get Started** below.\n" "2. Enter your API key and select a patient case and persona.\n" "3. Choose a simulation mode and begin." ) get_started_btn = gr.Button("Get Started →", variant="primary", size="lg") # ── Setup section ──────────────────────────────────────────────────────── with gr.Column(visible=False) as setup_section: gr.Markdown( "# 🏥 PatientSim — Setup\n\n" "Configure your session below. Select the patient case you want to simulate, " "choose the AI model to power the patient, and define the patient's persona across " "four behavioral axes. When ready, click **Start Simulation**." ) # ── Connection card ────────────────────────────────────────────────── with gr.Group(elem_classes=["form-card"]): gr.Markdown( "**🔑 Model & API Key**\n\n" "This demo runs on a shared API key with a limited number of free calls. " "If the free quota has been exhausted, please enter your own API key below " "(OpenAI or Google Gemini) to continue without restrictions. " "Your key is used only for this session and is never stored on our servers." ) with gr.Row(equal_height=True): model_dd = gr.Dropdown( choices=BACKEND_MODELS, value=BACKEND_MODELS[0], label="Model", scale=1, ) api_key_input = gr.Textbox( label="API Key (optional)", placeholder="Leave blank to use the shared demo key · sk-... or paste your Gemini key", type="password", scale=2, ) # ── Patient Case card (one gr.Button per patient) ──────────────────── with gr.Group(elem_classes=["form-card"]): gr.HTML("🩺 Patient Case") _all_cards: list[tuple[str, str, gr.HTML]] = [] # (hadm_id, avatar_url, html_component) _all_select_btns: list[tuple[str, gr.Button]] = [] # (hadm_id, button) _patients = _sorted_patients() for _i in range(0, len(_patients), 3): with gr.Row(elem_classes=["patient-row"]): for _j, _p in enumerate(_patients[_i:_i + 3]): _idx = _i + _j _avatar = _PATIENT_AVATAR_URLS.get(_p["hadm_id"], _DOCTOR_AVATAR) with gr.Column(scale=1, elem_classes=["patient-card-column"]): _card_html = gr.HTML( value=_build_single_card_html(_p, False, _avatar), elem_classes=["patient-card-html"], ) _select_btn = gr.Button("Select", size="sm", variant="secondary") _all_cards.append((_p["hadm_id"], _avatar, _card_html)) _all_select_btns.append((_p["hadm_id"], _select_btn)) # ── Persona card (redesigned) ──────────────────────────────────────── with gr.Group(elem_classes=["form-card"]): gr.HTML("🎭 Patient Persona") # Row 1: Personality + CEFR with gr.Row(equal_height=True): with gr.Column(min_width=200, elem_classes=["persona-cell"]): personality_dd = gr.Radio( choices=PERSONALITY_CHOICES, value="plain", label="Personality", elem_classes=["compact-radio", "personality-radio"], ) personality_tip = gr.HTML( value=_make_tip_html(PERSONALITY_TIPS, "plain"), elem_classes=["tip-html"], ) with gr.Column(min_width=200, elem_classes=["persona-cell"]): cefr_radio = gr.Radio( choices=CEFR_CHOICES, value="B", label="Language Proficiency (CEFR)", elem_classes=["compact-radio"], ) cefr_tip = gr.HTML( value=_make_tip_html(CEFR_TIPS, "B"), elem_classes=["tip-html"], ) # Row 2: Medical History Recall + Cognitive Confusion (now radios) with gr.Row(equal_height=True): with gr.Column(min_width=200, elem_classes=["persona-cell"]): recall_radio = gr.Radio( choices=RECALL_CHOICES, value="high", label="Medical History Recall", elem_classes=["compact-radio"], ) recall_tip = gr.HTML( value=_make_tip_html(RECALL_TIPS, "high"), elem_classes=["tip-html"], ) with gr.Column(min_width=200, elem_classes=["persona-cell"]): confusion_radio = gr.Radio( choices=CONFUSION_CHOICES, value="normal", label="Cognitive Confusion", elem_classes=["compact-radio"], ) confusion_tip = gr.HTML( value=_make_tip_html(CONFUSION_TIPS, "normal"), elem_classes=["tip-html"], ) start_btn = gr.Button("▶ Start Simulation", variant="primary", size="lg", elem_id="start-btn") setup_error_display = gr.HTML("", visible=False) # ── Mode selection section ─────────────────────────────────────────────── with gr.Column(visible=False) as mode_section: gr.Markdown( "## Choose Simulation Mode\n\n" "Your patient agent is ready. Review your configuration below, then pick a mode. " "**Auto Simulation** runs a fully automated consultation so you can observe the " "system in action. **Practice Mode** puts you in the doctor's seat for hands-on " "training." ) recap_display = gr.HTML() with gr.Row(equal_height=True): # Auto simulation card with gr.Column(elem_classes=["mode-card"], min_width=280): gr.Markdown( "### 🤖 Auto Simulation\n" "Watch a fully automated consultation between an AI doctor and " "the simulated patient. The doctor conducts the interview and " "arrives at a differential diagnosis — no input required." ) auto_btn = gr.Button("Run Auto Simulation", variant="primary") # Manual practice card with gr.Column(elem_classes=["mode-card"], min_width=280): gr.Markdown( "### 🩺 Practice Mode\n" "You play the doctor. Choose how much patient information " "to display on the side panel." ) profile_mode_radio = gr.Radio( choices=[ ("Full Profile — all case details visible", "full"), ("Basic Info only — practice without prior knowledge", "blind"), ], value="full", label="Patient Profile Visibility", elem_classes=["compact-radio"], ) manual_btn = gr.Button("Start Practice", variant="primary") with gr.Row(): back_from_mode_btn = gr.Button("← Back to Setup") # ── Auto simulation section ────────────────────────────────────────────── with gr.Column(visible=False) as auto_section: with gr.Row(): back_from_auto_btn = gr.Button("← Back to Setup") back_from_auto_to_mode_btn = gr.Button("← Mode Selection") gr.Markdown( "### 🤖 Auto Simulation\n\n" "An AI doctor is conducting the consultation. The doctor will ask questions, " "gather medical history, and conclude with a differential diagnosis. " "**Doctor** (user) messages appear on the right; **Patient** (assistant) responses on the left." ) auto_recap = gr.HTML() with gr.Row(equal_height=False): with gr.Column(scale=1, min_width=280): auto_recap = gr.HTML() with gr.Column(scale=2): auto_chatbot = gr.Chatbot( label="Doctor–Patient Dialogue", height=700, show_label=True, avatar_images=( _DOCTOR_AVATAR, list(_PATIENT_AVATAR_URLS.values())[0], ), placeholder="Run the auto simulation to see the dialogue here.", ) # ── Manual chat section ────────────────────────────────────────────────── with gr.Column(visible=False) as chat_section: gr.Markdown( "### 🩺 Practice Mode\n\n" "You are the attending physician. Type your questions and responses in the box " "below to consult the simulated patient. Use the patient profile panel on the " "left for reference. Try to gather sufficient history to reach a differential diagnosis." ) with gr.Row(): back_from_chat_btn = gr.Button("← Back to Setup", scale=1) back_from_chat_to_mode_btn = gr.Button("← Mode Selection", scale=1) reset_btn = gr.Button("↺ Reset Conversation", scale=1) with gr.Row(equal_height=False): with gr.Column(scale=1, min_width=280): profile_display = gr.HTML(elem_classes=["profile-display"]) chat_recap = gr.HTML() with gr.Column(scale=2): chatbot = gr.Chatbot( label="Consultation", height=700, show_label=True, avatar_images=( _DOCTOR_AVATAR, list(_PATIENT_AVATAR_URLS.values())[0], # overridden dynamically on start ), placeholder=( "The conversation will appear here. " "Type your message below to begin." ), ) with gr.Row(): msg_box = gr.Textbox( placeholder="Type your message as the doctor…", label="Doctor's Message", lines=1, max_lines=5, scale=5, show_label=False, container=False, ) send_btn = gr.Button( "Send", variant="primary", scale=1, min_width=80 ) # ── Event wiring ───────────────────────────────────────────────────────── # Intro → Setup get_started_btn.click( fn=go_to_setup, outputs=[intro_section, setup_section], ) # Patient card selection (one handler per button) def _make_select_fn(target_id: str): def _fn(): html_updates = [ gr.update(value=_build_single_card_html(PATIENT_DICT[hid], hid == target_id, av)) for hid, av, _ in _all_cards ] return html_updates + [target_id] return _fn for _hadm_id, _select_btn in _all_select_btns: _select_btn.click( fn=_make_select_fn(_hadm_id), outputs=[c for _, _, c in _all_cards] + [selected_patient_state], ) # Tooltip updates personality_dd.change( fn=lambda v: _make_tip_html(PERSONALITY_TIPS, v), inputs=[personality_dd], outputs=[personality_tip], ) cefr_radio.change( fn=lambda v: _make_tip_html(CEFR_TIPS, v), inputs=[cefr_radio], outputs=[cefr_tip], ) recall_radio.change( fn=lambda v: _make_tip_html(RECALL_TIPS, v), inputs=[recall_radio], outputs=[recall_tip], ) confusion_radio.change( fn=lambda v: _make_tip_html(CONFUSION_TIPS, v), inputs=[confusion_radio], outputs=[confusion_tip], ) # Start simulation → mode selection start_btn.click( fn=start_simulation, inputs=[selected_patient_state, model_dd, cefr_radio, personality_dd, recall_radio, confusion_radio, api_key_input], outputs=[patient_agent_state, sim_config_state, setup_section, mode_section, recap_display, setup_error_display], ) # Back to setup from mode selection back_from_mode_btn.click( fn=back_to_setup, inputs=[patient_agent_state], outputs=[patient_agent_state, sim_config_state, setup_section, mode_section, chat_section, auto_section, recap_display, setup_error_display], ) # Auto simulation auto_event = auto_btn.click( fn=start_auto, inputs=[patient_agent_state, sim_config_state], outputs=[auto_chatbot, mode_section, auto_section, chat_section, auto_recap], ) back_from_auto_to_mode_btn.click( fn=back_to_mode_from_auto, inputs=[patient_agent_state], outputs=[auto_chatbot, auto_recap, mode_section, auto_section], cancels=[auto_event], ) back_from_auto_btn.click( fn=back_to_setup_from_auto, inputs=[patient_agent_state], outputs=[auto_chatbot, auto_recap, patient_agent_state, sim_config_state, setup_section, mode_section, auto_section, chat_section, recap_display, setup_error_display], cancels=[auto_event], ) # Manual practice manual_btn.click( fn=start_manual, inputs=[profile_mode_radio, patient_agent_state, sim_config_state], outputs=[profile_display, chat_recap, chatbot, chat_history_state, mode_section, auto_section, chat_section], ) chat_event_send = send_btn.click( fn=chat, inputs=[msg_box, chat_history_state, patient_agent_state, sim_config_state], outputs=[chatbot, chat_history_state, msg_box], ) chat_event_submit = msg_box.submit( fn=chat, inputs=[msg_box, chat_history_state, patient_agent_state, sim_config_state], outputs=[chatbot, chat_history_state, msg_box], ) back_from_chat_to_mode_btn.click( fn=back_to_mode_from_chat, inputs=[patient_agent_state], outputs=[chatbot, profile_display, chat_recap, mode_section, chat_section], cancels=[chat_event_send, chat_event_submit], ) back_from_chat_btn.click( fn=back_to_setup_from_chat, inputs=[patient_agent_state], outputs=[chatbot, profile_display, chat_recap, patient_agent_state, sim_config_state, setup_section, mode_section, auto_section, chat_section, recap_display, setup_error_display], cancels=[chat_event_send, chat_event_submit], ) reset_btn.click( fn=reset_chat, inputs=[patient_agent_state], outputs=[chatbot, chat_history_state], js="() => { if (!confirm('Reset the entire conversation? This cannot be undone.')) return null; }", ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)