File size: 24,494 Bytes
7c1fdd8
 
9a69528
 
 
7c1fdd8
cde0a42
7c1fdd8
 
 
fb3ebfd
 
288b056
31e94bf
 
 
320be0c
 
31e94bf
4803058
31e94bf
 
887a3cd
4803058
9fd2250
887a3cd
 
9fd2250
887a3cd
9fd2250
 
887a3cd
9fd2250
 
 
 
 
 
 
887a3cd
fb3ebfd
 
9fd2250
7c1fdd8
 
 
 
 
 
 
9a69528
 
7c1fdd8
 
 
 
 
 
9a69528
eaa285b
 
 
 
 
 
 
 
 
 
f883f36
 
 
 
 
 
 
 
 
 
 
f6b97e1
 
f883f36
1417aba
f883f36
1417aba
 
 
 
 
 
 
 
 
 
 
 
f883f36
 
 
 
f6b97e1
 
 
 
 
1417aba
 
f6b97e1
 
97e42d7
 
 
 
 
7f8103a
 
 
97e42d7
 
 
 
 
7f8103a
 
 
97e42d7
 
7f8103a
97e42d7
 
7f8103a
 
97e42d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f8103a
7c1fdd8
 
9a69528
eaa285b
8a6baf9
9a69528
 
7c1fdd8
 
110bdce
7c1fdd8
 
110bdce
a2f3dad
 
97e42d7
7c1fdd8
 
110bdce
7c1fdd8
 
 
 
 
 
 
110bdce
7c1fdd8
 
 
 
110bdce
 
7c1fdd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110bdce
7c1fdd8
 
 
 
 
 
 
97e42d7
 
476f261
97e42d7
 
 
 
 
 
 
 
7c1fdd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110bdce
 
7c1fdd8
 
 
 
 
 
 
 
 
 
 
 
cde0a42
 
 
 
 
 
 
 
 
 
 
 
e0ca17a
cde0a42
 
 
 
 
 
 
 
 
 
 
e0ca17a
cde0a42
 
 
 
7c1fdd8
 
cde0a42
 
 
 
 
 
7c1fdd8
 
 
 
 
 
 
 
 
 
 
 
 
97e42d7
 
 
 
 
 
 
 
 
 
 
7c1fdd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cde0a42
 
7c1fdd8
 
 
110bdce
cde0a42
7c1fdd8
 
 
 
 
 
 
 
 
 
110bdce
7c1fdd8
 
 
f6b97e1
 
320be0c
 
 
 
 
 
 
 
6fd3b24
ec93275
 
 
6fd3b24
 
ec93275
6fd3b24
ec93275
 
 
 
 
 
 
 
 
 
 
6fd3b24
ec93275
 
6fd3b24
ec93275
 
 
 
6fd3b24
ec93275
320be0c
cde0a42
320be0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb3ebfd
 
288b056
 
 
9fd2250
 
887a3cd
2927869
887a3cd
288b056
887a3cd
 
 
288b056
4803058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7218012
4bfed28
 
 
 
 
 
 
 
 
2b230f4
 
 
 
4bfed28
2b230f4
 
4bfed28
2b230f4
4bfed28
2b230f4
 
 
 
 
4bfed28
 
2b230f4
4bfed28
 
 
 
 
 
 
 
 
fb3ebfd
2ffd3ab
1417aba
 
 
 
 
2ffd3ab
 
f6b97e1
 
 
 
1417aba
f6b97e1
 
 
1417aba
 
 
 
 
 
 
f6b97e1
 
 
1417aba
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
import datetime as dt

import pandas as pd
import streamlit as st

from sidebar import sidebar
from model_utils import load_model, load_model_ensemble, ensemble_predict, get_latest_model_name_hf
from preprocess_utils import load_train_features
from preprocess_utils import preprocess_pipeline as preprocess
from inference_utils import add_predictions
import numpy as np
import matplotlib.pyplot as plt


import sys
from pathlib import Path
import shap
from inference_utils import ensemble_shap

ROOT = Path(__file__).resolve().parents[1]  # /app
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from src.survival_utils import prepare_cox_df, fit_cox, make_patient_design_row, predict_patient_survival
from src.cox_persist import load_cox_artifacts, ensure_cox_artifacts_available

