Synav commited on
Commit
6e4dce9
·
verified ·
1 Parent(s): 746ab4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -71
app.py CHANGED
@@ -30,6 +30,7 @@ from sklearn.model_selection import train_test_split
30
 
31
  #Figures setting block
32
  import io
 
33
 
34
  def make_fig(figsize=(5.5, 3.6), dpi=120):
35
  """
@@ -38,28 +39,22 @@ def make_fig(figsize=(5.5, 3.6), dpi=120):
38
  fig = plt.figure(figsize=figsize, dpi=dpi)
39
  return fig
40
 
 
 
41
  def fig_to_png_bytes(fig, dpi=600):
42
- """
43
- Export current figure as PNG bytes at high DPI (>=600).
44
- """
45
  buf = io.BytesIO()
46
  fig.savefig(buf, format="png", dpi=int(dpi), bbox_inches="tight")
47
  buf.seek(0)
48
  return buf.getvalue()
49
 
50
- def render_plot_with_download(
51
- fig,
52
- *,
53
- title: str,
54
- filename: str,
55
- export_dpi: int = 600
56
- ):
57
- """
58
- Show a compact plot in Streamlit + provide high-DPI PNG download.
59
- """
60
- st.pyplot(fig, clear_figure=True, use_container_width=True)
61
-
62
  png_bytes = fig_to_png_bytes(fig, dpi=export_dpi)
 
 
 
 
 
63
  st.download_button(
64
  label=f"Download {title} (PNG {export_dpi} dpi)",
65
  data=png_bytes,
@@ -68,6 +63,10 @@ def render_plot_with_download(
68
  key=f"dl_{filename}"
69
  )
70
 
 
 
 
 
71
 
72
  # ============================================================
73
  # Fixed schema definition (PLACEHOLDER FRAMEWORK)
@@ -1195,57 +1194,54 @@ with tab_predict:
1195
 
1196
 
1197
  # PR plot
1198
- pr = pr_ext
1199
- fig_pr = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1200
- plt.plot(pr_ext["recall"], pr_ext["precision"])
1201
- plt.xlabel("Recall")
1202
- plt.ylabel("Precision")
1203
- plt.title(f"External PR Curve (AP = {pr_ext['average_precision']:.3f})")
1204
 
1205
  render_plot_with_download(
1206
- fig_pr,
1207
  title="External PR curve",
1208
  filename="external_pr_curve.png",
1209
  export_dpi=export_dpi
1210
  )
1211
-
1212
-
1213
-
1214
- # Calibration plot (external)
1215
- cal = cal_ext
1216
- fig_cal = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1217
- plt.plot(cal["prob_pred"], cal["prob_true"])
1218
- plt.plot([0, 1], [0, 1])
1219
- plt.xlabel("Mean predicted probability")
1220
- plt.ylabel("Observed event rate")
1221
- plt.title("External calibration curve")
1222
 
1223
  render_plot_with_download(
1224
- fig_cal,
1225
  title="External calibration curve",
1226
  filename="external_calibration_curve.png",
1227
  export_dpi=export_dpi
1228
  )
1229
 
1230
- # DCA plot (external)
1231
- dca = dca_ext
1232
- fig_dca = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1233
- plt.plot(dca["thresholds"], dca["net_benefit_model"], label="Model")
1234
- plt.plot(dca["thresholds"], dca["net_benefit_all"], label="Treat all")
1235
- plt.plot(dca["thresholds"], dca["net_benefit_none"], label="Treat none")
1236
- plt.xlabel("Threshold probability")
1237
- plt.ylabel("Net benefit")
1238
- plt.title("External decision curve analysis")
1239
- plt.legend()
1240
 
1241
  render_plot_with_download(
1242
- fig_dca,
1243
  title="External decision curve",
1244
  filename="external_decision_curve.png",
1245
  export_dpi=export_dpi
1246
  )
1247
 
1248
 
 
1249
 
1250
  except Exception as e:
1251
  st.error(f"Could not compute external validation metrics: {e}")
@@ -1419,40 +1415,32 @@ with tab_predict:
1419
  st.markdown(f"### Global SHAP summary (first {batch_n} rows)")
1420
 
1421
  # BAR SUMMARY
1422
- fig_bar = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1423
  shap.summary_plot(
1424
  shap_vals_batch,
1425
  features=X_dense,
1426
  feature_names=names,
1427
  plot_type="bar",
1428
  max_display=max_display,
1429
- show=False,
1430
- plot_size=FIGSIZE
1431
- )
1432
- render_plot_with_download(
1433
- fig_bar,
1434
- title="SHAP bar summary",
1435
- filename="shap_summary_bar.png",
1436
- export_dpi=export_dpi
1437
  )
 
 
 
1438
 
1439
 
1440
  # BEESWARM SUMMARY (optional)
1441
- fig_swarm = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1442
  shap.summary_plot(
1443
  shap_vals_batch,
1444
  features=X_dense,
1445
  feature_names=names,
1446
  max_display=max_display,
1447
- show=False,
1448
- plot_size=FIGSIZE
1449
- )
1450
- render_plot_with_download(
1451
- fig_swarm,
1452
- title="SHAP beeswarm",
1453
- filename="shap_beeswarm.png",
1454
- export_dpi=export_dpi
1455
  )
 
 
 
1456
 
1457
 
1458
  st.markdown("### Waterfall plots (batch)")
@@ -1536,18 +1524,16 @@ with tab_predict:
1536
 
1537
  with c1:
1538
  st.markdown("**Waterfall**")
1539
- fig = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1540
  shap.plots.waterfall(exp, show=False, max_display=20)
1541
- render_plot_with_download(
1542
- fig, title="SHAP waterfall", filename="shap_waterfall_row.png", export_dpi=export_dpi
1543
- )
1544
 
1545
  with c2:
1546
  st.markdown("**Top features**")
1547
- fig2 = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1548
  shap.plots.bar(exp, show=False, max_display=20)
1549
- render_plot_with_download(
1550
- fig2, title="SHAP bar", filename="shap_bar_row.png", export_dpi=export_dpi
1551
- )
1552
 
1553
  st.stop()
 
30
 
31
  #Figures setting block
32
  import io
33
+ import matplotlib.pyplot as plt
34
 
35
  def make_fig(figsize=(5.5, 3.6), dpi=120):
36
  """
 
