Spaces:
Running
Running
| """Manage Model β password-gated page for teammates to update the live model. | |
| Upload an exported model (a .zip of the folder produced by the training | |
| notebooks, i.e. config.json + model.safetensors + tokenizer files). It is | |
| validated against the project's 7-label scheme and pushed to the Hugging Face | |
| Hub repo the app loads from (config.PRIMARY_MODEL_ID). The Hub repo is the | |
| durable store β a Space's own disk is wiped on restart, so we never rely on it. | |
| Secrets (set in the Space settings, never in git): | |
| HF_TOKEN a Hugging Face *write* token | |
| MANAGE_PASSWORD password protecting this page | |
| """ | |
| import io | |
| import json | |
| import os | |
| import tempfile | |
| import zipfile | |
| from pathlib import Path | |
| import streamlit as st | |
| import config | |
| st.set_page_config(page_title="Manage Model", page_icon="π", layout="centered") | |
| st.title("π Manage Model") | |
| st.caption("Update the live NER model. Restricted to the team.") | |
| def _secret(name, default=""): | |
| # st.secrets raises if no secrets file exists; fall back to env. | |
| try: | |
| if name in st.secrets: | |
| return st.secrets[name] | |
| except Exception: # noqa: BLE001 | |
| pass | |
| return os.environ.get(name, default) | |
| # ---- Password gate ---------------------------------------------------------- | |
| PASSWORD = _secret("MANAGE_PASSWORD") | |
| if not PASSWORD: | |
| st.error("This page is disabled: no `MANAGE_PASSWORD` secret is configured.") | |
| st.stop() | |
| if not st.session_state.get("manage_authed"): | |
| pw = st.text_input("Password", type="password") | |
| if st.button("Unlock"): | |
| if pw == PASSWORD: | |
| st.session_state["manage_authed"] = True | |
| st.rerun() | |
| else: | |
| st.error("Wrong password.") | |
| st.stop() | |
| # ---- Authed -------------------------------------------------------------- | |
| TARGET_REPO = st.text_input("Target Hub model repo", value=config.PRIMARY_MODEL_ID, | |
| help="The repo the app loads. Overwriting it updates the live model.") | |
| private = st.checkbox("Keep repo private", value=True) | |
| st.markdown( | |
| "Upload a **.zip of your exported model folder** " | |
| "(`config.json`, `model.safetensors`, tokenizer files, ideally `label_config.json`). " | |
| "This is the folder the training notebooks write to `exported_models/β¦`." | |
| ) | |
| up = st.file_uploader("Model .zip", type=["zip"]) | |
| REQUIRED = ["config.json"] | |
| WEIGHTS = ["model.safetensors", "pytorch_model.bin"] | |
| def _find_model_dir(root: Path): | |
| """Locate the directory holding config.json (zip may have a wrapper folder).""" | |
| for cfg in root.rglob("config.json"): | |
| return cfg.parent | |
| return None | |
| def _validate(model_dir: Path): | |
| files = {p.name for p in model_dir.iterdir()} | |
| if not any(w in files for w in WEIGHTS): | |
| return f"No weights file found ({' or '.join(WEIGHTS)})." | |
| # Label-scheme check: prefer label_config.json, else config.json id2label. | |
| labels = None | |
| if (model_dir / "label_config.json").exists(): | |
| labels = set(json.loads((model_dir / "label_config.json").read_text()) | |
| .get("id2label", {}).values()) | |
| else: | |
| cfg = json.loads((model_dir / "config.json").read_text()) | |
| labels = set(cfg.get("id2label", {}).values()) | |
| if "B-SKILL" not in labels: | |
| return ("Label scheme mismatch β model does not use the project's BIO tags " | |
| f"(expected B-SKILL/JOB_TITLE/EDUCATION, got: {sorted(labels) or 'none'}).") | |
| return None | |
| if up is not None and st.button("Validate & publish", type="primary"): | |
| token = _secret("HF_TOKEN") | |
| if not token: | |
| st.error("No `HF_TOKEN` secret configured β cannot push to the Hub.") | |
| st.stop() | |
| try: | |
| from huggingface_hub import HfApi, create_repo | |
| except ModuleNotFoundError: | |
| st.error("`huggingface_hub` is not installed (add it to requirements.txt).") | |
| st.stop() | |
| with tempfile.TemporaryDirectory() as td: | |
| tdp = Path(td) | |
| try: | |
| with zipfile.ZipFile(io.BytesIO(up.getvalue())) as zf: | |
| zf.extractall(tdp) | |
| except zipfile.BadZipFile: | |
| st.error("That file is not a valid .zip.") | |
| st.stop() | |
| model_dir = _find_model_dir(tdp) | |
| if model_dir is None: | |
| st.error("No `config.json` found anywhere in the zip.") | |
| st.stop() | |
| err = _validate(model_dir) | |
| if err: | |
| st.error(err) | |
| st.stop() | |
| st.success(f"Validated: {sorted(p.name for p in model_dir.iterdir())}") | |
| with st.spinner(f"Publishing to {TARGET_REPO}β¦"): | |
| create_repo(TARGET_REPO, repo_type="model", private=private, | |
| exist_ok=True, token=token) | |
| HfApi(token=token).upload_folder( | |
| folder_path=str(model_dir), repo_id=TARGET_REPO, repo_type="model", | |
| commit_message="Update model via Manage Model page", | |
| ) | |
| st.success(f"β Published to https://huggingface.co/{TARGET_REPO}") | |
| st.info("Go to any page and click **π Reload model** in the sidebar to use it.") | |