@st.cache_resource
def get_or_load_cox():
    """
    Load Cox artifacts trained in Bulk page.
    If running on Spaces without persistence, download from HF dataset repo when missing.
    """
    # Ensure artifacts exist locally (downloads to /tmp/saved_models by default)
    ensure_cox_artifacts_available()  # downloads cox_os_model.joblib + cox_os_meta.json

    payload, meta = load_cox_artifacts(prefix="cox_os")  # loads from DEFAULT_DIR (/tmp/saved_models)
    cph = payload["cph"]
    design_cols = payload["design_cols"]
    cat_cols = payload["cat_cols"]
    return cph, design_cols, cat_cols, meta



# --- Country options for UI (pycountry backed) ---
try:
    import pycountry
    COUNTRY_OPTIONS = sorted({c.name for c in pycountry.countries})
    COUNTRY_OPTIONS = sorted(set(COUNTRY_OPTIONS) | {"UAE", "KSA", "UK", "USA", "US", "RUSSIA", "IRAN", "SYRIA", "PALESTINE"})
except Exception:
    COUNTRY_OPTIONS = sorted(["UAE", "KSA", "UK", "USA", "INDIA", "PAKISTAN", "EGYPT", "SAUDI ARABIA", "UNITED ARAB EMIRATES"])


st.set_page_config(
    page_title="GVHD-Intel | Individual Prediction",
    page_icon="🧬",
    layout="wide",
    initial_sidebar_state="expanded",
)


st.markdown(
    """
    <style>
    /* Hide the first page in the sidebar navigation (this is typically the main app.py page) */
    [data-testid="stSidebarNav"] ul li:first-child {display: none;}
    </style>
    """,
    unsafe_allow_html=True
)
# ---------- App Header ----------

st.markdown(
    """
    <h1 style='text-align: center; margin-bottom: 10px;'>
        GVHD-Intel Pro
    </h1>
    """,
    unsafe_allow_html=True
)

st.markdown(
    """
    <div style='font-size:16px; line-height:1.6;'>
    A modular clinical decision-support framework for predicting <b>Acute GVHD</b> and <b>Chronic GVHD</b> risk after HSCT using machine learning models trained on transplant datasets.<br><br>

    <b>Prediction workflow:</b><br>
    • Use the <b>Acute GVHD</b> or <b>Chronic GVHD</b> tab below and click the corresponding <b>Predict</b> button.<br>
    • The app will automatically use the <b>latest ensemble model</b> available for the selected outcome (Acute = <b>A</b>, Chronic = <b>C</b>) based on the timestamp in the model name.<br>
    • Admin use: The sidebar model selector remains available for browsing models, but tab-based prediction will prioritize the latest outcome-specific ensemble.<br><br>

    <b>Explainability:</b><br>
    • Displays the <b>top 20 risk factors</b> contributing to each prediction.<br>
    • Provides a <b>SHAP waterfall plot</b> for the individual patient (for ensemble models, SHAP values are averaged across models).<br><br>

    <b>Survival module:</b><br>
    • Supports adjusted overall survival (all-cause mortality) risk estimation using <b>Cox Proportional Hazards</b> analysis.<br>
    • Generates an individualized adjusted survival curve with landmark estimates (e.g., 3-, 5-, and 10-year survival).<br><br>

    Enter recipient and donor details below to generate individualized risk predictions.<br><br>

    <b>Research tool only — not validated for clinical use.</b>
    </div>
    """,
    unsafe_allow_html=True
)



st.divider()

def two_alleles_number_inputs(label: str, key_prefix: str, max_allele: int = 500, disabled: bool = False):
    """
    Two numeric allele inputs. 0 = not provided.
    Returns [] if both 0 else list of strings, e.g. ["12","24"].
    """
    c1, c2 = st.columns(2)
    with c1:
        a1 = st.number_input(
            f"{label} allele 1",
            min_value=0, max_value=max_allele, step=1,
            value=0,
            key=f"{key_prefix}_1",
            disabled=disabled,
        )
    with c2:
        a2 = st.number_input(
            f"{label} allele 2",
            min_value=0, max_value=max_allele, step=1,
            value=0,
            key=f"{key_prefix}_2",
            disabled=disabled,
        )

    a1, a2 = int(a1), int(a2)
    out = []
    if a1 != 0:
        out.append(str(a1))
    if a2 != 0 and a2 != a1:
        out.append(str(a2))
    return out


