Synav commited on
Commit
cf97de2
·
verified ·
1 Parent(s): a4c4923

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -6
app.py CHANGED
@@ -17,7 +17,8 @@ from sklearn.metrics import (
17
  precision_recall_curve, average_precision_score,
18
  brier_score_loss
19
  )
20
-
 
21
 
22
 
23
  from sklearn.calibration import calibration_curve
@@ -39,7 +40,7 @@ from sklearn.model_selection import train_test_split
39
 
40
  # REPLACE make_fig with this (or add this and stop using plt.plot directly)
41
  def make_fig(figsize=(5.5, 3.6), dpi=120):
42
- import matplotlib.pyplot as plt
43
  fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
44
  return fig, ax
45
 
@@ -63,7 +64,7 @@ def render_plot_with_download(
63
  export_dpi: int = 600,
64
  key: Optional[str] = None
65
  ):
66
- import matplotlib.pyplot as plt # lazy
67
  png_bytes = fig_to_png_bytes(fig, dpi=export_dpi)
68
  st.pyplot(fig, clear_figure=False)
69
  st.download_button(
@@ -234,6 +235,9 @@ def train_survival_bundle(
234
 
235
  # one-hot
236
  df_surv_oh = pd.get_dummies(df_surv, columns=cat_cols, drop_first=True)
 
 
 
237
 
238
  # remove duplicate columns if any messy headers caused duplicates
239
  df_surv_oh = df_surv_oh.loc[:, ~df_surv_oh.columns.duplicated()].copy()
@@ -663,7 +667,7 @@ def train_and_save(
663
  l1_C: float,
664
  use_dimred: bool,
665
  svd_components: int,):
666
- from lifelines import CoxPHFitter
667
  X = df[feature_cols].copy()
668
  y_raw = df[LABEL_COL].copy()
669
 
@@ -863,7 +867,7 @@ def train_and_save(
863
  # SHAP
864
  # ============================================================
865
  def build_shap_explainer(pipe, X_bg, max_bg=200):
866
- import shap # lazy
867
 
868
  if X_bg is None or len(X_bg) == 0:
869
  raise ValueError("SHAP background is empty.")
@@ -882,7 +886,7 @@ def build_shap_explainer(pipe, X_bg, max_bg=200):
882
  return explainer
883
 
884
 
885
- def safe_dense(Xt, max_rows: int = 2000):
886
  """
887
  Convert sparse->dense carefully. Avoid converting huge matrices to dense.
888
  """
@@ -2612,6 +2616,9 @@ with tab_predict:
2612
  try:
2613
  cph = bundle["model"]
2614
  surv_cols = bundle.get("columns", [])
 
 
 
2615
 
2616
  # Build Cox input row (same preprocessing as Cox training)
2617
  df_one_surv = X_one[feature_cols].copy()
@@ -2627,12 +2634,21 @@ with tab_predict:
2627
  )
2628
 
2629
  df_one_surv_oh = pd.get_dummies(df_one_surv, columns=cat_cols, drop_first=True)
 
 
2630
 
2631
  # Align to training predictor columns
2632
  for col in surv_cols:
2633
  if col not in df_one_surv_oh.columns:
2634
  df_one_surv_oh[col] = 0
2635
  df_one_surv_oh = df_one_surv_oh[surv_cols]
 
 
 
 
 
 
 
2636
 
2637
  # Predict survival function
2638
  surv_fn = cph.predict_survival_function(df_one_surv_oh)
@@ -2994,6 +3010,17 @@ with tab_predict:
2994
  imp = bundle.get("imputer", None)
2995
 
2996
  df_surv_in = X_inf[feature_cols].copy()
 
 
 
 
 
 
 
 
 
 
 
2997
  df_surv_in_oh = pd.get_dummies(df_surv_in, columns=cat_cols, drop_first=True)
2998
 
2999
  # align columns
 
17
  precision_recall_curve, average_precision_score,
18
  brier_score_loss
19
  )
20
+ import shap
21
+ import matplotlib.pyplot as plt
22
 
23
 
24
  from sklearn.calibration import calibration_curve
 
40
 
41
  # REPLACE make_fig with this (or add this and stop using plt.plot directly)
42
  def make_fig(figsize=(5.5, 3.6), dpi=120):
43
+
44
  fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
45
  return fig, ax
46
 
 
64
  export_dpi: int = 600,
65
  key: Optional[str] = None
66
  ):