39
  fig = plt.figure(figsize=figsize, dpi=dpi)
40
  return fig
41
 
42
+
43
+
44
  def fig_to_png_bytes(fig, dpi=600):
 
 
 
45
  buf = io.BytesIO()
46
  fig.savefig(buf, format="png", dpi=int(dpi), bbox_inches="tight")
47
  buf.seek(0)
48
  return buf.getvalue()
49
 
50
+ def render_plot_with_download(fig, *, title: str, filename: str, export_dpi: int = 600):
51
+ # 1) SAVE FIRST (before any clear)
 
 
 
 
 
 
 
 
 
 
52
  png_bytes = fig_to_png_bytes(fig, dpi=export_dpi)
53
+
54
+ # 2) DISPLAY (do NOT clear)
55
+ st.pyplot(fig, clear_figure=False, use_container_width=False)
56
+
57
+ # 3) DOWNLOAD
58
  st.download_button(
59
  label=f"Download {title} (PNG {export_dpi} dpi)",
60
  data=png_bytes,
 
63
  key=f"dl_{filename}"
64
  )
65
 
66
+ # 4) Prevent figure leakage / overlaps
67
+ plt.close(fig)
68
+
69
+
70
 
71
  # ============================================================
72
  # Fixed schema definition (PLACEHOLDER FRAMEWORK)
 
1194
 
1195
 
1196
  # PR plot
1197
+ # PR curve (external)
1198
+ fig, ax = plt.subplots(figsize=FIGSIZE, dpi=plot_dpi_screen)
1199
+ ax.plot(pr_ext["recall"], pr_ext["precision"])
1200
+ ax.set_xlabel("Recall")
1201
+ ax.set_ylabel("Precision")
1202
+ ax.set_title(f"External PR Curve (AP = {pr_ext['average_precision']:.3f})")
1203
 
