Synav commited on
Commit
66e1586
·
verified ·
1 Parent(s): 1668c7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -52
app.py CHANGED
@@ -915,6 +915,7 @@ with tab_predict:
915
  )
916
 
917
  #SHAP BLOCK
 
918
  st.divider()
919
  st.subheader("Batch SHAP (first 200 rows)")
920
 
@@ -922,61 +923,116 @@ with tab_predict:
922
  n_rows = len(X_inf)
923
  batch_n = min(MAX_BATCH, n_rows)
924
 
925
- cA, cB = st.columns([1, 1])
926
  with cA:
927
  do_batch = st.button(f"Compute batch SHAP for first {batch_n} rows", key="batch_shap_btn")
928
  with cB:
929
  max_display = st.slider("Top features to display", 5, 40, 20, 1, key="batch_max_display")
 
 
930
 
931
  if do_batch:
932
  with st.spinner("Computing batch SHAP..."):
933
  pre = pipe.named_steps["preprocess"]
934
 
935
- # Use first N rows (fast + predictable memory)
936
  X_batch = X_inf.iloc[:batch_n].copy()
937
  X_batch_t = pre.transform(X_batch)
938
 
939
- # Build explainer once (cached)
940
- if st.session_state.get("explainer") is None:
 
941
  st.session_state.explainer = build_shap_explainer(pipe, X_inf)
 
942
 
943
- explainer = st.session_state.explainer
944
-
945
- shap_vals = explainer.shap_values(X_batch_t)
946
- if isinstance(shap_vals, list):
947
- shap_vals = shap_vals[1] # positive class
948
 
949
- # Cache batch results
950
- st.session_state.shap_batch_vals = shap_vals
951
- st.session_state.shap_batch_Xt = X_batch_t
952
- st.session_state.shap_batch_n = batch_n
953
 
 
954
  try:
955
- st.session_state.shap_batch_feature_names = list(pre.get_feature_names_out())
956
  except Exception:
957
- st.session_state.shap_batch_feature_names = [f"f{i}" for i in range(shap_vals.shape[1])]
 
 
 
 
 
 
958
 
959
  st.success(f"Batch SHAP computed for first {batch_n} rows.")
960
-
961
-
962
  if "shap_batch_vals" in st.session_state:
963
- shap_vals = st.session_state.shap_batch_vals
964
- X_batch_t = st.session_state.shap_batch_Xt
965
  batch_n = st.session_state.shap_batch_n
966
  names = st.session_state.shap_batch_feature_names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
967
 
968
- st.markdown("### Global SHAP summary (first {} rows)".format(batch_n))
969
-
970
- # Convert X to dense only once if needed (beeswarm often needs dense)
971
- try:
972
- X_dense = X_batch_t.toarray()
973
- except Exception:
974
- X_dense = np.array(X_batch_t)
975
 
976
  # BAR SUMMARY
977
  fig_bar = plt.figure()
978
  shap.summary_plot(
979
- shap_vals,
980
  features=X_dense,
981
  feature_names=names,
982
  plot_type="bar",
@@ -985,22 +1041,17 @@ with tab_predict:
985
  )
986
  st.pyplot(fig_bar, clear_figure=True)
987
 
988
- # BEESWARM SUMMARY
989
- fig_swarm = plt.figure()
990
- shap.summary_plot(
991
- shap_vals,
992
- features=X_dense,
993
- feature_names=names,
994
- max_display=max_display,
995
- show=False,
996
- )
997
- st.pyplot(fig_swarm, clear_figure=True)
998
-
999
- if "shap_batch_vals" in st.session_state:
1000
- shap_vals = st.session_state.shap_batch_vals
1001
- X_batch_t = st.session_state.shap_batch_Xt
1002
- batch_n = st.session_state.shap_batch_n
1003
- names = st.session_state.shap_batch_feature_names
1004
 
1005
  st.markdown("### Waterfall plots (batch)")
1006
 
@@ -1014,20 +1065,15 @@ with tab_predict:
1014
  max_waterfalls = st.slider("Max waterfall plots to render", 1, 10, 3, 1, key="max_waterfalls")
1015
  rows_to_plot = rows_to_plot[:max_waterfalls]
