File size: 10,689 Bytes
a89a7f1
2bf3d21
a89a7f1
 
 
 
b279884
d8ae160
c607869
ae347c6
9ac8f1c
a89a7f1
5bf7fd5
a89a7f1
 
db3d901
 
 
 
 
 
 
 
 
 
 
 
 
e8b0701
 
eb41f91
 
d8ae160
 
 
 
 
 
 
c607869
d8ae160
 
 
 
c607869
 
 
d8ae160
c607869
 
eb41f91
 
d8ae160
 
 
 
 
 
b279884
 
 
 
 
 
 
 
 
 
 
d8ae160
b279884
d8ae160
b279884
 
 
 
d8ae160
 
b279884
d8ae160
 
 
b279884
 
 
 
d8ae160
 
 
 
 
 
 
b279884
d8ae160
b279884
 
 
 
d8ae160
 
b279884
d8ae160
 
 
b279884
 
 
 
d8ae160
 
 
 
 
 
 
 
 
2bf3d21
 
 
 
 
 
 
 
c30bbc5
 
 
 
 
 
 
 
 
 
db3d901
c30bbc5
db3d901
c30bbc5
 
 
 
 
 
 
 
db3d901
c30bbc5
 
 
 
 
 
db3d901
c30bbc5
db3d901
c30bbc5
 
 
 
 
 
 
db3d901
c30bbc5
 
 
 
 
 
 
 
 
 
db3d901
c30bbc5
 
 
 
 
 
e8b0701
ae347c6
e8b0701
ae347c6
 
e8b0701
 
 
 
 
 
 
 
ae347c6
e8b0701
 
 
2bf3d21
a89a7f1
dc186e4
a89a7f1
db3d901
 
a89a7f1
db3d901
a89a7f1
 
 
 
 
f4259c0
a89a7f1
 
 
db3d901
a89a7f1
 
d8ae160
 
 
2bf3d21
 
db3d901
2bf3d21
 
 
 
 
 
 
 
 
a89a7f1
 
e8b0701
db3d901
a89a7f1
 
c30bbc5
a89a7f1
 
 
7ad2026
db3d901
a89a7f1
 
7ad2026
a89a7f1
 
 
 
 
db3d901
a89a7f1
 
 
2bf3d21
 
 
 
 
 
a89a7f1
 
 
 
 
12cdb17
9ac8f1c
12cdb17
2bf3d21
a89a7f1
2bf3d21
a89a7f1
 
2bf3d21
ecd19ae
d8ae160
a89a7f1
db3d901
d8ae160
 
 
 
a89a7f1
 
 
2bf3d21
a89a7f1
c607869
 
 
 
d8ae160
c607869
 
a89a7f1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
import os
from dataclasses import dataclass

import streamlit as st
from dotenv import load_dotenv

from utils.analysis_sources import DEFAULT_COMPARE_MODEL, DEFAULT_HUB_REPO, SOURCE_HUB
from utils.helpers import DATASET_SOURCES, session_key, widget_key
from utils.preload import preload_once
from utils.runtime import configured_ndif_api_key, list_remote_models
from utils.theme import active_base, install_catppuccin_theme

load_dotenv()
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
REMOTE_DEFAULT_MODEL = os.environ.get("REMOTE_DEFAULT_MODEL", "google/gemma-2-9b-it")
_LAST_LOCAL_MODEL_KEY = session_key("sidebar", "last_local_model")
_LAST_REMOTE_MODEL_KEY = session_key("sidebar", "last_remote_model")
_SIDEBAR_ACTIVE_TAB_KEY = session_key("sidebar", "active_tab")
_SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY = session_key(
    "sidebar", "remote_model_custom_value"
)
_SIDEBAR_REMOTE_MODEL_CUSTOM_ENABLED_KEY = session_key(
    "sidebar", "remote_model_custom_enabled"
)
_SIDEBAR_REMOTE_MODEL_KEY = session_key("sidebar", "remote_model")
_SIDEBAR_LOCAL_MODEL_KEY = session_key("sidebar", "local_model")
_SIDEBAR_REMOTE_KEY = session_key("sidebar", "remote")
_SIDEBAR_DATASET_SOURCE_KEY = session_key("sidebar", "dataset_source")
_SIDEBAR_NDIF_API_KEY = session_key("sidebar", "ndif_api_key")
NDIF_REGISTRATION_URL = "https://login.ndif.us/"


