Synav commited on
Commit
529c5f8
·
verified ·
1 Parent(s): fc931be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +220 -221
app.py CHANGED
@@ -30,7 +30,7 @@ from sklearn.model_selection import train_test_split
30
 
31
  #Figures setting block
32
  import io
33
- import matplotlib.pyplot as plt
34
 
35
  # REPLACE make_fig with this (or add this and stop using plt.plot directly)
36
  def make_fig(figsize=(5.5, 3.6), dpi=120):
@@ -931,6 +931,8 @@ if "pipe" not in st.session_state:
931
  if "explainer" not in st.session_state:
932
  st.session_state.explainer = None
933
 
 
 
934
  with tab_train:
935
  st.subheader("Train model")
936
 
@@ -962,250 +964,243 @@ with tab_train:
962
  st.divider()
963
 
964
  # then keep your file uploader + training button + publish block here
965
- ...
966
-
967
-
968
- # ---------------- TRAIN ----------------
969
-
970
-
971
 
972
- with tab_train:
973
- st.subheader("Train model")
974
 
975
- if not is_admin():
976
- st.info("Training and publishing are restricted. Use Predict + SHAP for inference.")
977
- else:
978
- train_file = st.file_uploader("Upload training Excel (.xlsx)", type=["xlsx"])
979
 
980
- if train_file is not None:
981
- df = pd.read_excel(train_file, engine="openpyxl")
982
-
983
 
984
- feature_cols = get_feature_cols_from_df(df)
985
- st.dataframe(df.head())
986
- feature_cols = get_feature_cols_from_df(df)
987
 
988
- st.markdown("### Choose variable types (saved into the model)")
989
- default_numeric = feature_cols[:13] # initial suggestion
990
-
991
- num_cols = st.multiselect(
992
- "Numeric variables (will be median-imputed + scaled)",
993
- options=feature_cols,
994
- default=default_numeric
995
- )
996
-
997
- # Everything not selected as numeric becomes categorical
998
- cat_cols = [c for c in feature_cols if c not in num_cols]
999
-
1000
- st.write(f"Categorical variables (will be most-frequent-imputed + one-hot): {len(cat_cols)}")
1001
- st.caption("Note: The selected schema is stored with the trained model and must match inference files.")
1002
 
1003
- st.markdown("### Evaluation settings")
1004
- n_bins = st.slider("Calibration bins", 5, 20, 10, 1)
1005
- cal_strategy = st.selectbox("Calibration binning strategy", ["uniform", "quantile"], index=0)
 
 
 
 
 
 
 
1006
 
1007
- dca_points = st.slider("Decision curve points", 25, 200, 99, 1)
1008
-
1009
 
1010
- if st.button("Train model"):
1011
- with st.spinner("Training model..."):
1012
- pipe, meta, X_bg, y_test, proba = train_and_save(
1013
- df, feature_cols, num_cols, cat_cols,
1014
- n_bins=n_bins, cal_strategy=cal_strategy, dca_points=dca_points,
1015
- use_feature_selection=use_feature_selection,
1016
- l1_C=l1_C,
1017
- use_dimred=use_dimred,
1018
- svd_components=svd_components
1019
- )
1020
-
1021
- explainer = build_shap_explainer(pipe, X_bg)
1022
-
1023
- st.session_state.pipe = pipe
1024
- st.session_state.explainer = explainer
1025
- st.session_state.meta = meta
1026
 
1027
- st.success("Training complete. model.joblib and meta.json created.")
1028
 
