Update app.py
Browse files
app.py
CHANGED
|
@@ -28,6 +28,45 @@ from sklearn.impute import SimpleImputer
|
|
| 28 |
from sklearn.linear_model import LogisticRegression
|
| 29 |
from sklearn.model_selection import train_test_split
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
# ============================================================
|
|
@@ -717,68 +756,85 @@ with tab_train:
|
|
| 717 |
)
|
| 718 |
st.markdown("**Confusion Matrix (threshold = 0.5)**")
|
| 719 |
st.dataframe(cm_df)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 720 |
|
| 721 |
# ROC curve plot (matplotlib)
|
| 722 |
roc = m["roc_curve"]
|
| 723 |
-
fig =
|
| 724 |
plt.plot(roc["fpr"], roc["tpr"])
|
| 725 |
plt.plot([0, 1], [0, 1])
|
| 726 |
plt.xlabel("False Positive Rate (1 - Specificity)")
|
| 727 |
plt.ylabel("True Positive Rate (Sensitivity)")
|
| 728 |
plt.title(f"ROC Curve (AUC = {m['roc_auc']:.3f})")
|
| 729 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 730 |
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
fig_pr
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
st.pyplot(fig_pr, clear_figure=True)
|
| 746 |
|
| 747 |
-
st.divider()
|
| 748 |
-
st.subheader("Calibration (Reliability Plot)")
|
| 749 |
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
| 765 |
|
| 766 |
|
| 767 |
-
|
| 768 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 769 |
|
| 770 |
-
if "decision_curve" not in m:
|
| 771 |
-
st.warning("decision curve not available in this model metadata. Retrain the model to generate it.")
|
| 772 |
-
else:
|
| 773 |
-
dca = m["decision_curve"]
|
| 774 |
-
fig_dca = plt.figure()
|
| 775 |
-
plt.plot(dca["thresholds"], dca["net_benefit_model"])
|
| 776 |
-
plt.plot(dca["thresholds"], dca["net_benefit_all"])
|
| 777 |
-
plt.plot(dca["thresholds"], dca["net_benefit_none"])
|
| 778 |
-
plt.xlabel("Threshold probability")
|
| 779 |
-
plt.ylabel("Net benefit")
|
| 780 |
-
plt.title("Decision curve analysis")
|
| 781 |
-
st.pyplot(fig_dca, clear_figure=True)
|
| 782 |
|
| 783 |
st.caption(
|
| 784 |
"If the model curve is above Treat-all and Treat-none across a threshold range, "
|
|
@@ -967,48 +1023,72 @@ with tab_predict:
|
|
| 967 |
st.dataframe(cm_df)
|
| 968 |
|
| 969 |
# ROC plot
|
| 970 |
-
fig =
|
| 971 |
plt.plot(fpr, tpr)
|
| 972 |
plt.plot([0, 1], [0, 1])
|
| 973 |
plt.xlabel("False Positive Rate (1 - Specificity)")
|
| 974 |
plt.ylabel("True Positive Rate (Sensitivity)")
|
| 975 |
plt.title(f"External ROC Curve (AUC = {roc_auc_ext:.3f})")
|
| 976 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 977 |
|
| 978 |
# PR plot
|
| 979 |
-
|
| 980 |
-
|
| 981 |
-
|
| 982 |
-
fig_pr = plt.figure()
|
| 983 |
-
plt.plot(pr_ext["recall"], pr_ext["precision"])
|
| 984 |
plt.xlabel("Recall")
|
| 985 |
plt.ylabel("Precision")
|
| 986 |
-
plt.title(f"
|
| 987 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 988 |
|
| 989 |
# Calibration plot
|
| 990 |
-
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
c2.write(f"Bins: {cal_ext['n_bins']} | Strategy: {cal_ext['strategy']}")
|
| 994 |
-
fig_cal = plt.figure()
|
| 995 |
-
plt.plot(cal_ext["prob_pred"], cal_ext["prob_true"])
|
| 996 |
plt.plot([0, 1], [0, 1])
|
| 997 |
plt.xlabel("Mean predicted probability")
|
| 998 |
plt.ylabel("Observed event rate")
|
| 999 |
-
plt.title("
|
| 1000 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1001 |
|
| 1002 |
# DCA plot
|
| 1003 |
-
|
| 1004 |
-
fig_dca =
|
| 1005 |
-
plt.plot(
|
| 1006 |
-
plt.plot(
|
| 1007 |
-
plt.plot(
|
| 1008 |
plt.xlabel("Threshold probability")
|
| 1009 |
plt.ylabel("Net benefit")
|
| 1010 |
-
plt.title("
|
| 1011 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1012 |
|
| 1013 |
except Exception as e:
|
| 1014 |
st.error(f"Could not compute external validation metrics: {e}")
|
|
@@ -1190,7 +1270,7 @@ with tab_predict:
|
|
| 1190 |
st.markdown(f"### Global SHAP summary (first {batch_n} rows)")
|
| 1191 |
|
| 1192 |
# BAR SUMMARY
|
| 1193 |
-
fig_bar =
|
| 1194 |
shap.summary_plot(
|
| 1195 |
shap_vals_batch,
|
| 1196 |
features=X_dense,
|
|
@@ -1198,20 +1278,33 @@ with tab_predict:
|
|
| 1198 |
plot_type="bar",
|
| 1199 |
max_display=max_display,
|
| 1200 |
show=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1201 |
)
|
| 1202 |
-
|
| 1203 |
|
| 1204 |
# BEESWARM SUMMARY (optional)
|
| 1205 |
-
|
| 1206 |
-
|
| 1207 |
-
|
| 1208 |
-
|
| 1209 |
-
|
| 1210 |
-
|
| 1211 |
-
|
| 1212 |
-
|
| 1213 |
-
|
| 1214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1215 |
|
| 1216 |
st.markdown("### Waterfall plots (batch)")
|
| 1217 |
|
|
|
|
| 28 |
from sklearn.linear_model import LogisticRegression
|
| 29 |
from sklearn.model_selection import train_test_split
|
| 30 |
|
| 31 |
+
#Figures setting block
|
| 32 |
+
import io
|
| 33 |
+
|
| 34 |
+
def make_fig(figsize=(5.5, 3.6), dpi=120):
|
| 35 |
+
"""
|
| 36 |
+
Small figure for laptop screens.
|
| 37 |
+
"""
|
| 38 |
+
fig = plt.figure(figsize=figsize, dpi=dpi)
|
| 39 |
+
return fig
|
| 40 |
+
|
| 41 |
+
def fig_to_png_bytes(fig, dpi=600):
|
| 42 |
+
"""
|
| 43 |
+
Export current figure as PNG bytes at high DPI (>=600).
|
| 44 |
+
"""
|
| 45 |
+
buf = io.BytesIO()
|
| 46 |
+
fig.savefig(buf, format="png", dpi=int(dpi), bbox_inches="tight")
|
| 47 |
+
buf.seek(0)
|
| 48 |
+
return buf.getvalue()
|
| 49 |
+
|
| 50 |
+
def render_plot_with_download(
|
| 51 |
+
fig,
|
| 52 |
+
*,
|
| 53 |
+
title: str,
|
| 54 |
+
filename: str,
|
| 55 |
+
export_dpi: int = 600
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Show a compact plot in Streamlit + provide high-DPI PNG download.
|
| 59 |
+
"""
|
| 60 |
+
st.pyplot(fig, clear_figure=True, use_container_width=True)
|
| 61 |
+
|
| 62 |
+
png_bytes = fig_to_png_bytes(fig, dpi=export_dpi)
|
| 63 |
+
st.download_button(
|
| 64 |
+
label=f"Download {title} (PNG {export_dpi} dpi)",
|
| 65 |
+
data=png_bytes,
|
| 66 |
+
file_name=filename,
|
| 67 |
+
mime="image/png",
|
| 68 |
+
key=f"dl_{filename}"
|
| 69 |
+
)
|
| 70 |
|
| 71 |
|
| 72 |
# ============================================================
|
|
|
|
| 756 |
)
|
| 757 |
st.markdown("**Confusion Matrix (threshold = 0.5)**")
|
| 758 |
st.dataframe(cm_df)
|
| 759 |
+
|
| 760 |
+
st.markdown("### Plot display settings")
|
| 761 |
+
|
| 762 |
+
plot_width = st.slider("Plot width (inches)", 4.0, 10.0, 5.5, 0.1)
|
| 763 |
+
plot_height = st.slider("Plot height (inches)", 2.5, 6.0, 3.6, 0.1)
|
| 764 |
+
plot_dpi_screen = st.slider("Screen DPI", 80, 200, 120, 10)
|
| 765 |
+
|
| 766 |
+
export_dpi = st.selectbox("Export DPI (PNG)", [300, 600, 900, 1200], index=1)
|
| 767 |
+
|
| 768 |
+
FIGSIZE = (plot_width, plot_height)
|
| 769 |
+
|
| 770 |
|
| 771 |
# ROC curve plot (matplotlib)
|
| 772 |
roc = m["roc_curve"]
|
| 773 |
+
fig = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 774 |
plt.plot(roc["fpr"], roc["tpr"])
|
| 775 |
plt.plot([0, 1], [0, 1])
|
| 776 |
plt.xlabel("False Positive Rate (1 - Specificity)")
|
| 777 |
plt.ylabel("True Positive Rate (Sensitivity)")
|
| 778 |
plt.title(f"ROC Curve (AUC = {m['roc_auc']:.3f})")
|
| 779 |
+
|
| 780 |
+
render_plot_with_download(
|
| 781 |
+
fig,
|
| 782 |
+
title="ROC curve",
|
| 783 |
+
filename="roc_curve.png",
|
| 784 |
+
export_dpi=export_dpi
|
| 785 |
+
)
|
| 786 |
|
| 787 |
+
#Precision recall curve
|
| 788 |
+
pr = m["pr_curve"]
|
| 789 |
+
fig_pr = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 790 |
+
plt.plot(pr["recall"], pr["precision"])
|
| 791 |
+
plt.xlabel("Recall")
|
| 792 |
+
plt.ylabel("Precision")
|
| 793 |
+
plt.title(f"PR Curve (AP = {pr['average_precision']:.3f})")
|
| 794 |
+
|
| 795 |
+
render_plot_with_download(
|
| 796 |
+
fig_pr,
|
| 797 |
+
title="PR curve",
|
| 798 |
+
filename="pr_curve.png",
|
| 799 |
+
export_dpi=export_dpi
|
| 800 |
+
)
|
|
|
|
| 801 |
|
|
|
|
|
|
|
| 802 |
|
| 803 |
+
#Calibration plot
|
| 804 |
+
cal = m["calibration"]
|
| 805 |
+
fig_cal = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 806 |
+
plt.plot(cal["prob_pred"], cal["prob_true"])
|
| 807 |
+
plt.plot([0, 1], [0, 1])
|
| 808 |
+
plt.xlabel("Mean predicted probability")
|
| 809 |
+
plt.ylabel("Observed event rate")
|
| 810 |
+
plt.title("Calibration curve")
|
| 811 |
+
|
| 812 |
+
render_plot_with_download(
|
| 813 |
+
fig_cal,
|
| 814 |
+
title="Calibration curve",
|
| 815 |
+
filename="calibration_curve.png",
|
| 816 |
+
export_dpi=export_dpi
|
| 817 |
+
)
|
| 818 |
|
| 819 |
|
| 820 |
+
#Decision curve
|
| 821 |
+
dca = m["decision_curve"]
|
| 822 |
+
fig_dca = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 823 |
+
plt.plot(dca["thresholds"], dca["net_benefit_model"], label="Model")
|
| 824 |
+
plt.plot(dca["thresholds"], dca["net_benefit_all"], label="Treat all")
|
| 825 |
+
plt.plot(dca["thresholds"], dca["net_benefit_none"], label="Treat none")
|
| 826 |
+
plt.xlabel("Threshold probability")
|
| 827 |
+
plt.ylabel("Net benefit")
|
| 828 |
+
plt.title("Decision curve analysis")
|
| 829 |
+
plt.legend()
|
| 830 |
+
|
| 831 |
+
render_plot_with_download(
|
| 832 |
+
fig_dca,
|
| 833 |
+
title="Decision curve",
|
| 834 |
+
filename="decision_curve.png",
|
| 835 |
+
export_dpi=export_dpi
|
| 836 |
+
)
|
| 837 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 838 |
|
| 839 |
st.caption(
|
| 840 |
"If the model curve is above Treat-all and Treat-none across a threshold range, "
|
|
|
|
| 1023 |
st.dataframe(cm_df)
|
| 1024 |
|
| 1025 |
# ROC plot
|
| 1026 |
+
fig = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1027 |
plt.plot(fpr, tpr)
|
| 1028 |
plt.plot([0, 1], [0, 1])
|
| 1029 |
plt.xlabel("False Positive Rate (1 - Specificity)")
|
| 1030 |
plt.ylabel("True Positive Rate (Sensitivity)")
|
| 1031 |
plt.title(f"External ROC Curve (AUC = {roc_auc_ext:.3f})")
|
| 1032 |
+
|
| 1033 |
+
render_plot_with_download(
|
| 1034 |
+
fig,
|
| 1035 |
+
title="External ROC curve",
|
| 1036 |
+
filename="external_roc_curve.png",
|
| 1037 |
+
export_dpi=export_dpi
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
|
| 1041 |
# PR plot
|
| 1042 |
+
pr = m["pr_curve"]
|
| 1043 |
+
fig_pr = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1044 |
+
plt.plot(pr["recall"], pr["precision"])
|
|
|
|
|
|
|
| 1045 |
plt.xlabel("Recall")
|
| 1046 |
plt.ylabel("Precision")
|
| 1047 |
+
plt.title(f"PR Curve (AP = {pr['average_precision']:.3f})")
|
| 1048 |
+
|
| 1049 |
+
render_plot_with_download(
|
| 1050 |
+
fig_pr,
|
| 1051 |
+
title="PR curve",
|
| 1052 |
+
filename="pr_curve.png",
|
| 1053 |
+
export_dpi=export_dpi
|
| 1054 |
+
)
|
| 1055 |
+
|
| 1056 |
|
| 1057 |
# Calibration plot
|
| 1058 |
+
cal = m["calibration"]
|
| 1059 |
+
fig_cal = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1060 |
+
plt.plot(cal["prob_pred"], cal["prob_true"])
|
|
|
|
|
|
|
|
|
|
| 1061 |
plt.plot([0, 1], [0, 1])
|
| 1062 |
plt.xlabel("Mean predicted probability")
|
| 1063 |
plt.ylabel("Observed event rate")
|
| 1064 |
+
plt.title("Calibration curve")
|
| 1065 |
+
|
| 1066 |
+
render_plot_with_download(
|
| 1067 |
+
fig_cal,
|
| 1068 |
+
title="Calibration curve",
|
| 1069 |
+
filename="calibration_curve.png",
|
| 1070 |
+
export_dpi=export_dpi
|
| 1071 |
+
)
|
| 1072 |
+
|
| 1073 |
|
| 1074 |
# DCA plot
|
| 1075 |
+
dca = m["decision_curve"]
|
| 1076 |
+
fig_dca = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1077 |
+
plt.plot(dca["thresholds"], dca["net_benefit_model"], label="Model")
|
| 1078 |
+
plt.plot(dca["thresholds"], dca["net_benefit_all"], label="Treat all")
|
| 1079 |
+
plt.plot(dca["thresholds"], dca["net_benefit_none"], label="Treat none")
|
| 1080 |
plt.xlabel("Threshold probability")
|
| 1081 |
plt.ylabel("Net benefit")
|
| 1082 |
+
plt.title("Decision curve analysis")
|
| 1083 |
+
plt.legend()
|
| 1084 |
+
|
| 1085 |
+
render_plot_with_download(
|
| 1086 |
+
fig_dca,
|
| 1087 |
+
title="Decision curve",
|
| 1088 |
+
filename="decision_curve.png",
|
| 1089 |
+
export_dpi=export_dpi
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
|
| 1093 |
except Exception as e:
|
| 1094 |
st.error(f"Could not compute external validation metrics: {e}")
|
|
|
|
| 1270 |
st.markdown(f"### Global SHAP summary (first {batch_n} rows)")
|
| 1271 |
|
| 1272 |
# BAR SUMMARY
|
| 1273 |
+
fig_bar = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1274 |
shap.summary_plot(
|
| 1275 |
shap_vals_batch,
|
| 1276 |
features=X_dense,
|
|
|
|
| 1278 |
plot_type="bar",
|
| 1279 |
max_display=max_display,
|
| 1280 |
show=False,
|
| 1281 |
+
plot_size=FIGSIZE
|
| 1282 |
+
)
|
| 1283 |
+
render_plot_with_download(
|
| 1284 |
+
fig_bar,
|
| 1285 |
+
title="SHAP bar summary",
|
| 1286 |
+
filename="shap_summary_bar.png",
|
| 1287 |
+
export_dpi=export_dpi
|
| 1288 |
)
|
| 1289 |
+
|
| 1290 |
|
| 1291 |
# BEESWARM SUMMARY (optional)
|
| 1292 |
+
fig_swarm = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1293 |
+
shap.summary_plot(
|
| 1294 |
+
shap_vals_batch,
|
| 1295 |
+
features=X_dense,
|
| 1296 |
+
feature_names=names,
|
| 1297 |
+
max_display=max_display,
|
| 1298 |
+
show=False,
|
| 1299 |
+
plot_size=FIGSIZE
|
| 1300 |
+
)
|
| 1301 |
+
render_plot_with_download(
|
| 1302 |
+
fig_swarm,
|
| 1303 |
+
title="SHAP beeswarm",
|
| 1304 |
+
filename="shap_beeswarm.png",
|
| 1305 |
+
export_dpi=export_dpi
|
| 1306 |
+
)
|
| 1307 |
+
|
| 1308 |
|
| 1309 |
st.markdown("### Waterfall plots (batch)")
|
| 1310 |
|