shivapriyasom commited on
Commit
5608848
Β·
verified Β·
1 Parent(s): e4469ae

Upload 11 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,11 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ensemble_model_AGVHD.skops filter=lfs diff=lfs merge=lfs -text
37
+ ensemble_model_CGVHD.skops filter=lfs diff=lfs merge=lfs -text
38
+ ensemble_model_DEAD.skops filter=lfs diff=lfs merge=lfs -text
39
+ ensemble_model_DWOGF.skops filter=lfs diff=lfs merge=lfs -text
40
+ ensemble_model_GF.skops filter=lfs diff=lfs merge=lfs -text
41
+ ensemble_model_STROKEHI.skops filter=lfs diff=lfs merge=lfs -text
42
+ ensemble_model_VOCPSHI.skops filter=lfs diff=lfs merge=lfs -text
43
+ preprocessor.skops filter=lfs diff=lfs merge=lfs -text
app (2).py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import traceback
4
+
5
+ from inference import (
6
+ FEATURE_NAMES,
7
+ REPORTING_OUTCOMES,
8
+ OUTCOME_DESCRIPTIONS,
9
+ OUTCOMES,
10
+ SHAP_OUTCOMES,
11
+ predict_with_comparison,
12
+ create_all_shap_plots,
13
+ icon_array,
14
+ )
15
+
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Choice lists
19
+ # ---------------------------------------------------------------------------
20
+
21
+ AGEGPFF_CHOICES = ["<=10", "11-17", "18-29", "30-49", ">=50"]
22
+ SEX_CHOICES = ["Male", "Female"]
23
+ KPS_CHOICES = ["<90", "β‰₯ 90"]
24
+ DONORF_CHOICES = [
25
+ "HLA identical sibling",
26
+ "HLA mismatch relative",
27
+ "Matched unrelated donor",
28
+ "Mismatched unrelated donor or cord blood",
29
+ ]
30
+ GRAFTYPE_CHOICES = ["Bone marrow", "Peripheral blood", "Cord blood"]
31
+ CONDGRPF_CHOICES = ["MAC", "RIC", "NMA"]
32
+ CONDGRP_FINAL_CHOICES = [
33
+ "TBI/Cy", "TBI/Cy/Flu", "TBI/Cy/Flu/TT", "TBI/Mel", "TBI/Flu",
34
+ "TBI alone (300/400/600cGy)", "Bu/Cy", "Bu/Mel", "Flu/Bu/TT",
35
+ "Flu/Bu", "Flu/Mel/TT", "Flu/Mel", "Cy/Flu", "Treosulfan",
36
+ "Cy alone", "Flud", "TLI",
37
+ ]
38
+ ATGF_CHOICES = ["ATG", "Alemtuzumab", "None"]
39
+ GVHD_FINAL_CHOICES = [
40
+ "Ex-vivo T-cell depletion", "CD34 selection", "Post-CY + siro +- MMF",
41
+ "Post-CY + MMF + CNI", "CNI + MMF", "CNI + MTX", "CNI alone",
42
+ "CNI + siro", "Siro alone", "MMF + MTX", "MMF + siro", "MMF alone",
43
+ "MTX alone", "MTX + siro",
44
+ ]
45
+ HLA_FINAL_CHOICES = ["8/8", "7/8", "≀ 6/8"]
46
+ RCMVPR_CHOICES = ["Negative", "Positive"]
47
+ EXCHTFPR_CHOICES = ["No", "Yes"]
48
+ VOC2YPR_CHOICES = ["No", "Yes"]
49
+ VOCFRQPR_CHOICES = ["< 3/yr", "β‰₯ 3/yr"]
50
+ SCATXRSN_CHOICES = [
51
+ "CNS event", "Acute chest Syndrome", "Recurrent vaso-occlusive pain",
52
+ "Recurrent priapism", "Excessive transfusion requirements/iron overload",
53
+ "Cardio-pulmonary", "Chronic transfusion", "Asymptomatic",
54
+ "Renal insufficiency", "Splenic sequestration", "Avascular necrosis",
55
+ "Hodgkin lymphoma",
56
+ ]
57
+
58
+
59
+ # ---------------------------------------------------------------------------
60
+ # Grouped published-regimen dropdown
61
+ # ---------------------------------------------------------------------------
62
+
63
+ GROUPED_REGIMEN_CHOICES = [
64
+ ("── HLA IDENTICAL ──", "__header_hla_identical__"),
65
+ ("Hsieh et al 2014", "Hsieh et al 2014"),
66
+ ("Krishnamurti et al 2019", "Krishnamurti et al 2019"),
67
+ ("King et al 2015", "King et al 2015"),
68
+ ("Walters et al 1996", "Walters et al 1996"),
69
+ ("── HLA MISMATCHED ──", "__header_hla_mismatched__"),
70
+ ("Bolanos-Meade et al 2022 (HLA Mismatch)", "Bolanos-Meade et al 2022 (HLA Mismatch)"),
71
+ ("Patel et al 2020 (HLA Mismatch)", "Patel et al 2020 (HLA Mismatch)"),
72
+ ("── MATCHED UNRELATED ──", "__header_matched_unrelated__"),
73
+ ("L Krishnamurti et al 2019", "L Krishnamurti et al 2019"),
74
+ ("Shenoy et al 2016", "Shenoy et al 2016"),
75
+ ("── MISMATCHED UNRELATED / CORD BLOOD ──", "__header_mismatched_cord__"),
76
+ ("Bolanos-Meade et al 2022 (Mismatched/Cord)", "Bolanos-Meade et al 2022 (Mismatched/Cord)"),
77
+ ("Patel et al 2020 (Mismatched/Cord)", "Patel et al 2020 (Mismatched/Cord)"),
78
+ ]
79
+
80
+ HEADER_VALUES = {v for _, v in GROUPED_REGIMEN_CHOICES if v.startswith("__header_")}
81
+
82
+ PUBLISHED_PRESETS = {
83
+ # HLA Identical Sibling
84
+ "Hsieh et al 2014": {
85
+ "CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI alone (300/400/600cGy)",
86
+ "ATGF": "Alemtuzumab", "GVHD_FINAL": "Siro alone",
87
+ "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling",
88
+ },
89
+ "Krishnamurti et al 2019": {
90
+ "CONDGRPF": "MAC", "CONDGRP_FINAL": "Flu/Bu",
91
+ "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX",
92
+ "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling",
93
+ },
94
+ "King et al 2015": {
95
+ "CONDGRPF": "RIC", "CONDGRP_FINAL": "Flu/Mel",
96
+ "ATGF": "Alemtuzumab", "GVHD_FINAL": "CNI + MTX",
97
+ "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling",
98
+ },
99
+ "Walters et al 1996": {
100
+ "CONDGRPF": "MAC", "CONDGRP_FINAL": "Bu/Cy",
101
+ "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX",
102
+ "HLA_FINAL": "8/8", "DONORF": "HLA identical sibling",
103
+ },
104
+ # HLA Mismatch Relative
105
+ "Bolanos-Meade et al 2022 (HLA Mismatch)": {
106
+ "CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu",
107
+ "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF",
108
+ "HLA_FINAL": "7/8", "DONORF": "HLA mismatch relative",
109
+ },
110
+ "Patel et al 2020 (HLA Mismatch)": {
111
+ "CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT",
112
+ "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF",
113
+ "HLA_FINAL": "7/8", "DONORF": "HLA mismatch relative",
114
+ },
115
+ # Matched Unrelated Donor
116
+ "L Krishnamurti et al 2019": {
117
+ "CONDGRPF": "MAC", "CONDGRP_FINAL": "Flu/Bu",
118
+ "ATGF": "ATG", "GVHD_FINAL": "CNI + MTX",
119
+ "HLA_FINAL": "8/8", "DONORF": "Matched unrelated donor",
120
+ },
121
+ "Shenoy et al 2016": {
122
+ "CONDGRPF": "RIC", "CONDGRP_FINAL": "Flu/Mel",
123
+ "ATGF": "Alemtuzumab", "GVHD_FINAL": "CNI + MTX",
124
+ "HLA_FINAL": "8/8", "DONORF": "Matched unrelated donor",
125
+ },
126
+ # Mismatched Unrelated Donor or Cord Blood
127
+ "Bolanos-Meade et al 2022 (Mismatched/Cord)": {
128
+ "CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu",
129
+ "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF",
130
+ "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood",
131
+ },
132
+ "Patel et al 2020 (Mismatched/Cord)": {
133
+ "CONDGRPF": "NMA", "CONDGRP_FINAL": "TBI/Cy/Flu/TT",
134
+ "ATGF": "ATG", "GVHD_FINAL": "Post-CY + siro +- MMF",
135
+ "HLA_FINAL": "7/8", "DONORF": "Mismatched unrelated donor or cord blood",
136
+ },
137
+ }
138
+
139
+
140
+ # ---------------------------------------------------------------------------
141
+ # Feature groupings
142
+ # ---------------------------------------------------------------------------
143
+
144
+ PATIENT_FEATURES = ["AGE", "AGEGPFF", "SEX", "KPS", "RCMVPR"]
145
+ DONOR_FEATURES = ["DONORF", "GRAFTYPE", "HLA_FINAL",
146
+ "CONDGRPF", "CONDGRP_FINAL", "ATGF", "GVHD_FINAL"]
147
+ DISEASE_FEATURES = ["NACS2YR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN"]
148
+ ALL_FEATURES = PATIENT_FEATURES + DONOR_FEATURES + DISEASE_FEATURES
149
+
150
+
151
+ # ---------------------------------------------------------------------------
152
+ # Utility callbacks
153
+ # ---------------------------------------------------------------------------
154
+
155
+ def get_age_group(age):
156
+ if age is None or age == "":
157
+ return ""
158
+ try:
159
+ age = float(age)
160
+ if age <= 10:
161
+ return "<=10"
162
+ elif age <= 17:
163
+ return "11-17"
164
+ elif age <= 29:
165
+ return "18-29"
166
+ elif age <= 49:
167
+ return "30-49"
168
+ else:
169
+ return ">=50"
170
+ except (ValueError, TypeError):
171
+ return ""
172
+
173
+
174
+ def vocfrqpr_from_voc2ypr(voc_status):
175
+ if voc_status == "No":
176
+ return gr.update(value="< 3/yr", interactive=False)
177
+ else:
178
+ return gr.update(value=None, interactive=True)
179
+
180
+
181
+ def apply_grouped_preset(selected_value):
182
+ if not selected_value or selected_value in HEADER_VALUES:
183
+ return [gr.update(value=None)] + [gr.update()] * 6
184
+
185
+ preset = PUBLISHED_PRESETS.get(selected_value)
186
+ if not preset:
187
+ return [gr.update()] * 7
188
+
189
+ return [
190
+ gr.update(),
191
+ gr.update(value=preset["DONORF"]),
192
+ gr.update(value=preset["CONDGRPF"]),
193
+ gr.update(value=preset["CONDGRP_FINAL"]),
194
+ gr.update(value=preset["ATGF"]),
195
+ gr.update(value=preset["GVHD_FINAL"]),
196
+ gr.update(value=preset["HLA_FINAL"]),
197
+ ]
198
+
199
+
200
+ # ---------------------------------------------------------------------------
201
+ # Component factory
202
+ # ---------------------------------------------------------------------------
203
+
204
+ def make_component(name: str):
205
+ if name == "AGE":
206
+ return gr.Number(label="Age at transplant (years)", minimum=0, maximum=120)
207
+ elif name == "AGEGPFF":
208
+ return gr.Textbox(label="Age group (Auto-filled)", interactive=False)
209
+ elif name == "NACS2YR":
210
+ return gr.Number(
211
+ label="Number of Acute Chest Syndromes within 2 years pre-HCT",
212
+ minimum=0,
213
+ )
214
+ elif name == "SEX":
215
+ return gr.Dropdown(SEX_CHOICES, label="Sex")
216
+ elif name == "KPS":
217
+ return gr.Dropdown(KPS_CHOICES, label="Karnofsky/Lansky Performance Score at HCT")
218
+ elif name == "DONORF":
219
+ return gr.Dropdown(DONORF_CHOICES, label="Donor type")
220
+ elif name == "GRAFTYPE":
221
+ return gr.Dropdown(GRAFTYPE_CHOICES, label="Graft type")
222
+ elif name == "CONDGRPF":
223
+ return gr.Dropdown(CONDGRPF_CHOICES, label="Conditioning intensity")
224
+ elif name == "CONDGRP_FINAL":
225
+ return gr.Dropdown(CONDGRP_FINAL_CHOICES, label="Conditioning Regimen")
226
+ elif name == "ATGF":
227
+ return gr.Dropdown(ATGF_CHOICES, label="Serotherapy")
228
+ elif name == "GVHD_FINAL":
229
+ return gr.Dropdown(GVHD_FINAL_CHOICES, label="GVHD Prophylaxis")
230
+ elif name == "HLA_FINAL":
231
+ return gr.Dropdown(HLA_FINAL_CHOICES, label="Donor-Recipient HLA Matching")
232
+ elif name == "RCMVPR":
233
+ return gr.Dropdown(RCMVPR_CHOICES, label="Recipient CMV serostatus")
234
+ elif name == "EXCHTFPR":
235
+ return gr.Dropdown(EXCHTFPR_CHOICES, label="Exchange transfusion required?")
236
+ elif name == "VOC2YPR":
237
+ return gr.Dropdown(
238
+ VOC2YPR_CHOICES,
239
+ label="VOC requiring hospitalization within 2 years pre-HCT?",
240
+ )
241
+ elif name == "VOCFRQPR":
242
+ return gr.Dropdown(VOCFRQPR_CHOICES, label="Frequency of VOC hospitalizations")
243
+ elif name == "SCATXRSN":
244
+ return gr.Dropdown(SCATXRSN_CHOICES, label="Reason for Transplant")
245
+ else:
246
+ return gr.Textbox(label=name)
247
+
248
+
249
+ # ---------------------------------------------------------------------------
250
+ # Prediction callback
251
+ # ---------------------------------------------------------------------------
252
+
253
+ def predict_gradio(*values):
254
+ try:
255
+ user_vals = {f: v for f, v in zip(ALL_FEATURES, values)}
256
+
257
+ missing = []
258
+ for f, v in user_vals.items():
259
+ if v is None or v == "" or (isinstance(v, float) and pd.isna(v)):
260
+ missing.append(f)
261
+ if missing:
262
+ raise ValueError(
263
+ f"Please fill in all fields before predicting.\nMissing: {', '.join(missing)}"
264
+ )
265
+
266
+ calibrated, uncalibrated = predict_with_comparison(user_vals)
267
+ calibrated_probs, calibrated_intervals = calibrated
268
+
269
+ rows = []
270
+ for outcome in REPORTING_OUTCOMES:
271
+ desc = OUTCOME_DESCRIPTIONS[outcome]
272
+ calib_prob = calibrated_probs[outcome]
273
+ ci_low_c, ci_high_c = calibrated_intervals[outcome]
274
+ rows.append({
275
+ "Outcome": desc,
276
+ "Probability": f"{calib_prob * 100:.1f}%",
277
+ "95% CI": f"[{ci_low_c * 100:.1f}% - {ci_high_c * 100:.1f}%]",
278
+ })
279
+ df = pd.DataFrame(rows)
280
+
281
+ shap_plots = create_all_shap_plots(user_vals, max_display=10)
282
+
283
+ # Icon arrays for each outcome
284
+ icon_outcomes = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI"]
285
+ icon_plots = {o: icon_array(calibrated_probs[o], o) for o in icon_outcomes}
286
+
287
+ return (
288
+ df,
289
+ icon_plots["DEAD"],
290
+ icon_plots["GF"],
291
+ icon_plots["AGVHD"],
292
+ icon_plots["CGVHD"],
293
+ icon_plots["VOCPSHI"],
294
+ icon_plots["STROKEHI"],
295
+ shap_plots["DEAD"],
296
+ shap_plots["GF"],
297
+ shap_plots["AGVHD"],
298
+ shap_plots["CGVHD"],
299
+ shap_plots["VOCPSHI"],
300
+ shap_plots["EFS"],
301
+ shap_plots["STROKEHI"],
302
+ shap_plots["OS"],
303
+ )
304
+
305
+ except Exception as e:
306
+ tb = traceback.format_exc()
307
+ print("=" * 60)
308
+ print("ERROR IN predict_gradio:")
309
+ print(tb)
310
+ print("=" * 60)
311
+ raise gr.Error(f"{type(e).__name__}: {str(e)}\n\nSee terminal for full traceback.")
312
+
313
+
314
+ # ---------------------------------------------------------------------------
315
+ # CSS (passed to launch() in Gradio 6+)
316
+ # ---------------------------------------------------------------------------
317
+
318
+ custom_css = """
319
+ .predict-button {
320
+ background: linear-gradient(to right, #ff6b35, #ff8c42) !important;
321
+ border: none !important;
322
+ color: white !important;
323
+ font-weight: bold !important;
324
+ font-size: 16px !important;
325
+ padding: 12px !important;
326
+ }
327
+ .predict-button:hover {
328
+ background: linear-gradient(to right, #ff5722, #ff7b29) !important;
329
+ }
330
+ """
331
+
332
+ # ---------------------------------------------------------------------------
333
+ # Gradio UI
334
+ # ---------------------------------------------------------------------------
335
+
336
+ with gr.Blocks(title="HCT Outcome Prediction Model") as demo:
337
+ gr.Markdown(
338
+ """
339
+ # HCT Outcome Prediction Model
340
+
341
+ Enter patient, transplant, and disease characteristics to predict outcomes.
342
+ """
343
+ )
344
+
345
+ inputs_dict = {}
346
+
347
+ with gr.Row():
348
+ # ── Patient Characteristics ──────────────────────────────────────
349
+ with gr.Column(scale=1):
350
+ gr.Markdown("### Patient Characteristics")
351
+ for f in PATIENT_FEATURES:
352
+ inputs_dict[f] = make_component(f)
353
+
354
+ # ── Transplant Characteristics ───────────────────────────────────
355
+ with gr.Column(scale=1):
356
+ gr.Markdown("### Transplant Characteristics")
357
+
358
+ grouped_regimen_dropdown = gr.Dropdown(
359
+ choices=GROUPED_REGIMEN_CHOICES,
360
+ value=None,
361
+ label="Published conditioning regimen",
362
+ info="Auto-fills Donor Type, Conditioning Intensity, Conditioning Regimen, "
363
+ "Serotherapy and GVHD Prophylaxis",
364
+ )
365
+
366
+ donorf_comp = inputs_dict["DONORF"] = make_component("DONORF")
367
+ inputs_dict["GRAFTYPE"] = make_component("GRAFTYPE")
368
+ condgrpf = inputs_dict["CONDGRPF"] = make_component("CONDGRPF")
369
+ condgrp_final = inputs_dict["CONDGRP_FINAL"] = make_component("CONDGRP_FINAL")
370
+ atgf = inputs_dict["ATGF"] = make_component("ATGF")
371
+ gvhd_final = inputs_dict["GVHD_FINAL"] = make_component("GVHD_FINAL")
372
+ hla_final = inputs_dict["HLA_FINAL"] = make_component("HLA_FINAL")
373
+
374
+ # ── Disease Characteristics ──────────────────────────────────────
375
+ with gr.Column(scale=1):
376
+ gr.Markdown("### Disease Characteristics")
377
+ for f in DISEASE_FEATURES:
378
+ inputs_dict[f] = make_component(f)
379
+
380
+ # ── Reactive callbacks ───────────────────────────────────────────────
381
+ inputs_dict["AGE"].change(
382
+ fn=get_age_group,
383
+ inputs=inputs_dict["AGE"],
384
+ outputs=inputs_dict["AGEGPFF"],
385
+ )
386
+
387
+ inputs_dict["VOC2YPR"].change(
388
+ fn=vocfrqpr_from_voc2ypr,
389
+ inputs=inputs_dict["VOC2YPR"],
390
+ outputs=inputs_dict["VOCFRQPR"],
391
+ )
392
+
393
+ grouped_regimen_dropdown.change(
394
+ fn=apply_grouped_preset,
395
+ inputs=grouped_regimen_dropdown,
396
+ outputs=[
397
+ grouped_regimen_dropdown,
398
+ donorf_comp, condgrpf, condgrp_final, atgf, gvhd_final, hla_final,
399
+ ],
400
+ )
401
+
402
+ inputs_list = [inputs_dict[f] for f in ALL_FEATURES]
403
+
404
+ btn = gr.Button("Predict", elem_classes="predict-button", size="lg")
405
+
406
+ gr.Markdown("---")
407
+ gr.Markdown("## Prediction Results")
408
+ gr.Markdown("### Predicted Outcomes")
409
+
410
+ with gr.Column():
411
+ output_table = gr.Dataframe(
412
+ headers=["Outcome", "Probability", "95% CI"],
413
+ label="",
414
+ row_count=(len(REPORTING_OUTCOMES), "dynamic"),
415
+ column_count=(3, "fixed"), # fixed: col_count β†’ column_count (Gradio 6)
416
+ )
417
+
418
+ gr.Markdown("---")
419
+ gr.Markdown("## Icon Arrays")
420
+
421
+ with gr.Row():
422
+ with gr.Column():
423
+ icon_dead = gr.Plot(label="Death")
424
+ with gr.Column():
425
+ icon_gf = gr.Plot(label="Graft Failure")
426
+ with gr.Column():
427
+ icon_agvhd = gr.Plot(label="Acute Graft-versus-Host Disease")
428
+
429
+ with gr.Row():
430
+ with gr.Column():
431
+ icon_cgvhd = gr.Plot(label="Chronic Graft-versus-Host Disease")
432
+ with gr.Column():
433
+ icon_vocpshi = gr.Plot(label="Vaso-Occlusive Crisis Post-HCT")
434
+ with gr.Column():
435
+ icon_stroke = gr.Plot(label="Stroke Post-HCT")
436
+
437
+ gr.Markdown("---")
438
+ gr.Markdown("## SHAP - Feature Importance")
439
+
440
+ with gr.Row():
441
+ with gr.Column():
442
+ shap_dead = gr.Plot(label="Death")
443
+ with gr.Column():
444
+ shap_gf = gr.Plot(label="Graft Failure")
445
+ with gr.Column():
446
+ shap_agvhd = gr.Plot(label="Acute Graft-versus-Host Disease")
447
+ with gr.Column():
448
+ shap_cgvhd = gr.Plot(label="Chronic Graft-versus-Host Disease")
449
+
450
+ with gr.Row():
451
+ with gr.Column():
452
+ shap_vocpshi = gr.Plot(label="Vaso-Occlusive Crisis Post-HCT")
453
+ with gr.Column():
454
+ shap_efs = gr.Plot(label="Event-Free Survival")
455
+ with gr.Column():
456
+ shap_stroke = gr.Plot(label="Stroke Post-HCT")
457
+ with gr.Column():
458
+ shap_os = gr.Plot(label="Overall Survival")
459
+
460
+ btn.click(
461
+ fn=predict_gradio,
462
+ inputs=inputs_list,
463
+ outputs=[
464
+ output_table,
465
+ icon_dead, icon_gf, icon_agvhd, icon_cgvhd, icon_vocpshi, icon_stroke,
466
+ shap_dead, shap_gf, shap_agvhd, shap_cgvhd,
467
+ shap_vocpshi, shap_efs, shap_stroke, shap_os,
468
+ ],
469
+ )
470
+
471
+
472
+ if __name__ == "__main__":
473
+ demo.launch(
474
+ ssr_mode=False,
475
+ css=custom_css, # css moved to launch() in Gradio 6
476
+ )
ensemble_model_AGVHD.skops ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d82b622617aed90a5925d450ce6ba40909f61f012570566e62c7581b86e9a2e
3
+ size 74999945
ensemble_model_CGVHD.skops ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d05b99633060d90ca94e48b332efd52301d0c87631f8c8c5f638a438bc68aece
3
+ size 92583386
ensemble_model_DEAD.skops ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5cd000ae7752e4852c3271c9ea434c34830f7e8007ec107b7715833852028ed
3
+ size 66979465
ensemble_model_DWOGF.skops ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:712b614d45498facc9cbd881ea63570f4a5130006877d8f765dbc52d6c787ae8
3
+ size 67038095
ensemble_model_GF.skops ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22d7e3b82157a0877134653b7f14d5320b8047811283b4ba4a51b9cf332c9e78
3
+ size 40044086
ensemble_model_STROKEHI.skops ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81f610182edda56a4ace74988212f3ad37888e32860d4531e2b0381914fcb7f2
3
+ size 41369307
ensemble_model_VOCPSHI.skops ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9d46bb9b61c8034ff400aeae5778e78755310f33eb11a7f3655f29dd7dcd853
3
+ size 46389465
inference (1).py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import skops.io as sio
4
+ import shap
5
+ import plotly.graph_objects as go
6
+ import os
7
+ import sys
8
+ import warnings
9
+
10
+ warnings.filterwarnings("ignore", category=UserWarning, module="sklearn")
11
+
12
+ print("===== Application Startup =====")
13
+ print(f"Working directory: {os.getcwd()}")
14
+ print(f"Files present: {os.listdir('.')}")
15
+
16
+ # ---------------------------------------------------------------------------
17
+ # Compatibility patch
18
+ # ---------------------------------------------------------------------------
19
+ import sklearn.compose._column_transformer as _ct
20
+ if not hasattr(_ct, "_RemainderColsList"):
21
+ class _RemainderColsList(list):
22
+ def __init__(self, lst=None, future_dtype=None):
23
+ super().__init__(lst or [])
24
+ self.future_dtype = future_dtype
25
+ _ct._RemainderColsList = _RemainderColsList
26
+ import sklearn.compose
27
+ sklearn.compose._RemainderColsList = _RemainderColsList
28
+ print("Patched _RemainderColsList into sklearn.compose")
29
+
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Column / feature definitions
33
+ # ---------------------------------------------------------------------------
34
+
35
+ NUM_COLUMNS = ["AGE", "NACS2YR"]
36
+ CATEG_COLUMNS = [
37
+ "AGEGPFF", "SEX", "KPS", "DONORF", "GRAFTYPE", "CONDGRPF",
38
+ "CONDGRP_FINAL", "ATGF", "GVHD_FINAL", "HLA_FINAL",
39
+ "RCMVPR", "EXCHTFPR", "VOC2YPR", "VOCFRQPR", "SCATXRSN",
40
+ ]
41
+
42
+ FEATURE_NAMES = NUM_COLUMNS + CATEG_COLUMNS
43
+
44
+ OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", "DWOGF"]
45
+ CLASSIFICATION_OUTCOMES = OUTCOMES
46
+
47
+ REPORTING_OUTCOMES = [
48
+ "OS", "EFS", "GF", "DEAD",
49
+ "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI",
50
+ ]
51
+
52
+ OUTCOME_DESCRIPTIONS = {
53
+ "OS": "Overall Survival",
54
+ "EFS": "Event-Free Survival",
55
+ "DEAD": "Total Mortality",
56
+ "GF": "Graft Failure",
57
+ "AGVHD": "Acute Graft-versus-Host Disease",
58
+ "CGVHD": "Chronic Graft-versus-Host Disease",
59
+ "VOCPSHI": "Vaso-Occlusive Crisis Post-HCT",
60
+ "STROKEHI": "Stroke Post-HCT",
61
+ }
62
+
63
+ SHAP_OUTCOMES = ["DEAD", "GF", "AGVHD", "CGVHD", "VOCPSHI", "STROKEHI", "OS", "EFS"]
64
+
65
+ MODEL_DIR = "."
66
+ CONSENSUS_THRESHOLD = 0.5
67
+ DEFAULT_N_BOOT_CI = 500
68
+
69
+
70
+ # ---------------------------------------------------------------------------
71
+ # Model loading
72
+ # ---------------------------------------------------------------------------
73
+
74
+ def _load_skops_model(fname):
75
+ if not os.path.exists(fname):
76
+ raise RuntimeError(f"Model file not found: {fname}")
77
+ try:
78
+ untrusted = sio.get_untrusted_types(file=fname)
79
+ model = sio.load(fname, trusted=untrusted)
80
+ print(f" Loaded: {fname}")
81
+ return model
82
+ except Exception as e:
83
+ raise RuntimeError(f"Failed to load '{fname}': {type(e).__name__}: {e}") from e
84
+
85
+
86
+ print("Loading preprocessor...")
87
+ preprocessor = _load_skops_model(os.path.join(MODEL_DIR, "preprocessor.skops"))
88
+
89
+ print("Loading ensemble models...")
90
+ classification_model_data = {}
91
+ for _o in CLASSIFICATION_OUTCOMES:
92
+ _path = os.path.join(MODEL_DIR, f"ensemble_model_{_o}.skops")
93
+ if os.path.exists(_path):
94
+ classification_model_data[_o] = _load_skops_model(_path)
95
+ else:
96
+ print(f" Warning: Model for {_o} not found at {_path}. Skipping.")
97
+
98
+ classification_models = {o: d["models"] for o, d in classification_model_data.items()}
99
+ betas = {o: d["beta"] for o, d in classification_model_data.items()}
100
+ priors = {o: d["prior"] for o, d in classification_model_data.items()}
101
+ consensus_thresholds = {
102
+ o: d.get("consensus_threshold", CONSENSUS_THRESHOLD)
103
+ for o, d in classification_model_data.items()
104
+ }
105
+
106
+ calibrators = {}
107
+ for _o, _d in classification_model_data.items():
108
+ _cal = None
109
+ _cal_type = _d.get("calibrator_type", None)
110
+ if "calibrator" in _d and _d["calibrator"] is not None:
111
+ if _cal_type is None or _cal_type == "isotonic":
112
+ _cal = _d["calibrator"]
113
+ else:
114
+ print(f" Warning: outcome '{_o}' has calibrator_type='{_cal_type}'. Skipping.")
115
+ elif "isotonic_calibrator" in _d and _d["isotonic_calibrator"] is not None:
116
+ _cal = _d["isotonic_calibrator"]
117
+ calibrators[_o] = _cal
118
+
119
+ isotonic_calibrators = calibrators
120
+
121
+ oof_probs_calibrated = {
122
+ o: d.get("oof_probs_calibrated") for o, d in classification_model_data.items()
123
+ }
124
+
125
+ ohe = preprocessor.named_transformers_["cat"]
126
+ ohe_feature_names = ohe.get_feature_names_out(CATEG_COLUMNS)
127
+ processed_feature_names = np.concatenate([NUM_COLUMNS, ohe_feature_names])
128
+
129
+ print(f"Models loaded: {list(classification_models.keys())}")
130
+
131
+
132
+ # ---------------------------------------------------------------------------
133
+ # SHAP background data
134
+ # ---------------------------------------------------------------------------
135
+
136
+ print("Building SHAP background...")
137
+ np.random.seed(23)
138
+ _n_background = 500
139
+
140
+ _background_data = {
141
+ "AGE": np.random.uniform(5, 50, _n_background),
142
+ "NACS2YR": np.random.randint(0, 5, _n_background),
143
+ "AGEGPFF": np.random.choice(["<=10", "11-17", "18-29", "30-49", ">=50"], _n_background),
144
+ "SEX": np.random.choice(["Male", "Female"], _n_background),
145
+ "KPS": np.random.choice(["<90", "β‰₯ 90"], _n_background),
146
+ "DONORF": np.random.choice([
147
+ "HLA identical sibling", "HLA mismatch relative",
148
+ "Matched unrelated donor",
149
+ "Mismatched unrelated donor or cord blood",
150
+ ], _n_background),
151
+ "GRAFTYPE": np.random.choice(["Bone marrow", "Peripheral blood", "Cord blood"], _n_background),
152
+ "CONDGRPF": np.random.choice(["MAC", "RIC", "NMA"], _n_background),
153
+ "CONDGRP_FINAL": np.random.choice(["TBI/Cy", "Bu/Cy", "Flu/Bu", "Flu/Mel"], _n_background),
154
+ "ATGF": np.random.choice(["ATG", "Alemtuzumab", "None"], _n_background),
155
+ "GVHD_FINAL": np.random.choice(["CNI + MMF", "CNI + MTX", "Post-CY + siro +- MMF"], _n_background),
156
+ "HLA_FINAL": np.random.choice(["8/8", "7/8", "≀ 6/8"], _n_background),
157
+ "RCMVPR": np.random.choice(["Negative", "Positive"], _n_background),
158
+ "EXCHTFPR": np.random.choice(["No", "Yes"], _n_background),
159
+ "VOC2YPR": np.random.choice(["No", "Yes"], _n_background),
160
+ "VOCFRQPR": np.random.choice(["< 3/yr", "β‰₯ 3/yr"], _n_background),
161
+ "SCATXRSN": np.random.choice([
162
+ "CNS event", "Acute chest Syndrome",
163
+ "Recurrent vaso-occlusive pain", "Recurrent priapism",
164
+ "Excessive transfusion requirements/iron overload",
165
+ "Cardio-pulmonary", "Chronic transfusion", "Asymptomatic",
166
+ "Renal insufficiency", "Splenic sequestration",
167
+ "Avascular necrosis", "Hodgkin lymphoma",
168
+ ], _n_background),
169
+ }
170
+
171
+ _background_df = pd.DataFrame(_background_data)[FEATURE_NAMES]
172
+ _X_background = preprocessor.transform(_background_df)
173
+ shap_background = shap.maskers.Independent(_X_background)
174
+ print("SHAP background ready.")
175
+
176
+
177
+ # ---------------------------------------------------------------------------
178
+ # Calibration helpers
179
+ # ---------------------------------------------------------------------------
180
+
181
+ def calibrate_probabilities_undersampling(p_s, beta):
182
+ p_s = np.asarray(p_s, dtype=float)
183
+ numerator = beta * p_s
184
+ denominator = np.maximum((beta - 1.0) * p_s + 1.0, 1e-10)
185
+ return np.clip(numerator / denominator, 0.0, 1.0)
186
+
187
+
188
+ def predict_consensus_signed_voting(ensemble_models, X_test, threshold=0.5):
189
+ individual_probas = np.array(
190
+ [m.predict_proba(X_test)[:, 1] for m in ensemble_models]
191
+ )
192
+ binary_preds = (individual_probas >= threshold).astype(int)
193
+ signed_votes = np.where(binary_preds == 1, 1, -1)
194
+ avg_signed_vote = np.mean(signed_votes, axis=0)
195
+ consensus_pred = (avg_signed_vote > 0).astype(int)
196
+ avg_proba = np.mean(individual_probas, axis=0)
197
+ return consensus_pred, avg_proba, avg_signed_vote, individual_probas.flatten()
198
+
199
+
200
+ def predict_consensus_majority(ensemble_models, X_test, threshold=0.5):
201
+ individual_probas = np.array(
202
+ [m.predict_proba(X_test)[:, 1] for m in ensemble_models]
203
+ )
204
+ avg_proba = np.mean(individual_probas, axis=0)
205
+ return avg_proba, individual_probas.flatten()
206
+
207
+
208
+ # ---------------------------------------------------------------------------
209
+ # Bootstrap CI
210
+ # ---------------------------------------------------------------------------
211
+
212
+ def bootstrap_ci_from_oof(
213
+ point_estimate: float,
214
+ oof_probs: np.ndarray,
215
+ n_boot: int = DEFAULT_N_BOOT_CI,
216
+ confidence: float = 0.95,
217
+ random_state: int = 42,
218
+ ) -> tuple:
219
+ if oof_probs is None or len(oof_probs) == 0:
220
+ return float(point_estimate), float(point_estimate)
221
+
222
+ oof_probs = np.asarray(oof_probs, dtype=float)
223
+ rng = np.random.RandomState(random_state)
224
+ grand_mean = np.mean(oof_probs)
225
+ n = len(oof_probs)
226
+
227
+ boot_means = np.array([
228
+ np.mean(rng.choice(oof_probs, size=n, replace=True))
229
+ for _ in range(n_boot)
230
+ ])
231
+
232
+ shift = point_estimate - grand_mean
233
+ boot_means = boot_means + shift
234
+
235
+ alpha = 1.0 - confidence
236
+ lo = float(np.clip(np.percentile(boot_means, 100 * alpha / 2), 0.0, 1.0))
237
+ hi = float(np.clip(np.percentile(boot_means, 100 * (1 - alpha / 2)), 0.0, 1.0))
238
+ return lo, hi
239
+
240
+
241
+ # ---------------------------------------------------------------------------
242
+ # Calibration dispatch
243
+ # ---------------------------------------------------------------------------
244
+
245
+ def _calibrate_point(outcome: str, raw_prob: float, use_calibration: bool) -> float:
246
+ beta = betas[outcome]
247
+ p_beta = float(calibrate_probabilities_undersampling([raw_prob], beta)[0])
248
+ if not use_calibration:
249
+ return p_beta
250
+ cal = calibrators.get(outcome)
251
+ if cal is None:
252
+ return p_beta
253
+ return float(cal.transform([p_beta])[0])
254
+
255
+
256
+ # ---------------------------------------------------------------------------
257
+ # Main prediction functions
258
+ # ---------------------------------------------------------------------------
259
+
260
+ def predict_all_outcomes(
261
+ user_inputs,
262
+ use_calibration: bool = True,
263
+ use_signed_voting: bool = True,
264
+ n_boot_ci: int = DEFAULT_N_BOOT_CI,
265
+ ):
266
+ if isinstance(user_inputs, dict):
267
+ input_df = pd.DataFrame([user_inputs])
268
+ else:
269
+ input_df = pd.DataFrame([user_inputs], columns=FEATURE_NAMES)
270
+
271
+ input_df = input_df[FEATURE_NAMES]
272
+ X = preprocessor.transform(input_df)
273
+
274
+ probs, intervals = {}, {}
275
+
276
+ for o in CLASSIFICATION_OUTCOMES:
277
+ if o not in classification_models:
278
+ continue
279
+
280
+ threshold = consensus_thresholds.get(o, CONSENSUS_THRESHOLD)
281
+
282
+ if use_signed_voting:
283
+ _, uncalib_arr, _, _ = predict_consensus_signed_voting(
284
+ classification_models[o], X, threshold
285
+ )
286
+ else:
287
+ uncalib_arr, _ = predict_consensus_majority(
288
+ classification_models[o], X, threshold
289
+ )
290
+
291
+ raw_prob = float(uncalib_arr[0])
292
+ event_prob = _calibrate_point(o, raw_prob, use_calibration)
293
+
294
+ lo, hi = bootstrap_ci_from_oof(
295
+ point_estimate=event_prob,
296
+ oof_probs=oof_probs_calibrated.get(o),
297
+ n_boot=n_boot_ci,
298
+ )
299
+
300
+ probs[o] = event_prob
301
+ intervals[o] = (lo, hi)
302
+
303
+ # OS = 1 - P(DEAD)
304
+ if "DEAD" in probs:
305
+ p_dead = probs["DEAD"]
306
+ probs["OS"] = float(1.0 - p_dead)
307
+ dead_lo, dead_hi = intervals["DEAD"]
308
+ intervals["OS"] = (
309
+ float(np.clip(1.0 - dead_hi, 0, 1)),
310
+ float(np.clip(1.0 - dead_lo, 0, 1)),
311
+ )
312
+
313
+ # EFS = 1 - P(DWOGF) - P(GF)
314
+ if "DWOGF" in probs and "GF" in probs:
315
+ p_dwogf = probs["DWOGF"]
316
+ p_gf = probs["GF"]
317
+ probs["EFS"] = float(np.clip(1.0 - p_dwogf - p_gf, 0.0, 1.0))
318
+
319
+ oof_dwogf = oof_probs_calibrated.get("DWOGF")
320
+ oof_gf = oof_probs_calibrated.get("GF")
321
+
322
+ if oof_dwogf is not None and oof_gf is not None:
323
+ oof_dwogf = np.asarray(oof_dwogf, dtype=float)
324
+ oof_gf = np.asarray(oof_gf, dtype=float)
325
+ n_min = min(len(oof_dwogf), len(oof_gf))
326
+ oof_dwogf = oof_dwogf[:n_min]
327
+ oof_gf = oof_gf[:n_min]
328
+
329
+ rng = np.random.RandomState(42)
330
+ grand_dwogf = np.mean(oof_dwogf)
331
+ grand_gf = np.mean(oof_gf)
332
+ shift_dwogf = p_dwogf - grand_dwogf
333
+ shift_gf = p_gf - grand_gf
334
+
335
+ efs_boot = np.array([
336
+ np.clip(
337
+ 1.0
338
+ - (np.mean(rng.choice(oof_dwogf, size=n_min, replace=True)) + shift_dwogf)
339
+ - (np.mean(rng.choice(oof_gf, size=n_min, replace=True)) + shift_gf),
340
+ 0.0, 1.0,
341
+ )
342
+ for _ in range(n_boot_ci)
343
+ ])
344
+ intervals["EFS"] = (
345
+ float(np.percentile(efs_boot, 2.5)),
346
+ float(np.percentile(efs_boot, 97.5)),
347
+ )
348
+ else:
349
+ intervals["EFS"] = (probs["EFS"], probs["EFS"])
350
+
351
+ return probs, intervals
352
+
353
+
354
+ def predict_with_comparison(user_inputs, n_boot_ci: int = DEFAULT_N_BOOT_CI):
355
+ cal_probs, cal_intervals = predict_all_outcomes(user_inputs, True, True, n_boot_ci)
356
+ uncal_probs, uncal_intervals = predict_all_outcomes(user_inputs, False, True, n_boot_ci)
357
+ return (cal_probs, cal_intervals), (uncal_probs, uncal_intervals)
358
+
359
+
360
+ # ---------------------------------------------------------------------------
361
+ # SHAP helpers
362
+ # ---------------------------------------------------------------------------
363
+
364
+ def _get_shap_values_for_model_outcome(user_inputs, model_outcome, invert, X_proc):
365
+ all_model_shap_vals = []
366
+ for rf_model in classification_models[model_outcome]:
367
+ explainer = shap.TreeExplainer(rf_model, model_output="probability", data=shap_background)
368
+ shap_vals = explainer.shap_values(X_proc)
369
+ if isinstance(shap_vals, list):
370
+ shap_vals = shap_vals[1]
371
+ elif shap_vals.ndim == 3 and shap_vals.shape[2] == 2:
372
+ shap_vals = shap_vals[:, :, 1]
373
+ sv = shap_vals[0]
374
+ if invert:
375
+ sv = -sv
376
+ all_model_shap_vals.append(sv)
377
+ return np.array(all_model_shap_vals)
378
+
379
+
380
+ def compute_shap_values_with_direction(user_inputs, outcome, max_display=10):
381
+ if isinstance(user_inputs, dict):
382
+ input_df = pd.DataFrame([user_inputs])
383
+ else:
384
+ input_df = pd.DataFrame([user_inputs], columns=FEATURE_NAMES)
385
+
386
+ X_proc = preprocessor.transform(input_df)
387
+
388
+ processed_to_orig = {f: f for f in NUM_COLUMNS}
389
+ for pf in ohe_feature_names:
390
+ processed_to_orig[pf] = pf.split("_", 1)[0]
391
+
392
+ if outcome == "OS":
393
+ raw_shap = _get_shap_values_for_model_outcome(user_inputs, "DEAD", invert=True, X_proc=X_proc)
394
+ elif outcome == "EFS":
395
+ shap_dwogf = _get_shap_values_for_model_outcome(user_inputs, "DWOGF", invert=True, X_proc=X_proc)
396
+ shap_gf = _get_shap_values_for_model_outcome(user_inputs, "GF", invert=True, X_proc=X_proc)
397
+ raw_shap = np.concatenate([shap_dwogf, shap_gf], axis=0)
398
+ else:
399
+ raw_shap = _get_shap_values_for_model_outcome(user_inputs, outcome, invert=False, X_proc=X_proc)
400
+
401
+ unique_orig_features = list(dict.fromkeys(processed_to_orig.values()))
402
+ n_models = len(raw_shap)
403
+
404
+ model_shap_by_orig = np.zeros((n_models, len(unique_orig_features)))
405
+ for model_idx in range(n_models):
406
+ agg_by_orig = {}
407
+ for i, pf in enumerate(processed_feature_names):
408
+ orig = processed_to_orig[pf]
409
+ agg_by_orig.setdefault(orig, 0.0)
410
+ agg_by_orig[orig] += raw_shap[model_idx, i]
411
+ for feat_idx, feat_name in enumerate(unique_orig_features):
412
+ model_shap_by_orig[model_idx, feat_idx] = agg_by_orig.get(feat_name, 0.0)
413
+
414
+ mean_shap_vals = np.mean(model_shap_by_orig, axis=0)
415
+
416
+ rng = np.random.RandomState(42)
417
+ bootstrap_shap_means = np.array([
418
+ np.mean(model_shap_by_orig[rng.choice(n_models, size=n_models, replace=True)], axis=0)
419
+ for _ in range(DEFAULT_N_BOOT_CI)
420
+ ])
421
+ shap_ci_low = np.percentile(bootstrap_shap_means, 2.5, axis=0)
422
+ shap_ci_high = np.percentile(bootstrap_shap_means, 97.5, axis=0)
423
+
424
+ order = np.argsort(-np.abs(mean_shap_vals))
425
+
426
+ top_feat_names = []
427
+ for i in order[:max_display]:
428
+ feat_name = unique_orig_features[i]
429
+ if feat_name in user_inputs:
430
+ val = user_inputs[feat_name]
431
+ if isinstance(val, float) and val != int(val):
432
+ display_name = f"{feat_name} = {val:.2f}"
433
+ elif isinstance(val, (int, float)):
434
+ display_name = f"{feat_name} = {int(val)}"
435
+ else:
436
+ val_str = str(val)
437
+ if len(val_str) > 20:
438
+ val_str = val_str[:17] + "..."
439
+ display_name = f"{feat_name} = {val_str}"
440
+ else:
441
+ display_name = feat_name
442
+ top_feat_names.append(display_name)
443
+
444
+ top_feat_names = top_feat_names[::-1]
445
+ top_shap_vals = mean_shap_vals[order][:max_display][::-1]
446
+ top_ci_low = shap_ci_low[order][:max_display][::-1]
447
+ top_ci_high = shap_ci_high[order][:max_display][::-1]
448
+
449
+ return top_feat_names, top_shap_vals, top_ci_low, top_ci_high
450
+
451
+
452
+ def create_shap_plot(user_inputs, outcome, max_display=10):
453
+ feat_names, shap_vals, ci_low, ci_high = compute_shap_values_with_direction(
454
+ user_inputs, outcome, max_display
455
+ )
456
+
457
+ colors = ["blue" if v >= 0 else "red" for v in shap_vals]
458
+ error_minus = shap_vals - ci_low
459
+ error_plus = ci_high - shap_vals
460
+
461
+ fig = go.Figure()
462
+ fig.add_trace(go.Bar(
463
+ y=feat_names,
464
+ x=shap_vals,
465
+ orientation="h",
466
+ marker=dict(color=colors),
467
+ showlegend=False,
468
+ error_x=dict(
469
+ type="data",
470
+ symmetric=False,
471
+ array=error_plus,
472
+ arrayminus=error_minus,
473
+ color="gray",
474
+ thickness=1.5,
475
+ width=4,
476
+ ),
477
+ ))
478
+ fig.add_vline(x=0, line_width=1, line_color="black")
479
+
480
+ fig.update_layout(
481
+ title=dict(
482
+ text=OUTCOME_DESCRIPTIONS.get(outcome, outcome),
483
+ x=0.5, xanchor="center",
484
+ font=dict(size=14, color="black"),
485
+ ),
486
+ xaxis_title="SHAP value",
487
+ yaxis_title="",
488
+ height=400,
489
+ margin=dict(l=120, r=60, t=50, b=50),
490
+ plot_bgcolor="white",
491
+ paper_bgcolor="white",
492
+ xaxis=dict(showgrid=True, gridcolor="lightgray", zeroline=True,
493
+ zerolinecolor="black", zerolinewidth=1),
494
+ yaxis=dict(showgrid=False),
495
+ )
496
+ return fig
497
+
498
+
499
+ def create_all_shap_plots(user_inputs, max_display=10):
500
+ return {o: create_shap_plot(user_inputs, o, max_display) for o in SHAP_OUTCOMES}
501
+
502
+
503
+ # ---------------------------------------------------------------------------
504
+ # Icon array
505
+ # ---------------------------------------------------------------------------
506
+ # Root cause of previous gaps / distortion:
507
+ # Plotly shape coords are in DATA units. If px-per-data-unit differs on
508
+ # x vs y axes the circle head becomes an ellipse and spacing looks uneven.
509
+ #
510
+ # Fix:
511
+ # β€’ Use EQUAL axis spans on x and y (both = cols + 2*pad = 10.3)
512
+ # β€’ Set width and height so that usable pixels are EQUAL on both axes:
513
+ # usable_w = W - margin_l - margin_r = W - 20
514
+ # usable_h = H - margin_t - margin_b = H - 100
515
+ # usable_w == usable_h β†’ H = W + 80
516
+ # β€’ This guarantees 1 data-unit = same number of pixels on both axes,
517
+ # so circles are round and spacing is perfectly uniform.
518
+ # ---------------------------------------------------------------------------
519
+
520
+ def _stick_figure(cx, cy, color, s):
521
+ """
522
+ Returns Plotly shape dicts for a stick figure centred at (cx, cy).
523
+ s = scale (data units). With a cell size of 1.0, s β‰ˆ 0.46 gives
524
+ a figure that fills ~75 % of the cell vertically.
525
+
526
+ Anatomy (all offsets relative to cy):
527
+ head centre : cy + s*0.55 radius s*0.18
528
+ neck top : cy + s*0.35
529
+ hip : cy - s*0.15
530
+ arm branch : cy + s*0.18
531
+ foot : cy - s*0.55
532
+ """
533
+ shapes = []
534
+ lw = dict(color=color, width=1.8) # fixed pixel width β€” looks consistent
535
+
536
+ # head
537
+ hr = s * 0.18
538
+ hy = cy + s * 0.55
539
+ shapes.append(dict(
540
+ type="circle", xref="x", yref="y",
541
+ x0=cx - hr, y0=hy - hr,
542
+ x1=cx + hr, y1=hy + hr,
543
+ fillcolor=color,
544
+ line=dict(color=color, width=0),
545
+ ))
546
+
547
+ neck_y = cy + s * 0.35
548
+ hip_y = cy - s * 0.15
549
+ arm_y = cy + s * 0.18
550
+ foot_y = cy - s * 0.55
551
+
552
+ # spine
553
+ shapes.append(dict(type="line", xref="x", yref="y",
554
+ x0=cx, y0=neck_y, x1=cx, y1=hip_y, line=lw))
555
+
556
+ # arms
557
+ adx = s * 0.32
558
+ ady = s * 0.15
559
+ shapes.append(dict(type="line", xref="x", yref="y",
560
+ x0=cx, y0=arm_y, x1=cx - adx, y1=arm_y - ady, line=lw))
561
+ shapes.append(dict(type="line", xref="x", yref="y",
562
+ x0=cx, y0=arm_y, x1=cx + adx, y1=arm_y - ady, line=lw))
563
+
564
+ # legs
565
+ ldx = s * 0.26
566
+ shapes.append(dict(type="line", xref="x", yref="y",
567
+ x0=cx, y0=hip_y, x1=cx - ldx, y1=foot_y, line=lw))
568
+ shapes.append(dict(type="line", xref="x", yref="y",
569
+ x0=cx, y0=hip_y, x1=cx + ldx, y1=foot_y, line=lw))
570
+
571
+ return shapes
572
+
573
+
574
+ def icon_array(probability, outcome):
575
+ outcome_labels = {
576
+ "DEAD": ("Death", "Overall Survival"),
577
+ "GF": ("Graft Failure", "No Graft Failure"),
578
+ "AGVHD": ("AGVHD", "No AGVHD"),
579
+ "CGVHD": ("CGVHD", "No CGVHD"),
580
+ "VOCPSHI": ("VOC Post-HCT", "No VOC Post-HCT"),
581
+ "STROKEHI": ("Stroke Post-HCT", "No Stroke Post-HCT"),
582
+ }
583
+
584
+ event_label, no_event_label = outcome_labels.get(outcome, ("Event", "No Event"))
585
+ n_event = round(probability * 100)
586
+ n_no_event = 100 - n_event
587
+ cols, rows = 10, 10
588
+
589
+ # ── Layout constants ──────────────────────────────────────────────────
590
+ # Icons sit on an integer grid 0..9 Γ— 0..9.
591
+ # Padding of 0.65 on each side β†’ axis span = 9 + 2*0.65 = 10.30
592
+ # Margins: left=10, right=10, top=95, bottom=10
593
+ # usable_w = W - 20 ; usable_h = H - 105
594
+ # To ensure px_per_unit identical on both axes: usable_w == usable_h
595
+ # β†’ H = W + 85
596
+ # We also enforce equal axis spans (both 10.30).
597
+
598
+ PAD = 0.65
599
+ W = 400
600
+ H = W + 85 # = 485 β†’ usable = 380 px on both axes
601
+ S = 0.46 # figure scale (β‰ˆ 75 % vertical fill per cell)
602
+
603
+ x_lo, x_hi = -PAD, (cols - 1) + PAD # -0.65 … 9.65 span=10.30
604
+ y_lo, y_hi = -PAD, (rows - 1) + PAD # -0.65 … 9.65 span=10.30
605
+
606
+ all_shapes = []
607
+ icon_idx = 0
608
+
609
+ for row in range(rows): # row 0 β†’ top of grid
610
+ for col in range(cols): # col 0 β†’ left
611
+ color = "#e05555" if icon_idx < n_event else "#3bbfad"
612
+ cx = col
613
+ cy = (rows - 1) - row # invert: row 0 β†’ cy=9 (top)
614
+ all_shapes.extend(_stick_figure(cx, cy, color, S))
615
+ icon_idx += 1
616
+
617
+ fig = go.Figure()
618
+ fig.update_layout(
619
+ title=dict(
620
+ text=(
621
+ f"<b>{OUTCOME_DESCRIPTIONS.get(outcome, outcome)}</b><br>"
622
+ f"<span style='font-size:12px;color:#e05555'>"
623
+ f"β–  {event_label}: {n_event}%</span>"
624
+ f"&nbsp;&nbsp;"
625
+ f"<span style='font-size:12px;color:#3bbfad'>"
626
+ f"β–  {no_event_label}: {n_no_event}%</span>"
627
+ ),
628
+ x=0.5, xanchor="center",
629
+ font=dict(size=14, color="black"),
630
+ ),
631
+ shapes=all_shapes,
632
+ xaxis=dict(
633
+ range=[x_lo, x_hi],
634
+ showgrid=False, zeroline=False, showticklabels=False,
635
+ fixedrange=True,
636
+ ),
637
+ yaxis=dict(
638
+ range=[y_lo, y_hi],
639
+ showgrid=False, zeroline=False, showticklabels=False,
640
+ fixedrange=True,
641
+ # scaleanchor / scaleratio intentionally OMITTED β€”
642
+ # equal spans + equal usable pixels already guarantee
643
+ # identical px/unit on both axes without distortion.
644
+ ),
645
+ width=W,
646
+ height=H,
647
+ margin=dict(l=10, r=10, t=95, b=10),
648
+ plot_bgcolor="white",
649
+ paper_bgcolor="white",
650
+ )
651
+ return fig
652
+
653
+
654
+ print("===== inference.py loaded successfully =====")
preprocessor.skops ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dda5bfd9eeac05a1c29b9c28d3dd10911e651002a7273e96848cd33fe0706b94
3
+ size 123373
requirements (3).txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ pandas
3
+ numpy
4
+ scikit-learn==1.8.0
5
+ shap
6
+ matplotlib
7
+ plotly
8
+ skops
9
+ openpyxl