GVHD_Prediction / src /sidebar.py
mfarnas's picture
add shap & local testing
470000a
import streamlit as st
from pathlib import Path
import os
import glob
import pyarrow.parquet as pq
LOCAL = False
if not LOCAL:
from huggingface_hub import HfApi, hf_hub_download
def get_model_options():
models = ["Default_GVHD_ensemble"]
api = HfApi(token=os.environ["HF_TOKEN"])
all_files = api.list_repo_files(repo_id=os.environ["HF_REPO_ID"], repo_type="dataset")
parquet_files = [f for f in all_files if f.startswith("models/") and f.endswith(".parquet")]
for f in parquet_files:
try:
# Download and read Parquet file
downloaded = hf_hub_download(
repo_id=os.environ["HF_REPO_ID"],
repo_type="dataset",
filename=f,
token=os.environ["HF_TOKEN"]
)
table = pq.read_table(downloaded)
row = table.to_pylist()[0]
models.append(row["filename"].replace(".parquet",""))
except Exception as e:
st.warning(f"Skipping model file due to error: {f} ({e})")
return sorted(set(models))
else:
import pickle
SAVED_MODELS_DIR = Path("src/saved_models")
SAVED_MODELS_DIR.mkdir(exist_ok=True)
def get_model_options():
"""Get list of available models from local storage"""
models = ["Default_GVHD_ensemble"]
if SAVED_MODELS_DIR.exists():
saved_models = [p.stem for p in SAVED_MODELS_DIR.glob("*.pkl")]
models.extend(saved_models)
return sorted(set(models))
def sidebar():
st.session_state.orig_train_cols = ['EPI/ID numbers', 'Recipient_gender', 'Recepient_DOB', 'Recepient_Nationality', 'Hematological Diagnosis', 'Date of first diagnosis/BMBx date', 'Recepient_Blood group before HSCT', 'Donor_DOB', 'Donor_gender', 'D_Blood group', 'R_HLA_A', 'R_HLA_B', 'R_HLA_C', 'R_HLA_DR', 'R_HLA_DQ', 'D_HLA_A', 'D_HLA_B', 'D_HLA_C', 'D_HLA_DR', 'D_HLA_DQ', 'Number of lines of Rx before HSCT', 'PreHSCT conditioning regimen+/-ATG+/-TBI', 'HSCT_date', 'Source of cells', 'Donor_relation to recipient', 'HLA match ratio', 'Post HSCT regimen', 'First_GVHD prophylaxis', 'GVHD', 'Acute GVHD(<100 days)', 'Chronic GVHD>100 days', 'Acute+Chronic', 'GVHD severity', 'R_HLA_A1', 'R_HLA_A2', 'R_HLA_B1', 'R_HLA_B2', 'R_HLA_C1', 'R_HLA_C2', 'R_HLA_DR1', 'R_HLA_DR2', 'R_HLA_DQ1', 'R_HLA_DQ2', 'D_HLA_A1', 'D_HLA_A2', 'D_HLA_B1', 'D_HLA_B2', 'D_HLA_C1', 'D_HLA_C2', 'D_HLA_DR1', 'D_HLA_DR2', 'D_HLA_DQ1', 'D_HLA_DQ2', 'R_HLA_A_1', 'R_HLA_A_11', 'R_HLA_A_12', 'R_HLA_A_2', 'R_HLA_A_20', 'R_HLA_A_23', 'R_HLA_A_24', 'R_HLA_A_25', 'R_HLA_A_26', 'R_HLA_A_29', 'R_HLA_A_3', 'R_HLA_A_30', 'R_HLA_A_31', 'R_HLA_A_32', 'R_HLA_A_33', 'R_HLA_A_34', 'R_HLA_A_4', 'R_HLA_A_66', 'R_HLA_A_68', 'R_HLA_A_69', 'R_HLA_A_7', 'R_HLA_A_74', 'R_HLA_A_8', 'R_HLA_A_X', 'R_HLA_B_13', 'R_HLA_B_14', 'R_HLA_B_15', 'R_HLA_B_18', 'R_HLA_B_23', 'R_HLA_B_24', 'R_HLA_B_27', 'R_HLA_B_35', 'R_HLA_B_37', 'R_HLA_B_38', 'R_HLA_B_39', 'R_HLA_B_40', 'R_HLA_B_41', 'R_HLA_B_42', 'R_HLA_B_44', 'R_HLA_B_45', 'R_HLA_B_46', 'R_HLA_B_49', 'R_HLA_B_50', 'R_HLA_B_51', 'R_HLA_B_52', 'R_HLA_B_53', 'R_HLA_B_55', 'R_HLA_B_56', 'R_HLA_B_57', 'R_HLA_B_58', 'R_HLA_B_7', 'R_HLA_B_73', 'R_HLA_B_8', 'R_HLA_B_81', 'R_HLA_B_X', 'R_HLA_C_1', 'R_HLA_C_12', 'R_HLA_C_14', 'R_HLA_C_15', 'R_HLA_C_16', 'R_HLA_C_17', 'R_HLA_C_18', 'R_HLA_C_2', 'R_HLA_C_3', 'R_HLA_C_38', 'R_HLA_C_4', 'R_HLA_C_49', 'R_HLA_C_5', 'R_HLA_C_50', 'R_HLA_C_6', 'R_HLA_C_7', 'R_HLA_C_8', 'R_HLA_C_X', 'R_HLA_DR_1', 'R_HLA_DR_10', 'R_HLA_DR_11', 'R_HLA_DR_12', 'R_HLA_DR_13', 'R_HLA_DR_14', 'R_HLA_DR_15', 'R_HLA_DR_16', 'R_HLA_DR_17', 'R_HLA_DR_2', 'R_HLA_DR_3', 'R_HLA_DR_4', 'R_HLA_DR_5', 'R_HLA_DR_6', 'R_HLA_DR_7', 'R_HLA_DR_8', 'R_HLA_DR_9', 'R_HLA_DR_X', 'R_HLA_DQ_1', 'R_HLA_DQ_11', 'R_HLA_DQ_15', 'R_HLA_DQ_16', 'R_HLA_DQ_2', 'R_HLA_DQ_3', 'R_HLA_DQ_301', 'R_HLA_DQ_4', 'R_HLA_DQ_5', 'R_HLA_DQ_6', 'R_HLA_DQ_7', 'R_HLA_DQ_X', 'D_HLA_A_1', 'D_HLA_A_11', 'D_HLA_A_12', 'D_HLA_A_2', 'D_HLA_A_23', 'D_HLA_A_24', 'D_HLA_A_25', 'D_HLA_A_26', 'D_HLA_A_29', 'D_HLA_A_3', 'D_HLA_A_30', 'D_HLA_A_31', 'D_HLA_A_32', 'D_HLA_A_33', 'D_HLA_A_34', 'D_HLA_A_66', 'D_HLA_A_68', 'D_HLA_A_69', 'D_HLA_A_7', 'D_HLA_A_74', 'D_HLA_A_8', 'D_HLA_A_X', 'D_HLA_B_13', 'D_HLA_B_14', 'D_HLA_B_15', 'D_HLA_B_17', 'D_HLA_B_18', 'D_HLA_B_23', 'D_HLA_B_24', 'D_HLA_B_27', 'D_HLA_B_35', 'D_HLA_B_37', 'D_HLA_B_38', 'D_HLA_B_39', 'D_HLA_B_40', 'D_HLA_B_41', 'D_HLA_B_42', 'D_HLA_B_44', 'D_HLA_B_45', 'D_HLA_B_48', 'D_HLA_B_49', 'D_HLA_B_50', 'D_HLA_B_51', 'D_HLA_B_52', 'D_HLA_B_53', 'D_HLA_B_55', 'D_HLA_B_56', 'D_HLA_B_57', 'D_HLA_B_58', 'D_HLA_B_7', 'D_HLA_B_73', 'D_HLA_B_8', 'D_HLA_B_81', 'D_HLA_B_X', 'D_HLA_C_1', 'D_HLA_C_12', 'D_HLA_C_14', 'D_HLA_C_15', 'D_HLA_C_16', 'D_HLA_C_17', 'D_HLA_C_18', 'D_HLA_C_2', 'D_HLA_C_3', 'D_HLA_C_38', 'D_HLA_C_4', 'D_HLA_C_49', 'D_HLA_C_5', 'D_HLA_C_50', 'D_HLA_C_6', 'D_HLA_C_7', 'D_HLA_C_8', 'D_HLA_C_X', 'D_HLA_DR_1', 'D_HLA_DR_10', 'D_HLA_DR_11', 'D_HLA_DR_12', 'D_HLA_DR_13', 'D_HLA_DR_14', 'D_HLA_DR_15', 'D_HLA_DR_16', 'D_HLA_DR_17', 'D_HLA_DR_2', 'D_HLA_DR_3', 'D_HLA_DR_4', 'D_HLA_DR_5', 'D_HLA_DR_6', 'D_HLA_DR_7', 'D_HLA_DR_8', 'D_HLA_DR_9', 'D_HLA_DR_X', 'D_HLA_DQ_1', 'D_HLA_DQ_11', 'D_HLA_DQ_15', 'D_HLA_DQ_16', 'D_HLA_DQ_2', 'D_HLA_DQ_3', 'D_HLA_DQ_301', 'D_HLA_DQ_4', 'D_HLA_DQ_5', 'D_HLA_DQ_6', 'D_HLA_DQ_7', 'D_HLA_DQ_X', 'Recepient_DOB_Year', 'Donor_DOB_Year', 'HSCT_date_Year', 'R_Age_at_transplant', 'D_Age_at_transplant', 'Age_Gap_R_D', 'PreHSCT_ALEMTUZUMAB', 'PreHSCT_ATG', 'PreHSCT_BEAM', 'PreHSCT_BUSULFAN', 'PreHSCT_CAMPATH', 'PreHSCT_CARMUSTINE', 'PreHSCT_CLOFARABINE', 'PreHSCT_CYCLOPHOSPHAMIDE', 'PreHSCT_CYCLOSPORIN', 'PreHSCT_CYTARABINE', 'PreHSCT_ETOPOSIDE', 'PreHSCT_FLUDARABINE', 'PreHSCT_GEMCITABINE', 'PreHSCT_MELPHALAN', 'PreHSCT_METHOTREXATE', 'PreHSCT_OTHER', 'PreHSCT_RANIMUSTINE', 'PreHSCT_REDUCEDCONDITIONING', 'PreHSCT_RITUXIMAB', 'PreHSCT_SIROLIMUS', 'PreHSCT_TBI', 'PreHSCT_THIOTEPA', 'PreHSCT_TREOSULFAN', 'PreHSCT_UA', 'PreHSCT_VORNOSTAT', 'PreHSCT_X', 'First_GVHD_prophylaxis_ABATACEPT', 'First_GVHD_prophylaxis_ALEMTUZUMAB', 'First_GVHD_prophylaxis_ATG', 'First_GVHD_prophylaxis_CYCLOPHOSPHAMIDE', 'First_GVHD_prophylaxis_CYCLOSPORIN', 'First_GVHD_prophylaxis_IMATINIB', 'First_GVHD_prophylaxis_LEFLUNOMIDE', 'First_GVHD_prophylaxis_METHOTREXATE', 'First_GVHD_prophylaxis_MMF', 'First_GVHD_prophylaxis_NONE', 'First_GVHD_prophylaxis_RUXOLITINIB', 'First_GVHD_prophylaxis_SIROLIMUS', 'First_GVHD_prophylaxis_STEROID', 'First_GVHD_prophylaxis_TAC', 'First_GVHD_prophylaxis_TACROLIMUS', 'First_GVHD_prophylaxis_X', 'Recepient_Blood group before HSCT_MergePlusMinus', 'D_Blood group_MergePlusMinus', 'R_Age_at_transplant_cutoff16', 'R_Age_at_transplant_cutoff18', 'D_Age_at_transplant_cutoff16', 'D_Age_at_transplant_cutoff18', 'Relation_and_Recipient_Gender', 'Relation_and_Donor_Gender', 'Relation_and_Recipient_and_Donor_Gender', 'Recepient_Nationality_Geographical', 'Recepient_Nationality_Cultural', 'Recepient_Nationality_Regional_Income', 'Recepient_Nationality_Regional_WHO', 'Hematological Diagnosis_Grouped', 'Hematological Diagnosis_Malignant', 'PreHSCT_MTX', 'First_GVHD_prophylaxis_MTX']
if 'selected_model' not in st.session_state:
st.session_state.selected_model = "Default_GVHD_ensemble"
st.sidebar.title("Model Selection")
st.session_state.selected_model = st.sidebar.selectbox("Model", get_model_options())