Synav commited on
Commit
956abc0
·
verified ·
1 Parent(s): f2f0624

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -14
app.py CHANGED
@@ -2314,7 +2314,9 @@ with tab_predict:
2314
 
2315
  out["risk_band"] = band_one(proba_one)
2316
 
2317
- # ---- SHAP for single patient (works even without inference Excel) ----
 
 
2318
  X_one_t = transform_before_clf(pipe, X_one)
2319
 
2320
  explainer = st.session_state.get("explainer")
@@ -2340,14 +2342,12 @@ with tab_predict:
2340
  shap_vals = shap_vals[1]
2341
 
2342
  names = get_final_feature_names(pipe)
2343
- if len(names) != shap_vals.shape[1]:
2344
- names = [f"f{i}" for i in range(shap_vals.shape[1])]
2345
 
2346
  try:
2347
  x_dense = X_one_t.toarray()[0]
2348
  except Exception:
2349
  x_dense = np.array(X_one_t)[0]
2350
-
2351
  base = explainer.expected_value
2352
  if not np.isscalar(base):
2353
  base = float(np.array(base).reshape(-1)[0])
@@ -2359,16 +2359,8 @@ with tab_predict:
2359
  feature_names=names,
2360
  )
2361
 
2362
- # Plot
2363
- plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
2364
- shap.plots.waterfall(exp, show=False, max_display=20)
2365
- fig_w = plt.gcf()
2366
- render_plot_with_download(fig_w, title="Single-patient SHAP waterfall", filename="single_patient_shap_waterfall.png", export_dpi=export_dpi, key="dl_sp_wf")
2367
-
2368
- plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
2369
- shap.plots.bar(exp, show=False, max_display=20)
2370
- fig_b = plt.gcf()
2371
- render_plot_with_download(fig_b, title="Single-patient SHAP bar", filename="single_patient_shap_bar.png", export_dpi=export_dpi, key="dl_sp_bar")
2372
 
2373
 
2374
  st.dataframe(out, use_container_width=True)
@@ -2381,6 +2373,42 @@ with tab_predict:
2381
  key="dl_sp_csv",
2382
  )
2383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2384
 
2385
 
2386
  # -----------------------------
 
2314
 
2315
  out["risk_band"] = band_one(proba_one)
2316
 
2317
+
2318
+
2319
+ # ---- SHAP compute only (cache) ----
2320
  X_one_t = transform_before_clf(pipe, X_one)
2321
 
2322
  explainer = st.session_state.get("explainer")
 
2342
  shap_vals = shap_vals[1]
2343
 
2344
  names = get_final_feature_names(pipe)
 
 
2345
 
2346
  try:
2347
  x_dense = X_one_t.toarray()[0]
2348
  except Exception:
2349
  x_dense = np.array(X_one_t)[0]
2350
+
2351
  base = explainer.expected_value
2352
  if not np.isscalar(base):
2353
  base = float(np.array(base).reshape(-1)[0])
 
2359
  feature_names=names,
2360
  )
2361
 
2362
+ # CACHE ONLY
2363
+ st.session_state.shap_single_exp = exp
 
 
 
 
 
 
 
 
2364
 
2365
 
2366
  st.dataframe(out, use_container_width=True)
 
2373
  key="dl_sp_csv",
2374
  )
2375
 
2376
+ # ---- Always render cached SHAP ----
2377
+ if "shap_single_exp" in st.session_state:
2378
+ exp = st.session_state.shap_single_exp
2379
+
2380
+ max_display_single = st.slider(
2381
+ "Top features to display (single patient)",
2382
+ 5, 40, 20, 1,
2383
+ key="sp_single_max_display"
2384
+ )
2385
+
2386
+ c1, c2 = st.columns(2)
2387
+
2388
+ with c1:
2389
+ plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
2390
+ shap.plots.waterfall(exp, show=False, max_display=max_display_single)
2391
+ fig_w = plt.gcf()
2392
+ render_plot_with_download(
2393
+ fig_w,
2394
+ title="Single-patient SHAP waterfall",
2395
+ filename="single_patient_shap_waterfall.png",
2396
+ export_dpi=export_dpi,
2397
+ key="dl_sp_wf"
2398
+ )
2399
+
2400
+ with c2:
2401
+ plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
2402
+ shap.plots.bar(exp, show=False, max_display=max_display_single)
2403
+ fig_b = plt.gcf()
2404
+ render_plot_with_download(
2405
+ fig_b,
2406
+ title="Single-patient SHAP bar",
2407
+ filename="single_patient_shap_bar.png",
2408
+ export_dpi=export_dpi,
2409
+ key="dl_sp_bar"
2410
+ )
2411
+
2412
 
2413
 
2414
  # -----------------------------