67
+
68
  png_bytes = fig_to_png_bytes(fig, dpi=export_dpi)
69
  st.pyplot(fig, clear_figure=False)
70
  st.download_button(
 
235
 
236
  # one-hot
237
  df_surv_oh = pd.get_dummies(df_surv, columns=cat_cols, drop_first=True)
238
+ if duration_col not in df_surv_oh.columns or event_col not in df_surv_oh.columns:
239
+ raise ValueError("Survival DF missing duration/event columns after one-hot encoding.")
240
+
241
 
242
  # remove duplicate columns if any messy headers caused duplicates
243
  df_surv_oh = df_surv_oh.loc[:, ~df_surv_oh.columns.duplicated()].copy()
 
667
  l1_C: float,
668
  use_dimred: bool,
669
  svd_components: int,):
670
+
671
  X = df[feature_cols].copy()
672
  y_raw = df[LABEL_COL].copy()
673
 
 
867
  # SHAP
868
  # ============================================================
869
  def build_shap_explainer(pipe, X_bg, max_bg=200):
870
+
871
 
872
  if X_bg is None or len(X_bg) == 0:
873
  raise ValueError("SHAP background is empty.")
 
886
  return explainer
887
 
888
 
889
+ def safe_dense(Xt, max_rows: int = 200):
890
  """
891
  Convert sparse->dense carefully. Avoid converting huge matrices to dense.
892
  """
 
2616
  try:
2617
  cph = bundle["model"]
2618
  surv_cols = bundle.get("columns", [])
2619
+
2620
+
2621
+
2622
 
2623
  # Build Cox input row (same preprocessing as Cox training)
2624
  df_one_surv = X_one[feature_cols].copy()
 
2634
  )
2635
 
2636
  df_one_surv_oh = pd.get_dummies(df_one_surv, columns=cat_cols, drop_first=True)
2637
+
2638
+
2639
 
2640
  # Align to training predictor columns
2641
  for col in surv_cols:
2642
  if col not in df_one_surv_oh.columns:
2643
  df_one_surv_oh[col] = 0
2644
  df_one_surv_oh = df_one_surv_oh[surv_cols]
2645
+
2646
+ imp = bundle.get("imputer", None)
2647
+ if imp is not None:
2648
+ X_imp = imp.transform(df_one_surv_oh)
2649
+ df_one_surv_oh = pd.DataFrame(X_imp, columns=surv_cols, index=df_one_surv_oh.index)
2650
+ else:
2651
+ df_one_surv_oh = df_one_surv_oh.fillna(0)
2652
 
2653
  # Predict survival function
2654
  surv_fn = cph.predict_survival_function(df_one_surv_oh)
 
3010
  imp = bundle.get("imputer", None)
3011
 
3012
  df_surv_in = X_inf[feature_cols].copy()
3013
+ for c in num_cols:
3014
+ if c in df_surv_in.columns:
3015
+ df_surv_in[c] = pd.to_numeric(df_surv_in[c], errors="coerce")
3016
+
3017
+ for c in cat_cols:
3018
+ if c in df_surv_in.columns:
3019
+ df_surv_in[c] = df_surv_in[c].astype("object")
3020
+ df_surv_in.loc[df_surv_in[c].isna(), c] = np.nan
3021
+ df_surv_in[c] = df_surv_in[c].map(lambda v: v if pd.isna(v) else str(v))
3022
+
3023
+
3024
  df_surv_in_oh = pd.get_dummies(df_surv_in, columns=cat_cols, drop_first=True)
3025
 
3026
  # align columns