Synav commited on
Commit
18343ab
·
verified ·
1 Parent(s): 5573147

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -84
app.py CHANGED
@@ -28,6 +28,45 @@ from sklearn.impute import SimpleImputer
28
  from sklearn.linear_model import LogisticRegression
29
  from sklearn.model_selection import train_test_split
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  # ============================================================
@@ -717,68 +756,85 @@ with tab_train:
717
  )
718
  st.markdown("**Confusion Matrix (threshold = 0.5)**")
719
  st.dataframe(cm_df)
 
 
 
 
 
 
 
 
 
 
 
720
 
721
  # ROC curve plot (matplotlib)
722
  roc = m["roc_curve"]
723
- fig = plt.figure()
724
  plt.plot(roc["fpr"], roc["tpr"])
725
  plt.plot([0, 1], [0, 1])
726
  plt.xlabel("False Positive Rate (1 - Specificity)")
727
  plt.ylabel("True Positive Rate (Sensitivity)")
728
  plt.title(f"ROC Curve (AUC = {m['roc_auc']:.3f})")
729
- st.pyplot(fig, clear_figure=True)
 
 
 
 
 
 
730
 
731
- st.divider()
732
- st.subheader("Precision–Recall (PR) Curve")
733
- if "pr_curve" not in m:
734
- st.warning("PR curve not available in this model metadata. Retrain the model to generate it.")
735
- else:
736
- pr = m["pr_curve"]
737
- c1, c2 = st.columns(2)
738
- c1.metric("Average Precision (AP)", f"{pr['average_precision']:.3f}")
739
-
740
- fig_pr = plt.figure()
741
- plt.plot(pr["recall"], pr["precision"])
742
- plt.xlabel("Recall")
743
- plt.ylabel("Precision")
744
- plt.title(f"PR Curve (AP = {pr['average_precision']:.3f})")
745
- st.pyplot(fig_pr, clear_figure=True)
746
 
747
- st.divider()
748
- st.subheader("Calibration (Reliability Plot)")
749
 
750
- if "calibration" not in m:
751
- st.warning("calibration curve not available in this model metadata. Retrain the model to generate it.")
752
- else:
753
- cal = m["calibration"]
754
- c1, c2 = st.columns(2)
755
- c1.metric("Brier score", f"{cal['brier']:.4f}")
756
- c2.write(f"Bins: {cal['n_bins']} | Strategy: {cal['strategy']}")
757
-
758
- fig_cal = plt.figure()
759
- plt.plot(cal["prob_pred"], cal["prob_true"])
760
- plt.plot([0, 1], [0, 1])
761
- plt.xlabel("Mean predicted probability")
762
- plt.ylabel("Observed event rate")
763
- plt.title("Calibration curve")
764
- st.pyplot(fig_cal, clear_figure=True)
765
 
766
 
