Synav commited on
Commit
2b2b5ee
·
verified ·
1 Parent(s): 500b8f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -2
app.py CHANGED
@@ -761,11 +761,133 @@ with tab_predict:
761
 
762
 
763
  proba = pipe.predict_proba(X_inf)[:, 1]
764
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
765
  df_out = df_inf.copy()
766
  df_out["predicted_probability"] = proba
767
- st.dataframe(df_out.head())
768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769
  st.download_button(
770
  "Download predictions",
771
  df_out.to_csv(index=False).encode(),
@@ -773,6 +895,7 @@ with tab_predict:
773
  "text/csv"
774
  )
775
 
 
776
  st.subheader("SHAP explanation")
777
 
778
  with st.form("shap_form"):
 
761
 
762
 
763
  proba = pipe.predict_proba(X_inf)[:, 1]
764
+ st.divider()
765
+ st.subheader("External validation (if AA label is present)")
766
+
767
+ if LABEL_COL in df_inf.columns:
768
+ try:
769
+ y_ext_raw = df_inf[LABEL_COL].copy()
770
+ y_ext01, _ = coerce_binary_label(y_ext_raw)
771
+
772
+ # Core metrics
773
+ roc_auc_ext = float(roc_auc_score(y_ext01, proba))
774
+ fpr, tpr, roc_thresholds = roc_curve(y_ext01, proba)
775
+
776
+ # Threshold metrics (user-controlled)
777
+ thr_ext = st.slider("External validation threshold", 0.0, 1.0, 0.5, 0.01, key="thr_ext")
778
+ cls_ext = compute_classification_metrics(y_ext01, proba, threshold=float(thr_ext))
779
+
780
+ pr_ext = compute_pr_curve(y_ext01, proba)
781
+ cal_ext = compute_calibration(
782
+ y_ext01, proba,
783
+ n_bins=int(n_bins) if "n_bins" in locals() else 10,
784
+ strategy=str(cal_strategy) if "cal_strategy" in locals() else "uniform"
785
+ )
786
+ dca_ext = decision_curve_analysis(y_ext01, proba)
787
+
788
+ # Display headline metrics
789
+ c1, c2, c3, c4 = st.columns(4)
790
+ c1.metric("ROC AUC (external)", f"{roc_auc_ext:.3f}")
791
+ c2.metric("Sensitivity", f"{cls_ext['sensitivity']:.3f}")
792
+ c3.metric("Specificity", f"{cls_ext['specificity']:.3f}")
793
+ c4.metric("F1", f"{cls_ext['f1']:.3f}")
794
+
795
+ # Confusion matrix
796
+ cm_df = pd.DataFrame(
797
+ [[cls_ext["tn"], cls_ext["fp"]], [cls_ext["fn"], cls_ext["tp"]]],
798
+ index=["Actual 0", "Actual 1"],
799
+ columns=["Pred 0", "Pred 1"],
800
+ )
801
+ st.markdown("**Confusion Matrix (external)**")
802
+ st.dataframe(cm_df)
803
+
804
+ # ROC plot
805
+ fig = plt.figure()
806
+ plt.plot(fpr, tpr)
807
+ plt.plot([0, 1], [0, 1])
808
+ plt.xlabel("False Positive Rate (1 - Specificity)")
809
+ plt.ylabel("True Positive Rate (Sensitivity)")
810
+ plt.title(f"External ROC Curve (AUC = {roc_auc_ext:.3f})")
811
+ st.pyplot(fig, clear_figure=True)
812
+
813
+ # PR plot
814
+ st.subheader("Precision–Recall (external)")
815
+ c1, c2 = st.columns(2)
816
+ c1.metric("Average Precision (AP)", f"{pr_ext['average_precision']:.3f}")
817
+ fig_pr = plt.figure()
818
+ plt.plot(pr_ext["recall"], pr_ext["precision"])
819
+ plt.xlabel("Recall")
820
+ plt.ylabel("Precision")
821
+ plt.title(f"External PR Curve (AP = {pr_ext['average_precision']:.3f})")
822
+ st.pyplot(fig_pr, clear_figure=True)
823
+
824
+ # Calibration plot
825
+ st.subheader("Calibration (external)")
826
+ c1, c2 = st.columns(2)
827
+ c1.metric("Brier score", f"{cal_ext['brier']:.4f}")
828
+ c2.write(f"Bins: {cal_ext['n_bins']} | Strategy: {cal_ext['strategy']}")
829
+ fig_cal = plt.figure()
830
+ plt.plot(cal_ext["prob_pred"], cal_ext["prob_true"])
831
+ plt.plot([0, 1], [0, 1])
832
+ plt.xlabel("Mean predicted probability")
833
+ plt.ylabel("Observed event rate")
834
+ plt.title("External Calibration curve")
835
+ st.pyplot(fig_cal, clear_figure=True)
836
+
837
+ # DCA plot
838
+ st.subheader("Decision Curve Analysis (external)")
839
+ fig_dca = plt.figure()
840
+ plt.plot(dca_ext["thresholds"], dca_ext["net_benefit_model"])
841
+ plt.plot(dca_ext["thresholds"], dca_ext["net_benefit_all"])
842
+ plt.plot(dca_ext["thresholds"], dca_ext["net_benefit_none"])
843
+ plt.xlabel("Threshold probability")
844
+ plt.ylabel("Net benefit")
845
+ plt.title("External Decision curve analysis")
846
+ st.pyplot(fig_dca, clear_figure=True)
847
+
848
+ except Exception as e:
849
+ st.error(f"Could not compute external validation metrics: {e}")
850
+ else:
851
+ st.info("No AA column found in the inference Excel, so external validation metrics cannot be computed.")
852
+
853
+ # Predict probabilities
854
+ proba = pipe.predict_proba(X_inf)[:, 1]
855
+
856
  df_out = df_inf.copy()
857
  df_out["predicted_probability"] = proba
858
+
859
 
860
+
861
+
862
+ # --- classification + risk bands ---
863
+ st.divider()
864
+ st.subheader("Risk stratification")
865
+
866
+ thr = st.slider(
867
+ "Decision threshold for classification",
868
+ 0.0, 1.0, 0.5, 0.01,
869
+ key="pred_thr"
870
+ )
871
+ df_out["predicted_class"] = (df_out["predicted_probability"] >= thr).astype(int)
872
+
873
+ low_cut, high_cut = st.slider(
874
+ "Risk band cutoffs (low, high)",
875
+ 0.0, 1.0, (0.2, 0.8), 0.01,
876
+ key="risk_cuts"
877
+ )
878
+
879
+ def band(p):
880
+ if p < low_cut:
881
+ return "Low"
882
+ if p >= high_cut:
883
+ return "High"
884
+ return "Intermediate"
885
+
886
+ df_out["risk_band"] = df_out["predicted_probability"].map(band)
887
+ # --- END ADD ---
888
+
889
+ st.dataframe(df_out.head())
890
+
891
  st.download_button(
892
  "Download predictions",
893
  df_out.to_csv(index=False).encode(),
 
895
  "text/csv"
896
  )
897
 
898
+
899
  st.subheader("SHAP explanation")
900
 
901
  with st.form("shap_form"):