_TABS = ["Chat", "Analysis", "Probing", "Extract"]
_TAB_ICONS = [
    ":material/chat:",
    ":material/search:",
    ":material/biotech:",
    ":material/tune:",
]
_TAB_PRELOAD_MODULES = {
    "Chat": ("tabs.analysis_core", "tabs.extract", "tabs.compare_chat", "tabs.probe"),
    "Analysis": ("tabs.chat", "tabs.extract", "tabs.probe"),
    "Probing": ("tabs.chat", "tabs.analysis_core", "tabs.extract"),
    "Extract": ("tabs.chat", "tabs.analysis_core", "tabs.probe"),
}
_TAB_PRELOAD_FUNCTIONS = {
    "Chat": ("utils.analysis_metadata:synth_persona_attribute_names",),
    "Probing": ("utils.analysis_metadata:synth_persona_attribute_names",),
    "Extract": ("utils.analysis_metadata:synth_persona_attribute_names",),
}


def _hub_metadata_preload_calls() -> tuple[
    tuple[str, tuple[str, str, str, str | None]], ...
]:
    calls: list[tuple[str, tuple[str, str, str, str | None]]] = []

    def add(repo: str, model: str, mask_strategy: str, variant: str | None) -> None:
        calls.append(
            (
                "utils.analysis_sources:prefetch_hub_metadata",
                (repo, model, mask_strategy, variant),
            )
        )

    shared_source = st.session_state.get("source:last_source", SOURCE_HUB)
    shared_mask_strategy = st.session_state.get(
        "source:last_mask_strategy", "answer_mean"
    )

    analysis_source = st.session_state.get("analysis:last_source", shared_source)
    if analysis_source == SOURCE_HUB:
        repo = st.session_state.get(
            "analysis:hub_repo",
            st.session_state.get("source:hub_repo", DEFAULT_HUB_REPO),
        )
        mask_strategy = st.session_state.get(
            "analysis:last_mask_strategy",
            shared_mask_strategy,
        )
        model = st.session_state.get(
            widget_key("load", "hub_model", repo, mask_strategy),
            st.session_state.get(
                "analysis:hub_model_fallback",
                st.session_state.get("source:hub_model", DEFAULT_COMPARE_MODEL),
            ),
        )
        variant = st.session_state.get(
            "analysis:last_projection_variant",
            st.session_state.get("analysis:last_similarity_variant"),
        )
        add(repo, model, mask_strategy, variant)

    probe_source = st.session_state.get(widget_key("probe", "source"), shared_source)
    if probe_source == SOURCE_HUB:
        repo = st.session_state.get(
            "probe:hub_repo",
            st.session_state.get("source:hub_repo", DEFAULT_HUB_REPO),
        )
        mask_strategy = st.session_state.get(
            "probe:last_mask_strategy",
            shared_mask_strategy,
        )
        model = st.session_state.get(
            widget_key("probe", "hub_model", repo, mask_strategy),
            st.session_state.get(
                "probe:hub_model_fallback",
                st.session_state.get("source:hub_model", DEFAULT_COMPARE_MODEL),
            ),
        )
        add(repo, model, mask_strategy, st.session_state.get("probe:variant"))

    deduped: dict[tuple[str, tuple[str, str, str, str | None]], None] = {}
    for call in calls:
        deduped[call] = None
    return tuple(deduped)


@dataclass(frozen=True)
class SidebarState:
    remote: bool
    model_name: str
    dataset_source: str
    active_tab: str


def _remote_model_input(remote_models: list[str]) -> str:
    """Return the active remote model id, picking from running NDIF deployments or a custom value."""

    last_remote = st.session_state.get(_LAST_REMOTE_MODEL_KEY, REMOTE_DEFAULT_MODEL)

    if not remote_models:
        st.warning("No running NDIF models found.")
        model_name = st.text_input(
            "Model",
            value=st.session_state.get(
                _SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY, last_remote
            ),
            key=_SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY,
            help="NDIF model id. Use this to cold-load a remote model.",
        )
        st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name
        return model_name

    custom = st.toggle(
        "Custom remote model",
        value=False,
        key=_SIDEBAR_REMOTE_MODEL_CUSTOM_ENABLED_KEY,
        help="Enter any NDIF-loadable model id, even if it is not currently running.",
    )
    if custom:
        model_name = st.text_input(
            "Model",
            value=st.session_state.get(
                _SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY, last_remote
            ),
            key=_SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY,
            help="NDIF model id. Example: openai/gpt-oss-20b",
        )
        st.caption(
            f"{len(remote_models)} running NDIF model(s) detected. "
            "Custom model ids can cold-load if your NDIF account allows it."
        )
    else:
        default_model = st.session_state.get(_SIDEBAR_REMOTE_MODEL_KEY, last_remote)
        if default_model not in remote_models:
            default_model = (
                REMOTE_DEFAULT_MODEL
                if REMOTE_DEFAULT_MODEL in remote_models
                else remote_models[0]
            )
        model_name = st.selectbox(
            "Model",
            options=remote_models,
            index=remote_models.index(default_model),
            key=_SIDEBAR_REMOTE_MODEL_KEY,
            help="Running NDIF model.",
        )
    st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name
    return model_name


