Synav commited on
Commit
6fd3b24
·
verified ·
1 Parent(s): ec93275

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +7 -17
src/streamlit_app.py CHANGED
@@ -463,25 +463,14 @@ if submitted:
463
  # If SELF donor case, prediction is forced to 0.0 and there is no meaningful SHAP
464
  if st.session_state.SELF:
465
  st.info("SHAP is not shown for SELF donor (prediction forced to 0).")
466
- elif model_to_use.endswith("ensemble"):
467
  ensemble_data = load_model_ensemble(model_to_use)
468
  models = ensemble_data["model"]
469
 
470
- # Call your ensemble SHAP helper
471
- shap_values, shap_base, feature_names = ensemble_shap(
472
- models=models,
473
- X=X_model,
474
- cat_features=cat_features,
475
- )
476
-
477
- # Build SHAP explanation object (single patient)
478
- one = shap.Explanation(
479
- values=shap_values[0],
480
- base_values=shap_base[0] if np.ndim(shap_base) > 0 else shap_base,
481
- data=X_model.iloc[0].values,
482
- feature_names=feature_names,
483
- )
484
 
 
485
  vals = one.values
486
  feats = np.array(one.feature_names)
487
 
@@ -493,14 +482,15 @@ if submitted:
493
  "SHAP value (pushes risk ↑ / ↓)": vals[top_idx],
494
  })
495
 
496
- st.subheader("Top features driving this patient’s prediction")
497
  st.dataframe(shap_table, use_container_width=True)
498
 
499
- st.subheader("Waterfall plot (ensemble)")
500
  plt.figure(figsize=(10, 6))
501
  shap.plots.waterfall(one, max_display=20, show=False)
502
  st.pyplot(plt.gcf(), bbox_inches="tight")
503
  plt.clf()
 
504
  st.caption("SHAP values shown are averaged across ensemble models.")
505
  else:
506
  model_dict = load_model(model_to_use)
 
463
  # If SELF donor case, prediction is forced to 0.0 and there is no meaningful SHAP
464
  if st.session_state.SELF:
465
  st.info("SHAP is not shown for SELF donor (prediction forced to 0).")
466
+ elif model_to_use.endswith("ensemble"):
467
  ensemble_data = load_model_ensemble(model_to_use)
468
  models = ensemble_data["model"]
469
 
470
+ # ensemble_shap() RETURNS shap.Explanation (already averaged across models)
471
+ ens_expl = ensemble_shap(models=models, X=X_model, positive_class=1)
 
 
 
 
 
 
 
 
 
 
 
 
472
 
473
+ one = ens_expl[0]
474
  vals = one.values
475
  feats = np.array(one.feature_names)
476
 
 
482
  "SHAP value (pushes risk ↑ / ↓)": vals[top_idx],
483
  })
484
 
485
+ st.subheader("Top features driving this patient’s prediction (ensemble-mean)")
486
  st.dataframe(shap_table, use_container_width=True)
487
 
488
+ st.subheader("Waterfall plot (single patient, ensemble-mean)")
489
  plt.figure(figsize=(10, 6))
490
  shap.plots.waterfall(one, max_display=20, show=False)
491
  st.pyplot(plt.gcf(), bbox_inches="tight")
492
  plt.clf()
493
+
494
  st.caption("SHAP values shown are averaged across ensemble models.")
495
  else:
496
  model_dict = load_model(model_to_use)