Update app.py
Browse files
app.py
CHANGED
|
@@ -2071,7 +2071,8 @@ with tab_predict:
|
|
| 2071 |
with cA:
|
| 2072 |
do_batch = st.button(f"Compute batch SHAP for first {batch_n} rows", key="batch_shap_btn")
|
| 2073 |
with cB:
|
| 2074 |
-
|
|
|
|
| 2075 |
with cC:
|
| 2076 |
show_beeswarm = st.checkbox("Show beeswarm (slower)", value=True, key="batch_beeswarm")
|
| 2077 |
|
|
@@ -2186,7 +2187,7 @@ with tab_predict:
|
|
| 2186 |
features=X_dense,
|
| 2187 |
feature_names=names,
|
| 2188 |
plot_type="bar",
|
| 2189 |
-
max_display=
|
| 2190 |
show=False
|
| 2191 |
)
|
| 2192 |
fig_bar = plt.gcf()
|
|
@@ -2201,7 +2202,7 @@ with tab_predict:
|
|
| 2201 |
shap_vals_batch,
|
| 2202 |
features=X_dense,
|
| 2203 |
feature_names=names,
|
| 2204 |
-
max_display=
|
| 2205 |
show=False
|
| 2206 |
)
|
| 2207 |
fig_swarm = plt.gcf()
|
|
@@ -2217,7 +2218,8 @@ with tab_predict:
|
|
| 2217 |
default=[0],
|
| 2218 |
key="batch_rows_to_plot"
|
| 2219 |
)
|
| 2220 |
-
|
|
|
|
| 2221 |
max_waterfalls = st.slider("Max waterfall plots to render", 1, 10, 3, 1, key="max_waterfalls")
|
| 2222 |
rows_to_plot = rows_to_plot[:max_waterfalls]
|
| 2223 |
|
|
@@ -2240,8 +2242,10 @@ with tab_predict:
|
|
| 2240 |
fig_w, ax = plt.subplots(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 2241 |
plt.sca(ax) # important: set current axis for SHAP
|
| 2242 |
|
| 2243 |
-
|
| 2244 |
-
|
|
|
|
|
|
|
| 2245 |
render_plot_with_download(
|
| 2246 |
fig_w,
|
| 2247 |
title=f"Batch SHAP waterfall (row {r})",
|
|
|
|
| 2071 |
with cA:
|
| 2072 |
do_batch = st.button(f"Compute batch SHAP for first {batch_n} rows", key="batch_shap_btn")
|
| 2073 |
with cB:
|
| 2074 |
+
max_display_batch = st.slider("Top features to display (batch)",5, 40, 20, 1,key="batch_max_display")
|
| 2075 |
+
|
| 2076 |
with cC:
|
| 2077 |
show_beeswarm = st.checkbox("Show beeswarm (slower)", value=True, key="batch_beeswarm")
|
| 2078 |
|
|
|
|
| 2187 |
features=X_dense,
|
| 2188 |
feature_names=names,
|
| 2189 |
plot_type="bar",
|
| 2190 |
+
max_display=max_display_batch,
|
| 2191 |
show=False
|
| 2192 |
)
|
| 2193 |
fig_bar = plt.gcf()
|
|
|
|
| 2202 |
shap_vals_batch,
|
| 2203 |
features=X_dense,
|
| 2204 |
feature_names=names,
|
| 2205 |
+
max_display=max_display_batch,
|
| 2206 |
show=False
|
| 2207 |
)
|
| 2208 |
fig_swarm = plt.gcf()
|
|
|
|
| 2218 |
default=[0],
|
| 2219 |
key="batch_rows_to_plot"
|
| 2220 |
)
|
| 2221 |
+
max_display_single = st.slider("Top features to display (single-row SHAP)",5, 40, 20, 1,key="single_max_display")
|
| 2222 |
+
|
| 2223 |
max_waterfalls = st.slider("Max waterfall plots to render", 1, 10, 3, 1, key="max_waterfalls")
|
| 2224 |
rows_to_plot = rows_to_plot[:max_waterfalls]
|
| 2225 |
|
|
|
|
| 2242 |
fig_w, ax = plt.subplots(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 2243 |
plt.sca(ax) # important: set current axis for SHAP
|
| 2244 |
|
| 2245 |
+
|
| 2246 |
+
shap.plots.waterfall(exp, show=False, max_display=max_display_single)
|
| 2247 |
+
shap.plots.bar(exp, show=False, max_display=max_display_single)
|
| 2248 |
+
|
| 2249 |
render_plot_with_download(
|
| 2250 |
fig_w,
|
| 2251 |
title=f"Batch SHAP waterfall (row {r})",
|