Update app.py
Browse files
app.py
CHANGED
|
@@ -761,11 +761,133 @@ with tab_predict:
|
|
| 761 |
|
| 762 |
|
| 763 |
proba = pipe.predict_proba(X_inf)[:, 1]
|
| 764 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 765 |
df_out = df_inf.copy()
|
| 766 |
df_out["predicted_probability"] = proba
|
| 767 |
-
|
| 768 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 769 |
st.download_button(
|
| 770 |
"Download predictions",
|
| 771 |
df_out.to_csv(index=False).encode(),
|
|
@@ -773,6 +895,7 @@ with tab_predict:
|
|
| 773 |
"text/csv"
|
| 774 |
)
|
| 775 |
|
|
|
|
| 776 |
st.subheader("SHAP explanation")
|
| 777 |
|
| 778 |
with st.form("shap_form"):
|
|
|
|
| 761 |
|
| 762 |
|
| 763 |
proba = pipe.predict_proba(X_inf)[:, 1]
|
| 764 |
+
st.divider()
|
| 765 |
+
st.subheader("External validation (if AA label is present)")
|
| 766 |
+
|
| 767 |
+
if LABEL_COL in df_inf.columns:
|
| 768 |
+
try:
|
| 769 |
+
y_ext_raw = df_inf[LABEL_COL].copy()
|
| 770 |
+
y_ext01, _ = coerce_binary_label(y_ext_raw)
|
| 771 |
+
|
| 772 |
+
# Core metrics
|
| 773 |
+
roc_auc_ext = float(roc_auc_score(y_ext01, proba))
|
| 774 |
+
fpr, tpr, roc_thresholds = roc_curve(y_ext01, proba)
|
| 775 |
+
|
| 776 |
+
# Threshold metrics (user-controlled)
|
| 777 |
+
thr_ext = st.slider("External validation threshold", 0.0, 1.0, 0.5, 0.01, key="thr_ext")
|
| 778 |
+
cls_ext = compute_classification_metrics(y_ext01, proba, threshold=float(thr_ext))
|
| 779 |
+
|
| 780 |
+
pr_ext = compute_pr_curve(y_ext01, proba)
|
| 781 |
+
cal_ext = compute_calibration(
|
| 782 |
+
y_ext01, proba,
|
| 783 |
+
n_bins=int(n_bins) if "n_bins" in locals() else 10,
|
| 784 |
+
strategy=str(cal_strategy) if "cal_strategy" in locals() else "uniform"
|
| 785 |
+
)
|
| 786 |
+
dca_ext = decision_curve_analysis(y_ext01, proba)
|
| 787 |
+
|
| 788 |
+
# Display headline metrics
|
| 789 |
+
c1, c2, c3, c4 = st.columns(4)
|
| 790 |
+
c1.metric("ROC AUC (external)", f"{roc_auc_ext:.3f}")
|
| 791 |
+
c2.metric("Sensitivity", f"{cls_ext['sensitivity']:.3f}")
|
| 792 |
+
c3.metric("Specificity", f"{cls_ext['specificity']:.3f}")
|
| 793 |
+
c4.metric("F1", f"{cls_ext['f1']:.3f}")
|
| 794 |
+
|
| 795 |
+
# Confusion matrix
|
| 796 |
+
cm_df = pd.DataFrame(
|
| 797 |
+
[[cls_ext["tn"], cls_ext["fp"]], [cls_ext["fn"], cls_ext["tp"]]],
|
| 798 |
+
index=["Actual 0", "Actual 1"],
|
| 799 |
+
columns=["Pred 0", "Pred 1"],
|
| 800 |
+
)
|
| 801 |
+
st.markdown("**Confusion Matrix (external)**")
|
| 802 |
+
st.dataframe(cm_df)
|
| 803 |
+
|
| 804 |
+
# ROC plot
|
| 805 |
+
fig = plt.figure()
|
| 806 |
+
plt.plot(fpr, tpr)
|
| 807 |
+
plt.plot([0, 1], [0, 1])
|
| 808 |
+
plt.xlabel("False Positive Rate (1 - Specificity)")
|
| 809 |
+
plt.ylabel("True Positive Rate (Sensitivity)")
|
| 810 |
+
plt.title(f"External ROC Curve (AUC = {roc_auc_ext:.3f})")
|
| 811 |
+
st.pyplot(fig, clear_figure=True)
|
| 812 |
+
|
| 813 |
+
# PR plot
|
| 814 |
+
st.subheader("Precision–Recall (external)")
|
| 815 |
+
c1, c2 = st.columns(2)
|
| 816 |
+
c1.metric("Average Precision (AP)", f"{pr_ext['average_precision']:.3f}")
|
| 817 |
+
fig_pr = plt.figure()
|
| 818 |
+
plt.plot(pr_ext["recall"], pr_ext["precision"])
|
| 819 |
+
plt.xlabel("Recall")
|
| 820 |
+
plt.ylabel("Precision")
|
| 821 |
+
plt.title(f"External PR Curve (AP = {pr_ext['average_precision']:.3f})")
|
| 822 |
+
st.pyplot(fig_pr, clear_figure=True)
|
| 823 |
+
|
| 824 |
+
# Calibration plot
|
| 825 |
+
st.subheader("Calibration (external)")
|
| 826 |
+
c1, c2 = st.columns(2)
|
| 827 |
+
c1.metric("Brier score", f"{cal_ext['brier']:.4f}")
|
| 828 |
+
c2.write(f"Bins: {cal_ext['n_bins']} | Strategy: {cal_ext['strategy']}")
|
| 829 |
+
fig_cal = plt.figure()
|
| 830 |
+
plt.plot(cal_ext["prob_pred"], cal_ext["prob_true"])
|
| 831 |
+
plt.plot([0, 1], [0, 1])
|
| 832 |
+
plt.xlabel("Mean predicted probability")
|
| 833 |
+
plt.ylabel("Observed event rate")
|
| 834 |
+
plt.title("External Calibration curve")
|
| 835 |
+
st.pyplot(fig_cal, clear_figure=True)
|
| 836 |
+
|
| 837 |
+
# DCA plot
|
| 838 |
+
st.subheader("Decision Curve Analysis (external)")
|
| 839 |
+
fig_dca = plt.figure()
|
| 840 |
+
plt.plot(dca_ext["thresholds"], dca_ext["net_benefit_model"])
|
| 841 |
+
plt.plot(dca_ext["thresholds"], dca_ext["net_benefit_all"])
|
| 842 |
+
plt.plot(dca_ext["thresholds"], dca_ext["net_benefit_none"])
|
| 843 |
+
plt.xlabel("Threshold probability")
|
| 844 |
+
plt.ylabel("Net benefit")
|
| 845 |
+
plt.title("External Decision curve analysis")
|
| 846 |
+
st.pyplot(fig_dca, clear_figure=True)
|
| 847 |
+
|
| 848 |
+
except Exception as e:
|
| 849 |
+
st.error(f"Could not compute external validation metrics: {e}")
|
| 850 |
+
else:
|
| 851 |
+
st.info("No AA column found in the inference Excel, so external validation metrics cannot be computed.")
|
| 852 |
+
|
| 853 |
+
# Predict probabilities
|
| 854 |
+
proba = pipe.predict_proba(X_inf)[:, 1]
|
| 855 |
+
|
| 856 |
df_out = df_inf.copy()
|
| 857 |
df_out["predicted_probability"] = proba
|
| 858 |
+
|
| 859 |
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
# --- classification + risk bands ---
|
| 863 |
+
st.divider()
|
| 864 |
+
st.subheader("Risk stratification")
|
| 865 |
+
|
| 866 |
+
thr = st.slider(
|
| 867 |
+
"Decision threshold for classification",
|
| 868 |
+
0.0, 1.0, 0.5, 0.01,
|
| 869 |
+
key="pred_thr"
|
| 870 |
+
)
|
| 871 |
+
df_out["predicted_class"] = (df_out["predicted_probability"] >= thr).astype(int)
|
| 872 |
+
|
| 873 |
+
low_cut, high_cut = st.slider(
|
| 874 |
+
"Risk band cutoffs (low, high)",
|
| 875 |
+
0.0, 1.0, (0.2, 0.8), 0.01,
|
| 876 |
+
key="risk_cuts"
|
| 877 |
+
)
|
| 878 |
+
|
| 879 |
+
def band(p):
|
| 880 |
+
if p < low_cut:
|
| 881 |
+
return "Low"
|
| 882 |
+
if p >= high_cut:
|
| 883 |
+
return "High"
|
| 884 |
+
return "Intermediate"
|
| 885 |
+
|
| 886 |
+
df_out["risk_band"] = df_out["predicted_probability"].map(band)
|
| 887 |
+
# --- END ADD ---
|
| 888 |
+
|
| 889 |
+
st.dataframe(df_out.head())
|
| 890 |
+
|
| 891 |
st.download_button(
|
| 892 |
"Download predictions",
|
| 893 |
df_out.to_csv(index=False).encode(),
|
|
|
|
| 895 |
"text/csv"
|
| 896 |
)
|
| 897 |
|
| 898 |
+
|
| 899 |
st.subheader("SHAP explanation")
|
| 900 |
|
| 901 |
with st.form("shap_form"):
|