mfarnas commited on
Commit
76f9b18
·
1 Parent(s): 4da4fcb

add shap for bulk preds

Browse files
Files changed (1) hide show
  1. src/pages/2_Bulk_Predictions.py +92 -4
src/pages/2_Bulk_Predictions.py CHANGED
@@ -1,11 +1,17 @@
1
  import streamlit as st
2
  import pandas as pd
 
 
3
  from model_utils import load_model, load_model_ensemble, ensemble_predict
4
  from preprocess_utils import load_train_features
5
  from preprocess_utils import preprocess_pipeline as preprocess
6
- from inference_utils import add_predictions, compute_metrics
7
  from sidebar import sidebar
8
 
 
 
 
 
9
  # Initialize sidebar
10
  sidebar()
11
 
@@ -31,7 +37,7 @@ if "selected_model" in st.session_state:
31
  target_col = model_dict.get("target_col", "UNKNOWN")
32
 
33
  st.session_state.target_col = target_col
34
- st.warning(f"The model selected will only predict the target \"{target_col}\". Please choose a different model if you want to predict a different target.")
35
 
36
  st.title("📊 Bulk Patient Predictions")
37
 
@@ -51,7 +57,7 @@ if uploaded_file:
51
 
52
  if st.button("Predict"):
53
  if "bulk_input_df" not in st.session_state:
54
- st.warning("Please preprocess data first.")
55
  else:
56
  df = st.session_state.bulk_input_df
57
 
@@ -74,10 +80,22 @@ if uploaded_file:
74
  if ensemble:
75
  preds = ensemble_predict(models, X, cat_features)
76
  metrics_result_ensemble = compute_metrics(y, preds)
 
77
  else:
78
- # single model prediction
79
  preds = model.predict_proba(X)[:, 1]
80
  metrics_result_single = compute_metrics(y, preds)
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  st.session_state.targets_df = y
83
  styled = add_predictions(X.copy(), preds)
@@ -93,6 +111,76 @@ if uploaded_file:
93
  for metric, value in metrics_result_ensemble.items():
94
  st.write(f" **{metric}**: {value:.3f}")
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  # Find difference in columns between uploaded data and training features
97
  missing_features = set(st.session_state.orig_train_cols).union(train_features) - set(df.columns)
98
  missing_features = set([i if i[-2:] != "_X" else '' for i in missing_features])
 
1
  import streamlit as st
2
  import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
  from model_utils import load_model, load_model_ensemble, ensemble_predict
6
  from preprocess_utils import load_train_features
7
  from preprocess_utils import preprocess_pipeline as preprocess
8
+ from inference_utils import add_predictions, compute_metrics, st_shap, ensemble_shap
9
  from sidebar import sidebar
10
 
11
+ import shap
12
+ import lime
13
+ import lime.lime_tabular
14
+
15
  # Initialize sidebar
16
  sidebar()
17
 
 
37
  target_col = model_dict.get("target_col", "UNKNOWN")
38
 
39
  st.session_state.target_col = target_col
40
+ st.warning(f"The model selected will only predict the target \"{target_col}\". Please choose a different model if you wish to predict a different target.")
41
 
42
  st.title("📊 Bulk Patient Predictions")
43
 
 
57
 
58
  if st.button("Predict"):
59
  if "bulk_input_df" not in st.session_state:
60
+ st.warning("Please preprocess the data first.")
61
  else:
62
  df = st.session_state.bulk_input_df
63
 
 
80
  if ensemble:
81
  preds = ensemble_predict(models, X, cat_features)
82
  metrics_result_ensemble = compute_metrics(y, preds)
83
+ shap_values = ensemble_shap(models, X)
84
  else:
 
85
  preds = model.predict_proba(X)[:, 1]
86
  metrics_result_single = compute_metrics(y, preds)
