Update app.py
Browse files
app.py
CHANGED
|
@@ -30,7 +30,7 @@ from sklearn.model_selection import train_test_split
|
|
| 30 |
|
| 31 |
#Figures setting block
|
| 32 |
import io
|
| 33 |
-
|
| 34 |
|
| 35 |
# REPLACE make_fig with this (or add this and stop using plt.plot directly)
|
| 36 |
def make_fig(figsize=(5.5, 3.6), dpi=120):
|
|
@@ -931,6 +931,8 @@ if "pipe" not in st.session_state:
|
|
| 931 |
if "explainer" not in st.session_state:
|
| 932 |
st.session_state.explainer = None
|
| 933 |
|
|
|
|
|
|
|
| 934 |
with tab_train:
|
| 935 |
st.subheader("Train model")
|
| 936 |
|
|
@@ -962,250 +964,243 @@ with tab_train:
|
|
| 962 |
st.divider()
|
| 963 |
|
| 964 |
# then keep your file uploader + training button + publish block here
|
| 965 |
-
|
| 966 |
-
|
| 967 |
-
|
| 968 |
-
# ---------------- TRAIN ----------------
|
| 969 |
-
|
| 970 |
-
|
| 971 |
|
| 972 |
-
with tab_train:
|
| 973 |
-
st.subheader("Train model")
|
| 974 |
|
| 975 |
-
|
| 976 |
-
st.info("Training and publishing are restricted. Use Predict + SHAP for inference.")
|
| 977 |
-
else:
|
| 978 |
-
train_file = st.file_uploader("Upload training Excel (.xlsx)", type=["xlsx"])
|
| 979 |
|
| 980 |
-
|
| 981 |
-
|
| 982 |
-
|
| 983 |
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
feature_cols = get_feature_cols_from_df(df)
|
| 987 |
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
|
| 991 |
-
num_cols = st.multiselect(
|
| 992 |
-
"Numeric variables (will be median-imputed + scaled)",
|
| 993 |
-
options=feature_cols,
|
| 994 |
-
default=default_numeric
|
| 995 |
-
)
|
| 996 |
-
|
| 997 |
-
# Everything not selected as numeric becomes categorical
|
| 998 |
-
cat_cols = [c for c in feature_cols if c not in num_cols]
|
| 999 |
-
|
| 1000 |
-
st.write(f"Categorical variables (will be most-frequent-imputed + one-hot): {len(cat_cols)}")
|
| 1001 |
-
st.caption("Note: The selected schema is stored with the trained model and must match inference files.")
|
| 1002 |
|
| 1003 |
-
|
| 1004 |
-
|
| 1005 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1006 |
|
| 1007 |
-
|
| 1008 |
-
|
| 1009 |
|
| 1010 |
-
|
| 1011 |
-
|
| 1012 |
-
|
| 1013 |
-
|
| 1014 |
-
|
| 1015 |
-
|
| 1016 |
-
l1_C=l1_C,
|
| 1017 |
-
use_dimred=use_dimred,
|
| 1018 |
-
svd_components=svd_components
|
| 1019 |
-
)
|
| 1020 |
-
|
| 1021 |
-
explainer = build_shap_explainer(pipe, X_bg)
|
| 1022 |
-
|
| 1023 |
-
st.session_state.pipe = pipe
|
| 1024 |
-
st.session_state.explainer = explainer
|
| 1025 |
-
st.session_state.meta = meta
|
| 1026 |
|
| 1027 |
-
st.success("Training complete. model.joblib and meta.json created.")
|
| 1028 |
|
| 1029 |
-
|
| 1030 |
-
|
| 1031 |
-
|
| 1032 |
-
|
| 1033 |
-
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
|
| 1038 |
-
|
| 1039 |
-
|
| 1040 |
-
|
| 1041 |
-
|
| 1042 |
-
|
| 1043 |
-
|
| 1044 |
-
|
| 1045 |
-
|
| 1046 |
-
|
| 1047 |
-
|
| 1048 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1049 |
|
| 1050 |
|
| 1051 |
|
| 1052 |
|
| 1053 |
|
| 1054 |
-
|
| 1055 |
-
|
| 1056 |
-
|
| 1057 |
-
|
| 1058 |
-
|
| 1059 |
-
|
| 1060 |
-
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
|
| 1064 |
|
| 1065 |
|
| 1066 |
|
| 1067 |
|
| 1068 |
-
|
| 1069 |
-
|
| 1070 |
-
|
| 1071 |
-
|
| 1072 |
-
|
| 1073 |
-
|
| 1074 |
-
|
| 1075 |
-
|
| 1076 |
-
|
| 1077 |
-
|
| 1078 |
-
render_plot_with_download(
|
| 1079 |
-
fig,
|
| 1080 |
-
title="ROC curve",
|
| 1081 |
-
filename="roc_curve.png",
|
| 1082 |
-
export_dpi=export_dpi,
|
| 1083 |
-
key="dl_train_roc"
|
| 1084 |
-
)
|
| 1085 |
-
|
| 1086 |
-
|
| 1087 |
-
#Precision recall curve
|
| 1088 |
-
# =========================
|
| 1089 |
-
# TRAINING: PR curve plot
|
| 1090 |
-
# =========================
|
| 1091 |
-
pr = m["pr_curve"]
|
| 1092 |
-
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1093 |
-
ax.plot(pr["recall"], pr["precision"])
|
| 1094 |
-
ax.set_xlabel("Recall")
|
| 1095 |
-
ax.set_ylabel("Precision")
|
| 1096 |
-
ax.set_title(f"PR Curve (AP = {pr['average_precision']:.3f})")
|
| 1097 |
|
| 1098 |
-
|
| 1099 |
-
|
| 1100 |
-
|
| 1101 |
-
|
| 1102 |
-
|
| 1103 |
-
|
| 1104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1105 |
|
| 1106 |
|
| 1107 |
|
| 1108 |
-
|
| 1109 |
-
|
| 1110 |
-
|
| 1111 |
-
|
| 1112 |
-
|
| 1113 |
-
|
| 1114 |
-
|
| 1115 |
-
|
| 1116 |
-
|
| 1117 |
-
|
| 1118 |
-
|
| 1119 |
-
|
| 1120 |
-
|
| 1121 |
-
|
| 1122 |
-
|
| 1123 |
-
|
| 1124 |
-
|
| 1125 |
-
|
| 1126 |
-
|
| 1127 |
|
| 1128 |
|
| 1129 |
|
| 1130 |
-
|
| 1131 |
-
|
| 1132 |
-
|
| 1133 |
-
|
| 1134 |
-
|
| 1135 |
-
|
| 1136 |
-
|
| 1137 |
-
|
| 1138 |
-
|
| 1139 |
-
|
| 1140 |
-
|
| 1141 |
-
|
| 1142 |
-
|
| 1143 |
-
|
| 1144 |
-
|
| 1145 |
-
|
| 1146 |
-
|
| 1147 |
-
|
| 1148 |
-
|
| 1149 |
-
|
| 1150 |
-
|
| 1151 |
-
|
| 1152 |
-
|
| 1153 |
-
|
| 1154 |
-
|
| 1155 |
-
|
| 1156 |
-
|
| 1157 |
|
| 1158 |
|
| 1159 |
|
|
|
|
|
|
|
|
|
|
| 1160 |
|
| 1161 |
-
|
| 1162 |
-
st.subheader("Threshold analysis")
|
| 1163 |
|
| 1164 |
-
|
| 1165 |
-
|
| 1166 |
-
|
| 1167 |
-
|
| 1168 |
-
|
| 1169 |
-
st.session_state.proba_last =
|
| 1170 |
-
|
| 1171 |
-
|
| 1172 |
-
|
| 1173 |
-
|
| 1174 |
-
|
| 1175 |
-
|
| 1176 |
-
|
| 1177 |
-
|
| 1178 |
-
|
| 1179 |
-
"Balanced Accuracy": cls["balanced_accuracy"],
|
| 1180 |
-
})
|
| 1181 |
|
| 1182 |
|
| 1183 |
|
| 1184 |
|
| 1185 |
|
| 1186 |
-
|
| 1187 |
-
|
| 1188 |
-
|
| 1189 |
-
if st.session_state.get("pipe") is not None:
|
| 1190 |
-
st.divider()
|
| 1191 |
-
st.subheader("Publish trained model to Hugging Face Hub")
|
| 1192 |
-
|
| 1193 |
-
default_version = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
| 1194 |
-
version_tag = st.text_input(
|
| 1195 |
-
"Version tag",
|
| 1196 |
-
value=default_version,
|
| 1197 |
-
help="Used as releases/<version>/ in the model repository",
|
| 1198 |
-
)
|
| 1199 |
|
| 1200 |
-
|
| 1201 |
-
|
| 1202 |
-
|
| 1203 |
-
|
| 1204 |
-
|
| 1205 |
-
|
| 1206 |
-
|
| 1207 |
-
|
| 1208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1209 |
|
| 1210 |
|
| 1211 |
# ---------------- PREDICT ----------------
|
|
@@ -1490,8 +1485,7 @@ with tab_predict:
|
|
| 1490 |
TAB_MANAGED_FIELDS.add(NGS_COUNT_COL)
|
| 1491 |
|
| 1492 |
# Age + key dates are handled by DOB/Dx inputs, not generic UI
|
| 1493 |
-
|
| 1494 |
-
TAB_MANAGED_FIELDS.add(DX_DATE_FEATURE)
|
| 1495 |
if "Date of 1st CR" in feature_cols:
|
| 1496 |
TAB_MANAGED_FIELDS.add("Date of 1st CR")
|
| 1497 |
|
|
@@ -1613,7 +1607,8 @@ with tab_predict:
|
|
| 1613 |
continue
|
| 1614 |
|
| 1615 |
# Age auto-calc (display integer, store float)
|
| 1616 |
-
|
|
|
|
| 1617 |
if np.isnan(derived_age):
|
| 1618 |
st.number_input(
|
| 1619 |
f"{f} (auto from DOB & Dx date)",
|
|
@@ -1635,28 +1630,32 @@ with tab_predict:
|
|
| 1635 |
)
|
| 1636 |
values_by_index[i] = float(derived_age)
|
| 1637 |
continue
|
| 1638 |
-
|
|
|
|
|
|
|
| 1639 |
values_by_index[i] = np.nan if dx_date is None else dx_date.isoformat()
|
| 1640 |
-
st.text_input(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1641 |
continue
|
|
|
|
| 1642 |
|
| 1643 |
if f.strip() == "Date of 1st CR".strip():
|
| 1644 |
values_by_index[i] = np.nan if cr1_date is None else cr1_date.isoformat()
|
| 1645 |
-
st.text_input(f, value="" if cr1_date is None else cr1_date.isoformat(), disabled=True, key=f"sp_{i}_cr_show")
|
| 1646 |
-
continue
|
| 1647 |
-
|
| 1648 |
-
|
| 1649 |
-
# Dx date text (store string)
|
| 1650 |
-
if f.strip() == DX_DATE_FEATURE.strip():
|
| 1651 |
st.text_input(
|
| 1652 |
f"{f} (auto)",
|
| 1653 |
-
value="" if
|
| 1654 |
-
key=f"sp_{i}_dx",
|
| 1655 |
disabled=True,
|
|
|
|
| 1656 |
)
|
| 1657 |
-
values_by_index[i] = np.nan if dx_date is None else str(dx_date)
|
| 1658 |
continue
|
| 1659 |
|
|
|
|
|
|
|
|
|
|
| 1660 |
# ECOG mapped to int
|
| 1661 |
if f.strip() == "ECOG":
|
| 1662 |
values_by_index[i] = int(ecog)
|
|
|
|
| 30 |
|
| 31 |
#Figures setting block
|
| 32 |
import io
|
| 33 |
+
|
| 34 |
|
| 35 |
# REPLACE make_fig with this (or add this and stop using plt.plot directly)
|
| 36 |
def make_fig(figsize=(5.5, 3.6), dpi=120):
|
|
|
|
| 931 |
if "explainer" not in st.session_state:
|
| 932 |
st.session_state.explainer = None
|
| 933 |
|
| 934 |
+
|
| 935 |
+
# ---------------- TRAIN ----------------
|
| 936 |
with tab_train:
|
| 937 |
st.subheader("Train model")
|
| 938 |
|
|
|
|
| 964 |
st.divider()
|
| 965 |
|
| 966 |
# then keep your file uploader + training button + publish block here
|
| 967 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 968 |
|
|
|
|
|
|
|
| 969 |
|
| 970 |
+
train_file = st.file_uploader("Upload training Excel (.xlsx)", type=["xlsx"])
|
|
|
|
|
|
|
|
|
|
| 971 |
|
| 972 |
+
if train_file is None:
|
| 973 |
+
st.info("Upload a training Excel file to enable training.")
|
| 974 |
+
else:
|
| 975 |
|
| 976 |
+
df = pd.read_excel(train_file, engine="openpyxl")
|
| 977 |
+
feature_cols = get_feature_cols_from_df(df)
|
|
|
|
| 978 |
|
| 979 |
+
st.dataframe(df.head(), use_container_width=True)
|
| 980 |
+
feature_cols = get_feature_cols_from_df(df)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 981 |
|
| 982 |
+
st.markdown("### Choose variable types (saved into the model)")
|
| 983 |
+
default_numeric = feature_cols[:13] # initial suggestion
|
| 984 |
+
num_cols = st.multiselect(
|
| 985 |
+
"Numeric variables (will be median-imputed + scaled)",
|
| 986 |
+
options=feature_cols,
|
| 987 |
+
default=default_numeric
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
# Everything not selected as numeric becomes categorical
|
| 991 |
+
cat_cols = [c for c in feature_cols if c not in num_cols]
|
| 992 |
|
| 993 |
+
st.write(f"Categorical variables (will be most-frequent-imputed + one-hot): {len(cat_cols)}")
|
| 994 |
+
st.caption("Note: The selected schema is stored with the trained model and must match inference files.")
|
| 995 |
|
| 996 |
+
st.markdown("### Evaluation settings")
|
| 997 |
+
n_bins = st.slider("Calibration bins", 5, 20, 10, 1)
|
| 998 |
+
cal_strategy = st.selectbox("Calibration binning strategy", ["uniform", "quantile"], index=0)
|
| 999 |
+
|
| 1000 |
+
dca_points = st.slider("Decision curve points", 25, 200, 99, 1)
|
| 1001 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1002 |
|
|
|
|
| 1003 |
|
| 1004 |
+
if st.button("Train model"):
|
| 1005 |
+
with st.spinner("Training model..."):
|
| 1006 |
+
pipe, meta, X_bg, y_test, proba = train_and_save(
|
| 1007 |
+
df, feature_cols, num_cols, cat_cols,
|
| 1008 |
+
n_bins=n_bins, cal_strategy=cal_strategy, dca_points=dca_points,
|
| 1009 |
+
use_feature_selection=use_feature_selection,
|
| 1010 |
+
l1_C=l1_C,
|
| 1011 |
+
use_dimred=use_dimred,
|
| 1012 |
+
svd_components=svd_components
|
| 1013 |
+
)
|
| 1014 |
+
|
| 1015 |
+
explainer = build_shap_explainer(pipe, X_bg)
|
| 1016 |
+
|
| 1017 |
+
st.session_state.pipe = pipe
|
| 1018 |
+
st.session_state.explainer = explainer
|
| 1019 |
+
st.session_state.meta = meta
|
| 1020 |
+
|
| 1021 |
+
st.success("Training complete. model.joblib and meta.json created.")
|
| 1022 |
+
|
| 1023 |
+
|
| 1024 |
+
st.divider()
|
| 1025 |
+
st.subheader("Training performance (test split)")
|
| 1026 |
+
|
| 1027 |
+
m = meta["metrics"]
|
| 1028 |
+
|
| 1029 |
+
# Show key metrics at threshold 0.5
|
| 1030 |
+
c1, c2, c3, c4 = st.columns(4)
|
| 1031 |
+
c1.metric("ROC AUC", f"{m['roc_auc']:.3f}")
|
| 1032 |
+
c2.metric("Sensitivity (best F1 thr)", f"{m['sensitivity@best']:.3f}")
|
| 1033 |
+
c3.metric("Specificity (best F1 thr)", f"{m['specificity@best']:.3f}")
|
| 1034 |
+
c4.metric("F1 (best)", f"{m['f1@best']:.3f}")
|
| 1035 |
+
st.caption(f"Best threshold (max F1): {m['best_threshold']:.2f}")
|
| 1036 |
+
|
| 1037 |
+
|
| 1038 |
+
c5, c6, c7, c8 = st.columns(4)
|
| 1039 |
+
c5.metric("Precision", f"{m['precision@0.5']:.3f}")
|
| 1040 |
+
c6.metric("Accuracy", f"{m['accuracy@0.5']:.3f}")
|
| 1041 |
+
c7.metric("Balanced Acc", f"{m['balanced_accuracy@0.5']:.3f}")
|
| 1042 |
+
c8.metric("Test N", m["n_test"])
|
| 1043 |
|
| 1044 |
|
| 1045 |
|
| 1046 |
|
| 1047 |
|
| 1048 |
+
# Confusion matrix display
|
| 1049 |
+
cm = m["confusion_matrix@0.5"]
|
| 1050 |
+
cm_df = pd.DataFrame(
|
| 1051 |
+
[[cm["tn"], cm["fp"]], [cm["fn"], cm["tp"]]],
|
| 1052 |
+
index=["Actual 0", "Actual 1"],
|
| 1053 |
+
columns=["Pred 0", "Pred 1"]
|
| 1054 |
+
)
|
| 1055 |
+
st.markdown("**Confusion Matrix (threshold = 0.5)**")
|
| 1056 |
+
st.dataframe(cm_df)
|
| 1057 |
|
| 1058 |
|
| 1059 |
|
| 1060 |
|
| 1061 |
|
| 1062 |
+
# TRAINING: ROC curve plot
|
| 1063 |
+
# =========================
|
| 1064 |
+
roc = m["roc_curve"]
|
| 1065 |
+
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1066 |
+
ax.plot(roc["fpr"], roc["tpr"])
|
| 1067 |
+
ax.plot([0, 1], [0, 1])
|
| 1068 |
+
ax.set_xlabel("False Positive Rate (1 - Specificity)")
|
| 1069 |
+
ax.set_ylabel("True Positive Rate (Sensitivity)")
|
| 1070 |
+
ax.set_title(f"ROC Curve (AUC = {m['roc_auc']:.3f})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1071 |
|
| 1072 |
+
render_plot_with_download(
|
| 1073 |
+
fig,
|
| 1074 |
+
title="ROC curve",
|
| 1075 |
+
filename="roc_curve.png",
|
| 1076 |
+
export_dpi=export_dpi,
|
| 1077 |
+
key="dl_train_roc"
|
| 1078 |
+
)
|
| 1079 |
+
|
| 1080 |
+
|
| 1081 |
+
#Precision recall curve
|
| 1082 |
+
# =========================
|
| 1083 |
+
# TRAINING: PR curve plot
|
| 1084 |
+
# =========================
|
| 1085 |
+
pr = m["pr_curve"]
|
| 1086 |
+
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1087 |
+
ax.plot(pr["recall"], pr["precision"])
|
| 1088 |
+
ax.set_xlabel("Recall")
|
| 1089 |
+
ax.set_ylabel("Precision")
|
| 1090 |
+
ax.set_title(f"PR Curve (AP = {pr['average_precision']:.3f})")
|
| 1091 |
+
|
| 1092 |
+
render_plot_with_download(
|
| 1093 |
+
fig,
|
| 1094 |
+
title="PR curve",
|
| 1095 |
+
filename="pr_curve.png",
|
| 1096 |
+
export_dpi=export_dpi,
|
| 1097 |
+
key="dl_train_pr"
|
| 1098 |
+
)
|
| 1099 |
|
| 1100 |
|
| 1101 |
|
| 1102 |
+
#Calibration plot
|
| 1103 |
+
# =========================
|
| 1104 |
+
# TRAINING: Calibration plot
|
| 1105 |
+
# =========================
|
| 1106 |
+
cal = m["calibration"]
|
| 1107 |
+
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1108 |
+
ax.plot(cal["prob_pred"], cal["prob_true"])
|
| 1109 |
+
ax.plot([0, 1], [0, 1])
|
| 1110 |
+
ax.set_xlabel("Mean predicted probability")
|
| 1111 |
+
ax.set_ylabel("Observed event rate")
|
| 1112 |
+
ax.set_title("Calibration curve")
|
| 1113 |
+
|
| 1114 |
+
render_plot_with_download(
|
| 1115 |
+
fig,
|
| 1116 |
+
title="Calibration curve",
|
| 1117 |
+
filename="calibration_curve.png",
|
| 1118 |
+
export_dpi=export_dpi,
|
| 1119 |
+
key="dl_train_cal"
|
| 1120 |
+
)
|
| 1121 |
|
| 1122 |
|
| 1123 |
|
| 1124 |
+
#Decision curve
|
| 1125 |
+
# =========================
|
| 1126 |
+
# TRAINING: Decision curve analysis plot
|
| 1127 |
+
# =========================
|
| 1128 |
+
dca = m["decision_curve"]
|
| 1129 |
+
fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
|
| 1130 |
+
ax.plot(dca["thresholds"], dca["net_benefit_model"], label="Model")
|
| 1131 |
+
ax.plot(dca["thresholds"], dca["net_benefit_all"], label="Treat all")
|
| 1132 |
+
ax.plot(dca["thresholds"], dca["net_benefit_none"], label="Treat none")
|
| 1133 |
+
ax.set_xlabel("Threshold probability")
|
| 1134 |
+
ax.set_ylabel("Net benefit")
|
| 1135 |
+
ax.set_title("Decision curve analysis")
|
| 1136 |
+
ax.legend()
|
| 1137 |
+
|
| 1138 |
+
render_plot_with_download(
|
| 1139 |
+
fig,
|
| 1140 |
+
title="Decision curve",
|
| 1141 |
+
filename="decision_curve.png",
|
| 1142 |
+
export_dpi=export_dpi,
|
| 1143 |
+
key="dl_train_dca"
|
| 1144 |
+
)
|
| 1145 |
+
|
| 1146 |
+
|
| 1147 |
+
st.caption(
|
| 1148 |
+
"If the model curve is above Treat-all and Treat-none across a threshold range, "
|
| 1149 |
+
"the model provides net clinical benefit in that range."
|
| 1150 |
+
)
|
| 1151 |
|
| 1152 |
|
| 1153 |
|
| 1154 |
+
|
| 1155 |
+
st.divider()
|
| 1156 |
+
st.subheader("Threshold analysis")
|
| 1157 |
|
| 1158 |
+
thr = st.slider("Decision threshold", 0.0, 1.0, 0.5, 0.01)
|
|
|
|
| 1159 |
|
| 1160 |
+
# Recompute threshold-based metrics quickly using stored probabilities
|
| 1161 |
+
# You need y_test and proba in scope. Easiest is to store them in session_state during training.
|
| 1162 |
+
st.session_state.y_test_last = y_test
|
| 1163 |
+
st.session_state.proba_last = proba
|
| 1164 |
+
if "y_test_last" in st.session_state and "proba_last" in st.session_state:
|
| 1165 |
+
cls = compute_classification_metrics(st.session_state.y_test_last, st.session_state.proba_last, threshold=thr)
|
| 1166 |
+
st.write({
|
| 1167 |
+
"Sensitivity": cls["sensitivity"],
|
| 1168 |
+
"Specificity": cls["specificity"],
|
| 1169 |
+
"Precision": cls["precision"],
|
| 1170 |
+
"Recall": cls["recall"],
|
| 1171 |
+
"F1": cls["f1"],
|
| 1172 |
+
"Accuracy": cls["accuracy"],
|
| 1173 |
+
"Balanced Accuracy": cls["balanced_accuracy"],
|
| 1174 |
+
})
|
|
|
|
|
|
|
| 1175 |
|
| 1176 |
|
| 1177 |
|
| 1178 |
|
| 1179 |
|
| 1180 |
+
# ---------------- PUBLISH (only after training) ----------------
|
| 1181 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1182 |
|
| 1183 |
+
if st.session_state.get("pipe") is not None:
|
| 1184 |
+
st.divider()
|
| 1185 |
+
st.subheader("Publish trained model to Hugging Face Hub")
|
| 1186 |
+
|
| 1187 |
+
default_version = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
| 1188 |
+
version_tag = st.text_input(
|
| 1189 |
+
"Version tag",
|
| 1190 |
+
value=default_version,
|
| 1191 |
+
help="Used as releases/<version>/ in the model repository",
|
| 1192 |
+
)
|
| 1193 |
+
|
| 1194 |
+
if st.button("Publish model.joblib + meta.json to Model Repo"):
|
| 1195 |
+
try:
|
| 1196 |
+
with st.spinner("Uploading to Hugging Face Model repo..."):
|
| 1197 |
+
paths = publish_to_hub(MODEL_REPO_ID, version_tag)
|
| 1198 |
+
|
| 1199 |
+
st.success("Uploaded successfully to your model repository.")
|
| 1200 |
+
st.json(paths)
|
| 1201 |
+
except Exception as e:
|
| 1202 |
+
st.error(f"Upload failed: {e}")
|
| 1203 |
+
|
| 1204 |
|
| 1205 |
|
| 1206 |
# ---------------- PREDICT ----------------
|
|
|
|
| 1485 |
TAB_MANAGED_FIELDS.add(NGS_COUNT_COL)
|
| 1486 |
|
| 1487 |
# Age + key dates are handled by DOB/Dx inputs, not generic UI
|
| 1488 |
+
|
|
|
|
| 1489 |
if "Date of 1st CR" in feature_cols:
|
| 1490 |
TAB_MANAGED_FIELDS.add("Date of 1st CR")
|
| 1491 |
|
|
|
|
| 1607 |
continue
|
| 1608 |
|
| 1609 |
# Age auto-calc (display integer, store float)
|
| 1610 |
+
# --- Age (auto from DOB & Dx date) ---
|
| 1611 |
+
if f.strip() == AGE_FEATURE.strip():
|
| 1612 |
if np.isnan(derived_age):
|
| 1613 |
st.number_input(
|
| 1614 |
f"{f} (auto from DOB & Dx date)",
|
|
|
|
| 1630 |
)
|
| 1631 |
values_by_index[i] = float(derived_age)
|
| 1632 |
continue
|
| 1633 |
+
|
| 1634 |
+
# --- Diagnosis date (auto from dx_date input) ---
|
| 1635 |
+
if f.strip() == DX_DATE_FEATURE.strip():
|
| 1636 |
values_by_index[i] = np.nan if dx_date is None else dx_date.isoformat()
|
| 1637 |
+
st.text_input(
|
| 1638 |
+
f"{f} (auto)",
|
| 1639 |
+
value="" if dx_date is None else dx_date.isoformat(),
|
| 1640 |
+
disabled=True,
|
| 1641 |
+
key=f"sp_{i}_dx_show"
|
| 1642 |
+
)
|
| 1643 |
continue
|
| 1644 |
+
|
| 1645 |
|
| 1646 |
if f.strip() == "Date of 1st CR".strip():
|
| 1647 |
values_by_index[i] = np.nan if cr1_date is None else cr1_date.isoformat()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1648 |
st.text_input(
|
| 1649 |
f"{f} (auto)",
|
| 1650 |
+
value="" if cr1_date is None else cr1_date.isoformat(),
|
|
|
|
| 1651 |
disabled=True,
|
| 1652 |
+
key=f"sp_{i}_cr_show"
|
| 1653 |
)
|
|
|
|
| 1654 |
continue
|
| 1655 |
|
| 1656 |
+
|
| 1657 |
+
|
| 1658 |
+
|
| 1659 |
# ECOG mapped to int
|
| 1660 |
if f.strip() == "ECOG":
|
| 1661 |
values_by_index[i] = int(ecog)
|