Synav commited on
Commit
9fd2250
·
verified ·
1 Parent(s): a4430dd

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +14 -36
src/streamlit_app.py CHANGED
@@ -20,48 +20,25 @@ if str(ROOT) not in sys.path:
20
  sys.path.insert(0, str(ROOT))
21
 
22
  from src.survival_utils import prepare_cox_df, fit_cox, make_patient_design_row, predict_patient_survival
23
-
24
- COX_TRAIN_PATH = ROOT / "assets" / "cox_training.parquet"
25
 
26
  @st.cache_resource
27
- def get_or_train_cox():
28
  """
29
- Train Cox once per app start (cached). No persistence needed across restarts.
 
30
  """
31
- if not COX_TRAIN_PATH.exists():
32
- raise FileNotFoundError(f"Missing Cox training file: {COX_TRAIN_PATH}")
33
-
34
- df_train = pd.read_parquet(COX_TRAIN_PATH)
35
-
36
- # REQUIRED columns must exist in df_train:
37
- # OS_time_days, Event_clean, Predicted_GVHD_Risk + your covariates
38
- covariates = [
39
- "R_Age_at_transplant_cutoff18",
40
- "HLA_total_mismatch",
41
- "Hematological Diagnosis_Grouped",
42
- "Donor_relation to recepient",
43
- "Source of cells",
44
- "Donor_type",
45
- "Conditioning_intensity",
46
- "GVHD_Prophylaxis_Cat",
47
- ]
48
-
49
- # df_train already should contain Predicted_GVHD_Risk.
50
- # If it doesn't, you must generate it offline and store it in this parquet.
51
- preds = df_train["Predicted_GVHD_Risk"].values
52
-
53
- df_cox, design_cols, cat_cols = prepare_cox_df(df_train, preds, covariates)
54
-
55
- cph = fit_cox(df_cox)
56
- meta = {
57
- "n": len(df_cox),
58
- "n_events": int(df_cox["Event_clean"].sum()),
59
- "c_index": float(cph.concordance_index_),
60
- "covariates": covariates,
61
- }
62
  return cph, design_cols, cat_cols, meta
63
 
64
 
 
65
  # --- Country options for UI (pycountry backed) ---
66
  try:
67
  import pycountry
@@ -447,7 +424,8 @@ if submitted:
447
 
448
  # ---- load saved Cox artifacts (trained once in Bulk page) ----
449
  try:
450
- cph, design_cols, cat_cols, meta = get_or_train_cox()
 
451
  covariates = meta["covariates"]
452
  st.success("Cox model ready (trained on app start).")
453
  st.caption(f"N={meta['n']} | events={meta['n_events']} | C-index={meta['c_index']:.3f}")
 
20
  sys.path.insert(0, str(ROOT))
21
 
22
  from src.survival_utils import prepare_cox_df, fit_cox, make_patient_design_row, predict_patient_survival
23
+ from src.cox_persist import load_cox_artifacts, ensure_cox_artifacts_available
 
24
 
25
  @st.cache_resource
26
+ def get_or_load_cox():
27
  """
28
+ Load Cox artifacts trained in Bulk page.
29
+ If running on Spaces without persistence, download from HF dataset repo when missing.
30
  """
31
+ # Ensure artifacts exist locally (downloads to /tmp/saved_models by default)
32
+ ensure_cox_artifacts_available() # downloads cox_os_model.joblib + cox_os_meta.json
33
+
34
+ payload, meta = load_cox_artifacts(prefix="cox_os") # loads from DEFAULT_DIR (/tmp/saved_models)
35
+ cph = payload["cph"]
36
+ design_cols = payload["design_cols"]
37
+ cat_cols = payload["cat_cols"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  return cph, design_cols, cat_cols, meta
39
 
40
 
41
+
42
  # --- Country options for UI (pycountry backed) ---
43
  try:
44
  import pycountry
 
424
 
425
  # ---- load saved Cox artifacts (trained once in Bulk page) ----
426
  try:
427
+ cph, design_cols, cat_cols, meta = get_or_load_cox()
428
+ st.success("Cox model ready (loaded from saved artifacts).")
429
  covariates = meta["covariates"]
430
  st.success("Cox model ready (trained on app start).")
431
  st.caption(f"N={meta['n']} | events={meta['n_events']} | C-index={meta['c_index']:.3f}")