87
+
88
+ explainer = shap.TreeExplainer(model)
89
+ shap_values = explainer(X)
90
+
91
+ # Handle multi-class (use class 1)
92
+ if shap_values.values.ndim == 3:
93
+ shap_values = shap.Explanation(
94
+ values=shap_values.values[:, :, 1],
95
+ base_values=shap_values.base_values[:, 1] if shap_values.base_values.ndim == 2 else shap_values.base_values,
96
+ data=X,
97
+ feature_names=X.columns
98
+ )
99
 
100
  st.session_state.targets_df = y
101
  styled = add_predictions(X.copy(), preds)
 
111
  for metric, value in metrics_result_ensemble.items():
112
  st.write(f" **{metric}**: {value:.3f}")
113
 
114
+ def get_top_features(shap_values_array, feature_names, n=20):
115
+ import numpy as np
116
+ import shap
117
+
118
+ # If a shap.Explanation was passed, extract .values
119
+ if isinstance(shap_values_array, shap.Explanation):
120
+ shap_values_array = shap_values_array.values
121
+
122
+ mean_abs_shap = np.abs(shap_values_array).mean(0)
123
+ feature_importance = pd.DataFrame({
124
+ 'feature': feature_names,
125
+ 'importance': mean_abs_shap
126
+ })
127
+ return feature_importance.sort_values('importance', ascending=False)['feature'].tolist()[:n]
128
+
129
+
130
+ with st.expander("Show SHAP Explainability", expanded=True):
131
+ # Get top 20 features
132
+ top_features = get_top_features(shap_values, X.columns)
133
+
134
+ # Feature selection widget
135
+ selected_features = st.multiselect(
136
+ "Select features to display in plots",
137
+ options=list(X.columns),
138
+ default=top_features
139
+ )
140
+
141
+ if not selected_features:
142
+ st.warning("Please select at least one feature to display")
143
+ else:
144
+ # Filter data for selected features
145
+ X_selected = X[selected_features]
146
+ feature_indices = [list(X.columns).index(f) for f in selected_features]
147
+
148
+ # Slice features directly from the SHAP Explanation
149
+ shap_values_selected = shap_values[:, feature_indices]
150
+ shap_values_selected.feature_names = selected_features
151
+ shap_values_selected.data = X_selected
152
+
153
+ # ---- Beeswarm: overall feature impact ----
154
+ st.subheader("SHAP Feature Importance")
155
+ plt.figure(figsize=(10, 6))
156
+ shap.plots.beeswarm(shap_values_selected, max_display=20, show=False)
157
+ st.pyplot(plt.gcf(), bbox_inches='tight')
158
+ plt.clf()
159
+
160
+ # ---- Mean absolute SHAP bar chart ----
161
+ st.subheader("Mean(|SHAP value|) per Feature")
162
+ plt.figure(figsize=(10, 6))
163
+ shap.plots.bar(shap_values_selected, max_display=20, show=False)
164
+ st.pyplot(plt.gcf(), bbox_inches='tight')
165
+ plt.clf()
166
+
167
+ # ---- Dependence plot ----
168
+ st.subheader("SHAP Dependence Plot")
169
+ feature = st.selectbox("Select main feature", selected_features)
170
+ interaction_feature = st.selectbox(
171
+ "Select interaction feature (optional)",
172
+ ["None"] + selected_features
173
+ )
174
+
175
+ plt.figure(figsize=(10, 6))
176
+ if interaction_feature == "None":
177
+ shap.dependence_plot(feature, shap_values_selected.values, X_selected, show=False)
178
+ else:
179
+ shap.dependence_plot(feature, shap_values_selected.values, X_selected, interaction_index=interaction_feature, show=False)
180
+
181
+ st.pyplot(plt.gcf(), bbox_inches='tight')
182
+ plt.clf()
183
+
184
  # Find difference in columns between uploaded data and training features
185
  missing_features = set(st.session_state.orig_train_cols).union(train_features) - set(df.columns)
186
  missing_features = set([i if i[-2:] != "_X" else '' for i in missing_features])