Update app.py
Browse files
app.py
CHANGED
|
@@ -1388,72 +1388,69 @@ with tab_predict:
|
|
| 1388 |
st.pyplot(fig_w, clear_figure=True)
|
| 1389 |
|
| 1390 |
|
| 1391 |
-
|
|
|
|
| 1392 |
st.subheader("SHAP explanation")
|
| 1393 |
-
|
| 1394 |
with st.form("shap_form"):
|
| 1395 |
row = st.number_input("Row index", 0, len(X_inf) - 1, 0)
|
| 1396 |
explain_btn = st.form_submit_button("Generate SHAP explanation")
|
| 1397 |
-
|
| 1398 |
if explain_btn:
|
| 1399 |
X_one = X_inf.iloc[[int(row)]]
|
| 1400 |
X_one_t = transform_before_clf(pipe, X_one)
|
| 1401 |
-
|
| 1402 |
explainer = st.session_state.get("explainer")
|
| 1403 |
if explainer is None:
|
| 1404 |
st.session_state.explainer = build_shap_explainer(pipe, X_inf)
|
| 1405 |
explainer = st.session_state.explainer
|
| 1406 |
-
|
| 1407 |
shap_vals = explainer.shap_values(X_one_t)
|
| 1408 |
if isinstance(shap_vals, list):
|
| 1409 |
-
shap_vals = shap_vals[1]
|
| 1410 |
-
|
| 1411 |
names = get_final_feature_names(pipe)
|
| 1412 |
-
if len(names) !=
|
| 1413 |
st.warning(
|
| 1414 |
-
f"Feature name mismatch: names={len(names)} vs shap_cols={
|
| 1415 |
"Using generic names."
|
| 1416 |
)
|
| 1417 |
-
names = [f"f{i}" for i in range(
|
| 1418 |
-
|
| 1419 |
# Dense row vector for SHAP plots
|
| 1420 |
try:
|
| 1421 |
x_dense = X_one_t.toarray()[0]
|
| 1422 |
except Exception:
|
| 1423 |
x_dense = np.array(X_one_t)[0]
|
| 1424 |
-
|
| 1425 |
-
|
| 1426 |
-
|
| 1427 |
-
|
| 1428 |
-
|
| 1429 |
-
|
| 1430 |
-
)
|
| 1431 |
-
names = [f"f{i}" for i in range(len(shap_vals[0]))]
|
| 1432 |
-
|
| 1433 |
-
|
| 1434 |
-
try:
|
| 1435 |
-
x_dense = X_one_t.toarray()[0]
|
| 1436 |
-
except Exception:
|
| 1437 |
-
x_dense = np.array(X_one_t)[0]
|
| 1438 |
-
|
| 1439 |
exp = shap.Explanation(
|
| 1440 |
values=shap_vals[0],
|
| 1441 |
-
base_values=float(base)
|
| 1442 |
data=x_dense,
|
| 1443 |
feature_names=names,
|
| 1444 |
)
|
| 1445 |
-
|
| 1446 |
c1, c2 = st.columns(2)
|
| 1447 |
-
|
| 1448 |
with c1:
|
| 1449 |
st.markdown("**Waterfall**")
|
| 1450 |
-
fig =
|
| 1451 |
shap.plots.waterfall(exp, show=False, max_display=20)
|
| 1452 |
-
|
| 1453 |
-
|
|
|
|
|
|
|
| 1454 |
with c2:
|
| 1455 |
st.markdown("**Top features**")
|
| 1456 |
-
fig2 =
|
| 1457 |
shap.plots.bar(exp, show=False, max_display=20)
|
| 1458 |
-
|
| 1459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1388 |
st.pyplot(fig_w, clear_figure=True)
|
| 1389 |
|
| 1390 |
|
| 1391 |
+
|
| 1392 |
+
# Single row SHAP block
|
| 1393 |
st.subheader("SHAP explanation")
|
| 1394 |
+
|
| 1395 |
with st.form("shap_form"):
|
| 1396 |
row = st.number_input("Row index", 0, len(X_inf) - 1, 0)
|
| 1397 |
explain_btn = st.form_submit_button("Generate SHAP explanation")
|
| 1398 |
+
|
| 1399 |
if explain_btn:
|
| 1400 |
X_one = X_inf.iloc[[int(row)]]
|
| 1401 |
X_one_t = transform_before_clf(pipe, X_one)
|
| 1402 |
+
|
| 1403 |
explainer = st.session_state.get("explainer")
|
| 1404 |
if explainer is None:
|
| 1405 |
st.session_state.explainer = build_shap_explainer(pipe, X_inf)
|
| 1406 |
explainer = st.session_state.explainer
|
| 1407 |
+
|
| 1408 |
shap_vals = explainer.shap_values(X_one_t)
|
| 1409 |
if isinstance(shap_vals, list):
|
| 1410 |
+
shap_vals = shap_vals[1] # positive class
|
| 1411 |
+
|
| 1412 |
names = get_final_feature_names(pipe)
|
| 1413 |
+
if len(names) != shap_vals.shape[1]:
|
| 1414 |
st.warning(
|
| 1415 |
+
f"Feature name mismatch: names={len(names)} vs shap_cols={shap_vals.shape[1]}. "
|
| 1416 |
"Using generic names."
|
| 1417 |
)
|
| 1418 |
+
names = [f"f{i}" for i in range(shap_vals.shape[1])]
|
| 1419 |
+
|
| 1420 |
# Dense row vector for SHAP plots
|
| 1421 |
try:
|
| 1422 |
x_dense = X_one_t.toarray()[0]
|
| 1423 |
except Exception:
|
| 1424 |
x_dense = np.array(X_one_t)[0]
|
| 1425 |
+
|
| 1426 |
+
# Base value
|
| 1427 |
+
base = explainer.expected_value
|
| 1428 |
+
if not np.isscalar(base):
|
| 1429 |
+
base = float(np.array(base).reshape(-1)[0])
|
| 1430 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1431 |
exp = shap.Explanation(
|
| 1432 |
values=shap_vals[0],
|
| 1433 |
+
base_values=float(base),
|
| 1434 |
data=x_dense,
|
| 1435 |
feature_names=names,
|
| 1436 |
)
|
| 1437 |
+
|
| 1438 |
c1, c2 = st.columns(2)
|
| 1439 |
+
|
| 1440 |
with c1:
|
| 1441 |
st.markdown("**Waterfall**")
|
| 1442 |
+
fig = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1443 |
shap.plots.waterfall(exp, show=False, max_display=20)
|
| 1444 |
+
render_plot_with_download(
|
| 1445 |
+
fig, title="SHAP waterfall", filename="shap_waterfall_row.png", export_dpi=export_dpi
|
| 1446 |
+
)
|
| 1447 |
+
|
| 1448 |
with c2:
|
| 1449 |
st.markdown("**Top features**")
|
| 1450 |
+
fig2 = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1451 |
shap.plots.bar(exp, show=False, max_display=20)
|
| 1452 |
+
render_plot_with_download(
|
| 1453 |
+
fig2, title="SHAP bar", filename="shap_bar_row.png", export_dpi=export_dpi
|
| 1454 |
+
)
|
| 1455 |
+
|
| 1456 |
+
st.stop()
|