Update app.py
Browse files
app.py
CHANGED
|
@@ -2930,45 +2930,47 @@ with tab_predict:
|
|
| 2930 |
# ---- Batch survival probabilities (if survival model loaded) ----
|
| 2931 |
# ---- Batch survival probabilities (if survival model loaded) ----
|
| 2932 |
bundle = st.session_state.get("surv_model", None)
|
|
|
|
| 2933 |
if isinstance(bundle, dict) and bundle.get("model") is not None:
|
| 2934 |
-
|
| 2935 |
-
|
| 2936 |
-
|
|
|
|
| 2937 |
|
| 2938 |
-
|
| 2939 |
-
|
| 2940 |
-
df_surv_in_oh = pd.get_dummies(df_surv_in, columns=cat_cols, drop_first=True)
|
| 2941 |
|
| 2942 |
-
|
| 2943 |
-
|
| 2944 |
-
|
| 2945 |
-
|
| 2946 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2947 |
|
| 2948 |
-
|
| 2949 |
-
if imp is not None:
|
| 2950 |
-
X_imp = imp.transform(df_surv_in_oh)
|
| 2951 |
-
df_surv_in_oh = pd.DataFrame(X_imp, columns=surv_cols, index=df_surv_in_oh.index)
|
| 2952 |
-
else:
|
| 2953 |
-
df_surv_in_oh = df_surv_in_oh.fillna(0)
|
| 2954 |
|
| 2955 |
-
|
| 2956 |
-
|
|
|
|
|
|
|
| 2957 |
|
| 2958 |
-
|
| 2959 |
-
|
| 2960 |
-
|
| 2961 |
-
|
| 2962 |
-
|
| 2963 |
-
|
| 2964 |
-
|
| 2965 |
-
|
| 2966 |
-
|
| 2967 |
-
|
| 2968 |
-
|
| 2969 |
-
st.warning(f"Batch survival probabilities could not be computed: {e}")
|
| 2970 |
-
else:
|
| 2971 |
-
st.info("Survival model not loaded/published yet (survival bundle missing).")
|
| 2972 |
|
| 2973 |
|
| 2974 |
|
|
|
|
| 2930 |
# ---- Batch survival probabilities (if survival model loaded) ----
|
| 2931 |
# ---- Batch survival probabilities (if survival model loaded) ----
|
| 2932 |
bundle = st.session_state.get("surv_model", None)
|
| 2933 |
+
|
| 2934 |
if isinstance(bundle, dict) and bundle.get("model") is not None:
|
| 2935 |
+
try:
|
| 2936 |
+
cph = bundle["model"]
|
| 2937 |
+
surv_cols = bundle["columns"]
|
| 2938 |
+
imp = bundle.get("imputer", None)
|
| 2939 |
|
| 2940 |
+
df_surv_in = X_inf[feature_cols].copy()
|
| 2941 |
+
df_surv_in_oh = pd.get_dummies(df_surv_in, columns=cat_cols, drop_first=True)
|
|
|
|
| 2942 |
|
| 2943 |
+
# align columns
|
| 2944 |
+
for col in surv_cols:
|
| 2945 |
+
if col not in df_surv_in_oh.columns:
|
| 2946 |
+
df_surv_in_oh[col] = 0
|
| 2947 |
+
df_surv_in_oh = df_surv_in_oh[surv_cols]
|
| 2948 |
+
|
| 2949 |
+
# impute
|
| 2950 |
+
if imp is not None:
|
| 2951 |
+
X_imp = imp.transform(df_surv_in_oh)
|
| 2952 |
+
df_surv_in_oh = pd.DataFrame(X_imp, columns=surv_cols, index=df_surv_in_oh.index)
|
| 2953 |
+
else:
|
| 2954 |
+
df_surv_in_oh = df_surv_in_oh.fillna(0)
|
| 2955 |
|
| 2956 |
+
surv_fn_all = cph.predict_survival_function(df_surv_in_oh)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2957 |
|
| 2958 |
+
def surv_vec_at(days: int):
|
| 2959 |
+
idx = surv_fn_all.index.values
|
| 2960 |
+
j = int(np.argmin(np.abs(idx - days)))
|
| 2961 |
+
return surv_fn_all.iloc[j, :].values.astype(float)
|
| 2962 |
|
| 2963 |
+
df_out["survival_6m"] = surv_vec_at(180)
|
| 2964 |
+
df_out["survival_1y"] = surv_vec_at(365)
|
| 2965 |
+
df_out["survival_2y"] = surv_vec_at(730)
|
| 2966 |
+
df_out["survival_3y"] = surv_vec_at(1095)
|
| 2967 |
+
|
| 2968 |
+
except Exception as e:
|
| 2969 |
+
st.warning(f"Batch survival probabilities could not be computed: {e}")
|
| 2970 |
+
|
| 2971 |
+
else:
|
| 2972 |
+
st.info("Survival model not loaded/published yet (survival bundle missing).")
|
| 2973 |
+
|
|
|
|
|
|
|
|
|
|
| 2974 |
|
| 2975 |
|
| 2976 |
|