def hla_block(person: str, key_prefix: str, enabled: bool):
    if not enabled:
        return {
            f"{key_prefix}_A":  ["Unknown"],
            f"{key_prefix}_B":  ["Unknown"],
            f"{key_prefix}_C":  ["Unknown"],
            f"{key_prefix}_DR": ["Unknown"],
            f"{key_prefix}_DQ": ["Unknown"],
        }

    st.markdown(f"###### {person} HLA Alleles")

    def _val(x):
        return x if (x and len(x) > 0) else ["Unknown"]

    return {
        f"{key_prefix}_A":  _val(two_alleles_number_inputs(f"{person} HLA-A",  f"{key_prefix}_A")),
        f"{key_prefix}_B":  _val(two_alleles_number_inputs(f"{person} HLA-B",  f"{key_prefix}_B")),
        f"{key_prefix}_C":  _val(two_alleles_number_inputs(f"{person} HLA-C",  f"{key_prefix}_C")),
        f"{key_prefix}_DR": _val(two_alleles_number_inputs(f"{person} HLA-DR", f"{key_prefix}_DR")),
        f"{key_prefix}_DQ": _val(two_alleles_number_inputs(f"{person} HLA-DQ", f"{key_prefix}_DQ")),
    }

# Initialize sidebar (model/target/threshold defaults etc.)
sidebar()

st.title("👤 Individual Patient Risk Prediction")
st.caption("Enter recipient, donor, and transplant variables including HLA alleles to generate individualized risk estimates")


# Ensure this exists so add_predictions() can join safely
st.session_state.targets_df = pd.DataFrame()

min_d = dt.date(1950, 1, 1)
max_d = dt.date(2050, 12, 31)



