Synav commited on
Commit
0d384af
·
verified ·
1 Parent(s): 702891b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -33
app.py CHANGED
@@ -2930,45 +2930,47 @@ with tab_predict:
2930
  # ---- Batch survival probabilities (if survival model loaded) ----
2931
  # ---- Batch survival probabilities (if survival model loaded) ----
2932
  bundle = st.session_state.get("surv_model", None)
 
2933
  if isinstance(bundle, dict) and bundle.get("model") is not None:
2934
- cph = bundle["model"]
2935
- surv_cols = bundle["columns"]
2936
- imp = bundle.get("imputer", None)
 
2937
 
2938
- df_surv_in = X_inf[feature_cols].copy()
2939
- # coerce num/cat same as you already do
2940
- df_surv_in_oh = pd.get_dummies(df_surv_in, columns=cat_cols, drop_first=True)
2941
 
2942
- # align columns exactly
2943
- for col in surv_cols:
2944
- if col not in df_surv_in_oh.columns:
2945
- df_surv_in_oh[col] = 0
2946
- df_surv_in_oh = df_surv_in_oh[surv_cols]
 
 
 
 
 
 
 
2947
 
2948
- # impute with training imputer
2949
- if imp is not None:
2950
- X_imp = imp.transform(df_surv_in_oh)
2951
- df_surv_in_oh = pd.DataFrame(X_imp, columns=surv_cols, index=df_surv_in_oh.index)
2952
- else:
2953
- df_surv_in_oh = df_surv_in_oh.fillna(0)
2954
 
2955
- surv_fn_all = cph.predict_survival_function(df_surv_in_oh)
2956
-
 
 
2957
 
2958
- def surv_vec_at(days: int):
2959
- idx = surv_fn_all.index.values
2960
- j = int(np.argmin(np.abs(idx - days)))
2961
- return surv_fn_all.iloc[j, :].values.astype(float)
2962
-
2963
- df_out["survival_6m"] = surv_vec_at(180)
2964
- df_out["survival_1y"] = surv_vec_at(365)
2965
- df_out["survival_2y"] = surv_vec_at(730)
2966
- df_out["survival_3y"] = surv_vec_at(1095)
2967
-
2968
- except Exception as e:
2969
- st.warning(f"Batch survival probabilities could not be computed: {e}")
2970
- else:
2971
- st.info("Survival model not loaded/published yet (survival bundle missing).")
2972
 
2973
 
2974
 
 
2930
  # ---- Batch survival probabilities (if survival model loaded) ----
2931
  # ---- Batch survival probabilities (if survival model loaded) ----
2932
  bundle = st.session_state.get("surv_model", None)
2933
+
2934
  if isinstance(bundle, dict) and bundle.get("model") is not None:
2935
+ try:
2936
+ cph = bundle["model"]
2937
+ surv_cols = bundle["columns"]
2938
+ imp = bundle.get("imputer", None)
2939
 
2940
+ df_surv_in = X_inf[feature_cols].copy()
2941
+ df_surv_in_oh = pd.get_dummies(df_surv_in, columns=cat_cols, drop_first=True)
 
2942
 
2943
+ # align columns
2944
+ for col in surv_cols:
2945
+ if col not in df_surv_in_oh.columns:
2946
+ df_surv_in_oh[col] = 0
2947
+ df_surv_in_oh = df_surv_in_oh[surv_cols]
2948
+
2949
+ # impute
2950
+ if imp is not None:
2951
+ X_imp = imp.transform(df_surv_in_oh)
2952
+ df_surv_in_oh = pd.DataFrame(X_imp, columns=surv_cols, index=df_surv_in_oh.index)
2953
+ else:
2954
+ df_surv_in_oh = df_surv_in_oh.fillna(0)
2955
 
2956
+ surv_fn_all = cph.predict_survival_function(df_surv_in_oh)
 
 
 
 
 
2957
 
2958
+ def surv_vec_at(days: int):
2959
+ idx = surv_fn_all.index.values
2960
+ j = int(np.argmin(np.abs(idx - days)))
2961
+ return surv_fn_all.iloc[j, :].values.astype(float)
2962
 
2963
+ df_out["survival_6m"] = surv_vec_at(180)
2964
+ df_out["survival_1y"] = surv_vec_at(365)
2965
+ df_out["survival_2y"] = surv_vec_at(730)
2966
+ df_out["survival_3y"] = surv_vec_at(1095)
2967
+
2968
+ except Exception as e:
2969
+ st.warning(f"Batch survival probabilities could not be computed: {e}")
2970
+
2971
+ else:
2972
+ st.info("Survival model not loaded/published yet (survival bundle missing).")
2973
+
 
 
 
2974
 
2975
 
2976