Synav commited on
Commit
598d303
·
verified ·
1 Parent(s): 6bbc45c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -62
app.py CHANGED
@@ -145,7 +145,7 @@ def build_pipeline(
145
 
146
  cat_pipe = Pipeline([
147
  ("imputer", SimpleImputer(strategy="most_frequent")),
148
- ("onehot", OneHotEncoder(handle_unknown="ignore", sparse_output=True, drop="first"))
149
  ])
150
 
151
  preprocessor = ColumnTransformer(
@@ -813,7 +813,7 @@ st.warning(
813
  with st.expander("Admin controls", expanded=False):
814
  st.text_input("Admin key", type="password", key="admin_key")
815
  st.caption("Training and publishing are enabled only after admin authentication.")
816
-
817
 
818
 
819
  tab_train, tab_predict = st.tabs(["1️⃣ Train", "2️⃣ Predict + SHAP"])
@@ -823,34 +823,40 @@ if "pipe" not in st.session_state:
823
  if "explainer" not in st.session_state:
824
  st.session_state.explainer = None
825
 
 
 
826
 
 
 
 
827
 
 
828
 
829
- st.markdown("### Feature reduction options")
830
-
831
- use_feature_selection = st.checkbox(
832
- "Drop columns that do not affect prediction (L1 feature selection)",
833
- value=True
834
- )
835
-
836
- l1_C = st.slider(
837
- "L1 selection strength (lower = fewer features)",
838
- 0.01, 10.0, 1.0, 0.01
839
- ) if use_feature_selection else 1.0
840
-
841
- use_dimred = st.checkbox(
842
- "Dimensionality reduction (TruncatedSVD) — reduces interpretability",
843
- value=False
844
- )
 
 
 
845
 
846
- svd_components = st.slider(
847
- "SVD components (only used if enabled)",
848
- 5, 300, 50, 5
849
- ) if use_dimred else 50
850
 
851
  # ---------------- TRAIN ----------------
852
 
853
- with tab_train:
854
 
855
  st.subheader("Train model")
856
  if not is_admin():
@@ -1133,7 +1139,7 @@ with tab_predict:
1133
  st.divider()
1134
  if st.session_state.pipe is None:
1135
  st.warning("Load a model version above, then upload an inference Excel.")
1136
- st.stop()
1137
 
1138
  pipe = st.session_state.pipe
1139
 
@@ -1258,23 +1264,23 @@ with tab_predict:
1258
  # --- header dates ---
1259
  c1, c2 = st.columns(2)
1260
  with c1:
1261
- dob = st.date_input(
1262
- "Date of birth (DOB)",
1263
- value=None,
1264
- min_value=MIN_DOB,
1265
- max_value=date.today(),
1266
- key="sp_dob",
1267
- )
1268
  with c2:
1269
- dx_date = st.date_input(
1270
- "Date of Diagnosis / 1st Bone Marrow biopsy",
1271
- value=None,
1272
- min_value=MIN_DOB,
1273
- max_value=date.today(),
1274
- key="sp_dx_date",
1275
- )
 
1276
 
1277
- derived_age = age_years_at(dob, dx_date) # float or nan
 
1278
 
1279
  def yesno_to_01(v: str):
1280
  if v == "Yes":
@@ -1393,14 +1399,7 @@ with tab_predict:
1393
  v = st.text_input(f, value="", key=f"sp_{i}_other")
1394
  values_by_index[i] = np.nan if v.strip() == "" else v
1395
 
1396
- with tab_clin:
1397
- st.caption("Clinical flags: Yes=1, No=0")
1398
- for i, f in enumerate(feature_cols):
1399
- if f in YESNO_FIELDS:
1400
- v = st.selectbox(f, options=["", "No", "Yes"], index=0, key=f"sp_{i}_yn")
1401
- values_by_index[i] = yesno_to_01(v)
1402
-
1403
-
1404
  # Apply FISH/NGS selections to row
1405
  fish_set = set(fish_selected)
1406
  ngs_set = set(ngs_selected)
@@ -1529,11 +1528,8 @@ with tab_predict:
1529
  cls_ext = compute_classification_metrics(y_ext01, proba, threshold=float(thr_ext))
1530
 
1531
  pr_ext = compute_pr_curve(y_ext01, proba)
1532
- cal_ext = compute_calibration(
1533
- y_ext01, proba,
1534
- n_bins=int(n_bins) if "n_bins" in locals() else 10,
1535
- strategy=str(cal_strategy) if "cal_strategy" in locals() else "uniform"
1536
- )
1537
  dca_ext = decision_curve_analysis(y_ext01, proba)
1538
 
1539
  # Display headline metrics
@@ -1819,16 +1815,17 @@ with tab_predict:
1819
 
1820
 
1821
  # BEESWARM SUMMARY (optional)
1822
- plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
1823
- shap.summary_plot(
1824
- shap_vals_batch,
1825
- features=X_dense,
1826
- feature_names=names,
1827
- max_display=max_display,
1828
- show=False
1829
- )
1830
- fig_swarm = plt.gcf()
1831
- render_plot_with_download(fig_swarm, title="SHAP beeswarm", filename="shap_beeswarm.png", export_dpi=export_dpi)
 
1832
 
1833
 
1834
 
 
145
 
146
  cat_pipe = Pipeline([
147
  ("imputer", SimpleImputer(strategy="most_frequent")),
148
+ ("onehot", OneHotEncoder(handle_unknown="ignore", sparse=True, drop="first"))
149
  ])
150
 
151
  preprocessor = ColumnTransformer(
 
813
  with st.expander("Admin controls", expanded=False):
814
  st.text_input("Admin key", type="password", key="admin_key")
815
  st.caption("Training and publishing are enabled only after admin authentication.")
816
+
817
 
818
 
819
  tab_train, tab_predict = st.tabs(["1️⃣ Train", "2️⃣ Predict + SHAP"])
 
823
  if "explainer" not in st.session_state:
824
  st.session_state.explainer = None
825
 
826
+ with tab_train:
827
+ st.subheader("Train model")
828
 
829
+ if not is_admin():
830
+ st.info("Training and publishing are restricted. Use Predict + SHAP for inference.")
831
+ st.stop()
832
 
833
+ st.markdown("### Feature reduction options")
834
 
835
+ use_feature_selection = st.checkbox(
836
+ "Drop columns that do not affect prediction (L1 feature selection)",
837
+ value=True,
838
+ key="train_use_feature_selection"
839
+ )
840
+ l1_C = st.slider(
841
+ "L1 selection strength (lower = fewer features)",
842
+ 0.01, 10.0, 1.0, 0.01
843
+ ) if use_feature_selection else 1.0
844
+
845
+ use_dimred = st.checkbox(
846
+ "Dimensionality reduction (TruncatedSVD) — reduces interpretability",
847
+ value=False
848
+ )
849
+
850
+ svd_components = st.slider(
851
+ "SVD components (only used if enabled)",
852
+ 5, 300, 50, 5
853
+ ) if use_dimred else 50
854
 
855
+ st.divider()
 
 
 
856
 
857
  # ---------------- TRAIN ----------------
858
 
859
+
860
 
861
  st.subheader("Train model")
862
  if not is_admin():
 
1139
  st.divider()
1140
  if st.session_state.pipe is None:
1141
  st.warning("Load a model version above, then upload an inference Excel.")
1142
+
1143
 
1144
  pipe = st.session_state.pipe
1145
 
 
1264
  # --- header dates ---
1265
  c1, c2 = st.columns(2)
1266
  with c1:
1267
+ dob_unknown = st.checkbox("DOB unknown", value=False, key="dob_unknown")
1268
+ dob = None
1269
+ if not dob_unknown:
1270
+ dob = st.date_input("Date of birth (DOB)", min_value=MIN_DOB, max_value=date.today(), key="dob")
1271
+
 
 
1272
  with c2:
1273
+ dx_unknown = st.checkbox("Diagnosis date unknown", value=False, key="dx_unknown")
1274
+ dx_date = None
1275
+ if not dx_unknown:
1276
+ dx_date = st.date_input(
1277
+ "Date of Diagnosis / 1st Bone Marrow biopsy",
1278
+ min_value=MIN_DOB, max_value=date.today(),
1279
+ key="dx_date"
1280
+ )
1281
 
1282
+ derived_age = age_years_at(dob, dx_date)
1283
+
1284
 
1285
  def yesno_to_01(v: str):
1286
  if v == "Yes":
 
1399
  v = st.text_input(f, value="", key=f"sp_{i}_other")
1400
  values_by_index[i] = np.nan if v.strip() == "" else v
1401
 
1402
+
 
 
 
 
 
 
 
1403
  # Apply FISH/NGS selections to row
1404
  fish_set = set(fish_selected)
1405
  ngs_set = set(ngs_selected)
 
1528
  cls_ext = compute_classification_metrics(y_ext01, proba, threshold=float(thr_ext))
1529
 
1530
  pr_ext = compute_pr_curve(y_ext01, proba)
1531
+ cal_ext = compute_calibration(y_ext01, proba, n_bins=PRED_N_BINS, strategy=PRED_CAL_STRATEGY)
1532
+
 
 
 
1533
  dca_ext = decision_curve_analysis(y_ext01, proba)
1534
 
1535
  # Display headline metrics
 
1815
 
1816
 
1817
  # BEESWARM SUMMARY (optional)
1818
+ if show_beeswarm:
1819
+ plt.figure(figsize=FIGSIZE, dpi=plot_dpi_screen)
1820
+ shap.summary_plot(
1821
+ shap_vals_batch,
1822
+ features=X_dense,
1823
+ feature_names=names,
1824
+ max_display=max_display,
1825
+ show=False
1826
+ )
1827
+ fig_swarm = plt.gcf()
1828
+ render_plot_with_download(fig_swarm, title="SHAP beeswarm", filename="shap_beeswarm.png", export_dpi=export_dpi)
1829
 
1830
 
1831