with st.container():
    st.subheader("Recepient Information", divider=True)
    gender = st.radio("Recepient Gender", ["MALE", "FEMALE"], index=None)

    dob = st.date_input(
        "Recepient DOB",
        value=dt.date(2000, 1, 31),
        min_value=min_d,
        max_value=max_d,
        format="DD/MM/YYYY",
    )

    nationality = st.selectbox(
        "Recepient Nationality (Country)",
        COUNTRY_OPTIONS,
        index=None,
    )

    diagnosis = st.selectbox(
        "Hematological Diagnosis",
        sorted([
            "ACUTE MYELOID LEUKEMIA", "ALPHA THALSSEMIA", "AMYLOIDOSIS",
            "APLASTIC ANEMIA", "BALL", "BETA THALESSEMIA",
            "BLASTIC PLASMACYTOID DENDRITRIC CELL NEOPLASM",
            "CHRONIC GRANULOMATOUS DISEASE", "CHRONIC LYMPHOCYTIC LEUKEMIA",
            "CML", "COMBINED VARIABLE IMMUNODEFICIENCY",
            "DYSKERATOSIS CONGENTIA", "FANCONI ANEMIA",
            "GLANZMANN THROMBASTHENIA",
            "HEMOPHAGOCYTIC LYMPHOHISTIOCYTOSIS (HLH)",
            "HEREDITARY SPHEROCYTOSIS", "HODGKIN LYMPHOMA",
            "HYPOGAMMAGLOBULINEMIA", "LANGERHANS CELL HISTIOCYTOSIS",
            "MYELODYSPLASTIC SYNDROME", "MEDULLOBLASTOMA", "MULTIPLE MYELOMA",
            "MYELOFIBROSIS", "MYELOPROLIFERATIVE DISORDER", "NEUROBLASTOMA",
            "NON HODGKIN LYMPHOMA", "OTHER",
            "PAROXYSMAL NOCTURNAL HEMOGLOBINURIA", "PLASMA CELL LEUKEMIA",
            "SCID", "SICKLE CELL DISEASE", "TALL",
            "X-LINKED HYPER IGM SYNDROME",
        ]),
        index=None,
    )

    diagnosis_date = st.date_input(
        "Date of First Diagnosis / BMBx",
        value=dt.date(2000, 1, 31),
        min_value=min_d,
        max_value=max_d,
        format="DD/MM/YYYY",
    )

    recepient_blood_group = st.radio(
        "Recepient Blood Group",
        ["A+", "A-", "B+", "B-", "O+", "O-", "AB+", "AB-", "Unknown"],
        index=None,
        key="recepient_blood_group",
    )

    st.subheader("Donor Information", divider=True)
    donor_relation = st.radio(
        "Donor Relation to Recepient",
        ["SELF", "SIBLING", "FIRST DEGREE RELATIVE", "SECOND DEGREE RELATIVE", "RELATED", "UNRELATED", "Unknown"],
        index=None,
    )
    st.session_state.SELF = (donor_relation == "SELF")

    donor_gender = st.radio("Donor Gender", ["MALE", "FEMALE"], index=None)

    donor_dob = st.date_input(
        "Donor DOB",
        value=dt.date(2000, 1, 31),
        min_value=min_d,
        max_value=max_d,
        format="DD/MM/YYYY",
    )

    donor_blood_group = st.radio(
        "Donor Blood Group",
        ["A+", "A-", "B+", "B-", "O+", "O-", "AB+", "AB-", "Unknown"],
        index=None,
        key="donor_blood_group",
    )

    st.subheader("Treatment Details", divider=True)
    lines_of_rx = st.selectbox("Number of Lines of Rx Before HSCT", [0, 1, 2, 3, 4, 5, 6, 7, "Unknown"], index=None)

    conditioning = st.multiselect(
        "Pre-HSCT Conditioning Regimen",
        sorted([
            "ALEMTUZUMAB", "ATG", "BEAM", "BUSULFAN", "CAMPATH", "CARMUSTINE",
            "CLOFARABINE", "CYCLOPHOSPHAMIDE", "CYCLOSPORIN", "CYTARABINE",
            "ETOPOSIDE", "FLUDARABINE", "GEMCITABINE", "MELPHALAN",
            "METHOTREXATE", "OTHER", "RANIMUSTINE", "REDUCED_CONDITIONING",
            "RITUXIMAB", "SIROLIMUS", "TBI", "THIOTEPA", "TREOSULFAN",
            "UA", "VORINOSTAT",
        ]),
        placeholder="Choose option(s)",
    )

    st.subheader("HSCT Details", divider=True)
    hsct_date = st.date_input(
        "HSCT Date",
        value=dt.date(2000, 1, 31),
        min_value=min_d,
        max_value=max_d,
        format="DD/MM/YYYY",
    )

    cell_source = st.radio(
        "Source of Cells",
        sorted(["BONE MARROW", "PERIPHERAL BLOOD", "UMBILICAL CORD", "PBSC", "Unknown"]),
        index=None,
    )

    hla_match = st.radio("HLA Match Ratio", ["FULL", "PARTIAL", "HAPLOIDENTICAL", "Unknown"], index=None)

    st.subheader("Transplant Risk Modifiers", divider=True)
    donor_type = st.selectbox(
        "Donor Type",
        ["MRD", "MMRD", "MUD", "MMUD", "HAPLOIDENTICAL", "CORD", "Unknown"],
        index=None,
    )

    r_hla = {"R_HLA_A":["Unknown"], "R_HLA_B":["Unknown"], "R_HLA_C":["Unknown"], "R_HLA_DR":["Unknown"], "R_HLA_DQ":["Unknown"]}
    d_hla = {"D_HLA_A":["Unknown"], "D_HLA_B":["Unknown"], "D_HLA_C":["Unknown"], "D_HLA_DR":["Unknown"], "D_HLA_DQ":["Unknown"]}
    with st.expander("Enter: HLA Allele Details", expanded=True):
        st.caption("Enter HLA allele data if available to enhance prediction accuracy.")
    
        enter_recipient_hla = st.toggle("Enter Recipient HLA alleles", value=False, key="enter_recipient_hla")
        r_hla = hla_block("Recipient", "R_HLA", enabled=enter_recipient_hla)
    
        enter_donor_hla = st.toggle("Enter Donor HLA alleles", value=False, key="enter_donor_hla")
        d_hla = hla_block("Donor", "D_HLA", enabled=enter_donor_hla)
    
    conditioning_intensity = st.radio("Conditioning Intensity", ["MAC", "RIC", "NMA", "Unknown"], index=None)

    last_followup_date = st.date_input(
        "Last Follow-up Date (optional)",
        value=None,
        min_value=min_d,
        max_value=max_d,
        format="DD/MM/YYYY",
    )

    date_of_death = st.date_input(
        "Date of Death (optional)",
        value=None,
        min_value=min_d,
        max_value=max_d,
        format="DD/MM/YYYY",
    )

    gvhd_proph_cat = st.radio(
        "GVHD Prophylaxis Category",
        ["CNI_BASED", "PTCY_BASED", "ATG_BASED", "TCD", "OTHER", "Unknown"],
        index=None,
    )

    st.subheader("Post-HSCT Treatment and GVHD Prophylaxis", divider=True)
    post_hsct_regimen = st.radio("Post-HSCT Regimen", ["YES", "NO", "IVIG", "Unknown"], index=None)

    gvhd_prophylaxis = st.multiselect(
        "First GVHD Prophylaxis",
        sorted([
            "ABATACEPT", "ALEMTUZUMAB", "ATG", "CYCLOPHOSPHAMIDE", "CYCLOSPORIN",
            "IMATINIB", "LEFLUNOMIDE", "METHOTREXATE", "MMF", "NONE",
            "RUXOLITINIB", "SIROLIMUS", "STEROID", "TAC", "TACROLIMUS",
        ]),
        placeholder="Choose option(s)",
    )
    # --- Predict buttons as tabs (sidebar selection remains available) ---
    tab_a, tab_c = st.tabs(["Acute GVHD", "Chronic GVHD"])
    
    override_model_name = None
    submitted = False
    
    with tab_a:
        if st.button("PREDICT ACUTE GVHD (latest ensemble)", type="primary", use_container_width=True, key="pred_acute"):
            try:
                override_model_name = get_latest_model_name_hf(
                    target_initial="A",
                    mode="ensemble",
                    name_contains= None,
                )
                submitted = True
            except Exception as e:
                st.error(f"Could not locate latest Acute ensemble model: {e}")
    
    with tab_c:
        if st.button("PREDICT CHRONIC GVHD (latest ensemble)", type="primary", use_container_width=True, key="pred_chronic"):
            try:
                override_model_name = get_latest_model_name_hf(
                    target_initial="C",
                    mode="ensemble",
                    name_contains= None,
                )
                submitted = True
            except Exception as e:
                st.error(f"Could not locate latest Chronic ensemble model: {e}")