def _ndif_api_key_input() -> None:
    """Prompt for a per-session NDIF API key."""

    if configured_ndif_api_key():
        st.caption("Using NDIF API key from environment.")
        return

    api_key = st.text_input(
        "NDIF API key",
        type="password",
        key=_SIDEBAR_NDIF_API_KEY,
        help=f"Required for remote (NDIF) execution. Register at {NDIF_REGISTRATION_URL}",
    )
    if not api_key:
        st.caption(f"No NDIF API key found. [Get one]({NDIF_REGISTRATION_URL}).")


def _sidebar_controls() -> SidebarState:
    with st.sidebar:
        st.markdown("## Persona UI")

        if _SIDEBAR_ACTIVE_TAB_KEY not in st.session_state:
            st.session_state[_SIDEBAR_ACTIVE_TAB_KEY] = "Chat"

        active_tab = st.session_state[_SIDEBAR_ACTIVE_TAB_KEY]
        for tab_name, icon in zip(_TABS, _TAB_ICONS, strict=True):
            is_selected = tab_name == active_tab
            if st.button(
                tab_name,
                key=f"sidebar__tab__{tab_name.lower()}",
                width="stretch",
                type="primary" if is_selected else "secondary",
                icon=icon,
            ):
                st.session_state[_SIDEBAR_ACTIVE_TAB_KEY] = tab_name
                st.rerun()

        if active_tab in {"Analysis", "Probing"}:
            # These tabs select their own model in-tab. The global sidebar
            # only carries over the last local model id for breadcrumbs.
            model_name = st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL)
            dataset_source = st.session_state.get(
                _SIDEBAR_DATASET_SOURCE_KEY,
                DATASET_SOURCES[0],
            )
            return SidebarState(
                remote=False,
                model_name=model_name,
                dataset_source=dataset_source,
                active_tab=active_tab,
            )

        st.divider()
        st.caption("Runtime")
        _ndif_api_key_input()
        remote = st.toggle("Remote (NDIF)", value=False, key=_SIDEBAR_REMOTE_KEY)

        if remote:
            model_name = _remote_model_input(list_remote_models())
        else:
            model_name = st.text_input(
                "Model",
                value=st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL),
                key=_SIDEBAR_LOCAL_MODEL_KEY,
                help="Local model id or path.",
            )
            st.session_state[_LAST_LOCAL_MODEL_KEY] = model_name

        st.caption("Data")
        dataset_source = st.selectbox(
            "Source",
            DATASET_SOURCES,
            key=_SIDEBAR_DATASET_SOURCE_KEY,
            help="Dataset for Chat and Extract.",
        )

    return SidebarState(
        remote=remote,
        model_name=model_name,
        dataset_source=dataset_source,
        active_tab=active_tab,
    )


def main() -> None:
    """Run the Streamlit app."""

    st.set_page_config(page_title="Persona UI", layout="wide")
    install_catppuccin_theme(active_base())

    sidebar = _sidebar_controls()

    if sidebar.active_tab == "Extract":
        from tabs.extract import render_extract_tab

        render_extract_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)
    elif sidebar.active_tab == "Analysis":
        from tabs.analysis_core import render_analysis_tab

        render_analysis_tab()
    elif sidebar.active_tab == "Probing":
        from tabs.probe import render_probing_tab

        render_probing_tab()
    else:
        from tabs.chat import render_chat_tab

        render_chat_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)

    preload_once(
        f"after-{sidebar.active_tab.lower()}",
        modules=_TAB_PRELOAD_MODULES.get(sidebar.active_tab, ()),
        functions=_TAB_PRELOAD_FUNCTIONS.get(sidebar.active_tab, ()),
        calls=_hub_metadata_preload_calls(),
    )


if __name__ == "__main__":
    main()