dek924 commited on
Commit
38b2a99
Β·
1 Parent(s): f64ac5a

feat: app initialize

Browse files
Files changed (2) hide show
  1. app.py +499 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import copy
4
+ import gradio as gr
5
+ from patientsim import PatientAgent
6
+
7
+
8
+ # ---------------------------------------------------------------------------
9
+ # Constants
10
+ # ---------------------------------------------------------------------------
11
+ CEFR_CHOICES = [
12
+ ("A β€” Beginner", "A"),
13
+ ("B β€” Intermediate", "B"),
14
+ ("C β€” Advanced", "C"),
15
+ ]
16
+ PERSONALITY_CHOICES = [
17
+ ("Neutral", "plain"),
18
+ ("Talkative", "verbose"),
19
+ ("Distrustful", "distrust"),
20
+ ("Pleasing", "pleasing"),
21
+ ("Impatient", "impatient"),
22
+ ("Overanxious", "overanxious"),
23
+ ]
24
+ RECALL_CHOICES = [
25
+ ("No History", "no_history"),
26
+ ("Low", "low"),
27
+ ("High", "high"),
28
+ ]
29
+ CONFUSION_CHOICES = [
30
+ ("Normal", "normal"),
31
+ ("Moderate", "moderate"),
32
+ ("High", "high"),
33
+ ]
34
+ BACKEND_MODELS = [
35
+ "gemini-2.5-flash",
36
+ "gemini-3-flash-preview",
37
+ "gemini-3.1-flash-lite-preview",
38
+ ]
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Patient data
42
+ # ---------------------------------------------------------------------------
43
+ _DATA_PATH = os.path.join(os.path.dirname(__file__), "demo", "data", "demo_data.json")
44
+ with open(_DATA_PATH) as _f:
45
+ PATIENT_DATA: list[dict] = json.load(_f)
46
+ PATIENT_DICT: dict[str, dict] = {p["hadm_id"]: p for p in PATIENT_DATA}
47
+
48
+
49
+ def _patient_choices() -> list[tuple[str, str]]:
50
+ choices = []
51
+ for idx, p in enumerate(sorted(PATIENT_DATA, key=lambda x: x["hadm_id"])):
52
+ label = (
53
+ f"{idx + 1}. Age {p['age']} | {p['gender']} "
54
+ f"| {p.get('diagnosis', 'Unknown')}"
55
+ )
56
+ choices.append((label, p["hadm_id"]))
57
+ return choices
58
+
59
+
60
+ # ---------------------------------------------------------------------------
61
+ # HTML helpers
62
+ # ---------------------------------------------------------------------------
63
+ def _row(label: str, val) -> str:
64
+ val = str(val) if val not in (None, "") else "N/A"
65
+ return (
66
+ f"<tr>"
67
+ f"<td style='padding:3px 8px;white-space:nowrap;vertical-align:top'><b>{label}:</b></td>"
68
+ f"<td style='padding:3px 8px'>{val}</td>"
69
+ f"</tr>"
70
+ )
71
+
72
+
73
+ def _section(title: str, rows_html: str) -> str:
74
+ return (
75
+ f"<div style='font-weight:bold;color:#2c3e50;border-bottom:1px solid #ccc;"
76
+ f"margin:10px 0 4px;padding-bottom:2px'>{title}</div>"
77
+ f"<table style='border-collapse:collapse;font-size:13px;margin-bottom:4px'>"
78
+ f"{rows_html}</table>"
79
+ )
80
+
81
+
82
+ def build_profile_html(p: dict) -> str:
83
+ basic = (
84
+ "<h3 style='margin:0 0 6px'>Patient Profile</h3>"
85
+ f"<table style='border-collapse:collapse;font-size:14px;margin-bottom:4px'>"
86
+ f"{_row('Age', p.get('age'))}"
87
+ f"{_row('Gender', p.get('gender'))}"
88
+ f"{_row('Race', p.get('race'))}"
89
+ f"{_row('Arrival Transport', p.get('arrival_transport'))}"
90
+ f"</table>"
91
+ )
92
+ social = _section(
93
+ "Social History",
94
+ _row("Tobacco", p.get("tobacco"))
95
+ + _row("Alcohol", p.get("alcohol"))
96
+ + _row("Illicit Drug", p.get("illicit_drug"))
97
+ + _row("Exercise", p.get("exercise"))
98
+ + _row("Marital Status", p.get("marital_status"))
99
+ + _row("Children", p.get("children"))
100
+ + _row("Living Situation", p.get("living_situation"))
101
+ + _row("Occupation", p.get("occupation"))
102
+ + _row("Insurance", p.get("insurance")),
103
+ )
104
+ history = _section(
105
+ "Previous Medical History",
106
+ _row("Allergies", p.get("allergies"))
107
+ + _row("Family History", p.get("family_medical_history"))
108
+ + _row("Medical Devices", p.get("medical_device"))
109
+ + _row("Prior Medical History", p.get("medical_history")),
110
+ )
111
+ visit = _section(
112
+ "Current Visit",
113
+ _row("Present Illness (+)", p.get("present_illness_positive"))
114
+ + _row("Present Illness (βˆ’)", p.get("present_illness_negative"))
115
+ + _row("Chief Complaint", p.get("chiefcomplaint"))
116
+ + _row("Pain (0–10)", p.get("pain"))
117
+ + _row("Medications", p.get("medication"))
118
+ + _row("Disposition", p.get("disposition"))
119
+ + _row("Diagnosis", p.get("diagnosis")),
120
+ )
121
+ return f"<div style='font-family:Noto Sans KR,Noto Sans,Malgun Gothic,Apple SD Gothic Neo,sans-serif;font-size:14px;line-height:1.5'>{basic}{social}{history}{visit}</div>"
122
+
123
+
124
+ # ---------------------------------------------------------------------------
125
+ # Custom CSS
126
+ # ---------------------------------------------------------------------------
127
+ CUSTOM_CSS = """
128
+ /* ── Global gothic (sans-serif) font ───────────────────────────── */
129
+ @import url('https://fonts.googleapis.com/css2?family=Noto+Sans+KR:wght@300;400;500;600;700&family=Noto+Sans:wght@300;400;500;600;700&display=swap');
130
+
131
+ *, *::before, *::after {
132
+ font-family: 'Noto Sans KR', 'Noto Sans', 'Malgun Gothic', 'Apple SD Gothic Neo',
133
+ 'Segoe UI', sans-serif !important;
134
+ }
135
+ .prose h1, .prose h2, .prose h3, .prose h4,
136
+ .markdown-text h1, .markdown-text h2, .markdown-text h3, .markdown-text h4 {
137
+ font-family: 'Noto Sans KR', 'Noto Sans', 'Malgun Gothic', 'Apple SD Gothic Neo',
138
+ 'Segoe UI', sans-serif !important;
139
+ }
140
+
141
+ /* ── Compact radio pills ────────────────────────────────────────── */
142
+ .compact-radio .wrap {
143
+ gap: 6px !important;
144
+ }
145
+ .compact-radio label {
146
+ padding: 6px 14px !important;
147
+ border-radius: 8px !important;
148
+ font-size: 13px !important;
149
+ min-width: unset !important;
150
+ }
151
+
152
+ /* ── Tooltip descriptions ───────────────────────────────────────── */
153
+ .option-desc {
154
+ font-size: 12px;
155
+ color: var(--body-text-color-subdued);
156
+ margin-top: 2px;
157
+ line-height: 1.4;
158
+ }
159
+ """
160
+
161
+ # Tooltip maps for descriptions shown below each radio group
162
+ PERSONALITY_TIPS = {
163
+ "plain": "No strong emotions or noticeable behavior.",
164
+ "verbose": "Speaks a lot, gives highly detailed responses.",
165
+ "distrust": "Questions the doctor's expertise and care.",
166
+ "pleasing": "Overly positive, tends to minimize problems.",
167
+ "impatient": "Easily irritated and lacks patience.",
168
+ "overanxious": "Expresses concern beyond what is typical.",
169
+ }
170
+ RECALL_TIPS = {
171
+ "no_history": "No significant past medical history.",
172
+ "low": "Often forgets even major medical events.",
173
+ "high": "Usually recalls medical events accurately.",
174
+ }
175
+ CONFUSION_TIPS = {
176
+ "normal": "Clear mental status.",
177
+ "moderate": "Somewhat confused and off-topic at times.",
178
+ "high": "Highly dazed and extremely confused.",
179
+ }
180
+ CEFR_TIPS = {
181
+ "A": "Can make simple sentences.",
182
+ "B": "Can have daily conversations.",
183
+ "C": "Can freely use advanced medical terms.",
184
+ }
185
+
186
+
187
+ def _make_tip_html(tips: dict, selected: str) -> str:
188
+ desc = tips.get(selected, "")
189
+ return f"<p class='option-desc'>{desc}</p>" if desc else ""
190
+
191
+
192
+ # ---------------------------------------------------------------------------
193
+ # Callbacks
194
+ # ---------------------------------------------------------------------------
195
+ def start_simulation(
196
+ hadm_id: str,
197
+ api_key_input: str,
198
+ model: str,
199
+ cefr: str,
200
+ personality: str,
201
+ recall: str,
202
+ confusion: str,
203
+ ):
204
+ # Keep-current helper: returns all outputs unchanged so button resets cleanly
205
+ def _no_change():
206
+ return (
207
+ gr.update(), # patient_agent_state (keep)
208
+ gr.update(), # profile_display (keep)
209
+ gr.update(), # chatbot (keep)
210
+ gr.update(visible=True), # setup_section stays visible
211
+ gr.update(visible=False),# chat_section stays hidden
212
+ )
213
+
214
+ if not hadm_id:
215
+ gr.Warning("Please select a patient first.")
216
+ return _no_change()
217
+
218
+ # Resolve API key: user input β†’ env var
219
+ api_key = (
220
+ api_key_input.strip()
221
+ or os.environ.get("GOOGLE_API_KEY", "")
222
+ or os.environ.get("GENAI_API_KEY", "")
223
+ )
224
+ use_vertex = bool(os.environ.get("GOOGLE_PROJECT_ID", ""))
225
+
226
+ if not api_key and not use_vertex:
227
+ gr.Warning(
228
+ "API key is required. Enter your Google AI API key above, "
229
+ "or set the GOOGLE_API_KEY environment variable."
230
+ )
231
+ return _no_change()
232
+
233
+ patient = copy.deepcopy(PATIENT_DICT[hadm_id])
234
+
235
+ try:
236
+ agent = PatientAgent(
237
+ model=model,
238
+ visit_type="emergency_department",
239
+ personality=personality,
240
+ recall_level=recall,
241
+ confusion_level=confusion,
242
+ lang_proficiency_level=cefr,
243
+ api_key=api_key if not use_vertex else None,
244
+ use_vertex=use_vertex,
245
+ temperature=0.7,
246
+ num_word_sample=10,
247
+ random_seed=42,
248
+ log_verbose=False,
249
+ **patient,
250
+ )
251
+ except Exception as e:
252
+ gr.Warning(f"Failed to initialize patient agent: {e}")
253
+ return _no_change()
254
+
255
+ profile_html = build_profile_html(patient)
256
+ return (
257
+ agent,
258
+ profile_html,
259
+ [], # empty chat history
260
+ gr.update(visible=False), # hide setup section
261
+ gr.update(visible=True), # show chat section
262
+ )
263
+
264
+
265
+ def chat(message: str, history: list, agent):
266
+ if agent is None:
267
+ raise gr.Error("No simulation running. Please start a simulation first.")
268
+ if not message.strip():
269
+ return history, ""
270
+
271
+ response = agent(user_prompt=message, using_multi_turn=True, verbose=False)
272
+ history = history + [
273
+ {"role": "user", "content": message},
274
+ {"role": "assistant", "content": response},
275
+ ]
276
+ return history, ""
277
+
278
+
279
+ def reset_chat(agent):
280
+ if agent is not None:
281
+ agent.reset_history(verbose=False)
282
+ return []
283
+
284
+
285
+ def return_to_setup():
286
+ return (
287
+ None, # clear agent state
288
+ [], # clear chat
289
+ "", # clear profile html
290
+ gr.update(visible=True), # show setup section
291
+ gr.update(visible=False), # hide chat section
292
+ )
293
+
294
+
295
+ # ---------------------------------------------------------------------------
296
+ # UI
297
+ # ---------------------------------------------------------------------------
298
+ with gr.Blocks(title="PatientSim", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
299
+ patient_agent_state = gr.State(None)
300
+
301
+ # ── Header ──────────────────────────────────────────────────────────────
302
+ gr.Markdown(
303
+ "# πŸ₯ PatientSim β€” ED Consultation Demo\n"
304
+ "An interactive simulator for realistic doctor–patient interactions "
305
+ "([NeurIPS 2025](https://openreview.net/forum?id=1THAjdP4QJ)). "
306
+ "Configure a patient and persona below, then converse as the doctor."
307
+ )
308
+
309
+ # ── Setup section ────────────────────────────────────────────────────────
310
+ with gr.Column(visible=True) as setup_section:
311
+
312
+ # ── Row 1: Patient & Model & API Key ─────────────────────────────
313
+ with gr.Group():
314
+ gr.Markdown("#### πŸ”‘ Connection")
315
+ with gr.Row():
316
+ api_key_input = gr.Textbox(
317
+ label="Google AI API Key",
318
+ placeholder="Enter your API key (or set GOOGLE_API_KEY env var)",
319
+ type="password",
320
+ scale=2,
321
+ )
322
+ model_dd = gr.Dropdown(
323
+ choices=BACKEND_MODELS,
324
+ value=BACKEND_MODELS[0],
325
+ label="Backend Model",
326
+ scale=1,
327
+ )
328
+
329
+ # ── Row 2: Patient selector ──────────────────────────────────────
330
+ with gr.Group():
331
+ gr.Markdown("#### 🩺 Patient Case")
332
+ patient_dd = gr.Dropdown(
333
+ choices=_patient_choices(),
334
+ label="Select a patient case",
335
+ )
336
+
337
+ # ── Row 3: Persona settings (2 Γ— 2 grid) ────────────────────────
338
+ with gr.Group():
339
+ gr.Markdown("#### 🎭 Patient Persona")
340
+ with gr.Row():
341
+ with gr.Column(min_width=200):
342
+ personality_radio = gr.Radio(
343
+ choices=PERSONALITY_CHOICES,
344
+ value="plain",
345
+ label="Personality",
346
+ elem_classes=["compact-radio"],
347
+ )
348
+ personality_tip = gr.HTML(
349
+ value=_make_tip_html(PERSONALITY_TIPS, "plain")
350
+ )
351
+ with gr.Column(min_width=200):
352
+ cefr_radio = gr.Radio(
353
+ choices=CEFR_CHOICES,
354
+ value="C",
355
+ label="Language Proficiency (CEFR)",
356
+ elem_classes=["compact-radio"],
357
+ )
358
+ cefr_tip = gr.HTML(
359
+ value=_make_tip_html(CEFR_TIPS, "C")
360
+ )
361
+ with gr.Row():
362
+ with gr.Column(min_width=200):
363
+ recall_radio = gr.Radio(
364
+ choices=RECALL_CHOICES,
365
+ value="no_history",
366
+ label="Medical History Recall",
367
+ elem_classes=["compact-radio"],
368
+ )
369
+ recall_tip = gr.HTML(
370
+ value=_make_tip_html(RECALL_TIPS, "no_history")
371
+ )
372
+ with gr.Column(min_width=200):
373
+ confusion_radio = gr.Radio(
374
+ choices=CONFUSION_CHOICES,
375
+ value="normal",
376
+ label="Cognitive Confusion",
377
+ elem_classes=["compact-radio"],
378
+ )
379
+ confusion_tip = gr.HTML(
380
+ value=_make_tip_html(CONFUSION_TIPS, "normal")
381
+ )
382
+
383
+ start_btn = gr.Button("β–Ά Start Simulation", variant="primary", size="lg")
384
+
385
+ # ── Chat section ─────────────────────────────────────────────────────────
386
+ with gr.Column(visible=False) as chat_section:
387
+ with gr.Row():
388
+ return_btn = gr.Button("← Return to Patient Selection", scale=1)
389
+ reset_btn = gr.Button("β†Ί Reset Conversation", scale=1)
390
+
391
+ with gr.Row(equal_height=False):
392
+ with gr.Column(scale=1, min_width=280):
393
+ profile_display = gr.HTML()
394
+
395
+ with gr.Column(scale=2):
396
+ chatbot = gr.Chatbot(
397
+ label="Consultation",
398
+ height=560,
399
+ show_label=True,
400
+ avatar_images=(
401
+ "https://cdn-icons-png.flaticon.com/512/4140/4140048.png",
402
+ "https://cdn-icons-png.flaticon.com/512/3774/3774299.png",
403
+ ),
404
+ placeholder=(
405
+ "The conversation will appear here. "
406
+ "Type your message below to begin."
407
+ ),
408
+ )
409
+ with gr.Row():
410
+ msg_box = gr.Textbox(
411
+ placeholder="Type your message as the doctor…",
412
+ label="Doctor's Message",
413
+ lines=1,
414
+ scale=5,
415
+ show_label=False,
416
+ container=False,
417
+ )
418
+ send_btn = gr.Button(
419
+ "Send", variant="primary", scale=1, min_width=80
420
+ )
421
+
422
+ # ── Event wiring ─────────────────────────────────────────────────────────
423
+
424
+ # Tooltip updates when radio selection changes
425
+ personality_radio.change(
426
+ fn=lambda v: _make_tip_html(PERSONALITY_TIPS, v),
427
+ inputs=[personality_radio],
428
+ outputs=[personality_tip],
429
+ )
430
+ cefr_radio.change(
431
+ fn=lambda v: _make_tip_html(CEFR_TIPS, v),
432
+ inputs=[cefr_radio],
433
+ outputs=[cefr_tip],
434
+ )
435
+ recall_radio.change(
436
+ fn=lambda v: _make_tip_html(RECALL_TIPS, v),
437
+ inputs=[recall_radio],
438
+ outputs=[recall_tip],
439
+ )
440
+ confusion_radio.change(
441
+ fn=lambda v: _make_tip_html(CONFUSION_TIPS, v),
442
+ inputs=[confusion_radio],
443
+ outputs=[confusion_tip],
444
+ )
445
+
446
+ # Start simulation
447
+ _start_outputs = [
448
+ patient_agent_state,
449
+ profile_display,
450
+ chatbot,
451
+ setup_section,
452
+ chat_section,
453
+ ]
454
+ start_btn.click(
455
+ fn=start_simulation,
456
+ inputs=[
457
+ patient_dd,
458
+ api_key_input,
459
+ model_dd,
460
+ cefr_radio,
461
+ personality_radio,
462
+ recall_radio,
463
+ confusion_radio,
464
+ ],
465
+ outputs=_start_outputs,
466
+ )
467
+
468
+ # Chat
469
+ _chat_outputs = [chatbot, msg_box]
470
+ send_btn.click(
471
+ fn=chat,
472
+ inputs=[msg_box, chatbot, patient_agent_state],
473
+ outputs=_chat_outputs,
474
+ )
475
+ msg_box.submit(
476
+ fn=chat,
477
+ inputs=[msg_box, chatbot, patient_agent_state],
478
+ outputs=_chat_outputs,
479
+ )
480
+
481
+ # Reset & return
482
+ reset_btn.click(
483
+ fn=reset_chat,
484
+ inputs=[patient_agent_state],
485
+ outputs=[chatbot],
486
+ )
487
+
488
+ _return_outputs = [
489
+ patient_agent_state,
490
+ chatbot,
491
+ profile_display,
492
+ setup_section,
493
+ chat_section,
494
+ ]
495
+ return_btn.click(fn=return_to_setup, outputs=_return_outputs)
496
+
497
+
498
+ if __name__ == "__main__":
499
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ patientsim
2
+ gradio
3
+ jsonlines
4
+ pyyaml
5
+ numpy