PatientSim / app.py
dek924's picture
fix: doc/patient avatar location
ab96eda
import os
import re
import html
import json
import copy
import logging
import gradio as gr
from dotenv import load_dotenv, find_dotenv
from patientsim import PatientAgent, DoctorAgent
from patientsim.utils.common_utils import detect_ed_termination
from rate_limiter import RateLimiter, get_client_key
load_dotenv(find_dotenv(usecwd=True), override=False)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
handlers=[logging.StreamHandler()],
)
_logger = logging.getLogger("patientsim.app")
# ---------------------------------------------------------------------------
# Rate limiter (singleton β€” shared across all Gradio worker threads)
# ---------------------------------------------------------------------------
_rate_limiter = RateLimiter()
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
CEFR_CHOICES = [
("A β€” Beginner", "A"),
("B β€” Intermediate", "B"),
("C β€” Advanced", "C"),
]
PERSONALITY_CHOICES = [
("Neutral", "plain"),
("Talkative", "verbose"),
("Distrustful", "distrust"),
("Pleasing", "pleasing"),
("Impatient", "impatient"),
("Overanxious", "overanxious"),
]
RECALL_CHOICES = [
("Low", "low"),
("High", "high"),
]
CONFUSION_CHOICES = [
("Normal", "normal"),
("High", "high"),
]
BACKEND_MODELS = [
"gemini-3.1-flash-lite-preview",
"gemini-3-flash-preview",
"gemini-2.5-flash",
"gpt-5.4-nano",
"gpt-5.4-mini",
"gpt-5.4",
]
MAX_AUTO_INFERENCES = 10
MAX_MESSAGE_CHARS = 2000
# ---------------------------------------------------------------------------
# Patient data
# ---------------------------------------------------------------------------
_DATA_PATH = os.path.join(os.path.abspath(os.path.dirname(__file__)), "demo", "data", "demo_data.json")
try:
with open(_DATA_PATH) as _f:
PATIENT_DATA: list[dict] = json.load(_f)
except (FileNotFoundError, json.JSONDecodeError) as _e:
raise RuntimeError(f"Failed to load patient data from {_DATA_PATH}: {_e}") from _e
PATIENT_DICT: dict[str, dict] = {p["hadm_id"]: p for p in PATIENT_DATA}
_DOCTOR_AVATAR = "https://cdn-icons-png.flaticon.com/512/3774/3774299.png"
# Per-patient avatar URLs matched to demographics (gender + age group)
_PATIENT_AVATAR_URLS = {
"patient_MI": "https://cdn-icons-png.flaticon.com/512/5488/5488324.png", # 64yo Male β€” elderly man
"patient_PNA": "https://cdn-icons-png.flaticon.com/512/4140/4140051.png", # 47yo Female β€” middle-aged woman
"patient_UTI": "https://cdn-icons-png.flaticon.com/512/4140/4140047.png", # 29yo Female β€” young woman
}
_SORTED_PATIENTS: list[dict] = sorted(PATIENT_DATA, key=lambda x: x["hadm_id"])
def _sorted_patients() -> list[dict]:
return _SORTED_PATIENTS
def _build_single_card_html(p: dict, selected: bool, avatar_url: str) -> str:
"""Build HTML for a single patient card (no select button β€” handled by gr.Button)."""
border_color = "#3b82f6" if selected else "#e5e7eb"
bg_color = "#eff6ff" if selected else "#ffffff"
shadow = "0 0 0 2px #3b82f6" if selected else "0 1px 3px rgba(0,0,0,0.06)"
age = html.escape(str(p.get("age", "?")))
gender = html.escape(str(p.get("gender", "Unknown")))
diagnosis = html.escape(str(p.get("diagnosis", "Unknown")))
chief = html.escape(str(p.get("chiefcomplaint", "β€”")))
transport = html.escape(str(p.get("arrival_transport", "β€”")))
safe_avatar = html.escape(avatar_url)
return (
f"<div style='background:{bg_color};border:2px solid {border_color};"
f"border-radius:12px;padding:24px 20px;box-shadow:{shadow};transition:all 0.15s ease'>"
f"<div style='text-align:center;margin-bottom:10px'>"
f"<img src='{safe_avatar}' style='width:56px;height:56px;border-radius:50%;"
f"background:#f3f4f6;padding:4px' alt='Patient'>"
f"</div>"
f"<div style='font-size:13px;color:#6b7280;line-height:1.6'>"
f"<div><b>Age:</b> {age} Β· <b>Gender:</b> {gender}</div>"
f"<div><b>Chief Complaint:</b> {chief}</div>"
f"<div><b>Transport:</b> {transport}</div>"
f"<div><b>Dx:</b> {diagnosis}</div>"
f"</div>"
f"</div>"
)
# ---------------------------------------------------------------------------
# HTML helpers
# ---------------------------------------------------------------------------
def build_recap_html(hadm_id: str, model: str, cefr: str, personality: str, recall: str, confusion: str) -> str:
patient = PATIENT_DICT.get(hadm_id, {})
patient_label = (
f"Age {patient.get('age')} Β· {patient.get('gender')} Β· {patient.get('diagnosis', 'Unknown')}"
if patient else "β€”"
)
personality_label = next((l for l, v in PERSONALITY_CHOICES if v == personality), personality)
cefr_label = next((l for l, v in CEFR_CHOICES if v == cefr), cefr)
recall_label = next((l for l, v in RECALL_CHOICES if v == recall), recall)
confusion_label = next((l for l, v in CONFUSION_CHOICES if v == confusion), confusion)
def _card(label, value):
safe_label = html.escape(str(label))
safe_value = html.escape(str(value))
return (
"<div style='padding:10px 14px;background:var(--background-fill-primary,#fff);"
"border:1px solid var(--border-color-primary);border-radius:8px'>"
f"<div style='font-size:11px;font-weight:600;letter-spacing:0.06em;"
f"text-transform:uppercase;color:var(--body-text-color-subdued);margin-bottom:4px'>{safe_label}</div>"
f"<div style='font-size:13px;font-weight:500;line-height:1.4'>{safe_value}</div>"
"</div>"
)
items = [
_card("Patient", patient_label),
_card("Personality", personality_label),
_card("Model", model),
_card("Language Proficiency", cefr_label),
_card("Medical History Recall", recall_label),
_card("Cognitive Confusion", confusion_label),
]
grid_items = "".join(items)
return (
"<div style='background:var(--color-accent-soft,#f0f7ff);"
"border:1px solid var(--border-color-primary);border-radius:10px;"
"padding:16px 20px;margin-bottom:8px'>"
"<div style='font-weight:600;font-size:14px;margin-bottom:12px'>πŸ“‹ Simulation Configuration</div>"
f"<div style='display:grid;grid-template-columns:1fr 1fr;gap:8px'>{grid_items}</div>"
"</div>"
)
# ---------------------------------------------------------------------------
# Profile HTML helpers (card-based layout matching recap style)
# ---------------------------------------------------------------------------
def _profile_item(label: str, val) -> str:
"""Build a single key-value item inside a profile card."""
safe_val = html.escape(str(val)) if val not in (None, "") else "N/A"
safe_label = html.escape(str(label))
return (
f"<div style='display:flex;gap:6px;padding:5px 0;"
f"border-bottom:1px solid #f3f4f6;font-size:13px;line-height:1.5'>"
f"<span style='color:#6b7280;white-space:nowrap;min-width:110px;"
f"font-weight:500'>{safe_label}</span>"
f"<span style='color:#1f2937;flex:1'>{safe_val}</span>"
f"</div>"
)
def _profile_card(icon: str, title: str, items_html: str, accent: str = "#3b82f6") -> str:
"""Build a styled card section for the profile panel."""
return (
f"<div style='background:var(--background-fill-primary,#fff);"
f"border:1px solid var(--border-color-primary,#e5e7eb);border-radius:10px;"
f"padding:14px 16px;margin-bottom:10px'>"
f"<div style='display:flex;align-items:center;gap:8px;margin-bottom:10px;"
f"padding-bottom:8px;border-bottom:2px solid {accent}'>"
f"<span style='font-size:16px'>{icon}</span>"
f"<span style='font-size:13px;font-weight:600;letter-spacing:0.03em;"
f"text-transform:uppercase;color:{accent}'>{html.escape(title)}</span>"
f"</div>"
f"{items_html}"
f"</div>"
)
def build_profile_html(p: dict) -> str:
hadm_id = p.get("hadm_id", "")
avatar_url = _PATIENT_AVATAR_URLS.get(hadm_id, list(_PATIENT_AVATAR_URLS.values())[0])
safe_avatar = html.escape(avatar_url)
# Header with avatar
header = (
f"<div style='text-align:center;margin-bottom:14px'>"
f"<img src='{safe_avatar}' style='width:52px;height:52px;border-radius:50%;"
f"background:#f3f4f6;padding:3px;margin-bottom:6px' alt='Patient'>"
f"<div style='font-size:15px;font-weight:600;color:#1f2937'>Patient Profile</div>"
f"</div>"
)
basic = _profile_card("πŸ‘€", "Demographics",
_profile_item("Age", p.get("age"))
+ _profile_item("Gender", p.get("gender"))
+ _profile_item("Race", p.get("race"))
+ _profile_item("Transport", p.get("arrival_transport")),
accent="#3b82f6",
)
social = _profile_card("🏠", "Social History",
_profile_item("Tobacco", p.get("tobacco"))
+ _profile_item("Alcohol", p.get("alcohol"))
+ _profile_item("Illicit Drug", p.get("illicit_drug"))
+ _profile_item("Exercise", p.get("exercise"))
+ _profile_item("Marital Status", p.get("marital_status"))
+ _profile_item("Children", p.get("children"))
+ _profile_item("Living Situation", p.get("living_situation"))
+ _profile_item("Occupation", p.get("occupation"))
+ _profile_item("Insurance", p.get("insurance")),
accent="#10b981",
)
history = _profile_card("πŸ“‹", "Previous Medical History",
_profile_item("Allergies", p.get("allergies"))
+ _profile_item("Family History", p.get("family_medical_history"))
+ _profile_item("Medical Devices", p.get("medical_device"))
+ _profile_item("Prior History", p.get("medical_history")),
accent="#f59e0b",
)
visit = _profile_card("🩺", "Current Visit",
_profile_item("Present Illness (+)", p.get("present_illness_positive"))
+ _profile_item("Present Illness (βˆ’)", p.get("present_illness_negative"))
+ _profile_item("Chief Complaint", p.get("chiefcomplaint"))
+ _profile_item("Pain (0–10)", p.get("pain"))
+ _profile_item("Medications", p.get("medication"))
+ _profile_item("Disposition", p.get("disposition"))
+ _profile_item("Diagnosis", p.get("diagnosis")),
accent="#ef4444",
)
return (
f"<div style='font-family:Noto Sans KR,Noto Sans,Malgun Gothic,Apple SD Gothic Neo,sans-serif;"
f"font-size:14px;line-height:1.5;background:var(--color-accent-soft,#f0f7ff);"
f"border:1px solid var(--border-color-primary,#e5e7eb);border-radius:12px;"
f"padding:18px 16px;max-height:400px;overflow-y:auto'>"
f"{header}{basic}{social}{history}{visit}"
f"</div>"
)
def build_blind_profile_html(p: dict) -> str:
"""Show only basic demographic info for practice mode without full case details."""
hadm_id = p.get("hadm_id", "")
avatar_url = _PATIENT_AVATAR_URLS.get(hadm_id, list(_PATIENT_AVATAR_URLS.values())[0])
safe_avatar = html.escape(avatar_url)
header = (
f"<div style='text-align:center;margin-bottom:14px'>"
f"<img src='{safe_avatar}' style='width:52px;height:52px;border-radius:50%;"
f"background:#f3f4f6;padding:3px;margin-bottom:6px' alt='Patient'>"
f"<div style='font-size:15px;font-weight:600;color:#1f2937'>Patient Info</div>"
f"<div style='font-size:12px;color:#9ca3af;font-style:italic;margin-top:4px'>"
f"Basic demographics only β€” gather the rest through consultation.</div>"
f"</div>"
)
basic = _profile_card("πŸ‘€", "Demographics",
_profile_item("Age", p.get("age"))
+ _profile_item("Gender", p.get("gender"))
+ _profile_item("Race", p.get("race"))
+ _profile_item("Transport", p.get("arrival_transport")),
accent="#3b82f6",
)
hint = (
"<div style='background:var(--background-fill-primary,#fff);"
"border:1px dashed #d1d5db;border-radius:10px;padding:16px;text-align:center;"
"color:#9ca3af;font-size:13px'>"
"πŸ” Additional information is hidden.<br>"
"Interview the patient to uncover their history."
"</div>"
)
return (
f"<div style='font-family:Noto Sans KR,Noto Sans,Malgun Gothic,Apple SD Gothic Neo,sans-serif;"
f"font-size:14px;line-height:1.5;background:var(--color-accent-soft,#f0f7ff);"
f"border:1px solid var(--border-color-primary,#e5e7eb);border-radius:12px;"
f"padding:18px 16px'>"
f"{header}{basic}{hint}"
f"</div>"
)
# ---------------------------------------------------------------------------
# Custom CSS
# ---------------------------------------------------------------------------
CUSTOM_CSS = """
/* ── Global gothic (sans-serif) font ───────────────────────────── */
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+KR:wght@300;400;500;600;700&family=Noto+Sans:wght@300;400;500;600;700&display=swap');
*, *::before, *::after {
font-family: 'Noto Sans KR', 'Noto Sans', 'Malgun Gothic', 'Apple SD Gothic Neo',
'Segoe UI', sans-serif !important;
}
.prose h1, .prose h2, .prose h3, .prose h4,
.markdown-text h1, .markdown-text h2, .markdown-text h3, .markdown-text h4 {
font-family: 'Noto Sans KR', 'Noto Sans', 'Malgun Gothic', 'Apple SD Gothic Neo',
'Segoe UI', sans-serif !important;
}
/* ── Page background ────────────────────────────────────────────── */
body, .gradio-container, .main, footer {
background-color: #f3f4f6 !important;
}
/* ── All form labels β€” no blue/accent background ────────────────── */
label span,
label > span,
.block label span,
.label-wrap span,
.gradio-container label span,
.gradio-dropdown label span,
.wrap > label > span,
[class*="label"] > span {
background: transparent !important;
background-color: transparent !important;
}
/* ── White shadow cards for form sections ───────────────────────── */
.form-card {
background: #ffffff !important;
border-radius: 14px !important;
box-shadow: 0 1px 3px rgba(0,0,0,0.08), 0 0 0 1px rgba(0,0,0,0.05) !important;
border: none !important;
padding: 16px 20px 20px !important;
}
.form-card > .gap,
.form-card > div {
background: transparent !important;
border: none !important;
gap: 12px !important;
padding: 8px !important;
}
/* ── Card section title β€” vertically centered with background ──── */
.card-title {
font-size: 18px;
font-weight: 600;
color: #374151;
padding: 6px 0;
margin-bottom: 4px;
border-bottom: 1px solid #e5e7eb;
display: flex;
align-items: center;
justify-content: center;
min-height: 32px;
line-height: 1;
}
/* ── Tooltip descriptions β€” plain text inside the white cell ─────── */
.option-desc {
font-size: 15px;
color: #4b5563;
margin-top: 8px;
line-height: 1.4;
}
/* ── Transparent HTML tip containers ────────────────────────────── */
.tip-html,
.tip-html > div,
.tip-html > .prose {
background: transparent !important;
border: none !important;
padding: 0 !important;
margin: 0 !important;
}
/* ── CEFR radio & Persona grid β€” thin border on buttons ─────────── */
.compact-radio .wrap {
gap: 6px !important;
background: transparent !important;
}
.compact-radio label {
padding: 5px 14px !important;
border-radius: 8px !important;
font-size: 13px !important;
min-width: unset !important;
background: transparent !important;
border: 1px solid #d1d5db !important;
}
.compact-radio label:has(input:checked) {
border-color: #3b82f6 !important;
background: #eff6ff !important;
}
.persona-grid {
display: grid !important;
grid-template-columns: 1fr 1fr !important;
gap: 16px !important;
}
/* ── Persona cell β€” white box containing buttons + description ───── */
.persona-cell {
background: #ffffff !important;
border: 1px solid #e5e7eb !important;
border-radius: 10px !important;
padding: 14px 16px !important;
display: flex !important;
flex-direction: column !important;
}
/* Force options to use black bold font without background */
.persona-cell label span,
.persona-cell span {
color: black !important;
font-weight: bold !important;
background: transparent !important;
}
/* ── Personality radio β€” 2-column grid for 6 choices ────────────── */
.personality-radio .wrap {
display: grid !important;
grid-template-columns: 1fr 1fr !important;
gap: 6px !important;
}
.persona-cell .compact-radio {
margin-bottom: 4px !important;
}
/* ── Mode cards ─────────────────────────────────────────────────── */
.mode-card {
background: #ffffff !important;
border: 1px solid #e5e7eb !important;
border-radius: 14px !important;
box-shadow: 0 1px 3px rgba(0,0,0,0.06) !important;
padding: 20px !important;
}
/* ── Start simulation button shadow ─────────────────────────────── */
#start-btn > button {
box-shadow: 0 4px 14px rgba(59, 130, 246, 0.30) !important;
font-size: 15px !important;
}
/* ── Patient card HTML wrapper β€” remove Gradio's default padding ── */
.patient-card-html > div,
.patient-card-html .prose {
padding: 10 !important;
margin: 0 !important;
border: none !important;
background: #ffffff !important;
}
/* ── Patient profile panel β€” strip wrapper chrome ───────────────── */
.profile-display,
.profile-display > div,
.profile-display .prose {
background: transparent !important;
border: none !important;
padding: 0 !important;
margin: 0 !important;
}
/* White background + padding for each patient column cell */
.form-card .gradio-column {
background: #ffffff !important;
border-radius: 10px !important;
border: 1px solid #e5e7eb !important;
padding: 12px !important;
gap: 8px !important;
}
/* Spacing between columns and from the card border */
.form-card .gradio-row {
gap: 12px !important;
padding: 12px 12px !important;
background: transparent !important;
}
/* Fallback using data-testid if gradio-column class doesn't apply */
.form-card [data-testid="column"] {
background: #ffffff !important;
border-radius: 10px !important;
border: 1px solid #e5e7eb !important;
padding: 12px !important;
}
.form-card [data-testid="row"] {
gap: 24px !important;
padding: 12px 8px !important;
}
.patient-row {
gap: 16px !important;
margin-bottom: 16px !important;
padding: 0 24px !important;
}
/* Gradio 6 wraps row flex content in an inner div β€” target it directly */
.patient-row > div {
gap: flex !important;
display: flex !important;
flex-wrap: wrap !important;
align-items: stretch !important;
}
.patient-card-column {
flex: 1 1 calc(33.333% - 12px) !important;
flex-direction: column !important;
min-width: 100px !important;
max-width: calc(33.333% - 12px) !important;
display: flex !important;
height: 100% !important;
gap: 0 !important;
}
.patient-card-html {
flex: 1 !important;
height: 100% !important;
}
"""
# Tooltip maps for descriptions shown below each radio group
PERSONALITY_TIPS = {
"plain": "No strong emotions or noticeable behavior.",
"verbose": "Speaks a lot, gives highly detailed responses.",
"distrust": "Questions the doctor's expertise and care.",
"pleasing": "Overly positive, tends to minimize problems.",
"impatient": "Easily irritated and lacks patience.",
"overanxious": "Expresses concern beyond what is typical.",
}
RECALL_TIPS = {
"low": "Often forgets even major medical events.",
"high": "Usually recalls medical events accurately.",
}
CONFUSION_TIPS = {
"normal": "Clear mental status.",
"high": "Highly dazed and extremely confused.",
}
CEFR_TIPS = {
"A": "Can make simple sentences.",
"B": "Can have daily conversations.",
"C": "Can freely use advanced medical terms.",
}
def _make_tip_html(tips: dict, selected: str) -> str:
desc = tips.get(selected, "")
return f"<p class='option-desc'>{desc}</p>" if desc else ""
# ---------------------------------------------------------------------------
# Callbacks β€” Setup
# ---------------------------------------------------------------------------
def go_to_setup():
return gr.update(visible=False), gr.update(visible=True)
def _sanitize_error(msg: str) -> str:
"""Remove API key patterns and internal paths from error messages."""
msg = re.sub(r'sk-[A-Za-z0-9_-]{10,}', '[REDACTED]', msg)
msg = re.sub(r'AIza[A-Za-z0-9_-]{20,}', '[REDACTED]', msg)
msg = re.sub(r'AQ\.[A-Za-z0-9_-]{10,}', '[REDACTED]', msg)
msg = re.sub(r'/(?:home|data|tmp|var|usr|etc)/[^\s\'"]+', '[PATH_REDACTED]', msg)
return msg
def _setup_error(msg: str):
"""Return output tuple that keeps setup visible and shows an inline error."""
safe_msg = html.escape(_sanitize_error(msg))
error_html = (
f"<div style='background:#fef2f2;border:1px solid #fca5a5;border-radius:8px;"
f"padding:10px 14px;color:#b91c1c;font-size:13px'>⚠️ {safe_msg}</div>"
)
return (
gr.update(), # patient_agent_state β€” no change
gr.update(), # sim_config_state β€” no change
gr.update(), # setup_section β€” no change (already visible)
gr.update(), # mode_section β€” no change (already hidden)
gr.update(), # recap_display β€” no change
gr.update(value=error_html, visible=True), # show setup_error_display
)
def start_simulation(
hadm_id: str,
model: str,
cefr: str,
personality: str,
recall: str,
confusion: str,
user_api_key: str = "",
request: gr.Request = None,
):
if not hadm_id:
return _setup_error("Please select a patient first.")
if model not in BACKEND_MODELS:
return _setup_error("Invalid model selection.")
using_own_key = bool(user_api_key.strip())
is_openai = "gpt" in model.lower()
if using_own_key:
api_key = user_api_key.strip()
elif is_openai:
api_key = os.environ.get("OPENAI_API_KEY", "")
else:
api_key = os.environ.get("GENAI_API_KEY", "") or os.environ.get("GOOGLE_API_KEY", "")
if not api_key:
key_name = "OPENAI_API_KEY" if is_openai else "GENAI_API_KEY / GOOGLE_API_KEY"
return _setup_error(
f"API key not configured ({key_name}). "
"Run with your own key, or contact the demo administrator."
)
patient = PATIENT_DICT.get(hadm_id)
if patient is None:
return _setup_error("Invalid patient selection. Please select a valid patient from the list.")
patient = copy.deepcopy(patient)
try:
agent = PatientAgent(
model=model,
visit_type="emergency_department",
personality=personality,
recall_level=recall,
confusion_level=confusion,
lang_proficiency_level=cefr,
api_key=api_key,
temperature=0.7,
num_word_sample=10,
random_seed=42,
log_verbose=False,
use_vertex=True, # Force using Vertex AI for all models to ensure consistent behavior and better error handling
**patient,
)
except Exception as e:
_logger.error("Failed to initialize patient agent: %s", _sanitize_error(str(e)))
return _setup_error(f"Failed to initialize patient agent: {_sanitize_error(str(e))}")
recap = build_recap_html(hadm_id, model, cefr, personality, recall, confusion)
sim_config = {
"patient": patient,
"model": model,
"recap_html": recap,
"user_api_key": user_api_key.strip(), # empty string = using shared demo key
}
return (
agent,
sim_config,
gr.update(visible=False), # hide setup_section
gr.update(visible=True), # show mode_section
gr.update(value=recap), # recap_display
gr.update(visible=False), # hide setup_error_display
)
def back_to_setup(agent):
if agent is not None:
agent.reset_history(verbose=False)
return (
None, # clear patient_agent_state
None, # clear sim_config_state
gr.update(visible=True), # show setup_section
gr.update(visible=False), # hide mode_section
gr.update(visible=False), # hide chat_section
gr.update(visible=False), # hide auto_section
gr.update(value=""), # clear recap_display
gr.update(value="", visible=False), # clear setup_error_display
)
# ---------------------------------------------------------------------------
# Callbacks β€” Manual practice mode
# ---------------------------------------------------------------------------
def start_manual(profile_mode: str, agent, sim_config: dict):
if agent is None or sim_config is None:
raise gr.Warning("Session expired. Please restart.")
agent.reset_history(verbose=False)
patient = sim_config["patient"]
profile_html = (
build_profile_html(patient)
if profile_mode == "full"
else build_blind_profile_html(patient)
)
hadm_id = sim_config["patient"].get("hadm_id", "")
patient_avatar = _PATIENT_AVATAR_URLS.get(hadm_id, list(_PATIENT_AVATAR_URLS.values())[0])
return (
profile_html,
sim_config.get("recap_html", ""), # show recap in left panel
gr.update(value=[], avatar_images=(_DOCTOR_AVATAR, patient_avatar)), # chatbot with matched avatar
[], # clear chat log
gr.update(visible=False), # hide mode_section
gr.update(visible=False), # hide auto_section
gr.update(visible=True), # show chat_section
)
_INJECTION_PATTERNS = re.compile(
r'ignore\s+(?:\w+\s+){0,3}(instructions?|prompts?|rules?)'
r'|(?:forget|disregard)\s+(?:\w+\s+){0,3}instructions?'
r'|you\s+are\s+(?:now|actually)\s+(?:a\s+|an\s+)?(?:DAN|jailbreak)'
r'|reveal\s+(?:\w+\s+){0,2}(system\s+)?prompt'
r'|act\s+as\s+(?:if\s+)?you\s+(?:have\s+no|are\s+without)\s+(?:restrictions?|guidelines?)',
re.IGNORECASE,
)
def chat(message: str, history: list, agent, sim_config: dict, request: gr.Request = None):
if agent is None:
raise gr.Error("No simulation running. Please start a simulation first.")
if not message.strip():
return history, ""
if len(message) > MAX_MESSAGE_CHARS:
raise gr.Error(
f"Message too long ({len(message)} characters). "
f"Please keep messages under {MAX_MESSAGE_CHARS} characters."
)
if _INJECTION_PATTERNS.search(message):
_logger.warning("Prompt injection attempt detected from key=%s", get_client_key(request))
raise gr.Error("Invalid input detected. Please enter a valid clinical question.")
using_own_key = bool(sim_config and sim_config.get("user_api_key"))
client_key = get_client_key(request)
if not using_own_key:
allowed, limit_msg = _rate_limiter.check_chat_message(client_key)
if not allowed:
raise gr.Error(limit_msg)
else:
# Own-key users bypass per-IP quotas but still respect global capacity.
allowed, limit_msg = _rate_limiter.check_global_capacity()
if not allowed:
raise gr.Error(limit_msg)
response = agent(user_prompt=message, using_multi_turn=True, verbose=False)
history = history + [
{"role": "user", "content": message},
{"role": "assistant", "content": response},
]
return history, history, ""
def reset_chat(agent):
if agent is not None:
agent.reset_history(verbose=False)
return gr.update(value=[]), []
def back_to_setup_from_chat(agent):
if agent is not None:
agent.reset_history(verbose=False)
return (
[], # clear chat log
"", # clear profile html
"", # clear chat recap
None, # clear patient_agent_state
None, # clear sim_config_state
gr.update(visible=True), # show setup_section
gr.update(visible=False), # hide mode_section
gr.update(visible=False), # hide auto_section
gr.update(visible=False), # hide chat_section
gr.update(value=""), # clear recap_display
gr.update(value="", visible=False), # clear setup_error_display
)
# ---------------------------------------------------------------------------
# Callbacks β€” Auto simulation mode
# ---------------------------------------------------------------------------
def _auto_fallback_outputs():
"""Return fresh gr.update() instances for error/exit paths in start_auto."""
return (
gr.update(), # auto_chatbot
gr.update(visible=True), # mode_section
gr.update(visible=False), # auto_section
gr.update(visible=False), # chat_section
gr.update(), # auto_recap
)
def start_auto(agent, sim_config: dict, request: gr.Request = None):
"""Generator β€” yields chatbot updates turn-by-turn so the UI streams live."""
client_key = get_client_key(request)
if agent is None or sim_config is None:
gr.Warning("Session expired. Please restart.")
yield _auto_fallback_outputs()
return
using_own_key = bool(sim_config.get("user_api_key"))
if not using_own_key:
allowed, limit_msg = _rate_limiter.check_auto_run(client_key)
if not allowed:
gr.Warning(limit_msg)
yield _auto_fallback_outputs()
return
else:
# Own-key users bypass per-IP quotas but still enforce the concurrent
# run cap and the hard global capacity limit.
allowed, limit_msg = _rate_limiter.check_own_key_auto_run(client_key)
if not allowed:
gr.Warning(limit_msg)
yield _auto_fallback_outputs()
return
try:
agent.reset_history(verbose=False)
# Show auto_section immediately; set per-patient avatar on first yield
_hadm_id = sim_config["patient"].get("hadm_id", "")
_patient_avatar = _PATIENT_AVATAR_URLS.get(_hadm_id, list(_PATIENT_AVATAR_URLS.values())[0])
yield (
gr.update(value=[], avatar_images=(_DOCTOR_AVATAR, _patient_avatar)),
gr.update(visible=False),
gr.update(visible=True),
gr.update(visible=False),
sim_config.get("recap_html", ""),
)
model = sim_config["model"]
is_openai = "gpt" in model.lower()
if using_own_key:
api_key = sim_config["user_api_key"]
elif is_openai:
api_key = os.environ.get("OPENAI_API_KEY", "")
else:
api_key = os.environ.get("GENAI_API_KEY", "") or os.environ.get("GOOGLE_API_KEY", "")
try:
doctor = DoctorAgent(
model=model,
api_key=api_key,
temperature=0.2,
random_seed=42,
max_inferences=MAX_AUTO_INFERENCES,
use_vertex=True,
**{k: sim_config["patient"][k] for k in ("age", "gender", "arrival_transport") if k in sim_config["patient"]},
)
except Exception as e:
_logger.error("Failed to initialize doctor agent: %s", _sanitize_error(str(e)))
gr.Error(f"Failed to initialize doctor agent: {_sanitize_error(str(e))}")
yield _auto_fallback_outputs()
return
# Switch to auto_section immediately with empty chatbot; set recap once
chat_history = []
recap_html = sim_config.get("recap_html", "")
yield chat_history, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), recap_html
def _append(role: str, content: str):
chat_history.append({"role": role, "content": content})
def _yield():
return chat_history, gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update()
try:
# Doctor greets first
doctor_greet = doctor.doctor_greet
_append("user", doctor_greet)
yield _yield()
for inference_idx in range(MAX_AUTO_INFERENCES):
is_last = inference_idx == MAX_AUTO_INFERENCES - 1
# Patient responds to the last doctor message
patient_response = agent(
user_prompt=chat_history[-1]["content"],
using_multi_turn=True,
verbose=False,
)
_append("assistant", patient_response)
yield _yield()
# Doctor responds
doctor_input = chat_history[-1]["content"]
if is_last:
doctor_input += "\nThis is the final turn. Now, you must provide your top 5 differential diagnoses."
doctor_response = doctor(
user_prompt=doctor_input,
using_multi_turn=True,
verbose=False,
)
_append("user", doctor_response)
yield _yield()
if detect_ed_termination(doctor_response):
break
except Exception as e:
gr.Error(f"Simulation error: {_sanitize_error(str(e))}")
yield _yield()
finally:
_rate_limiter.release_auto_slot(client_key)
def back_to_mode_from_auto(agent):
if agent is not None:
agent.reset_history(verbose=False)
return (
[], # clear auto chatbot log
"", # clear auto recap
gr.update(visible=True), # show mode_section
gr.update(visible=False), # hide auto_section
)
def back_to_mode_from_chat(agent):
if agent is not None:
agent.reset_history(verbose=False)
return (
[], # clear chat log
"", # clear profile html
"", # clear chat recap
gr.update(visible=True), # show mode_section
gr.update(visible=False), # hide chat_section
)
def back_to_setup_from_auto(agent):
if agent is not None:
agent.reset_history(verbose=False)
return (
[], # clear auto chatbot log
"", # clear auto recap
None, # clear patient_agent_state
None, # clear sim_config_state
gr.update(visible=True), # show setup_section
gr.update(visible=False), # hide mode_section
gr.update(visible=False), # hide auto_section
gr.update(visible=False), # hide chat_section
gr.update(value=""), # clear recap_display
gr.update(value="", visible=False), # clear setup_error_display
)
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="PatientSim", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
patient_agent_state = gr.State(None)
sim_config_state = gr.State(None)
selected_patient_state = gr.State(None)
chat_history_state = gr.State([])
# ── Intro section ────────────────────────────────────────────────────────
with gr.Column(visible=True) as intro_section:
gr.Markdown(
"# πŸ₯ PatientSim β€” ED Consultation Demo\n\n"
"**PatientSim** is a research framework for simulating realistic emergency department "
"doctor–patient interactions, presented at "
"[NeurIPS 2025 (Datasets & Benchmarks)](https://openreview.net/forum?id=1THAjdP4QJ).\n\n"
"Large language models act as patients with **controllable personas** β€” you choose their "
"personality, language proficiency, medical history recall, and cognitive state β€” producing "
"diverse and realistic consultation scenarios for training and evaluation.\n\n"
"---\n\n"
"### What you can do here\n\n"
"| Mode | Description |\n"
"|------|-------------|\n"
"| πŸ€– **Auto Simulation** | Watch an AI doctor conduct a full consultation with the simulated patient and arrive at a differential diagnosis β€” no input required. |\n"
"| 🩺 **Practice Mode** | You play the doctor. Consult the simulated patient, gather medical history, and work toward a diagnosis yourself. |\n\n"
"---\n\n"
"### How to get started\n"
"1. Click **Get Started** below.\n"
"2. Enter your API key and select a patient case and persona.\n"
"3. Choose a simulation mode and begin."
)
get_started_btn = gr.Button("Get Started β†’", variant="primary", size="lg")
# ── Setup section ────────────────────────────────────────────────────────
with gr.Column(visible=False) as setup_section:
gr.Markdown(
"# πŸ₯ PatientSim β€” Setup\n\n"
"Configure your session below. Select the patient case you want to simulate, "
"choose the AI model to power the patient, and define the patient's persona across "
"four behavioral axes. When ready, click **Start Simulation**."
)
# ── Connection card ──────────────────────────────────────────────────
with gr.Group(elem_classes=["form-card"]):
gr.Markdown(
"**πŸ”‘ Model & API Key**\n\n"
"This demo runs on a shared API key with a limited number of free calls. "
"If the free quota has been exhausted, please enter your own API key below "
"(OpenAI or Google Gemini) to continue without restrictions. "
"Your key is used only for this session and is never stored on our servers."
)
with gr.Row(equal_height=True):
model_dd = gr.Dropdown(
choices=BACKEND_MODELS,
value=BACKEND_MODELS[0],
label="Model",
scale=1,
)
api_key_input = gr.Textbox(
label="API Key (optional)",
placeholder="Leave blank to use the shared demo key Β· sk-... or paste your Gemini key",
type="password",
scale=2,
)
# ── Patient Case card (one gr.Button per patient) ────────────────────
with gr.Group(elem_classes=["form-card"]):
gr.HTML("<span class='card-title'>🩺 Patient Case</span>")
_all_cards: list[tuple[str, str, gr.HTML]] = [] # (hadm_id, avatar_url, html_component)
_all_select_btns: list[tuple[str, gr.Button]] = [] # (hadm_id, button)
_patients = _sorted_patients()
for _i in range(0, len(_patients), 3):
with gr.Row(elem_classes=["patient-row"]):
for _j, _p in enumerate(_patients[_i:_i + 3]):
_idx = _i + _j
_avatar = _PATIENT_AVATAR_URLS.get(_p["hadm_id"], _DOCTOR_AVATAR)
with gr.Column(scale=1, elem_classes=["patient-card-column"]):
_card_html = gr.HTML(
value=_build_single_card_html(_p, False, _avatar),
elem_classes=["patient-card-html"],
)
_select_btn = gr.Button("Select", size="sm", variant="secondary")
_all_cards.append((_p["hadm_id"], _avatar, _card_html))
_all_select_btns.append((_p["hadm_id"], _select_btn))
# ── Persona card (redesigned) ────────────────────────────────────────
with gr.Group(elem_classes=["form-card"]):
gr.HTML("<span class='card-title'>🎭 Patient Persona</span>")
# Row 1: Personality + CEFR
with gr.Row(equal_height=True):
with gr.Column(min_width=200, elem_classes=["persona-cell"]):
personality_dd = gr.Radio(
choices=PERSONALITY_CHOICES,
value="plain",
label="Personality",
elem_classes=["compact-radio", "personality-radio"],
)
personality_tip = gr.HTML(
value=_make_tip_html(PERSONALITY_TIPS, "plain"),
elem_classes=["tip-html"],
)
with gr.Column(min_width=200, elem_classes=["persona-cell"]):
cefr_radio = gr.Radio(
choices=CEFR_CHOICES,
value="B",
label="Language Proficiency (CEFR)",
elem_classes=["compact-radio"],
)
cefr_tip = gr.HTML(
value=_make_tip_html(CEFR_TIPS, "B"),
elem_classes=["tip-html"],
)
# Row 2: Medical History Recall + Cognitive Confusion (now radios)
with gr.Row(equal_height=True):
with gr.Column(min_width=200, elem_classes=["persona-cell"]):
recall_radio = gr.Radio(
choices=RECALL_CHOICES,
value="high",
label="Medical History Recall",
elem_classes=["compact-radio"],
)
recall_tip = gr.HTML(
value=_make_tip_html(RECALL_TIPS, "high"),
elem_classes=["tip-html"],
)
with gr.Column(min_width=200, elem_classes=["persona-cell"]):
confusion_radio = gr.Radio(
choices=CONFUSION_CHOICES,
value="normal",
label="Cognitive Confusion",
elem_classes=["compact-radio"],
)
confusion_tip = gr.HTML(
value=_make_tip_html(CONFUSION_TIPS, "normal"),
elem_classes=["tip-html"],
)
start_btn = gr.Button("β–Ά Start Simulation", variant="primary", size="lg", elem_id="start-btn")
setup_error_display = gr.HTML("", visible=False)
# ── Mode selection section ───────────────────────────────────────────────
with gr.Column(visible=False) as mode_section:
gr.Markdown(
"## Choose Simulation Mode\n\n"
"Your patient agent is ready. Review your configuration below, then pick a mode. "
"**Auto Simulation** runs a fully automated consultation so you can observe the "
"system in action. **Practice Mode** puts you in the doctor's seat for hands-on "
"training."
)
recap_display = gr.HTML()
with gr.Row(equal_height=True):
# Auto simulation card
with gr.Column(elem_classes=["mode-card"], min_width=280):
gr.Markdown(
"### πŸ€– Auto Simulation\n"
"Watch a fully automated consultation between an AI doctor and "
"the simulated patient. The doctor conducts the interview and "
"arrives at a differential diagnosis β€” no input required."
)
auto_btn = gr.Button("Run Auto Simulation", variant="primary")
# Manual practice card
with gr.Column(elem_classes=["mode-card"], min_width=280):
gr.Markdown(
"### 🩺 Practice Mode\n"
"You play the doctor. Choose how much patient information "
"to display on the side panel."
)
profile_mode_radio = gr.Radio(
choices=[
("Full Profile β€” all case details visible", "full"),
("Basic Info only β€” practice without prior knowledge", "blind"),
],
value="full",
label="Patient Profile Visibility",
elem_classes=["compact-radio"],
)
manual_btn = gr.Button("Start Practice", variant="primary")
with gr.Row():
back_from_mode_btn = gr.Button("← Back to Setup")
# ── Auto simulation section ──────────────────────────────────────────────
with gr.Column(visible=False) as auto_section:
with gr.Row():
back_from_auto_btn = gr.Button("← Back to Setup")
back_from_auto_to_mode_btn = gr.Button("← Mode Selection")
gr.Markdown(
"### πŸ€– Auto Simulation\n\n"
"An AI doctor is conducting the consultation. The doctor will ask questions, "
"gather medical history, and conclude with a differential diagnosis. "
"**Doctor** (user) messages appear on the right; **Patient** (assistant) responses on the left."
)
auto_recap = gr.HTML()
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=280):
auto_recap = gr.HTML()
with gr.Column(scale=2):
auto_chatbot = gr.Chatbot(
label="Doctor–Patient Dialogue",
height=700,
show_label=True,
avatar_images=(
_DOCTOR_AVATAR,
list(_PATIENT_AVATAR_URLS.values())[0],
),
placeholder="Run the auto simulation to see the dialogue here.",
)
# ── Manual chat section ──────────────────────────────────────────────────
with gr.Column(visible=False) as chat_section:
gr.Markdown(
"### 🩺 Practice Mode\n\n"
"You are the attending physician. Type your questions and responses in the box "
"below to consult the simulated patient. Use the patient profile panel on the "
"left for reference. Try to gather sufficient history to reach a differential diagnosis."
)
with gr.Row():
back_from_chat_btn = gr.Button("← Back to Setup", scale=1)
back_from_chat_to_mode_btn = gr.Button("← Mode Selection", scale=1)
reset_btn = gr.Button("β†Ί Reset Conversation", scale=1)
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=280):
profile_display = gr.HTML(elem_classes=["profile-display"])
chat_recap = gr.HTML()
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Consultation",
height=700,
show_label=True,
avatar_images=(
_DOCTOR_AVATAR,
list(_PATIENT_AVATAR_URLS.values())[0], # overridden dynamically on start
),
placeholder=(
"The conversation will appear here. "
"Type your message below to begin."
),
)
with gr.Row():
msg_box = gr.Textbox(
placeholder="Type your message as the doctor…",
label="Doctor's Message",
lines=1,
max_lines=5,
scale=5,
show_label=False,
container=False,
)
send_btn = gr.Button(
"Send", variant="primary", scale=1, min_width=80
)
# ── Event wiring ─────────────────────────────────────────────────────────
# Intro β†’ Setup
get_started_btn.click(
fn=go_to_setup,
outputs=[intro_section, setup_section],
)
# Patient card selection (one handler per button)
def _make_select_fn(target_id: str):
def _fn():
html_updates = [
gr.update(value=_build_single_card_html(PATIENT_DICT[hid], hid == target_id, av))
for hid, av, _ in _all_cards
]
return html_updates + [target_id]
return _fn
for _hadm_id, _select_btn in _all_select_btns:
_select_btn.click(
fn=_make_select_fn(_hadm_id),
outputs=[c for _, _, c in _all_cards] + [selected_patient_state],
)
# Tooltip updates
personality_dd.change(
fn=lambda v: _make_tip_html(PERSONALITY_TIPS, v),
inputs=[personality_dd],
outputs=[personality_tip],
)
cefr_radio.change(
fn=lambda v: _make_tip_html(CEFR_TIPS, v),
inputs=[cefr_radio],
outputs=[cefr_tip],
)
recall_radio.change(
fn=lambda v: _make_tip_html(RECALL_TIPS, v),
inputs=[recall_radio],
outputs=[recall_tip],
)
confusion_radio.change(
fn=lambda v: _make_tip_html(CONFUSION_TIPS, v),
inputs=[confusion_radio],
outputs=[confusion_tip],
)
# Start simulation β†’ mode selection
start_btn.click(
fn=start_simulation,
inputs=[selected_patient_state, model_dd, cefr_radio, personality_dd, recall_radio, confusion_radio, api_key_input],
outputs=[patient_agent_state, sim_config_state, setup_section, mode_section, recap_display, setup_error_display],
)
# Back to setup from mode selection
back_from_mode_btn.click(
fn=back_to_setup,
inputs=[patient_agent_state],
outputs=[patient_agent_state, sim_config_state, setup_section, mode_section, chat_section, auto_section, recap_display, setup_error_display],
)
# Auto simulation
auto_event = auto_btn.click(
fn=start_auto,
inputs=[patient_agent_state, sim_config_state],
outputs=[auto_chatbot, mode_section, auto_section, chat_section, auto_recap],
)
back_from_auto_to_mode_btn.click(
fn=back_to_mode_from_auto,
inputs=[patient_agent_state],
outputs=[auto_chatbot, auto_recap, mode_section, auto_section],
cancels=[auto_event],
)
back_from_auto_btn.click(
fn=back_to_setup_from_auto,
inputs=[patient_agent_state],
outputs=[auto_chatbot, auto_recap, patient_agent_state, sim_config_state, setup_section, mode_section, auto_section, chat_section, recap_display, setup_error_display],
cancels=[auto_event],
)
# Manual practice
manual_btn.click(
fn=start_manual,
inputs=[profile_mode_radio, patient_agent_state, sim_config_state],
outputs=[profile_display, chat_recap, chatbot, chat_history_state, mode_section, auto_section, chat_section],
)
chat_event_send = send_btn.click(
fn=chat,
inputs=[msg_box, chat_history_state, patient_agent_state, sim_config_state],
outputs=[chatbot, chat_history_state, msg_box],
)
chat_event_submit = msg_box.submit(
fn=chat,
inputs=[msg_box, chat_history_state, patient_agent_state, sim_config_state],
outputs=[chatbot, chat_history_state, msg_box],
)
back_from_chat_to_mode_btn.click(
fn=back_to_mode_from_chat,
inputs=[patient_agent_state],
outputs=[chatbot, profile_display, chat_recap, mode_section, chat_section],
cancels=[chat_event_send, chat_event_submit],
)
back_from_chat_btn.click(
fn=back_to_setup_from_chat,
inputs=[patient_agent_state],
outputs=[chatbot, profile_display, chat_recap, patient_agent_state, sim_config_state, setup_section, mode_section, auto_section, chat_section, recap_display, setup_error_display],
cancels=[chat_event_send, chat_event_submit],
)
reset_btn.click(
fn=reset_chat,
inputs=[patient_agent_state],
outputs=[chatbot, chat_history_state],
js="() => { if (!confirm('Reset the entire conversation? This cannot be undone.')) return null; }",
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)