Synav commited on
Commit
18e34dc
·
verified ·
1 Parent(s): bba2c69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -2071,7 +2071,8 @@ with tab_predict:
2071
  with cA:
2072
  do_batch = st.button(f"Compute batch SHAP for first {batch_n} rows", key="batch_shap_btn")
2073
  with cB:
2074
- max_display = st.slider("Top features to display", 5, 40, 20, 1, key="batch_max_display")
 
2075
  with cC:
2076
  show_beeswarm = st.checkbox("Show beeswarm (slower)", value=True, key="batch_beeswarm")
2077
 
@@ -2186,7 +2187,7 @@ with tab_predict:
2186
  features=X_dense,
2187
  feature_names=names,
2188
  plot_type="bar",
2189
- max_display=max_display,
2190
  show=False
2191
  )
2192
  fig_bar = plt.gcf()
@@ -2201,7 +2202,7 @@ with tab_predict:
2201
  shap_vals_batch,
2202
  features=X_dense,
2203
  feature_names=names,
2204
- max_display=max_display,
2205
  show=False
2206
  )
2207
  fig_swarm = plt.gcf()
@@ -2217,7 +2218,8 @@ with tab_predict:
2217
  default=[0],
2218
  key="batch_rows_to_plot"
2219
  )
2220
-
 
2221
  max_waterfalls = st.slider("Max waterfall plots to render", 1, 10, 3, 1, key="max_waterfalls")
2222
  rows_to_plot = rows_to_plot[:max_waterfalls]
2223
 
@@ -2240,8 +2242,10 @@ with tab_predict:
2240
  fig_w, ax = plt.subplots(figsize=FIGSIZE, dpi=plot_dpi_screen)
2241
  plt.sca(ax) # important: set current axis for SHAP
2242
 
2243
- shap.plots.waterfall(exp, show=False, max_display=max_display)
2244
-
 
 
2245
  render_plot_with_download(
2246
  fig_w,
2247
  title=f"Batch SHAP waterfall (row {r})",
 
2071
  with cA:
2072
  do_batch = st.button(f"Compute batch SHAP for first {batch_n} rows", key="batch_shap_btn")
2073
  with cB:
2074
+ max_display_batch = st.slider("Top features to display (batch)",5, 40, 20, 1,key="batch_max_display")
2075
+
2076
  with cC:
2077
  show_beeswarm = st.checkbox("Show beeswarm (slower)", value=True, key="batch_beeswarm")
2078
 
 
2187
  features=X_dense,
2188
  feature_names=names,
2189
  plot_type="bar",
2190
+ max_display=max_display_batch,
2191
  show=False
2192
  )
2193
  fig_bar = plt.gcf()
 
2202
  shap_vals_batch,
2203
  features=X_dense,
2204
  feature_names=names,
2205
+ max_display=max_display_batch,
2206
  show=False
2207
  )
2208
  fig_swarm = plt.gcf()
 
2218
  default=[0],
2219
  key="batch_rows_to_plot"
2220
  )
2221
+ max_display_single = st.slider("Top features to display (single-row SHAP)",5, 40, 20, 1,key="single_max_display")
2222
+
2223
  max_waterfalls = st.slider("Max waterfall plots to render", 1, 10, 3, 1, key="max_waterfalls")
2224
  rows_to_plot = rows_to_plot[:max_waterfalls]
2225
 
 
2242
  fig_w, ax = plt.subplots(figsize=FIGSIZE, dpi=plot_dpi_screen)
2243
  plt.sca(ax) # important: set current axis for SHAP
2244
 
2245
+
2246
+ shap.plots.waterfall(exp, show=False, max_display=max_display_single)
2247
+ shap.plots.bar(exp, show=False, max_display=max_display_single)
2248
+
2249
  render_plot_with_download(
2250
  fig_w,
2251
  title=f"Batch SHAP waterfall (row {r})",