767
- st.divider()
768
- st.subheader("Decision Curve Analysis (Clinical Usefulness)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769
 
770
- if "decision_curve" not in m:
771
- st.warning("decision curve not available in this model metadata. Retrain the model to generate it.")
772
- else:
773
- dca = m["decision_curve"]
774
- fig_dca = plt.figure()
775
- plt.plot(dca["thresholds"], dca["net_benefit_model"])
776
- plt.plot(dca["thresholds"], dca["net_benefit_all"])
777
- plt.plot(dca["thresholds"], dca["net_benefit_none"])
778
- plt.xlabel("Threshold probability")
779
- plt.ylabel("Net benefit")
780
- plt.title("Decision curve analysis")
781
- st.pyplot(fig_dca, clear_figure=True)
782
 
783
  st.caption(
784
  "If the model curve is above Treat-all and Treat-none across a threshold range, "
@@ -967,48 +1023,72 @@ with tab_predict:
967
  st.dataframe(cm_df)
968
 
969
  # ROC plot
970
- fig = plt.figure()
971
  plt.plot(fpr, tpr)
972
  plt.plot([0, 1], [0, 1])
973
  plt.xlabel("False Positive Rate (1 - Specificity)")
974
  plt.ylabel("True Positive Rate (Sensitivity)")
975
  plt.title(f"External ROC Curve (AUC = {roc_auc_ext:.3f})")
976
- st.pyplot(fig, clear_figure=True)
 
 
 
 
 
 
 
977
 
978
  # PR plot
979
- st.subheader("Precision–Recall (external)")
980
- c1, c2 = st.columns(2)
981
- c1.metric("Average Precision (AP)", f"{pr_ext['average_precision']:.3f}")
982
- fig_pr = plt.figure()
983
- plt.plot(pr_ext["recall"], pr_ext["precision"])
984
  plt.xlabel("Recall")
985
  plt.ylabel("Precision")
986
- plt.title(f"External PR Curve (AP = {pr_ext['average_precision']:.3f})")
987
- st.pyplot(fig_pr, clear_figure=True)
 
 
 
 
 
 
 
988
 
989
  # Calibration plot
990
- st.subheader("Calibration (external)")
991
- c1, c2 = st.columns(2)
992
- c1.metric("Brier score", f"{cal_ext['brier']:.4f}")
993
- c2.write(f"Bins: {cal_ext['n_bins']} | Strategy: {cal_ext['strategy']}")
994
- fig_cal = plt.figure()
995
- plt.plot(cal_ext["prob_pred"], cal_ext["prob_true"])
996
  plt.plot([0, 1], [0, 1])
997
  plt.xlabel("Mean predicted probability")
998
  plt.ylabel("Observed event rate")
999
- plt.title("External Calibration curve")
1000
- st.pyplot(fig_cal, clear_figure=True)
 
 
 
 
 
 
 
1001
 
1002
  # DCA plot
1003
- st.subheader("Decision Curve Analysis (external)")
1004
- fig_dca = plt.figure()
1005
- plt.plot(dca_ext["thresholds"], dca_ext["net_benefit_model"])
1006
- plt.plot(dca_ext["thresholds"], dca_ext["net_benefit_all"])
1007
- plt.plot(dca_ext["thresholds"], dca_ext["net_benefit_none"])
1008
  plt.xlabel("Threshold probability")
1009
  plt.ylabel("Net benefit")
1010
- plt.title("External Decision curve analysis")
1011
- st.pyplot(fig_dca, clear_figure=True)
 
 
 
 
 
 
 
 
1012
 
1013
  except Exception as e:
1014
  st.error(f"Could not compute external validation metrics: {e}")
@@ -1190,7 +1270,7 @@ with tab_predict:
1190
  st.markdown(f"### Global SHAP summary (first {batch_n} rows)")
1191
 
1192
  # BAR SUMMARY
1193
- fig_bar = plt.figure()
1194
  shap.summary_plot(
1195
  shap_vals_batch,
1196
  features=X_dense,
@@ -1198,20 +1278,33 @@ with tab_predict:
1198
  plot_type="bar",
1199
  max_display=max_display,
1200
  show=False,
 
 
 
 
 
 
 
1201
  )
1202
- st.pyplot(fig_bar, clear_figure=True)
1203
 
1204
  # BEESWARM SUMMARY (optional)
1205
- if show_beeswarm:
1206
- fig_swarm = plt.figure()
1207
- shap.summary_plot(
1208
- shap_vals_batch,
1209
- features=X_dense,
1210
- feature_names=names,
1211
- max_display=max_display,
1212
- show=False,
1213
- )
1214
- st.pyplot(fig_swarm, clear_figure=True)
 
 
 
 
 
 
1215
 
1216
  st.markdown("### Waterfall plots (batch)")
1217
 
 
28
  from sklearn.linear_model import LogisticRegression
29
  from sklearn.model_selection import train_test_split
30
 
31
+ #Figures setting block
32
+ import io
33
+
34
+ def make_fig(figsize=(5.5, 3.6), dpi=120):
35
+ """
36
+ Small figure for laptop screens.
37
+ """
38
+ fig = plt.figure(figsize=figsize, dpi=dpi)
39
+ return fig
40
+
41
+ def fig_to_png_bytes(fig, dpi=600):
42
+ """
43
+ Export current figure as PNG bytes at high DPI (>=600).
44
+ """
45
+ buf = io.BytesIO()
46
+ fig.savefig(buf, format="png", dpi=int(dpi), bbox_inches="tight")
47
+ buf.seek(0)
48
+ return buf.getvalue()
49
+
50
+ def render_plot_with_download(
51
+ fig,
52
+ *,
53
+ title: str,
54
+ filename: str,
55
+ export_dpi: int = 600
56
+ ):
57
+ """
58
+ Show a compact plot in Streamlit + provide high-DPI PNG download.
59
+ """
60
+ st.pyplot(fig, clear_figure=True, use_container_width=True)
61
+
62
+ png_bytes = fig_to_png_bytes(fig, dpi=export_dpi)
63
+ st.download_button(
64
+ label=f"Download {title} (PNG {export_dpi} dpi)",
65
+ data=png_bytes,
66
+ file_name=filename,
67
+ mime="image/png",
68
+ key=f"dl_{filename}"
69
+ )
70
 
71
 
72
  # ============================================================
 
756
  )
757
  st.markdown("**Confusion Matrix (threshold = 0.5)**")
758
  st.dataframe(cm_df)
759
+
760
+ st.markdown("### Plot display settings")
761
+
762
+ plot_width = st.slider("Plot width (inches)", 4.0, 10.0, 5.5, 0.1)
763
+ plot_height = st.slider("Plot height (inches)", 2.5, 6.0, 3.6, 0.1)
764
+ plot_dpi_screen = st.slider("Screen DPI", 80, 200, 120, 10)
765
+
766
+ export_dpi = st.selectbox("Export DPI (PNG)", [300, 600, 900, 1200], index=1)
767
+
768
+ FIGSIZE = (plot_width, plot_height)
769
+
770
 
771
  # ROC curve plot (matplotlib)
772
  roc = m["roc_curve"]
773
+ fig = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
774
  plt.plot(roc["fpr"], roc["tpr"])
775
  plt.plot([0, 1], [0, 1])
776
  plt.xlabel("False Positive Rate (1 - Specificity)")
777
  plt.ylabel("True Positive Rate (Sensitivity)")
778
  plt.title(f"ROC Curve (AUC = {m['roc_auc']:.3f})")
779
+
780
+ render_plot_with_download(
781
+ fig,
782
+ title="ROC curve",
783
+ filename="roc_curve.png",
784
+ export_dpi=export_dpi
785
+ )
786
 
787
+ #Precision recall curve
788
+ pr = m["pr_curve"]
789
+ fig_pr = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
790
+ plt.plot(pr["recall"], pr["precision"])
791
+ plt.xlabel("Recall")
792
+ plt.ylabel("Precision")
793
+ plt.title(f"PR Curve (AP = {pr['average_precision']:.3f})")
794
+
795
+ render_plot_with_download(
796
+ fig_pr,
797
+ title="PR curve",
798
+ filename="pr_curve.png",
799
+ export_dpi=export_dpi
800
+ )
 
801
 
 
 
802
 
803
+ #Calibration plot
804
+ cal = m["calibration"]
805
+ fig_cal = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
806
+ plt.plot(cal["prob_pred"], cal["prob_true"])
807
+ plt.plot([0, 1], [0, 1])
808
+ plt.xlabel("Mean predicted probability")
809
+ plt.ylabel("Observed event rate")
810
+ plt.title("Calibration curve")
811
+
812
+ render_plot_with_download(
813
+ fig_cal,
814
+ title="Calibration curve",
815
+ filename="calibration_curve.png",
816
+ export_dpi=export_dpi
817
+ )
818
 
819
 
820
+ #Decision curve
821
+ dca = m["decision_curve"]
822
+ fig_dca = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
823
+ plt.plot(dca["thresholds"], dca["net_benefit_model"], label="Model")
824
+ plt.plot(dca["thresholds"], dca["net_benefit_all"], label="Treat all")
825
+ plt.plot(dca["thresholds"], dca["net_benefit_none"], label="Treat none")
826
+ plt.xlabel("Threshold probability")
827
+ plt.ylabel("Net benefit")
828
+ plt.title("Decision curve analysis")
829
+ plt.legend()
830
+
831
+ render_plot_with_download(
832
+ fig_dca,
833
+ title="Decision curve",
834
+ filename="decision_curve.png",
835
+ export_dpi=export_dpi
836
+ )
837
 
 
 
 
 
 
 
 
 
 
 
 
 
838
 
839
  st.caption(
840
  "If the model curve is above Treat-all and Treat-none across a threshold range, "
 
1023
  st.dataframe(cm_df)
1024
 
1025
  # ROC plot
1026
+ fig = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1027
  plt.plot(fpr, tpr)
1028
  plt.plot([0, 1], [0, 1])
1029
  plt.xlabel("False Positive Rate (1 - Specificity)")
1030
  plt.ylabel("True Positive Rate (Sensitivity)")
1031
  plt.title(f"External ROC Curve (AUC = {roc_auc_ext:.3f})")
1032
+
1033
+ render_plot_with_download(
1034
+ fig,
1035
+ title="External ROC curve",
1036
+ filename="external_roc_curve.png",
1037
+ export_dpi=export_dpi
1038
+ )
1039
+
1040
 
1041
  # PR plot
1042
+ pr = m["pr_curve"]
1043
+ fig_pr = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1044
+ plt.plot(pr["recall"], pr["precision"])
 
 
1045
  plt.xlabel("Recall")
1046
  plt.ylabel("Precision")
1047
+ plt.title(f"PR Curve (AP = {pr['average_precision']:.3f})")
1048
+
1049
+ render_plot_with_download(
1050
+ fig_pr,
1051
+ title="PR curve",
1052
+ filename="pr_curve.png",
1053
+ export_dpi=export_dpi
1054
+ )
1055
+
1056
 
1057
  # Calibration plot
1058
+ cal = m["calibration"]
1059
+ fig_cal = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1060
+ plt.plot(cal["prob_pred"], cal["prob_true"])
 
 
 
1061
  plt.plot([0, 1], [0, 1])
1062
  plt.xlabel("Mean predicted probability")
1063
  plt.ylabel("Observed event rate")
1064
+ plt.title("Calibration curve")
1065
+
1066
+ render_plot_with_download(
1067
+ fig_cal,
1068
+ title="Calibration curve",
1069
+ filename="calibration_curve.png",
1070
+ export_dpi=export_dpi
1071
+ )
1072
+
1073
 
1074
  # DCA plot
1075
+ dca = m["decision_curve"]
1076
+ fig_dca = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1077
+ plt.plot(dca["thresholds"], dca["net_benefit_model"], label="Model")
1078
+ plt.plot(dca["thresholds"], dca["net_benefit_all"], label="Treat all")
1079
+ plt.plot(dca["thresholds"], dca["net_benefit_none"], label="Treat none")
1080
  plt.xlabel("Threshold probability")
1081
  plt.ylabel("Net benefit")
1082
+ plt.title("Decision curve analysis")
1083
+ plt.legend()
1084
+
1085
+ render_plot_with_download(
1086
+ fig_dca,
1087
+ title="Decision curve",
1088
+ filename="decision_curve.png",
1089
+ export_dpi=export_dpi
1090
+ )
1091
+
1092
 
1093
  except Exception as e:
1094
  st.error(f"Could not compute external validation metrics: {e}")
 
1270
  st.markdown(f"### Global SHAP summary (first {batch_n} rows)")
1271
 
1272
  # BAR SUMMARY
1273
+ fig_bar = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1274
  shap.summary_plot(
1275
  shap_vals_batch,
1276
  features=X_dense,
 
1278
  plot_type="bar",
1279
  max_display=max_display,
1280
  show=False,
1281
+ plot_size=FIGSIZE
1282
+ )
1283
+ render_plot_with_download(
1284
+ fig_bar,
1285
+ title="SHAP bar summary",
1286
+ filename="shap_summary_bar.png",
1287
+ export_dpi=export_dpi
1288
  )
1289
+
1290
 
1291
  # BEESWARM SUMMARY (optional)
1292
+ fig_swarm = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1293
+ shap.summary_plot(
1294
+ shap_vals_batch,
1295
+ features=X_dense,
1296
+ feature_names=names,
1297
+ max_display=max_display,
1298
+ show=False,
1299
+ plot_size=FIGSIZE
1300
+ )
1301
+ render_plot_with_download(
1302
+ fig_swarm,
1303
+ title="SHAP beeswarm",
1304
+ filename="shap_beeswarm.png",
1305
+ export_dpi=export_dpi
1306
+ )
1307
+
1308
 
1309
  st.markdown("### Waterfall plots (batch)")
1310