Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +14 -36
src/streamlit_app.py
CHANGED
|
@@ -20,48 +20,25 @@ if str(ROOT) not in sys.path:
|
|
| 20 |
sys.path.insert(0, str(ROOT))
|
| 21 |
|
| 22 |
from src.survival_utils import prepare_cox_df, fit_cox, make_patient_design_row, predict_patient_survival
|
| 23 |
-
|
| 24 |
-
COX_TRAIN_PATH = ROOT / "assets" / "cox_training.parquet"
|
| 25 |
|
| 26 |
@st.cache_resource
|
| 27 |
-
def
|
| 28 |
"""
|
| 29 |
-
|
|
|
|
| 30 |
"""
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
covariates = [
|
| 39 |
-
"R_Age_at_transplant_cutoff18",
|
| 40 |
-
"HLA_total_mismatch",
|
| 41 |
-
"Hematological Diagnosis_Grouped",
|
| 42 |
-
"Donor_relation to recepient",
|
| 43 |
-
"Source of cells",
|
| 44 |
-
"Donor_type",
|
| 45 |
-
"Conditioning_intensity",
|
| 46 |
-
"GVHD_Prophylaxis_Cat",
|
| 47 |
-
]
|
| 48 |
-
|
| 49 |
-
# df_train already should contain Predicted_GVHD_Risk.
|
| 50 |
-
# If it doesn't, you must generate it offline and store it in this parquet.
|
| 51 |
-
preds = df_train["Predicted_GVHD_Risk"].values
|
| 52 |
-
|
| 53 |
-
df_cox, design_cols, cat_cols = prepare_cox_df(df_train, preds, covariates)
|
| 54 |
-
|
| 55 |
-
cph = fit_cox(df_cox)
|
| 56 |
-
meta = {
|
| 57 |
-
"n": len(df_cox),
|
| 58 |
-
"n_events": int(df_cox["Event_clean"].sum()),
|
| 59 |
-
"c_index": float(cph.concordance_index_),
|
| 60 |
-
"covariates": covariates,
|
| 61 |
-
}
|
| 62 |
return cph, design_cols, cat_cols, meta
|
| 63 |
|
| 64 |
|
|
|
|
| 65 |
# --- Country options for UI (pycountry backed) ---
|
| 66 |
try:
|
| 67 |
import pycountry
|
|
@@ -447,7 +424,8 @@ if submitted:
|
|
| 447 |
|
| 448 |
# ---- load saved Cox artifacts (trained once in Bulk page) ----
|
| 449 |
try:
|
| 450 |
-
cph, design_cols, cat_cols, meta =
|
|
|
|
| 451 |
covariates = meta["covariates"]
|
| 452 |
st.success("Cox model ready (trained on app start).")
|
| 453 |
st.caption(f"N={meta['n']} | events={meta['n_events']} | C-index={meta['c_index']:.3f}")
|
|
|
|
| 20 |
sys.path.insert(0, str(ROOT))
|
| 21 |
|
| 22 |
from src.survival_utils import prepare_cox_df, fit_cox, make_patient_design_row, predict_patient_survival
|
| 23 |
+
from src.cox_persist import load_cox_artifacts, ensure_cox_artifacts_available
|
|
|
|
| 24 |
|
| 25 |
@st.cache_resource
|
| 26 |
+
def get_or_load_cox():
|
| 27 |
"""
|
| 28 |
+
Load Cox artifacts trained in Bulk page.
|
| 29 |
+
If running on Spaces without persistence, download from HF dataset repo when missing.
|
| 30 |
"""
|
| 31 |
+
# Ensure artifacts exist locally (downloads to /tmp/saved_models by default)
|
| 32 |
+
ensure_cox_artifacts_available() # downloads cox_os_model.joblib + cox_os_meta.json
|
| 33 |
+
|
| 34 |
+
payload, meta = load_cox_artifacts(prefix="cox_os") # loads from DEFAULT_DIR (/tmp/saved_models)
|
| 35 |
+
cph = payload["cph"]
|
| 36 |
+
design_cols = payload["design_cols"]
|
| 37 |
+
cat_cols = payload["cat_cols"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
return cph, design_cols, cat_cols, meta
|
| 39 |
|
| 40 |
|
| 41 |
+
|
| 42 |
# --- Country options for UI (pycountry backed) ---
|
| 43 |
try:
|
| 44 |
import pycountry
|
|
|
|
| 424 |
|
| 425 |
# ---- load saved Cox artifacts (trained once in Bulk page) ----
|
| 426 |
try:
|
| 427 |
+
cph, design_cols, cat_cols, meta = get_or_load_cox()
|
| 428 |
+
st.success("Cox model ready (loaded from saved artifacts).")
|
| 429 |
covariates = meta["covariates"]
|
| 430 |
st.success("Cox model ready (trained on app start).")
|
| 431 |
st.caption(f"N={meta['n']} | events={meta['n_events']} | C-index={meta['c_index']:.3f}")
|