Update app.py
Browse files
app.py
CHANGED
|
@@ -2313,6 +2313,63 @@ with tab_predict:
|
|
| 2313 |
out["predicted_class"] = pred_class
|
| 2314 |
|
| 2315 |
out["risk_band"] = band_one(proba_one)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2316 |
|
| 2317 |
st.dataframe(out, use_container_width=True)
|
| 2318 |
|
|
|
|
| 2313 |
out["predicted_class"] = pred_class
|
| 2314 |
|
| 2315 |
out["risk_band"] = band_one(proba_one)
|
| 2316 |
+
|
| 2317 |
+
# ---- SHAP for single patient (works even without inference Excel) ----
|
| 2318 |
+
X_one_t = transform_before_clf(pipe, X_one)
|
| 2319 |
+
|
| 2320 |
+
explainer = st.session_state.get("explainer")
|
| 2321 |
+
explainer_sig = st.session_state.get("explainer_sig")
|
| 2322 |
+
|
| 2323 |
+
current_sig = (
|
| 2324 |
+
selected,
|
| 2325 |
+
None if st.session_state.get("X_bg_for_shap") is None else int(len(st.session_state["X_bg_for_shap"]))
|
| 2326 |
+
)
|
| 2327 |
+
|
| 2328 |
+
if explainer is None or explainer_sig != current_sig:
|
| 2329 |
+
X_bg = st.session_state.get("X_bg_for_shap")
|
| 2330 |
+
if X_bg is None:
|
| 2331 |
+
st.error("SHAP background not available. Admin must publish latest/background.csv.")
|
| 2332 |
+
st.stop()
|
| 2333 |
+
|
| 2334 |
+
st.session_state.explainer = build_shap_explainer(pipe, X_bg)
|
| 2335 |
+
st.session_state.explainer_sig = current_sig
|
| 2336 |
+
explainer = st.session_state.explainer
|
| 2337 |
+
|
| 2338 |
+
shap_vals = explainer.shap_values(X_one_t)
|
| 2339 |
+
if isinstance(shap_vals, list):
|
| 2340 |
+
shap_vals = shap_vals[1]
|
| 2341 |
+
|
| 2342 |
+
names = get_final_feature_names(pipe)
|
| 2343 |
+
if len(names) != shap_vals.shape[1]:
|
| 2344 |
+
names = [f"f{i}" for i in range(shap_vals.shape[1])]
|
| 2345 |
+
|
| 2346 |
+
try:
|
| 2347 |
+
x_dense = X_one_t.toarray()[0]
|
| 2348 |
+
except Exception:
|
| 2349 |
+
x_dense = np.array(X_one_t)[0]
|
| 2350 |
+
|
| 2351 |
+
base = explainer.expected_value
|
| 2352 |
+
if not np.isscalar(base):
|
| 2353 |
+
base = float(np.array(base).reshape(-1)[0])
|
| 2354 |
+
|
| 2355 |
+
exp = shap.Explanation(
|
| 2356 |
+
values=shap_vals[0],
|
| 2357 |
+
base_values=float(base),
|
| 2358 |
+
data=x_dense,
|
| 2359 |
+
feature_names=names,
|
| 2360 |
+
)
|
| 2361 |
+
|
| 2362 |
+
# Plot
|
| 2363 |
+
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 2364 |
+
shap.plots.waterfall(exp, show=False, max_display=20)
|
| 2365 |
+
fig_w = plt.gcf()
|
| 2366 |
+
render_plot_with_download(fig_w, title="Single-patient SHAP waterfall", filename="single_patient_shap_waterfall.png", export_dpi=export_dpi, key="dl_sp_wf")
|
| 2367 |
+
|
| 2368 |
+
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 2369 |
+
shap.plots.bar(exp, show=False, max_display=20)
|
| 2370 |
+
fig_b = plt.gcf()
|
| 2371 |
+
render_plot_with_download(fig_b, title="Single-patient SHAP bar", filename="single_patient_shap_bar.png", export_dpi=export_dpi, key="dl_sp_bar")
|
| 2372 |
+
|
| 2373 |
|
| 2374 |
st.dataframe(out, use_container_width=True)
|
| 2375 |
|