Spaces:
Running
Running
| import os | |
| import re | |
| import html | |
| import json | |
| import copy | |
| import logging | |
| import gradio as gr | |
| from dotenv import load_dotenv, find_dotenv | |
| from patientsim import PatientAgent, DoctorAgent | |
| from patientsim.utils.common_utils import detect_ed_termination | |
| from rate_limiter import RateLimiter, get_client_key | |
| load_dotenv(find_dotenv(usecwd=True), override=False) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)s [%(name)s] %(message)s", | |
| handlers=[logging.StreamHandler()], | |
| ) | |
| _logger = logging.getLogger("patientsim.app") | |
| # --------------------------------------------------------------------------- | |
| # Rate limiter (singleton β shared across all Gradio worker threads) | |
| # --------------------------------------------------------------------------- | |
| _rate_limiter = RateLimiter() | |
| # --------------------------------------------------------------------------- | |
| # Constants | |
| # --------------------------------------------------------------------------- | |
| CEFR_CHOICES = [ | |
| ("A β Beginner", "A"), | |
| ("B β Intermediate", "B"), | |
| ("C β Advanced", "C"), | |
| ] | |
| PERSONALITY_CHOICES = [ | |
| ("Neutral", "plain"), | |
| ("Talkative", "verbose"), | |
| ("Distrustful", "distrust"), | |
| ("Pleasing", "pleasing"), | |
| ("Impatient", "impatient"), | |
| ("Overanxious", "overanxious"), | |
| ] | |
| RECALL_CHOICES = [ | |
| ("Low", "low"), | |
| ("High", "high"), | |
| ] | |
| CONFUSION_CHOICES = [ | |
| ("Normal", "normal"), | |
| ("High", "high"), | |
| ] | |
| BACKEND_MODELS = [ | |
| "gemini-3.1-flash-lite-preview", | |
| "gemini-3-flash-preview", | |
| "gemini-2.5-flash", | |
| "gpt-5.4-nano", | |
| "gpt-5.4-mini", | |
| "gpt-5.4", | |
| ] | |
| MAX_AUTO_INFERENCES = 10 | |
| MAX_MESSAGE_CHARS = 2000 | |
| # --------------------------------------------------------------------------- | |
| # Patient data | |
| # --------------------------------------------------------------------------- | |
| _DATA_PATH = os.path.join(os.path.abspath(os.path.dirname(__file__)), "demo", "data", "demo_data.json") | |
| try: | |
| with open(_DATA_PATH) as _f: | |
| PATIENT_DATA: list[dict] = json.load(_f) | |
| except (FileNotFoundError, json.JSONDecodeError) as _e: | |
| raise RuntimeError(f"Failed to load patient data from {_DATA_PATH}: {_e}") from _e | |
| PATIENT_DICT: dict[str, dict] = {p["hadm_id"]: p for p in PATIENT_DATA} | |
| _DOCTOR_AVATAR = "https://cdn-icons-png.flaticon.com/512/3774/3774299.png" | |
| # Per-patient avatar URLs matched to demographics (gender + age group) | |
| _PATIENT_AVATAR_URLS = { | |
| "patient_MI": "https://cdn-icons-png.flaticon.com/512/5488/5488324.png", # 64yo Male β elderly man | |
| "patient_PNA": "https://cdn-icons-png.flaticon.com/512/4140/4140051.png", # 47yo Female β middle-aged woman | |
| "patient_UTI": "https://cdn-icons-png.flaticon.com/512/4140/4140047.png", # 29yo Female β young woman | |
| } | |
| _SORTED_PATIENTS: list[dict] = sorted(PATIENT_DATA, key=lambda x: x["hadm_id"]) | |
| def _sorted_patients() -> list[dict]: | |
| return _SORTED_PATIENTS | |
| def _build_single_card_html(p: dict, selected: bool, avatar_url: str) -> str: | |
| """Build HTML for a single patient card (no select button β handled by gr.Button).""" | |
| border_color = "#3b82f6" if selected else "#e5e7eb" | |
| bg_color = "#eff6ff" if selected else "#ffffff" | |
| shadow = "0 0 0 2px #3b82f6" if selected else "0 1px 3px rgba(0,0,0,0.06)" | |
| age = html.escape(str(p.get("age", "?"))) | |
| gender = html.escape(str(p.get("gender", "Unknown"))) | |
| diagnosis = html.escape(str(p.get("diagnosis", "Unknown"))) | |
| chief = html.escape(str(p.get("chiefcomplaint", "β"))) | |
| transport = html.escape(str(p.get("arrival_transport", "β"))) | |
| safe_avatar = html.escape(avatar_url) | |
| return ( | |
| f"<div style='background:{bg_color};border:2px solid {border_color};" | |
| f"border-radius:12px;padding:24px 20px;box-shadow:{shadow};transition:all 0.15s ease'>" | |
| f"<div style='text-align:center;margin-bottom:10px'>" | |
| f"<img src='{safe_avatar}' style='width:56px;height:56px;border-radius:50%;" | |
| f"background:#f3f4f6;padding:4px' alt='Patient'>" | |
| f"</div>" | |
| f"<div style='font-size:13px;color:#6b7280;line-height:1.6'>" | |
| f"<div><b>Age:</b> {age} Β· <b>Gender:</b> {gender}</div>" | |
| f"<div><b>Chief Complaint:</b> {chief}</div>" | |
| f"<div><b>Transport:</b> {transport}</div>" | |
| f"<div><b>Dx:</b> {diagnosis}</div>" | |
| f"</div>" | |
| f"</div>" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # HTML helpers | |
| # --------------------------------------------------------------------------- | |
| def build_recap_html(hadm_id: str, model: str, cefr: str, personality: str, recall: str, confusion: str) -> str: | |
| patient = PATIENT_DICT.get(hadm_id, {}) | |
| patient_label = ( | |
| f"Age {patient.get('age')} Β· {patient.get('gender')} Β· {patient.get('diagnosis', 'Unknown')}" | |
| if patient else "β" | |
| ) | |
| personality_label = next((l for l, v in PERSONALITY_CHOICES if v == personality), personality) | |
| cefr_label = next((l for l, v in CEFR_CHOICES if v == cefr), cefr) | |
| recall_label = next((l for l, v in RECALL_CHOICES if v == recall), recall) | |
| confusion_label = next((l for l, v in CONFUSION_CHOICES if v == confusion), confusion) | |
| def _card(label, value): | |
| safe_label = html.escape(str(label)) | |
| safe_value = html.escape(str(value)) | |
| return ( | |
| "<div style='padding:10px 14px;background:var(--background-fill-primary,#fff);" | |
| "border:1px solid var(--border-color-primary);border-radius:8px'>" | |
| f"<div style='font-size:11px;font-weight:600;letter-spacing:0.06em;" | |
| f"text-transform:uppercase;color:var(--body-text-color-subdued);margin-bottom:4px'>{safe_label}</div>" | |
| f"<div style='font-size:13px;font-weight:500;line-height:1.4'>{safe_value}</div>" | |
| "</div>" | |
| ) | |
| items = [ | |
| _card("Patient", patient_label), | |
| _card("Personality", personality_label), | |
| _card("Model", model), | |
| _card("Language Proficiency", cefr_label), | |
| _card("Medical History Recall", recall_label), | |
| _card("Cognitive Confusion", confusion_label), | |
| ] | |
| grid_items = "".join(items) | |
| return ( | |
| "<div style='background:var(--color-accent-soft,#f0f7ff);" | |
| "border:1px solid var(--border-color-primary);border-radius:10px;" | |
| "padding:16px 20px;margin-bottom:8px'>" | |
| "<div style='font-weight:600;font-size:14px;margin-bottom:12px'>π Simulation Configuration</div>" | |
| f"<div style='display:grid;grid-template-columns:1fr 1fr;gap:8px'>{grid_items}</div>" | |
| "</div>" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Profile HTML helpers (card-based layout matching recap style) | |
| # --------------------------------------------------------------------------- | |
| def _profile_item(label: str, val) -> str: | |
| """Build a single key-value item inside a profile card.""" | |
| safe_val = html.escape(str(val)) if val not in (None, "") else "N/A" | |
| safe_label = html.escape(str(label)) | |
| return ( | |
| f"<div style='display:flex;gap:6px;padding:5px 0;" | |
| f"border-bottom:1px solid #f3f4f6;font-size:13px;line-height:1.5'>" | |
| f"<span style='color:#6b7280;white-space:nowrap;min-width:110px;" | |
| f"font-weight:500'>{safe_label}</span>" | |
| f"<span style='color:#1f2937;flex:1'>{safe_val}</span>" | |
| f"</div>" | |
| ) | |
| def _profile_card(icon: str, title: str, items_html: str, accent: str = "#3b82f6") -> str: | |
| """Build a styled card section for the profile panel.""" | |
| return ( | |
| f"<div style='background:var(--background-fill-primary,#fff);" | |
| f"border:1px solid var(--border-color-primary,#e5e7eb);border-radius:10px;" | |
| f"padding:14px 16px;margin-bottom:10px'>" | |
| f"<div style='display:flex;align-items:center;gap:8px;margin-bottom:10px;" | |
| f"padding-bottom:8px;border-bottom:2px solid {accent}'>" | |
| f"<span style='font-size:16px'>{icon}</span>" | |
| f"<span style='font-size:13px;font-weight:600;letter-spacing:0.03em;" | |
| f"text-transform:uppercase;color:{accent}'>{html.escape(title)}</span>" | |
| f"</div>" | |
| f"{items_html}" | |
| f"</div>" | |
| ) | |
| def build_profile_html(p: dict) -> str: | |
| hadm_id = p.get("hadm_id", "") | |
| avatar_url = _PATIENT_AVATAR_URLS.get(hadm_id, list(_PATIENT_AVATAR_URLS.values())[0]) | |
| safe_avatar = html.escape(avatar_url) | |
| # Header with avatar | |
| header = ( | |
| f"<div style='text-align:center;margin-bottom:14px'>" | |
| f"<img src='{safe_avatar}' style='width:52px;height:52px;border-radius:50%;" | |
| f"background:#f3f4f6;padding:3px;margin-bottom:6px' alt='Patient'>" | |
| f"<div style='font-size:15px;font-weight:600;color:#1f2937'>Patient Profile</div>" | |
| f"</div>" | |
| ) | |
| basic = _profile_card("π€", "Demographics", | |
| _profile_item("Age", p.get("age")) | |
| + _profile_item("Gender", p.get("gender")) | |
| + _profile_item("Race", p.get("race")) | |
| + _profile_item("Transport", p.get("arrival_transport")), | |
| accent="#3b82f6", | |
| ) | |
| social = _profile_card("π ", "Social History", | |
| _profile_item("Tobacco", p.get("tobacco")) | |
| + _profile_item("Alcohol", p.get("alcohol")) | |
| + _profile_item("Illicit Drug", p.get("illicit_drug")) | |
| + _profile_item("Exercise", p.get("exercise")) | |
| + _profile_item("Marital Status", p.get("marital_status")) | |
| + _profile_item("Children", p.get("children")) | |
| + _profile_item("Living Situation", p.get("living_situation")) | |
| + _profile_item("Occupation", p.get("occupation")) | |
| + _profile_item("Insurance", p.get("insurance")), | |
| accent="#10b981", | |
| ) | |
| history = _profile_card("π", "Previous Medical History", | |
| _profile_item("Allergies", p.get("allergies")) | |
| + _profile_item("Family History", p.get("family_medical_history")) | |
| + _profile_item("Medical Devices", p.get("medical_device")) | |
| + _profile_item("Prior History", p.get("medical_history")), | |
| accent="#f59e0b", | |
| ) | |
| visit = _profile_card("π©Ί", "Current Visit", | |
| _profile_item("Present Illness (+)", p.get("present_illness_positive")) | |
| + _profile_item("Present Illness (β)", p.get("present_illness_negative")) | |
| + _profile_item("Chief Complaint", p.get("chiefcomplaint")) | |
| + _profile_item("Pain (0β10)", p.get("pain")) | |
| + _profile_item("Medications", p.get("medication")) | |
| + _profile_item("Disposition", p.get("disposition")) | |
| + _profile_item("Diagnosis", p.get("diagnosis")), | |
| accent="#ef4444", | |
| ) | |
| return ( | |
| f"<div style='font-family:Noto Sans KR,Noto Sans,Malgun Gothic,Apple SD Gothic Neo,sans-serif;" | |
| f"font-size:14px;line-height:1.5;background:var(--color-accent-soft,#f0f7ff);" | |
| f"border:1px solid var(--border-color-primary,#e5e7eb);border-radius:12px;" | |
| f"padding:18px 16px;max-height:400px;overflow-y:auto'>" | |
| f"{header}{basic}{social}{history}{visit}" | |
| f"</div>" | |
| ) | |
| def build_blind_profile_html(p: dict) -> str: | |
| """Show only basic demographic info for practice mode without full case details.""" | |
| hadm_id = p.get("hadm_id", "") | |
| avatar_url = _PATIENT_AVATAR_URLS.get(hadm_id, list(_PATIENT_AVATAR_URLS.values())[0]) | |
| safe_avatar = html.escape(avatar_url) | |
| header = ( | |
| f"<div style='text-align:center;margin-bottom:14px'>" | |
| f"<img src='{safe_avatar}' style='width:52px;height:52px;border-radius:50%;" | |
| f"background:#f3f4f6;padding:3px;margin-bottom:6px' alt='Patient'>" | |
| f"<div style='font-size:15px;font-weight:600;color:#1f2937'>Patient Info</div>" | |
| f"<div style='font-size:12px;color:#9ca3af;font-style:italic;margin-top:4px'>" | |
| f"Basic demographics only β gather the rest through consultation.</div>" | |
| f"</div>" | |
| ) | |
| basic = _profile_card("π€", "Demographics", | |
| _profile_item("Age", p.get("age")) | |
| + _profile_item("Gender", p.get("gender")) | |
| + _profile_item("Race", p.get("race")) | |
| + _profile_item("Transport", p.get("arrival_transport")), | |
| accent="#3b82f6", | |
| ) | |
| hint = ( | |
| "<div style='background:var(--background-fill-primary,#fff);" | |
| "border:1px dashed #d1d5db;border-radius:10px;padding:16px;text-align:center;" | |
| "color:#9ca3af;font-size:13px'>" | |
| "π Additional information is hidden.<br>" | |
| "Interview the patient to uncover their history." | |
| "</div>" | |
| ) | |
| return ( | |
| f"<div style='font-family:Noto Sans KR,Noto Sans,Malgun Gothic,Apple SD Gothic Neo,sans-serif;" | |
| f"font-size:14px;line-height:1.5;background:var(--color-accent-soft,#f0f7ff);" | |
| f"border:1px solid var(--border-color-primary,#e5e7eb);border-radius:12px;" | |
| f"padding:18px 16px'>" | |
| f"{header}{basic}{hint}" | |
| f"</div>" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Custom CSS | |
| # --------------------------------------------------------------------------- | |
| CUSTOM_CSS = """ | |
| /* ββ Global gothic (sans-serif) font βββββββββββββββββββββββββββββ */ | |
| @import url('https://fonts.googleapis.com/css2?family=Noto+Sans+KR:wght@300;400;500;600;700&family=Noto+Sans:wght@300;400;500;600;700&display=swap'); | |
| *, *::before, *::after { | |
| font-family: 'Noto Sans KR', 'Noto Sans', 'Malgun Gothic', 'Apple SD Gothic Neo', | |
| 'Segoe UI', sans-serif !important; | |
| } | |
| .prose h1, .prose h2, .prose h3, .prose h4, | |
| .markdown-text h1, .markdown-text h2, .markdown-text h3, .markdown-text h4 { | |
| font-family: 'Noto Sans KR', 'Noto Sans', 'Malgun Gothic', 'Apple SD Gothic Neo', | |
| 'Segoe UI', sans-serif !important; | |
| } | |
| /* ββ Page background ββββββββββββββββββββββββββββββββββββββββββββββ */ | |
| body, .gradio-container, .main, footer { | |
| background-color: #f3f4f6 !important; | |
| } | |
| /* ββ All form labels β no blue/accent background ββββββββββββββββββ */ | |
| label span, | |
| label > span, | |
| .block label span, | |
| .label-wrap span, | |
| .gradio-container label span, | |
| .gradio-dropdown label span, | |
| .wrap > label > span, | |
| [class*="label"] > span { | |
| background: transparent !important; | |
| background-color: transparent !important; | |
| } | |
| /* ββ White shadow cards for form sections βββββββββββββββββββββββββ */ | |
| .form-card { | |
| background: #ffffff !important; | |
| border-radius: 14px !important; | |
| box-shadow: 0 1px 3px rgba(0,0,0,0.08), 0 0 0 1px rgba(0,0,0,0.05) !important; | |
| border: none !important; | |
| padding: 16px 20px 20px !important; | |
| } | |
| .form-card > .gap, | |
| .form-card > div { | |
| background: transparent !important; | |
| border: none !important; | |
| gap: 12px !important; | |
| padding: 8px !important; | |
| } | |
| /* ββ Card section title β vertically centered with background ββββ */ | |
| .card-title { | |
| font-size: 18px; | |
| font-weight: 600; | |
| color: #374151; | |
| padding: 6px 0; | |
| margin-bottom: 4px; | |
| border-bottom: 1px solid #e5e7eb; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| min-height: 32px; | |
| line-height: 1; | |
| } | |
| /* ββ Tooltip descriptions β plain text inside the white cell βββββββ */ | |
| .option-desc { | |
| font-size: 15px; | |
| color: #4b5563; | |
| margin-top: 8px; | |
| line-height: 1.4; | |
| } | |
| /* ββ Transparent HTML tip containers ββββββββββββββββββββββββββββββ */ | |
| .tip-html, | |
| .tip-html > div, | |
| .tip-html > .prose { | |
| background: transparent !important; | |
| border: none !important; | |
| padding: 0 !important; | |
| margin: 0 !important; | |
| } | |
| /* ββ CEFR radio & Persona grid β thin border on buttons βββββββββββ */ | |
| .compact-radio .wrap { | |
| gap: 6px !important; | |
| background: transparent !important; | |
| } | |
| .compact-radio label { | |
| padding: 5px 14px !important; | |
| border-radius: 8px !important; | |
| font-size: 13px !important; | |
| min-width: unset !important; | |
| background: transparent !important; | |
| border: 1px solid #d1d5db !important; | |
| } | |
| .compact-radio label:has(input:checked) { | |
| border-color: #3b82f6 !important; | |
| background: #eff6ff !important; | |
| } | |
| .persona-grid { | |
| display: grid !important; | |
| grid-template-columns: 1fr 1fr !important; | |
| gap: 16px !important; | |
| } | |
| /* ββ Persona cell β white box containing buttons + description βββββ */ | |
| .persona-cell { | |
| background: #ffffff !important; | |
| border: 1px solid #e5e7eb !important; | |
| border-radius: 10px !important; | |
| padding: 14px 16px !important; | |
| display: flex !important; | |
| flex-direction: column !important; | |
| } | |
| /* Force options to use black bold font without background */ | |
| .persona-cell label span, | |
| .persona-cell span { | |
| color: black !important; | |
| font-weight: bold !important; | |
| background: transparent !important; | |
| } | |
| /* ββ Personality radio β 2-column grid for 6 choices ββββββββββββββ */ | |
| .personality-radio .wrap { | |
| display: grid !important; | |
| grid-template-columns: 1fr 1fr !important; | |
| gap: 6px !important; | |
| } | |
| .persona-cell .compact-radio { | |
| margin-bottom: 4px !important; | |
| } | |
| /* ββ Mode cards βββββββββββββββββββββββββββββββββββββββββββββββββββ */ | |
| .mode-card { | |
| background: #ffffff !important; | |
| border: 1px solid #e5e7eb !important; | |
| border-radius: 14px !important; | |
| box-shadow: 0 1px 3px rgba(0,0,0,0.06) !important; | |
| padding: 20px !important; | |
| } | |
| /* ββ Start simulation button shadow βββββββββββββββββββββββββββββββ */ | |
| #start-btn > button { | |
| box-shadow: 0 4px 14px rgba(59, 130, 246, 0.30) !important; | |
| font-size: 15px !important; | |
| } | |
| /* ββ Patient card HTML wrapper β remove Gradio's default padding ββ */ | |
| .patient-card-html > div, | |
| .patient-card-html .prose { | |
| padding: 10 !important; | |
| margin: 0 !important; | |
| border: none !important; | |
| background: #ffffff !important; | |
| } | |
| /* ββ Patient profile panel β strip wrapper chrome βββββββββββββββββ */ | |
| .profile-display, | |
| .profile-display > div, | |
| .profile-display .prose { | |
| background: transparent !important; | |
| border: none !important; | |
| padding: 0 !important; | |
| margin: 0 !important; | |
| } | |
| /* White background + padding for each patient column cell */ | |
| .form-card .gradio-column { | |
| background: #ffffff !important; | |
| border-radius: 10px !important; | |
| border: 1px solid #e5e7eb !important; | |
| padding: 12px !important; | |
| gap: 8px !important; | |
| } | |
| /* Spacing between columns and from the card border */ | |
| .form-card .gradio-row { | |
| gap: 12px !important; | |
| padding: 12px 12px !important; | |
| background: transparent !important; | |
| } | |
| /* Fallback using data-testid if gradio-column class doesn't apply */ | |
| .form-card [data-testid="column"] { | |
| background: #ffffff !important; | |
| border-radius: 10px !important; | |
| border: 1px solid #e5e7eb !important; | |
| padding: 12px !important; | |
| } | |
| .form-card [data-testid="row"] { | |
| gap: 24px !important; | |
| padding: 12px 8px !important; | |
| } | |
| .patient-row { | |
| gap: 16px !important; | |
| margin-bottom: 16px !important; | |
| padding: 0 24px !important; | |
| } | |
| /* Gradio 6 wraps row flex content in an inner div β target it directly */ | |
| .patient-row > div { | |
| gap: flex !important; | |
| display: flex !important; | |
| flex-wrap: wrap !important; | |
| align-items: stretch !important; | |
| } | |
| .patient-card-column { | |
| flex: 1 1 calc(33.333% - 12px) !important; | |
| flex-direction: column !important; | |
| min-width: 100px !important; | |
| max-width: calc(33.333% - 12px) !important; | |
| display: flex !important; | |
| height: 100% !important; | |
| gap: 0 !important; | |
| } | |
| .patient-card-html { | |
| flex: 1 !important; | |
| height: 100% !important; | |
| } | |
| """ | |
| # Tooltip maps for descriptions shown below each radio group | |
| PERSONALITY_TIPS = { | |
| "plain": "No strong emotions or noticeable behavior.", | |
| "verbose": "Speaks a lot, gives highly detailed responses.", | |
| "distrust": "Questions the doctor's expertise and care.", | |
| "pleasing": "Overly positive, tends to minimize problems.", | |
| "impatient": "Easily irritated and lacks patience.", | |
| "overanxious": "Expresses concern beyond what is typical.", | |
| } | |
| RECALL_TIPS = { | |
| "low": "Often forgets even major medical events.", | |
| "high": "Usually recalls medical events accurately.", | |
| } | |
| CONFUSION_TIPS = { | |
| "normal": "Clear mental status.", | |
| "high": "Highly dazed and extremely confused.", | |
| } | |
| CEFR_TIPS = { | |
| "A": "Can make simple sentences.", | |
| "B": "Can have daily conversations.", | |
| "C": "Can freely use advanced medical terms.", | |
| } | |
| def _make_tip_html(tips: dict, selected: str) -> str: | |
| desc = tips.get(selected, "") | |
| return f"<p class='option-desc'>{desc}</p>" if desc else "" | |
| # --------------------------------------------------------------------------- | |
| # Callbacks β Setup | |
| # --------------------------------------------------------------------------- | |
| def go_to_setup(): | |
| return gr.update(visible=False), gr.update(visible=True) | |
| def _sanitize_error(msg: str) -> str: | |
| """Remove API key patterns and internal paths from error messages.""" | |
| msg = re.sub(r'sk-[A-Za-z0-9_-]{10,}', '[REDACTED]', msg) | |
| msg = re.sub(r'AIza[A-Za-z0-9_-]{20,}', '[REDACTED]', msg) | |
| msg = re.sub(r'AQ\.[A-Za-z0-9_-]{10,}', '[REDACTED]', msg) | |
| msg = re.sub(r'/(?:home|data|tmp|var|usr|etc)/[^\s\'"]+', '[PATH_REDACTED]', msg) | |
| return msg | |
| def _setup_error(msg: str): | |
| """Return output tuple that keeps setup visible and shows an inline error.""" | |
| safe_msg = html.escape(_sanitize_error(msg)) | |
| error_html = ( | |
| f"<div style='background:#fef2f2;border:1px solid #fca5a5;border-radius:8px;" | |
| f"padding:10px 14px;color:#b91c1c;font-size:13px'>β οΈ {safe_msg}</div>" | |
| ) | |
| return ( | |
| gr.update(), # patient_agent_state β no change | |
| gr.update(), # sim_config_state β no change | |
| gr.update(), # setup_section β no change (already visible) | |
| gr.update(), # mode_section β no change (already hidden) | |
| gr.update(), # recap_display β no change | |
| gr.update(value=error_html, visible=True), # show setup_error_display | |
| ) | |
| def start_simulation( | |
| hadm_id: str, | |
| model: str, | |
| cefr: str, | |
| personality: str, | |
| recall: str, | |
| confusion: str, | |
| user_api_key: str = "", | |
| request: gr.Request = None, | |
| ): | |
| if not hadm_id: | |
| return _setup_error("Please select a patient first.") | |
| if model not in BACKEND_MODELS: | |
| return _setup_error("Invalid model selection.") | |
| using_own_key = bool(user_api_key.strip()) | |
| is_openai = "gpt" in model.lower() | |
| if using_own_key: | |
| api_key = user_api_key.strip() | |
| elif is_openai: | |
| api_key = os.environ.get("OPENAI_API_KEY", "") | |
| else: | |
| api_key = os.environ.get("GENAI_API_KEY", "") or os.environ.get("GOOGLE_API_KEY", "") | |
| if not api_key: | |
| key_name = "OPENAI_API_KEY" if is_openai else "GENAI_API_KEY / GOOGLE_API_KEY" | |
| return _setup_error( | |
| f"API key not configured ({key_name}). " | |
| "Run with your own key, or contact the demo administrator." | |
| ) | |
| patient = PATIENT_DICT.get(hadm_id) | |
| if patient is None: | |
| return _setup_error("Invalid patient selection. Please select a valid patient from the list.") | |
| patient = copy.deepcopy(patient) | |
| try: | |
| agent = PatientAgent( | |
| model=model, | |
| visit_type="emergency_department", | |
| personality=personality, | |
| recall_level=recall, | |
| confusion_level=confusion, | |
| lang_proficiency_level=cefr, | |
| api_key=api_key, | |
| temperature=0.7, | |
| num_word_sample=10, | |
| random_seed=42, | |
| log_verbose=False, | |
| use_vertex=True, # Force using Vertex AI for all models to ensure consistent behavior and better error handling | |
| **patient, | |
| ) | |
| except Exception as e: | |
| _logger.error("Failed to initialize patient agent: %s", _sanitize_error(str(e))) | |
| return _setup_error(f"Failed to initialize patient agent: {_sanitize_error(str(e))}") | |
| recap = build_recap_html(hadm_id, model, cefr, personality, recall, confusion) | |
| sim_config = { | |
| "patient": patient, | |
| "model": model, | |
| "recap_html": recap, | |
| "user_api_key": user_api_key.strip(), # empty string = using shared demo key | |
| } | |
| return ( | |
| agent, | |
| sim_config, | |
| gr.update(visible=False), # hide setup_section | |
| gr.update(visible=True), # show mode_section | |
| gr.update(value=recap), # recap_display | |
| gr.update(visible=False), # hide setup_error_display | |
| ) | |
| def back_to_setup(agent): | |
| if agent is not None: | |
| agent.reset_history(verbose=False) | |
| return ( | |
| None, # clear patient_agent_state | |
| None, # clear sim_config_state | |
| gr.update(visible=True), # show setup_section | |
| gr.update(visible=False), # hide mode_section | |
| gr.update(visible=False), # hide chat_section | |
| gr.update(visible=False), # hide auto_section | |
| gr.update(value=""), # clear recap_display | |
| gr.update(value="", visible=False), # clear setup_error_display | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Callbacks β Manual practice mode | |
| # --------------------------------------------------------------------------- | |
| def start_manual(profile_mode: str, agent, sim_config: dict): | |
| if agent is None or sim_config is None: | |
| raise gr.Warning("Session expired. Please restart.") | |
| agent.reset_history(verbose=False) | |
| patient = sim_config["patient"] | |
| profile_html = ( | |
| build_profile_html(patient) | |
| if profile_mode == "full" | |
| else build_blind_profile_html(patient) | |
| ) | |
| hadm_id = sim_config["patient"].get("hadm_id", "") | |
| patient_avatar = _PATIENT_AVATAR_URLS.get(hadm_id, list(_PATIENT_AVATAR_URLS.values())[0]) | |
| return ( | |
| profile_html, | |
| sim_config.get("recap_html", ""), # show recap in left panel | |
| gr.update(value=[], avatar_images=(_DOCTOR_AVATAR, patient_avatar)), # chatbot with matched avatar | |
| [], # clear chat log | |
| gr.update(visible=False), # hide mode_section | |
| gr.update(visible=False), # hide auto_section | |
| gr.update(visible=True), # show chat_section | |
| ) | |
| _INJECTION_PATTERNS = re.compile( | |
| r'ignore\s+(?:\w+\s+){0,3}(instructions?|prompts?|rules?)' | |
| r'|(?:forget|disregard)\s+(?:\w+\s+){0,3}instructions?' | |
| r'|you\s+are\s+(?:now|actually)\s+(?:a\s+|an\s+)?(?:DAN|jailbreak)' | |
| r'|reveal\s+(?:\w+\s+){0,2}(system\s+)?prompt' | |
| r'|act\s+as\s+(?:if\s+)?you\s+(?:have\s+no|are\s+without)\s+(?:restrictions?|guidelines?)', | |
| re.IGNORECASE, | |
| ) | |
| def chat(message: str, history: list, agent, sim_config: dict, request: gr.Request = None): | |
| if agent is None: | |
| raise gr.Error("No simulation running. Please start a simulation first.") | |
| if not message.strip(): | |
| return history, "" | |
| if len(message) > MAX_MESSAGE_CHARS: | |
| raise gr.Error( | |
| f"Message too long ({len(message)} characters). " | |
| f"Please keep messages under {MAX_MESSAGE_CHARS} characters." | |
| ) | |
| if _INJECTION_PATTERNS.search(message): | |
| _logger.warning("Prompt injection attempt detected from key=%s", get_client_key(request)) | |
| raise gr.Error("Invalid input detected. Please enter a valid clinical question.") | |
| using_own_key = bool(sim_config and sim_config.get("user_api_key")) | |
| client_key = get_client_key(request) | |
| if not using_own_key: | |
| allowed, limit_msg = _rate_limiter.check_chat_message(client_key) | |
| if not allowed: | |
| raise gr.Error(limit_msg) | |
| else: | |
| # Own-key users bypass per-IP quotas but still respect global capacity. | |
| allowed, limit_msg = _rate_limiter.check_global_capacity() | |
| if not allowed: | |
| raise gr.Error(limit_msg) | |
| response = agent(user_prompt=message, using_multi_turn=True, verbose=False) | |
| history = history + [ | |
| {"role": "user", "content": message}, | |
| {"role": "assistant", "content": response}, | |
| ] | |
| return history, history, "" | |
| def reset_chat(agent): | |
| if agent is not None: | |
| agent.reset_history(verbose=False) | |
| return gr.update(value=[]), [] | |
| def back_to_setup_from_chat(agent): | |
| if agent is not None: | |
| agent.reset_history(verbose=False) | |
| return ( | |
| [], # clear chat log | |
| "", # clear profile html | |
| "", # clear chat recap | |
| None, # clear patient_agent_state | |
| None, # clear sim_config_state | |
| gr.update(visible=True), # show setup_section | |
| gr.update(visible=False), # hide mode_section | |
| gr.update(visible=False), # hide auto_section | |
| gr.update(visible=False), # hide chat_section | |
| gr.update(value=""), # clear recap_display | |
| gr.update(value="", visible=False), # clear setup_error_display | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Callbacks β Auto simulation mode | |
| # --------------------------------------------------------------------------- | |
| def _auto_fallback_outputs(): | |
| """Return fresh gr.update() instances for error/exit paths in start_auto.""" | |
| return ( | |
| gr.update(), # auto_chatbot | |
| gr.update(visible=True), # mode_section | |
| gr.update(visible=False), # auto_section | |
| gr.update(visible=False), # chat_section | |
| gr.update(), # auto_recap | |
| ) | |
| def start_auto(agent, sim_config: dict, request: gr.Request = None): | |
| """Generator β yields chatbot updates turn-by-turn so the UI streams live.""" | |
| client_key = get_client_key(request) | |
| if agent is None or sim_config is None: | |
| gr.Warning("Session expired. Please restart.") | |
| yield _auto_fallback_outputs() | |
| return | |
| using_own_key = bool(sim_config.get("user_api_key")) | |
| if not using_own_key: | |
| allowed, limit_msg = _rate_limiter.check_auto_run(client_key) | |
| if not allowed: | |
| gr.Warning(limit_msg) | |
| yield _auto_fallback_outputs() | |
| return | |
| else: | |
| # Own-key users bypass per-IP quotas but still enforce the concurrent | |
| # run cap and the hard global capacity limit. | |
| allowed, limit_msg = _rate_limiter.check_own_key_auto_run(client_key) | |
| if not allowed: | |
| gr.Warning(limit_msg) | |
| yield _auto_fallback_outputs() | |
| return | |
| try: | |
| agent.reset_history(verbose=False) | |
| # Show auto_section immediately; set per-patient avatar on first yield | |
| _hadm_id = sim_config["patient"].get("hadm_id", "") | |
| _patient_avatar = _PATIENT_AVATAR_URLS.get(_hadm_id, list(_PATIENT_AVATAR_URLS.values())[0]) | |
| yield ( | |
| gr.update(value=[], avatar_images=(_DOCTOR_AVATAR, _patient_avatar)), | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| sim_config.get("recap_html", ""), | |
| ) | |
| model = sim_config["model"] | |
| is_openai = "gpt" in model.lower() | |
| if using_own_key: | |
| api_key = sim_config["user_api_key"] | |
| elif is_openai: | |
| api_key = os.environ.get("OPENAI_API_KEY", "") | |
| else: | |
| api_key = os.environ.get("GENAI_API_KEY", "") or os.environ.get("GOOGLE_API_KEY", "") | |
| try: | |
| doctor = DoctorAgent( | |
| model=model, | |
| api_key=api_key, | |
| temperature=0.2, | |
| random_seed=42, | |
| max_inferences=MAX_AUTO_INFERENCES, | |
| use_vertex=True, | |
| **{k: sim_config["patient"][k] for k in ("age", "gender", "arrival_transport") if k in sim_config["patient"]}, | |
| ) | |
| except Exception as e: | |
| _logger.error("Failed to initialize doctor agent: %s", _sanitize_error(str(e))) | |
| gr.Error(f"Failed to initialize doctor agent: {_sanitize_error(str(e))}") | |
| yield _auto_fallback_outputs() | |
| return | |
| # Switch to auto_section immediately with empty chatbot; set recap once | |
| chat_history = [] | |
| recap_html = sim_config.get("recap_html", "") | |
| yield chat_history, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), recap_html | |
| def _append(role: str, content: str): | |
| chat_history.append({"role": role, "content": content}) | |
| def _yield(): | |
| return chat_history, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update() | |
| try: | |
| # Doctor greets first | |
| doctor_greet = doctor.doctor_greet | |
| _append("user", doctor_greet) | |
| yield _yield() | |
| for inference_idx in range(MAX_AUTO_INFERENCES): | |
| is_last = inference_idx == MAX_AUTO_INFERENCES - 1 | |
| # Patient responds to the last doctor message | |
| patient_response = agent( | |
| user_prompt=chat_history[-1]["content"], | |
| using_multi_turn=True, | |
| verbose=False, | |
| ) | |
| _append("assistant", patient_response) | |
| yield _yield() | |
| # Doctor responds | |
| doctor_input = chat_history[-1]["content"] | |
| if is_last: | |
| doctor_input += "\nThis is the final turn. Now, you must provide your top 5 differential diagnoses." | |
| doctor_response = doctor( | |
| user_prompt=doctor_input, | |
| using_multi_turn=True, | |
| verbose=False, | |
| ) | |
| _append("user", doctor_response) | |
| yield _yield() | |
| if detect_ed_termination(doctor_response): | |
| break | |
| except Exception as e: | |
| gr.Error(f"Simulation error: {_sanitize_error(str(e))}") | |
| yield _yield() | |
| finally: | |
| _rate_limiter.release_auto_slot(client_key) | |
| def back_to_mode_from_auto(agent): | |
| if agent is not None: | |
| agent.reset_history(verbose=False) | |
| return ( | |
| [], # clear auto chatbot log | |
| "", # clear auto recap | |
| gr.update(visible=True), # show mode_section | |
| gr.update(visible=False), # hide auto_section | |
| ) | |
| def back_to_mode_from_chat(agent): | |
| if agent is not None: | |
| agent.reset_history(verbose=False) | |
| return ( | |
| [], # clear chat log | |
| "", # clear profile html | |
| "", # clear chat recap | |
| gr.update(visible=True), # show mode_section | |
| gr.update(visible=False), # hide chat_section | |
| ) | |
| def back_to_setup_from_auto(agent): | |
| if agent is not None: | |
| agent.reset_history(verbose=False) | |
| return ( | |
| [], # clear auto chatbot log | |
| "", # clear auto recap | |
| None, # clear patient_agent_state | |
| None, # clear sim_config_state | |
| gr.update(visible=True), # show setup_section | |
| gr.update(visible=False), # hide mode_section | |
| gr.update(visible=False), # hide auto_section | |
| gr.update(visible=False), # hide chat_section | |
| gr.update(value=""), # clear recap_display | |
| gr.update(value="", visible=False), # clear setup_error_display | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="PatientSim", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: | |
| patient_agent_state = gr.State(None) | |
| sim_config_state = gr.State(None) | |
| selected_patient_state = gr.State(None) | |
| chat_history_state = gr.State([]) | |
| # ββ Intro section ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(visible=True) as intro_section: | |
| gr.Markdown( | |
| "# π₯ PatientSim β ED Consultation Demo\n\n" | |
| "**PatientSim** is a research framework for simulating realistic emergency department " | |
| "doctorβpatient interactions, presented at " | |
| "[NeurIPS 2025 (Datasets & Benchmarks)](https://openreview.net/forum?id=1THAjdP4QJ).\n\n" | |
| "Large language models act as patients with **controllable personas** β you choose their " | |
| "personality, language proficiency, medical history recall, and cognitive state β producing " | |
| "diverse and realistic consultation scenarios for training and evaluation.\n\n" | |
| "---\n\n" | |
| "### What you can do here\n\n" | |
| "| Mode | Description |\n" | |
| "|------|-------------|\n" | |
| "| π€ **Auto Simulation** | Watch an AI doctor conduct a full consultation with the simulated patient and arrive at a differential diagnosis β no input required. |\n" | |
| "| π©Ί **Practice Mode** | You play the doctor. Consult the simulated patient, gather medical history, and work toward a diagnosis yourself. |\n\n" | |
| "---\n\n" | |
| "### How to get started\n" | |
| "1. Click **Get Started** below.\n" | |
| "2. Enter your API key and select a patient case and persona.\n" | |
| "3. Choose a simulation mode and begin." | |
| ) | |
| get_started_btn = gr.Button("Get Started β", variant="primary", size="lg") | |
| # ββ Setup section ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(visible=False) as setup_section: | |
| gr.Markdown( | |
| "# π₯ PatientSim β Setup\n\n" | |
| "Configure your session below. Select the patient case you want to simulate, " | |
| "choose the AI model to power the patient, and define the patient's persona across " | |
| "four behavioral axes. When ready, click **Start Simulation**." | |
| ) | |
| # ββ Connection card ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Group(elem_classes=["form-card"]): | |
| gr.Markdown( | |
| "**π Model & API Key**\n\n" | |
| "This demo runs on a shared API key with a limited number of free calls. " | |
| "If the free quota has been exhausted, please enter your own API key below " | |
| "(OpenAI or Google Gemini) to continue without restrictions. " | |
| "Your key is used only for this session and is never stored on our servers." | |
| ) | |
| with gr.Row(equal_height=True): | |
| model_dd = gr.Dropdown( | |
| choices=BACKEND_MODELS, | |
| value=BACKEND_MODELS[0], | |
| label="Model", | |
| scale=1, | |
| ) | |
| api_key_input = gr.Textbox( | |
| label="API Key (optional)", | |
| placeholder="Leave blank to use the shared demo key Β· sk-... or paste your Gemini key", | |
| type="password", | |
| scale=2, | |
| ) | |
| # ββ Patient Case card (one gr.Button per patient) ββββββββββββββββββββ | |
| with gr.Group(elem_classes=["form-card"]): | |
| gr.HTML("<span class='card-title'>π©Ί Patient Case</span>") | |
| _all_cards: list[tuple[str, str, gr.HTML]] = [] # (hadm_id, avatar_url, html_component) | |
| _all_select_btns: list[tuple[str, gr.Button]] = [] # (hadm_id, button) | |
| _patients = _sorted_patients() | |
| for _i in range(0, len(_patients), 3): | |
| with gr.Row(elem_classes=["patient-row"]): | |
| for _j, _p in enumerate(_patients[_i:_i + 3]): | |
| _idx = _i + _j | |
| _avatar = _PATIENT_AVATAR_URLS.get(_p["hadm_id"], _DOCTOR_AVATAR) | |
| with gr.Column(scale=1, elem_classes=["patient-card-column"]): | |
| _card_html = gr.HTML( | |
| value=_build_single_card_html(_p, False, _avatar), | |
| elem_classes=["patient-card-html"], | |
| ) | |
| _select_btn = gr.Button("Select", size="sm", variant="secondary") | |
| _all_cards.append((_p["hadm_id"], _avatar, _card_html)) | |
| _all_select_btns.append((_p["hadm_id"], _select_btn)) | |
| # ββ Persona card (redesigned) ββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Group(elem_classes=["form-card"]): | |
| gr.HTML("<span class='card-title'>π Patient Persona</span>") | |
| # Row 1: Personality + CEFR | |
| with gr.Row(equal_height=True): | |
| with gr.Column(min_width=200, elem_classes=["persona-cell"]): | |
| personality_dd = gr.Radio( | |
| choices=PERSONALITY_CHOICES, | |
| value="plain", | |
| label="Personality", | |
| elem_classes=["compact-radio", "personality-radio"], | |
| ) | |
| personality_tip = gr.HTML( | |
| value=_make_tip_html(PERSONALITY_TIPS, "plain"), | |
| elem_classes=["tip-html"], | |
| ) | |
| with gr.Column(min_width=200, elem_classes=["persona-cell"]): | |
| cefr_radio = gr.Radio( | |
| choices=CEFR_CHOICES, | |
| value="B", | |
| label="Language Proficiency (CEFR)", | |
| elem_classes=["compact-radio"], | |
| ) | |
| cefr_tip = gr.HTML( | |
| value=_make_tip_html(CEFR_TIPS, "B"), | |
| elem_classes=["tip-html"], | |
| ) | |
| # Row 2: Medical History Recall + Cognitive Confusion (now radios) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(min_width=200, elem_classes=["persona-cell"]): | |
| recall_radio = gr.Radio( | |
| choices=RECALL_CHOICES, | |
| value="high", | |
| label="Medical History Recall", | |
| elem_classes=["compact-radio"], | |
| ) | |
| recall_tip = gr.HTML( | |
| value=_make_tip_html(RECALL_TIPS, "high"), | |
| elem_classes=["tip-html"], | |
| ) | |
| with gr.Column(min_width=200, elem_classes=["persona-cell"]): | |
| confusion_radio = gr.Radio( | |
| choices=CONFUSION_CHOICES, | |
| value="normal", | |
| label="Cognitive Confusion", | |
| elem_classes=["compact-radio"], | |
| ) | |
| confusion_tip = gr.HTML( | |
| value=_make_tip_html(CONFUSION_TIPS, "normal"), | |
| elem_classes=["tip-html"], | |
| ) | |
| start_btn = gr.Button("βΆ Start Simulation", variant="primary", size="lg", elem_id="start-btn") | |
| setup_error_display = gr.HTML("", visible=False) | |
| # ββ Mode selection section βββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(visible=False) as mode_section: | |
| gr.Markdown( | |
| "## Choose Simulation Mode\n\n" | |
| "Your patient agent is ready. Review your configuration below, then pick a mode. " | |
| "**Auto Simulation** runs a fully automated consultation so you can observe the " | |
| "system in action. **Practice Mode** puts you in the doctor's seat for hands-on " | |
| "training." | |
| ) | |
| recap_display = gr.HTML() | |
| with gr.Row(equal_height=True): | |
| # Auto simulation card | |
| with gr.Column(elem_classes=["mode-card"], min_width=280): | |
| gr.Markdown( | |
| "### π€ Auto Simulation\n" | |
| "Watch a fully automated consultation between an AI doctor and " | |
| "the simulated patient. The doctor conducts the interview and " | |
| "arrives at a differential diagnosis β no input required." | |
| ) | |
| auto_btn = gr.Button("Run Auto Simulation", variant="primary") | |
| # Manual practice card | |
| with gr.Column(elem_classes=["mode-card"], min_width=280): | |
| gr.Markdown( | |
| "### π©Ί Practice Mode\n" | |
| "You play the doctor. Choose how much patient information " | |
| "to display on the side panel." | |
| ) | |
| profile_mode_radio = gr.Radio( | |
| choices=[ | |
| ("Full Profile β all case details visible", "full"), | |
| ("Basic Info only β practice without prior knowledge", "blind"), | |
| ], | |
| value="full", | |
| label="Patient Profile Visibility", | |
| elem_classes=["compact-radio"], | |
| ) | |
| manual_btn = gr.Button("Start Practice", variant="primary") | |
| with gr.Row(): | |
| back_from_mode_btn = gr.Button("β Back to Setup") | |
| # ββ Auto simulation section ββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(visible=False) as auto_section: | |
| with gr.Row(): | |
| back_from_auto_btn = gr.Button("β Back to Setup") | |
| back_from_auto_to_mode_btn = gr.Button("β Mode Selection") | |
| gr.Markdown( | |
| "### π€ Auto Simulation\n\n" | |
| "An AI doctor is conducting the consultation. The doctor will ask questions, " | |
| "gather medical history, and conclude with a differential diagnosis. " | |
| "**Doctor** (user) messages appear on the right; **Patient** (assistant) responses on the left." | |
| ) | |
| auto_recap = gr.HTML() | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1, min_width=280): | |
| auto_recap = gr.HTML() | |
| with gr.Column(scale=2): | |
| auto_chatbot = gr.Chatbot( | |
| label="DoctorβPatient Dialogue", | |
| height=700, | |
| show_label=True, | |
| avatar_images=( | |
| _DOCTOR_AVATAR, | |
| list(_PATIENT_AVATAR_URLS.values())[0], | |
| ), | |
| placeholder="Run the auto simulation to see the dialogue here.", | |
| ) | |
| # ββ Manual chat section ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Column(visible=False) as chat_section: | |
| gr.Markdown( | |
| "### π©Ί Practice Mode\n\n" | |
| "You are the attending physician. Type your questions and responses in the box " | |
| "below to consult the simulated patient. Use the patient profile panel on the " | |
| "left for reference. Try to gather sufficient history to reach a differential diagnosis." | |
| ) | |
| with gr.Row(): | |
| back_from_chat_btn = gr.Button("β Back to Setup", scale=1) | |
| back_from_chat_to_mode_btn = gr.Button("β Mode Selection", scale=1) | |
| reset_btn = gr.Button("βΊ Reset Conversation", scale=1) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1, min_width=280): | |
| profile_display = gr.HTML(elem_classes=["profile-display"]) | |
| chat_recap = gr.HTML() | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| label="Consultation", | |
| height=700, | |
| show_label=True, | |
| avatar_images=( | |
| _DOCTOR_AVATAR, | |
| list(_PATIENT_AVATAR_URLS.values())[0], # overridden dynamically on start | |
| ), | |
| placeholder=( | |
| "The conversation will appear here. " | |
| "Type your message below to begin." | |
| ), | |
| ) | |
| with gr.Row(): | |
| msg_box = gr.Textbox( | |
| placeholder="Type your message as the doctorβ¦", | |
| label="Doctor's Message", | |
| lines=1, | |
| max_lines=5, | |
| scale=5, | |
| show_label=False, | |
| container=False, | |
| ) | |
| send_btn = gr.Button( | |
| "Send", variant="primary", scale=1, min_width=80 | |
| ) | |
| # ββ Event wiring βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Intro β Setup | |
| get_started_btn.click( | |
| fn=go_to_setup, | |
| outputs=[intro_section, setup_section], | |
| ) | |
| # Patient card selection (one handler per button) | |
| def _make_select_fn(target_id: str): | |
| def _fn(): | |
| html_updates = [ | |
| gr.update(value=_build_single_card_html(PATIENT_DICT[hid], hid == target_id, av)) | |
| for hid, av, _ in _all_cards | |
| ] | |
| return html_updates + [target_id] | |
| return _fn | |
| for _hadm_id, _select_btn in _all_select_btns: | |
| _select_btn.click( | |
| fn=_make_select_fn(_hadm_id), | |
| outputs=[c for _, _, c in _all_cards] + [selected_patient_state], | |
| ) | |
| # Tooltip updates | |
| personality_dd.change( | |
| fn=lambda v: _make_tip_html(PERSONALITY_TIPS, v), | |
| inputs=[personality_dd], | |
| outputs=[personality_tip], | |
| ) | |
| cefr_radio.change( | |
| fn=lambda v: _make_tip_html(CEFR_TIPS, v), | |
| inputs=[cefr_radio], | |
| outputs=[cefr_tip], | |
| ) | |
| recall_radio.change( | |
| fn=lambda v: _make_tip_html(RECALL_TIPS, v), | |
| inputs=[recall_radio], | |
| outputs=[recall_tip], | |
| ) | |
| confusion_radio.change( | |
| fn=lambda v: _make_tip_html(CONFUSION_TIPS, v), | |
| inputs=[confusion_radio], | |
| outputs=[confusion_tip], | |
| ) | |
| # Start simulation β mode selection | |
| start_btn.click( | |
| fn=start_simulation, | |
| inputs=[selected_patient_state, model_dd, cefr_radio, personality_dd, recall_radio, confusion_radio, api_key_input], | |
| outputs=[patient_agent_state, sim_config_state, setup_section, mode_section, recap_display, setup_error_display], | |
| ) | |
| # Back to setup from mode selection | |
| back_from_mode_btn.click( | |
| fn=back_to_setup, | |
| inputs=[patient_agent_state], | |
| outputs=[patient_agent_state, sim_config_state, setup_section, mode_section, chat_section, auto_section, recap_display, setup_error_display], | |
| ) | |
| # Auto simulation | |
| auto_event = auto_btn.click( | |
| fn=start_auto, | |
| inputs=[patient_agent_state, sim_config_state], | |
| outputs=[auto_chatbot, mode_section, auto_section, chat_section, auto_recap], | |
| ) | |
| back_from_auto_to_mode_btn.click( | |
| fn=back_to_mode_from_auto, | |
| inputs=[patient_agent_state], | |
| outputs=[auto_chatbot, auto_recap, mode_section, auto_section], | |
| cancels=[auto_event], | |
| ) | |
| back_from_auto_btn.click( | |
| fn=back_to_setup_from_auto, | |
| inputs=[patient_agent_state], | |
| outputs=[auto_chatbot, auto_recap, patient_agent_state, sim_config_state, setup_section, mode_section, auto_section, chat_section, recap_display, setup_error_display], | |
| cancels=[auto_event], | |
| ) | |
| # Manual practice | |
| manual_btn.click( | |
| fn=start_manual, | |
| inputs=[profile_mode_radio, patient_agent_state, sim_config_state], | |
| outputs=[profile_display, chat_recap, chatbot, chat_history_state, mode_section, auto_section, chat_section], | |
| ) | |
| chat_event_send = send_btn.click( | |
| fn=chat, | |
| inputs=[msg_box, chat_history_state, patient_agent_state, sim_config_state], | |
| outputs=[chatbot, chat_history_state, msg_box], | |
| ) | |
| chat_event_submit = msg_box.submit( | |
| fn=chat, | |
| inputs=[msg_box, chat_history_state, patient_agent_state, sim_config_state], | |
| outputs=[chatbot, chat_history_state, msg_box], | |
| ) | |
| back_from_chat_to_mode_btn.click( | |
| fn=back_to_mode_from_chat, | |
| inputs=[patient_agent_state], | |
| outputs=[chatbot, profile_display, chat_recap, mode_section, chat_section], | |
| cancels=[chat_event_send, chat_event_submit], | |
| ) | |
| back_from_chat_btn.click( | |
| fn=back_to_setup_from_chat, | |
| inputs=[patient_agent_state], | |
| outputs=[chatbot, profile_display, chat_recap, patient_agent_state, sim_config_state, setup_section, mode_section, auto_section, chat_section, recap_display, setup_error_display], | |
| cancels=[chat_event_send, chat_event_submit], | |
| ) | |
| reset_btn.click( | |
| fn=reset_chat, | |
| inputs=[patient_agent_state], | |
| outputs=[chatbot, chat_history_state], | |
| js="() => { if (!confirm('Reset the entire conversation? This cannot be undone.')) return null; }", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |