Spaces:
Sleeping
Sleeping
Jac-Zac commited on
Commit ·
f4259c0
1
Parent(s): 5bf7fd5
Cleaned up code with the new updates
Browse files- README.md +1 -4
- app.py +1 -1
- tabs/chat.py +327 -281
- tabs/compare.py +253 -166
- tabs/extract.py +32 -34
- utils/artifacts.py +0 -244
- utils/chat.py +0 -1
- utils/chat_export.py +8 -48
- utils/datasets.py +5 -1
- utils/helpers.py +15 -9
README.md
CHANGED
|
@@ -24,13 +24,10 @@ persona-ui/
|
|
| 24 |
│ ├── compare.py # Activation comparison tab
|
| 25 |
│ └── extract.py # Extraction tab
|
| 26 |
└── utils/
|
| 27 |
-
├── artifacts.py # Load saved activations metadata
|
| 28 |
├── chat.py # Chat generation logic
|
| 29 |
├── chat_export.py # Export chat logs to JSON
|
| 30 |
├── datasets.py # Dataset loader wrapper
|
| 31 |
-
├── extraction.py # Extraction orchestration
|
| 32 |
├── helpers.py # UI labels and slug helpers
|
| 33 |
-
├── local_dataset.py # Local JSONL dataset parsing
|
| 34 |
└── runtime.py # Model caching and NDIF queries
|
| 35 |
```
|
| 36 |
|
|
@@ -81,7 +78,7 @@ HF_HOME=... # Optional: HuggingFace cache directory
|
|
| 81 |
ARTIFACTS_DIR=... # Optional: where activations are read from (default: ./artifacts)
|
| 82 |
```
|
| 83 |
|
| 84 |
-
The app picks up this file automatically via `
|
| 85 |
|
| 86 |
## Saved Artifacts
|
| 87 |
|
|
|
|
| 24 |
│ ├── compare.py # Activation comparison tab
|
| 25 |
│ └── extract.py # Extraction tab
|
| 26 |
└── utils/
|
|
|
|
| 27 |
├── chat.py # Chat generation logic
|
| 28 |
├── chat_export.py # Export chat logs to JSON
|
| 29 |
├── datasets.py # Dataset loader wrapper
|
|
|
|
| 30 |
├── helpers.py # UI labels and slug helpers
|
|
|
|
| 31 |
└── runtime.py # Model caching and NDIF queries
|
| 32 |
```
|
| 33 |
|
|
|
|
| 78 |
ARTIFACTS_DIR=... # Optional: where activations are read from (default: ./artifacts)
|
| 79 |
```
|
| 80 |
|
| 81 |
+
The app picks up this file automatically via `load_dotenv()` on startup.
|
| 82 |
|
| 83 |
## Saved Artifacts
|
| 84 |
|
app.py
CHANGED
|
@@ -26,7 +26,7 @@ def _sidebar_controls() -> tuple[bool, str, str, str]:
|
|
| 26 |
if st.button(
|
| 27 |
tab_name,
|
| 28 |
key=f"sidebar__tab__{tab_name.lower()}",
|
| 29 |
-
|
| 30 |
type="primary" if is_selected else "secondary",
|
| 31 |
icon=icon,
|
| 32 |
):
|
|
|
|
| 26 |
if st.button(
|
| 27 |
tab_name,
|
| 28 |
key=f"sidebar__tab__{tab_name.lower()}",
|
| 29 |
+
width="stretch",
|
| 30 |
type="primary" if is_selected else "secondary",
|
| 31 |
icon=icon,
|
| 32 |
):
|
tabs/chat.py
CHANGED
|
@@ -1,10 +1,15 @@
|
|
| 1 |
-
import threading
|
| 2 |
from concurrent.futures import ThreadPoolExecutor
|
| 3 |
-
from
|
| 4 |
|
| 5 |
import streamlit as st
|
|
|
|
| 6 |
|
| 7 |
-
from state import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from utils.chat import ChatReply, generate_chat_reply, resolve_system_prompt
|
| 9 |
from utils.chat_export import save_chat_export
|
| 10 |
from utils.datasets import load_dataset
|
|
@@ -12,14 +17,12 @@ from utils.helpers import (
|
|
| 12 |
MODE_LABEL_TO_KEY,
|
| 13 |
MODE_LABELS,
|
| 14 |
VARIANT_LABELS,
|
|
|
|
| 15 |
persona_label,
|
| 16 |
widget_key,
|
| 17 |
)
|
| 18 |
from utils.runtime import cached_model
|
| 19 |
|
| 20 |
-
_VISIBLE_MESSAGE_COUNT = 5
|
| 21 |
-
_model_lock = threading.Lock()
|
| 22 |
-
|
| 23 |
|
| 24 |
def _render_chat_message(message: dict[str, str]) -> None:
|
| 25 |
if not message.get("content"):
|
|
@@ -33,6 +36,21 @@ def _clear_chat_ui_state(*keys: str) -> None:
|
|
| 33 |
st.session_state.pop(key, None)
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
def _generation_dict(gen_kwargs: dict, advanced_generation: bool) -> dict[str, object]:
|
| 37 |
return {
|
| 38 |
"max_new_tokens": int(gen_kwargs["max_new_tokens"]),
|
|
@@ -46,186 +64,146 @@ def _generation_dict(gen_kwargs: dict, advanced_generation: bool) -> dict[str, o
|
|
| 46 |
}
|
| 47 |
|
| 48 |
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
def _panel_state(panel_key: str) -> dict:
|
| 53 |
-
"""Get or initialise compare-panel chat state stored in session_state."""
|
| 54 |
-
if panel_key not in st.session_state:
|
| 55 |
-
st.session_state[panel_key] = {
|
| 56 |
-
"messages": [],
|
| 57 |
-
"persona_id": None,
|
| 58 |
-
"prompt_mode": "templated",
|
| 59 |
-
"past_key_values": None,
|
| 60 |
-
}
|
| 61 |
-
return st.session_state[panel_key]
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def _render_compare_panel(
|
| 65 |
-
side: str,
|
| 66 |
-
context_key: str,
|
| 67 |
-
personas: list,
|
| 68 |
-
remote: bool,
|
| 69 |
-
model_name: str,
|
| 70 |
-
dataset_source: str,
|
| 71 |
-
gen_kwargs: dict,
|
| 72 |
-
advanced_generation: bool,
|
| 73 |
-
) -> dict:
|
| 74 |
-
"""Render persona/prompt controls + chat log for one compare panel.
|
| 75 |
-
|
| 76 |
-
Returns a dict with keys needed by the generation step:
|
| 77 |
-
panel_key, state, active_system_prompt, selected_persona, chat_log
|
| 78 |
-
"""
|
| 79 |
-
panel_key = widget_key(context_key, f"cmp_{side}")
|
| 80 |
-
state = _panel_state(panel_key)
|
| 81 |
-
|
| 82 |
-
# ── Per-panel selectors ──────────────────────────────────────────────────
|
| 83 |
-
p_col, m_col = st.columns([3, 2])
|
| 84 |
with p_col:
|
| 85 |
selected_index = next(
|
| 86 |
-
(i for i, p in enumerate(personas) if p.id ==
|
| 87 |
)
|
| 88 |
selected_persona = st.selectbox(
|
| 89 |
"Persona",
|
| 90 |
options=personas,
|
| 91 |
index=selected_index,
|
| 92 |
format_func=persona_label,
|
| 93 |
-
key=
|
| 94 |
)
|
| 95 |
with m_col:
|
| 96 |
-
current_label = VARIANT_LABELS.get(
|
| 97 |
prompt_mode_label = st.selectbox(
|
| 98 |
"Prompt",
|
| 99 |
options=MODE_LABELS,
|
| 100 |
index=MODE_LABELS.index(current_label),
|
| 101 |
-
key=
|
| 102 |
)
|
| 103 |
prompt_mode = MODE_LABEL_TO_KEY[prompt_mode_label]
|
| 104 |
-
|
| 105 |
-
# Reset state when persona or mode changes.
|
| 106 |
changed = (
|
| 107 |
-
|
| 108 |
-
or state["prompt_mode"] != prompt_mode
|
| 109 |
)
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
)
|
|
|
|
| 119 |
|
| 120 |
-
# ── System prompt ────────────────────────────────────────────────────────
|
| 121 |
-
active_system_prompt = resolve_system_prompt(
|
| 122 |
-
persona=selected_persona, mode=prompt_mode
|
| 123 |
-
)
|
| 124 |
-
custom_prompt_key = widget_key(panel_key, "custom_prompt")
|
| 125 |
-
if prompt_mode != "empty":
|
| 126 |
-
if custom_prompt_key not in st.session_state:
|
| 127 |
-
st.session_state[custom_prompt_key] = active_system_prompt
|
| 128 |
-
with st.expander("Edit prompt", expanded=False):
|
| 129 |
-
active_system_prompt = (
|
| 130 |
-
st.text_area(
|
| 131 |
-
"prompt",
|
| 132 |
-
key=custom_prompt_key,
|
| 133 |
-
height=150,
|
| 134 |
-
label_visibility="collapsed",
|
| 135 |
-
)
|
| 136 |
-
or None
|
| 137 |
-
)
|
| 138 |
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
model_name=model_name,
|
| 149 |
-
dataset_source=dataset_source,
|
| 150 |
-
persona_id=selected_persona.id,
|
| 151 |
-
persona_name=getattr(selected_persona, "name", None),
|
| 152 |
-
panel_label=side,
|
| 153 |
-
prompt_mode=prompt_mode,
|
| 154 |
-
system_prompt=active_system_prompt,
|
| 155 |
-
messages=state["messages"],
|
| 156 |
-
generation=_generation_dict(gen_kwargs, advanced_generation),
|
| 157 |
-
)
|
| 158 |
-
export_success_message = f"Saved chat export to {export_path}"
|
| 159 |
-
with action_col2:
|
| 160 |
-
if st.button(
|
| 161 |
-
"Reset chat",
|
| 162 |
-
key=widget_key(panel_key, "reset"),
|
| 163 |
-
use_container_width=True,
|
| 164 |
-
type="secondary",
|
| 165 |
-
):
|
| 166 |
-
state["messages"] = []
|
| 167 |
-
state["past_key_values"] = None
|
| 168 |
-
_clear_chat_ui_state(
|
| 169 |
-
widget_key(panel_key, "custom_prompt"),
|
| 170 |
-
widget_key(panel_key, "show_all"),
|
| 171 |
-
)
|
| 172 |
-
st.rerun()
|
| 173 |
|
| 174 |
-
|
| 175 |
-
st.
|
| 176 |
-
|
| 177 |
-
# ── Message history ──────────────────────────────────────────────────────
|
| 178 |
-
show_all_key = widget_key(panel_key, "show_all")
|
| 179 |
-
messages = state["messages"]
|
| 180 |
-
if len(messages) > _VISIBLE_MESSAGE_COUNT and not st.session_state.get(
|
| 181 |
-
show_all_key, False
|
| 182 |
-
):
|
| 183 |
-
hidden_count = len(messages) - _VISIBLE_MESSAGE_COUNT
|
| 184 |
-
if st.button(
|
| 185 |
-
f"Show earlier ({hidden_count} hidden)",
|
| 186 |
-
key=widget_key(panel_key, "show_all_btn"),
|
| 187 |
):
|
| 188 |
-
|
| 189 |
-
st.
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
for msg in visible:
|
| 197 |
-
_render_chat_message(msg)
|
| 198 |
|
| 199 |
-
return
|
| 200 |
-
"panel_key": panel_key,
|
| 201 |
-
"state": state,
|
| 202 |
-
"active_system_prompt": active_system_prompt,
|
| 203 |
-
"selected_persona": selected_persona,
|
| 204 |
-
"chat_log": chat_log,
|
| 205 |
-
}
|
| 206 |
|
| 207 |
|
| 208 |
-
def
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
|
| 231 |
def _render_compare_mode(
|
|
@@ -233,35 +211,90 @@ def _render_compare_mode(
|
|
| 233 |
model_name: str,
|
| 234 |
context_key: str,
|
| 235 |
dataset_source: str,
|
| 236 |
-
personas: list,
|
| 237 |
gen_kwargs: dict,
|
| 238 |
advanced_generation: bool,
|
| 239 |
) -> None:
|
| 240 |
"""Render the full side-by-side comparison UI."""
|
| 241 |
left_col, right_col = st.columns(2)
|
| 242 |
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
personas,
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
advanced_generation,
|
| 253 |
)
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
"
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
user_prompt = st.chat_input(
|
| 267 |
"Ask both...",
|
|
@@ -271,43 +304,73 @@ def _render_compare_mode(
|
|
| 271 |
return
|
| 272 |
|
| 273 |
model = cached_model(model_name=model_name, remote=remote)
|
| 274 |
-
panels = [
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
-
for
|
| 277 |
-
|
| 278 |
-
with
|
| 279 |
-
|
| 280 |
-
_render_chat_message({"role": "user", "content": user_prompt})
|
| 281 |
|
| 282 |
-
# Generate both responses in parallel (remote: truly concurrent; local: serialised via lock).
|
| 283 |
with st.spinner("Generating..."):
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
results = []
|
| 290 |
-
for
|
| 291 |
try:
|
| 292 |
-
results.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
except Exception as exc:
|
| 294 |
results.append(exc)
|
| 295 |
|
| 296 |
-
for (
|
| 297 |
if isinstance(result, Exception):
|
| 298 |
-
with
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
panel["state"]["messages"].pop()
|
| 302 |
continue
|
| 303 |
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
with col:
|
| 309 |
-
with panel["chat_log"]:
|
| 310 |
-
_render_chat_message({"role": "assistant", "content": result.text})
|
| 311 |
|
| 312 |
|
| 313 |
# ── Main tab entry point ───────────────────────────────────────────────────────
|
|
@@ -465,6 +528,12 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 465 |
# ── Single-chat mode ──────────────────────────────────────────────────────
|
| 466 |
persona_select_key = widget_key(context_key, "persona_select")
|
| 467 |
prompt_mode_select_key = widget_key(context_key, "system_prompt_select")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
|
| 469 |
col1, col2 = st.columns([2, 1])
|
| 470 |
with col1:
|
|
@@ -481,66 +550,35 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 481 |
)
|
| 482 |
with col2:
|
| 483 |
current_mode_label = VARIANT_LABELS.get(chat_state["prompt_mode"], "None")
|
| 484 |
-
|
| 485 |
"Prompt",
|
| 486 |
options=MODE_LABELS,
|
| 487 |
index=MODE_LABELS.index(current_mode_label),
|
| 488 |
key=prompt_mode_select_key,
|
| 489 |
)
|
| 490 |
-
prompt_mode = MODE_LABEL_TO_KEY[
|
| 491 |
|
| 492 |
active_system_prompt = resolve_system_prompt(
|
| 493 |
persona=selected_persona,
|
| 494 |
mode=prompt_mode,
|
| 495 |
)
|
| 496 |
|
| 497 |
-
chat_input_key = widget_key(context_key, "chat_input")
|
| 498 |
-
show_all_key = widget_key(context_key, "show_all_messages")
|
| 499 |
-
custom_prompt_key = widget_key(context_key, "custom_system_prompt")
|
| 500 |
-
pending_key = widget_key(context_key, "pending_prompt")
|
| 501 |
-
export_success_message: str | None = None
|
| 502 |
-
|
| 503 |
-
action_col1, action_col2 = st.columns(2)
|
| 504 |
-
with action_col1:
|
| 505 |
-
if st.button("Reset chat", use_container_width=True, type="secondary"):
|
| 506 |
-
reset_chat_state(model_name, remote, dataset_source)
|
| 507 |
-
_clear_chat_ui_state(
|
| 508 |
-
chat_input_key,
|
| 509 |
-
show_all_key,
|
| 510 |
-
custom_prompt_key,
|
| 511 |
-
pending_key,
|
| 512 |
-
)
|
| 513 |
-
st.rerun()
|
| 514 |
-
with action_col2:
|
| 515 |
-
if st.button("Export chat", use_container_width=True):
|
| 516 |
-
export_path = save_chat_export(
|
| 517 |
-
model_name=model_name,
|
| 518 |
-
dataset_source=dataset_source,
|
| 519 |
-
persona_id=selected_persona.id,
|
| 520 |
-
persona_name=getattr(selected_persona, "name", None),
|
| 521 |
-
prompt_mode=prompt_mode,
|
| 522 |
-
system_prompt=active_system_prompt,
|
| 523 |
-
messages=chat_state["messages"],
|
| 524 |
-
generation=_generation_dict(gen_kwargs, advanced_generation),
|
| 525 |
-
)
|
| 526 |
-
export_success_message = f"Saved chat export to {export_path}"
|
| 527 |
-
|
| 528 |
-
if export_success_message:
|
| 529 |
-
st.success(export_success_message)
|
| 530 |
-
|
| 531 |
changed_context = (
|
| 532 |
chat_state["persona_id"] != selected_persona.id
|
| 533 |
or chat_state["prompt_mode"] != prompt_mode
|
| 534 |
)
|
| 535 |
if changed_context:
|
| 536 |
had_history = bool(chat_state["messages"])
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
|
|
|
|
|
|
|
|
|
| 541 |
chat_input_key,
|
| 542 |
show_all_key,
|
| 543 |
-
|
| 544 |
pending_key,
|
| 545 |
)
|
| 546 |
if had_history:
|
|
@@ -548,40 +586,51 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 548 |
|
| 549 |
chat_log = st.container()
|
| 550 |
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
active_system_prompt = (
|
| 558 |
-
st.text_area(
|
| 559 |
-
"Prompt",
|
| 560 |
-
key=custom_prompt_key,
|
| 561 |
-
height=200,
|
| 562 |
-
label_visibility="collapsed",
|
| 563 |
-
)
|
| 564 |
-
or None
|
| 565 |
-
)
|
| 566 |
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
if
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 582 |
|
| 583 |
-
|
| 584 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 585 |
|
| 586 |
user_prompt = st.chat_input(
|
| 587 |
"Ask something...",
|
|
@@ -598,10 +647,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 598 |
if not st.session_state.pop(pending_key, False):
|
| 599 |
return
|
| 600 |
|
| 601 |
-
messages = []
|
| 602 |
-
if active_system_prompt:
|
| 603 |
-
messages.append({"role": "system", "content": active_system_prompt})
|
| 604 |
-
messages.extend(chat_state["messages"])
|
| 605 |
|
| 606 |
with st.spinner("Generating reply..."):
|
| 607 |
model = cached_model(model_name=model_name, remote=remote)
|
|
|
|
|
|
|
| 1 |
from concurrent.futures import ThreadPoolExecutor
|
| 2 |
+
from typing import Any
|
| 3 |
|
| 4 |
import streamlit as st
|
| 5 |
+
from persona_data.synth_persona import PersonaData
|
| 6 |
|
| 7 |
+
from state import (
|
| 8 |
+
_default_chat_state,
|
| 9 |
+
chat_session_key,
|
| 10 |
+
get_chat_state,
|
| 11 |
+
reset_chat_state,
|
| 12 |
+
)
|
| 13 |
from utils.chat import ChatReply, generate_chat_reply, resolve_system_prompt
|
| 14 |
from utils.chat_export import save_chat_export
|
| 15 |
from utils.datasets import load_dataset
|
|
|
|
| 17 |
MODE_LABEL_TO_KEY,
|
| 18 |
MODE_LABELS,
|
| 19 |
VARIANT_LABELS,
|
| 20 |
+
VISIBLE_MESSAGE_COUNT,
|
| 21 |
persona_label,
|
| 22 |
widget_key,
|
| 23 |
)
|
| 24 |
from utils.runtime import cached_model
|
| 25 |
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
def _render_chat_message(message: dict[str, str]) -> None:
|
| 28 |
if not message.get("content"):
|
|
|
|
| 36 |
st.session_state.pop(key, None)
|
| 37 |
|
| 38 |
|
| 39 |
+
def _reset_single_chat_context(
|
| 40 |
+
model_name: str,
|
| 41 |
+
remote: bool,
|
| 42 |
+
dataset_source: str,
|
| 43 |
+
chat_state: dict[str, object],
|
| 44 |
+
persona_id: str,
|
| 45 |
+
prompt_mode: str,
|
| 46 |
+
*ui_keys: str,
|
| 47 |
+
) -> None:
|
| 48 |
+
reset_chat_state(model_name, remote, dataset_source)
|
| 49 |
+
chat_state["persona_id"] = persona_id
|
| 50 |
+
chat_state["prompt_mode"] = prompt_mode
|
| 51 |
+
_clear_chat_ui_state(*ui_keys)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
def _generation_dict(gen_kwargs: dict, advanced_generation: bool) -> dict[str, object]:
|
| 55 |
return {
|
| 56 |
"max_new_tokens": int(gen_kwargs["max_new_tokens"]),
|
|
|
|
| 64 |
}
|
| 65 |
|
| 66 |
|
| 67 |
+
def _render_persona_prompt_controls(
|
| 68 |
+
personas: list[PersonaData],
|
| 69 |
+
current_persona_id: str | None,
|
| 70 |
+
current_prompt_mode: str,
|
| 71 |
+
persona_key: str,
|
| 72 |
+
prompt_key: str,
|
| 73 |
+
column_widths: tuple[int, int] = (3, 2),
|
| 74 |
+
) -> tuple[PersonaData, str, bool]:
|
| 75 |
+
"""Render persona and prompt selectors, returning the selected values."""
|
| 76 |
|
| 77 |
+
p_col, m_col = st.columns(list(column_widths))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
with p_col:
|
| 79 |
selected_index = next(
|
| 80 |
+
(i for i, p in enumerate(personas) if p.id == current_persona_id), 0
|
| 81 |
)
|
| 82 |
selected_persona = st.selectbox(
|
| 83 |
"Persona",
|
| 84 |
options=personas,
|
| 85 |
index=selected_index,
|
| 86 |
format_func=persona_label,
|
| 87 |
+
key=persona_key,
|
| 88 |
)
|
| 89 |
with m_col:
|
| 90 |
+
current_label = VARIANT_LABELS.get(current_prompt_mode, "None")
|
| 91 |
prompt_mode_label = st.selectbox(
|
| 92 |
"Prompt",
|
| 93 |
options=MODE_LABELS,
|
| 94 |
index=MODE_LABELS.index(current_label),
|
| 95 |
+
key=prompt_key,
|
| 96 |
)
|
| 97 |
prompt_mode = MODE_LABEL_TO_KEY[prompt_mode_label]
|
|
|
|
|
|
|
| 98 |
changed = (
|
| 99 |
+
current_persona_id != selected_persona.id or current_prompt_mode != prompt_mode
|
|
|
|
| 100 |
)
|
| 101 |
+
return selected_persona, prompt_mode, changed
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _render_system_prompt_editor(
|
| 105 |
+
prompt_key: str,
|
| 106 |
+
prompt_mode: str,
|
| 107 |
+
active_system_prompt: str | None,
|
| 108 |
+
*,
|
| 109 |
+
height: int,
|
| 110 |
+
label: str = "Prompt",
|
| 111 |
+
) -> str | None:
|
| 112 |
+
"""Render the editable system prompt area for a chat panel."""
|
| 113 |
+
|
| 114 |
+
if prompt_mode == "empty":
|
| 115 |
+
return active_system_prompt
|
| 116 |
+
|
| 117 |
+
if prompt_key not in st.session_state:
|
| 118 |
+
st.session_state[prompt_key] = active_system_prompt or ""
|
| 119 |
+
|
| 120 |
+
with st.expander("Edit prompt", expanded=False):
|
| 121 |
+
edited_prompt = (
|
| 122 |
+
st.text_area(
|
| 123 |
+
label,
|
| 124 |
+
key=prompt_key,
|
| 125 |
+
height=height,
|
| 126 |
+
label_visibility="collapsed",
|
| 127 |
+
)
|
| 128 |
+
or None
|
| 129 |
)
|
| 130 |
+
return edited_prompt
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
+
def _render_chat_window(
|
| 134 |
+
*,
|
| 135 |
+
chat_log: Any,
|
| 136 |
+
messages: list[dict[str, str]],
|
| 137 |
+
show_all_key: str,
|
| 138 |
+
show_all_btn_key: str,
|
| 139 |
+
show_earlier_label: str,
|
| 140 |
+
) -> Any:
|
| 141 |
+
"""Render the visible chat history inside one container."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
+
with chat_log:
|
| 144 |
+
if len(messages) > VISIBLE_MESSAGE_COUNT and not st.session_state.get(
|
| 145 |
+
show_all_key, False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
):
|
| 147 |
+
hidden_count = len(messages) - VISIBLE_MESSAGE_COUNT
|
| 148 |
+
if st.button(
|
| 149 |
+
f"{show_earlier_label} ({hidden_count} hidden)",
|
| 150 |
+
key=show_all_btn_key,
|
| 151 |
+
):
|
| 152 |
+
st.session_state[show_all_key] = True
|
| 153 |
+
st.rerun()
|
| 154 |
+
visible_messages = messages[-VISIBLE_MESSAGE_COUNT:]
|
| 155 |
+
else:
|
| 156 |
+
visible_messages = messages
|
| 157 |
|
| 158 |
+
for message in visible_messages:
|
| 159 |
+
_render_chat_message(message)
|
|
|
|
|
|
|
| 160 |
|
| 161 |
+
return chat_log
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
|
| 164 |
+
def _build_chat_messages(
|
| 165 |
+
system_prompt: str | None,
|
| 166 |
+
messages: list[dict[str, str]],
|
| 167 |
+
) -> list[dict[str, str]]:
|
| 168 |
+
return (
|
| 169 |
+
[{"role": "system", "content": system_prompt}] if system_prompt else []
|
| 170 |
+
) + messages
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _save_chat_export_message(
|
| 174 |
+
*,
|
| 175 |
+
model_name: str,
|
| 176 |
+
dataset_source: str,
|
| 177 |
+
persona_id: str,
|
| 178 |
+
persona_name: str | None,
|
| 179 |
+
prompt_mode: str,
|
| 180 |
+
system_prompt: str | None,
|
| 181 |
+
messages: list[dict[str, str]],
|
| 182 |
+
generation: dict[str, object],
|
| 183 |
+
panel_label: str | None = None,
|
| 184 |
+
) -> str:
|
| 185 |
+
export_path = save_chat_export(
|
| 186 |
+
model_name=model_name,
|
| 187 |
+
dataset_source=dataset_source,
|
| 188 |
+
persona_id=persona_id,
|
| 189 |
+
persona_name=persona_name,
|
| 190 |
+
panel_label=panel_label,
|
| 191 |
+
prompt_mode=prompt_mode,
|
| 192 |
+
system_prompt=system_prompt,
|
| 193 |
+
messages=messages,
|
| 194 |
+
generation=generation,
|
| 195 |
+
)
|
| 196 |
+
return f"Saved chat export to {export_path}"
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
# ── Compare mode helpers ───────────────────────────────────────────────────────
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def _panel_state(panel_key: str) -> dict:
|
| 203 |
+
"""Get or initialise compare-panel chat state stored in session_state."""
|
| 204 |
+
if panel_key not in st.session_state:
|
| 205 |
+
st.session_state[panel_key] = _default_chat_state()
|
| 206 |
+
return st.session_state[panel_key]
|
| 207 |
|
| 208 |
|
| 209 |
def _render_compare_mode(
|
|
|
|
| 211 |
model_name: str,
|
| 212 |
context_key: str,
|
| 213 |
dataset_source: str,
|
| 214 |
+
personas: list[PersonaData],
|
| 215 |
gen_kwargs: dict,
|
| 216 |
advanced_generation: bool,
|
| 217 |
) -> None:
|
| 218 |
"""Render the full side-by-side comparison UI."""
|
| 219 |
left_col, right_col = st.columns(2)
|
| 220 |
|
| 221 |
+
def render_panel(side: str, column) -> tuple[dict[str, object], Any, str | None]:
|
| 222 |
+
panel_key = widget_key(context_key, f"cmp_{side}")
|
| 223 |
+
state = st.session_state.get(panel_key)
|
| 224 |
+
if state is None:
|
| 225 |
+
state = _default_chat_state()
|
| 226 |
+
st.session_state[panel_key] = state
|
| 227 |
+
prompt_key = widget_key(panel_key, "custom_prompt")
|
| 228 |
+
show_all_key = widget_key(panel_key, "show_all")
|
| 229 |
+
|
| 230 |
+
selected_persona, prompt_mode, changed = _render_persona_prompt_controls(
|
| 231 |
personas,
|
| 232 |
+
state["persona_id"],
|
| 233 |
+
state["prompt_mode"],
|
| 234 |
+
widget_key(panel_key, "persona"),
|
| 235 |
+
widget_key(panel_key, "prompt_mode"),
|
|
|
|
| 236 |
)
|
| 237 |
+
if changed:
|
| 238 |
+
state["messages"] = []
|
| 239 |
+
state["past_key_values"] = None
|
| 240 |
+
state["persona_id"] = selected_persona.id
|
| 241 |
+
state["prompt_mode"] = prompt_mode
|
| 242 |
+
_clear_chat_ui_state(prompt_key, show_all_key)
|
| 243 |
+
|
| 244 |
+
active_system_prompt = resolve_system_prompt(
|
| 245 |
+
persona=selected_persona, mode=prompt_mode
|
| 246 |
+
)
|
| 247 |
+
active_system_prompt = _render_system_prompt_editor(
|
| 248 |
+
prompt_key,
|
| 249 |
+
prompt_mode,
|
| 250 |
+
active_system_prompt,
|
| 251 |
+
height=150,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
btn_col1, btn_col2 = st.columns(2)
|
| 255 |
+
with btn_col1:
|
| 256 |
+
if st.button(
|
| 257 |
+
"Export chat", key=widget_key(panel_key, "export_chat"), width="stretch"
|
| 258 |
+
):
|
| 259 |
+
st.success(
|
| 260 |
+
_save_chat_export_message(
|
| 261 |
+
model_name=model_name,
|
| 262 |
+
dataset_source=dataset_source,
|
| 263 |
+
persona_id=selected_persona.id,
|
| 264 |
+
persona_name=getattr(selected_persona, "name", None),
|
| 265 |
+
prompt_mode=prompt_mode,
|
| 266 |
+
system_prompt=active_system_prompt,
|
| 267 |
+
messages=state["messages"],
|
| 268 |
+
generation=_generation_dict(gen_kwargs, advanced_generation),
|
| 269 |
+
panel_label=side,
|
| 270 |
+
)
|
| 271 |
+
)
|
| 272 |
+
with btn_col2:
|
| 273 |
+
if st.button(
|
| 274 |
+
"Reset chat",
|
| 275 |
+
key=widget_key(panel_key, "reset"),
|
| 276 |
+
width="stretch",
|
| 277 |
+
type="secondary",
|
| 278 |
+
):
|
| 279 |
+
state["messages"] = []
|
| 280 |
+
state["past_key_values"] = None
|
| 281 |
+
_clear_chat_ui_state(prompt_key, show_all_key)
|
| 282 |
+
st.rerun()
|
| 283 |
+
|
| 284 |
+
chat_log = st.container()
|
| 285 |
+
_render_chat_window(
|
| 286 |
+
chat_log=chat_log,
|
| 287 |
+
messages=state["messages"],
|
| 288 |
+
show_all_key=show_all_key,
|
| 289 |
+
show_all_btn_key=widget_key(panel_key, "show_all_btn"),
|
| 290 |
+
show_earlier_label="Show earlier",
|
| 291 |
)
|
| 292 |
+
return state, chat_log, active_system_prompt
|
| 293 |
+
|
| 294 |
+
with left_col:
|
| 295 |
+
left_state, left_log, left_prompt = render_panel("left", left_col)
|
| 296 |
+
with right_col:
|
| 297 |
+
right_state, right_log, right_prompt = render_panel("right", right_col)
|
| 298 |
|
| 299 |
user_prompt = st.chat_input(
|
| 300 |
"Ask both...",
|
|
|
|
| 304 |
return
|
| 305 |
|
| 306 |
model = cached_model(model_name=model_name, remote=remote)
|
| 307 |
+
panels = [
|
| 308 |
+
(left_state, left_log, left_prompt),
|
| 309 |
+
(right_state, right_log, right_prompt),
|
| 310 |
+
]
|
| 311 |
|
| 312 |
+
for panel_state, panel_log, _panel_prompt in panels:
|
| 313 |
+
panel_state["messages"].append({"role": "user", "content": user_prompt})
|
| 314 |
+
with panel_log:
|
| 315 |
+
_render_chat_message({"role": "user", "content": user_prompt})
|
|
|
|
| 316 |
|
|
|
|
| 317 |
with st.spinner("Generating..."):
|
| 318 |
+
if remote:
|
| 319 |
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
| 320 |
+
futures = [
|
| 321 |
+
executor.submit(
|
| 322 |
+
generate_chat_reply,
|
| 323 |
+
model=model,
|
| 324 |
+
messages=(
|
| 325 |
+
[{"role": "system", "content": panel_prompt}]
|
| 326 |
+
if panel_prompt
|
| 327 |
+
else []
|
| 328 |
+
)
|
| 329 |
+
+ panel_state["messages"],
|
| 330 |
+
remote=remote,
|
| 331 |
+
past_key_values=panel_state["past_key_values"],
|
| 332 |
+
**gen_kwargs,
|
| 333 |
+
)
|
| 334 |
+
for panel_state, _panel_log, panel_prompt in panels
|
| 335 |
+
]
|
| 336 |
+
results: list[ChatReply | Exception] = []
|
| 337 |
+
for future in futures:
|
| 338 |
+
try:
|
| 339 |
+
results.append(future.result())
|
| 340 |
+
except Exception as exc:
|
| 341 |
+
results.append(exc)
|
| 342 |
+
else:
|
| 343 |
results = []
|
| 344 |
+
for panel_state, _panel_log, panel_prompt in panels:
|
| 345 |
try:
|
| 346 |
+
results.append(
|
| 347 |
+
generate_chat_reply(
|
| 348 |
+
model=model,
|
| 349 |
+
messages=(
|
| 350 |
+
[{"role": "system", "content": panel_prompt}]
|
| 351 |
+
if panel_prompt
|
| 352 |
+
else []
|
| 353 |
+
)
|
| 354 |
+
+ panel_state["messages"],
|
| 355 |
+
remote=remote,
|
| 356 |
+
past_key_values=panel_state["past_key_values"],
|
| 357 |
+
**gen_kwargs,
|
| 358 |
+
)
|
| 359 |
+
)
|
| 360 |
except Exception as exc:
|
| 361 |
results.append(exc)
|
| 362 |
|
| 363 |
+
for (panel_state, panel_log, _panel_prompt), result in zip(panels, results):
|
| 364 |
if isinstance(result, Exception):
|
| 365 |
+
with panel_log:
|
| 366 |
+
st.error(f"Generation failed: {result}")
|
| 367 |
+
panel_state["messages"].pop()
|
|
|
|
| 368 |
continue
|
| 369 |
|
| 370 |
+
panel_state["messages"].append({"role": "assistant", "content": result.text})
|
| 371 |
+
panel_state["past_key_values"] = result.past_key_values if not remote else None
|
| 372 |
+
with panel_log:
|
| 373 |
+
_render_chat_message({"role": "assistant", "content": result.text})
|
|
|
|
|
|
|
|
|
|
| 374 |
|
| 375 |
|
| 376 |
# ── Main tab entry point ───────────────────────────────────────────────────────
|
|
|
|
| 528 |
# ── Single-chat mode ──────────────────────────────────────────────────────
|
| 529 |
persona_select_key = widget_key(context_key, "persona_select")
|
| 530 |
prompt_mode_select_key = widget_key(context_key, "system_prompt_select")
|
| 531 |
+
prompt_key = widget_key(context_key, "custom_system_prompt")
|
| 532 |
+
show_all_key = widget_key(context_key, "show_all_messages")
|
| 533 |
+
chat_input_key = widget_key(context_key, "chat_input")
|
| 534 |
+
pending_key = widget_key(context_key, "pending_prompt")
|
| 535 |
+
export_key = widget_key(context_key, "export_chat")
|
| 536 |
+
reset_key = widget_key(context_key, "reset")
|
| 537 |
|
| 538 |
col1, col2 = st.columns([2, 1])
|
| 539 |
with col1:
|
|
|
|
| 550 |
)
|
| 551 |
with col2:
|
| 552 |
current_mode_label = VARIANT_LABELS.get(chat_state["prompt_mode"], "None")
|
| 553 |
+
st.selectbox(
|
| 554 |
"Prompt",
|
| 555 |
options=MODE_LABELS,
|
| 556 |
index=MODE_LABELS.index(current_mode_label),
|
| 557 |
key=prompt_mode_select_key,
|
| 558 |
)
|
| 559 |
+
prompt_mode = MODE_LABEL_TO_KEY[st.session_state[prompt_mode_select_key]]
|
| 560 |
|
| 561 |
active_system_prompt = resolve_system_prompt(
|
| 562 |
persona=selected_persona,
|
| 563 |
mode=prompt_mode,
|
| 564 |
)
|
| 565 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
changed_context = (
|
| 567 |
chat_state["persona_id"] != selected_persona.id
|
| 568 |
or chat_state["prompt_mode"] != prompt_mode
|
| 569 |
)
|
| 570 |
if changed_context:
|
| 571 |
had_history = bool(chat_state["messages"])
|
| 572 |
+
_reset_single_chat_context(
|
| 573 |
+
model_name,
|
| 574 |
+
remote,
|
| 575 |
+
dataset_source,
|
| 576 |
+
chat_state,
|
| 577 |
+
selected_persona.id,
|
| 578 |
+
prompt_mode,
|
| 579 |
chat_input_key,
|
| 580 |
show_all_key,
|
| 581 |
+
prompt_key,
|
| 582 |
pending_key,
|
| 583 |
)
|
| 584 |
if had_history:
|
|
|
|
| 586 |
|
| 587 |
chat_log = st.container()
|
| 588 |
|
| 589 |
+
active_system_prompt = _render_system_prompt_editor(
|
| 590 |
+
prompt_key,
|
| 591 |
+
prompt_mode,
|
| 592 |
+
active_system_prompt,
|
| 593 |
+
height=200,
|
| 594 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 595 |
|
| 596 |
+
action_col1, action_col2 = st.columns(2)
|
| 597 |
+
with action_col1:
|
| 598 |
+
if st.button("Export chat", key=export_key, width="stretch"):
|
| 599 |
+
st.success(
|
| 600 |
+
_save_chat_export_message(
|
| 601 |
+
model_name=model_name,
|
| 602 |
+
dataset_source=dataset_source,
|
| 603 |
+
persona_id=selected_persona.id,
|
| 604 |
+
persona_name=getattr(selected_persona, "name", None),
|
| 605 |
+
prompt_mode=prompt_mode,
|
| 606 |
+
system_prompt=active_system_prompt,
|
| 607 |
+
messages=chat_state["messages"],
|
| 608 |
+
generation=_generation_dict(gen_kwargs, advanced_generation),
|
| 609 |
+
)
|
| 610 |
+
)
|
| 611 |
+
with action_col2:
|
| 612 |
+
if st.button("Reset chat", key=reset_key, width="stretch", type="secondary"):
|
| 613 |
+
_reset_single_chat_context(
|
| 614 |
+
model_name,
|
| 615 |
+
remote,
|
| 616 |
+
dataset_source,
|
| 617 |
+
chat_state,
|
| 618 |
+
selected_persona.id,
|
| 619 |
+
prompt_mode,
|
| 620 |
+
chat_input_key,
|
| 621 |
+
show_all_key,
|
| 622 |
+
prompt_key,
|
| 623 |
+
pending_key,
|
| 624 |
+
)
|
| 625 |
+
st.rerun()
|
| 626 |
|
| 627 |
+
_render_chat_window(
|
| 628 |
+
chat_log=chat_log,
|
| 629 |
+
messages=chat_state["messages"],
|
| 630 |
+
show_all_key=show_all_key,
|
| 631 |
+
show_all_btn_key=widget_key(context_key, "show_all_btn"),
|
| 632 |
+
show_earlier_label="Show earlier messages",
|
| 633 |
+
)
|
| 634 |
|
| 635 |
user_prompt = st.chat_input(
|
| 636 |
"Ask something...",
|
|
|
|
| 647 |
if not st.session_state.pop(pending_key, False):
|
| 648 |
return
|
| 649 |
|
| 650 |
+
messages = _build_chat_messages(active_system_prompt, chat_state["messages"])
|
|
|
|
|
|
|
|
|
|
| 651 |
|
| 652 |
with st.spinner("Generating reply..."):
|
| 653 |
model = cached_model(model_name=model_name, remote=remote)
|
tabs/compare.py
CHANGED
|
@@ -1,21 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
|
|
|
| 2 |
from persona_data.environment import get_artifacts_dir
|
| 3 |
from persona_vectors.analysis import build_embedding_figure, project_pca, project_umap
|
| 4 |
-
from persona_vectors.
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
|
| 10 |
-
from utils.artifacts import (
|
| 11 |
-
artifact_persona_options,
|
| 12 |
-
list_available_layers,
|
| 13 |
-
load_cosine_traces,
|
| 14 |
-
load_embedding_samples,
|
| 15 |
-
)
|
| 16 |
from utils.helpers import (
|
| 17 |
ANALYSIS_HELP_TEXT,
|
| 18 |
-
ANALYSIS_LABELS,
|
| 19 |
ANALYSIS_MODES,
|
| 20 |
PROMPT_VARIANTS,
|
| 21 |
persona_display_label,
|
|
@@ -29,15 +26,151 @@ def _filename(*parts: str) -> str:
|
|
| 29 |
return "__".join(slugify(part) for part in parts if part)
|
| 30 |
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
model_name: str,
|
| 35 |
variants: list[str],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
) -> tuple[list[str], dict[str, str]]:
|
| 37 |
-
persona_options
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
| 41 |
)
|
| 42 |
if not persona_options:
|
| 43 |
if len(variants) > 1:
|
|
@@ -55,15 +188,81 @@ def _select_artifact_personas(
|
|
| 55 |
format_func=lambda persona_id: persona_display_label(
|
| 56 |
persona_id, persona_names.get(persona_id)
|
| 57 |
),
|
| 58 |
-
key=widget_key("load", "personas", model_name, *variants),
|
| 59 |
)
|
| 60 |
return persona_ids, persona_names
|
| 61 |
|
| 62 |
|
| 63 |
-
def
|
| 64 |
-
|
| 65 |
-
|
|
|
|
| 66 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
col1, col2 = st.columns(2)
|
| 68 |
with col1:
|
| 69 |
variant_a = st.selectbox(
|
|
@@ -86,24 +285,16 @@ def _render_cosine_similarity(
|
|
| 86 |
st.warning("Choose two different variants to compare.")
|
| 87 |
return
|
| 88 |
|
| 89 |
-
persona_ids, _ = _select_artifact_personas(
|
| 90 |
-
artifacts_root,
|
| 91 |
-
model_name,
|
| 92 |
-
[variant_a, variant_b],
|
| 93 |
-
)
|
| 94 |
if not persona_ids:
|
| 95 |
return
|
| 96 |
|
| 97 |
-
cosine_fig_key = widget_key("load", "cosine_fig_state", model_name)
|
| 98 |
-
filename = _filename("compare", "cosine", model_name, variant_a, variant_b)
|
| 99 |
|
| 100 |
if st.button("Compare vectors", type="primary"):
|
| 101 |
-
traces, loaded_names, errors =
|
| 102 |
-
|
| 103 |
-
model_name,
|
| 104 |
-
persona_ids,
|
| 105 |
-
variant_a,
|
| 106 |
-
variant_b,
|
| 107 |
)
|
| 108 |
|
| 109 |
if errors:
|
|
@@ -125,7 +316,7 @@ def _render_cosine_similarity(
|
|
| 125 |
)
|
| 126 |
for persona_id, short, long in traces
|
| 127 |
]
|
| 128 |
-
fig =
|
| 129 |
display_traces,
|
| 130 |
title=f"{prompt_variant_label(variant_a)} vs {prompt_variant_label(variant_b)}",
|
| 131 |
show=False,
|
|
@@ -134,82 +325,27 @@ def _render_cosine_similarity(
|
|
| 134 |
|
| 135 |
if cosine_fig_key in st.session_state:
|
| 136 |
fig, n_traces = st.session_state[cosine_fig_key]
|
| 137 |
-
st.plotly_chart(fig,
|
| 138 |
-
|
| 139 |
-
with save_col1:
|
| 140 |
-
if st.button("Save HTML", key=widget_key("load", "save_cosine_html")):
|
| 141 |
-
output_path = save_plot_html(fig, filename)
|
| 142 |
-
st.success(f"Saved HTML to `{output_path}`")
|
| 143 |
-
with save_col2:
|
| 144 |
-
if st.button("Save PNG", key=widget_key("load", "save_cosine_png")):
|
| 145 |
-
try:
|
| 146 |
-
output_path = save_plot_png(fig, filename)
|
| 147 |
-
st.success(f"Saved PNG to `{output_path}`")
|
| 148 |
-
except Exception as exc:
|
| 149 |
-
st.error(f"Could not save PNG: {exc}")
|
| 150 |
st.success(f"Loaded {n_traces} personas for cosine comparison.")
|
| 151 |
|
| 152 |
|
| 153 |
-
def _render_embedding_analysis(
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
analysis_mode: str,
|
| 157 |
-
) -> None:
|
| 158 |
-
selected_variant = st.selectbox(
|
| 159 |
-
"Variant",
|
| 160 |
-
options=PROMPT_VARIANTS,
|
| 161 |
-
format_func=prompt_variant_label,
|
| 162 |
-
key=widget_key("load", "variant"),
|
| 163 |
-
)
|
| 164 |
-
|
| 165 |
-
persona_ids, persona_names = _select_artifact_personas(
|
| 166 |
-
artifacts_root,
|
| 167 |
-
model_name,
|
| 168 |
-
[selected_variant],
|
| 169 |
-
)
|
| 170 |
-
if not persona_ids:
|
| 171 |
-
return
|
| 172 |
-
|
| 173 |
-
layer_options = list_available_layers(
|
| 174 |
-
artifacts_root,
|
| 175 |
-
model_name,
|
| 176 |
-
[selected_variant],
|
| 177 |
-
persona_ids,
|
| 178 |
-
)
|
| 179 |
-
if not layer_options:
|
| 180 |
-
st.info(
|
| 181 |
-
"No shared layers are available for the selected personas. Try fewer personas or a different variant."
|
| 182 |
-
)
|
| 183 |
return
|
| 184 |
-
|
| 185 |
persona_key = "_".join(sorted(persona_ids))
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
for layer in st.session_state.get(layer_key, layer_options[:3])
|
| 190 |
-
if layer in layer_options
|
| 191 |
-
] or layer_options[:3]
|
| 192 |
-
selected_layers = st.multiselect(
|
| 193 |
-
"Layers",
|
| 194 |
-
options=layer_options,
|
| 195 |
-
default=default_layers,
|
| 196 |
-
key=layer_key,
|
| 197 |
-
)
|
| 198 |
-
if not selected_layers:
|
| 199 |
-
st.info("Select at least one layer.")
|
| 200 |
return
|
| 201 |
|
| 202 |
-
button_label = (
|
| 203 |
-
"Generate PCA projection"
|
| 204 |
-
if analysis_mode == "PCA"
|
| 205 |
-
else "Generate UMAP projection"
|
| 206 |
-
)
|
| 207 |
-
|
| 208 |
embedding_fig_key = widget_key(
|
| 209 |
-
"load", "embedding_fig_state", model_name, analysis_mode
|
| 210 |
)
|
| 211 |
|
| 212 |
-
if st.button(
|
| 213 |
progress = st.progress(0, text="Preparing projections...")
|
| 214 |
|
| 215 |
def update_progress(current: int, total: int, loaded: int) -> None:
|
|
@@ -219,15 +355,13 @@ def _render_embedding_analysis(
|
|
| 219 |
text=f"Processing layer {current}/{total} ({loaded} plot(s) ready)",
|
| 220 |
)
|
| 221 |
|
| 222 |
-
project_fn = project_pca if analysis_mode == "PCA" else project_umap
|
| 223 |
try:
|
| 224 |
-
plots, errors =
|
| 225 |
-
|
| 226 |
-
model_name,
|
| 227 |
persona_ids,
|
| 228 |
selected_variant,
|
| 229 |
selected_layers,
|
| 230 |
-
project_fn,
|
| 231 |
persona_names,
|
| 232 |
progress_fn=update_progress,
|
| 233 |
)
|
|
@@ -248,18 +382,7 @@ def _render_embedding_analysis(
|
|
| 248 |
st.info("Try fewer personas, fewer layers, or a different variant.")
|
| 249 |
st.session_state.pop(embedding_fig_key, None)
|
| 250 |
else:
|
| 251 |
-
|
| 252 |
-
rendered_figures: list[tuple[int, object]] = []
|
| 253 |
-
for layer_idx, coords, labels, hover_text in plots:
|
| 254 |
-
fig = build_embedding_figure(
|
| 255 |
-
coords=coords,
|
| 256 |
-
labels=labels,
|
| 257 |
-
title=f"{title_prefix}, layer {layer_idx}",
|
| 258 |
-
x_label=x_label,
|
| 259 |
-
y_label=y_label,
|
| 260 |
-
hover_text=hover_text,
|
| 261 |
-
)
|
| 262 |
-
rendered_figures.append((layer_idx, fig))
|
| 263 |
total_samples = sum(coords.shape[0] for _, coords, _, _ in plots)
|
| 264 |
st.session_state[embedding_fig_key] = (
|
| 265 |
rendered_figures,
|
|
@@ -274,52 +397,14 @@ def _render_embedding_analysis(
|
|
| 274 |
rendered_figures, saved_persona_key, saved_variant, total_samples = (
|
| 275 |
st.session_state[embedding_fig_key]
|
| 276 |
)
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
|
|
|
| 283 |
)
|
| 284 |
-
filenames = [
|
| 285 |
-
_filename(
|
| 286 |
-
"compare",
|
| 287 |
-
analysis_mode,
|
| 288 |
-
model_name,
|
| 289 |
-
saved_variant,
|
| 290 |
-
saved_persona_key,
|
| 291 |
-
str(layer_idx),
|
| 292 |
-
)
|
| 293 |
-
for layer_idx, _ in rendered_figures
|
| 294 |
-
]
|
| 295 |
-
save_col1, save_col2 = st.columns(2)
|
| 296 |
-
with save_col1:
|
| 297 |
-
if st.button(
|
| 298 |
-
"Save HTML",
|
| 299 |
-
key=widget_key("load", "save_embedding_html", analysis_mode),
|
| 300 |
-
):
|
| 301 |
-
saved_paths = [
|
| 302 |
-
save_plot_html(fig, fn)
|
| 303 |
-
for (_, fig), fn in zip(rendered_figures, filenames)
|
| 304 |
-
]
|
| 305 |
-
st.success(
|
| 306 |
-
f"Saved {len(saved_paths)} HTML plot(s) to `artifacts/plots`."
|
| 307 |
-
)
|
| 308 |
-
with save_col2:
|
| 309 |
-
if st.button(
|
| 310 |
-
"Save PNG",
|
| 311 |
-
key=widget_key("load", "save_embedding_png", analysis_mode),
|
| 312 |
-
):
|
| 313 |
-
try:
|
| 314 |
-
saved_paths = [
|
| 315 |
-
save_plot_png(fig, fn)
|
| 316 |
-
for (_, fig), fn in zip(rendered_figures, filenames)
|
| 317 |
-
]
|
| 318 |
-
st.success(
|
| 319 |
-
f"Saved {len(saved_paths)} PNG plot(s) to `artifacts/plots`."
|
| 320 |
-
)
|
| 321 |
-
except Exception as exc:
|
| 322 |
-
st.error(f"Could not save PNGs: {exc}")
|
| 323 |
|
| 324 |
|
| 325 |
def render_compare_tab(model_name: str) -> None:
|
|
@@ -336,6 +421,8 @@ def render_compare_tab(model_name: str) -> None:
|
|
| 336 |
value=str(get_artifacts_dir() / "activations"),
|
| 337 |
)
|
| 338 |
|
|
|
|
|
|
|
| 339 |
analysis_mode = st.segmented_control(
|
| 340 |
"Analysis mode",
|
| 341 |
options=ANALYSIS_MODES,
|
|
@@ -348,7 +435,7 @@ def render_compare_tab(model_name: str) -> None:
|
|
| 348 |
st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
|
| 349 |
|
| 350 |
if analysis_mode == "Cosine similarity":
|
| 351 |
-
_render_cosine_similarity(
|
| 352 |
return
|
| 353 |
|
| 354 |
-
_render_embedding_analysis(
|
|
|
|
| 1 |
+
from collections.abc import Callable
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
import streamlit as st
|
| 5 |
+
import torch
|
| 6 |
from persona_data.environment import get_artifacts_dir
|
| 7 |
from persona_vectors.analysis import build_embedding_figure, project_pca, project_umap
|
| 8 |
+
from persona_vectors.artifacts import ActivationStore
|
| 9 |
+
from persona_vectors.artifacts import list_layers as list_available_layers
|
| 10 |
+
from persona_vectors.artifacts import list_personas as list_available_personas
|
| 11 |
+
from persona_vectors.artifacts import load_mean_activations, load_persona_names
|
| 12 |
+
from persona_vectors.plots import plot_layer_similarity, save_plot_html, save_plot_png
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from utils.helpers import (
|
| 15 |
ANALYSIS_HELP_TEXT,
|
|
|
|
| 16 |
ANALYSIS_MODES,
|
| 17 |
PROMPT_VARIANTS,
|
| 18 |
persona_display_label,
|
|
|
|
| 26 |
return "__".join(slugify(part) for part in parts if part)
|
| 27 |
|
| 28 |
|
| 29 |
+
@dataclass(frozen=True)
|
| 30 |
+
class ProjectionConfig:
|
| 31 |
+
title_prefix: str
|
| 32 |
+
x_label: str
|
| 33 |
+
y_label: str
|
| 34 |
+
project_fn: Callable[[torch.Tensor], torch.Tensor]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
_PROJECTION_CONFIGS: dict[str, ProjectionConfig] = {
|
| 38 |
+
"PCA": ProjectionConfig("PCA", "PC1", "PC2", project_pca),
|
| 39 |
+
"UMAP": ProjectionConfig("UMAP", "UMAP 1", "UMAP 2", project_umap),
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@st.cache_data(show_spinner=False)
|
| 44 |
+
def _list_layers(
|
| 45 |
+
root_dir: str,
|
| 46 |
model_name: str,
|
| 47 |
variants: list[str],
|
| 48 |
+
persona_ids: list[str],
|
| 49 |
+
) -> list[int]:
|
| 50 |
+
return list_available_layers(root_dir, model_name, variants, persona_ids)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _load_embedding_samples(
|
| 54 |
+
store: ActivationStore,
|
| 55 |
+
persona_ids: list[str],
|
| 56 |
+
variant: str,
|
| 57 |
+
selected_layers: list[int],
|
| 58 |
+
project_fn: Callable[[torch.Tensor], torch.Tensor],
|
| 59 |
+
persona_names: dict[str, str],
|
| 60 |
+
progress_fn: Callable[[int, int, int], None] | None = None,
|
| 61 |
+
) -> tuple[list[tuple[int, torch.Tensor, list[str], list[str]]], list[str]]:
|
| 62 |
+
"""Load samples for 2D projections without re-reading each layer from disk."""
|
| 63 |
+
|
| 64 |
+
plots: list[tuple[int, torch.Tensor, list[str], list[str]]] = []
|
| 65 |
+
errors: list[str] = []
|
| 66 |
+
vectors_by_persona: dict[str, torch.Tensor] = {}
|
| 67 |
+
|
| 68 |
+
for persona_id in persona_ids:
|
| 69 |
+
try:
|
| 70 |
+
vectors, _ = store.load(variant, persona_id)
|
| 71 |
+
except (FileNotFoundError, KeyError, OSError, ValueError) as exc:
|
| 72 |
+
errors.append(f"{persona_id} / {variant}: {exc}")
|
| 73 |
+
continue
|
| 74 |
+
|
| 75 |
+
vectors_by_persona[persona_id] = vectors
|
| 76 |
+
|
| 77 |
+
total_layers = len(selected_layers)
|
| 78 |
+
for idx, layer_idx in enumerate(selected_layers, start=1):
|
| 79 |
+
samples: list[torch.Tensor] = []
|
| 80 |
+
labels: list[str] = []
|
| 81 |
+
hover_text: list[str] = []
|
| 82 |
+
|
| 83 |
+
for persona_id, vectors in vectors_by_persona.items():
|
| 84 |
+
if layer_idx >= vectors.shape[1]:
|
| 85 |
+
errors.append(f"{persona_id} / {variant}: missing layer {layer_idx}")
|
| 86 |
+
continue
|
| 87 |
+
|
| 88 |
+
layer_vectors = vectors[:, layer_idx, :]
|
| 89 |
+
samples.append(layer_vectors)
|
| 90 |
+
labels.extend([persona_id] * layer_vectors.shape[0])
|
| 91 |
+
display_name = persona_names.get(persona_id) or persona_id
|
| 92 |
+
hover_text.extend(
|
| 93 |
+
[f"<b>{display_name}</b><br>{variant}"] * layer_vectors.shape[0]
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
if not samples:
|
| 97 |
+
errors.append(f"Layer {layer_idx}: no selected personas have this layer")
|
| 98 |
+
else:
|
| 99 |
+
all_samples = torch.cat(samples, dim=0)
|
| 100 |
+
if all_samples.shape[0] < 2:
|
| 101 |
+
errors.append(
|
| 102 |
+
f"Layer {layer_idx}: need at least 2 samples after filtering selected personas"
|
| 103 |
+
)
|
| 104 |
+
else:
|
| 105 |
+
try:
|
| 106 |
+
coords = project_fn(all_samples)
|
| 107 |
+
plots.append((layer_idx, coords, labels, hover_text))
|
| 108 |
+
except Exception as exc:
|
| 109 |
+
errors.append(f"Layer {layer_idx}: {exc}")
|
| 110 |
+
|
| 111 |
+
if progress_fn is not None:
|
| 112 |
+
progress_fn(idx, total_layers, len(plots))
|
| 113 |
+
|
| 114 |
+
return plots, errors
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _build_embedding_figures(
|
| 118 |
+
plots: list[tuple[int, torch.Tensor, list[str], list[str]]],
|
| 119 |
+
config: ProjectionConfig,
|
| 120 |
+
) -> list[tuple[int, object]]:
|
| 121 |
+
return [
|
| 122 |
+
(
|
| 123 |
+
layer_idx,
|
| 124 |
+
build_embedding_figure(
|
| 125 |
+
coords=coords,
|
| 126 |
+
labels=labels,
|
| 127 |
+
title=f"{config.title_prefix}, layer {layer_idx}",
|
| 128 |
+
x_label=config.x_label,
|
| 129 |
+
y_label=config.y_label,
|
| 130 |
+
hover_text=hover_text,
|
| 131 |
+
),
|
| 132 |
+
)
|
| 133 |
+
for layer_idx, coords, labels, hover_text in plots
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def _render_embedding_results(
|
| 138 |
+
store: ActivationStore,
|
| 139 |
+
analysis_mode: str,
|
| 140 |
+
rendered_figures: list[tuple[int, object]],
|
| 141 |
+
saved_variant: str,
|
| 142 |
+
saved_persona_key: str,
|
| 143 |
+
total_samples: int,
|
| 144 |
+
) -> None:
|
| 145 |
+
cols = st.columns(2)
|
| 146 |
+
for idx, (_, fig) in enumerate(rendered_figures):
|
| 147 |
+
with cols[idx % 2]:
|
| 148 |
+
st.plotly_chart(fig, width="stretch")
|
| 149 |
+
|
| 150 |
+
st.success(f"Loaded {total_samples} samples across {len(rendered_figures)} layers.")
|
| 151 |
+
filenames = [
|
| 152 |
+
_filename(
|
| 153 |
+
"compare",
|
| 154 |
+
analysis_mode,
|
| 155 |
+
store.model_name,
|
| 156 |
+
saved_variant,
|
| 157 |
+
saved_persona_key,
|
| 158 |
+
str(layer_idx),
|
| 159 |
+
)
|
| 160 |
+
for layer_idx, _ in rendered_figures
|
| 161 |
+
]
|
| 162 |
+
_render_save_buttons([fig for _, fig in rendered_figures], filenames, analysis_mode)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _select_artifact_personas(
|
| 166 |
+
store: ActivationStore,
|
| 167 |
+
variants: list[str],
|
| 168 |
) -> tuple[list[str], dict[str, str]]:
|
| 169 |
+
persona_options = list_available_personas(
|
| 170 |
+
store.root_dir, store.model_name, variants
|
| 171 |
+
)
|
| 172 |
+
persona_names = load_persona_names(
|
| 173 |
+
store.root_dir, store.model_name, variants, persona_options
|
| 174 |
)
|
| 175 |
if not persona_options:
|
| 176 |
if len(variants) > 1:
|
|
|
|
| 188 |
format_func=lambda persona_id: persona_display_label(
|
| 189 |
persona_id, persona_names.get(persona_id)
|
| 190 |
),
|
| 191 |
+
key=widget_key("load", "personas", store.model_name, *variants),
|
| 192 |
)
|
| 193 |
return persona_ids, persona_names
|
| 194 |
|
| 195 |
|
| 196 |
+
def _render_save_buttons(
|
| 197 |
+
figs: list[object],
|
| 198 |
+
filenames: list[str],
|
| 199 |
+
key_suffix: str,
|
| 200 |
) -> None:
|
| 201 |
+
"""Render Save HTML / Save PNG column buttons for one or more figures."""
|
| 202 |
+
col1, col2 = st.columns(2)
|
| 203 |
+
with col1:
|
| 204 |
+
if st.button("Save HTML", key=widget_key("load", "save_html", key_suffix)):
|
| 205 |
+
paths = [save_plot_html(fig, fn) for fig, fn in zip(figs, filenames)]
|
| 206 |
+
st.success(f"Saved {len(paths)} HTML file(s) to `artifacts/plots`.")
|
| 207 |
+
with col2:
|
| 208 |
+
if st.button("Save PNG", key=widget_key("load", "save_png", key_suffix)):
|
| 209 |
+
try:
|
| 210 |
+
paths = [save_plot_png(fig, fn) for fig, fn in zip(figs, filenames)]
|
| 211 |
+
st.success(f"Saved {len(paths)} PNG file(s) to `artifacts/plots`.")
|
| 212 |
+
except Exception as exc:
|
| 213 |
+
st.error(f"Could not save PNG: {exc}")
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _select_embedding_config(
|
| 217 |
+
store: ActivationStore,
|
| 218 |
+
) -> tuple[str, list[str], dict[str, str], list[int]] | None:
|
| 219 |
+
"""Render variant / persona / layer selectors and return the selection, or None on early exit."""
|
| 220 |
+
selected_variant = st.selectbox(
|
| 221 |
+
"Variant",
|
| 222 |
+
options=PROMPT_VARIANTS,
|
| 223 |
+
format_func=prompt_variant_label,
|
| 224 |
+
key=widget_key("load", "variant"),
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
persona_ids, persona_names = _select_artifact_personas(store, [selected_variant])
|
| 228 |
+
if not persona_ids:
|
| 229 |
+
return None
|
| 230 |
+
|
| 231 |
+
layer_options = _list_layers(
|
| 232 |
+
str(store.root_dir),
|
| 233 |
+
store.model_name,
|
| 234 |
+
[selected_variant],
|
| 235 |
+
persona_ids,
|
| 236 |
+
)
|
| 237 |
+
if not layer_options:
|
| 238 |
+
st.info(
|
| 239 |
+
"No shared layers are available for the selected personas. Try fewer personas or a different variant."
|
| 240 |
+
)
|
| 241 |
+
return None
|
| 242 |
+
|
| 243 |
+
persona_key = "_".join(sorted(persona_ids))
|
| 244 |
+
layer_key = widget_key(
|
| 245 |
+
"load", "layers", store.model_name, selected_variant, persona_key
|
| 246 |
+
)
|
| 247 |
+
default_layers = [
|
| 248 |
+
layer
|
| 249 |
+
for layer in st.session_state.get(layer_key, layer_options[:3])
|
| 250 |
+
if layer in layer_options
|
| 251 |
+
] or layer_options[:3]
|
| 252 |
+
selected_layers = st.multiselect(
|
| 253 |
+
"Layers",
|
| 254 |
+
options=layer_options,
|
| 255 |
+
default=default_layers,
|
| 256 |
+
key=layer_key,
|
| 257 |
+
)
|
| 258 |
+
if not selected_layers:
|
| 259 |
+
st.info("Select at least one layer.")
|
| 260 |
+
return None
|
| 261 |
+
|
| 262 |
+
return selected_variant, persona_ids, persona_names, selected_layers
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _render_cosine_similarity(store: ActivationStore) -> None:
|
| 266 |
col1, col2 = st.columns(2)
|
| 267 |
with col1:
|
| 268 |
variant_a = st.selectbox(
|
|
|
|
| 285 |
st.warning("Choose two different variants to compare.")
|
| 286 |
return
|
| 287 |
|
| 288 |
+
persona_ids, _ = _select_artifact_personas(store, [variant_a, variant_b])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
if not persona_ids:
|
| 290 |
return
|
| 291 |
|
| 292 |
+
cosine_fig_key = widget_key("load", "cosine_fig_state", store.model_name)
|
| 293 |
+
filename = _filename("compare", "cosine", store.model_name, variant_a, variant_b)
|
| 294 |
|
| 295 |
if st.button("Compare vectors", type="primary"):
|
| 296 |
+
traces, loaded_names, errors = load_mean_activations(
|
| 297 |
+
store.root_dir, store.model_name, persona_ids, variant_a, variant_b
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
)
|
| 299 |
|
| 300 |
if errors:
|
|
|
|
| 316 |
)
|
| 317 |
for persona_id, short, long in traces
|
| 318 |
]
|
| 319 |
+
fig = plot_layer_similarity(
|
| 320 |
display_traces,
|
| 321 |
title=f"{prompt_variant_label(variant_a)} vs {prompt_variant_label(variant_b)}",
|
| 322 |
show=False,
|
|
|
|
| 325 |
|
| 326 |
if cosine_fig_key in st.session_state:
|
| 327 |
fig, n_traces = st.session_state[cosine_fig_key]
|
| 328 |
+
st.plotly_chart(fig, width="stretch")
|
| 329 |
+
_render_save_buttons([fig], [filename], "cosine")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
st.success(f"Loaded {n_traces} personas for cosine comparison.")
|
| 331 |
|
| 332 |
|
| 333 |
+
def _render_embedding_analysis(store: ActivationStore, analysis_mode: str) -> None:
|
| 334 |
+
config = _select_embedding_config(store)
|
| 335 |
+
if config is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
return
|
| 337 |
+
selected_variant, persona_ids, persona_names, selected_layers = config
|
| 338 |
persona_key = "_".join(sorted(persona_ids))
|
| 339 |
+
projection_config = _PROJECTION_CONFIGS.get(analysis_mode)
|
| 340 |
+
if projection_config is None:
|
| 341 |
+
st.error(f"Unsupported analysis mode: {analysis_mode}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
return
|
| 343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
embedding_fig_key = widget_key(
|
| 345 |
+
"load", "embedding_fig_state", store.model_name, analysis_mode
|
| 346 |
)
|
| 347 |
|
| 348 |
+
if st.button(f"Generate {analysis_mode} projection", type="primary"):
|
| 349 |
progress = st.progress(0, text="Preparing projections...")
|
| 350 |
|
| 351 |
def update_progress(current: int, total: int, loaded: int) -> None:
|
|
|
|
| 355 |
text=f"Processing layer {current}/{total} ({loaded} plot(s) ready)",
|
| 356 |
)
|
| 357 |
|
|
|
|
| 358 |
try:
|
| 359 |
+
plots, errors = _load_embedding_samples(
|
| 360 |
+
store,
|
|
|
|
| 361 |
persona_ids,
|
| 362 |
selected_variant,
|
| 363 |
selected_layers,
|
| 364 |
+
projection_config.project_fn,
|
| 365 |
persona_names,
|
| 366 |
progress_fn=update_progress,
|
| 367 |
)
|
|
|
|
| 382 |
st.info("Try fewer personas, fewer layers, or a different variant.")
|
| 383 |
st.session_state.pop(embedding_fig_key, None)
|
| 384 |
else:
|
| 385 |
+
rendered_figures = _build_embedding_figures(plots, projection_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
total_samples = sum(coords.shape[0] for _, coords, _, _ in plots)
|
| 387 |
st.session_state[embedding_fig_key] = (
|
| 388 |
rendered_figures,
|
|
|
|
| 397 |
rendered_figures, saved_persona_key, saved_variant, total_samples = (
|
| 398 |
st.session_state[embedding_fig_key]
|
| 399 |
)
|
| 400 |
+
_render_embedding_results(
|
| 401 |
+
store,
|
| 402 |
+
analysis_mode,
|
| 403 |
+
rendered_figures,
|
| 404 |
+
saved_variant,
|
| 405 |
+
saved_persona_key,
|
| 406 |
+
total_samples,
|
| 407 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
|
| 409 |
|
| 410 |
def render_compare_tab(model_name: str) -> None:
|
|
|
|
| 421 |
value=str(get_artifacts_dir() / "activations"),
|
| 422 |
)
|
| 423 |
|
| 424 |
+
store = ActivationStore(model_name, artifacts_root)
|
| 425 |
+
|
| 426 |
analysis_mode = st.segmented_control(
|
| 427 |
"Analysis mode",
|
| 428 |
options=ANALYSIS_MODES,
|
|
|
|
| 435 |
st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
|
| 436 |
|
| 437 |
if analysis_mode == "Cosine similarity":
|
| 438 |
+
_render_cosine_similarity(store)
|
| 439 |
return
|
| 440 |
|
| 441 |
+
_render_embedding_analysis(store, analysis_mode)
|
tabs/extract.py
CHANGED
|
@@ -3,6 +3,7 @@ from persona_vectors.extraction import run_extraction
|
|
| 3 |
|
| 4 |
from utils.datasets import load_dataset
|
| 5 |
from utils.helpers import (
|
|
|
|
| 6 |
PROMPT_VARIANTS,
|
| 7 |
persona_label,
|
| 8 |
prompt_variant_label,
|
|
@@ -84,8 +85,8 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 84 |
st.info("Select at least one persona.")
|
| 85 |
return
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
| 89 |
|
| 90 |
with st.expander("Advanced", expanded=False):
|
| 91 |
st.caption("Filters")
|
|
@@ -114,35 +115,38 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 114 |
)
|
| 115 |
qa_filter_difficulty = difficulty_values if difficulty_values else None
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
| 121 |
)
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
| 127 |
st.warning(f"No QA pairs match filters for: {names}. They will be skipped.")
|
| 128 |
|
| 129 |
-
|
| 130 |
-
if not personas_to_run:
|
| 131 |
st.info("No personas have matching QA pairs. Widen the filters.")
|
| 132 |
return
|
| 133 |
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
"Max questions",
|
| 139 |
-
min_value=1,
|
| 140 |
-
max_value=min_qa_count,
|
| 141 |
-
value=min_qa_count,
|
| 142 |
-
key=_extract_widget_key(
|
| 143 |
-
model_name, remote, dataset_source, "max_questions"
|
| 144 |
-
),
|
| 145 |
-
)
|
| 146 |
|
| 147 |
run_clicked = st.button("Run extraction", type="primary")
|
| 148 |
if not run_clicked:
|
|
@@ -153,25 +157,19 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 153 |
progress = st.progress(0, text="Preparing extraction...")
|
| 154 |
ndif_status_box = st.empty() # shows live NDIF job status when remote=True
|
| 155 |
|
| 156 |
-
_STATUS_ICONS = {
|
| 157 |
-
"RECEIVED": "◉", "QUEUED": "◎", "DISPATCHED": "◈",
|
| 158 |
-
"RUNNING": "●", "COMPLETED": "✓", "ERROR": "✗",
|
| 159 |
-
}
|
| 160 |
-
|
| 161 |
def _on_ndif_status(job_id: str, status_name: str, description: str) -> None:
|
| 162 |
-
icon =
|
| 163 |
ndif_status_box.caption(f"{icon} `{job_id}` **{status_name}** — {description}")
|
| 164 |
|
| 165 |
with st.spinner("Loading model..."):
|
| 166 |
model = cached_model(model_name=model_name, remote=remote)
|
| 167 |
|
| 168 |
try:
|
| 169 |
-
total_steps = len(
|
| 170 |
step = 0
|
| 171 |
results = []
|
| 172 |
|
| 173 |
-
for persona in
|
| 174 |
-
qa_pairs = qa_by_persona[persona.id][:max_questions]
|
| 175 |
for variant in selected_variants:
|
| 176 |
progress.progress(
|
| 177 |
step / total_steps if total_steps else 1.0,
|
|
@@ -181,7 +179,7 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 181 |
model=model,
|
| 182 |
model_name=model_name,
|
| 183 |
persona=persona,
|
| 184 |
-
qa_pairs=qa_pairs,
|
| 185 |
variants=[variant],
|
| 186 |
remote=remote,
|
| 187 |
on_status=_on_ndif_status if remote else None,
|
|
|
|
| 3 |
|
| 4 |
from utils.datasets import load_dataset
|
| 5 |
from utils.helpers import (
|
| 6 |
+
NDIF_STATUS_ICONS,
|
| 7 |
PROMPT_VARIANTS,
|
| 8 |
persona_label,
|
| 9 |
prompt_variant_label,
|
|
|
|
| 85 |
st.info("Select at least one persona.")
|
| 86 |
return
|
| 87 |
|
| 88 |
+
runs = None
|
| 89 |
+
max_questions = 0
|
| 90 |
|
| 91 |
with st.expander("Advanced", expanded=False):
|
| 92 |
st.caption("Filters")
|
|
|
|
| 115 |
)
|
| 116 |
qa_filter_difficulty = difficulty_values if difficulty_values else None
|
| 117 |
|
| 118 |
+
runs, skipped = [], []
|
| 119 |
+
for persona in selected_personas:
|
| 120 |
+
qa = list(
|
| 121 |
+
dataset.get_qa(
|
| 122 |
+
persona.id, type=qa_filter_type, difficulty=qa_filter_difficulty
|
| 123 |
+
)
|
| 124 |
)
|
| 125 |
+
if qa:
|
| 126 |
+
runs.append((persona, qa))
|
| 127 |
+
else:
|
| 128 |
+
skipped.append(persona)
|
| 129 |
+
if skipped:
|
| 130 |
+
names = ", ".join(p.name for p in skipped)
|
| 131 |
st.warning(f"No QA pairs match filters for: {names}. They will be skipped.")
|
| 132 |
|
| 133 |
+
if not runs:
|
|
|
|
| 134 |
st.info("No personas have matching QA pairs. Widen the filters.")
|
| 135 |
return
|
| 136 |
|
| 137 |
+
max_q = min(len(qa_pairs) for _, qa_pairs in runs)
|
| 138 |
+
max_questions = st.slider(
|
| 139 |
+
"Max questions",
|
| 140 |
+
min_value=1,
|
| 141 |
+
max_value=max_q,
|
| 142 |
+
value=max_q,
|
| 143 |
+
key=_extract_widget_key(
|
| 144 |
+
model_name, remote, dataset_source, "max_questions"
|
| 145 |
+
),
|
| 146 |
+
)
|
| 147 |
|
| 148 |
+
if runs is None:
|
| 149 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
run_clicked = st.button("Run extraction", type="primary")
|
| 152 |
if not run_clicked:
|
|
|
|
| 157 |
progress = st.progress(0, text="Preparing extraction...")
|
| 158 |
ndif_status_box = st.empty() # shows live NDIF job status when remote=True
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
def _on_ndif_status(job_id: str, status_name: str, description: str) -> None:
|
| 161 |
+
icon = NDIF_STATUS_ICONS.get(status_name, "•")
|
| 162 |
ndif_status_box.caption(f"{icon} `{job_id}` **{status_name}** — {description}")
|
| 163 |
|
| 164 |
with st.spinner("Loading model..."):
|
| 165 |
model = cached_model(model_name=model_name, remote=remote)
|
| 166 |
|
| 167 |
try:
|
| 168 |
+
total_steps = len(runs) * len(selected_variants)
|
| 169 |
step = 0
|
| 170 |
results = []
|
| 171 |
|
| 172 |
+
for persona, qa_pairs in runs:
|
|
|
|
| 173 |
for variant in selected_variants:
|
| 174 |
progress.progress(
|
| 175 |
step / total_steps if total_steps else 1.0,
|
|
|
|
| 179 |
model=model,
|
| 180 |
model_name=model_name,
|
| 181 |
persona=persona,
|
| 182 |
+
qa_pairs=qa_pairs[:max_questions],
|
| 183 |
variants=[variant],
|
| 184 |
remote=remote,
|
| 185 |
on_status=_on_ndif_status if remote else None,
|
utils/artifacts.py
DELETED
|
@@ -1,244 +0,0 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
from collections.abc import Callable
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
|
| 5 |
-
import streamlit as st
|
| 6 |
-
import torch
|
| 7 |
-
from persona_vectors.activation_io import (
|
| 8 |
-
load_activation_metadata,
|
| 9 |
-
load_per_question_vectors,
|
| 10 |
-
model_dir_name,
|
| 11 |
-
)
|
| 12 |
-
|
| 13 |
-
logger = logging.getLogger(__name__)
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
def list_available_personas(
|
| 17 |
-
artifacts_root: str | Path,
|
| 18 |
-
model_name: str,
|
| 19 |
-
variants: list[str],
|
| 20 |
-
) -> list[str]:
|
| 21 |
-
"""List persona ids available for every requested variant."""
|
| 22 |
-
|
| 23 |
-
shared_personas: set[str] | None = None
|
| 24 |
-
root = Path(artifacts_root)
|
| 25 |
-
for variant in variants:
|
| 26 |
-
model_dir = root / model_dir_name(model_name) / variant
|
| 27 |
-
if not model_dir.exists():
|
| 28 |
-
return []
|
| 29 |
-
|
| 30 |
-
variant_personas = {d.name for d in model_dir.iterdir() if d.is_dir()}
|
| 31 |
-
if shared_personas is None:
|
| 32 |
-
shared_personas = variant_personas
|
| 33 |
-
else:
|
| 34 |
-
shared_personas &= variant_personas
|
| 35 |
-
|
| 36 |
-
if not shared_personas:
|
| 37 |
-
return []
|
| 38 |
-
|
| 39 |
-
return sorted(shared_personas or set())
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def load_persona_names(
|
| 43 |
-
artifacts_root: str | Path,
|
| 44 |
-
model_name: str,
|
| 45 |
-
variants: list[str],
|
| 46 |
-
persona_ids: list[str],
|
| 47 |
-
) -> dict[str, str]:
|
| 48 |
-
"""Load display names from saved activation metadata."""
|
| 49 |
-
|
| 50 |
-
names: dict[str, str] = {}
|
| 51 |
-
for persona_id in persona_ids:
|
| 52 |
-
for variant in variants:
|
| 53 |
-
try:
|
| 54 |
-
metadata = load_activation_metadata(
|
| 55 |
-
root_dir=artifacts_root,
|
| 56 |
-
model_name=model_name,
|
| 57 |
-
prompt_variant=variant,
|
| 58 |
-
persona_id=persona_id,
|
| 59 |
-
)
|
| 60 |
-
except Exception:
|
| 61 |
-
logger.debug(
|
| 62 |
-
"Failed to load metadata for persona %s variant %s",
|
| 63 |
-
persona_id,
|
| 64 |
-
variant,
|
| 65 |
-
exc_info=True,
|
| 66 |
-
)
|
| 67 |
-
continue
|
| 68 |
-
|
| 69 |
-
persona_name = metadata.get("persona_name")
|
| 70 |
-
if isinstance(persona_name, str) and persona_name:
|
| 71 |
-
names[persona_id] = persona_name
|
| 72 |
-
break
|
| 73 |
-
|
| 74 |
-
return names
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
def artifact_persona_options(
|
| 78 |
-
artifacts_root: str | Path,
|
| 79 |
-
model_name: str,
|
| 80 |
-
variants: list[str],
|
| 81 |
-
) -> tuple[list[str], dict[str, str]]:
|
| 82 |
-
"""Return persona ids and names for the selected artifacts."""
|
| 83 |
-
|
| 84 |
-
persona_options = list_available_personas(artifacts_root, model_name, variants)
|
| 85 |
-
persona_names = load_persona_names(
|
| 86 |
-
artifacts_root,
|
| 87 |
-
model_name,
|
| 88 |
-
variants,
|
| 89 |
-
persona_options,
|
| 90 |
-
)
|
| 91 |
-
return persona_options, persona_names
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
@st.cache_data(show_spinner=False)
|
| 95 |
-
def list_available_layers(
|
| 96 |
-
artifacts_root: str,
|
| 97 |
-
model_name: str,
|
| 98 |
-
variants: list[str],
|
| 99 |
-
persona_ids: list[str],
|
| 100 |
-
) -> list[int]:
|
| 101 |
-
"""List layer indices shared by all matching saved activation files."""
|
| 102 |
-
|
| 103 |
-
shared_layers: set[int] | None = None
|
| 104 |
-
for variant in variants:
|
| 105 |
-
for persona_id in persona_ids:
|
| 106 |
-
try:
|
| 107 |
-
vectors, _ = load_per_question_vectors(
|
| 108 |
-
root_dir=artifacts_root,
|
| 109 |
-
model_name=model_name,
|
| 110 |
-
prompt_variant=variant,
|
| 111 |
-
persona_id=persona_id,
|
| 112 |
-
)
|
| 113 |
-
except Exception:
|
| 114 |
-
logger.debug(
|
| 115 |
-
"Failed to load vectors for persona %s variant %s",
|
| 116 |
-
persona_id,
|
| 117 |
-
variant,
|
| 118 |
-
exc_info=True,
|
| 119 |
-
)
|
| 120 |
-
continue
|
| 121 |
-
|
| 122 |
-
layers = set(range(vectors.shape[1]))
|
| 123 |
-
if shared_layers is None:
|
| 124 |
-
shared_layers = layers
|
| 125 |
-
else:
|
| 126 |
-
shared_layers &= layers
|
| 127 |
-
|
| 128 |
-
return sorted(shared_layers or set())
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def load_cosine_traces(
|
| 132 |
-
artifacts_root: str | Path,
|
| 133 |
-
model_name: str,
|
| 134 |
-
persona_ids: list[str],
|
| 135 |
-
variant_a: str,
|
| 136 |
-
variant_b: str,
|
| 137 |
-
) -> tuple[list[tuple[str, torch.Tensor, torch.Tensor]], dict[str, str], list[str]]:
|
| 138 |
-
"""Load mean activation traces for pairwise cosine-similarity plots."""
|
| 139 |
-
|
| 140 |
-
persona_names = load_persona_names(
|
| 141 |
-
artifacts_root,
|
| 142 |
-
model_name,
|
| 143 |
-
[variant_a, variant_b],
|
| 144 |
-
persona_ids,
|
| 145 |
-
)
|
| 146 |
-
traces: list[tuple[str, torch.Tensor, torch.Tensor]] = []
|
| 147 |
-
errors: list[str] = []
|
| 148 |
-
|
| 149 |
-
for persona_id in persona_ids:
|
| 150 |
-
try:
|
| 151 |
-
vectors_a, _ = load_per_question_vectors(
|
| 152 |
-
root_dir=artifacts_root,
|
| 153 |
-
model_name=model_name,
|
| 154 |
-
prompt_variant=variant_a,
|
| 155 |
-
persona_id=persona_id,
|
| 156 |
-
)
|
| 157 |
-
vectors_b, _ = load_per_question_vectors(
|
| 158 |
-
root_dir=artifacts_root,
|
| 159 |
-
model_name=model_name,
|
| 160 |
-
prompt_variant=variant_b,
|
| 161 |
-
persona_id=persona_id,
|
| 162 |
-
)
|
| 163 |
-
except Exception as exc:
|
| 164 |
-
errors.append(f"{persona_id}: {exc}")
|
| 165 |
-
continue
|
| 166 |
-
|
| 167 |
-
traces.append(
|
| 168 |
-
(persona_id, vectors_a.float().mean(dim=0), vectors_b.float().mean(dim=0))
|
| 169 |
-
)
|
| 170 |
-
|
| 171 |
-
return traces, persona_names, errors
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
def load_embedding_samples(
|
| 175 |
-
artifacts_root: str | Path,
|
| 176 |
-
model_name: str,
|
| 177 |
-
persona_ids: list[str],
|
| 178 |
-
variant: str,
|
| 179 |
-
selected_layers: list[int],
|
| 180 |
-
project_fn: Callable[[torch.Tensor], torch.Tensor],
|
| 181 |
-
persona_names: dict[str, str],
|
| 182 |
-
progress_fn: Callable[[int, int, int], None] | None = None,
|
| 183 |
-
) -> tuple[list[tuple[int, torch.Tensor, list[str], list[str]]], list[str]]:
|
| 184 |
-
"""Load samples for 2D projections without re-reading each layer from disk."""
|
| 185 |
-
|
| 186 |
-
plots: list[tuple[int, torch.Tensor, list[str], list[str]]] = []
|
| 187 |
-
errors: list[str] = []
|
| 188 |
-
vectors_by_persona: dict[str, torch.Tensor] = {}
|
| 189 |
-
|
| 190 |
-
for persona_id in persona_ids:
|
| 191 |
-
try:
|
| 192 |
-
vectors, _ = load_per_question_vectors(
|
| 193 |
-
root_dir=artifacts_root,
|
| 194 |
-
model_name=model_name,
|
| 195 |
-
prompt_variant=variant,
|
| 196 |
-
persona_id=persona_id,
|
| 197 |
-
)
|
| 198 |
-
except Exception as exc:
|
| 199 |
-
errors.append(f"{persona_id} / {variant}: {exc}")
|
| 200 |
-
continue
|
| 201 |
-
|
| 202 |
-
vectors_by_persona[persona_id] = vectors
|
| 203 |
-
|
| 204 |
-
total_layers = len(selected_layers)
|
| 205 |
-
for idx, layer_idx in enumerate(selected_layers, start=1):
|
| 206 |
-
samples: list[torch.Tensor] = []
|
| 207 |
-
labels: list[str] = []
|
| 208 |
-
hover_text: list[str] = []
|
| 209 |
-
|
| 210 |
-
for persona_id, vectors in vectors_by_persona.items():
|
| 211 |
-
if layer_idx >= vectors.shape[1]:
|
| 212 |
-
errors.append(f"{persona_id} / {variant}: missing layer {layer_idx}")
|
| 213 |
-
continue
|
| 214 |
-
|
| 215 |
-
layer_vectors = vectors[:, layer_idx, :]
|
| 216 |
-
samples.append(layer_vectors)
|
| 217 |
-
labels.extend([persona_id] * layer_vectors.shape[0])
|
| 218 |
-
display_name = persona_names.get(persona_id) or persona_id
|
| 219 |
-
hover_text.extend(
|
| 220 |
-
[
|
| 221 |
-
f"<b>{display_name}</b><br>{variant}",
|
| 222 |
-
]
|
| 223 |
-
* layer_vectors.shape[0]
|
| 224 |
-
)
|
| 225 |
-
|
| 226 |
-
if not samples:
|
| 227 |
-
errors.append(f"Layer {layer_idx}: no selected personas have this layer")
|
| 228 |
-
else:
|
| 229 |
-
all_samples = torch.cat(samples, dim=0)
|
| 230 |
-
if all_samples.shape[0] < 2:
|
| 231 |
-
errors.append(
|
| 232 |
-
f"Layer {layer_idx}: need at least 2 samples after filtering selected personas"
|
| 233 |
-
)
|
| 234 |
-
else:
|
| 235 |
-
try:
|
| 236 |
-
coords = project_fn(all_samples)
|
| 237 |
-
plots.append((layer_idx, coords, labels, hover_text))
|
| 238 |
-
except Exception as exc:
|
| 239 |
-
errors.append(f"Layer {layer_idx}: {exc}")
|
| 240 |
-
|
| 241 |
-
if progress_fn is not None:
|
| 242 |
-
progress_fn(idx, total_layers, len(plots))
|
| 243 |
-
|
| 244 |
-
return plots, errors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/chat.py
CHANGED
|
@@ -52,7 +52,6 @@ def resolve_system_prompt(
|
|
| 52 |
return format_biography_prompt(persona.biography_md)
|
| 53 |
if mode == "custom":
|
| 54 |
return format_empty_persona_prompt()
|
| 55 |
-
return ""
|
| 56 |
|
| 57 |
|
| 58 |
def _format_plain_messages(
|
|
|
|
| 52 |
return format_biography_prompt(persona.biography_md)
|
| 53 |
if mode == "custom":
|
| 54 |
return format_empty_persona_prompt()
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
def _format_plain_messages(
|
utils/chat_export.py
CHANGED
|
@@ -3,24 +3,23 @@ from datetime import datetime, timezone
|
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
from persona_data.environment import get_artifacts_dir
|
| 6 |
-
from persona_vectors.activation_io import model_dir_name
|
| 7 |
|
| 8 |
from utils.helpers import slugify
|
| 9 |
|
| 10 |
|
| 11 |
-
def
|
| 12 |
*,
|
| 13 |
model_name: str,
|
| 14 |
dataset_source: str,
|
| 15 |
persona_id: str,
|
| 16 |
persona_name: str | None,
|
| 17 |
-
panel_label: str | None,
|
| 18 |
prompt_mode: str,
|
| 19 |
system_prompt: str | None,
|
| 20 |
messages: list[dict[str, str]],
|
| 21 |
generation: dict[str, object],
|
| 22 |
-
|
| 23 |
-
|
|
|
|
| 24 |
|
| 25 |
Args:
|
| 26 |
model_name: Model identifier used for the chat.
|
|
@@ -28,14 +27,15 @@ def build_chat_export_payload(
|
|
| 28 |
persona_id: Selected persona id.
|
| 29 |
persona_name: Selected persona display name, if available.
|
| 30 |
prompt_mode: Active system prompt mode.
|
|
|
|
| 31 |
messages: Conversation messages without the system prompt.
|
| 32 |
generation: Generation settings used for the chat.
|
| 33 |
|
| 34 |
Returns:
|
| 35 |
-
|
| 36 |
"""
|
| 37 |
|
| 38 |
-
|
| 39 |
"model_name": model_name,
|
| 40 |
"dataset_source": dataset_source,
|
| 41 |
"persona": {
|
|
@@ -51,50 +51,10 @@ def build_chat_export_payload(
|
|
| 51 |
+ messages,
|
| 52 |
}
|
| 53 |
|
| 54 |
-
|
| 55 |
-
def save_chat_export(
|
| 56 |
-
*,
|
| 57 |
-
model_name: str,
|
| 58 |
-
dataset_source: str,
|
| 59 |
-
persona_id: str,
|
| 60 |
-
persona_name: str | None,
|
| 61 |
-
prompt_mode: str,
|
| 62 |
-
system_prompt: str | None,
|
| 63 |
-
messages: list[dict[str, str]],
|
| 64 |
-
generation: dict[str, object],
|
| 65 |
-
panel_label: str | None = None,
|
| 66 |
-
) -> Path:
|
| 67 |
-
"""Save the current chat session to ``artifacts/chats`` as JSON.
|
| 68 |
-
|
| 69 |
-
Args:
|
| 70 |
-
model_name: Model identifier used for the chat.
|
| 71 |
-
dataset_source: Human-readable dataset source label.
|
| 72 |
-
persona_id: Selected persona id.
|
| 73 |
-
persona_name: Selected persona display name, if available.
|
| 74 |
-
prompt_mode: Active system prompt mode.
|
| 75 |
-
system_prompt: Current system prompt text, if any.
|
| 76 |
-
messages: Conversation messages without the system prompt.
|
| 77 |
-
generation: Generation settings used for the chat.
|
| 78 |
-
|
| 79 |
-
Returns:
|
| 80 |
-
The path the export was written to.
|
| 81 |
-
"""
|
| 82 |
-
|
| 83 |
-
payload = build_chat_export_payload(
|
| 84 |
-
model_name=model_name,
|
| 85 |
-
dataset_source=dataset_source,
|
| 86 |
-
persona_id=persona_id,
|
| 87 |
-
persona_name=persona_name,
|
| 88 |
-
panel_label=panel_label,
|
| 89 |
-
prompt_mode=prompt_mode,
|
| 90 |
-
system_prompt=system_prompt,
|
| 91 |
-
messages=messages,
|
| 92 |
-
generation=generation,
|
| 93 |
-
)
|
| 94 |
export_dir = (
|
| 95 |
get_artifacts_dir()
|
| 96 |
/ "chats"
|
| 97 |
-
/
|
| 98 |
/ slugify(dataset_source)
|
| 99 |
/ slugify(persona_id)
|
| 100 |
)
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
from persona_data.environment import get_artifacts_dir
|
|
|
|
| 6 |
|
| 7 |
from utils.helpers import slugify
|
| 8 |
|
| 9 |
|
| 10 |
+
def save_chat_export(
|
| 11 |
*,
|
| 12 |
model_name: str,
|
| 13 |
dataset_source: str,
|
| 14 |
persona_id: str,
|
| 15 |
persona_name: str | None,
|
|
|
|
| 16 |
prompt_mode: str,
|
| 17 |
system_prompt: str | None,
|
| 18 |
messages: list[dict[str, str]],
|
| 19 |
generation: dict[str, object],
|
| 20 |
+
panel_label: str | None = None,
|
| 21 |
+
) -> Path:
|
| 22 |
+
"""Save the current chat session to ``artifacts/chats`` as JSON.
|
| 23 |
|
| 24 |
Args:
|
| 25 |
model_name: Model identifier used for the chat.
|
|
|
|
| 27 |
persona_id: Selected persona id.
|
| 28 |
persona_name: Selected persona display name, if available.
|
| 29 |
prompt_mode: Active system prompt mode.
|
| 30 |
+
system_prompt: Current system prompt text, if any.
|
| 31 |
messages: Conversation messages without the system prompt.
|
| 32 |
generation: Generation settings used for the chat.
|
| 33 |
|
| 34 |
Returns:
|
| 35 |
+
The path the export was written to.
|
| 36 |
"""
|
| 37 |
|
| 38 |
+
payload = {
|
| 39 |
"model_name": model_name,
|
| 40 |
"dataset_source": dataset_source,
|
| 41 |
"persona": {
|
|
|
|
| 51 |
+ messages,
|
| 52 |
}
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
export_dir = (
|
| 55 |
get_artifacts_dir()
|
| 56 |
/ "chats"
|
| 57 |
+
/ model_name.replace("/", "__")
|
| 58 |
/ slugify(dataset_source)
|
| 59 |
/ slugify(persona_id)
|
| 60 |
)
|
utils/datasets.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import atexit
|
|
|
|
| 2 |
import shutil
|
| 3 |
from pathlib import Path
|
| 4 |
from tempfile import mkdtemp
|
|
@@ -31,10 +32,13 @@ def _upload_cache_dir() -> Path:
|
|
| 31 |
def _uploaded_file_to_temp_path(uploaded_file: Any, stem: str) -> Path:
|
| 32 |
suffix = Path(uploaded_file.name).suffix or ".jsonl"
|
| 33 |
temp_path = _upload_cache_dir() / f"{stem}{suffix}"
|
|
|
|
| 34 |
data = uploaded_file.getvalue()
|
| 35 |
-
|
|
|
|
| 36 |
return temp_path
|
| 37 |
temp_path.write_bytes(data)
|
|
|
|
| 38 |
return temp_path
|
| 39 |
|
| 40 |
|
|
|
|
| 1 |
import atexit
|
| 2 |
+
import hashlib
|
| 3 |
import shutil
|
| 4 |
from pathlib import Path
|
| 5 |
from tempfile import mkdtemp
|
|
|
|
| 32 |
def _uploaded_file_to_temp_path(uploaded_file: Any, stem: str) -> Path:
|
| 33 |
suffix = Path(uploaded_file.name).suffix or ".jsonl"
|
| 34 |
temp_path = _upload_cache_dir() / f"{stem}{suffix}"
|
| 35 |
+
hash_path = temp_path.with_suffix(temp_path.suffix + ".sha256")
|
| 36 |
data = uploaded_file.getvalue()
|
| 37 |
+
digest = hashlib.sha256(data).hexdigest()
|
| 38 |
+
if temp_path.exists() and hash_path.exists() and hash_path.read_text() == digest:
|
| 39 |
return temp_path
|
| 40 |
temp_path.write_bytes(data)
|
| 41 |
+
hash_path.write_text(digest)
|
| 42 |
return temp_path
|
| 43 |
|
| 44 |
|
utils/helpers.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
|
|
|
|
|
|
| 1 |
from persona_data.synth_persona import PersonaData
|
| 2 |
-
from persona_vectors.
|
| 3 |
|
| 4 |
# Variant key -> human-readable label mapping
|
| 5 |
VARIANT_LABELS = {
|
|
@@ -18,25 +20,29 @@ MODE_LABELS = list(VARIANT_LABELS.values())
|
|
| 18 |
# Reverse lookup: label -> key
|
| 19 |
MODE_LABEL_TO_KEY = {v: k for k, v in VARIANT_LABELS.items()}
|
| 20 |
|
|
|
|
|
|
|
| 21 |
DATASET_SOURCES = ["HuggingFace: synth-persona", "Local JSONL upload"]
|
| 22 |
ANALYSIS_MODES = ["Cosine similarity", "PCA", "UMAP"]
|
| 23 |
|
| 24 |
-
ANALYSIS_LABELS = {
|
| 25 |
-
"PCA": ("PCA", "PC1", "PC2"),
|
| 26 |
-
"UMAP": ("UMAP", "UMAP 1", "UMAP 2"),
|
| 27 |
-
}
|
| 28 |
-
|
| 29 |
ANALYSIS_HELP_TEXT = {
|
| 30 |
"Cosine similarity": "Compare layer-wise alignment between variants.",
|
| 31 |
"PCA": "Project the selected layers into a global 2D view.",
|
| 32 |
"UMAP": "Project the selected layers into a local-neighborhood 2D view.",
|
| 33 |
}
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
def slugify(value: str) -> str:
|
| 37 |
-
"""Convert a string to a slug safe for filenames and URLs."""
|
| 38 |
|
| 39 |
-
|
|
|
|
| 40 |
|
| 41 |
return re.sub(r"[^a-z0-9]+", "_", value.lower()).strip("_") or "unknown"
|
| 42 |
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
from persona_data.synth_persona import PersonaData
|
| 4 |
+
from persona_vectors.artifacts import SUPPORTED_VARIANTS
|
| 5 |
|
| 6 |
# Variant key -> human-readable label mapping
|
| 7 |
VARIANT_LABELS = {
|
|
|
|
| 20 |
# Reverse lookup: label -> key
|
| 21 |
MODE_LABEL_TO_KEY = {v: k for k, v in VARIANT_LABELS.items()}
|
| 22 |
|
| 23 |
+
VISIBLE_MESSAGE_COUNT = 5
|
| 24 |
+
|
| 25 |
DATASET_SOURCES = ["HuggingFace: synth-persona", "Local JSONL upload"]
|
| 26 |
ANALYSIS_MODES = ["Cosine similarity", "PCA", "UMAP"]
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
ANALYSIS_HELP_TEXT = {
|
| 29 |
"Cosine similarity": "Compare layer-wise alignment between variants.",
|
| 30 |
"PCA": "Project the selected layers into a global 2D view.",
|
| 31 |
"UMAP": "Project the selected layers into a local-neighborhood 2D view.",
|
| 32 |
}
|
| 33 |
|
| 34 |
+
NDIF_STATUS_ICONS = {
|
| 35 |
+
"RECEIVED": "◉",
|
| 36 |
+
"QUEUED": "◎",
|
| 37 |
+
"DISPATCHED": "◈",
|
| 38 |
+
"RUNNING": "●",
|
| 39 |
+
"COMPLETED": "✓",
|
| 40 |
+
"ERROR": "✗",
|
| 41 |
+
}
|
| 42 |
|
|
|
|
|
|
|
| 43 |
|
| 44 |
+
def slugify(value: str) -> str:
|
| 45 |
+
"""Convert a string to a filesystem-safe slug."""
|
| 46 |
|
| 47 |
return re.sub(r"[^a-z0-9]+", "_", value.lower()).strip("_") or "unknown"
|
| 48 |
|