Update app.py
Browse files
app.py
CHANGED
|
@@ -915,6 +915,7 @@ with tab_predict:
|
|
| 915 |
)
|
| 916 |
|
| 917 |
#SHAP BLOCK
|
|
|
|
| 918 |
st.divider()
|
| 919 |
st.subheader("Batch SHAP (first 200 rows)")
|
| 920 |
|
|
@@ -922,61 +923,116 @@ with tab_predict:
|
|
| 922 |
n_rows = len(X_inf)
|
| 923 |
batch_n = min(MAX_BATCH, n_rows)
|
| 924 |
|
| 925 |
-
cA, cB = st.columns([1, 1])
|
| 926 |
with cA:
|
| 927 |
do_batch = st.button(f"Compute batch SHAP for first {batch_n} rows", key="batch_shap_btn")
|
| 928 |
with cB:
|
| 929 |
max_display = st.slider("Top features to display", 5, 40, 20, 1, key="batch_max_display")
|
|
|
|
|
|
|
| 930 |
|
| 931 |
if do_batch:
|
| 932 |
with st.spinner("Computing batch SHAP..."):
|
| 933 |
pre = pipe.named_steps["preprocess"]
|
| 934 |
|
| 935 |
-
# Use first N rows (fast + predictable memory)
|
| 936 |
X_batch = X_inf.iloc[:batch_n].copy()
|
| 937 |
X_batch_t = pre.transform(X_batch)
|
| 938 |
|
| 939 |
-
#
|
| 940 |
-
|
|
|
|
| 941 |
st.session_state.explainer = build_shap_explainer(pipe, X_inf)
|
|
|
|
| 942 |
|
| 943 |
-
|
| 944 |
-
|
| 945 |
-
|
| 946 |
-
if isinstance(shap_vals, list):
|
| 947 |
-
shap_vals = shap_vals[1] # positive class
|
| 948 |
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
|
| 953 |
|
|
|
|
| 954 |
try:
|
| 955 |
-
|
| 956 |
except Exception:
|
| 957 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 958 |
|
| 959 |
st.success(f"Batch SHAP computed for first {batch_n} rows.")
|
| 960 |
-
|
| 961 |
-
|
| 962 |
if "shap_batch_vals" in st.session_state:
|
| 963 |
-
|
| 964 |
-
|
| 965 |
batch_n = st.session_state.shap_batch_n
|
| 966 |
names = st.session_state.shap_batch_feature_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 967 |
|
| 968 |
-
st.markdown("### Global SHAP summary (first {} rows)"
|
| 969 |
-
|
| 970 |
-
# Convert X to dense only once if needed (beeswarm often needs dense)
|
| 971 |
-
try:
|
| 972 |
-
X_dense = X_batch_t.toarray()
|
| 973 |
-
except Exception:
|
| 974 |
-
X_dense = np.array(X_batch_t)
|
| 975 |
|
| 976 |
# BAR SUMMARY
|
| 977 |
fig_bar = plt.figure()
|
| 978 |
shap.summary_plot(
|
| 979 |
-
|
| 980 |
features=X_dense,
|
| 981 |
feature_names=names,
|
| 982 |
plot_type="bar",
|
|
@@ -985,22 +1041,17 @@ with tab_predict:
|
|
| 985 |
)
|
| 986 |
st.pyplot(fig_bar, clear_figure=True)
|
| 987 |
|
| 988 |
-
# BEESWARM SUMMARY
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
|
| 994 |
-
|
| 995 |
-
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
|
| 999 |
-
if "shap_batch_vals" in st.session_state:
|
| 1000 |
-
shap_vals = st.session_state.shap_batch_vals
|
| 1001 |
-
X_batch_t = st.session_state.shap_batch_Xt
|
| 1002 |
-
batch_n = st.session_state.shap_batch_n
|
| 1003 |
-
names = st.session_state.shap_batch_feature_names
|
| 1004 |
|
| 1005 |
st.markdown("### Waterfall plots (batch)")
|
| 1006 |
|
|
@@ -1014,20 +1065,15 @@ with tab_predict:
|
|
| 1014 |
max_waterfalls = st.slider("Max waterfall plots to render", 1, 10, 3, 1, key="max_waterfalls")
|
| 1015 |
rows_to_plot = rows_to_plot[:max_waterfalls]
|
| 1016 |
|
| 1017 |
-
|
|
|
|
| 1018 |
if not np.isscalar(base):
|
| 1019 |
base = float(np.array(base).reshape(-1)[0])
|
| 1020 |
|
| 1021 |
-
# dense only if needed for data in Explanation
|
| 1022 |
-
try:
|
| 1023 |
-
X_dense = X_batch_t.toarray()
|
| 1024 |
-
except Exception:
|
| 1025 |
-
X_dense = np.array(X_batch_t)
|
| 1026 |
-
|
| 1027 |
for r in rows_to_plot:
|
| 1028 |
st.markdown(f"**Row {r} (within first {batch_n})**")
|
| 1029 |
exp = shap.Explanation(
|
| 1030 |
-
values=
|
| 1031 |
base_values=float(base),
|
| 1032 |
data=X_dense[r],
|
| 1033 |
feature_names=names,
|
|
@@ -1036,7 +1082,8 @@ with tab_predict:
|
|
| 1036 |
shap.plots.waterfall(exp, show=False, max_display=max_display)
|
| 1037 |
st.pyplot(fig_w, clear_figure=True)
|
| 1038 |
|
| 1039 |
-
|
|
|
|
| 1040 |
st.subheader("SHAP explanation")
|
| 1041 |
|
| 1042 |
with st.form("shap_form"):
|
|
|
|
| 915 |
)
|
| 916 |
|
| 917 |
#SHAP BLOCK
|
| 918 |
+
|
| 919 |
st.divider()
|
| 920 |
st.subheader("Batch SHAP (first 200 rows)")
|
| 921 |
|
|
|
|
| 923 |
n_rows = len(X_inf)
|
| 924 |
batch_n = min(MAX_BATCH, n_rows)
|
| 925 |
|
| 926 |
+
cA, cB, cC = st.columns([1, 1, 1])
|
| 927 |
with cA:
|
| 928 |
do_batch = st.button(f"Compute batch SHAP for first {batch_n} rows", key="batch_shap_btn")
|
| 929 |
with cB:
|
| 930 |
max_display = st.slider("Top features to display", 5, 40, 20, 1, key="batch_max_display")
|
| 931 |
+
with cC:
|
| 932 |
+
show_beeswarm = st.checkbox("Show beeswarm (slower)", value=True, key="batch_beeswarm")
|
| 933 |
|
| 934 |
if do_batch:
|
| 935 |
with st.spinner("Computing batch SHAP..."):
|
| 936 |
pre = pipe.named_steps["preprocess"]
|
| 937 |
|
|
|
|
| 938 |
X_batch = X_inf.iloc[:batch_n].copy()
|
| 939 |
X_batch_t = pre.transform(X_batch)
|
| 940 |
|
| 941 |
+
# Ensure explainer exists
|
| 942 |
+
explainer = st.session_state.get("explainer")
|
| 943 |
+
if explainer is None:
|
| 944 |
st.session_state.explainer = build_shap_explainer(pipe, X_inf)
|
| 945 |
+
explainer = st.session_state.explainer
|
| 946 |
|
| 947 |
+
shap_vals_batch = explainer.shap_values(X_batch_t)
|
| 948 |
+
if isinstance(shap_vals_batch, list):
|
| 949 |
+
shap_vals_batch = shap_vals_batch[1] # positive class
|
|
|
|
|
|
|
| 950 |
|
| 951 |
+
try:
|
| 952 |
+
names = list(pre.get_feature_names_out())
|
| 953 |
+
except Exception:
|
| 954 |
+
names = [f"f{i}" for i in range(shap_vals_batch.shape[1])]
|
| 955 |
|
| 956 |
+
# Dense conversion once (used for summary + waterfalls)
|
| 957 |
try:
|
| 958 |
+
X_dense = X_batch_t.toarray()
|
| 959 |
except Exception:
|
| 960 |
+
X_dense = np.array(X_batch_t)
|
| 961 |
+
|
| 962 |
+
# Cache batch results
|
| 963 |
+
st.session_state.shap_batch_vals = shap_vals_batch
|
| 964 |
+
st.session_state.shap_batch_X_dense = X_dense
|
| 965 |
+
st.session_state.shap_batch_n = batch_n
|
| 966 |
+
st.session_state.shap_batch_feature_names = names
|
| 967 |
|
| 968 |
st.success(f"Batch SHAP computed for first {batch_n} rows.")
|
| 969 |
+
|
|
|
|
| 970 |
if "shap_batch_vals" in st.session_state:
|
| 971 |
+
shap_vals_batch = st.session_state.shap_batch_vals
|
| 972 |
+
X_dense = st.session_state.shap_batch_X_dense
|
| 973 |
batch_n = st.session_state.shap_batch_n
|
| 974 |
names = st.session_state.shap_batch_feature_names
|
| 975 |
+
|
| 976 |
+
st.divider()
|
| 977 |
+
st.subheader("Export: Top SHAP features per row (batch)")
|
| 978 |
+
|
| 979 |
+
top_k = st.slider("Top-K features per row", 3, 30, 10, 1, key="topk_export")
|
| 980 |
+
|
| 981 |
+
# Optional: include predicted probabilities for the same batch rows
|
| 982 |
+
# (Assumes you already computed proba for all X_inf earlier)
|
| 983 |
+
include_proba = st.checkbox("Include predicted probability", value=True, key="include_proba_export")
|
| 984 |
+
|
| 985 |
+
if st.button("Generate Top-K SHAP table", key="gen_topk_shap"):
|
| 986 |
+
shap_vals_batch = st.session_state.shap_batch_vals # shape: (batch_n, n_features)
|
| 987 |
+
names = st.session_state.shap_batch_feature_names
|
| 988 |
+
batch_n = st.session_state.shap_batch_n
|
| 989 |
+
|
| 990 |
+
rows = []
|
| 991 |
+
for i in range(batch_n):
|
| 992 |
+
sv = shap_vals_batch[i]
|
| 993 |
+
idx = np.argsort(np.abs(sv))[::-1][:top_k] # top-k by absolute SHAP
|
| 994 |
+
|
| 995 |
+
for j in idx:
|
| 996 |
+
val = float(sv[j])
|
| 997 |
+
rows.append({
|
| 998 |
+
"row_in_batch": int(i),
|
| 999 |
+
"feature": str(names[j]),
|
| 1000 |
+
"shap_value": val,
|
| 1001 |
+
"abs_shap_value": abs(val),
|
| 1002 |
+
"direction": "↑" if val > 0 else ("↓" if val < 0 else "0"),
|
| 1003 |
+
})
|
| 1004 |
+
|
| 1005 |
+
df_topk = pd.DataFrame(rows)
|
| 1006 |
+
|
| 1007 |
+
if include_proba:
|
| 1008 |
+
# Use the same batch rows from the previously computed proba vector
|
| 1009 |
+
# If you want absolute Excel row index, add + df_inf.index[0] logic as needed
|
| 1010 |
+
proba_batch = proba[:batch_n]
|
| 1011 |
+
df_proba = pd.DataFrame({"row_in_batch": list(range(batch_n)), "predicted_probability": proba_batch})
|
| 1012 |
+
df_topk = df_topk.merge(df_proba, on="row_in_batch", how="left")
|
| 1013 |
+
|
| 1014 |
+
# Sort nicely: each row block by importance
|
| 1015 |
+
df_topk = df_topk.sort_values(["row_in_batch", "abs_shap_value"], ascending=[True, False])
|
| 1016 |
+
|
| 1017 |
+
st.dataframe(df_topk, use_container_width=True)
|
| 1018 |
+
|
| 1019 |
+
st.download_button(
|
| 1020 |
+
"Download Top-K SHAP per row (CSV)",
|
| 1021 |
+
df_topk.to_csv(index=False).encode("utf-8"),
|
| 1022 |
+
file_name=f"shap_top{top_k}_per_row_first{batch_n}.csv",
|
| 1023 |
+
mime="text/csv",
|
| 1024 |
+
key="dl_topk_shap_csv"
|
| 1025 |
+
)
|
| 1026 |
+
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
|
| 1030 |
+
st.markdown(f"### Global SHAP summary (first {batch_n} rows)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1031 |
|
| 1032 |
# BAR SUMMARY
|
| 1033 |
fig_bar = plt.figure()
|
| 1034 |
shap.summary_plot(
|
| 1035 |
+
shap_vals_batch,
|
| 1036 |
features=X_dense,
|
| 1037 |
feature_names=names,
|
| 1038 |
plot_type="bar",
|
|
|
|
| 1041 |
)
|
| 1042 |
st.pyplot(fig_bar, clear_figure=True)
|
| 1043 |
|
| 1044 |
+
# BEESWARM SUMMARY (optional)
|
| 1045 |
+
if show_beeswarm:
|
| 1046 |
+
fig_swarm = plt.figure()
|
| 1047 |
+
shap.summary_plot(
|
| 1048 |
+
shap_vals_batch,
|
| 1049 |
+
features=X_dense,
|
| 1050 |
+
feature_names=names,
|
| 1051 |
+
max_display=max_display,
|
| 1052 |
+
show=False,
|
| 1053 |
+
)
|
| 1054 |
+
st.pyplot(fig_swarm, clear_figure=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1055 |
|
| 1056 |
st.markdown("### Waterfall plots (batch)")
|
| 1057 |
|
|
|
|
| 1065 |
max_waterfalls = st.slider("Max waterfall plots to render", 1, 10, 3, 1, key="max_waterfalls")
|
| 1066 |
rows_to_plot = rows_to_plot[:max_waterfalls]
|
| 1067 |
|
| 1068 |
+
explainer = st.session_state.get("explainer")
|
| 1069 |
+
base = explainer.expected_value
|
| 1070 |
if not np.isscalar(base):
|
| 1071 |
base = float(np.array(base).reshape(-1)[0])
|
| 1072 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1073 |
for r in rows_to_plot:
|
| 1074 |
st.markdown(f"**Row {r} (within first {batch_n})**")
|
| 1075 |
exp = shap.Explanation(
|
| 1076 |
+
values=shap_vals_batch[r],
|
| 1077 |
base_values=float(base),
|
| 1078 |
data=X_dense[r],
|
| 1079 |
feature_names=names,
|
|
|
|
| 1082 |
shap.plots.waterfall(exp, show=False, max_display=max_display)
|
| 1083 |
st.pyplot(fig_w, clear_figure=True)
|
| 1084 |
|
| 1085 |
+
|
| 1086 |
+
#Single row SHAP block
|
| 1087 |
st.subheader("SHAP explanation")
|
| 1088 |
|
| 1089 |
with st.form("shap_form"):
|