1029
-
1030
- st.divider()
1031
- st.subheader("Training performance (test split)")
1032
-
1033
- m = meta["metrics"]
1034
-
1035
- # Show key metrics at threshold 0.5
1036
- c1, c2, c3, c4 = st.columns(4)
1037
- c1.metric("ROC AUC", f"{m['roc_auc']:.3f}")
1038
- c2.metric("Sensitivity (best F1 thr)", f"{m['sensitivity@best']:.3f}")
1039
- c3.metric("Specificity (best F1 thr)", f"{m['specificity@best']:.3f}")
1040
- c4.metric("F1 (best)", f"{m['f1@best']:.3f}")
1041
- st.caption(f"Best threshold (max F1): {m['best_threshold']:.2f}")
1042
-
1043
-
1044
- c5, c6, c7, c8 = st.columns(4)
1045
- c5.metric("Precision", f"{m['precision@0.5']:.3f}")
1046
- c6.metric("Accuracy", f"{m['accuracy@0.5']:.3f}")
1047
- c7.metric("Balanced Acc", f"{m['balanced_accuracy@0.5']:.3f}")
1048
- c8.metric("Test N", m["n_test"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1049
 
1050
 
1051
 
1052
 
1053
 
1054
- # Confusion matrix display
1055
- cm = m["confusion_matrix@0.5"]
1056
- cm_df = pd.DataFrame(
1057
- [[cm["tn"], cm["fp"]], [cm["fn"], cm["tp"]]],
1058
- index=["Actual 0", "Actual 1"],
1059
- columns=["Pred 0", "Pred 1"]
1060
- )
1061
- st.markdown("**Confusion Matrix (threshold = 0.5)**")
1062
- st.dataframe(cm_df)
1063
 
1064
 
1065
 
1066
 
1067
 
1068
- # TRAINING: ROC curve plot
1069
- # =========================
1070
- roc = m["roc_curve"]
1071
- fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1072
- ax.plot(roc["fpr"], roc["tpr"])
1073
- ax.plot([0, 1], [0, 1])
1074
- ax.set_xlabel("False Positive Rate (1 - Specificity)")
1075
- ax.set_ylabel("True Positive Rate (Sensitivity)")
1076
- ax.set_title(f"ROC Curve (AUC = {m['roc_auc']:.3f})")
1077
-
1078
- render_plot_with_download(
1079
- fig,
1080
- title="ROC curve",
1081
- filename="roc_curve.png",
1082
- export_dpi=export_dpi,
1083
- key="dl_train_roc"
1084
- )
1085
-
1086
-
1087
- #Precision recall curve
1088
- # =========================
1089
- # TRAINING: PR curve plot
1090
- # =========================
1091
- pr = m["pr_curve"]
1092
- fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1093
- ax.plot(pr["recall"], pr["precision"])
1094
- ax.set_xlabel("Recall")
1095
- ax.set_ylabel("Precision")
1096
- ax.set_title(f"PR Curve (AP = {pr['average_precision']:.3f})")
1097
 
1098
- render_plot_with_download(
1099
- fig,
1100
- title="PR curve",
1101
- filename="pr_curve.png",
1102
- export_dpi=export_dpi,
1103
- key="dl_train_pr"
1104
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1105
 
1106
 
1107
 
1108
- #Calibration plot
1109
- # =========================
1110
- # TRAINING: Calibration plot
1111
- # =========================
1112
- cal = m["calibration"]
1113
- fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1114
- ax.plot(cal["prob_pred"], cal["prob_true"])
1115
- ax.plot([0, 1], [0, 1])
1116
- ax.set_xlabel("Mean predicted probability")
1117
- ax.set_ylabel("Observed event rate")
1118
- ax.set_title("Calibration curve")
1119
-
1120
- render_plot_with_download(
1121
- fig,
1122
- title="Calibration curve",
1123
- filename="calibration_curve.png",
1124
- export_dpi=export_dpi,
1125
- key="dl_train_cal"
1126
- )
1127
 
1128
 
1129
 
1130
- #Decision curve
1131
- # =========================
1132
- # TRAINING: Decision curve analysis plot
1133
- # =========================
1134
- dca = m["decision_curve"]
1135
- fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1136
- ax.plot(dca["thresholds"], dca["net_benefit_model"], label="Model")
1137
- ax.plot(dca["thresholds"], dca["net_benefit_all"], label="Treat all")
1138
- ax.plot(dca["thresholds"], dca["net_benefit_none"], label="Treat none")
1139
- ax.set_xlabel("Threshold probability")
1140
- ax.set_ylabel("Net benefit")
1141
- ax.set_title("Decision curve analysis")
1142
- ax.legend()
1143
-
1144
- render_plot_with_download(
1145
- fig,
1146
- title="Decision curve",
1147
- filename="decision_curve.png",
1148
- export_dpi=export_dpi,
1149
- key="dl_train_dca"
1150
- )
1151
-
1152
-
1153
- st.caption(
1154
- "If the model curve is above Treat-all and Treat-none across a threshold range, "
1155
- "the model provides net clinical benefit in that range."
1156
- )
1157
 
1158
 
1159
 
 
 
 
1160
 
1161
- st.divider()
1162
- st.subheader("Threshold analysis")
1163
 
1164
- thr = st.slider("Decision threshold", 0.0, 1.0, 0.5, 0.01)
1165
-
1166
- # Recompute threshold-based metrics quickly using stored probabilities
1167
- # You need y_test and proba in scope. Easiest is to store them in session_state during training.
1168
- st.session_state.y_test_last = y_test
1169
- st.session_state.proba_last = proba
1170
- if "y_test_last" in st.session_state and "proba_last" in st.session_state:
1171
- cls = compute_classification_metrics(st.session_state.y_test_last, st.session_state.proba_last, threshold=thr)
1172
- st.write({
1173
- "Sensitivity": cls["sensitivity"],
1174
- "Specificity": cls["specificity"],
1175
- "Precision": cls["precision"],
1176
- "Recall": cls["recall"],
1177
- "F1": cls["f1"],
1178
- "Accuracy": cls["accuracy"],
1179
- "Balanced Accuracy": cls["balanced_accuracy"],
1180
- })
1181
 
1182
 
1183
 
1184
 
1185
 
1186
- # ---------------- PUBLISH (only after training) ----------------
1187
-
1188
-
1189
- if st.session_state.get("pipe") is not None:
1190
- st.divider()
1191
- st.subheader("Publish trained model to Hugging Face Hub")
1192
-
1193
- default_version = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
1194
- version_tag = st.text_input(
1195
- "Version tag",
1196
- value=default_version,
1197
- help="Used as releases/<version>/ in the model repository",
1198
- )
1199
 
1200
- if st.button("Publish model.joblib + meta.json to Model Repo"):
1201
- try:
1202
- with st.spinner("Uploading to Hugging Face Model repo..."):
1203
- paths = publish_to_hub(MODEL_REPO_ID, version_tag)
1204
-
1205
- st.success("Uploaded successfully to your model repository.")
1206
- st.json(paths)
1207
- except Exception as e:
1208
- st.error(f"Upload failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
1209
 
1210
 
1211
  # ---------------- PREDICT ----------------
@@ -1490,8 +1485,7 @@ with tab_predict:
1490
  TAB_MANAGED_FIELDS.add(NGS_COUNT_COL)
1491
 
1492
  # Age + key dates are handled by DOB/Dx inputs, not generic UI
1493
- TAB_MANAGED_FIELDS.add(AGE_FEATURE)
1494
- TAB_MANAGED_FIELDS.add(DX_DATE_FEATURE)
1495
  if "Date of 1st CR" in feature_cols:
1496
  TAB_MANAGED_FIELDS.add("Date of 1st CR")
1497
 
@@ -1613,7 +1607,8 @@ with tab_predict:
1613
  continue
1614
 
1615
  # Age auto-calc (display integer, store float)
1616
- if f == AGE_FEATURE:
 
1617
  if np.isnan(derived_age):
1618
  st.number_input(
1619
  f"{f} (auto from DOB & Dx date)",
@@ -1635,28 +1630,32 @@ with tab_predict:
1635
  )
1636
  values_by_index[i] = float(derived_age)
1637
  continue
1638
- if f.strip() == "Date of 1st Bone Marrow biopsy (Date of Diagnosis)".strip():
 
 
1639
  values_by_index[i] = np.nan if dx_date is None else dx_date.isoformat()
1640
- st.text_input(f, value="" if dx_date is None else dx_date.isoformat(), disabled=True, key=f"sp_{i}_dx_show")
 
 
 
 
 
1641
  continue
 
1642
 
1643
  if f.strip() == "Date of 1st CR".strip():
1644
  values_by_index[i] = np.nan if cr1_date is None else cr1_date.isoformat()
1645
- st.text_input(f, value="" if cr1_date is None else cr1_date.isoformat(), disabled=True, key=f"sp_{i}_cr_show")
1646
- continue
1647
-
1648
-
1649
- # Dx date text (store string)
1650
- if f.strip() == DX_DATE_FEATURE.strip():
1651
  st.text_input(
1652
  f"{f} (auto)",
1653
- value="" if dx_date is None else str(dx_date),
1654
- key=f"sp_{i}_dx",
1655
  disabled=True,
 
1656
  )
1657
- values_by_index[i] = np.nan if dx_date is None else str(dx_date)
1658
  continue
1659
 
 
 
 
1660
  # ECOG mapped to int
1661
  if f.strip() == "ECOG":
1662
  values_by_index[i] = int(ecog)
 
30
 
31
  #Figures setting block
32
  import io
33
+
34
 
35
  # REPLACE make_fig with this (or add this and stop using plt.plot directly)
36
  def make_fig(figsize=(5.5, 3.6), dpi=120):
 
931
  if "explainer" not in st.session_state:
932
  st.session_state.explainer = None
933
 
934
+
935
+ # ---------------- TRAIN ----------------
936
  with tab_train:
937
  st.subheader("Train model")
938
 
 
964
  st.divider()
965
 
966
  # then keep your file uploader + training button + publish block here
967
+
 
 
 
 
 
968
 
 
 
969
 
970
+ train_file = st.file_uploader("Upload training Excel (.xlsx)", type=["xlsx"])
 
 
 
971
 
972
+ if train_file is None:
973
+ st.info("Upload a training Excel file to enable training.")
974
+ else:
975
 
976
+ df = pd.read_excel(train_file, engine="openpyxl")
977
+ feature_cols = get_feature_cols_from_df(df)
 
978
 
979
+ st.dataframe(df.head(), use_container_width=True)
980
+ feature_cols = get_feature_cols_from_df(df)
 
 
 
 
 
 
 
 
 
 
 
 
981
 
982
+ st.markdown("### Choose variable types (saved into the model)")
983
+ default_numeric = feature_cols[:13] # initial suggestion
984
+ num_cols = st.multiselect(
985
+ "Numeric variables (will be median-imputed + scaled)",
986
+ options=feature_cols,
987
+ default=default_numeric
988
+ )
989
+
990
+ # Everything not selected as numeric becomes categorical
991
+ cat_cols = [c for c in feature_cols if c not in num_cols]
992
 
993
+ st.write(f"Categorical variables (will be most-frequent-imputed + one-hot): {len(cat_cols)}")
994
+ st.caption("Note: The selected schema is stored with the trained model and must match inference files.")
995
 
996
+ st.markdown("### Evaluation settings")
997
+ n_bins = st.slider("Calibration bins", 5, 20, 10, 1)
998
+ cal_strategy = st.selectbox("Calibration binning strategy", ["uniform", "quantile"], index=0)
999
+
1000
+ dca_points = st.slider("Decision curve points", 25, 200, 99, 1)
1001
+
 
 
 
 
 
 
 
 
 
 
1002
 
 
1003
 
1004
+ if st.button("Train model"):
1005
+ with st.spinner("Training model..."):
1006
+ pipe, meta, X_bg, y_test, proba = train_and_save(
1007
+ df, feature_cols, num_cols, cat_cols,
1008
+ n_bins=n_bins, cal_strategy=cal_strategy, dca_points=dca_points,
1009
+ use_feature_selection=use_feature_selection,
1010
+ l1_C=l1_C,
1011
+ use_dimred=use_dimred,
1012
+ svd_components=svd_components
1013
+ )
1014
+
1015
+ explainer = build_shap_explainer(pipe, X_bg)
1016
+
1017
+ st.session_state.pipe = pipe
1018
+ st.session_state.explainer = explainer
1019
+ st.session_state.meta = meta
1020
+
1021
+ st.success("Training complete. model.joblib and meta.json created.")
1022
+
1023
+
1024
+ st.divider()
1025
+ st.subheader("Training performance (test split)")
1026
+
1027
+ m = meta["metrics"]
1028
+
1029
+ # Show key metrics at threshold 0.5
1030
+ c1, c2, c3, c4 = st.columns(4)
1031
+ c1.metric("ROC AUC", f"{m['roc_auc']:.3f}")
1032
+ c2.metric("Sensitivity (best F1 thr)", f"{m['sensitivity@best']:.3f}")
1033
+ c3.metric("Specificity (best F1 thr)", f"{m['specificity@best']:.3f}")
1034
+ c4.metric("F1 (best)", f"{m['f1@best']:.3f}")
1035
+ st.caption(f"Best threshold (max F1): {m['best_threshold']:.2f}")
1036
+
1037
+
1038
+ c5, c6, c7, c8 = st.columns(4)
1039
+ c5.metric("Precision", f"{m['precision@0.5']:.3f}")
1040
+ c6.metric("Accuracy", f"{m['accuracy@0.5']:.3f}")
1041
+ c7.metric("Balanced Acc", f"{m['balanced_accuracy@0.5']:.3f}")
1042
+ c8.metric("Test N", m["n_test"])
1043
 
1044
 
1045
 
1046
 
1047
 
1048
+ # Confusion matrix display
1049
+ cm = m["confusion_matrix@0.5"]
1050
+ cm_df = pd.DataFrame(
1051
+ [[cm["tn"], cm["fp"]], [cm["fn"], cm["tp"]]],
1052
+ index=["Actual 0", "Actual 1"],
1053
+ columns=["Pred 0", "Pred 1"]
1054
+ )
1055
+ st.markdown("**Confusion Matrix (threshold = 0.5)**")
1056
+ st.dataframe(cm_df)
1057
 
1058
 
1059
 
1060
 
1061
 
1062
+ # TRAINING: ROC curve plot
1063
+ # =========================
1064
+ roc = m["roc_curve"]
1065
+ fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1066
+ ax.plot(roc["fpr"], roc["tpr"])
1067
+ ax.plot([0, 1], [0, 1])
1068
+ ax.set_xlabel("False Positive Rate (1 - Specificity)")
1069
+ ax.set_ylabel("True Positive Rate (Sensitivity)")
1070
+ ax.set_title(f"ROC Curve (AUC = {m['roc_auc']:.3f})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1071
 
1072
+ render_plot_with_download(
1073
+ fig,
1074
+ title="ROC curve",
1075
+ filename="roc_curve.png",
1076
+ export_dpi=export_dpi,
1077
+ key="dl_train_roc"
1078
+ )
1079
+
1080
+
1081
+ #Precision recall curve
1082
+ # =========================
1083
+ # TRAINING: PR curve plot
1084
+ # =========================
1085
+ pr = m["pr_curve"]
1086
+ fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1087
+ ax.plot(pr["recall"], pr["precision"])
1088
+ ax.set_xlabel("Recall")
1089
+ ax.set_ylabel("Precision")
1090
+ ax.set_title(f"PR Curve (AP = {pr['average_precision']:.3f})")
1091
+
1092
+ render_plot_with_download(
1093
+ fig,
1094
+ title="PR curve",
1095
+ filename="pr_curve.png",
1096
+ export_dpi=export_dpi,
1097
+ key="dl_train_pr"
1098
+ )
1099
 
1100
 
1101
 
1102
+ #Calibration plot
1103
+ # =========================
1104
+ # TRAINING: Calibration plot
1105
+ # =========================
1106
+ cal = m["calibration"]
1107
+ fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1108
+ ax.plot(cal["prob_pred"], cal["prob_true"])
1109
+ ax.plot([0, 1], [0, 1])
1110
+ ax.set_xlabel("Mean predicted probability")
1111
+ ax.set_ylabel("Observed event rate")
1112
+ ax.set_title("Calibration curve")
1113
+
1114
+ render_plot_with_download(
1115
+ fig,
1116
+ title="Calibration curve",
1117
+ filename="calibration_curve.png",
1118
+ export_dpi=export_dpi,
1119
+ key="dl_train_cal"
1120
+ )
1121
 
1122
 
1123
 
1124
+ #Decision curve
1125
+ # =========================
1126
+ # TRAINING: Decision curve analysis plot
1127
+ # =========================
1128
+ dca = m["decision_curve"]
1129
+ fig, ax = make_fig(figsize=FIGSIZE, dpi=plot_dpi_screen)
1130
+ ax.plot(dca["thresholds"], dca["net_benefit_model"], label="Model")
1131
+ ax.plot(dca["thresholds"], dca["net_benefit_all"], label="Treat all")
1132
+ ax.plot(dca["thresholds"], dca["net_benefit_none"], label="Treat none")
1133
+ ax.set_xlabel("Threshold probability")
1134
+ ax.set_ylabel("Net benefit")
1135
+ ax.set_title("Decision curve analysis")
1136
+ ax.legend()
1137
+
1138
+ render_plot_with_download(
1139
+ fig,
1140
+ title="Decision curve",
1141
+ filename="decision_curve.png",
1142
+ export_dpi=export_dpi,
1143
+ key="dl_train_dca"
1144
+ )
1145
+
1146
+
1147
+ st.caption(
1148
+ "If the model curve is above Treat-all and Treat-none across a threshold range, "
1149
+ "the model provides net clinical benefit in that range."
1150
+ )
1151
 
1152
 
1153
 
1154
+
1155
+ st.divider()
1156
+ st.subheader("Threshold analysis")
1157
 
1158
+ thr = st.slider("Decision threshold", 0.0, 1.0, 0.5, 0.01)
 
1159
 
1160
+ # Recompute threshold-based metrics quickly using stored probabilities
1161
+ # You need y_test and proba in scope. Easiest is to store them in session_state during training.
1162
+ st.session_state.y_test_last = y_test
1163
+ st.session_state.proba_last = proba
1164
+ if "y_test_last" in st.session_state and "proba_last" in st.session_state:
1165
+ cls = compute_classification_metrics(st.session_state.y_test_last, st.session_state.proba_last, threshold=thr)
1166
+ st.write({
1167
+ "Sensitivity": cls["sensitivity"],
1168
+ "Specificity": cls["specificity"],
1169
+ "Precision": cls["precision"],
1170
+ "Recall": cls["recall"],
1171
+ "F1": cls["f1"],
1172
+ "Accuracy": cls["accuracy"],
1173
+ "Balanced Accuracy": cls["balanced_accuracy"],
1174
+ })
 
 
1175
 
1176
 
1177
 
1178
 
1179
 
1180
+ # ---------------- PUBLISH (only after training) ----------------
1181
+
 
 
 
 
 
 
 
 
 
 
 
1182
 
1183
+ if st.session_state.get("pipe") is not None:
1184
+ st.divider()
1185
+ st.subheader("Publish trained model to Hugging Face Hub")
1186
+
1187
+ default_version = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
1188
+ version_tag = st.text_input(
1189
+ "Version tag",
1190
+ value=default_version,
1191
+ help="Used as releases/<version>/ in the model repository",
1192
+ )
1193
+
1194
+ if st.button("Publish model.joblib + meta.json to Model Repo"):
1195
+ try:
1196
+ with st.spinner("Uploading to Hugging Face Model repo..."):
1197
+ paths = publish_to_hub(MODEL_REPO_ID, version_tag)
1198
+
1199
+ st.success("Uploaded successfully to your model repository.")
1200
+ st.json(paths)
1201
+ except Exception as e:
1202
+ st.error(f"Upload failed: {e}")
1203
+
1204
 
1205
 
1206
  # ---------------- PREDICT ----------------
 
1485
  TAB_MANAGED_FIELDS.add(NGS_COUNT_COL)
1486
 
1487
  # Age + key dates are handled by DOB/Dx inputs, not generic UI
1488
+
 
1489
  if "Date of 1st CR" in feature_cols:
1490
  TAB_MANAGED_FIELDS.add("Date of 1st CR")
1491
 
 
1607
  continue
1608
 
1609
  # Age auto-calc (display integer, store float)
1610
+ # --- Age (auto from DOB & Dx date) ---
1611
+ if f.strip() == AGE_FEATURE.strip():
1612
  if np.isnan(derived_age):
1613
  st.number_input(
1614
  f"{f} (auto from DOB & Dx date)",
 
1630
  )
1631
  values_by_index[i] = float(derived_age)
1632
  continue
1633
+
1634
+ # --- Diagnosis date (auto from dx_date input) ---
1635
+ if f.strip() == DX_DATE_FEATURE.strip():
1636
  values_by_index[i] = np.nan if dx_date is None else dx_date.isoformat()
1637
+ st.text_input(
1638
+ f"{f} (auto)",
1639
+ value="" if dx_date is None else dx_date.isoformat(),
1640
+ disabled=True,
1641
+ key=f"sp_{i}_dx_show"
1642
+ )
1643
  continue
1644
+
1645
 
1646
  if f.strip() == "Date of 1st CR".strip():
1647
  values_by_index[i] = np.nan if cr1_date is None else cr1_date.isoformat()
 
 
 
 
 
 
1648
  st.text_input(
1649
  f"{f} (auto)",
1650
+ value="" if cr1_date is None else cr1_date.isoformat(),
 
1651
  disabled=True,
1652
+ key=f"sp_{i}_cr_show"
1653
  )
 
1654
  continue
1655
 
1656
+
1657
+
1658
+
1659
  # ECOG mapped to int
1660
  if f.strip() == "ECOG":
1661
  values_by_index[i] = int(ecog)