if submitted:
    # If user used tabs, override the sidebar-selected model for THIS run
    model_to_use = override_model_name or st.session_state.selected_model

    # Optional: show which model is being used (highly recommended)
    st.caption(f"Model used: {model_to_use}")

    # Collect input values in a dict
    input_dict = {
        "Recepient_gender": gender,
        "Recepient_DOB": dob.strftime("%d/%m/%Y"),
        "Recepient_Nationality": nationality if nationality else "X",
        "Hematological Diagnosis": diagnosis,
        "Date of first diagnosis/BMBx date": diagnosis_date.strftime("%d/%m/%Y"),
        "Recepient_Blood group before HSCT": recepient_blood_group if recepient_blood_group != "Unknown" else "X",

        "Donor_DOB": donor_dob.strftime("%d/%m/%Y"),
        "Donor_gender": donor_gender,
        "D_Blood group": donor_blood_group if donor_blood_group != "Unknown" else "X",

        "R_HLA_A":  r_hla["R_HLA_A"],
        "R_HLA_B":  r_hla["R_HLA_B"],
        "R_HLA_C":  r_hla["R_HLA_C"],
        "R_HLA_DR": r_hla["R_HLA_DR"],
        "R_HLA_DQ": r_hla["R_HLA_DQ"],
        
        "D_HLA_A":  d_hla["D_HLA_A"],
        "D_HLA_B":  d_hla["D_HLA_B"],
        "D_HLA_C":  d_hla["D_HLA_C"],
        "D_HLA_DR": d_hla["D_HLA_DR"],
        "D_HLA_DQ": d_hla["D_HLA_DQ"],

        "Number of lines of Rx before HSCT": lines_of_rx,
        "PreHSCT conditioning regimen+/-ATG+/-TBI": conditioning,

        "HSCT_date": hsct_date.strftime("%d/%m/%Y"),
        "Last_followup_date": last_followup_date.strftime("%d/%m/%Y") if last_followup_date else "X",
        "Date_of_death": date_of_death.strftime("%d/%m/%Y") if date_of_death else "X",

        "Source of cells": cell_source,
        "Donor_relation to recepient": donor_relation,
        "HLA match ratio": hla_match,

        "Donor_type": donor_type if (donor_type and donor_type != "Unknown") else "X",
        "Conditioning_intensity": conditioning_intensity if (conditioning_intensity and conditioning_intensity != "Unknown") else "X",
        "GVHD_Prophylaxis_Cat": gvhd_proph_cat if (gvhd_proph_cat and gvhd_proph_cat != "Unknown") else "X",

        "Post HSCT regimen": post_hsct_regimen,
        "First_GVHD prophylaxis": gvhd_prophylaxis,
    }

    X_raw = pd.DataFrame([input_dict])

    # Preprocess
    train_features, cat_features = load_train_features()
    df_full, df_model = preprocess(X_raw)
    X_model = df_model.reindex(columns=train_features, fill_value=0).copy()

    # Ensure cat features are strings
    for c in cat_features:
        if c in X_model.columns:
            X_model[c] = X_model[c].astype(str)

    # Predict
    if st.session_state.SELF:
        pred_prob = 0.0
        target_col = "GVHD"
    else:
        if model_to_use.endswith("ensemble"):
            ensemble_data = load_model_ensemble(model_to_use)
            target_col = ensemble_data.get("target_col", "UNKNOWN")
            models = ensemble_data["model"]
            pred_prob = float(ensemble_predict(models, X_model, cat_features)[0])
        else:
            model_dict = load_model(model_to_use)
            target_col = model_dict.get("target_col", "UNKNOWN")
            model = model_dict["model"]
            pred_prob = float(model.predict_proba(X_model)[0][1])

    st.session_state.target_col = target_col

    st.info(
        f"Model predicts target: **{target_col}**. "
        f"Threshold: **{float(st.session_state.get('threshold', 0.5)):.2f}**."
    )

    # Display result
    result_df = add_predictions(pd.DataFrame(index=[0]), [pred_prob])
    st.subheader("Prediction")
    st.dataframe(result_df, use_container_width=False, width=420)

    st.divider()
    st.header("Explainability (SHAP)")
    
    try:
        # For SHAP, use the same model object used for prediction
        # If SELF donor case, prediction is forced to 0.0 and there is no meaningful SHAP
        if st.session_state.SELF:
            st.info("SHAP is not shown for SELF donor (prediction forced to 0).")
        elif model_to_use.endswith("ensemble"):
            ensemble_data = load_model_ensemble(model_to_use)
            models = ensemble_data["model"]
        
            # ensemble_shap() RETURNS shap.Explanation (already averaged across models)
            ens_expl = ensemble_shap(models=models, X=X_model, positive_class=1)
        
            one = ens_expl[0]
            vals = one.values
            feats = np.array(one.feature_names)
        
            top_idx = np.argsort(np.abs(vals))[::-1][:20]
        
            shap_table = pd.DataFrame({
                "Feature": feats[top_idx],
                "Feature value": X_model.iloc[0][feats[top_idx]].values,
                "SHAP value (pushes risk ↑ / ↓)": vals[top_idx],
            })
        
            st.subheader("Top features driving this patient’s prediction (ensemble-mean)")
            st.dataframe(shap_table, use_container_width=True)
        
            st.subheader("Waterfall plot (single patient, ensemble-mean)")
            plt.figure(figsize=(10, 6))
            shap.plots.waterfall(one, max_display=20, show=False)
            st.pyplot(plt.gcf(), bbox_inches="tight")
            plt.clf()
        
            st.caption("SHAP values shown are averaged across ensemble models.")
        else:
            model_dict = load_model(model_to_use)
            model = model_dict["model"]  # XGBoost model
    
            explainer = shap.TreeExplainer(model)
            shap_expl = explainer(X_model)  # shap.Explanation for 1 row
    
            # Handle possible 3D output (some wrappers return [n, p, classes])
            if shap_expl.values.ndim == 3:
                shap_expl = shap.Explanation(
                    values=shap_expl.values[:, :, 1],
                    base_values=shap_expl.base_values[:, 1] if np.ndim(shap_expl.base_values) == 2 else shap_expl.base_values,
                    data=X_model,
                    feature_names=X_model.columns,
                )
    
            one = shap_expl[0]
            vals = one.values
            feats = np.array(one.feature_names)
    
            top_idx = np.argsort(np.abs(vals))[::-1][:20]
    
            shap_table = pd.DataFrame({
                "Feature": feats[top_idx],
                "Feature value": X_model.iloc[0][feats[top_idx]].values,
                "SHAP value (pushes risk ↑ / ↓)": vals[top_idx],
            })
    
            st.subheader("Top features driving this patient’s prediction")
            st.dataframe(shap_table, use_container_width=True)
    
            st.subheader("Waterfall plot (single patient)")
            plt.figure(figsize=(10, 6))
            shap.plots.waterfall(one, max_display=20, show=False)
            st.pyplot(plt.gcf(), bbox_inches="tight")
            plt.clf()
    
    except Exception as e:
        st.error(f"SHAP explanation failed: {e}")

    st.divider()
    st.header("Adjusted Overall Survival Prediction (Cox Proportional Hazards)")
    
    # ---- load saved Cox artifacts (trained once in Bulk page) ----
    try:
        cph, design_cols, cat_cols, meta = get_or_load_cox()
        st.success("Cox model ready (loaded from saved artifacts).")
        covariates = meta["covariates"]
        
        st.caption(f"N={meta['n']} | events={meta['n_events']} | C-index={meta['c_index']:.3f}")
    except Exception as e:
        st.error(f"Cox unavailable: {e}")
        cph = None

    if cph is not None:
        # Build patient covariate row from THIS patient (df_full has preprocessed columns)
        patient_cov = df_full.loc[[0], [c for c in covariates if c in df_full.columns]].copy()
        patient_cov["Predicted_GVHD_Risk"] = float(pred_prob)
    
        # Ensure missing covariates exist (type-safe defaults)
        categorical_defaults = {
            "Hematological Diagnosis_Grouped": "OTHER",
            "Donor_relation to recepient": "Unknown",
            "Source of cells": "Unknown",
            "Donor_type": "Unknown",
            "Conditioning_intensity": "Unknown",
            "GVHD_Prophylaxis_Cat": "Unknown",
        }
        for c in covariates:
            if c not in patient_cov.columns:
                patient_cov[c] = categorical_defaults.get(c, 0)
    
        try:
            # Create one-hot design row matching training design columns
            patient_design = make_patient_design_row(patient_cov, design_cols, cat_cols)
            if isinstance(patient_design, tuple):
                patient_design = patient_design[0]  # keep only the DataFrame row
        
            # Predict survival curve + landmarks
            surv_fn, landmarks = predict_patient_survival(cph, patient_design, years=[3, 5, 10])
        
            st.subheader("Individual adjusted survival curve")
            # Convert days → years for plotting
            surv_years = surv_fn.copy()
            surv_years.index = surv_years.index / 365.25
            
            fig, ax = plt.subplots(figsize=(8, 5))
            surv_years.plot(ax=ax, legend=False)
            
            ax.set_title("Predicted adjusted survival curve")
            ax.set_xlabel("Years")
            ax.set_ylabel("Survival probability")
            
            # Optional: make ticks clean integers
            ax.set_xlim(left=0)
            ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
            
            st.pyplot(fig, bbox_inches="tight")
            plt.clf()
            
        
            st.write("Landmark survival probabilities:")
            for y in [3, 5, 10]:
                s = landmarks.get(y, None)
                if s is not None:
                    st.write(f"{y}-year survival: {100*s:.1f}%")
        
        except Exception as e:
            st.error(f"Survival prediction failed: {e}")

    st.caption(
        "Use the tabs above to run prediction with the latest ensemble model "
        "(Acute GVHD or Chronic GVHD). "
        "The sidebar model selector remains available for manual model exploration if needed. "
        "Model naming convention: A = Acute GVHD, C = Chronic GVHD, G = Overall GVHD "
        "(e.g., 260222_062257A_Acute GVHD_ensemble)."
    )

st.divider()

st.markdown(
    """
    <div style="text-align:center; font-size:14px; color:grey; opacity:0.85; padding-top:20px; line-height:1.6;">
        <br>
        ©2025 Department of Health – Abu Dhabi<br>
        Partnership: SSMC (PureHealth) & MBZUAI<br>
        Data Collaborators: SKMC, Tawam Hospital, KHCC (Jordan)<br><br>
        <span style="font-size:13.5px;">
        GVHD-Intel Pro — Developed and maintained by <b>Dr Syed Naveed</b><br>
        Contact: <a href="mailto:naveed3642003@gmail.com" style="color:inherit; text-decoration:none;">
        naveed3642003@gmail.com
        </a>
        </span>
    </div>
    """,
    unsafe_allow_html=True
)