Update app.py
Browse files
app.py
CHANGED
|
@@ -17,7 +17,8 @@ from sklearn.metrics import (
|
|
| 17 |
precision_recall_curve, average_precision_score,
|
| 18 |
brier_score_loss
|
| 19 |
)
|
| 20 |
-
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
from sklearn.calibration import calibration_curve
|
|
@@ -39,7 +40,7 @@ from sklearn.model_selection import train_test_split
|
|
| 39 |
|
| 40 |
# REPLACE make_fig with this (or add this and stop using plt.plot directly)
|
| 41 |
def make_fig(figsize=(5.5, 3.6), dpi=120):
|
| 42 |
-
|
| 43 |
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
| 44 |
return fig, ax
|
| 45 |
|
|
@@ -63,7 +64,7 @@ def render_plot_with_download(
|
|
| 63 |
export_dpi: int = 600,
|
| 64 |
key: Optional[str] = None
|
| 65 |
):
|
| 66 |
-
|
| 67 |
png_bytes = fig_to_png_bytes(fig, dpi=export_dpi)
|
| 68 |
st.pyplot(fig, clear_figure=False)
|
| 69 |
st.download_button(
|
|
@@ -234,6 +235,9 @@ def train_survival_bundle(
|
|
| 234 |
|
| 235 |
# one-hot
|
| 236 |
df_surv_oh = pd.get_dummies(df_surv, columns=cat_cols, drop_first=True)
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
# remove duplicate columns if any messy headers caused duplicates
|
| 239 |
df_surv_oh = df_surv_oh.loc[:, ~df_surv_oh.columns.duplicated()].copy()
|
|
@@ -663,7 +667,7 @@ def train_and_save(
|
|
| 663 |
l1_C: float,
|
| 664 |
use_dimred: bool,
|
| 665 |
svd_components: int,):
|
| 666 |
-
|
| 667 |
X = df[feature_cols].copy()
|
| 668 |
y_raw = df[LABEL_COL].copy()
|
| 669 |
|
|
@@ -863,7 +867,7 @@ def train_and_save(
|
|
| 863 |
# SHAP
|
| 864 |
# ============================================================
|
| 865 |
def build_shap_explainer(pipe, X_bg, max_bg=200):
|
| 866 |
-
|
| 867 |
|
| 868 |
if X_bg is None or len(X_bg) == 0:
|
| 869 |
raise ValueError("SHAP background is empty.")
|
|
@@ -882,7 +886,7 @@ def build_shap_explainer(pipe, X_bg, max_bg=200):
|
|
| 882 |
return explainer
|
| 883 |
|
| 884 |
|
| 885 |
-
def safe_dense(Xt, max_rows: int =
|
| 886 |
"""
|
| 887 |
Convert sparse->dense carefully. Avoid converting huge matrices to dense.
|
| 888 |
"""
|
|
@@ -2612,6 +2616,9 @@ with tab_predict:
|
|
| 2612 |
try:
|
| 2613 |
cph = bundle["model"]
|
| 2614 |
surv_cols = bundle.get("columns", [])
|
|
|
|
|
|
|
|
|
|
| 2615 |
|
| 2616 |
# Build Cox input row (same preprocessing as Cox training)
|
| 2617 |
df_one_surv = X_one[feature_cols].copy()
|
|
@@ -2627,12 +2634,21 @@ with tab_predict:
|
|
| 2627 |
)
|
| 2628 |
|
| 2629 |
df_one_surv_oh = pd.get_dummies(df_one_surv, columns=cat_cols, drop_first=True)
|
|
|
|
|
|
|
| 2630 |
|
| 2631 |
# Align to training predictor columns
|
| 2632 |
for col in surv_cols:
|
| 2633 |
if col not in df_one_surv_oh.columns:
|
| 2634 |
df_one_surv_oh[col] = 0
|
| 2635 |
df_one_surv_oh = df_one_surv_oh[surv_cols]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2636 |
|
| 2637 |
# Predict survival function
|
| 2638 |
surv_fn = cph.predict_survival_function(df_one_surv_oh)
|
|
@@ -2994,6 +3010,17 @@ with tab_predict:
|
|
| 2994 |
imp = bundle.get("imputer", None)
|
| 2995 |
|
| 2996 |
df_surv_in = X_inf[feature_cols].copy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2997 |
df_surv_in_oh = pd.get_dummies(df_surv_in, columns=cat_cols, drop_first=True)
|
| 2998 |
|
| 2999 |
# align columns
|
|
|
|
| 17 |
precision_recall_curve, average_precision_score,
|
| 18 |
brier_score_loss
|
| 19 |
)
|
| 20 |
+
import shap
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
|
| 23 |
|
| 24 |
from sklearn.calibration import calibration_curve
|
|
|
|
| 40 |
|
| 41 |
# REPLACE make_fig with this (or add this and stop using plt.plot directly)
|
| 42 |
def make_fig(figsize=(5.5, 3.6), dpi=120):
|
| 43 |
+
|
| 44 |
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
| 45 |
return fig, ax
|
| 46 |
|
|
|
|
| 64 |
export_dpi: int = 600,
|
| 65 |
key: Optional[str] = None
|
| 66 |
):
|
| 67 |
+
|
| 68 |
png_bytes = fig_to_png_bytes(fig, dpi=export_dpi)
|
| 69 |
st.pyplot(fig, clear_figure=False)
|
| 70 |
st.download_button(
|
|
|
|
| 235 |
|
| 236 |
# one-hot
|
| 237 |
df_surv_oh = pd.get_dummies(df_surv, columns=cat_cols, drop_first=True)
|
| 238 |
+
if duration_col not in df_surv_oh.columns or event_col not in df_surv_oh.columns:
|
| 239 |
+
raise ValueError("Survival DF missing duration/event columns after one-hot encoding.")
|
| 240 |
+
|
| 241 |
|
| 242 |
# remove duplicate columns if any messy headers caused duplicates
|
| 243 |
df_surv_oh = df_surv_oh.loc[:, ~df_surv_oh.columns.duplicated()].copy()
|
|
|
|
| 667 |
l1_C: float,
|
| 668 |
use_dimred: bool,
|
| 669 |
svd_components: int,):
|
| 670 |
+
|
| 671 |
X = df[feature_cols].copy()
|
| 672 |
y_raw = df[LABEL_COL].copy()
|
| 673 |
|
|
|
|
| 867 |
# SHAP
|
| 868 |
# ============================================================
|
| 869 |
def build_shap_explainer(pipe, X_bg, max_bg=200):
|
| 870 |
+
|
| 871 |
|
| 872 |
if X_bg is None or len(X_bg) == 0:
|
| 873 |
raise ValueError("SHAP background is empty.")
|
|
|
|
| 886 |
return explainer
|
| 887 |
|
| 888 |
|
| 889 |
+
def safe_dense(Xt, max_rows: int = 200):
|
| 890 |
"""
|
| 891 |
Convert sparse->dense carefully. Avoid converting huge matrices to dense.
|
| 892 |
"""
|
|
|
|
| 2616 |
try:
|
| 2617 |
cph = bundle["model"]
|
| 2618 |
surv_cols = bundle.get("columns", [])
|
| 2619 |
+
|
| 2620 |
+
|
| 2621 |
+
|
| 2622 |
|
| 2623 |
# Build Cox input row (same preprocessing as Cox training)
|
| 2624 |
df_one_surv = X_one[feature_cols].copy()
|
|
|
|
| 2634 |
)
|
| 2635 |
|
| 2636 |
df_one_surv_oh = pd.get_dummies(df_one_surv, columns=cat_cols, drop_first=True)
|
| 2637 |
+
|
| 2638 |
+
|
| 2639 |
|
| 2640 |
# Align to training predictor columns
|
| 2641 |
for col in surv_cols:
|
| 2642 |
if col not in df_one_surv_oh.columns:
|
| 2643 |
df_one_surv_oh[col] = 0
|
| 2644 |
df_one_surv_oh = df_one_surv_oh[surv_cols]
|
| 2645 |
+
|
| 2646 |
+
imp = bundle.get("imputer", None)
|
| 2647 |
+
if imp is not None:
|
| 2648 |
+
X_imp = imp.transform(df_one_surv_oh)
|
| 2649 |
+
df_one_surv_oh = pd.DataFrame(X_imp, columns=surv_cols, index=df_one_surv_oh.index)
|
| 2650 |
+
else:
|
| 2651 |
+
df_one_surv_oh = df_one_surv_oh.fillna(0)
|
| 2652 |
|
| 2653 |
# Predict survival function
|
| 2654 |
surv_fn = cph.predict_survival_function(df_one_surv_oh)
|
|
|
|
| 3010 |
imp = bundle.get("imputer", None)
|
| 3011 |
|
| 3012 |
df_surv_in = X_inf[feature_cols].copy()
|
| 3013 |
+
for c in num_cols:
|
| 3014 |
+
if c in df_surv_in.columns:
|
| 3015 |
+
df_surv_in[c] = pd.to_numeric(df_surv_in[c], errors="coerce")
|
| 3016 |
+
|
| 3017 |
+
for c in cat_cols:
|
| 3018 |
+
if c in df_surv_in.columns:
|
| 3019 |
+
df_surv_in[c] = df_surv_in[c].astype("object")
|
| 3020 |
+
df_surv_in.loc[df_surv_in[c].isna(), c] = np.nan
|
| 3021 |
+
df_surv_in[c] = df_surv_in[c].map(lambda v: v if pd.isna(v) else str(v))
|
| 3022 |
+
|
| 3023 |
+
|
| 3024 |
df_surv_in_oh = pd.get_dummies(df_surv_in, columns=cat_cols, drop_first=True)
|
| 3025 |
|
| 3026 |
# align columns
|