Update app.py
Browse files
app.py
CHANGED
|
@@ -2314,7 +2314,9 @@ with tab_predict:
|
|
| 2314 |
|
| 2315 |
out["risk_band"] = band_one(proba_one)
|
| 2316 |
|
| 2317 |
-
|
|
|
|
|
|
|
| 2318 |
X_one_t = transform_before_clf(pipe, X_one)
|
| 2319 |
|
| 2320 |
explainer = st.session_state.get("explainer")
|
|
@@ -2340,14 +2342,12 @@ with tab_predict:
|
|
| 2340 |
shap_vals = shap_vals[1]
|
| 2341 |
|
| 2342 |
names = get_final_feature_names(pipe)
|
| 2343 |
-
if len(names) != shap_vals.shape[1]:
|
| 2344 |
-
names = [f"f{i}" for i in range(shap_vals.shape[1])]
|
| 2345 |
|
| 2346 |
try:
|
| 2347 |
x_dense = X_one_t.toarray()[0]
|
| 2348 |
except Exception:
|
| 2349 |
x_dense = np.array(X_one_t)[0]
|
| 2350 |
-
|
| 2351 |
base = explainer.expected_value
|
| 2352 |
if not np.isscalar(base):
|
| 2353 |
base = float(np.array(base).reshape(-1)[0])
|
|
@@ -2359,16 +2359,8 @@ with tab_predict:
|
|
| 2359 |
feature_names=names,
|
| 2360 |
)
|
| 2361 |
|
| 2362 |
-
#
|
| 2363 |
-
|
| 2364 |
-
shap.plots.waterfall(exp, show=False, max_display=20)
|
| 2365 |
-
fig_w = plt.gcf()
|
| 2366 |
-
render_plot_with_download(fig_w, title="Single-patient SHAP waterfall", filename="single_patient_shap_waterfall.png", export_dpi=export_dpi, key="dl_sp_wf")
|
| 2367 |
-
|
| 2368 |
-
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 2369 |
-
shap.plots.bar(exp, show=False, max_display=20)
|
| 2370 |
-
fig_b = plt.gcf()
|
| 2371 |
-
render_plot_with_download(fig_b, title="Single-patient SHAP bar", filename="single_patient_shap_bar.png", export_dpi=export_dpi, key="dl_sp_bar")
|
| 2372 |
|
| 2373 |
|
| 2374 |
st.dataframe(out, use_container_width=True)
|
|
@@ -2381,6 +2373,42 @@ with tab_predict:
|
|
| 2381 |
key="dl_sp_csv",
|
| 2382 |
)
|
| 2383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2384 |
|
| 2385 |
|
| 2386 |
# -----------------------------
|
|
|
|
| 2314 |
|
| 2315 |
out["risk_band"] = band_one(proba_one)
|
| 2316 |
|
| 2317 |
+
|
| 2318 |
+
|
| 2319 |
+
# ---- SHAP compute only (cache) ----
|
| 2320 |
X_one_t = transform_before_clf(pipe, X_one)
|
| 2321 |
|
| 2322 |
explainer = st.session_state.get("explainer")
|
|
|
|
| 2342 |
shap_vals = shap_vals[1]
|
| 2343 |
|
| 2344 |
names = get_final_feature_names(pipe)
|
|
|
|
|
|
|
| 2345 |
|
| 2346 |
try:
|
| 2347 |
x_dense = X_one_t.toarray()[0]
|
| 2348 |
except Exception:
|
| 2349 |
x_dense = np.array(X_one_t)[0]
|
| 2350 |
+
|
| 2351 |
base = explainer.expected_value
|
| 2352 |
if not np.isscalar(base):
|
| 2353 |
base = float(np.array(base).reshape(-1)[0])
|
|
|
|
| 2359 |
feature_names=names,
|
| 2360 |
)
|
| 2361 |
|
| 2362 |
+
# CACHE ONLY
|
| 2363 |
+
st.session_state.shap_single_exp = exp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2364 |
|
| 2365 |
|
| 2366 |
st.dataframe(out, use_container_width=True)
|
|
|
|
| 2373 |
key="dl_sp_csv",
|
| 2374 |
)
|
| 2375 |
|
| 2376 |
+
# ---- Always render cached SHAP ----
|
| 2377 |
+
if "shap_single_exp" in st.session_state:
|
| 2378 |
+
exp = st.session_state.shap_single_exp
|
| 2379 |
+
|
| 2380 |
+
max_display_single = st.slider(
|
| 2381 |
+
"Top features to display (single patient)",
|
| 2382 |
+
5, 40, 20, 1,
|
| 2383 |
+
key="sp_single_max_display"
|
| 2384 |
+
)
|
| 2385 |
+
|
| 2386 |
+
c1, c2 = st.columns(2)
|
| 2387 |
+
|
| 2388 |
+
with c1:
|
| 2389 |
+
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 2390 |
+
shap.plots.waterfall(exp, show=False, max_display=max_display_single)
|
| 2391 |
+
fig_w = plt.gcf()
|
| 2392 |
+
render_plot_with_download(
|
| 2393 |
+
fig_w,
|
| 2394 |
+
title="Single-patient SHAP waterfall",
|
| 2395 |
+
filename="single_patient_shap_waterfall.png",
|
| 2396 |
+
export_dpi=export_dpi,
|
| 2397 |
+
key="dl_sp_wf"
|
| 2398 |
+
)
|
| 2399 |
+
|
| 2400 |
+
with c2:
|
| 2401 |
+
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 2402 |
+
shap.plots.bar(exp, show=False, max_display=max_display_single)
|
| 2403 |
+
fig_b = plt.gcf()
|
| 2404 |
+
render_plot_with_download(
|
| 2405 |
+
fig_b,
|
| 2406 |
+
title="Single-patient SHAP bar",
|
| 2407 |
+
filename="single_patient_shap_bar.png",
|
| 2408 |
+
export_dpi=export_dpi,
|
| 2409 |
+
key="dl_sp_bar"
|
| 2410 |
+
)
|
| 2411 |
+
|
| 2412 |
|
| 2413 |
|
| 2414 |
# -----------------------------
|