spriambada3 commited on
Commit
2a26805
·
1 Parent(s): d18fef3

add ai diagnosis and medication interaction

Browse files
Files changed (3) hide show
  1. cdss.py +22 -0
  2. diagnosis.py +117 -0
  3. simulator.py +40 -2
cdss.py CHANGED
@@ -35,6 +35,7 @@ from simulator import (
35
  )
36
  from editor import editor_ui, save_rules
37
  from validator import validator_ui, test_condition, add_rule_to_set
 
38
 
39
 
40
  # --- Build UI ---
@@ -67,6 +68,14 @@ with gr.Blocks(
67
  add_rule_button,
68
  add_rule_status,
69
  ) = validator_ui()
 
 
 
 
 
 
 
 
70
 
71
  with gr.Row():
72
  with gr.Column(scale=2):
@@ -111,6 +120,16 @@ with gr.Blocks(
111
  historic_box = gr.Textbox(label="Historic Text", lines=12, interactive=False)
112
 
113
  # --- Event Handlers ---
 
 
 
 
 
 
 
 
 
 
114
  ui_outputs = [
115
  state,
116
  scenario_lbl,
@@ -181,6 +200,9 @@ with gr.Blocks(
181
  gr.Timer(30.0).tick(tick_timer, timer_inputs, ui_outputs)
182
  gr.Timer(1.0).tick(countdown_tick, [last_tick_ts], [countdown_lbl])
183
 
 
 
 
184
  demo.load(inject_scenario, [gr.State("A0"), cdss_toggle, history_df, historic_text], ui_outputs)
185
 
186
  if __name__ == "__main__":
 
35
  )
36
  from editor import editor_ui, save_rules
37
  from validator import validator_ui, test_condition, add_rule_to_set
38
+ from diagnosis import diagnosis_ui, generate_diagnosis, check_medication_interaction
39
 
40
 
41
  # --- Build UI ---
 
68
  add_rule_button,
69
  add_rule_status,
70
  ) = validator_ui()
71
+ (
72
+ generate_button,
73
+ diagnosis_output,
74
+ medication_output,
75
+ medication_input,
76
+ check_button,
77
+ interaction_output,
78
+ ) = diagnosis_ui()
79
 
80
  with gr.Row():
81
  with gr.Column(scale=2):
 
120
  historic_box = gr.Textbox(label="Historic Text", lines=12, interactive=False)
121
 
122
  # --- Event Handlers ---
123
+ def update_medication_input(patient_type):
124
+ if patient_type == "Mother":
125
+ return gr.update(value="Aspirin, Ibuprofen")
126
+ elif patient_type == "Gyn":
127
+ return gr.update(value="Clopidogrel, Omeprazole")
128
+ elif patient_type == "Neonate":
129
+ return gr.update(value="Ceftriaxone, Calcium")
130
+
131
+ patient_type_radio.change(update_medication_input, inputs=patient_type_radio, outputs=medication_input)
132
+
133
  ui_outputs = [
134
  state,
135
  scenario_lbl,
 
200
  gr.Timer(30.0).tick(tick_timer, timer_inputs, ui_outputs)
201
  gr.Timer(1.0).tick(countdown_tick, [last_tick_ts], [countdown_lbl])
202
 
203
+ generate_button.click(generate_diagnosis, inputs=state, outputs=[diagnosis_output, medication_output])
204
+ check_button.click(check_medication_interaction, inputs=[patient_type_radio, medication_input], outputs=interaction_output)
205
+
206
  demo.load(inject_scenario, [gr.State("A0"), cdss_toggle, history_df, historic_text], ui_outputs)
207
 
208
  if __name__ == "__main__":
diagnosis.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CDSS Diagnosis Component
3
+ """
4
+
5
+ import gradio as gr
6
+ import os
7
+ import google.generativeai as genai
8
+ from google.generativeai.types import HarmCategory, HarmBlockThreshold
9
+ from models import PatientState, Vitals
10
+
11
+ GEMINI_MODEL_NAME = "gemini-2.5-flash"
12
+ # --- Gemini setup (simplified) ---
13
+ try:
14
+ genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
15
+ GEMINI_MODEL = genai.GenerativeModel(GEMINI_MODEL_NAME)
16
+ GEMINI_ERR = None
17
+ except Exception as e:
18
+ GEMINI_MODEL, GEMINI_ERR = None, f"Gemini import/config error: {e}"
19
+
20
+
21
+ def generate_diagnosis(patient_state: dict) -> tuple[str, str]:
22
+ if not GEMINI_MODEL:
23
+ return f"[CDSS AI ERROR] {GEMINI_ERR}", ""
24
+
25
+ ps = PatientState(**patient_state)
26
+ ps.vitals = Vitals(**ps.vitals)
27
+
28
+ prompt = f"""Generate a diagnosis in a medical record statement format (Subjective, Anamnese, Plan, Objective) and provide medication recommendations for the following patient data:
29
+
30
+ - Patient Type: {ps.patient_type}
31
+ - Vitals: {ps.vitals}
32
+ - Labs: {ps.labs}
33
+ - Notes: {ps.notes}
34
+ """
35
+
36
+ try:
37
+ response = GEMINI_MODEL.generate_content(
38
+ prompt,
39
+ safety_settings={
40
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
41
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
42
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
43
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
44
+ },
45
+ )
46
+ print(response)
47
+ if response.parts:
48
+ # simple parsing, assuming the response is well-formed
49
+ diagnosis = response.text.split("Medication Recommendations:")[0]
50
+ medication = response.text.split("Medication Recommendations:")[1]
51
+ return diagnosis, medication
52
+ else:
53
+ return "No response from AI.", ""
54
+ except Exception as e:
55
+ return f"[CDSS AI error] {e}", ""
56
+
57
+
58
+ def check_medication_interaction(patient_type: str, medications: str) -> str:
59
+ if not GEMINI_MODEL:
60
+ return f"[CDSS AI ERROR] {GEMINI_ERR}"
61
+
62
+ prompt = f"""Check for dangerous medication interactions in the following list of medications for a {patient_type} patient: {medications}.
63
+
64
+ Provide a clear warning if any dangerous interactions are found."""
65
+
66
+ try:
67
+ response = GEMINI_MODEL.generate_content(
68
+ prompt,
69
+ safety_settings={
70
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
71
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
72
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
73
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
74
+ },
75
+ )
76
+ print(response)
77
+ if response.parts:
78
+ return response.text
79
+ else:
80
+ return "No response from AI."
81
+ except Exception as e:
82
+ return f"[CDSS AI error] {e}"
83
+
84
+
85
+ def diagnosis_ui():
86
+ with gr.TabItem("Diagnosis"):
87
+ with gr.Row():
88
+ with gr.Column():
89
+ gr.Markdown("## Generate Diagnosis and Medication")
90
+ generate_button = gr.Button("Generate", variant="primary")
91
+ diagnosis_output = gr.Textbox(
92
+ label="Diagnosis (S.O.A.P)", lines=10, interactive=False
93
+ )
94
+ medication_output = gr.Textbox(
95
+ label="Medication Recommendations", lines=5, interactive=False
96
+ )
97
+
98
+ with gr.Column():
99
+ gr.Markdown("## Check Medication Interaction")
100
+ medication_input = gr.Textbox(
101
+ label="Medications (comma-separated)",
102
+ lines=3,
103
+ value="Aspirin, Ibuprofen", # Default for Mother
104
+ )
105
+ check_button = gr.Button("Check Interaction", variant="secondary")
106
+ interaction_output = gr.Textbox(
107
+ label="Interaction Result", lines=10, interactive=False
108
+ )
109
+
110
+ return (
111
+ generate_button,
112
+ diagnosis_output,
113
+ medication_output,
114
+ medication_input,
115
+ check_button,
116
+ interaction_output,
117
+ )
simulator.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  CDSS Simulator Component
3
  """
 
4
  import random
5
  import time
6
  from dataclasses import asdict
@@ -13,13 +14,15 @@ import plotly.express as px
13
 
14
  from models import Vitals, PatientState
15
  from rules import rule_based_cdss
 
16
 
17
  # --- Gemini setup (simplified) ---
18
  try:
19
  import google.generativeai as genai
20
  import os
 
21
  genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
22
- GEMINI_MODEL = genai.GenerativeModel("gemini-1.5-pro")
23
  GEMINI_ERR = None
24
  except Exception as e:
25
  GEMINI_MODEL, GEMINI_ERR = None, f"Gemini import/config error: {e}"
@@ -27,6 +30,7 @@ except Exception as e:
27
 
28
  # --- Data structures & Scenarios (Full list included) ---
29
 
 
30
  def scenario_A0_Normal() -> PatientState:
31
  return PatientState(
32
  "A0 Normal Case",
@@ -36,6 +40,7 @@ def scenario_A0_Normal() -> PatientState:
36
  Vitals(110, 70, 80, 16, 36.7, 99),
37
  )
38
 
 
39
  def scenario_A1_PPH() -> PatientState:
40
  return PatientState(
41
  "A1 PPH",
@@ -45,6 +50,7 @@ def scenario_A1_PPH() -> PatientState:
45
  Vitals(90, 60, 120, 24, 36.8, 96),
46
  )
47
 
 
48
  def scenario_A2_Preeclampsia() -> PatientState:
49
  return PatientState(
50
  "A2 Preeklampsia",
@@ -54,6 +60,7 @@ def scenario_A2_Preeclampsia() -> PatientState:
54
  Vitals(165, 105, 98, 20, 36.9, 98),
55
  )
56
 
 
57
  def scenario_A3_MaternalSepsis() -> PatientState:
58
  return PatientState(
59
  "A3 Sepsis Maternal",
@@ -63,6 +70,7 @@ def scenario_A3_MaternalSepsis() -> PatientState:
63
  Vitals(95, 60, 110, 24, 39.0, 96),
64
  )
65
 
 
66
  def scenario_B1_Prematurity() -> PatientState:
67
  return PatientState(
68
  "B1 Prematuritas/BBLR",
@@ -72,6 +80,7 @@ def scenario_B1_Prematurity() -> PatientState:
72
  Vitals(60, 35, 150, 50, 35.0, 90),
73
  )
74
 
 
75
  def scenario_B2_Asphyxia() -> PatientState:
76
  return PatientState(
77
  "B2 Asfiksia Perinatal",
@@ -81,6 +90,7 @@ def scenario_B2_Asphyxia() -> PatientState:
81
  Vitals(55, 30, 80, 10, 36.5, 82),
82
  )
83
 
 
84
  def scenario_B3_NeonatalSepsis() -> PatientState:
85
  return PatientState(
86
  "B3 Sepsis Neonatal",
@@ -90,6 +100,7 @@ def scenario_B3_NeonatalSepsis() -> PatientState:
90
  Vitals(60, 35, 170, 60, 38.5, 93),
91
  )
92
 
 
93
  def scenario_C1_GynSurgComp() -> PatientState:
94
  return PatientState(
95
  "C1 Komplikasi Bedah Ginekologis",
@@ -99,6 +110,7 @@ def scenario_C1_GynSurgComp() -> PatientState:
99
  Vitals(100, 65, 105, 20, 37.8, 98),
100
  )
101
 
 
102
  def scenario_C2_PostOpInfection() -> PatientState:
103
  return PatientState(
104
  "C2 Infeksi Pasca-Bedah",
@@ -108,6 +120,7 @@ def scenario_C2_PostOpInfection() -> PatientState:
108
  Vitals(105, 70, 108, 22, 38.0, 98),
109
  )
110
 
 
111
  def scenario_C3_DelayedGynCancer() -> PatientState:
112
  return PatientState(
113
  "C3 Keterlambatan Diagnostik Kanker Ginekologi",
@@ -117,6 +130,7 @@ def scenario_C3_DelayedGynCancer() -> PatientState:
117
  Vitals(120, 78, 86, 18, 36.8, 99),
118
  )
119
 
 
120
  SCENARIOS = {
121
  "A0": scenario_A0_Normal,
122
  "A1": scenario_A1_PPH,
@@ -145,13 +159,27 @@ def drift_vitals(state: PatientState) -> PatientState:
145
 
146
  # --- Rule-based fallback (no AI or AI disabled) ---
147
 
 
148
  def gemini_cdss(state: PatientState) -> str:
149
  if not GEMINI_MODEL:
150
  return f"[CDSS AI ERROR] {GEMINI_ERR}"
151
  try:
152
  v = state.vitals
153
  prompt = f"CDSS for {state.scenario}. Vitals: SBP {v.sbp}/{v.dbp}, HR {v.hr}. Analyze risks, give concise steps in Indonesian."
154
- return GEMINI_MODEL.generate_content(prompt).text or "[CDSS AI] No response."
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  except Exception as e:
156
  return f"[CDSS AI error] {e}"
157
 
@@ -180,9 +208,11 @@ def create_vital_plot(
180
  )
181
  return fig
182
 
 
183
  def _row_from_state(ps: PatientState) -> Dict[str, Any]:
184
  return {"timestamp": datetime.now(), "scenario": ps.scenario, **asdict(ps.vitals)}
185
 
 
186
  def prepare_df_for_display(df: pd.DataFrame) -> pd.DataFrame:
187
  if df is None or df.empty:
188
  return pd.DataFrame(
@@ -203,6 +233,7 @@ def prepare_df_for_display(df: pd.DataFrame) -> pd.DataFrame:
203
  df_display["timestamp"] = df_display["timestamp"].dt.strftime("%Y-%m-%d %H:%M:%S")
204
  return df_display
205
 
 
206
  def generate_all_plots(df: pd.DataFrame):
207
  """Helper to generate all 5 plot figures from a dataframe."""
208
  df_display = prepare_df_for_display(df)
@@ -257,6 +288,7 @@ def process_and_update(
257
  spo2_fig,
258
  )
259
 
 
260
  def state_to_panels(state: PatientState) -> Tuple:
261
  v = state.vitals
262
  return (
@@ -271,6 +303,7 @@ def state_to_panels(state: PatientState) -> Tuple:
271
  v.spo2,
272
  )
273
 
 
274
  def inject_scenario(
275
  tag: str, cdss_on: bool, history_df: pd.DataFrame, historic_text: str
276
  ):
@@ -283,6 +316,7 @@ def inject_scenario(
283
  )
284
  return process_and_update(ps, history_df, historic_text, cdss_on)
285
 
 
286
  def manual_edit(
287
  sbp,
288
  dbp,
@@ -313,6 +347,7 @@ def manual_edit(
313
  historic_text += f"\n[{datetime.now().strftime('%H:%M:%S')}] {ps.notes}"
314
  return process_and_update(ps, history_df, historic_text, cdss_on)
315
 
 
316
  def tick_timer(cdss_on, current_state, history_df, historic_text):
317
  if not current_state:
318
  return [gr.update()] * 22
@@ -321,6 +356,7 @@ def tick_timer(cdss_on, current_state, history_df, historic_text):
321
  ps = drift_vitals(ps)
322
  return process_and_update(ps, history_df, historic_text, cdss_on)
323
 
 
324
  def load_csv(file, history_df: pd.DataFrame):
325
  try:
326
  if file is not None:
@@ -338,11 +374,13 @@ def load_csv(file, history_df: pd.DataFrame):
338
  )
339
  return history_df, df_for_table, bp_fig, hr_fig, rr_fig, temp_fig, spo2_fig
340
 
 
341
  def countdown_tick(last_tick_ts: float):
342
  if not last_tick_ts:
343
  return "Next update in —"
344
  return f"Next update in {max(0, 30 - int(time.time() - last_tick_ts))}s"
345
 
 
346
  def simulator_ui():
347
  with gr.TabItem("CDSS Simulator"):
348
  with gr.Accordion("History, Trends, and Data Loading", open=True):
 
1
  """
2
  CDSS Simulator Component
3
  """
4
+
5
  import random
6
  import time
7
  from dataclasses import asdict
 
14
 
15
  from models import Vitals, PatientState
16
  from rules import rule_based_cdss
17
+ from google.generativeai.types import HarmCategory, HarmBlockThreshold
18
 
19
  # --- Gemini setup (simplified) ---
20
  try:
21
  import google.generativeai as genai
22
  import os
23
+
24
  genai.configure(api_key=os.environ["GOOGLE_API_KEY"])
25
+ GEMINI_MODEL = genai.GenerativeModel("gemini-2.5-flash")
26
  GEMINI_ERR = None
27
  except Exception as e:
28
  GEMINI_MODEL, GEMINI_ERR = None, f"Gemini import/config error: {e}"
 
30
 
31
  # --- Data structures & Scenarios (Full list included) ---
32
 
33
+
34
  def scenario_A0_Normal() -> PatientState:
35
  return PatientState(
36
  "A0 Normal Case",
 
40
  Vitals(110, 70, 80, 16, 36.7, 99),
41
  )
42
 
43
+
44
  def scenario_A1_PPH() -> PatientState:
45
  return PatientState(
46
  "A1 PPH",
 
50
  Vitals(90, 60, 120, 24, 36.8, 96),
51
  )
52
 
53
+
54
  def scenario_A2_Preeclampsia() -> PatientState:
55
  return PatientState(
56
  "A2 Preeklampsia",
 
60
  Vitals(165, 105, 98, 20, 36.9, 98),
61
  )
62
 
63
+
64
  def scenario_A3_MaternalSepsis() -> PatientState:
65
  return PatientState(
66
  "A3 Sepsis Maternal",
 
70
  Vitals(95, 60, 110, 24, 39.0, 96),
71
  )
72
 
73
+
74
  def scenario_B1_Prematurity() -> PatientState:
75
  return PatientState(
76
  "B1 Prematuritas/BBLR",
 
80
  Vitals(60, 35, 150, 50, 35.0, 90),
81
  )
82
 
83
+
84
  def scenario_B2_Asphyxia() -> PatientState:
85
  return PatientState(
86
  "B2 Asfiksia Perinatal",
 
90
  Vitals(55, 30, 80, 10, 36.5, 82),
91
  )
92
 
93
+
94
  def scenario_B3_NeonatalSepsis() -> PatientState:
95
  return PatientState(
96
  "B3 Sepsis Neonatal",
 
100
  Vitals(60, 35, 170, 60, 38.5, 93),
101
  )
102
 
103
+
104
  def scenario_C1_GynSurgComp() -> PatientState:
105
  return PatientState(
106
  "C1 Komplikasi Bedah Ginekologis",
 
110
  Vitals(100, 65, 105, 20, 37.8, 98),
111
  )
112
 
113
+
114
  def scenario_C2_PostOpInfection() -> PatientState:
115
  return PatientState(
116
  "C2 Infeksi Pasca-Bedah",
 
120
  Vitals(105, 70, 108, 22, 38.0, 98),
121
  )
122
 
123
+
124
  def scenario_C3_DelayedGynCancer() -> PatientState:
125
  return PatientState(
126
  "C3 Keterlambatan Diagnostik Kanker Ginekologi",
 
130
  Vitals(120, 78, 86, 18, 36.8, 99),
131
  )
132
 
133
+
134
  SCENARIOS = {
135
  "A0": scenario_A0_Normal,
136
  "A1": scenario_A1_PPH,
 
159
 
160
  # --- Rule-based fallback (no AI or AI disabled) ---
161
 
162
+
163
  def gemini_cdss(state: PatientState) -> str:
164
  if not GEMINI_MODEL:
165
  return f"[CDSS AI ERROR] {GEMINI_ERR}"
166
  try:
167
  v = state.vitals
168
  prompt = f"CDSS for {state.scenario}. Vitals: SBP {v.sbp}/{v.dbp}, HR {v.hr}. Analyze risks, give concise steps in Indonesian."
169
+ response = GEMINI_MODEL.generate_content(
170
+ prompt,
171
+ safety_settings={
172
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
173
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
174
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
175
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
176
+ },
177
+ )
178
+ print(response)
179
+ if response.parts:
180
+ return response.text or "[CDSS AI] No response."
181
+ else:
182
+ return "[CDSS AI] No response due to safety settings."
183
  except Exception as e:
184
  return f"[CDSS AI error] {e}"
185
 
 
208
  )
209
  return fig
210
 
211
+
212
  def _row_from_state(ps: PatientState) -> Dict[str, Any]:
213
  return {"timestamp": datetime.now(), "scenario": ps.scenario, **asdict(ps.vitals)}
214
 
215
+
216
  def prepare_df_for_display(df: pd.DataFrame) -> pd.DataFrame:
217
  if df is None or df.empty:
218
  return pd.DataFrame(
 
233
  df_display["timestamp"] = df_display["timestamp"].dt.strftime("%Y-%m-%d %H:%M:%S")
234
  return df_display
235
 
236
+
237
  def generate_all_plots(df: pd.DataFrame):
238
  """Helper to generate all 5 plot figures from a dataframe."""
239
  df_display = prepare_df_for_display(df)
 
288
  spo2_fig,
289
  )
290
 
291
+
292
  def state_to_panels(state: PatientState) -> Tuple:
293
  v = state.vitals
294
  return (
 
303
  v.spo2,
304
  )
305
 
306
+
307
  def inject_scenario(
308
  tag: str, cdss_on: bool, history_df: pd.DataFrame, historic_text: str
309
  ):
 
316
  )
317
  return process_and_update(ps, history_df, historic_text, cdss_on)
318
 
319
+
320
  def manual_edit(
321
  sbp,
322
  dbp,
 
347
  historic_text += f"\n[{datetime.now().strftime('%H:%M:%S')}] {ps.notes}"
348
  return process_and_update(ps, history_df, historic_text, cdss_on)
349
 
350
+
351
  def tick_timer(cdss_on, current_state, history_df, historic_text):
352
  if not current_state:
353
  return [gr.update()] * 22
 
356
  ps = drift_vitals(ps)
357
  return process_and_update(ps, history_df, historic_text, cdss_on)
358
 
359
+
360
  def load_csv(file, history_df: pd.DataFrame):
361
  try:
362
  if file is not None:
 
374
  )
375
  return history_df, df_for_table, bp_fig, hr_fig, rr_fig, temp_fig, spo2_fig
376
 
377
+
378
  def countdown_tick(last_tick_ts: float):
379
  if not last_tick_ts:
380
  return "Next update in —"
381
  return f"Next update in {max(0, 30 - int(time.time() - last_tick_ts))}s"
382
 
383
+
384
  def simulator_ui():
385
  with gr.TabItem("CDSS Simulator"):
386
  with gr.Accordion("History, Trends, and Data Loading", open=True):