Update app.py
Browse files
app.py
CHANGED
|
@@ -30,6 +30,7 @@ from sklearn.model_selection import train_test_split
|
|
| 30 |
|
| 31 |
#Figures setting block
|
| 32 |
import io
|
|
|
|
| 33 |
|
| 34 |
def make_fig(figsize=(5.5, 3.6), dpi=120):
|
| 35 |
"""
|
|
@@ -38,28 +39,22 @@ def make_fig(figsize=(5.5, 3.6), dpi=120):
|
|
| 38 |
fig = plt.figure(figsize=figsize, dpi=dpi)
|
| 39 |
return fig
|
| 40 |
|
|
|
|
|
|
|
| 41 |
def fig_to_png_bytes(fig, dpi=600):
|
| 42 |
-
"""
|
| 43 |
-
Export current figure as PNG bytes at high DPI (>=600).
|
| 44 |
-
"""
|
| 45 |
buf = io.BytesIO()
|
| 46 |
fig.savefig(buf, format="png", dpi=int(dpi), bbox_inches="tight")
|
| 47 |
buf.seek(0)
|
| 48 |
return buf.getvalue()
|
| 49 |
|
| 50 |
-
def render_plot_with_download(
|
| 51 |
-
|
| 52 |
-
*,
|
| 53 |
-
title: str,
|
| 54 |
-
filename: str,
|
| 55 |
-
export_dpi: int = 600
|
| 56 |
-
):
|
| 57 |
-
"""
|
| 58 |
-
Show a compact plot in Streamlit + provide high-DPI PNG download.
|
| 59 |
-
"""
|
| 60 |
-
st.pyplot(fig, clear_figure=True, use_container_width=True)
|
| 61 |
-
|
| 62 |
png_bytes = fig_to_png_bytes(fig, dpi=export_dpi)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
st.download_button(
|
| 64 |
label=f"Download {title} (PNG {export_dpi} dpi)",
|
| 65 |
data=png_bytes,
|
|
@@ -68,6 +63,10 @@ def render_plot_with_download(
|
|
| 68 |
key=f"dl_{filename}"
|
| 69 |
)
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
# ============================================================
|
| 73 |
# Fixed schema definition (PLACEHOLDER FRAMEWORK)
|
|
@@ -1195,57 +1194,54 @@ with tab_predict:
|
|
| 1195 |
|
| 1196 |
|
| 1197 |
# PR plot
|
| 1198 |
-
|
| 1199 |
-
|
| 1200 |
-
|
| 1201 |
-
|
| 1202 |
-
|
| 1203 |
-
|
| 1204 |
|
| 1205 |
render_plot_with_download(
|
| 1206 |
-
|
| 1207 |
title="External PR curve",
|
| 1208 |
filename="external_pr_curve.png",
|
| 1209 |
export_dpi=export_dpi
|
| 1210 |
)
|
| 1211 |
-
|
| 1212 |
-
|
| 1213 |
-
|
| 1214 |
-
|
| 1215 |
-
|
| 1216 |
-
|
| 1217 |
-
|
| 1218 |
-
|
| 1219 |
-
plt.xlabel("Mean predicted probability")
|
| 1220 |
-
plt.ylabel("Observed event rate")
|
| 1221 |
-
plt.title("External calibration curve")
|
| 1222 |
|
| 1223 |
render_plot_with_download(
|
| 1224 |
-
|
| 1225 |
title="External calibration curve",
|
| 1226 |
filename="external_calibration_curve.png",
|
| 1227 |
export_dpi=export_dpi
|
| 1228 |
)
|
| 1229 |
|
| 1230 |
-
# DCA
|
| 1231 |
-
|
| 1232 |
-
|
| 1233 |
-
|
| 1234 |
-
|
| 1235 |
-
|
| 1236 |
-
|
| 1237 |
-
|
| 1238 |
-
|
| 1239 |
-
plt.legend()
|
| 1240 |
|
| 1241 |
render_plot_with_download(
|
| 1242 |
-
|
| 1243 |
title="External decision curve",
|
| 1244 |
filename="external_decision_curve.png",
|
| 1245 |
export_dpi=export_dpi
|
| 1246 |
)
|
| 1247 |
|
| 1248 |
|
|
|
|
| 1249 |
|
| 1250 |
except Exception as e:
|
| 1251 |
st.error(f"Could not compute external validation metrics: {e}")
|
|
@@ -1419,40 +1415,32 @@ with tab_predict:
|
|
| 1419 |
st.markdown(f"### Global SHAP summary (first {batch_n} rows)")
|
| 1420 |
|
| 1421 |
# BAR SUMMARY
|
| 1422 |
-
|
| 1423 |
shap.summary_plot(
|
| 1424 |
shap_vals_batch,
|
| 1425 |
features=X_dense,
|
| 1426 |
feature_names=names,
|
| 1427 |
plot_type="bar",
|
| 1428 |
max_display=max_display,
|
| 1429 |
-
show=False
|
| 1430 |
-
plot_size=FIGSIZE
|
| 1431 |
-
)
|
| 1432 |
-
render_plot_with_download(
|
| 1433 |
-
fig_bar,
|
| 1434 |
-
title="SHAP bar summary",
|
| 1435 |
-
filename="shap_summary_bar.png",
|
| 1436 |
-
export_dpi=export_dpi
|
| 1437 |
)
|
|
|
|
|
|
|
|
|
|
| 1438 |
|
| 1439 |
|
| 1440 |
# BEESWARM SUMMARY (optional)
|
| 1441 |
-
|
| 1442 |
shap.summary_plot(
|
| 1443 |
shap_vals_batch,
|
| 1444 |
features=X_dense,
|
| 1445 |
feature_names=names,
|
| 1446 |
max_display=max_display,
|
| 1447 |
-
show=False
|
| 1448 |
-
plot_size=FIGSIZE
|
| 1449 |
-
)
|
| 1450 |
-
render_plot_with_download(
|
| 1451 |
-
fig_swarm,
|
| 1452 |
-
title="SHAP beeswarm",
|
| 1453 |
-
filename="shap_beeswarm.png",
|
| 1454 |
-
export_dpi=export_dpi
|
| 1455 |
)
|
|
|
|
|
|
|
|
|
|
| 1456 |
|
| 1457 |
|
| 1458 |
st.markdown("### Waterfall plots (batch)")
|
|
@@ -1536,18 +1524,16 @@ with tab_predict:
|
|
| 1536 |
|
| 1537 |
with c1:
|
| 1538 |
st.markdown("**Waterfall**")
|
| 1539 |
-
|
| 1540 |
shap.plots.waterfall(exp, show=False, max_display=20)
|
| 1541 |
-
|
| 1542 |
-
|
| 1543 |
-
)
|
| 1544 |
|
| 1545 |
with c2:
|
| 1546 |
st.markdown("**Top features**")
|
| 1547 |
-
|
| 1548 |
shap.plots.bar(exp, show=False, max_display=20)
|
| 1549 |
-
|
| 1550 |
-
|
| 1551 |
-
)
|
| 1552 |
|
| 1553 |
st.stop()
|
|
|
|
| 30 |
|
| 31 |
#Figures setting block
|
| 32 |
import io
|
| 33 |
+
import matplotlib.pyplot as plt
|
| 34 |
|
| 35 |
def make_fig(figsize=(5.5, 3.6), dpi=120):
|
| 36 |
"""
|
|
|
|
| 39 |
fig = plt.figure(figsize=figsize, dpi=dpi)
|
| 40 |
return fig
|
| 41 |
|
| 42 |
+
|
| 43 |
+
|
| 44 |
def fig_to_png_bytes(fig, dpi=600):
|
|
|
|
|
|
|
|
|
|
| 45 |
buf = io.BytesIO()
|
| 46 |
fig.savefig(buf, format="png", dpi=int(dpi), bbox_inches="tight")
|
| 47 |
buf.seek(0)
|
| 48 |
return buf.getvalue()
|
| 49 |
|
| 50 |
+
def render_plot_with_download(fig, *, title: str, filename: str, export_dpi: int = 600):
|
| 51 |
+
# 1) SAVE FIRST (before any clear)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
png_bytes = fig_to_png_bytes(fig, dpi=export_dpi)
|
| 53 |
+
|
| 54 |
+
# 2) DISPLAY (do NOT clear)
|
| 55 |
+
st.pyplot(fig, clear_figure=False, use_container_width=False)
|
| 56 |
+
|
| 57 |
+
# 3) DOWNLOAD
|
| 58 |
st.download_button(
|
| 59 |
label=f"Download {title} (PNG {export_dpi} dpi)",
|
| 60 |
data=png_bytes,
|
|
|
|
| 63 |
key=f"dl_{filename}"
|
| 64 |
)
|
| 65 |
|
| 66 |
+
# 4) Prevent figure leakage / overlaps
|
| 67 |
+
plt.close(fig)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
|
| 71 |
# ============================================================
|
| 72 |
# Fixed schema definition (PLACEHOLDER FRAMEWORK)
|
|
|
|
| 1194 |
|
| 1195 |
|
| 1196 |
# PR plot
|
| 1197 |
+
# PR curve (external)
|
| 1198 |
+
fig, ax = plt.subplots(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1199 |
+
ax.plot(pr_ext["recall"], pr_ext["precision"])
|
| 1200 |
+
ax.set_xlabel("Recall")
|
| 1201 |
+
ax.set_ylabel("Precision")
|
| 1202 |
+
ax.set_title(f"External PR Curve (AP = {pr_ext['average_precision']:.3f})")
|
| 1203 |
|
| 1204 |
render_plot_with_download(
|
| 1205 |
+
fig,
|
| 1206 |
title="External PR curve",
|
| 1207 |
filename="external_pr_curve.png",
|
| 1208 |
export_dpi=export_dpi
|
| 1209 |
)
|
| 1210 |
+
|
| 1211 |
+
# Calibration curve (external)
|
| 1212 |
+
fig, ax = plt.subplots(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1213 |
+
ax.plot(cal_ext["prob_pred"], cal_ext["prob_true"])
|
| 1214 |
+
ax.plot([0, 1], [0, 1]) # diagonal reference
|
| 1215 |
+
ax.set_xlabel("Mean predicted probability")
|
| 1216 |
+
ax.set_ylabel("Observed event rate")
|
| 1217 |
+
ax.set_title("External calibration curve")
|
|
|
|
|
|
|
|
|
|
| 1218 |
|
| 1219 |
render_plot_with_download(
|
| 1220 |
+
fig,
|
| 1221 |
title="External calibration curve",
|
| 1222 |
filename="external_calibration_curve.png",
|
| 1223 |
export_dpi=export_dpi
|
| 1224 |
)
|
| 1225 |
|
| 1226 |
+
# DCA (external)
|
| 1227 |
+
fig, ax = plt.subplots(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1228 |
+
ax.plot(dca_ext["thresholds"], dca_ext["net_benefit_model"], label="Model")
|
| 1229 |
+
ax.plot(dca_ext["thresholds"], dca_ext["net_benefit_all"], label="Treat all")
|
| 1230 |
+
ax.plot(dca_ext["thresholds"], dca_ext["net_benefit_none"], label="Treat none")
|
| 1231 |
+
ax.set_xlabel("Threshold probability")
|
| 1232 |
+
ax.set_ylabel("Net benefit")
|
| 1233 |
+
ax.set_title("External decision curve analysis")
|
| 1234 |
+
ax.legend()
|
|
|
|
| 1235 |
|
| 1236 |
render_plot_with_download(
|
| 1237 |
+
fig,
|
| 1238 |
title="External decision curve",
|
| 1239 |
filename="external_decision_curve.png",
|
| 1240 |
export_dpi=export_dpi
|
| 1241 |
)
|
| 1242 |
|
| 1243 |
|
| 1244 |
+
|
| 1245 |
|
| 1246 |
except Exception as e:
|
| 1247 |
st.error(f"Could not compute external validation metrics: {e}")
|
|
|
|
| 1415 |
st.markdown(f"### Global SHAP summary (first {batch_n} rows)")
|
| 1416 |
|
| 1417 |
# BAR SUMMARY
|
| 1418 |
+
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1419 |
shap.summary_plot(
|
| 1420 |
shap_vals_batch,
|
| 1421 |
features=X_dense,
|
| 1422 |
feature_names=names,
|
| 1423 |
plot_type="bar",
|
| 1424 |
max_display=max_display,
|
| 1425 |
+
show=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1426 |
)
|
| 1427 |
+
fig_bar = plt.gcf()
|
| 1428 |
+
render_plot_with_download(fig_bar, title="SHAP bar summary", filename="shap_summary_bar.png", export_dpi=export_dpi)
|
| 1429 |
+
|
| 1430 |
|
| 1431 |
|
| 1432 |
# BEESWARM SUMMARY (optional)
|
| 1433 |
+
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1434 |
shap.summary_plot(
|
| 1435 |
shap_vals_batch,
|
| 1436 |
features=X_dense,
|
| 1437 |
feature_names=names,
|
| 1438 |
max_display=max_display,
|
| 1439 |
+
show=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1440 |
)
|
| 1441 |
+
fig_swarm = plt.gcf()
|
| 1442 |
+
render_plot_with_download(fig_swarm, title="SHAP beeswarm", filename="shap_beeswarm.png", export_dpi=export_dpi)
|
| 1443 |
+
|
| 1444 |
|
| 1445 |
|
| 1446 |
st.markdown("### Waterfall plots (batch)")
|
|
|
|
| 1524 |
|
| 1525 |
with c1:
|
| 1526 |
st.markdown("**Waterfall**")
|
| 1527 |
+
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1528 |
shap.plots.waterfall(exp, show=False, max_display=20)
|
| 1529 |
+
fig_w = plt.gcf()
|
| 1530 |
+
render_plot_with_download(fig_w, title="SHAP waterfall", filename="shap_waterfall_row.png", export_dpi=export_dpi)
|
|
|
|
| 1531 |
|
| 1532 |
with c2:
|
| 1533 |
st.markdown("**Top features**")
|
| 1534 |
+
plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1535 |
shap.plots.bar(exp, show=False, max_display=20)
|
| 1536 |
+
fig_b = plt.gcf()
|
| 1537 |
+
render_plot_with_download(fig_b, title="SHAP bar", filename="shap_bar_row.png", export_dpi=export_dpi)
|
|
|
|
| 1538 |
|
| 1539 |
st.stop()
|