Synav commited on
Commit
a27119e
·
verified ·
1 Parent(s): f30098c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -20
app.py CHANGED
@@ -657,6 +657,16 @@ def build_shap_explainer(pipe, X_bg, max_bg=200):
657
  )
658
  return explainer
659
 
 
 
 
 
 
 
 
 
 
 
660
 
661
 
662
  def publish_to_hub(model_repo_id: str, version_tag: str):
@@ -1483,26 +1493,28 @@ with tab_train:
1483
  # ---------------- PUBLISH (only after training) ----------------
1484
 
1485
 
1486
- if st.session_state.get("pipe") is not None:
1487
- st.divider()
1488
- st.subheader("Publish trained model to Hugging Face Hub")
1489
-
1490
- default_version = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
1491
- version_tag = st.text_input(
1492
- "Version tag",
1493
- value=default_version,
1494
- help="Used as releases/<version>/ in the model repository",
1495
- )
1496
-
1497
- if st.button("Publish model.joblib + meta.json to Model Repo"):
1498
- try:
1499
- with st.spinner("Uploading to Hugging Face Model repo..."):
1500
- paths = publish_to_hub(MODEL_REPO_ID, version_tag)
1501
 
1502
- st.success("Uploaded successfully to your model repository.")
1503
- st.json(paths)
1504
- except Exception as e:
1505
- st.error(f"Upload failed: {e}")
 
 
 
 
 
 
1506
 
1507
 
1508
 
@@ -2078,7 +2090,22 @@ with tab_predict:
2078
  st.divider()
2079
  st.subheader("Predict single patient")
2080
 
2081
- thr_single = st.slider("Classification threshold", 0.0, 1.0, 0.5, 0.01, key="sp_thr")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2082
  low_cut_s, high_cut_s = st.slider(
2083
  "Risk band cutoffs (low, high)",
2084
  0.0, 1.0, (0.2, 0.8), 0.01,
 
657
  )
658
  return explainer
659
 
660
+ def ensure_model_repo_exists(model_repo_id: str, token: str):
661
+ """
662
+ Optional helper: create the model repo if it doesn't exist.
663
+ Safe to call; if it exists, it will error -> you can ignore.
664
+ """
665
+ api = HfApi(token=token)
666
+ try:
667
+ api.create_repo(repo_id=model_repo_id, repo_type="model", private=False, exist_ok=True)
668
+ except Exception:
669
+ pass
670
 
671
 
672
  def publish_to_hub(model_repo_id: str, version_tag: str):
 
1493
  # ---------------- PUBLISH (only after training) ----------------
1494
 
1495
 
1496
+ # ---------------- PUBLISH (only after training) ----------------
1497
+ if st.session_state.get("pipe") is not None:
1498
+ st.divider()
1499
+ st.subheader("Publish trained model to Hugging Face Hub")
1500
+
1501
+ default_version = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
1502
+ version_tag = st.text_input(
1503
+ "Version tag",
1504
+ value=default_version,
1505
+ help="Used as releases/<version>/ in the model repository",
1506
+ )
 
 
 
 
1507
 
1508
+ if st.button("Publish model.joblib + meta.json to Model Repo", key="publish_btn"):
1509
+ try:
1510
+ with st.spinner("Uploading to Hugging Face Model repo..."):
1511
+ paths = publish_to_hub(MODEL_REPO_ID, version_tag)
1512
+
1513
+ st.success("Uploaded successfully to your model repository.")
1514
+ st.json(paths)
1515
+ except Exception as e:
1516
+ st.error(f"Upload failed: {e}")
1517
+
1518
 
1519
 
1520
 
 
2090
  st.divider()
2091
  st.subheader("Predict single patient")
2092
 
2093
+ m = meta.get("metrics", {})
2094
+ default_thr = float(m.get("best_threshold", 0.5))
2095
+
2096
+ thr_single = st.slider(
2097
+ "Classification threshold",
2098
+ 0.0, 1.0, default_thr, 0.01,
2099
+ key="sp_thr"
2100
+ )
2101
+
2102
+ # External validation threshold
2103
+ thr_ext = st.slider(
2104
+ "External validation threshold",
2105
+ 0.0, 1.0, default_thr, 0.01,
2106
+ key="thr_ext"
2107
+ )
2108
+
2109
  low_cut_s, high_cut_s = st.slider(
2110
  "Risk band cutoffs (low, high)",
2111
  0.0, 1.0, (0.2, 0.8), 0.01,