1016
 
1017
- base = st.session_state.explainer.expected_value
 
1018
  if not np.isscalar(base):
1019
  base = float(np.array(base).reshape(-1)[0])
1020
 
1021
- # dense only if needed for data in Explanation
1022
- try:
1023
- X_dense = X_batch_t.toarray()
1024
- except Exception:
1025
- X_dense = np.array(X_batch_t)
1026
-
1027
  for r in rows_to_plot:
1028
  st.markdown(f"**Row {r} (within first {batch_n})**")
1029
  exp = shap.Explanation(
1030
- values=shap_vals[r],
1031
  base_values=float(base),
1032
  data=X_dense[r],
1033
  feature_names=names,
@@ -1036,7 +1082,8 @@ with tab_predict:
1036
  shap.plots.waterfall(exp, show=False, max_display=max_display)
1037
  st.pyplot(fig_w, clear_figure=True)
1038
 
1039
-
 
1040
  st.subheader("SHAP explanation")
1041
 
1042
  with st.form("shap_form"):
 
915
  )
916
 
917
  #SHAP BLOCK
918
+
919
  st.divider()
920
  st.subheader("Batch SHAP (first 200 rows)")
921
 
 
923
  n_rows = len(X_inf)
924
  batch_n = min(MAX_BATCH, n_rows)
925
 
926
+ cA, cB, cC = st.columns([1, 1, 1])
927
  with cA:
928
  do_batch = st.button(f"Compute batch SHAP for first {batch_n} rows", key="batch_shap_btn")
929
  with cB:
930
  max_display = st.slider("Top features to display", 5, 40, 20, 1, key="batch_max_display")
931
+ with cC:
932
+ show_beeswarm = st.checkbox("Show beeswarm (slower)", value=True, key="batch_beeswarm")
933
 
934
  if do_batch:
935
  with st.spinner("Computing batch SHAP..."):
936
  pre = pipe.named_steps["preprocess"]
937
 
 
938
  X_batch = X_inf.iloc[:batch_n].copy()
939
  X_batch_t = pre.transform(X_batch)
940
 
941
+ # Ensure explainer exists
942
+ explainer = st.session_state.get("explainer")
943
+ if explainer is None:
944
  st.session_state.explainer = build_shap_explainer(pipe, X_inf)
945
+ explainer = st.session_state.explainer
946
 
947
+ shap_vals_batch = explainer.shap_values(X_batch_t)
948
+ if isinstance(shap_vals_batch, list):
949
+ shap_vals_batch = shap_vals_batch[1] # positive class
 
 
950
 
951
+ try:
952
+ names = list(pre.get_feature_names_out())
953
+ except Exception:
954
+ names = [f"f{i}" for i in range(shap_vals_batch.shape[1])]
955
 
956
+ # Dense conversion once (used for summary + waterfalls)
957
  try:
958
+ X_dense = X_batch_t.toarray()
959
  except Exception:
960
+ X_dense = np.array(X_batch_t)
961
+
962
+ # Cache batch results
963
+ st.session_state.shap_batch_vals = shap_vals_batch
964
+ st.session_state.shap_batch_X_dense = X_dense
965
+ st.session_state.shap_batch_n = batch_n
966
+ st.session_state.shap_batch_feature_names = names
967
 
968
  st.success(f"Batch SHAP computed for first {batch_n} rows.")
969
+
 
970
  if "shap_batch_vals" in st.session_state:
971
+ shap_vals_batch = st.session_state.shap_batch_vals
972
+ X_dense = st.session_state.shap_batch_X_dense
973
  batch_n = st.session_state.shap_batch_n
974
  names = st.session_state.shap_batch_feature_names
