Synav commited on
Commit
ec93275
·
verified ·
1 Parent(s): e0ca17a

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +39 -2
src/streamlit_app.py CHANGED
@@ -463,8 +463,45 @@ if submitted:
463
  # If SELF donor case, prediction is forced to 0.0 and there is no meaningful SHAP
464
  if st.session_state.SELF:
465
  st.info("SHAP is not shown for SELF donor (prediction forced to 0).")
466
- elif model_to_use.endswith("ensemble"):
467
- st.info("SHAP for ensemble is available if your ensemble_shap() is configured. Skipping here for now.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  else:
469
  model_dict = load_model(model_to_use)
470
  model = model_dict["model"] # XGBoost model
 
463
  # If SELF donor case, prediction is forced to 0.0 and there is no meaningful SHAP
464
  if st.session_state.SELF:
465
  st.info("SHAP is not shown for SELF donor (prediction forced to 0).")
466
+ elif model_to_use.endswith("ensemble"):
467
+ ensemble_data = load_model_ensemble(model_to_use)
468
+ models = ensemble_data["model"]
469
+
470
+ # Call your ensemble SHAP helper
471
+ shap_values, shap_base, feature_names = ensemble_shap(
472
+ models=models,
473
+ X=X_model,
474
+ cat_features=cat_features,
475
+ )
476
+
477
+ # Build SHAP explanation object (single patient)
478
+ one = shap.Explanation(
479
+ values=shap_values[0],
480
+ base_values=shap_base[0] if np.ndim(shap_base) > 0 else shap_base,
481
+ data=X_model.iloc[0].values,
482
+ feature_names=feature_names,
483
+ )
484
+
485
+ vals = one.values
486
+ feats = np.array(one.feature_names)
487
+
488
+ top_idx = np.argsort(np.abs(vals))[::-1][:20]
489
+
490
+ shap_table = pd.DataFrame({
491
+ "Feature": feats[top_idx],
492
+ "Feature value": X_model.iloc[0][feats[top_idx]].values,
493
+ "SHAP value (pushes risk ↑ / ↓)": vals[top_idx],
494
+ })
495
+
496
+ st.subheader("Top features driving this patient’s prediction")
497
+ st.dataframe(shap_table, use_container_width=True)
498
+
499
+ st.subheader("Waterfall plot (ensemble)")
500
+ plt.figure(figsize=(10, 6))
501
+ shap.plots.waterfall(one, max_display=20, show=False)
502
+ st.pyplot(plt.gcf(), bbox_inches="tight")
503
+ plt.clf()
504
+ st.caption("SHAP values shown are averaged across ensemble models.")
505
  else:
506
  model_dict = load_model(model_to_use)
507
  model = model_dict["model"] # XGBoost model