Jac-Zac commited on
Commit
a89a7f1
·
0 Parent(s):

First commit

Browse files
.env.example ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy this file to .env and fill in the values.
2
+
3
+ # NDIF API key for remote nnsight execution
4
+ # Required only when REMOTE=True in notebook.py
5
+ # Get yours at https://login.ndif.us
6
+ NDIF_API_KEY=your-ndif-api-key-here
7
+
8
+ # HuggingFace model cache directory
9
+ # Defaults to ~/.cache/huggingface if unset
10
+ # Useful when working on a cluster with a shared cache or limited home quota
11
+ HF_HOME=/path/to/your/hf/cache
12
+
13
+ # Root directory for all generated artifacts (activations, plots, etc.)
14
+ # Defaults to artifacts if unset
15
+ ARTIFACTS_DIR=artifacts
16
+
17
+ # Default model IDs shown in the sidebar (optional — change to override the built-in defaults)
18
+ # DEFAULT_MODEL=google/gemma-2-2b-it
19
+ # REMOTE_DEFAULT_MODEL=google/gemma-2-9b-it
.gitignore ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ .venv/
25
+ venv/
26
+ ENV/
27
+ env/
28
+
29
+ # Environment variables — .env.example is intentionally tracked
30
+ .env
31
+ .env.*
32
+ !.env.example
33
+
34
+ # IDE
35
+ .idea/
36
+ .vscode/
37
+ *.swp
38
+ *.swo
39
+ *~
40
+
41
+ # Jupyter
42
+ .ipynb_checkpoints/
43
+
44
+ # Testing
45
+ .pytest_cache/
46
+ .coverage
47
+ htmlcov/
48
+
49
+ # OS
50
+ .DS_Store
51
+ Thumbs.db
52
+
53
+ # Project specific
54
+ results/
55
+ outputs/
56
+ artifacts/
57
+ *.json.bak
58
+ *.jsonl
59
+ *.jsonl.bak
60
+
61
+ # Tmp to avoid pushing things I'm testing
62
+ __marimo__/
63
+ AGENTS.md
64
+ # notebook_marimo.py
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Persona UI
2
+
3
+ Streamlit interface for persona vector extraction, analysis, and chat.
4
+
5
+ > [!WARNING]
6
+ > This is a proof-of-concept UI, mostly vibe-coded. It will likely be replaced by a proper frontend/backend in the future.
7
+
8
+ ## Overview
9
+
10
+ A web app built on top of [persona-vectors](../persona-vectors) that provides three tabs:
11
+
12
+ - **Chat** — interactive conversations with a model using persona-based system prompts (templated or biography)
13
+ - **Compare** — load saved activations and explore layer-wise cosine similarity, PCA, and UMAP projections
14
+ - **Extract** — run activation extraction from HuggingFace or a local JSONL dataset directly from the browser
15
+
16
+ ## Repository Layout
17
+
18
+ ```
19
+ persona-ui/
20
+ ├── app.py # Main entry point (Streamlit)
21
+ ├── state.py # Session state management (chat history, KV cache)
22
+ ├── tabs/
23
+ │ ├── chat.py # Chat tab
24
+ │ ├── compare.py # Activation comparison tab
25
+ │ └── extract.py # Extraction tab
26
+ └── utils/
27
+ ├── artifacts.py # Load saved activations metadata
28
+ ├── chat.py # Chat generation logic
29
+ ├── chat_export.py # Export chat logs to JSON
30
+ ├── datasets.py # Dataset loader wrapper
31
+ ├── extraction.py # Extraction orchestration
32
+ ├── helpers.py # UI labels and slug helpers
33
+ ├── local_dataset.py # Local JSONL dataset parsing
34
+ └── runtime.py # Model caching and NDIF queries
35
+ ```
36
+
37
+ Dataset loading and environment helpers are provided by the sibling
38
+ [persona-data](../persona-data) package. Core extraction, analysis, and
39
+ steering logic comes from [persona-vectors](../persona-vectors).
40
+
41
+ ## Installation
42
+
43
+ ```bash
44
+ uv sync
45
+ cp .env.example .env
46
+ ```
47
+
48
+ ## Quickstart
49
+
50
+ ```bash
51
+ streamlit run app.py
52
+ ```
53
+
54
+ ## Configuration
55
+
56
+ Copy `.env.example` to `.env` and fill in:
57
+
58
+ ```bash
59
+ NDIF_API_KEY=... # Required for remote (NDIF) model execution
60
+ HF_HOME=... # Optional: HuggingFace cache directory
61
+ ARTIFACTS_DIR=... # Optional: where activations are read from (default: ./artifacts)
62
+ ```
63
+
64
+ The app picks up this file automatically via `load_env()` on startup.
65
+
66
+ ## Saved Artifacts
67
+
68
+ The Compare and Extract tabs read from / write to:
69
+
70
+ ```
71
+ artifacts/
72
+ ├── activations/<model_dir>/<prompt_variant>/<persona_id>/
73
+ │ ├── activations.safetensors
74
+ │ └── metadata.json
75
+ └── chats/<model_dir>/<prompt_variant>/
76
+ └── <export>.json
77
+ ```
78
+
79
+ `<model_dir>` is the model name with `/` replaced by `__` (e.g. `google__gemma-2-9b-it`).
WARNING.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # WARNING 🚨
2
+
3
+ This part of the project is majorly vibe-coded. Mostly becuase it will probably be changed in the future to support an actual interace backhand / frontand without streamlit. And is as of now mostly a proof of concept and an easy development part of the project.
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import streamlit as st
5
+ from dotenv import load_dotenv
6
+
7
+ # Load .env early so DEFAULT_MODEL / REMOTE_DEFAULT_MODEL can be overridden via env
8
+ load_dotenv(Path(__file__).parent / ".env")
9
+
10
+ from utils.helpers import DATASET_SOURCES
11
+
12
+ DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
13
+ REMOTE_DEFAULT_MODEL = os.environ.get("REMOTE_DEFAULT_MODEL", "google/gemma-2-9b-it")
14
+
15
+
16
+ def _sidebar_controls() -> tuple[bool, str, str, str]:
17
+ from utils.runtime import list_remote_models
18
+
19
+ with st.sidebar:
20
+ st.markdown("# Persona UI")
21
+ st.caption("Chat, extract, and compare persona runs.")
22
+
23
+ if "sidebar__active_tab" not in st.session_state:
24
+ st.session_state["sidebar__active_tab"] = _TABS[0]
25
+
26
+ active_tab = st.session_state["sidebar__active_tab"]
27
+ for tab_name, icon in zip(_TABS, _TAB_ICONS, strict=True):
28
+ is_selected = tab_name == active_tab
29
+ if st.button(
30
+ tab_name,
31
+ key=f"sidebar__tab__{tab_name.lower()}",
32
+ use_container_width=True,
33
+ type="primary" if is_selected else "secondary",
34
+ icon=icon,
35
+ ):
36
+ st.session_state["sidebar__active_tab"] = tab_name
37
+ st.rerun()
38
+
39
+ st.divider()
40
+ st.caption("Runtime")
41
+ remote = st.toggle("Remote (NDIF)", value=False, key="sidebar__remote")
42
+
43
+ if remote:
44
+ remote_models = list_remote_models()
45
+ if remote_models:
46
+ default_model = (
47
+ REMOTE_DEFAULT_MODEL
48
+ if REMOTE_DEFAULT_MODEL in remote_models
49
+ else remote_models[0]
50
+ )
51
+ model_name = st.selectbox(
52
+ "Model",
53
+ options=remote_models,
54
+ index=remote_models.index(default_model),
55
+ key="sidebar__remote_model",
56
+ help="Running NDIF model.",
57
+ )
58
+ else:
59
+ st.error("No running NDIF models found.")
60
+ model_name = REMOTE_DEFAULT_MODEL
61
+ else:
62
+ model_name = st.text_input(
63
+ "Model",
64
+ value=DEFAULT_MODEL,
65
+ key="sidebar__local_model",
66
+ help="Local model id or path.",
67
+ )
68
+
69
+ st.caption("Data")
70
+ dataset_source = st.selectbox(
71
+ "Source",
72
+ DATASET_SOURCES,
73
+ key="sidebar__dataset_source",
74
+ help="Dataset for Chat and Extract.",
75
+ )
76
+
77
+ return remote, model_name, dataset_source, active_tab
78
+
79
+
80
+ _TABS = ["Chat", "Compare", "Extract"]
81
+ _TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
82
+
83
+
84
+ def main() -> None:
85
+ """Run the Streamlit app."""
86
+
87
+ # Deferred: importing torch is slow; keep it after dotenv load (done at
88
+ # module level above) so the Streamlit page config renders immediately.
89
+ import torch
90
+
91
+ torch.set_grad_enabled(False)
92
+
93
+ st.set_page_config(page_title="Persona UI", layout="wide")
94
+ remote, model_name, dataset_source, active_tab = _sidebar_controls()
95
+
96
+ if active_tab == "Extract":
97
+ from tabs.extract import render_extract_tab
98
+
99
+ render_extract_tab(remote, model_name, dataset_source)
100
+ elif active_tab == "Compare":
101
+ from tabs.compare import render_compare_tab
102
+
103
+ render_compare_tab(model_name)
104
+ else:
105
+ from tabs.chat import render_chat_tab
106
+
107
+ render_chat_tab(remote, model_name, dataset_source)
108
+
109
+
110
+ if __name__ == "__main__":
111
+ main()
pyproject.toml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "persona-ui"
3
+ version = "0.1.0"
4
+ description = "Streamlit UI for persona-vectors"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "persona-vectors",
9
+ "persona-data",
10
+ "nnterp>=1.3.0",
11
+ "streamlit>=1.44.0",
12
+ "plotly>=6.6.0",
13
+ "kaleido>=1.0.0",
14
+ "python-dotenv>=1.2.2",
15
+ "torch>=2.10.0",
16
+ "transformers>=5.2.0",
17
+ ]
18
+
19
+ [tool.uv.sources]
20
+ # NOTE: Switch to git sources after pushing the new package structure
21
+ persona-vectors = { git = "ssh://git@github.com/implicit-personalization/persona-vectors.git" }
22
+ persona-data = { git = "ssh://git@github.com/implicit-personalization/persona-data.git" }
23
+ # persona-vectors = { path = "../persona-vectors", editable = true }
24
+ # persona-data = { path = "../persona-data", editable = true }
25
+
26
+ # [build-system]
27
+ # requires = ["uv_build>=0.11.3,<0.12"]
28
+ # build-backend = "uv_build"
state.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ _CHAT_STATE_PREFIX = "chat_state::"
4
+
5
+
6
+ def chat_session_key(model_name: str, dataset_source: str) -> str:
7
+ """Build the session-state key for a chat context."""
8
+
9
+ return f"{_CHAT_STATE_PREFIX}{model_name}::{dataset_source}"
10
+
11
+
12
+ def _default_chat_state() -> dict[str, object]:
13
+ return {
14
+ "messages": [],
15
+ "persona_id": None,
16
+ "prompt_mode": "templated",
17
+ "past_key_values": None,
18
+ }
19
+
20
+
21
+ def _evict_inactive_kv_caches(active_key: str) -> None:
22
+ """Drop past_key_values from every chat context except the active one."""
23
+
24
+ for key in st.session_state:
25
+ if (
26
+ isinstance(key, str)
27
+ and key.startswith(_CHAT_STATE_PREFIX)
28
+ and key != active_key
29
+ ):
30
+ state = st.session_state[key]
31
+ if isinstance(state, dict) and state.get("past_key_values") is not None:
32
+ state["past_key_values"] = None
33
+
34
+
35
+ def get_chat_state(
36
+ model_name: str, remote: bool, dataset_source: str
37
+ ) -> dict[str, object]:
38
+ """Return the mutable chat state for the active context."""
39
+
40
+ key = chat_session_key(model_name, dataset_source)
41
+ state = st.session_state.get(key)
42
+ if state is None:
43
+ state = _default_chat_state()
44
+ st.session_state[key] = state
45
+ else:
46
+ for default_key, default_value in _default_chat_state().items():
47
+ state.setdefault(default_key, default_value)
48
+ _evict_inactive_kv_caches(key)
49
+ if remote and state.get("past_key_values") is not None:
50
+ state["past_key_values"] = None
51
+ return state
52
+
53
+
54
+ def reset_chat_state(model_name: str, remote: bool, dataset_source: str) -> None:
55
+ """Reset chat history and cache for the active context."""
56
+
57
+ state = get_chat_state(model_name, remote, dataset_source)
58
+ state["messages"] = []
59
+ state["past_key_values"] = None
tabs/__init__.py ADDED
File without changes
tabs/chat.py ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ from contextlib import nullcontext
4
+
5
+ import streamlit as st
6
+
7
+ from state import chat_session_key, get_chat_state, reset_chat_state
8
+ from utils.chat import ChatReply, generate_chat_reply, resolve_system_prompt
9
+ from utils.chat_export import save_chat_export
10
+ from utils.datasets import load_dataset
11
+ from utils.helpers import (
12
+ MODE_LABEL_TO_KEY,
13
+ MODE_LABELS,
14
+ VARIANT_LABELS,
15
+ persona_label,
16
+ widget_key,
17
+ )
18
+ from utils.runtime import cached_model
19
+
20
+ _VISIBLE_MESSAGE_COUNT = 5
21
+ _model_lock = threading.Lock()
22
+
23
+
24
+ def _render_chat_message(message: dict[str, str]) -> None:
25
+ if not message.get("content"):
26
+ return
27
+ with st.chat_message(message["role"]):
28
+ st.markdown(message["content"])
29
+
30
+
31
+ def _clear_chat_ui_state(*keys: str) -> None:
32
+ for key in keys:
33
+ st.session_state.pop(key, None)
34
+
35
+
36
+ def _generation_dict(gen_kwargs: dict, advanced_generation: bool) -> dict[str, object]:
37
+ return {
38
+ "max_new_tokens": int(gen_kwargs["max_new_tokens"]),
39
+ "advanced_generation": bool(advanced_generation),
40
+ "use_sampling": bool(gen_kwargs["do_sample"]),
41
+ "temperature": float(gen_kwargs["temperature"]),
42
+ "top_p": float(gen_kwargs["top_p"]),
43
+ "top_k": int(gen_kwargs["top_k"]),
44
+ "repetition_penalty": float(gen_kwargs["repetition_penalty"]),
45
+ "seed": gen_kwargs["seed"],
46
+ }
47
+
48
+
49
+ # ── Compare mode helpers ───────────────────────────────────────────────────────
50
+
51
+
52
+ def _panel_state(panel_key: str) -> dict:
53
+ """Get or initialise compare-panel chat state stored in session_state."""
54
+ if panel_key not in st.session_state:
55
+ st.session_state[panel_key] = {
56
+ "messages": [],
57
+ "persona_id": None,
58
+ "prompt_mode": "templated",
59
+ "past_key_values": None,
60
+ }
61
+ return st.session_state[panel_key]
62
+
63
+
64
+ def _render_compare_panel(
65
+ side: str,
66
+ context_key: str,
67
+ personas: list,
68
+ remote: bool,
69
+ model_name: str,
70
+ dataset_source: str,
71
+ gen_kwargs: dict,
72
+ advanced_generation: bool,
73
+ ) -> dict:
74
+ """Render persona/prompt controls + chat log for one compare panel.
75
+
76
+ Returns a dict with keys needed by the generation step:
77
+ panel_key, state, active_system_prompt, selected_persona, chat_log
78
+ """
79
+ panel_key = widget_key(context_key, f"cmp_{side}")
80
+ state = _panel_state(panel_key)
81
+
82
+ # ── Per-panel selectors ──────────────────────────────────────────────────
83
+ p_col, m_col = st.columns([3, 2])
84
+ with p_col:
85
+ selected_index = next(
86
+ (i for i, p in enumerate(personas) if p.id == state["persona_id"]), 0
87
+ )
88
+ selected_persona = st.selectbox(
89
+ "Persona",
90
+ options=personas,
91
+ index=selected_index,
92
+ format_func=persona_label,
93
+ key=widget_key(panel_key, "persona"),
94
+ )
95
+ with m_col:
96
+ current_label = VARIANT_LABELS.get(state["prompt_mode"], "None")
97
+ prompt_mode_label = st.selectbox(
98
+ "Prompt",
99
+ options=MODE_LABELS,
100
+ index=MODE_LABELS.index(current_label),
101
+ key=widget_key(panel_key, "prompt_mode"),
102
+ )
103
+ prompt_mode = MODE_LABEL_TO_KEY[prompt_mode_label]
104
+
105
+ # Reset state when persona or mode changes.
106
+ changed = (
107
+ state["persona_id"] != selected_persona.id
108
+ or state["prompt_mode"] != prompt_mode
109
+ )
110
+ if changed:
111
+ state["messages"] = []
112
+ state["past_key_values"] = None
113
+ state["persona_id"] = selected_persona.id
114
+ state["prompt_mode"] = prompt_mode
115
+ _clear_chat_ui_state(
116
+ widget_key(panel_key, "custom_prompt"),
117
+ widget_key(panel_key, "show_all"),
118
+ )
119
+
120
+ # ── System prompt ────────────────────────────────────────────────────────
121
+ active_system_prompt = resolve_system_prompt(
122
+ persona=selected_persona, mode=prompt_mode
123
+ )
124
+ custom_prompt_key = widget_key(panel_key, "custom_prompt")
125
+ if prompt_mode != "empty":
126
+ if custom_prompt_key not in st.session_state:
127
+ st.session_state[custom_prompt_key] = active_system_prompt
128
+ with st.expander("Edit prompt", expanded=False):
129
+ active_system_prompt = (
130
+ st.text_area(
131
+ "prompt",
132
+ key=custom_prompt_key,
133
+ height=150,
134
+ label_visibility="collapsed",
135
+ )
136
+ or None
137
+ )
138
+
139
+ export_success_message: str | None = None
140
+ action_col1, action_col2 = st.columns(2)
141
+ with action_col1:
142
+ if st.button(
143
+ "Export chat",
144
+ key=widget_key(panel_key, "export_chat"),
145
+ use_container_width=True,
146
+ ):
147
+ export_path = save_chat_export(
148
+ model_name=model_name,
149
+ dataset_source=dataset_source,
150
+ persona_id=selected_persona.id,
151
+ persona_name=getattr(selected_persona, "name", None),
152
+ panel_label=side,
153
+ prompt_mode=prompt_mode,
154
+ system_prompt=active_system_prompt,
155
+ messages=state["messages"],
156
+ generation=_generation_dict(gen_kwargs, advanced_generation),
157
+ )
158
+ export_success_message = f"Saved chat export to {export_path}"
159
+ with action_col2:
160
+ if st.button(
161
+ "Reset chat",
162
+ key=widget_key(panel_key, "reset"),
163
+ use_container_width=True,
164
+ type="secondary",
165
+ ):
166
+ state["messages"] = []
167
+ state["past_key_values"] = None
168
+ _clear_chat_ui_state(
169
+ widget_key(panel_key, "custom_prompt"),
170
+ widget_key(panel_key, "show_all"),
171
+ )
172
+ st.rerun()
173
+
174
+ if export_success_message:
175
+ st.success(export_success_message)
176
+
177
+ # ── Message history ──────────────────────────────────────────────────────
178
+ show_all_key = widget_key(panel_key, "show_all")
179
+ messages = state["messages"]
180
+ if len(messages) > _VISIBLE_MESSAGE_COUNT and not st.session_state.get(
181
+ show_all_key, False
182
+ ):
183
+ hidden_count = len(messages) - _VISIBLE_MESSAGE_COUNT
184
+ if st.button(
185
+ f"Show earlier ({hidden_count} hidden)",
186
+ key=widget_key(panel_key, "show_all_btn"),
187
+ ):
188
+ st.session_state[show_all_key] = True
189
+ st.rerun()
190
+ visible = messages[-_VISIBLE_MESSAGE_COUNT:]
191
+ else:
192
+ visible = messages
193
+
194
+ chat_log = st.container()
195
+ with chat_log:
196
+ for msg in visible:
197
+ _render_chat_message(msg)
198
+
199
+ return {
200
+ "panel_key": panel_key,
201
+ "state": state,
202
+ "active_system_prompt": active_system_prompt,
203
+ "selected_persona": selected_persona,
204
+ "chat_log": chat_log,
205
+ }
206
+
207
+
208
+ def _generate_for_panel(
209
+ panel: dict,
210
+ model,
211
+ remote: bool,
212
+ gen_kwargs: dict,
213
+ ) -> ChatReply:
214
+ """Run generate_chat_reply for one compare panel. Thread-safe."""
215
+ messages = []
216
+ if panel["active_system_prompt"]:
217
+ messages.append({"role": "system", "content": panel["active_system_prompt"]})
218
+ messages.extend(panel["state"]["messages"])
219
+
220
+ ctx = nullcontext() if remote else _model_lock
221
+ with ctx:
222
+ return generate_chat_reply(
223
+ model=model,
224
+ messages=messages,
225
+ remote=remote,
226
+ past_key_values=panel["state"]["past_key_values"],
227
+ **gen_kwargs,
228
+ )
229
+
230
+
231
+ def _render_compare_mode(
232
+ remote: bool,
233
+ model_name: str,
234
+ context_key: str,
235
+ dataset_source: str,
236
+ personas: list,
237
+ gen_kwargs: dict,
238
+ advanced_generation: bool,
239
+ ) -> None:
240
+ """Render the full side-by-side comparison UI."""
241
+ left_col, right_col = st.columns(2)
242
+
243
+ with left_col:
244
+ left = _render_compare_panel(
245
+ "left",
246
+ context_key,
247
+ personas,
248
+ remote,
249
+ model_name,
250
+ dataset_source,
251
+ gen_kwargs,
252
+ advanced_generation,
253
+ )
254
+ with right_col:
255
+ right = _render_compare_panel(
256
+ "right",
257
+ context_key,
258
+ personas,
259
+ remote,
260
+ model_name,
261
+ dataset_source,
262
+ gen_kwargs,
263
+ advanced_generation,
264
+ )
265
+
266
+ user_prompt = st.chat_input(
267
+ "Ask both...",
268
+ key=widget_key(context_key, "cmp_input"),
269
+ )
270
+ if not user_prompt:
271
+ return
272
+
273
+ model = cached_model(model_name=model_name, remote=remote)
274
+ panels = [(left, left_col), (right, right_col)]
275
+
276
+ for panel, col in panels:
277
+ panel["state"]["messages"].append({"role": "user", "content": user_prompt})
278
+ with col:
279
+ with panel["chat_log"]:
280
+ _render_chat_message({"role": "user", "content": user_prompt})
281
+
282
+ # Generate both responses in parallel (remote: truly concurrent; local: serialised via lock).
283
+ with st.spinner("Generating..."):
284
+ with ThreadPoolExecutor(max_workers=2) as executor:
285
+ futures = [
286
+ executor.submit(_generate_for_panel, panel, model, remote, gen_kwargs)
287
+ for panel, col in panels
288
+ ]
289
+ results = []
290
+ for future in futures:
291
+ try:
292
+ results.append(future.result())
293
+ except Exception as exc:
294
+ results.append(exc)
295
+
296
+ for (panel, col), result in zip(panels, results):
297
+ if isinstance(result, Exception):
298
+ with col:
299
+ with panel["chat_log"]:
300
+ st.error(f"Generation failed: {result}")
301
+ panel["state"]["messages"].pop()
302
+ continue
303
+
304
+ panel["state"]["messages"].append({"role": "assistant", "content": result.text})
305
+ panel["state"]["past_key_values"] = (
306
+ result.past_key_values if not remote else None
307
+ )
308
+ with col:
309
+ with panel["chat_log"]:
310
+ _render_chat_message({"role": "assistant", "content": result.text})
311
+
312
+
313
+ # ── Main tab entry point ───────────────────────────────────────────────────────
314
+
315
+
316
+ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
317
+ """Render the chat tab."""
318
+
319
+ st.title("Chat")
320
+
321
+ context_key = chat_session_key(model_name, dataset_source)
322
+ chat_state = get_chat_state(model_name, remote, dataset_source)
323
+ try:
324
+ dataset, dataset_status = load_dataset(dataset_source)
325
+ st.caption(dataset_status)
326
+ except Exception as exc:
327
+ st.error(f"Could not load data: {exc}")
328
+ st.info("Check the selected dataset source or upload both JSONL files.")
329
+ return
330
+
331
+ personas = list(dataset)
332
+ if not personas:
333
+ st.warning("No personas found in the selected dataset.")
334
+ st.info("Try a different dataset source or upload a non-empty personas file.")
335
+ return
336
+
337
+ # ── Generation settings ───────────────────────────────────────────────────
338
+ with st.expander("Advanced", expanded=False):
339
+ config_col1, config_col2 = st.columns([2, 1])
340
+ with config_col1:
341
+ max_new_tokens = st.slider(
342
+ "Max new tokens",
343
+ min_value=16,
344
+ max_value=512,
345
+ value=256,
346
+ step=16,
347
+ key=widget_key(context_key, "max_new_tokens"),
348
+ )
349
+ with config_col2:
350
+ repetition_penalty = st.slider(
351
+ "Repetition penalty",
352
+ min_value=0.5,
353
+ max_value=2.0,
354
+ value=1.0,
355
+ step=0.05,
356
+ key=widget_key(context_key, "repetition_penalty"),
357
+ )
358
+
359
+ use_sampling = st.checkbox(
360
+ "Random sampling",
361
+ value=False,
362
+ key=widget_key(context_key, "use_sampling"),
363
+ )
364
+
365
+ sampling_disabled = not use_sampling
366
+ sampling_col1, sampling_col2, sampling_col3 = st.columns(3)
367
+ with sampling_col1:
368
+ temperature = st.slider(
369
+ "Temperature",
370
+ min_value=0.01,
371
+ max_value=2.0,
372
+ value=1.0,
373
+ step=0.01,
374
+ disabled=sampling_disabled,
375
+ key=widget_key(context_key, "temperature"),
376
+ )
377
+ with sampling_col2:
378
+ top_p = st.slider(
379
+ "Top-p",
380
+ min_value=0.01,
381
+ max_value=1.0,
382
+ value=1.0,
383
+ step=0.01,
384
+ disabled=sampling_disabled,
385
+ key=widget_key(context_key, "top_p"),
386
+ )
387
+ with sampling_col3:
388
+ top_k = st.slider(
389
+ "Top-k (0 = off)",
390
+ min_value=0,
391
+ max_value=100,
392
+ value=50,
393
+ step=1,
394
+ disabled=sampling_disabled,
395
+ key=widget_key(context_key, "top_k"),
396
+ )
397
+
398
+ seed_disabled = sampling_disabled or remote
399
+ seed_enabled = st.checkbox(
400
+ "Fix seed",
401
+ value=False,
402
+ disabled=seed_disabled,
403
+ key=widget_key(context_key, "seed_enabled"),
404
+ )
405
+ if seed_enabled:
406
+ seed = int(
407
+ st.number_input(
408
+ "Seed",
409
+ min_value=0,
410
+ max_value=2_147_483_647,
411
+ value=0,
412
+ step=1,
413
+ disabled=seed_disabled,
414
+ key=widget_key(context_key, "seed"),
415
+ )
416
+ )
417
+ else:
418
+ seed = None
419
+
420
+ if remote:
421
+ st.caption("Seed is local-only and disabled for remote runs.")
422
+
423
+ advanced_generation = (
424
+ max_new_tokens != 256
425
+ or use_sampling
426
+ or temperature != 1.0
427
+ or top_p != 1.0
428
+ or top_k != 50
429
+ or repetition_penalty != 1.0
430
+ or seed is not None
431
+ )
432
+
433
+ do_sample = bool(use_sampling)
434
+ generation_seed = seed if do_sample and seed is not None and not remote else None
435
+ gen_kwargs = dict(
436
+ max_new_tokens=int(max_new_tokens),
437
+ do_sample=do_sample,
438
+ temperature=temperature,
439
+ top_p=top_p,
440
+ top_k=top_k,
441
+ repetition_penalty=repetition_penalty,
442
+ seed=generation_seed,
443
+ )
444
+
445
+ # ── Mode toggle ───────────────────────────────────────────────────────────
446
+ compare_mode = st.toggle(
447
+ "Compare mode",
448
+ value=False,
449
+ key=widget_key(context_key, "compare_mode"),
450
+ help="Side-by-side: send one message to two independent persona/prompt configurations.",
451
+ )
452
+
453
+ if compare_mode:
454
+ _render_compare_mode(
455
+ remote,
456
+ model_name,
457
+ context_key,
458
+ dataset_source,
459
+ personas,
460
+ gen_kwargs,
461
+ advanced_generation,
462
+ )
463
+ return
464
+
465
+ # ── Single-chat mode ──────────────────────────────────────────────────────
466
+ persona_select_key = widget_key(context_key, "persona_select")
467
+ prompt_mode_select_key = widget_key(context_key, "system_prompt_select")
468
+
469
+ col1, col2 = st.columns([2, 1])
470
+ with col1:
471
+ selected_index = next(
472
+ (i for i, p in enumerate(personas) if p.id == chat_state["persona_id"]),
473
+ 0,
474
+ )
475
+ selected_persona = st.selectbox(
476
+ "Persona",
477
+ options=personas,
478
+ index=selected_index,
479
+ format_func=persona_label,
480
+ key=persona_select_key,
481
+ )
482
+ with col2:
483
+ current_mode_label = VARIANT_LABELS.get(chat_state["prompt_mode"], "None")
484
+ prompt_mode_label = st.selectbox(
485
+ "Prompt",
486
+ options=MODE_LABELS,
487
+ index=MODE_LABELS.index(current_mode_label),
488
+ key=prompt_mode_select_key,
489
+ )
490
+ prompt_mode = MODE_LABEL_TO_KEY[prompt_mode_label]
491
+
492
+ active_system_prompt = resolve_system_prompt(
493
+ persona=selected_persona,
494
+ mode=prompt_mode,
495
+ )
496
+
497
+ chat_input_key = widget_key(context_key, "chat_input")
498
+ show_all_key = widget_key(context_key, "show_all_messages")
499
+ custom_prompt_key = widget_key(context_key, "custom_system_prompt")
500
+ pending_key = widget_key(context_key, "pending_prompt")
501
+ export_success_message: str | None = None
502
+
503
+ action_col1, action_col2 = st.columns(2)
504
+ with action_col1:
505
+ if st.button("Reset chat", use_container_width=True, type="secondary"):
506
+ reset_chat_state(model_name, remote, dataset_source)
507
+ _clear_chat_ui_state(
508
+ chat_input_key,
509
+ show_all_key,
510
+ custom_prompt_key,
511
+ pending_key,
512
+ )
513
+ st.rerun()
514
+ with action_col2:
515
+ if st.button("Export chat", use_container_width=True):
516
+ export_path = save_chat_export(
517
+ model_name=model_name,
518
+ dataset_source=dataset_source,
519
+ persona_id=selected_persona.id,
520
+ persona_name=getattr(selected_persona, "name", None),
521
+ prompt_mode=prompt_mode,
522
+ system_prompt=active_system_prompt,
523
+ messages=chat_state["messages"],
524
+ generation=_generation_dict(gen_kwargs, advanced_generation),
525
+ )
526
+ export_success_message = f"Saved chat export to {export_path}"
527
+
528
+ if export_success_message:
529
+ st.success(export_success_message)
530
+
531
+ changed_context = (
532
+ chat_state["persona_id"] != selected_persona.id
533
+ or chat_state["prompt_mode"] != prompt_mode
534
+ )
535
+ if changed_context:
536
+ had_history = bool(chat_state["messages"])
537
+ chat_state["persona_id"] = selected_persona.id
538
+ chat_state["prompt_mode"] = prompt_mode
539
+ reset_chat_state(model_name, remote, dataset_source)
540
+ _clear_chat_ui_state(
541
+ chat_input_key,
542
+ show_all_key,
543
+ custom_prompt_key,
544
+ pending_key,
545
+ )
546
+ if had_history:
547
+ st.info("Chat history reset because the persona or system prompt changed.")
548
+
549
+ chat_log = st.container()
550
+
551
+ with chat_log:
552
+ # System prompt as first item in conversation — collapsed by default, editable.
553
+ if prompt_mode != "empty":
554
+ if custom_prompt_key not in st.session_state:
555
+ st.session_state[custom_prompt_key] = active_system_prompt
556
+ with st.expander("Edit prompt", expanded=False):
557
+ active_system_prompt = (
558
+ st.text_area(
559
+ "Prompt",
560
+ key=custom_prompt_key,
561
+ height=200,
562
+ label_visibility="collapsed",
563
+ )
564
+ or None
565
+ )
566
+
567
+ # Collapse older messages, show only the most recent ones.
568
+ messages = chat_state["messages"]
569
+ if len(messages) > _VISIBLE_MESSAGE_COUNT and not st.session_state.get(
570
+ show_all_key, False
571
+ ):
572
+ hidden_count = len(messages) - _VISIBLE_MESSAGE_COUNT
573
+ if st.button(
574
+ f"Show earlier messages ({hidden_count} hidden)",
575
+ key=widget_key(context_key, "show_all_btn"),
576
+ ):
577
+ st.session_state[show_all_key] = True
578
+ st.rerun()
579
+ visible_messages = messages[-_VISIBLE_MESSAGE_COUNT:]
580
+ else:
581
+ visible_messages = messages
582
+
583
+ for message in visible_messages:
584
+ _render_chat_message(message)
585
+
586
+ user_prompt = st.chat_input(
587
+ "Ask something...",
588
+ key=chat_input_key,
589
+ )
590
+
591
+ # Pass 1: user submitted — append message and rerun so it renders before generation.
592
+ if user_prompt:
593
+ chat_state["messages"].append({"role": "user", "content": user_prompt})
594
+ st.session_state[pending_key] = True
595
+ st.rerun()
596
+
597
+ # Pass 2: message is already rendered above; now run generation.
598
+ if not st.session_state.pop(pending_key, False):
599
+ return
600
+
601
+ messages = []
602
+ if active_system_prompt:
603
+ messages.append({"role": "system", "content": active_system_prompt})
604
+ messages.extend(chat_state["messages"])
605
+
606
+ with st.spinner("Generating reply..."):
607
+ model = cached_model(model_name=model_name, remote=remote)
608
+ try:
609
+ reply: ChatReply = generate_chat_reply(
610
+ model=model,
611
+ messages=messages,
612
+ remote=remote,
613
+ past_key_values=chat_state["past_key_values"],
614
+ **gen_kwargs,
615
+ )
616
+ except Exception as exc:
617
+ with chat_log:
618
+ st.error(f"Could not generate a reply: {exc}")
619
+ st.info("Try a shorter prompt, reset the chat, or switch personas.")
620
+ chat_state["messages"].pop()
621
+ return
622
+
623
+ chat_state["messages"].append({"role": "assistant", "content": reply.text})
624
+ chat_state["past_key_values"] = reply.past_key_values if not remote else None
625
+
626
+ save_chat_export(
627
+ model_name=model_name,
628
+ dataset_source=dataset_source,
629
+ persona_id=selected_persona.id,
630
+ persona_name=getattr(selected_persona, "name", None),
631
+ prompt_mode=prompt_mode,
632
+ system_prompt=active_system_prompt,
633
+ messages=chat_state["messages"],
634
+ generation=_generation_dict(gen_kwargs, advanced_generation),
635
+ )
636
+ st.rerun()
tabs/compare.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from persona_data.environment import get_artifacts_dir
3
+ from persona_vectors.analysis import build_embedding_figure, project_pca, project_umap
4
+ from persona_vectors.plots import (
5
+ plot_multiple_layer_similarities,
6
+ save_plot_html,
7
+ save_plot_png,
8
+ )
9
+
10
+ from utils.artifacts import (
11
+ artifact_persona_options,
12
+ list_available_layers,
13
+ load_cosine_traces,
14
+ load_embedding_samples,
15
+ )
16
+ from utils.helpers import (
17
+ ANALYSIS_HELP_TEXT,
18
+ ANALYSIS_LABELS,
19
+ ANALYSIS_MODES,
20
+ PROMPT_VARIANTS,
21
+ persona_display_label,
22
+ prompt_variant_label,
23
+ slugify,
24
+ widget_key,
25
+ )
26
+
27
+
28
+ def _filename(*parts: str) -> str:
29
+ return "__".join(slugify(part) for part in parts if part)
30
+
31
+
32
+ def _select_artifact_personas(
33
+ artifacts_root: str,
34
+ model_name: str,
35
+ variants: list[str],
36
+ ) -> tuple[list[str], dict[str, str]]:
37
+ persona_options, persona_names = artifact_persona_options(
38
+ artifacts_root,
39
+ model_name,
40
+ variants,
41
+ )
42
+ if not persona_options:
43
+ if len(variants) > 1:
44
+ st.info(
45
+ "No personas have saved activations for all selected variants. Run extraction for both variants first."
46
+ )
47
+ else:
48
+ st.info("No personas found for this model yet. Run extraction first.")
49
+ return [], persona_names
50
+
51
+ persona_ids = st.multiselect(
52
+ "Personas",
53
+ options=persona_options,
54
+ default=persona_options[:1] if len(persona_options) > 1 else persona_options,
55
+ format_func=lambda persona_id: persona_display_label(
56
+ persona_id, persona_names.get(persona_id)
57
+ ),
58
+ key=widget_key("load", "personas", model_name, *variants),
59
+ )
60
+ return persona_ids, persona_names
61
+
62
+
63
+ def _render_cosine_similarity(
64
+ artifacts_root: str,
65
+ model_name: str,
66
+ ) -> None:
67
+ col1, col2 = st.columns(2)
68
+ with col1:
69
+ variant_a = st.selectbox(
70
+ "Variant A",
71
+ options=PROMPT_VARIANTS,
72
+ index=0,
73
+ format_func=prompt_variant_label,
74
+ key=widget_key("load", "variant_a"),
75
+ )
76
+ with col2:
77
+ variant_b = st.selectbox(
78
+ "Variant B",
79
+ options=PROMPT_VARIANTS,
80
+ index=min(1, len(PROMPT_VARIANTS) - 1),
81
+ format_func=prompt_variant_label,
82
+ key=widget_key("load", "variant_b"),
83
+ )
84
+
85
+ if variant_a == variant_b:
86
+ st.warning("Choose two different variants to compare.")
87
+ return
88
+
89
+ persona_ids, _ = _select_artifact_personas(
90
+ artifacts_root,
91
+ model_name,
92
+ [variant_a, variant_b],
93
+ )
94
+ if not persona_ids:
95
+ return
96
+
97
+ cosine_fig_key = widget_key("load", "cosine_fig_state", model_name)
98
+ filename = _filename("compare", "cosine", model_name, variant_a, variant_b)
99
+
100
+ if st.button("Compare vectors", type="primary"):
101
+ traces, loaded_names, errors = load_cosine_traces(
102
+ artifacts_root,
103
+ model_name,
104
+ persona_ids,
105
+ variant_a,
106
+ variant_b,
107
+ )
108
+
109
+ if errors:
110
+ for err in errors:
111
+ st.error(f"Failed to load vectors: `{err}`")
112
+ if not traces:
113
+ st.error("No personas loaded successfully.")
114
+ st.info(
115
+ "Check that extraction has been run for both variants and selected personas."
116
+ )
117
+ st.session_state.pop(cosine_fig_key, None)
118
+ return
119
+
120
+ display_traces = [
121
+ (
122
+ persona_display_label(persona_id, loaded_names.get(persona_id)),
123
+ short,
124
+ long,
125
+ )
126
+ for persona_id, short, long in traces
127
+ ]
128
+ fig = plot_multiple_layer_similarities(
129
+ display_traces,
130
+ title=f"{prompt_variant_label(variant_a)} vs {prompt_variant_label(variant_b)}",
131
+ show=False,
132
+ )
133
+ st.session_state[cosine_fig_key] = (fig, len(traces))
134
+
135
+ if cosine_fig_key in st.session_state:
136
+ fig, n_traces = st.session_state[cosine_fig_key]
137
+ st.plotly_chart(fig, use_container_width=True)
138
+ save_col1, save_col2 = st.columns(2)
139
+ with save_col1:
140
+ if st.button("Save HTML", key=widget_key("load", "save_cosine_html")):
141
+ output_path = save_plot_html(fig, filename)
142
+ st.success(f"Saved HTML to `{output_path}`")
143
+ with save_col2:
144
+ if st.button("Save PNG", key=widget_key("load", "save_cosine_png")):
145
+ try:
146
+ output_path = save_plot_png(fig, filename)
147
+ st.success(f"Saved PNG to `{output_path}`")
148
+ except Exception as exc:
149
+ st.error(f"Could not save PNG: {exc}")
150
+ st.success(f"Loaded {n_traces} personas for cosine comparison.")
151
+
152
+
153
+ def _render_embedding_analysis(
154
+ artifacts_root: str,
155
+ model_name: str,
156
+ analysis_mode: str,
157
+ ) -> None:
158
+ selected_variant = st.selectbox(
159
+ "Variant",
160
+ options=PROMPT_VARIANTS,
161
+ format_func=prompt_variant_label,
162
+ key=widget_key("load", "variant"),
163
+ )
164
+
165
+ persona_ids, persona_names = _select_artifact_personas(
166
+ artifacts_root,
167
+ model_name,
168
+ [selected_variant],
169
+ )
170
+ if not persona_ids:
171
+ return
172
+
173
+ layer_options = list_available_layers(
174
+ artifacts_root,
175
+ model_name,
176
+ [selected_variant],
177
+ persona_ids,
178
+ )
179
+ if not layer_options:
180
+ st.info(
181
+ "No shared layers are available for the selected personas. Try fewer personas or a different variant."
182
+ )
183
+ return
184
+
185
+ persona_key = "_".join(sorted(persona_ids))
186
+ layer_key = widget_key("load", "layers", model_name, selected_variant, persona_key)
187
+ default_layers = [
188
+ layer
189
+ for layer in st.session_state.get(layer_key, layer_options[:3])
190
+ if layer in layer_options
191
+ ] or layer_options[:3]
192
+ selected_layers = st.multiselect(
193
+ "Layers",
194
+ options=layer_options,
195
+ default=default_layers,
196
+ key=layer_key,
197
+ )
198
+ if not selected_layers:
199
+ st.info("Select at least one layer.")
200
+ return
201
+
202
+ button_label = (
203
+ "Generate PCA projection"
204
+ if analysis_mode == "PCA"
205
+ else "Generate UMAP projection"
206
+ )
207
+
208
+ embedding_fig_key = widget_key(
209
+ "load", "embedding_fig_state", model_name, analysis_mode
210
+ )
211
+
212
+ if st.button(button_label, type="primary"):
213
+ progress = st.progress(0, text="Preparing projections...")
214
+
215
+ def update_progress(current: int, total: int, loaded: int) -> None:
216
+ fraction = current / total if total else 1.0
217
+ progress.progress(
218
+ fraction,
219
+ text=f"Processing layer {current}/{total} ({loaded} plot(s) ready)",
220
+ )
221
+
222
+ project_fn = project_pca if analysis_mode == "PCA" else project_umap
223
+ try:
224
+ plots, errors = load_embedding_samples(
225
+ artifacts_root,
226
+ model_name,
227
+ persona_ids,
228
+ selected_variant,
229
+ selected_layers,
230
+ project_fn,
231
+ persona_names,
232
+ progress_fn=update_progress,
233
+ )
234
+
235
+ if errors:
236
+ for err in errors:
237
+ if (
238
+ "missing layer" in err
239
+ or "no selected personas have this layer" in err
240
+ ):
241
+ st.warning(f"Skipping unavailable data: `{err}`")
242
+ else:
243
+ st.error(f"Failed to load vectors: `{err}`")
244
+ if not plots:
245
+ st.warning(
246
+ "No projections could be built for the current persona/layer selection."
247
+ )
248
+ st.info("Try fewer personas, fewer layers, or a different variant.")
249
+ st.session_state.pop(embedding_fig_key, None)
250
+ else:
251
+ title_prefix, x_label, y_label = ANALYSIS_LABELS[analysis_mode]
252
+ rendered_figures: list[tuple[int, object]] = []
253
+ for layer_idx, coords, labels, hover_text in plots:
254
+ fig = build_embedding_figure(
255
+ coords=coords,
256
+ labels=labels,
257
+ title=f"{title_prefix}, layer {layer_idx}",
258
+ x_label=x_label,
259
+ y_label=y_label,
260
+ hover_text=hover_text,
261
+ )
262
+ rendered_figures.append((layer_idx, fig))
263
+ total_samples = sum(coords.shape[0] for _, coords, _, _ in plots)
264
+ st.session_state[embedding_fig_key] = (
265
+ rendered_figures,
266
+ persona_key,
267
+ selected_variant,
268
+ total_samples,
269
+ )
270
+ finally:
271
+ progress.empty()
272
+
273
+ if embedding_fig_key in st.session_state:
274
+ rendered_figures, saved_persona_key, saved_variant, total_samples = (
275
+ st.session_state[embedding_fig_key]
276
+ )
277
+ cols = st.columns(2)
278
+ for idx, (layer_idx, fig) in enumerate(rendered_figures):
279
+ with cols[idx % 2]:
280
+ st.plotly_chart(fig, use_container_width=True)
281
+ st.success(
282
+ f"Loaded {total_samples} samples across {len(rendered_figures)} layers."
283
+ )
284
+ filenames = [
285
+ _filename(
286
+ "compare",
287
+ analysis_mode,
288
+ model_name,
289
+ saved_variant,
290
+ saved_persona_key,
291
+ str(layer_idx),
292
+ )
293
+ for layer_idx, _ in rendered_figures
294
+ ]
295
+ save_col1, save_col2 = st.columns(2)
296
+ with save_col1:
297
+ if st.button(
298
+ "Save HTML",
299
+ key=widget_key("load", "save_embedding_html", analysis_mode),
300
+ ):
301
+ saved_paths = [
302
+ save_plot_html(fig, fn)
303
+ for (_, fig), fn in zip(rendered_figures, filenames)
304
+ ]
305
+ st.success(
306
+ f"Saved {len(saved_paths)} HTML plot(s) to `artifacts/plots`."
307
+ )
308
+ with save_col2:
309
+ if st.button(
310
+ "Save PNG",
311
+ key=widget_key("load", "save_embedding_png", analysis_mode),
312
+ ):
313
+ try:
314
+ saved_paths = [
315
+ save_plot_png(fig, fn)
316
+ for (_, fig), fn in zip(rendered_figures, filenames)
317
+ ]
318
+ st.success(
319
+ f"Saved {len(saved_paths)} PNG plot(s) to `artifacts/plots`."
320
+ )
321
+ except Exception as exc:
322
+ st.error(f"Could not save PNGs: {exc}")
323
+
324
+
325
+ def render_compare_tab(model_name: str) -> None:
326
+ """Render the compare tab."""
327
+
328
+ st.title("Compare")
329
+ st.caption("Compare saved activations by cosine similarity, PCA, or UMAP.")
330
+
331
+ st.subheader("Analysis")
332
+
333
+ with st.expander("Advanced", expanded=False):
334
+ artifacts_root = st.text_input(
335
+ "Artifacts root",
336
+ value=str(get_artifacts_dir() / "activations"),
337
+ )
338
+
339
+ analysis_mode = st.segmented_control(
340
+ "Analysis mode",
341
+ options=ANALYSIS_MODES,
342
+ default=ANALYSIS_MODES[0],
343
+ key=widget_key("load", "analysis_mode"),
344
+ label_visibility="collapsed",
345
+ )
346
+ if analysis_mode is None:
347
+ analysis_mode = ANALYSIS_MODES[0]
348
+ st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
349
+
350
+ if analysis_mode == "Cosine similarity":
351
+ _render_cosine_similarity(artifacts_root, model_name)
352
+ return
353
+
354
+ _render_embedding_analysis(artifacts_root, model_name, analysis_mode)
tabs/extract.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from utils.datasets import load_dataset
4
+ from utils.extraction import run_extraction
5
+ from utils.helpers import (
6
+ PROMPT_VARIANTS,
7
+ persona_label,
8
+ prompt_variant_label,
9
+ widget_key,
10
+ )
11
+ from utils.runtime import cached_model
12
+
13
+
14
+ def _extract_widget_key(
15
+ model_name: str, remote: bool, dataset_source: str, suffix: str
16
+ ) -> str:
17
+ return widget_key("extract", str(remote), model_name, dataset_source, suffix)
18
+
19
+
20
+ def _render_local_dataset_uploads() -> None:
21
+ """Render file inputs for local dataset uploads."""
22
+
23
+ with st.expander("Local dataset upload", expanded=True):
24
+ st.file_uploader(
25
+ "personas.jsonl",
26
+ type=["jsonl"],
27
+ key="extract__personas_file",
28
+ help="Expected fields: id, persona, templated_prompt, biography_md",
29
+ )
30
+ st.file_uploader(
31
+ "qa.jsonl",
32
+ type=["jsonl"],
33
+ key="extract__qa_file",
34
+ help="Expected fields: id, qid, type, question, answer, difficulty",
35
+ )
36
+
37
+
38
+ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> None:
39
+ """Render the extraction tab."""
40
+
41
+ st.title("Extract")
42
+
43
+ if dataset_source == "Local JSONL upload":
44
+ _render_local_dataset_uploads()
45
+
46
+ selected_variants = st.multiselect(
47
+ "Prompt variants",
48
+ options=PROMPT_VARIANTS,
49
+ default=PROMPT_VARIANTS,
50
+ format_func=prompt_variant_label,
51
+ key=_extract_widget_key(model_name, remote, dataset_source, "prompt_variants"),
52
+ )
53
+ if not selected_variants:
54
+ st.info("Select at least one prompt variant.")
55
+ return
56
+
57
+ try:
58
+ dataset, dataset_status = load_dataset(dataset_source)
59
+ st.caption(dataset_status)
60
+ except Exception as exc:
61
+ st.error(f"Could not load data: {exc}")
62
+ st.info(
63
+ "Upload both JSONL files or switch to the built-in SynthPersona source."
64
+ )
65
+ return
66
+
67
+ personas = list(dataset)
68
+ if not personas:
69
+ st.warning("No personas found in the selected dataset.")
70
+ st.info(
71
+ "Try another dataset source or check that the personas file is not empty."
72
+ )
73
+ return
74
+
75
+ selected_personas = st.multiselect(
76
+ "Personas",
77
+ options=personas,
78
+ default=[personas[0]] if personas else [],
79
+ format_func=persona_label,
80
+ key=_extract_widget_key(model_name, remote, dataset_source, "persona_select"),
81
+ )
82
+
83
+ if not selected_personas:
84
+ st.info("Select at least one persona.")
85
+ return
86
+
87
+ qa_filter_type: str | None
88
+ qa_filter_difficulty: list[int] | None
89
+
90
+ with st.expander("Advanced", expanded=False):
91
+ st.caption("Filters")
92
+
93
+ col1, col2, col3 = st.columns([2, 2, 1])
94
+ with col1:
95
+ qa_type_select = st.selectbox(
96
+ "QA type",
97
+ options=["all", "explicit", "implicit"],
98
+ index=0,
99
+ key=_extract_widget_key(
100
+ model_name, remote, dataset_source, "qa_type_select"
101
+ ),
102
+ )
103
+ qa_filter_type = (
104
+ qa_type_select if qa_type_select in ("explicit", "implicit") else None
105
+ )
106
+ with col2:
107
+ difficulty_values = st.multiselect(
108
+ "Difficulty",
109
+ options=[1, 2, 3],
110
+ default=[1, 2, 3],
111
+ key=_extract_widget_key(
112
+ model_name, remote, dataset_source, "difficulty_select"
113
+ ),
114
+ )
115
+ qa_filter_difficulty = difficulty_values if difficulty_values else None
116
+
117
+ # Pre-load QA pairs for all selected personas to validate filters and set slider range.
118
+ qa_by_persona = {
119
+ p.id: dataset.get_qa(
120
+ p.id, type=qa_filter_type, difficulty=qa_filter_difficulty
121
+ )
122
+ for p in selected_personas
123
+ }
124
+ personas_without_qa = [p for p in selected_personas if not qa_by_persona[p.id]]
125
+ if personas_without_qa:
126
+ names = ", ".join(p.name for p in personas_without_qa)
127
+ st.warning(f"No QA pairs match filters for: {names}. They will be skipped.")
128
+
129
+ personas_to_run = [p for p in selected_personas if qa_by_persona[p.id]]
130
+ if not personas_to_run:
131
+ st.info("No personas have matching QA pairs. Widen the filters.")
132
+ return
133
+
134
+ min_qa_count = min(len(qa_by_persona[p.id]) for p in personas_to_run)
135
+
136
+ with col3:
137
+ max_questions = st.slider(
138
+ "Max questions",
139
+ min_value=1,
140
+ max_value=min_qa_count,
141
+ value=min_qa_count,
142
+ key=_extract_widget_key(
143
+ model_name, remote, dataset_source, "max_questions"
144
+ ),
145
+ )
146
+
147
+ run_clicked = st.button("Run extraction", type="primary")
148
+ if not run_clicked:
149
+ return
150
+
151
+ status_box = st.empty()
152
+ status_box.info("Extraction in progress...")
153
+ progress = st.progress(0, text="Preparing extraction...")
154
+
155
+ with st.spinner("Loading model..."):
156
+ model = cached_model(model_name=model_name, remote=remote)
157
+
158
+ try:
159
+ total_steps = len(personas_to_run) * len(selected_variants)
160
+ step = 0
161
+ results = []
162
+
163
+ for persona in personas_to_run:
164
+ qa_pairs = qa_by_persona[persona.id][:max_questions]
165
+ for variant in selected_variants:
166
+ progress.progress(
167
+ step / total_steps if total_steps else 1.0,
168
+ text=f"{persona.name} · {prompt_variant_label(variant)} ({step + 1}/{total_steps})",
169
+ )
170
+ variant_results = run_extraction(
171
+ model=model,
172
+ model_name=model_name,
173
+ persona=persona,
174
+ qa_pairs=qa_pairs,
175
+ variants=[variant],
176
+ remote=remote,
177
+ )
178
+ results.extend(variant_results)
179
+ step += 1
180
+
181
+ progress.progress(1.0, text="Extraction complete")
182
+ except Exception as exc:
183
+ st.error(f"Extraction failed: {exc}")
184
+ return
185
+ finally:
186
+ progress.empty()
187
+
188
+ status_box.success("Extraction complete")
189
+ st.success(f"Saved {len(results)} artifact set(s)")
190
+
191
+ for result in results:
192
+ st.markdown(
193
+ f"- **{result.persona_name}** · {prompt_variant_label(result.variant)}: "
194
+ f"{result.n_questions} questions, {result.n_layers} layers, {result.d_model} hidden size"
195
+ )
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Utility helpers for the Streamlit UI."""
utils/artifacts.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections.abc import Callable
3
+ from pathlib import Path
4
+
5
+ import streamlit as st
6
+ import torch
7
+ from persona_vectors.activation_io import (
8
+ load_activation_metadata,
9
+ load_per_question_vectors,
10
+ )
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def model_dir_name(model_name: str) -> str:
16
+ """Encode a model name for use in artifact paths."""
17
+
18
+ return model_name.replace("/", "__")
19
+
20
+
21
+ def list_available_personas(
22
+ artifacts_root: str | Path,
23
+ model_name: str,
24
+ variants: list[str],
25
+ ) -> list[str]:
26
+ """List persona ids available for every requested variant."""
27
+
28
+ shared_personas: set[str] | None = None
29
+ root = Path(artifacts_root)
30
+ for variant in variants:
31
+ model_dir = root / model_dir_name(model_name) / variant
32
+ if not model_dir.exists():
33
+ return []
34
+
35
+ variant_personas = {d.name for d in model_dir.iterdir() if d.is_dir()}
36
+ if shared_personas is None:
37
+ shared_personas = variant_personas
38
+ else:
39
+ shared_personas &= variant_personas
40
+
41
+ if not shared_personas:
42
+ return []
43
+
44
+ return sorted(shared_personas or set())
45
+
46
+
47
+ def load_persona_names(
48
+ artifacts_root: str | Path,
49
+ model_name: str,
50
+ variants: list[str],
51
+ persona_ids: list[str],
52
+ ) -> dict[str, str]:
53
+ """Load display names from saved activation metadata."""
54
+
55
+ names: dict[str, str] = {}
56
+ for persona_id in persona_ids:
57
+ for variant in variants:
58
+ try:
59
+ metadata = load_activation_metadata(
60
+ root_dir=artifacts_root,
61
+ model_name=model_name,
62
+ prompt_variant=variant,
63
+ persona_id=persona_id,
64
+ )
65
+ except Exception:
66
+ logger.debug(
67
+ "Failed to load metadata for persona %s variant %s",
68
+ persona_id,
69
+ variant,
70
+ exc_info=True,
71
+ )
72
+ continue
73
+
74
+ persona_name = metadata.get("persona_name")
75
+ if isinstance(persona_name, str) and persona_name:
76
+ names[persona_id] = persona_name
77
+ break
78
+
79
+ return names
80
+
81
+
82
+ def artifact_persona_options(
83
+ artifacts_root: str | Path,
84
+ model_name: str,
85
+ variants: list[str],
86
+ ) -> tuple[list[str], dict[str, str]]:
87
+ """Return persona ids and names for the selected artifacts."""
88
+
89
+ persona_options = list_available_personas(artifacts_root, model_name, variants)
90
+ persona_names = load_persona_names(
91
+ artifacts_root,
92
+ model_name,
93
+ variants,
94
+ persona_options,
95
+ )
96
+ return persona_options, persona_names
97
+
98
+
99
+ @st.cache_data(show_spinner=False)
100
+ def list_available_layers(
101
+ artifacts_root: str,
102
+ model_name: str,
103
+ variants: list[str],
104
+ persona_ids: list[str],
105
+ ) -> list[int]:
106
+ """List layer indices shared by all matching saved activation files."""
107
+
108
+ shared_layers: set[int] | None = None
109
+ for variant in variants:
110
+ for persona_id in persona_ids:
111
+ try:
112
+ vectors, _ = load_per_question_vectors(
113
+ root_dir=artifacts_root,
114
+ model_name=model_name,
115
+ prompt_variant=variant,
116
+ persona_id=persona_id,
117
+ )
118
+ except Exception:
119
+ logger.debug(
120
+ "Failed to load vectors for persona %s variant %s",
121
+ persona_id,
122
+ variant,
123
+ exc_info=True,
124
+ )
125
+ continue
126
+
127
+ layers = set(range(vectors.shape[1]))
128
+ if shared_layers is None:
129
+ shared_layers = layers
130
+ else:
131
+ shared_layers &= layers
132
+
133
+ return sorted(shared_layers or set())
134
+
135
+
136
+ def load_cosine_traces(
137
+ artifacts_root: str | Path,
138
+ model_name: str,
139
+ persona_ids: list[str],
140
+ variant_a: str,
141
+ variant_b: str,
142
+ ) -> tuple[list[tuple[str, torch.Tensor, torch.Tensor]], dict[str, str], list[str]]:
143
+ """Load mean activation traces for pairwise cosine-similarity plots."""
144
+
145
+ persona_names = load_persona_names(
146
+ artifacts_root,
147
+ model_name,
148
+ [variant_a, variant_b],
149
+ persona_ids,
150
+ )
151
+ traces: list[tuple[str, torch.Tensor, torch.Tensor]] = []
152
+ errors: list[str] = []
153
+
154
+ for persona_id in persona_ids:
155
+ try:
156
+ vectors_a, _ = load_per_question_vectors(
157
+ root_dir=artifacts_root,
158
+ model_name=model_name,
159
+ prompt_variant=variant_a,
160
+ persona_id=persona_id,
161
+ )
162
+ vectors_b, _ = load_per_question_vectors(
163
+ root_dir=artifacts_root,
164
+ model_name=model_name,
165
+ prompt_variant=variant_b,
166
+ persona_id=persona_id,
167
+ )
168
+ except Exception as exc:
169
+ errors.append(f"{persona_id}: {exc}")
170
+ continue
171
+
172
+ traces.append(
173
+ (persona_id, vectors_a.float().mean(dim=0), vectors_b.float().mean(dim=0))
174
+ )
175
+
176
+ return traces, persona_names, errors
177
+
178
+
179
+ def load_embedding_samples(
180
+ artifacts_root: str | Path,
181
+ model_name: str,
182
+ persona_ids: list[str],
183
+ variant: str,
184
+ selected_layers: list[int],
185
+ project_fn: Callable[[torch.Tensor], torch.Tensor],
186
+ persona_names: dict[str, str],
187
+ progress_fn: Callable[[int, int, int], None] | None = None,
188
+ ) -> tuple[list[tuple[int, torch.Tensor, list[str], list[str]]], list[str]]:
189
+ """Load samples for 2D projections without re-reading each layer from disk."""
190
+
191
+ plots: list[tuple[int, torch.Tensor, list[str], list[str]]] = []
192
+ errors: list[str] = []
193
+ vectors_by_persona: dict[str, torch.Tensor] = {}
194
+
195
+ for persona_id in persona_ids:
196
+ try:
197
+ vectors, _ = load_per_question_vectors(
198
+ root_dir=artifacts_root,
199
+ model_name=model_name,
200
+ prompt_variant=variant,
201
+ persona_id=persona_id,
202
+ )
203
+ except Exception as exc:
204
+ errors.append(f"{persona_id} / {variant}: {exc}")
205
+ continue
206
+
207
+ vectors_by_persona[persona_id] = vectors
208
+
209
+ total_layers = len(selected_layers)
210
+ for idx, layer_idx in enumerate(selected_layers, start=1):
211
+ samples: list[torch.Tensor] = []
212
+ labels: list[str] = []
213
+ hover_text: list[str] = []
214
+
215
+ for persona_id, vectors in vectors_by_persona.items():
216
+ if layer_idx >= vectors.shape[1]:
217
+ errors.append(f"{persona_id} / {variant}: missing layer {layer_idx}")
218
+ continue
219
+
220
+ layer_vectors = vectors[:, layer_idx, :]
221
+ samples.append(layer_vectors)
222
+ labels.extend([persona_id] * layer_vectors.shape[0])
223
+ display_name = persona_names.get(persona_id) or persona_id
224
+ hover_text.extend(
225
+ [
226
+ f"<b>{display_name}</b><br>{variant}",
227
+ ]
228
+ * layer_vectors.shape[0]
229
+ )
230
+
231
+ if not samples:
232
+ errors.append(f"Layer {layer_idx}: no selected personas have this layer")
233
+ else:
234
+ all_samples = torch.cat(samples, dim=0)
235
+ if all_samples.shape[0] < 2:
236
+ errors.append(
237
+ f"Layer {layer_idx}: need at least 2 samples after filtering selected personas"
238
+ )
239
+ else:
240
+ try:
241
+ coords = project_fn(all_samples)
242
+ plots.append((layer_idx, coords, labels, hover_text))
243
+ except Exception as exc:
244
+ errors.append(f"Layer {layer_idx}: {exc}")
245
+
246
+ if progress_fn is not None:
247
+ progress_fn(idx, total_layers, len(plots))
248
+
249
+ return plots, errors
utils/chat.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from contextlib import contextmanager, nullcontext
3
+ from dataclasses import dataclass
4
+ from typing import Literal
5
+
6
+ import torch
7
+ from nnterp import StandardizedTransformer
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ from persona_data.synth_persona import PersonaData
12
+ from persona_data.prompts import (
13
+ format_biography_prompt,
14
+ format_templated_prompt,
15
+ normalize_messages,
16
+ )
17
+
18
+ SystemPromptMode = Literal["empty", "templated", "biography", "custom"]
19
+
20
+ _CUSTOM_PROMPT_DEFAULT = "You are a helpful assistant."
21
+
22
+
23
+ @dataclass
24
+ class ChatReply:
25
+ text: str
26
+ prompt_tokens: int
27
+ output_tokens: int
28
+ past_key_values: object | None
29
+
30
+
31
+ def resolve_system_prompt(
32
+ persona: PersonaData | None,
33
+ mode: SystemPromptMode,
34
+ ) -> str:
35
+ """Resolve the active system prompt for chat.
36
+
37
+ Args:
38
+ persona: Selected persona, if any.
39
+ mode: Prompt mode selected in the UI.
40
+
41
+ Returns:
42
+ The rendered system prompt string.
43
+ """
44
+
45
+ if persona is None:
46
+ return ""
47
+
48
+ if mode == "templated":
49
+ return format_templated_prompt(persona.templated_prompt)
50
+ if mode == "biography":
51
+ return format_biography_prompt(persona.biography_md)
52
+ if mode == "custom":
53
+ return _CUSTOM_PROMPT_DEFAULT
54
+ return ""
55
+
56
+
57
+ def _format_plain_messages(
58
+ messages: list[dict[str, str]], add_generation_prompt: bool
59
+ ) -> str:
60
+ """Format messages as plain ``Role: content`` text, used as a last-resort fallback."""
61
+ lines: list[str] = []
62
+
63
+ for message in messages:
64
+ role = message["role"]
65
+ content = message["content"]
66
+
67
+ if role == "system":
68
+ if content:
69
+ lines.append(f"System: {content}")
70
+ elif role == "user":
71
+ lines.append(f"User: {content}")
72
+ elif role == "assistant":
73
+ lines.append(f"Assistant: {content}")
74
+ else:
75
+ lines.append(f"{role.title()}: {content}")
76
+
77
+ if add_generation_prompt and (not lines or not lines[-1].startswith("Assistant:")):
78
+ lines.append("Assistant:")
79
+
80
+ return "\n\n".join(lines)
81
+
82
+
83
+ def _format_generation_prompt(
84
+ messages: list[dict[str, str]], tokenizer: object
85
+ ) -> tuple[str, int]:
86
+ """Render messages into a single prompt string and count prompt tokens.
87
+
88
+ Tries the tokenizer's chat template first, falls back to normalized messages,
89
+ then to a plain-text format if both template attempts fail.
90
+ """
91
+ normalized_messages = messages
92
+
93
+ try:
94
+ prompt = tokenizer.apply_chat_template(
95
+ normalized_messages,
96
+ tokenize=False,
97
+ add_generation_prompt=True,
98
+ )
99
+ except Exception:
100
+ logger.debug(
101
+ "Chat template failed on raw messages, trying normalized", exc_info=True
102
+ )
103
+ normalized_messages = normalize_messages(messages)
104
+
105
+ try:
106
+ prompt = tokenizer.apply_chat_template(
107
+ normalized_messages,
108
+ tokenize=False,
109
+ add_generation_prompt=True,
110
+ )
111
+ except Exception:
112
+ logger.debug(
113
+ "Chat template failed on normalized messages, falling back to plain format",
114
+ exc_info=True,
115
+ )
116
+ prompt = _format_plain_messages(
117
+ normalized_messages,
118
+ add_generation_prompt=True,
119
+ )
120
+
121
+ prompt_token_count = tokenizer(prompt, return_tensors="pt").input_ids.shape[1]
122
+ return prompt, prompt_token_count
123
+
124
+
125
+ @contextmanager
126
+ def _seeded_rng(seed: int | None):
127
+ """Context manager that forks the RNG state and sets a deterministic seed."""
128
+ if seed is None:
129
+ yield
130
+ return
131
+
132
+ cuda_ctx = torch.random.fork_rng(devices=range(torch.cuda.device_count()))
133
+ mps_ctx = (
134
+ torch.random.fork_rng(devices=range(1), device_type="mps")
135
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
136
+ else nullcontext()
137
+ )
138
+
139
+ with cuda_ctx, mps_ctx:
140
+ torch.manual_seed(seed)
141
+ yield
142
+
143
+
144
+ def generate_chat_reply(
145
+ model: StandardizedTransformer,
146
+ messages: list[dict[str, str]],
147
+ remote: bool,
148
+ past_key_values: object | None = None,
149
+ max_new_tokens: int = 256,
150
+ do_sample: bool = False,
151
+ temperature: float = 1.0,
152
+ top_p: float = 1.0,
153
+ top_k: int = 50,
154
+ repetition_penalty: float = 1.0,
155
+ seed: int | None = None,
156
+ ) -> ChatReply:
157
+ """Generate one assistant reply from a full chat history.
158
+
159
+ The helper uses ``model.generate`` so it works with both local and NDIF-backed
160
+ nnsight models. The full conversation is re-rendered each turn and the cache from
161
+ the previous turn is reused when available.
162
+
163
+ Args:
164
+ model: Loaded standardized nnterp model.
165
+ messages: Full chat history, including any system prompt as the first message.
166
+ remote: Whether to execute the generation on NDIF.
167
+ past_key_values: Cache returned by the previous generation step.
168
+ max_new_tokens: Maximum number of assistant tokens to generate.
169
+ do_sample: Whether to sample from the model distribution.
170
+ temperature: Sampling temperature, used only when sampling is enabled.
171
+ top_p: Nucleus sampling threshold, used only when sampling is enabled.
172
+ top_k: Top-k cutoff, used only when sampling is enabled.
173
+ repetition_penalty: Repetition penalty applied during decoding.
174
+ seed: Optional local RNG seed for sampled generation.
175
+
176
+ Returns:
177
+ ChatReply with generated text and the updated cache.
178
+ """
179
+
180
+ tokenizer = model.tokenizer
181
+ prompt, prompt_token_count = _format_generation_prompt(messages, tokenizer)
182
+
183
+ generation_kwargs: dict[str, object] = {
184
+ "max_new_tokens": max_new_tokens,
185
+ "return_dict_in_generate": True,
186
+ "use_cache": True,
187
+ }
188
+ if do_sample:
189
+ generation_kwargs["do_sample"] = True
190
+ generation_kwargs["temperature"] = temperature
191
+ generation_kwargs["top_p"] = top_p
192
+ generation_kwargs["top_k"] = top_k
193
+ if repetition_penalty != 1.0:
194
+ generation_kwargs["repetition_penalty"] = repetition_penalty
195
+ if past_key_values is not None and not remote:
196
+ generation_kwargs["past_key_values"] = past_key_values
197
+ if remote:
198
+ generation_kwargs["remote"] = True
199
+ # WARNING: NDIF returns caches on CPU, so cross-turn cache reuse is not stable.
200
+
201
+ with _seeded_rng(seed if do_sample and not remote else None):
202
+ with model.generate(prompt, **generation_kwargs) as tracer:
203
+ generated = tracer.result.save()
204
+
205
+ if hasattr(generated, "value") and getattr(generated, "value") is not None:
206
+ generated = generated.value
207
+
208
+ if not hasattr(generated, "sequences"):
209
+ raise ValueError("Generation did not return token sequences")
210
+
211
+ sequences = generated.sequences
212
+ if not isinstance(sequences, torch.Tensor):
213
+ raise TypeError("Generated sequences must be a tensor")
214
+
215
+ generated_ids = sequences[0, prompt_token_count:]
216
+ text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
217
+ output_tokens = int(sequences.shape[1] - prompt_token_count)
218
+
219
+ return ChatReply(
220
+ text=text,
221
+ prompt_tokens=prompt_token_count,
222
+ output_tokens=max(0, output_tokens),
223
+ past_key_values=(
224
+ getattr(generated, "past_key_values", None) if not remote else None
225
+ ),
226
+ )
utils/chat_export.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from datetime import datetime, timezone
3
+ from pathlib import Path
4
+
5
+ from persona_data.environment import get_artifacts_dir
6
+
7
+ from utils.artifacts import model_dir_name
8
+ from utils.helpers import slugify
9
+
10
+
11
+ def build_chat_export_payload(
12
+ *,
13
+ model_name: str,
14
+ dataset_source: str,
15
+ persona_id: str,
16
+ persona_name: str | None,
17
+ panel_label: str | None,
18
+ prompt_mode: str,
19
+ system_prompt: str | None,
20
+ messages: list[dict[str, str]],
21
+ generation: dict[str, object],
22
+ ) -> dict[str, object]:
23
+ """Build a JSON-serializable snapshot of the current chat session.
24
+
25
+ Args:
26
+ model_name: Model identifier used for the chat.
27
+ dataset_source: Human-readable dataset source label.
28
+ persona_id: Selected persona id.
29
+ persona_name: Selected persona display name, if available.
30
+ prompt_mode: Active system prompt mode.
31
+ messages: Conversation messages without the system prompt.
32
+ generation: Generation settings used for the chat.
33
+
34
+ Returns:
35
+ A JSON-serializable dictionary.
36
+ """
37
+
38
+ return {
39
+ "model_name": model_name,
40
+ "dataset_source": dataset_source,
41
+ "persona": {
42
+ "id": persona_id,
43
+ "name": persona_name,
44
+ },
45
+ "panel_label": panel_label,
46
+ "prompt_mode": prompt_mode,
47
+ "generation": generation,
48
+ "messages": (
49
+ [{"role": "system", "content": system_prompt}] if system_prompt else []
50
+ )
51
+ + messages,
52
+ }
53
+
54
+
55
+ def save_chat_export(
56
+ *,
57
+ model_name: str,
58
+ dataset_source: str,
59
+ persona_id: str,
60
+ persona_name: str | None,
61
+ prompt_mode: str,
62
+ system_prompt: str | None,
63
+ messages: list[dict[str, str]],
64
+ generation: dict[str, object],
65
+ panel_label: str | None = None,
66
+ ) -> Path:
67
+ """Save the current chat session to ``artifacts/chats`` as JSON.
68
+
69
+ Args:
70
+ model_name: Model identifier used for the chat.
71
+ dataset_source: Human-readable dataset source label.
72
+ persona_id: Selected persona id.
73
+ persona_name: Selected persona display name, if available.
74
+ prompt_mode: Active system prompt mode.
75
+ system_prompt: Current system prompt text, if any.
76
+ messages: Conversation messages without the system prompt.
77
+ generation: Generation settings used for the chat.
78
+
79
+ Returns:
80
+ The path the export was written to.
81
+ """
82
+
83
+ payload = build_chat_export_payload(
84
+ model_name=model_name,
85
+ dataset_source=dataset_source,
86
+ persona_id=persona_id,
87
+ persona_name=persona_name,
88
+ panel_label=panel_label,
89
+ prompt_mode=prompt_mode,
90
+ system_prompt=system_prompt,
91
+ messages=messages,
92
+ generation=generation,
93
+ )
94
+ export_dir = (
95
+ get_artifacts_dir()
96
+ / "chats"
97
+ / model_dir_name(model_name)
98
+ / slugify(dataset_source)
99
+ / slugify(persona_id)
100
+ )
101
+ export_dir.mkdir(parents=True, exist_ok=True)
102
+
103
+ timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
104
+ filename_parts = [
105
+ timestamp,
106
+ slugify(persona_name or persona_id),
107
+ slugify(prompt_mode),
108
+ ]
109
+ if panel_label:
110
+ filename_parts.append(slugify(panel_label))
111
+ export_path = export_dir / f"{'__'.join(filename_parts)}.json"
112
+ export_path.write_text(
113
+ f"{json.dumps(payload, indent=2, ensure_ascii=False)}\n",
114
+ encoding="utf-8",
115
+ )
116
+
117
+ return export_path
utils/datasets.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import shutil
3
+ from pathlib import Path
4
+ from tempfile import mkdtemp
5
+ from typing import Any
6
+
7
+ import streamlit as st
8
+ from persona_data.synth_persona import SynthPersonaDataset
9
+
10
+ from .helpers import DATASET_SOURCES
11
+ from .local_dataset import LocalPersonaDataset
12
+
13
+
14
+ @st.cache_resource(show_spinner=False)
15
+ def cached_hf_dataset() -> SynthPersonaDataset:
16
+ """Load the default SynthPersona HuggingFace dataset once."""
17
+
18
+ return SynthPersonaDataset()
19
+
20
+
21
+ def _upload_cache_dir() -> Path:
22
+ cache_dir = st.session_state.get("_upload_cache_dir")
23
+ if cache_dir is None:
24
+ cache_dir = mkdtemp(prefix="persona_vectors_uploads_")
25
+ st.session_state["_upload_cache_dir"] = cache_dir
26
+ # Register cleanup so the temp dir is removed when the server process exits.
27
+ atexit.register(shutil.rmtree, cache_dir, ignore_errors=True)
28
+ return Path(cache_dir)
29
+
30
+
31
+ def _uploaded_file_to_temp_path(uploaded_file: Any, stem: str) -> Path:
32
+ suffix = Path(uploaded_file.name).suffix or ".jsonl"
33
+ temp_path = _upload_cache_dir() / f"{stem}{suffix}"
34
+ data = uploaded_file.getvalue()
35
+ if temp_path.exists() and temp_path.stat().st_size == len(data):
36
+ return temp_path
37
+ temp_path.write_bytes(data)
38
+ return temp_path
39
+
40
+
41
+ def load_dataset(
42
+ dataset_source: str,
43
+ ) -> tuple[SynthPersonaDataset | LocalPersonaDataset, str]:
44
+ """Load the selected dataset source for the UI."""
45
+
46
+ if dataset_source == DATASET_SOURCES[0]:
47
+ return cached_hf_dataset(), "SynthPersona"
48
+
49
+ personas_file = st.session_state.get("extract__personas_file")
50
+ qa_file = st.session_state.get("extract__qa_file")
51
+ if personas_file is None or qa_file is None:
52
+ raise ValueError("Upload both personas.jsonl and qa.jsonl files")
53
+
54
+ personas_path = _uploaded_file_to_temp_path(personas_file, stem="personas")
55
+ qa_path = _uploaded_file_to_temp_path(qa_file, stem="qa")
56
+ return (
57
+ LocalPersonaDataset(personas_path=personas_path, qa_path=qa_path),
58
+ "Local upload",
59
+ )
utils/extraction.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from nnterp import StandardizedTransformer
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ from persona_data.environment import get_artifacts_dir
11
+ from persona_data.synth_persona import PersonaData, QAPair
12
+ from persona_vectors.activation_io import save_per_question_vectors
13
+ from persona_vectors.activations import extract_activations
14
+ from persona_data.prompts import (
15
+ format_biography_prompt,
16
+ format_messages,
17
+ format_templated_prompt,
18
+ )
19
+
20
+
21
+ @dataclass
22
+ class VariantExtractionResult:
23
+ variant: str
24
+ output_dir: str
25
+ n_questions: int
26
+ n_layers: int
27
+ d_model: int
28
+ persona_name: str = ""
29
+
30
+
31
+ def _prepare_inputs(
32
+ tokenizer: object,
33
+ system_prompt: str,
34
+ qa_pairs: list[QAPair],
35
+ ) -> tuple[list[str], list[torch.Tensor], list[str]]:
36
+ """Format QA pairs into tokenized prompts with answer-token masks.
37
+
38
+ Args:
39
+ tokenizer: HuggingFace-compatible tokenizer from the model.
40
+ system_prompt: System prompt to prepend to each conversation.
41
+ qa_pairs: List of question-answer pairs to format.
42
+
43
+ Returns:
44
+ A tuple of (full_texts, token_masks, questions) where full_texts are
45
+ the rendered prompt strings, token_masks are boolean tensors marking
46
+ answer tokens, and questions are the raw question strings.
47
+ """
48
+ full_texts: list[str] = []
49
+ token_masks: list[torch.Tensor] = []
50
+ questions: list[str] = []
51
+
52
+ for qa in qa_pairs:
53
+ messages = [
54
+ {"role": "system", "content": system_prompt},
55
+ {"role": "user", "content": qa.question},
56
+ {"role": "assistant", "content": qa.answer},
57
+ ]
58
+ full_prompt, answer_start = format_messages(messages, tokenizer)
59
+ seq_len = tokenizer(full_prompt, return_tensors="pt").input_ids.shape[1]
60
+
61
+ full_texts.append(full_prompt)
62
+ token_masks.append(torch.arange(seq_len) >= answer_start)
63
+ questions.append(qa.question)
64
+
65
+ return full_texts, token_masks, questions
66
+
67
+
68
+ def run_extraction(
69
+ model: StandardizedTransformer,
70
+ model_name: str,
71
+ persona: PersonaData,
72
+ qa_pairs: list[QAPair],
73
+ variants: list[str],
74
+ remote: bool,
75
+ ) -> list[VariantExtractionResult]:
76
+ """Run activation extraction and save outputs for selected variants.
77
+
78
+ Args:
79
+ model: Loaded standardized nnterp model.
80
+ model_name: HuggingFace model identifier used for artifact paths.
81
+ persona: The persona whose QA pairs are being extracted.
82
+ qa_pairs: Question-answer pairs to run extraction on.
83
+ variants: Prompt variants to extract (e.g. ``"templated"``, ``"biography"``).
84
+ remote: Whether to execute on NDIF.
85
+
86
+ Returns:
87
+ A list of extraction results, one per variant.
88
+
89
+ Raises:
90
+ ValueError: If ``qa_pairs`` is empty or an unsupported variant is given.
91
+ """
92
+ if not qa_pairs:
93
+ raise ValueError("No QA pairs selected for extraction")
94
+
95
+ tokenizer = model.tokenizer
96
+ activations_dir = get_artifacts_dir() / "activations"
97
+
98
+ system_prompt_by_variant = {
99
+ "templated": format_templated_prompt(persona.templated_prompt),
100
+ "biography": format_biography_prompt(persona.biography_md),
101
+ }
102
+
103
+ results: list[VariantExtractionResult] = []
104
+
105
+ for variant in variants:
106
+ if variant not in system_prompt_by_variant:
107
+ raise ValueError(f"Unsupported variant: {variant}")
108
+
109
+ full_texts, token_masks, questions = _prepare_inputs(
110
+ tokenizer=tokenizer,
111
+ system_prompt=system_prompt_by_variant[variant],
112
+ qa_pairs=qa_pairs,
113
+ )
114
+
115
+ per_question_vectors = extract_activations(
116
+ model=model,
117
+ full_texts=full_texts,
118
+ token_masks=token_masks,
119
+ remote=remote,
120
+ )
121
+
122
+ artifact_dir = save_per_question_vectors(
123
+ root_dir=activations_dir,
124
+ model_name=model_name,
125
+ prompt_variant=variant,
126
+ persona_id=persona.id,
127
+ persona_name=persona.name,
128
+ per_question_vectors=per_question_vectors,
129
+ questions=questions,
130
+ )
131
+
132
+ results.append(
133
+ VariantExtractionResult(
134
+ variant=variant,
135
+ output_dir=str(artifact_dir),
136
+ n_questions=per_question_vectors.shape[0],
137
+ n_layers=per_question_vectors.shape[1],
138
+ d_model=per_question_vectors.shape[2],
139
+ persona_name=persona.name,
140
+ )
141
+ )
142
+
143
+ # Free activation tensors between variants to keep memory bounded.
144
+ del per_question_vectors, full_texts, token_masks
145
+ gc.collect()
146
+ if torch.cuda.is_available():
147
+ torch.cuda.empty_cache()
148
+ if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
149
+ torch.mps.empty_cache()
150
+
151
+ return results
utils/helpers.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from persona_data.synth_persona import PersonaData
2
+
3
+ # Variant key -> human-readable label mapping
4
+ VARIANT_LABELS = {
5
+ "empty": "None",
6
+ "templated": "Template",
7
+ "biography": "Biography",
8
+ "custom": "Custom",
9
+ }
10
+
11
+ # Variants that correspond to actual system prompts (excludes "empty")
12
+ PROMPT_VARIANTS = ["templated", "biography"]
13
+
14
+ # For selectbox options: list of labels in definition order
15
+ MODE_LABELS = list(VARIANT_LABELS.values())
16
+
17
+ # Reverse lookup: label -> key
18
+ MODE_LABEL_TO_KEY = {v: k for k, v in VARIANT_LABELS.items()}
19
+
20
+ DATASET_SOURCES = ["HuggingFace: synth-persona", "Local JSONL upload"]
21
+ ANALYSIS_MODES = ["Cosine similarity", "PCA", "UMAP"]
22
+
23
+ ANALYSIS_LABELS = {
24
+ "PCA": ("PCA", "PC1", "PC2"),
25
+ "UMAP": ("UMAP", "UMAP 1", "UMAP 2"),
26
+ }
27
+
28
+ ANALYSIS_HELP_TEXT = {
29
+ "Cosine similarity": "Compare layer-wise alignment between variants.",
30
+ "PCA": "Project the selected layers into a global 2D view.",
31
+ "UMAP": "Project the selected layers into a local-neighborhood 2D view.",
32
+ }
33
+
34
+
35
+ def slugify(value: str) -> str:
36
+ """Convert a string to a slug safe for filenames and URLs."""
37
+
38
+ import re
39
+
40
+ return re.sub(r"[^a-z0-9]+", "_", value.lower()).strip("_") or "unknown"
41
+
42
+
43
+ def widget_key(*parts: str) -> str:
44
+ """Generate a namespaced Streamlit widget key from parts."""
45
+
46
+ return "::".join(parts)
47
+
48
+
49
+ def prompt_variant_label(variant: str) -> str:
50
+ """Return a human-friendly prompt-variant label."""
51
+
52
+ return VARIANT_LABELS.get(variant, variant.title())
53
+
54
+
55
+ def persona_label(persona: PersonaData) -> str:
56
+ """Format a persona for selection widgets."""
57
+
58
+ return f"{persona.name} ({persona.id})"
59
+
60
+
61
+ def persona_display_label(persona_id: str, persona_name: str | None) -> str:
62
+ """Format a persona id with an optional display name."""
63
+
64
+ if persona_name:
65
+ return f"{persona_name} ({persona_id})"
66
+ return persona_id
utils/local_dataset.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Iterator, Literal
6
+
7
+ from persona_data.synth_persona import PersonaData, QAPair
8
+
9
+
10
+ @dataclass
11
+ class LocalPersonaDataset:
12
+ """Dataset loaded from local JSONL files."""
13
+
14
+ personas_path: Path
15
+ qa_path: Path
16
+
17
+ def __post_init__(self) -> None:
18
+ with self.personas_path.open() as f:
19
+ self._personas: list[PersonaData] = []
20
+ for line in f:
21
+ if not line.strip():
22
+ continue
23
+ data = json.loads(line)
24
+ self._personas.append(
25
+ PersonaData(
26
+ id=data["id"],
27
+ persona=data["persona"],
28
+ templated_prompt=data["templated_prompt"],
29
+ biography_md=data["biography_md"],
30
+ )
31
+ )
32
+
33
+ self._qa: dict[str, list[QAPair]] = defaultdict(list)
34
+ with self.qa_path.open() as f:
35
+ for line in f:
36
+ if not line.strip():
37
+ continue
38
+ data = json.loads(line)
39
+ self._qa[data["id"]].append(
40
+ QAPair(
41
+ qid=data["qid"],
42
+ type=data["type"],
43
+ question=data["question"],
44
+ answer=data["answer"],
45
+ difficulty=data["difficulty"],
46
+ )
47
+ )
48
+
49
+ def __len__(self) -> int:
50
+ return len(self._personas)
51
+
52
+ def __iter__(self) -> Iterator[PersonaData]:
53
+ return iter(self._personas)
54
+
55
+ def __getitem__(self, idx: int) -> PersonaData:
56
+ return self._personas[idx]
57
+
58
+ def get_qa(
59
+ self,
60
+ persona_id: str,
61
+ type: Literal["explicit", "implicit"] | None = None,
62
+ difficulty: int | list[int] | None = None,
63
+ ) -> list[QAPair]:
64
+ pairs = self._qa.get(persona_id, [])
65
+ if type is not None:
66
+ pairs = [pair for pair in pairs if pair.type == type]
67
+
68
+ if difficulty is not None:
69
+ levels = {difficulty} if isinstance(difficulty, int) else set(difficulty)
70
+ pairs = [pair for pair in pairs if pair.difficulty in levels]
71
+
72
+ return pairs
utils/runtime.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import streamlit as st
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+
8
+ @st.cache_data(show_spinner=False, ttl=30)
9
+ def list_remote_models() -> list[str]:
10
+ """Return the NDIF language models that are currently running."""
11
+
12
+ import nnsight
13
+
14
+ try:
15
+ status = nnsight.ndif_status()
16
+ except Exception:
17
+ logger.warning("Failed to fetch NDIF status", exc_info=True)
18
+ return []
19
+
20
+ model_names: list[str] = []
21
+
22
+ for entry in status.values():
23
+ if not isinstance(entry, dict):
24
+ continue
25
+ if entry.get("model_class") not in {"LanguageModel", "StandardizedTransformer"}:
26
+ continue
27
+
28
+ state = entry.get("state")
29
+ state_name = getattr(state, "name", None) or getattr(state, "value", None)
30
+ if state_name != "RUNNING":
31
+ continue
32
+
33
+ repo_id = entry.get("repo_id")
34
+ if isinstance(repo_id, str):
35
+ model_names.append(repo_id)
36
+
37
+ return sorted(set(model_names))
38
+
39
+
40
+ @st.cache_resource(show_spinner=False, max_entries=1)
41
+ def cached_model(model_name: str, remote: bool):
42
+ """Load and cache a standardized nnterp model.
43
+
44
+ Streamlit reruns this app on every interaction, so caching keeps one loaded
45
+ model instance per ``(model_name, remote)`` instead of reloading weights on
46
+ every widget change.
47
+ """
48
+
49
+ from nnterp import StandardizedTransformer
50
+
51
+ # HACK: For now do it like this because of the bug.
52
+ # model = StandardizedTransformer(model_name, remote=True)
53
+ return StandardizedTransformer(model_name)
uv.lock ADDED
The diff for this file is too large to render. See raw diff