File size: 2,035 Bytes
c59578d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""Shared Streamlit UI helpers — the cross-page model toggle.

Keeping this in one place means every page shows the same picker and shares the
same selection (via st.session_state) and the same per-ref cache.
"""
import streamlit as st

import config
from lib.model import load_model

_CUSTOM_LABEL = "✏️ Custom HF model ID…"


@st.cache_resource(show_spinner="Loading NER model…")
def _load_cached(ref):
    """Cache one LoadedModel per distinct ref (None = demo fallback)."""
    return load_model(ref=ref)


def model_selector():
    """Render the sidebar model picker and return the selected LoadedModel.

    The choice persists across pages through st.session_state["model_label"].
    A "Custom HF model ID" option lets anyone load any Hub repo live, and the
    refresh button clears the cache to pick up a freshly-uploaded model.
    """
    options = config.available_models()           # {label: ref}
    labels = list(options.keys()) + [_CUSTOM_LABEL]

    st.sidebar.subheader("Model")
    current = st.session_state.get("model_label", labels[0])
    index = labels.index(current) if current in labels else 0
    choice = st.sidebar.selectbox(
        "Active NER model", labels, index=index, key="model_label",
        label_visibility="collapsed",
    )

    if choice == _CUSTOM_LABEL:
        ref = st.sidebar.text_input(
            "HF model repo id", placeholder="e.g. Zeqhx/cv-parser-ner",
            key="custom_model_id",
        ).strip() or None
    else:
        ref = options[choice]

    if st.sidebar.button("🔄 Reload model", use_container_width=True,
                         help="Clear the cache and re-pull (use after updating a model)."):
        _load_cached.clear()
        st.rerun()

    lm = _load_cached(ref)

    if lm.is_fallback:
        st.sidebar.warning("Demo mode — untrained head; predictions are not meaningful.",
                           icon="⚠️")
    else:
        st.sidebar.success("Model loaded", icon="✅")
    st.sidebar.caption(lm.source)

    return lm