shivapriyasom commited on
Commit
4a16c34
·
verified ·
1 Parent(s): 7de8b0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -3
app.py CHANGED
@@ -6,8 +6,6 @@ from inference import (
6
  FEATURE_NAMES,
7
  REPORTING_OUTCOMES,
8
  OUTCOME_DESCRIPTIONS,
9
- OUTCOMES,
10
- SHAP_OUTCOMES,
11
  predict_with_comparison,
12
  create_all_shap_plots,
13
  icon_array,
@@ -144,4 +142,194 @@ DONOR_FEATURES = ["DONORF", "GRAFTYPE", "HLA_FINAL",
144
  DISEASE_FEATURES = ["NACS2YR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN"]
145
  ALL_FEATURES = PATIENT_FEATURES + DONOR_FEATURES + DISEASE_FEATURES
146
 
147
- # ------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  FEATURE_NAMES,
7
  REPORTING_OUTCOMES,
8
  OUTCOME_DESCRIPTIONS,
 
 
9
  predict_with_comparison,
10
  create_all_shap_plots,
11
  icon_array,
 
142
  DISEASE_FEATURES = ["NACS2YR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN"]
143
  ALL_FEATURES = PATIENT_FEATURES + DONOR_FEATURES + DISEASE_FEATURES
144
 
145
+ # ---------------------------------------------------------------------------
146
+ # Utility callbacks
147
+ # ---------------------------------------------------------------------------
148
+
149
+ def get_age_group(age):
150
+ if age is None or age == "":
151
+ return ""
152
+ try:
153
+ age = float(age)
154
+ if age <= 10:
155
+ return "<=10"
156
+ elif age <= 17:
157
+ return "11-17"
158
+ elif age <= 29:
159
+ return "18-29"
160
+ elif age <= 49:
161
+ return "30-49"
162
+ else:
163
+ return ">=50"
164
+ except (ValueError, TypeError):
165
+ return ""
166
+
167
+
168
+ def vocfrqpr_from_voc2ypr(voc_status):
169
+ if voc_status == "No":
170
+ return gr.update(value="< 3/yr", interactive=False)
171
+ else:
172
+ return gr.update(value=None, interactive=True)
173
+
174
+
175
+ def apply_grouped_preset(selected_value):
176
+ if not selected_value or selected_value in HEADER_VALUES:
177
+ return [gr.update(value=None)] + [gr.update()] * 6
178
+
179
+ preset = PUBLISHED_PRESETS.get(selected_value)
180
+ if not preset:
181
+ return [gr.update()] * 7
182
+
183
+ return [
184
+ gr.update(),
185
+ gr.update(value=preset["DONORF"]),
186
+ gr.update(value=preset["CONDGRPF"]),
187
+ gr.update(value=preset["CONDGRP_FINAL"]),
188
+ gr.update(value=preset["ATGF"]),
189
+ gr.update(value=preset["GVHD_FINAL"]),
190
+ gr.update(value=preset["HLA_FINAL"]),
191
+ ]
192
+
193
+ # ---------------------------------------------------------------------------
194
+ # Component factory
195
+ # ---------------------------------------------------------------------------
196
+
197
+ def make_component(name: str):
198
+ if name == "AGE":
199
+ return gr.Number(label="Age at transplant (years)", minimum=0, maximum=120)
200
+ elif name == "AGEGPFF":
201
+ return gr.Textbox(label="Age group (Auto-filled)", interactive=False)
202
+ elif name == "NACS2YR":
203
+ return gr.Number(
204
+ label="Number of Acute Chest Syndromes within 2 years pre-HCT",
205
+ minimum=0,
206
+ )
207
+ elif name == "SEX":
208
+ return gr.Dropdown(SEX_CHOICES, label="Sex")
209
+ elif name == "KPS":
210
+ return gr.Dropdown(KPS_CHOICES, label="Karnofsky/Lansky Performance Score at HCT")
211
+ elif name == "DONORF":
212
+ return gr.Dropdown(DONORF_CHOICES, label="Donor type")
213
+ elif name == "GRAFTYPE":
214
+ return gr.Dropdown(GRAFTYPE_CHOICES, label="Graft type")
215
+ elif name == "CONDGRPF":
216
+ return gr.Dropdown(CONDGRPF_CHOICES, label="Conditioning intensity")
217
+ elif name == "CONDGRP_FINAL":
218
+ return gr.Dropdown(CONDGRP_FINAL_CHOICES, label="Conditioning Regimen")
219
+ elif name == "ATGF":
220
+ return gr.Dropdown(ATGF_CHOICES, label="Serotherapy")
221
+ elif name == "GVHD_FINAL":
222
+ return gr.Dropdown(GVHD_FINAL_CHOICES, label="GVHD Prophylaxis")
223
+ elif name == "HLA_FINAL":
224
+ return gr.Dropdown(HLA_FINAL_CHOICES, label="Donor-Recipient HLA Matching")
225
+ elif name == "RCMVPR":
226
+ return gr.Dropdown(RCMVPR_CHOICES, label="Recipient CMV serostatus")
227
+ elif name == "EXCHTFPR":
228
+ return gr.Dropdown(EXCHTFPR_CHOICES, label="Exchange transfusion required?")
229
+ elif name == "VOC2YPR":
230
+ return gr.Dropdown(
231
+ VOC2YPR_CHOICES,
232
+ label="VOC requiring hospitalization within 2 years pre-HCT?",
233
+ )
234
+ elif name == "VOCFRQPR":
235
+ return gr.Dropdown(VOCFRQPR_CHOICES, label="Frequency of VOC hospitalizations")
236
+ elif name == "SCATXRSN":
237
+ return gr.Dropdown(SCATXRSN_CHOICES, label="Reason for Transplant")
238
+ else:
239
+ return gr.Textbox(label=name)
240
+
241
+ # ---------------------------------------------------------------------------
242
+ # Prediction callback
243
+ # ---------------------------------------------------------------------------
244
+
245
+ def predict_gradio(*values):
246
+ try:
247
+ user_vals = {f: v for f, v in zip(ALL_FEATURES, values)}
248
+
249
+ missing = []
250
+ for f, v in user_vals.items():
251
+ if v is None or v == "" or (isinstance(v, float) and pd.isna(v)):
252
+ missing.append(f)
253
+ if missing:
254
+ raise ValueError(
255
+ f"Please fill in all fields before predicting.\nMissing: {', '.join(missing)}"
256
+ )
257
+
258
+ calibrated, _ = predict_with_comparison(user_vals)
259
+ calibrated_probs, calibrated_intervals = calibrated
260
+
261
+ rows = []
262
+ for outcome in REPORTING_OUTCOMES:
263
+ desc = OUTCOME_DESCRIPTIONS[outcome]
264
+ calib_prob = calibrated_probs[outcome]
265
+ ci_low_c, ci_high_c = calibrated_intervals[outcome]
266
+ rows.append({
267
+ "Outcome": desc,
268
+ "Probability": f"{calib_prob * 100:.1f}%",
269
+ "95% CI": f"[{ci_low_c * 100:.1f}% - {ci_high_c * 100:.1f}%]",
270
+ })
271
+ df = pd.DataFrame(rows)
272
+
273
+ shap_plots = create_all_shap_plots(user_vals, max_display=10)
274
+
275
+ icon_outcomes = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI"]
276
+ icon_plots = {o: icon_array(calibrated_probs[o], o) for o in icon_outcomes}
277
+
278
+ return (
279
+ df,
280
+ icon_plots["DEAD"],
281
+ icon_plots["GF"],
282
+ icon_plots["AGVHD"],
283
+ icon_plots["CGVHD"],
284
+ icon_plots["VOCPSHI"],
285
+ icon_plots["STROKEHI"],
286
+ shap_plots["DEAD"],
287
+ shap_plots["GF"],
288
+ shap_plots["AGVHD"],
289
+ shap_plots["CGVHD"],
290
+ shap_plots["VOCPSHI"],
291
+ shap_plots["EFS"],
292
+ shap_plots["STROKEHI"],
293
+ shap_plots["OS"],
294
+ )
295
+
296
+ except Exception as e:
297
+ tb = traceback.format_exc()
298
+ print("=" * 60)
299
+ print("ERROR IN predict_gradio:")
300
+ print(tb)
301
+ print("=" * 60)
302
+ raise gr.Error(f"{type(e).__name__}: {str(e)}\n\nSee terminal for full traceback.")
303
+
304
+ # ---------------------------------------------------------------------------
305
+ # CSS
306
+ # ---------------------------------------------------------------------------
307
+
308
+ custom_css = """
309
+ .predict-button {
310
+ background: linear-gradient(to right, #ff6b35, #ff8c42) !important;
311
+ border: none !important;
312
+ color: white !important;
313
+ font-weight: bold !important;
314
+ font-size: 16px !important;
315
+ padding: 12px !important;
316
+ }
317
+ .predict-button:hover {
318
+ background: linear-gradient(to right, #ff5722, #ff7b29) !important;
319
+ }
320
+ """
321
+
322
+ # ---------------------------------------------------------------------------
323
+ # Gradio UI
324
+ # ---------------------------------------------------------------------------
325
+
326
+ with gr.Blocks(title="HCT Outcome Prediction Model", css=custom_css) as demo:
327
+ gr.Markdown(
328
+ """
329
+ # HCT Outcome Prediction Model
330
+
331
+ Enter patient, transplant, and disease characteristics to predict outcomes.
332
+ """
333
+ )
334
+
335
+ inputs_dict =