975
+
976
+ st.divider()
977
+ st.subheader("Export: Top SHAP features per row (batch)")
978
+
979
+ top_k = st.slider("Top-K features per row", 3, 30, 10, 1, key="topk_export")
980
+
981
+ # Optional: include predicted probabilities for the same batch rows
982
+ # (Assumes you already computed proba for all X_inf earlier)
983
+ include_proba = st.checkbox("Include predicted probability", value=True, key="include_proba_export")
984
+
985
+ if st.button("Generate Top-K SHAP table", key="gen_topk_shap"):
986
+ shap_vals_batch = st.session_state.shap_batch_vals # shape: (batch_n, n_features)
987
+ names = st.session_state.shap_batch_feature_names
988
+ batch_n = st.session_state.shap_batch_n
989
+
990
+ rows = []
991
+ for i in range(batch_n):
992
+ sv = shap_vals_batch[i]
993
+ idx = np.argsort(np.abs(sv))[::-1][:top_k] # top-k by absolute SHAP
994
+
995
+ for j in idx:
996
+ val = float(sv[j])
997
+ rows.append({
998
+ "row_in_batch": int(i),
999
+ "feature": str(names[j]),
1000
+ "shap_value": val,
1001
+ "abs_shap_value": abs(val),
1002
+ "direction": "↑" if val > 0 else ("↓" if val < 0 else "0"),
1003
+ })
1004
+
1005
+ df_topk = pd.DataFrame(rows)
1006
+
1007
+ if include_proba:
1008
+ # Use the same batch rows from the previously computed proba vector
1009
+ # If you want absolute Excel row index, add + df_inf.index[0] logic as needed
1010
+ proba_batch = proba[:batch_n]
1011
+ df_proba = pd.DataFrame({"row_in_batch": list(range(batch_n)), "predicted_probability": proba_batch})
1012
+ df_topk = df_topk.merge(df_proba, on="row_in_batch", how="left")
1013
+
1014
+ # Sort nicely: each row block by importance
1015
+ df_topk = df_topk.sort_values(["row_in_batch", "abs_shap_value"], ascending=[True, False])
1016
+
1017
+ st.dataframe(df_topk, use_container_width=True)
1018
+
1019
+ st.download_button(
1020
+ "Download Top-K SHAP per row (CSV)",
1021
+ df_topk.to_csv(index=False).encode("utf-8"),
1022
+ file_name=f"shap_top{top_k}_per_row_first{batch_n}.csv",
1023
+ mime="text/csv",
1024
+ key="dl_topk_shap_csv"
1025
+ )
1026
+
1027
+
1028
+
1029
 
1030
+ st.markdown(f"### Global SHAP summary (first {batch_n} rows)")
 
 
 
 
 
 
1031
 
1032
  # BAR SUMMARY
1033
  fig_bar = plt.figure()
1034
  shap.summary_plot(
1035
+ shap_vals_batch,
1036
  features=X_dense,
1037
  feature_names=names,
1038
  plot_type="bar",
 
1041
  )
1042
  st.pyplot(fig_bar, clear_figure=True)
1043
 
1044
+ # BEESWARM SUMMARY (optional)
1045
+ if show_beeswarm:
1046
+ fig_swarm = plt.figure()
1047
+ shap.summary_plot(
1048
+ shap_vals_batch,
1049
+ features=X_dense,
1050
+ feature_names=names,
1051
+ max_display=max_display,
1052
+ show=False,
1053
+ )
1054
+ st.pyplot(fig_swarm, clear_figure=True)
 
 
 
 
 
1055
 
1056
  st.markdown("### Waterfall plots (batch)")
1057
 
 
1065
  max_waterfalls = st.slider("Max waterfall plots to render", 1, 10, 3, 1, key="max_waterfalls")
1066
  rows_to_plot = rows_to_plot[:max_waterfalls]
1067
 
1068
+ explainer = st.session_state.get("explainer")
1069
+ base = explainer.expected_value
1070
  if not np.isscalar(base):
1071
  base = float(np.array(base).reshape(-1)[0])
1072
 
 
 
 
 
 
 
1073
  for r in rows_to_plot:
1074
  st.markdown(f"**Row {r} (within first {batch_n})**")
1075
  exp = shap.Explanation(
1076
+ values=shap_vals_batch[r],
1077
  base_values=float(base),
1078
  data=X_dense[r],
1079
  feature_names=names,
 
1082
  shap.plots.waterfall(exp, show=False, max_display=max_display)
1083
  st.pyplot(fig_w, clear_figure=True)
1084
 
1085
+
1086
+ #Single row SHAP block
1087
  st.subheader("SHAP explanation")
1088
 
1089
  with st.form("shap_form"):