Synav commited on
Commit
2f23595
·
verified ·
1 Parent(s): fd32ea6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -35
app.py CHANGED
@@ -1388,72 +1388,69 @@ with tab_predict:
1388
  st.pyplot(fig_w, clear_figure=True)
1389
 
1390
 
1391
- #Single row SHAP block
 
1392
  st.subheader("SHAP explanation")
1393
-
1394
  with st.form("shap_form"):
1395
  row = st.number_input("Row index", 0, len(X_inf) - 1, 0)
1396
  explain_btn = st.form_submit_button("Generate SHAP explanation")
1397
-
1398
  if explain_btn:
1399
  X_one = X_inf.iloc[[int(row)]]
1400
  X_one_t = transform_before_clf(pipe, X_one)
1401
-
1402
  explainer = st.session_state.get("explainer")
1403
  if explainer is None:
1404
  st.session_state.explainer = build_shap_explainer(pipe, X_inf)
1405
  explainer = st.session_state.explainer
1406
-
1407
  shap_vals = explainer.shap_values(X_one_t)
1408
  if isinstance(shap_vals, list):
1409
- shap_vals = shap_vals[1]
1410
-
1411
  names = get_final_feature_names(pipe)
1412
- if len(names) != len(shap_vals[0]):
1413
  st.warning(
1414
- f"Feature name mismatch: names={len(names)} vs shap_cols={len(shap_vals[0])}. "
1415
  "Using generic names."
1416
  )
1417
- names = [f"f{i}" for i in range(len(shap_vals[0]))]
1418
-
1419
  # Dense row vector for SHAP plots
1420
  try:
1421
  x_dense = X_one_t.toarray()[0]
1422
  except Exception:
1423
  x_dense = np.array(X_one_t)[0]
1424
-
1425
-
1426
- if len(names) != len(shap_vals[0]):
1427
- st.warning(
1428
- f"Feature name mismatch: names={len(names)} vs shap_cols={len(shap_vals[0])}. "
1429
- "Falling back to generic names."
1430
- )
1431
- names = [f"f{i}" for i in range(len(shap_vals[0]))]
1432
-
1433
-
1434
- try:
1435
- x_dense = X_one_t.toarray()[0]
1436
- except Exception:
1437
- x_dense = np.array(X_one_t)[0]
1438
-
1439
  exp = shap.Explanation(
1440
  values=shap_vals[0],
1441
- base_values=float(base) if np.isscalar(base) else float(np.array(base).reshape(-1)[0]),
1442
  data=x_dense,
1443
  feature_names=names,
1444
  )
1445
-
1446
  c1, c2 = st.columns(2)
1447
-
1448
  with c1:
1449
  st.markdown("**Waterfall**")
1450
- fig = plt.figure()
1451
  shap.plots.waterfall(exp, show=False, max_display=20)
1452
- st.pyplot(fig, clear_figure=True)
1453
-
 
 
1454
  with c2:
1455
  st.markdown("**Top features**")
1456
- fig2 = plt.figure()
1457
  shap.plots.bar(exp, show=False, max_display=20)
1458
- st.pyplot(fig2, clear_figure=True)
1459
- st.stop()
 
 
 
 
1388
  st.pyplot(fig_w, clear_figure=True)
1389
 
1390
 
1391
+
1392
+ # Single row SHAP block
1393
  st.subheader("SHAP explanation")
1394
+
1395
  with st.form("shap_form"):
1396
  row = st.number_input("Row index", 0, len(X_inf) - 1, 0)
1397
  explain_btn = st.form_submit_button("Generate SHAP explanation")
1398
+
1399
  if explain_btn:
1400
  X_one = X_inf.iloc[[int(row)]]
1401
  X_one_t = transform_before_clf(pipe, X_one)
1402
+
1403
  explainer = st.session_state.get("explainer")
1404
  if explainer is None:
1405
  st.session_state.explainer = build_shap_explainer(pipe, X_inf)
1406
  explainer = st.session_state.explainer
1407
+
1408
  shap_vals = explainer.shap_values(X_one_t)
1409
  if isinstance(shap_vals, list):
1410
+ shap_vals = shap_vals[1] # positive class
1411
+
1412
  names = get_final_feature_names(pipe)
1413
+ if len(names) != shap_vals.shape[1]:
1414
  st.warning(
1415
+ f"Feature name mismatch: names={len(names)} vs shap_cols={shap_vals.shape[1]}. "
1416
  "Using generic names."
1417
  )
1418
+ names = [f"f{i}" for i in range(shap_vals.shape[1])]
1419
+
1420
  # Dense row vector for SHAP plots
1421
  try:
1422
  x_dense = X_one_t.toarray()[0]
1423
  except Exception:
1424
  x_dense = np.array(X_one_t)[0]
1425
+
1426
+ # Base value
1427
+ base = explainer.expected_value
1428
+ if not np.isscalar(base):
1429
+ base = float(np.array(base).reshape(-1)[0])
1430
+
 
 
 
 
 
 
 
 
 
1431
  exp = shap.Explanation(
1432
  values=shap_vals[0],
1433
+ base_values=float(base),
1434
  data=x_dense,
1435
  feature_names=names,
1436
  )
1437
+
1438
  c1, c2 = st.columns(2)
1439
+
1440
  with c1:
1441
  st.markdown("**Waterfall**")
1442
+ fig = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1443
  shap.plots.waterfall(exp, show=False, max_display=20)
1444
+ render_plot_with_download(
1445
+ fig, title="SHAP waterfall", filename="shap_waterfall_row.png", export_dpi=export_dpi
1446
+ )
1447
+
1448
  with c2:
1449
  st.markdown("**Top features**")
1450
+ fig2 = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1451
  shap.plots.bar(exp, show=False, max_display=20)
1452
+ render_plot_with_download(
1453
+ fig2, title="SHAP bar", filename="shap_bar_row.png", export_dpi=export_dpi
1454
+ )
1455
+
1456
+ st.stop()