Spaces:
Running
Running
File size: 4,005 Bytes
1b04efb e763894 1b04efb 25217b5 e763894 1b04efb e763894 1b04efb e763894 81c48d1 e763894 1b04efb e763894 1b04efb e763894 1b04efb e763894 25217b5 e763894 25217b5 e763894 25217b5 e763894 25217b5 e763894 25217b5 e763894 1b04efb 6dafc80 e763894 25217b5 e763894 6dafc80 25217b5 e763894 25217b5 e763894 6dafc80 e763894 6dafc80 e763894 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | import os
from pathlib import Path
import streamlit as st
LOCAL = False
DEFAULT_TARGET = "GVHD"
DEFAULT_THRESHOLD = 0.5
SAVED_MODELS_DIR = Path("src/saved_models")
def _init_session_defaults():
# defaults used by all pages
st.session_state.setdefault("selected_model", None)
st.session_state.setdefault("target_choice", DEFAULT_TARGET) # bulk page uses this
st.session_state.setdefault("target_col", DEFAULT_TARGET) # individual/training set this too
st.session_state.setdefault("threshold", DEFAULT_THRESHOLD)
st.session_state.setdefault("orig_train_cols", [
# Raw columns expected from uploaded CSV (NOT engineered features)
"EPI/ID numbers",
"Recepient_gender",
"Recepient_DOB",
"Recepient_Nationality",
"Hematological Diagnosis",
"Date of first diagnosis/BMBx date",
"Recepient_Blood group before HSCT",
"Donor_DOB",
"Donor_gender",
"D_Blood group",
"R_HLA_A", "R_HLA_B", "R_HLA_C", "R_HLA_DR", "R_HLA_DQ",
"D_HLA_A", "D_HLA_B", "D_HLA_C", "D_HLA_DR", "D_HLA_DQ",
"Number of lines of Rx before HSCT",
"PreHSCT conditioning regimen+/-ATG+/-TBI",
"HSCT_date",
"Source of cells",
"Donor_relation to recepient",
"HLA match ratio",
"First_GVHD prophylaxis",
# Pro / survival raw columns
"Last_followup_date",
"Date_of_death",
# Pro categorical columns
"Donor_type",
"Conditioning_intensity",
"GVHD_Prophylaxis_Cat",
])
def get_model_options():
models = []
if LOCAL:
SAVED_MODELS_DIR.mkdir(exist_ok=True)
models.extend([p.stem for p in SAVED_MODELS_DIR.glob("*.pkl")])
return sorted(models, reverse=True)
from huggingface_hub import HfApi
repo_id = os.environ.get("HF_REPO_ID", "")
token = os.environ.get("HF_TOKEN", "")
if not repo_id or not token:
return models
api = HfApi(token=token)
all_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
parquet_files = [
Path(f).stem
for f in all_files
if f.startswith("models/") and f.endswith(".parquet")
]
# Sort by timestamp prefix (yyMMdd_HHMMSS...)
def extract_timestamp(name):
return name.split("_")[0] if "_" in name else ""
parquet_files.sort(key=extract_timestamp, reverse=True)
return parquet_files
def sidebar():
_init_session_defaults()
st.sidebar.title("GVHD-Intel")
st.sidebar.subheader("Model")
options = get_model_options()
if not options:
st.sidebar.warning("No trained models found.")
return
# If no model selected OR selected model not in list → pick latest (first in sorted list)
if (
"selected_model" not in st.session_state
or st.session_state.selected_model not in options
):
st.session_state.selected_model = options[0] # latest model
st.session_state.selected_model = st.sidebar.selectbox(
"Model",
options=options,
index=options.index(st.session_state.selected_model),
)
# Optional: hide advanced controls behind toggle (recommended)
show_advanced = st.sidebar.toggle("Show advanced settings", value=False)
if show_advanced:
st.sidebar.subheader("Target")
st.session_state.target_choice = st.sidebar.radio(
"Choose target:",
options=["GVHD", "Acute GVHD(<100 days)", "Chronic GVHD>100 days"],
index=["GVHD", "Acute GVHD(<100 days)", "Chronic GVHD>100 days"].index(
st.session_state.get("target_choice", DEFAULT_TARGET)
)
)
st.sidebar.subheader("Decision threshold")
st.session_state.threshold = st.sidebar.slider(
"Positive threshold",
min_value=0.05,
max_value=0.95,
value=float(st.session_state.get("threshold", DEFAULT_THRESHOLD)),
step=0.05,
) |