1204
  render_plot_with_download(
1205
+ fig,
1206
  title="External PR curve",
1207
  filename="external_pr_curve.png",
1208
  export_dpi=export_dpi
1209
  )
1210
+
1211
+ # Calibration curve (external)
1212
+ fig, ax = plt.subplots(figsize=FIGSIZE, dpi=plot_dpi_screen)
1213
+ ax.plot(cal_ext["prob_pred"], cal_ext["prob_true"])
1214
+ ax.plot([0, 1], [0, 1]) # diagonal reference
1215
+ ax.set_xlabel("Mean predicted probability")
1216
+ ax.set_ylabel("Observed event rate")
1217
+ ax.set_title("External calibration curve")
 
 
 
1218
 
1219
  render_plot_with_download(
1220
+ fig,
1221
  title="External calibration curve",
1222
  filename="external_calibration_curve.png",
1223
  export_dpi=export_dpi
1224
  )
1225
 
1226
+ # DCA (external)
1227
+ fig, ax = plt.subplots(figsize=FIGSIZE, dpi=plot_dpi_screen)
1228
+ ax.plot(dca_ext["thresholds"], dca_ext["net_benefit_model"], label="Model")
1229
+ ax.plot(dca_ext["thresholds"], dca_ext["net_benefit_all"], label="Treat all")
1230
+ ax.plot(dca_ext["thresholds"], dca_ext["net_benefit_none"], label="Treat none")
1231
+ ax.set_xlabel("Threshold probability")
1232
+ ax.set_ylabel("Net benefit")
1233
+ ax.set_title("External decision curve analysis")
1234
+ ax.legend()
 
1235
 
1236
  render_plot_with_download(
1237
+ fig,
1238
  title="External decision curve",
1239
  filename="external_decision_curve.png",
1240
  export_dpi=export_dpi
1241
  )
1242
 
1243
 
1244
+
1245
 
1246
  except Exception as e:
1247
  st.error(f"Could not compute external validation metrics: {e}")
 
1415
  st.markdown(f"### Global SHAP summary (first {batch_n} rows)")
1416
 
1417
  # BAR SUMMARY
1418
+ plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
1419
  shap.summary_plot(
1420
  shap_vals_batch,
1421
  features=X_dense,
1422
  feature_names=names,
1423
  plot_type="bar",
1424
  max_display=max_display,
1425
+ show=False
 
 
 
 
 
 
 
1426
  )
1427
+ fig_bar = plt.gcf()
1428
+ render_plot_with_download(fig_bar, title="SHAP bar summary", filename="shap_summary_bar.png", export_dpi=export_dpi)
1429
+
1430
 
1431
 
1432
  # BEESWARM SUMMARY (optional)
1433
+ plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
1434
  shap.summary_plot(
1435
  shap_vals_batch,
1436
  features=X_dense,
1437
  feature_names=names,
1438
  max_display=max_display,
1439
+ show=False
 
 
 
 
 
 
 
1440
  )
1441
+ fig_swarm = plt.gcf()
1442
+ render_plot_with_download(fig_swarm, title="SHAP beeswarm", filename="shap_beeswarm.png", export_dpi=export_dpi)
1443
+
1444
 
1445
 
1446
  st.markdown("### Waterfall plots (batch)")
 
1524
 
1525
  with c1:
1526
  st.markdown("**Waterfall**")
1527
+ plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
1528
  shap.plots.waterfall(exp, show=False, max_display=20)
1529
+ fig_w = plt.gcf()
1530
+ render_plot_with_download(fig_w, title="SHAP waterfall", filename="shap_waterfall_row.png", export_dpi=export_dpi)
 
1531
 
1532
  with c2:
1533
  st.markdown("**Top features**")
1534
+ plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
1535
  shap.plots.bar(exp, show=False, max_display=20)
1536
+ fig_b = plt.gcf()
1537
+ render_plot_with_download(fig_b, title="SHAP bar", filename="shap_bar_row.png", export_dpi=export_dpi)
 
1538
 
1539
  st.stop()