File size: 4,005 Bytes
1b04efb
e763894
 
1b04efb
 
 
25217b5
e763894
 
1b04efb
e763894
1b04efb
e763894
 
 
81c48d1
e763894
 
 
 
 
1b04efb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e763894
1b04efb
 
 
e763894
1b04efb
 
 
e763894
 
 
 
25217b5
e763894
 
 
 
25217b5
e763894
 
 
 
 
 
 
25217b5
e763894
 
 
 
25217b5
 
 
 
 
 
 
 
 
 
 
e763894
25217b5
e763894
 
 
 
1b04efb
6dafc80
 
 
e763894
25217b5
 
 
 
 
 
 
 
 
 
 
e763894
 
6dafc80
25217b5
e763894
25217b5
 
 
e763894
6dafc80
 
 
 
 
 
 
 
 
 
 
 
 
 
e763894
 
 
 
 
6dafc80
e763894
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
from pathlib import Path
import streamlit as st

LOCAL = False


DEFAULT_TARGET = "GVHD"
DEFAULT_THRESHOLD = 0.5

SAVED_MODELS_DIR = Path("src/saved_models")


def _init_session_defaults():
    # defaults used by all pages
    st.session_state.setdefault("selected_model", None)
    st.session_state.setdefault("target_choice", DEFAULT_TARGET)   # bulk page uses this
    st.session_state.setdefault("target_col", DEFAULT_TARGET)      # individual/training set this too
    st.session_state.setdefault("threshold", DEFAULT_THRESHOLD)
    st.session_state.setdefault("orig_train_cols", [
        # Raw columns expected from uploaded CSV (NOT engineered features)
        "EPI/ID numbers",
        "Recepient_gender",
        "Recepient_DOB",
        "Recepient_Nationality",
        "Hematological Diagnosis",
        "Date of first diagnosis/BMBx date",
        "Recepient_Blood group before HSCT",
        "Donor_DOB",
        "Donor_gender",
        "D_Blood group",
        "R_HLA_A", "R_HLA_B", "R_HLA_C", "R_HLA_DR", "R_HLA_DQ",
        "D_HLA_A", "D_HLA_B", "D_HLA_C", "D_HLA_DR", "D_HLA_DQ",
        "Number of lines of Rx before HSCT",
        "PreHSCT conditioning regimen+/-ATG+/-TBI",
        "HSCT_date",
        "Source of cells",
        "Donor_relation to recepient",
        "HLA match ratio",
        "First_GVHD prophylaxis",

        # Pro / survival raw columns
        "Last_followup_date",
        "Date_of_death",

        # Pro categorical columns
        "Donor_type",
        "Conditioning_intensity",
        "GVHD_Prophylaxis_Cat",
    ])


def get_model_options():
    models = []

    if LOCAL:
        SAVED_MODELS_DIR.mkdir(exist_ok=True)
        models.extend([p.stem for p in SAVED_MODELS_DIR.glob("*.pkl")])
        return sorted(models, reverse=True)

    from huggingface_hub import HfApi

    repo_id = os.environ.get("HF_REPO_ID", "")
    token = os.environ.get("HF_TOKEN", "")

    if not repo_id or not token:
        return models

    api = HfApi(token=token)
    all_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")

    parquet_files = [
        Path(f).stem
        for f in all_files
        if f.startswith("models/") and f.endswith(".parquet")
    ]

    # Sort by timestamp prefix (yyMMdd_HHMMSS...)
    def extract_timestamp(name):
        return name.split("_")[0] if "_" in name else ""

    parquet_files.sort(key=extract_timestamp, reverse=True)

    return parquet_files


def sidebar():
    _init_session_defaults()

    st.sidebar.title("GVHD-Intel")
    st.sidebar.subheader("Model")
    options = get_model_options()

    if not options:
        st.sidebar.warning("No trained models found.")
        return
    
    # If no model selected OR selected model not in list → pick latest (first in sorted list)
    if (
        "selected_model" not in st.session_state
        or st.session_state.selected_model not in options
    ):
        st.session_state.selected_model = options[0]  # latest model
    
    st.session_state.selected_model = st.sidebar.selectbox(
        "Model",
        options=options,
        index=options.index(st.session_state.selected_model),
    )
    

    

    # Optional: hide advanced controls behind toggle (recommended)
    show_advanced = st.sidebar.toggle("Show advanced settings", value=False)

    if show_advanced:
        st.sidebar.subheader("Target")
        st.session_state.target_choice = st.sidebar.radio(
            "Choose target:",
            options=["GVHD", "Acute GVHD(<100 days)", "Chronic GVHD>100 days"],
            index=["GVHD", "Acute GVHD(<100 days)", "Chronic GVHD>100 days"].index(
                st.session_state.get("target_choice", DEFAULT_TARGET)
            )
        )

    st.sidebar.subheader("Decision threshold")
    st.session_state.threshold = st.sidebar.slider(
        "Positive threshold",
        min_value=0.05,
        max_value=0.95,
        value=float(st.session_state.get("threshold", DEFAULT_THRESHOLD)),
        step=0.05,
    )