enshicoolsoda commited on
Commit
82c62db
·
verified ·
1 Parent(s): cbfbd67

try with different prompts

Browse files
Files changed (1) hide show
  1. app.py +251 -61
app.py CHANGED
@@ -2,7 +2,6 @@ import json
2
  import os
3
  import requests
4
  import gradio as gr
5
- import pandas as pd
6
 
7
  # -----------------------------
8
  # 1. Configuration & Data Mapping
@@ -15,7 +14,6 @@ CANCER_MAP = {
15
  "Head and Neck Cancer": "data/hnsc_combined_data.json",
16
  }
17
 
18
- # Map for the Ground Truth JSON keys
19
  GT_MAP = {
20
  "Uterine Cancer": "UCEC",
21
  "Breast Cancer": "BRCA",
@@ -24,14 +22,36 @@ GT_MAP = {
24
  "Head and Neck Cancer": "HNSC",
25
  }
26
 
27
- COMMON_AGENTS = ["Carboplatin", "Paclitaxel", "Cisplatin", "Gemcitabine", "Doxorubicin", "Other"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # -----------------------------
30
  # 2. AI Backend Function
31
  # -----------------------------
32
  def ollama_chat(messages, temperature=0.1):
33
  endpoint = os.getenv("OLLAMA_ENDPOINT")
34
- if not endpoint: return "Error: Endpoint not set."
35
 
36
  url = f"{endpoint}/api/chat"
37
  headers = {"Content-Type": "application/json", "ngrok-skip-browser-warning": "true"}
@@ -44,101 +64,271 @@ def ollama_chat(messages, temperature=0.1):
44
  try:
45
  r = requests.post(url, json=payload, headers=headers, timeout=120)
46
  return r.json().get("message", {}).get("content", "")
47
- except: return "Connection Error"
 
48
 
49
  # -----------------------------
50
- # 3. Evaluation Logic
51
  # -----------------------------
52
- def run_evaluation(cancer_type):
53
- # 1. Load Data
54
  data_path = CANCER_MAP.get(cancer_type)
55
  gt_path = "data/ground_truth_5yr_recurrence.json"
56
 
57
  if not os.path.exists(data_path) or not os.path.exists(gt_path):
58
- return "Error: Missing data or ground truth files."
 
59
 
60
  with open(data_path, 'r') as f: patient_db = json.load(f)
61
  with open(gt_path, 'r') as f: all_gt = json.load(f)
62
 
63
  gt_labels = all_gt.get(GT_MAP[cancer_type], {})
64
-
65
- # 2. Filter patients present in both
66
  eval_ids = [pid for pid in gt_labels.keys() if pid in patient_db]
67
 
68
- results = []
 
 
69
  tp, tn, fp, fn = 0, 0, 0, 0
70
-
71
- yield f"Starting inference for {len(eval_ids)} patients in {cancer_type}..."
72
 
73
  for i, pid in enumerate(eval_ids):
74
- actual = gt_labels[pid] # "Yes" or "No"
75
  patient_json = json.dumps(patient_db[pid])
76
 
77
- # Zero-shot prompt
78
- eval_prompt = [
79
- {"role": "system", "content": "You are an oncology expert. Predict 5-year recurrence based ONLY on the provided JSON. Respond strictly with 'Yes' or 'No' and nothing else."},
80
- {"role": "user", "content": f"Patient Data: {patient_json}"}
81
- ]
82
 
83
- prediction_raw = ollama_chat(eval_prompt).strip()
84
- # Simple parser to find Yes/No in response
85
- prediction = "Yes" if "yes" in prediction_raw.lower() else "No"
 
 
 
 
86
 
87
- # Calculate Metrics
88
- if prediction == "Yes" and actual == "Yes": tp += 1
89
- elif prediction == "No" and actual == "No": tn += 1
90
- elif prediction == "Yes" and actual == "No": fp += 1
91
- elif prediction == "No" and actual == "Yes": fn += 1
92
 
93
- if i % 5 == 0:
94
- yield f"Processed {i+1}/{len(eval_ids)} patients..."
95
 
96
- # 3. Final Metric Calculation
97
- acc = (tp + tn) / len(eval_ids) if eval_ids else 0
 
98
  sens = tp / (tp + fn) if (tp + fn) > 0 else 0
99
  spec = tn / (tn + fp) if (tn + fp) > 0 else 0
100
 
101
- summary = f"""
102
- ### Evaluation Results: {cancer_type}
103
- - **Total Patients Processed:** {len(eval_ids)}
104
- - **Unweighted Accuracy:** {acc:.2%}
105
  - **Sensitivity (Recall):** {sens:.2%}
106
  - **Specificity:** {spec:.2%}
107
 
108
- *Confusion Matrix: TP={tp}, TN={tn}, FP={fp}, FN={fn}*
 
 
 
 
109
  """
110
- yield summary
111
 
112
  # -----------------------------
113
- # 4. UI Layout (Modified)
114
  # -----------------------------
115
- with gr.Blocks(title="OncoRisk Eval & Demo") as demo:
116
- gr.HTML('<div style="text-align:center"><h1>Oncology Risk Assistant</h1></div>')
117
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  with gr.Tabs():
119
- # Tab 1: Your original Chat/Simulation UI
120
- with gr.TabItem("Clinical Assistant"):
 
 
 
 
 
 
 
 
 
121
  with gr.Row():
122
  with gr.Column(scale=1):
123
- cancer_select = gr.Dropdown(label="Select Cancer Type", choices=list(CANCER_MAP.keys()))
124
- patient_select = gr.Dropdown(label="Select Patient ID")
125
- submit_btn = gr.Button("Analyze Case", variant="primary")
126
- missing_output = gr.HighlightedText(label="Completeness")
127
  with gr.Column(scale=2):
128
- chatbot = gr.Chatbot(height=500)
129
- msg_input = gr.Textbox(label="Input Box", lines=5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- # Tab 2: NEW Evaluation Engine
132
- with gr.TabItem("Performance Metrics (Zero-Shot)"):
133
- gr.Markdown("### Run Zero-Shot Inference on Ground Truth")
134
- eval_cancer_type = gr.Dropdown(label="Select Cancer for Evaluation", choices=list(CANCER_MAP.keys()))
135
- run_eval_btn = gr.Button("Start Experiment", variant="secondary")
136
- eval_results = gr.Markdown("Results will appear here after inference...")
137
 
138
- # Logic for Evaluation
139
- run_eval_btn.click(fn=run_evaluation, inputs=eval_cancer_type, outputs=eval_results)
140
 
141
- # (Keep your existing Event Logic for Chat/Data Selection here...)
142
- # ... [Same as your provided code] ...
143
 
144
- demo.launch()
 
2
  import os
3
  import requests
4
  import gradio as gr
 
5
 
6
  # -----------------------------
7
  # 1. Configuration & Data Mapping
 
14
  "Head and Neck Cancer": "data/hnsc_combined_data.json",
15
  }
16
 
 
17
  GT_MAP = {
18
  "Uterine Cancer": "UCEC",
19
  "Breast Cancer": "BRCA",
 
22
  "Head and Neck Cancer": "HNSC",
23
  }
24
 
25
+ COMMON_AGENTS = ["Carboplatin", "Paclitaxel", "Cisplatin", "Gemcitabine", "Doxorubicin", "Tamoxifen", "Other"]
26
+
27
+ # --- Prompt Templates ---
28
+ PROMPT_DIRECT = "You are an oncology expert. Predict 5-year recurrence based ONLY on the provided JSON. Respond strictly with 'Yes' or 'No' and nothing else."
29
+
30
+ PROMPT_COT = """You are an oncology expert. Predict 5-year cancer recurrence.
31
+ Process:
32
+ 1. Analyze demographics and tumor stage.
33
+ 2. Evaluate treatment timeline and dosages.
34
+ 3. Identify risk factors.
35
+ 4. State your final prediction.
36
+
37
+ Constraint: You must end your response with 'FINAL_PREDICTION: YES' or 'FINAL_PREDICTION: NO'."""
38
+
39
+ PROMPT_GRADING = """You are a clinical oncology researcher. Evaluate 5-year recurrence risk by grading:
40
+ - Tumor Burden (Stage/Grade)
41
+ - Treatment Adequacy (Agents/Duration)
42
+ - Patient Baseline
43
+
44
+ Prediction Rule: If cumulative evidence suggests >50% likelihood of recurrence, predict Yes.
45
+ Output Format:
46
+ [Reasoning]
47
+ Decision: [Yes/No]"""
48
 
49
  # -----------------------------
50
  # 2. AI Backend Function
51
  # -----------------------------
52
  def ollama_chat(messages, temperature=0.1):
53
  endpoint = os.getenv("OLLAMA_ENDPOINT")
54
+ if not endpoint: return "Error: OLLAMA_ENDPOINT not set."
55
 
56
  url = f"{endpoint}/api/chat"
57
  headers = {"Content-Type": "application/json", "ngrok-skip-browser-warning": "true"}
 
64
  try:
65
  r = requests.post(url, json=payload, headers=headers, timeout=120)
66
  return r.json().get("message", {}).get("content", "")
67
+ except Exception as e:
68
+ return f"Error: {str(e)}"
69
 
70
  # -----------------------------
71
+ # 3. Evaluation Engine Logic
72
  # -----------------------------
73
+ def run_evaluation(cancer_type, strategy):
 
74
  data_path = CANCER_MAP.get(cancer_type)
75
  gt_path = "data/ground_truth_5yr_recurrence.json"
76
 
77
  if not os.path.exists(data_path) or not os.path.exists(gt_path):
78
+ yield "Error: Required data files not found in /data folder."
79
+ return
80
 
81
  with open(data_path, 'r') as f: patient_db = json.load(f)
82
  with open(gt_path, 'r') as f: all_gt = json.load(f)
83
 
84
  gt_labels = all_gt.get(GT_MAP[cancer_type], {})
 
 
85
  eval_ids = [pid for pid in gt_labels.keys() if pid in patient_db]
86
 
87
+ # Map strategy to system prompt
88
+ sys_content = PROMPT_COT if strategy == "Chain-of-Thought" else (PROMPT_GRADING if strategy == "Evidence Grading" else PROMPT_DIRECT)
89
+
90
  tp, tn, fp, fn = 0, 0, 0, 0
91
+ yield f"🚀 Starting {strategy} inference for {len(eval_ids)} patients in {cancer_type}..."
 
92
 
93
  for i, pid in enumerate(eval_ids):
94
+ actual = gt_labels[pid]
95
  patient_json = json.dumps(patient_db[pid])
96
 
97
+ msgs = [{"role": "system", "content": sys_content}, {"role": "user", "content": f"Patient Data: {patient_json}"}]
98
+ raw_res = ollama_chat(msgs).strip().upper()
 
 
 
99
 
100
+ # Robust Parsing
101
+ if strategy == "Direct":
102
+ pred = "Yes" if "YES" in raw_res[:10] else "No"
103
+ elif strategy == "Chain-of-Thought":
104
+ pred = "Yes" if "FINAL_PREDICTION: YES" in raw_res else "No"
105
+ else: # Evidence Grading
106
+ pred = "Yes" if "DECISION: YES" in raw_res else "No"
107
 
108
+ if pred == "Yes" and actual == "Yes": tp += 1
109
+ elif pred == "No" and actual == "No": tn += 1
110
+ elif pred == "Yes" and actual == "No": fp += 1
111
+ else: fn += 1
 
112
 
113
+ if (i + 1) % 5 == 0:
114
+ yield f"🔄 Progress: {i+1}/{len(eval_ids)} patients processed..."
115
 
116
+ # Metrics
117
+ total = len(eval_ids)
118
+ acc = (tp + tn) / total if total > 0 else 0
119
  sens = tp / (tp + fn) if (tp + fn) > 0 else 0
120
  spec = tn / (tn + fp) if (tn + fp) > 0 else 0
121
 
122
+ yield f"""
123
+ ## {strategy} Strategy Results: {cancer_type}
124
+ - **Accuracy:** {acc:.2%}
 
125
  - **Sensitivity (Recall):** {sens:.2%}
126
  - **Specificity:** {spec:.2%}
127
 
128
+ **Confusion Matrix:**
129
+ | | Predicted YES | Predicted NO |
130
+ |---|---|---|
131
+ | **Actual YES** | {tp} (TP) | {fn} (FN) |
132
+ | **Actual NO** | {fp} (FP) | {tn} (TN) |
133
  """
 
134
 
135
  # -----------------------------
136
+ # 4. Helper UI Logic (Chat)
137
  # -----------------------------
138
+ def load_data(cancer_type):
139
+ path = CANCER_MAP.get(cancer_type)
140
+ with open(path, "r") as f: data = json.load(f)
141
+ ids = sorted([str(k) for k in data.keys()])
142
+ return gr.update(choices=ids, value=ids[0]), data
143
+
144
+ def respond(message, history):
145
+ history = history or []
146
+ # Standard System Prompt for Chat
147
+ sys = {"role": "system", "content": "You are an oncology assistant. Summarize the case and predict outcomes."}
148
+ res = ollama_chat([sys] + history + [{"role": "user", "content": message}])
149
+ history.append({"role": "user", "content": message})
150
+ history.append({"role": "assistant", "content": res})
151
+ return "", history
152
+
153
+ # -----------------------------
154
+ # 5. UI Layout
155
+ # -----------------------------
156
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
157
+ gr.Markdown("# Oncology Research Platform")
158
+ full_data_state = gr.State({})
159
+
160
  with gr.Tabs():
161
+ # TAB 1: Evaluation Engine
162
+ with gr.TabItem("🔬 Performance Metrics"):
163
+ gr.Markdown("### Zero-Shot Inference Experiments")
164
+ with gr.Row():
165
+ e_type = gr.Dropdown(label="Cancer Type", choices=list(CANCER_MAP.keys()), value="Uterine Cancer")
166
+ e_strat = gr.Dropdown(label="Prompting Strategy", choices=["Direct", "Chain-of-Thought", "Evidence Grading"], value="Direct")
167
+ run_btn = gr.Button("Start Experiment", variant="primary")
168
+ results_md = gr.Markdown("Select criteria and start to see metrics.")
169
+
170
+ # TAB 2: Clinical Assistant
171
+ with gr.TabItem("💬 Clinical Assistant"):
172
  with gr.Row():
173
  with gr.Column(scale=1):
174
+ c_select = gr.Dropdown(label="Cancer Type", choices=list(CANCER_MAP.keys()), value="Uterine Cancer")
175
+ p_select = gr.Dropdown(label="Patient ID")
 
 
176
  with gr.Column(scale=2):
177
+ chat = gr.Chatbot(height=400)
178
+ msg = gr.Textbox(label="Patient JSON / Message")
179
+ send = gr.Button("Analyze")
180
+
181
+ # Bindings
182
+ run_btn.click(run_evaluation, [e_type, e_strat], results_md)
183
+ c_select.change(load_data, c_select, [p_select, full_data_state])
184
+ p_select.change(lambda p, d: json.dumps(d.get(p), indent=2), [p_select, full_data_state], msg)
185
+ send.click(respond, [msg, chat], [msg, chat])
186
+
187
+ demo.load(load_data, c_select, [p_select, full_data_state])
188
+
189
+ demo.launch()
190
+
191
+ # import json
192
+ # import os
193
+ # import requests
194
+ # import gradio as gr
195
+ # import pandas as pd
196
+
197
+ # # -----------------------------
198
+ # # 1. Configuration & Data Mapping
199
+ # # -----------------------------
200
+ # CANCER_MAP = {
201
+ # "Uterine Cancer": "data/ucec_combined_data.json",
202
+ # "Breast Cancer": "data/brca_combined_data.json",
203
+ # "Lung Cancer": "data/luad_combined_data.json",
204
+ # "Bladder Cancer": "data/blca_combined_data.json",
205
+ # "Head and Neck Cancer": "data/hnsc_combined_data.json",
206
+ # }
207
+
208
+ # # Map for the Ground Truth JSON keys
209
+ # GT_MAP = {
210
+ # "Uterine Cancer": "UCEC",
211
+ # "Breast Cancer": "BRCA",
212
+ # "Lung Cancer": "LUAD",
213
+ # "Bladder Cancer": "BLCA",
214
+ # "Head and Neck Cancer": "HNSC",
215
+ # }
216
+
217
+ # COMMON_AGENTS = ["Carboplatin", "Paclitaxel", "Cisplatin", "Gemcitabine", "Doxorubicin", "Other"]
218
+
219
+ # # -----------------------------
220
+ # # 2. AI Backend Function
221
+ # # -----------------------------
222
+ # def ollama_chat(messages, temperature=0.1):
223
+ # endpoint = os.getenv("OLLAMA_ENDPOINT")
224
+ # if not endpoint: return "Error: Endpoint not set."
225
+
226
+ # url = f"{endpoint}/api/chat"
227
+ # headers = {"Content-Type": "application/json", "ngrok-skip-browser-warning": "true"}
228
+ # payload = {
229
+ # "model": "qwen2.5:7b",
230
+ # "messages": messages,
231
+ # "stream": False,
232
+ # "options": {"temperature": float(temperature), "num_ctx": 8192}
233
+ # }
234
+ # try:
235
+ # r = requests.post(url, json=payload, headers=headers, timeout=120)
236
+ # return r.json().get("message", {}).get("content", "")
237
+ # except: return "Connection Error"
238
+
239
+ # # -----------------------------
240
+ # # 3. Evaluation Logic
241
+ # # -----------------------------
242
+ # def run_evaluation(cancer_type):
243
+ # # 1. Load Data
244
+ # data_path = CANCER_MAP.get(cancer_type)
245
+ # gt_path = "data/ground_truth_5yr_recurrence.json"
246
+
247
+ # if not os.path.exists(data_path) or not os.path.exists(gt_path):
248
+ # return "Error: Missing data or ground truth files."
249
+
250
+ # with open(data_path, 'r') as f: patient_db = json.load(f)
251
+ # with open(gt_path, 'r') as f: all_gt = json.load(f)
252
+
253
+ # gt_labels = all_gt.get(GT_MAP[cancer_type], {})
254
+
255
+ # # 2. Filter patients present in both
256
+ # eval_ids = [pid for pid in gt_labels.keys() if pid in patient_db]
257
+
258
+ # results = []
259
+ # tp, tn, fp, fn = 0, 0, 0, 0
260
+
261
+ # yield f"Starting inference for {len(eval_ids)} patients in {cancer_type}..."
262
+
263
+ # for i, pid in enumerate(eval_ids):
264
+ # actual = gt_labels[pid] # "Yes" or "No"
265
+ # patient_json = json.dumps(patient_db[pid])
266
+
267
+ # # Zero-shot prompt
268
+ # eval_prompt = [
269
+ # {"role": "system", "content": "You are an oncology expert. Predict 5-year recurrence based ONLY on the provided JSON. Respond strictly with 'Yes' or 'No' and nothing else."},
270
+ # {"role": "user", "content": f"Patient Data: {patient_json}"}
271
+ # ]
272
+
273
+ # prediction_raw = ollama_chat(eval_prompt).strip()
274
+ # # Simple parser to find Yes/No in response
275
+ # prediction = "Yes" if "yes" in prediction_raw.lower() else "No"
276
+
277
+ # # Calculate Metrics
278
+ # if prediction == "Yes" and actual == "Yes": tp += 1
279
+ # elif prediction == "No" and actual == "No": tn += 1
280
+ # elif prediction == "Yes" and actual == "No": fp += 1
281
+ # elif prediction == "No" and actual == "Yes": fn += 1
282
+
283
+ # if i % 5 == 0:
284
+ # yield f"Processed {i+1}/{len(eval_ids)} patients..."
285
+
286
+ # # 3. Final Metric Calculation
287
+ # acc = (tp + tn) / len(eval_ids) if eval_ids else 0
288
+ # sens = tp / (tp + fn) if (tp + fn) > 0 else 0
289
+ # spec = tn / (tn + fp) if (tn + fp) > 0 else 0
290
+
291
+ # summary = f"""
292
+ # ### Evaluation Results: {cancer_type}
293
+ # - **Total Patients Processed:** {len(eval_ids)}
294
+ # - **Unweighted Accuracy:** {acc:.2%}
295
+ # - **Sensitivity (Recall):** {sens:.2%}
296
+ # - **Specificity:** {spec:.2%}
297
+
298
+ # *Confusion Matrix: TP={tp}, TN={tn}, FP={fp}, FN={fn}*
299
+ # """
300
+ # yield summary
301
+
302
+ # # -----------------------------
303
+ # # 4. UI Layout (Modified)
304
+ # # -----------------------------
305
+ # with gr.Blocks(title="OncoRisk Eval & Demo") as demo:
306
+ # gr.HTML('<div style="text-align:center"><h1>Oncology Risk Assistant</h1></div>')
307
+
308
+ # with gr.Tabs():
309
+ # # Tab 1: Your original Chat/Simulation UI
310
+ # with gr.TabItem("Clinical Assistant"):
311
+ # with gr.Row():
312
+ # with gr.Column(scale=1):
313
+ # cancer_select = gr.Dropdown(label="Select Cancer Type", choices=list(CANCER_MAP.keys()))
314
+ # patient_select = gr.Dropdown(label="Select Patient ID")
315
+ # submit_btn = gr.Button("Analyze Case", variant="primary")
316
+ # missing_output = gr.HighlightedText(label="Completeness")
317
+ # with gr.Column(scale=2):
318
+ # chatbot = gr.Chatbot(height=500)
319
+ # msg_input = gr.Textbox(label="Input Box", lines=5)
320
 
321
+ # # Tab 2: NEW Evaluation Engine
322
+ # with gr.TabItem("Performance Metrics (Zero-Shot)"):
323
+ # gr.Markdown("### Run Zero-Shot Inference on Ground Truth")
324
+ # eval_cancer_type = gr.Dropdown(label="Select Cancer for Evaluation", choices=list(CANCER_MAP.keys()))
325
+ # run_eval_btn = gr.Button("Start Experiment", variant="secondary")
326
+ # eval_results = gr.Markdown("Results will appear here after inference...")
327
 
328
+ # # Logic for Evaluation
329
+ # run_eval_btn.click(fn=run_evaluation, inputs=eval_cancer_type, outputs=eval_results)
330
 
331
+ # # (Keep your existing Event Logic for Chat/Data Selection here...)
332
+ # # ... [Same as your provided code] ...
333
 
334
+ # demo.launch()