Spaces:
Sleeping
Sleeping
Jac-Zac commited on
Commit ·
b279884
1
Parent(s): 9edffb7
Big refactoring
Browse files- Speed gains
- Improved dendogram figures
- Better information while chatting with models or loading datasets
- Faster overall ui
- Probin UI imrpovements
- Default values changed for better user experiennce
- Code structure refactoring
- .env.example +5 -0
- README.md +3 -1
- app.py +32 -17
- pyproject.toml +1 -1
- state.py +12 -2
- tabs/analysis/_shared.py +90 -10
- tabs/analysis/_state.py +29 -15
- tabs/analysis/cosine.py +2 -3
- tabs/analysis/dendrogram.py +136 -55
- tabs/analysis/layered.py +17 -17
- tabs/analysis_core.py +23 -168
- tabs/chat.py +24 -1
- tabs/chat_shared.py +19 -0
- tabs/chat_ui.py +1 -0
- tabs/compare_chat.py +26 -1
- tabs/extract.py +2 -3
- tabs/probe.py +72 -197
- tabs/probe_sweep.py +94 -0
- tabs/probe_ui.py +58 -38
- tests/test_datasets.py +129 -0
- tests/test_probe_cache_bounds.py +80 -0
- tests/test_probe_sweep.py +95 -0
- tests/test_probes.py +3 -6
- tests/test_state.py +16 -0
- utils/analysis_sources.py +1 -1
- utils/chat.py +41 -1
- utils/contrast.py +1 -3
- utils/controls.py +7 -1
- utils/datasets.py +85 -3
- utils/helpers.py +20 -0
- utils/probe_files.py +162 -0
- utils/probe_overlay.py +3 -8
- utils/probe_trace.py +28 -9
- utils/probes.py +18 -167
- utils/selection_controls.py +35 -0
- utils/source_controls.py +230 -0
- uv.lock +7 -7
.env.example
CHANGED
|
@@ -25,3 +25,8 @@ ARTIFACTS_DIR=artifacts
|
|
| 25 |
# PERSONA_UI_STORE_CACHE_ENTRIES=4
|
| 26 |
# PERSONA_UI_VECTOR_CACHE_ENTRIES=4
|
| 27 |
# PERSONA_UI_PREPARED_CACHE_ENTRIES=8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# PERSONA_UI_STORE_CACHE_ENTRIES=4
|
| 26 |
# PERSONA_UI_VECTOR_CACHE_ENTRIES=4
|
| 27 |
# PERSONA_UI_PREPARED_CACHE_ENTRIES=8
|
| 28 |
+
# PERSONA_UI_FIGURE_STATE_ENTRIES=2
|
| 29 |
+
# PERSONA_UI_PREPARED_STATE_ENTRIES=4
|
| 30 |
+
# PERSONA_UI_PROBE_CACHE_ENTRIES=8
|
| 31 |
+
# PERSONA_UI_PROBE_SWEEP_CACHE_ENTRIES=4
|
| 32 |
+
# PERSONA_UI_PROBE_DERIVED_CACHE_ENTRIES=12
|
README.md
CHANGED
|
@@ -118,6 +118,8 @@ ARTIFACTS_DIR=... # Optional: where persona vectors are read from (default:
|
|
| 118 |
PERSONA_VECTORS_HUB_REPO=... # Optional: default Analysis/Probing Hub dataset repo
|
| 119 |
PERSONA_UI_VECTOR_CACHE_ENTRIES=4 # Optional: loaded analysis datasets kept warm
|
| 120 |
PERSONA_UI_PREPARED_CACHE_ENTRIES=8 # Optional: prepared projections / k-means groups kept warm
|
|
|
|
|
|
|
| 121 |
```
|
| 122 |
|
| 123 |
The app picks up this file automatically via `load_dotenv()` on startup.
|
|
@@ -153,4 +155,4 @@ The store classes are `PersonaVectorStore` (local) and `HFPersonaVectorStore`
|
|
| 153 |
|
| 154 |
## Analysis responsiveness
|
| 155 |
|
| 156 |
-
The Analysis tab keeps
|
|
|
|
| 118 |
PERSONA_VECTORS_HUB_REPO=... # Optional: default Analysis/Probing Hub dataset repo
|
| 119 |
PERSONA_UI_VECTOR_CACHE_ENTRIES=4 # Optional: loaded analysis datasets kept warm
|
| 120 |
PERSONA_UI_PREPARED_CACHE_ENTRIES=8 # Optional: prepared projections / k-means groups kept warm
|
| 121 |
+
PERSONA_UI_FIGURE_STATE_ENTRIES=2 # Optional: recent rendered Analysis figures kept in-session
|
| 122 |
+
PERSONA_UI_PREPARED_STATE_ENTRIES=4 # Optional: recent projection-ready markers kept in-session
|
| 123 |
```
|
| 124 |
|
| 125 |
The app picks up this file automatically via `load_dotenv()` on startup.
|
|
|
|
| 155 |
|
| 156 |
## Analysis responsiveness
|
| 157 |
|
| 158 |
+
The Analysis tab keeps small bounded caches of loaded vector datasets, prepared projection data, and a tiny MRU window of rendered figures. Once a projection has been computed, recoloring it by persona, attribute, or k-means group reuses the same coordinates; nearby method switches can reuse the last couple of figures instead of rebuilding immediately, while the caps keep RAM bounded. Tune `PERSONA_UI_VECTOR_CACHE_ENTRIES` if RAM is tight or you regularly switch among many selections, `PERSONA_UI_PREPARED_CACHE_ENTRIES` if you revisit several projection configurations in one session, and `PERSONA_UI_FIGURE_STATE_ENTRIES` if you want more or less method-switch warmth. Probe loading, probe sweeps, and per-trace probe outputs are bounded separately via `PERSONA_UI_PROBE_CACHE_ENTRIES`, `PERSONA_UI_PROBE_SWEEP_CACHE_ENTRIES`, and `PERSONA_UI_PROBE_DERIVED_CACHE_ENTRIES`; the derived-output cache defaults to a wider MRU window because those tensors are small compared with traced activations and are cheap wins to keep warm.
|
app.py
CHANGED
|
@@ -4,11 +4,7 @@ from dataclasses import dataclass
|
|
| 4 |
import streamlit as st
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
|
| 7 |
-
from utils.analysis_sources import
|
| 8 |
-
DEFAULT_COMPARE_MODEL,
|
| 9 |
-
DEFAULT_HUB_REPO,
|
| 10 |
-
SOURCE_HUB,
|
| 11 |
-
)
|
| 12 |
from utils.helpers import DATASET_SOURCES, session_key, widget_key
|
| 13 |
from utils.preload import preload_once
|
| 14 |
from utils.runtime import list_remote_models
|
|
@@ -60,21 +56,34 @@ def _hub_metadata_preload_calls() -> tuple[
|
|
| 60 |
calls: list[tuple[str, tuple[str, str, str, str | None]]] = []
|
| 61 |
|
| 62 |
def add(repo: str, model: str, mask_strategy: str, variant: str | None) -> None:
|
| 63 |
-
calls.append(
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
-
analysis_source = st.session_state.get("analysis:last_source",
|
| 69 |
if analysis_source == SOURCE_HUB:
|
| 70 |
-
repo = st.session_state.get(
|
|
|
|
|
|
|
|
|
|
| 71 |
mask_strategy = st.session_state.get(
|
| 72 |
"analysis:last_mask_strategy",
|
| 73 |
-
|
| 74 |
)
|
| 75 |
model = st.session_state.get(
|
| 76 |
widget_key("load", "hub_model", repo, mask_strategy),
|
| 77 |
-
st.session_state.get(
|
|
|
|
|
|
|
|
|
|
| 78 |
)
|
| 79 |
variant = st.session_state.get(
|
| 80 |
"analysis:last_projection_variant",
|
|
@@ -82,16 +91,22 @@ def _hub_metadata_preload_calls() -> tuple[
|
|
| 82 |
)
|
| 83 |
add(repo, model, mask_strategy, variant)
|
| 84 |
|
| 85 |
-
probe_source = st.session_state.get(widget_key("probe", "source"),
|
| 86 |
if probe_source == SOURCE_HUB:
|
| 87 |
-
repo = st.session_state.get(
|
|
|
|
|
|
|
|
|
|
| 88 |
mask_strategy = st.session_state.get(
|
| 89 |
"probe:last_mask_strategy",
|
| 90 |
-
|
| 91 |
)
|
| 92 |
model = st.session_state.get(
|
| 93 |
widget_key("probe", "hub_model", repo, mask_strategy),
|
| 94 |
-
st.session_state.get(
|
|
|
|
|
|
|
|
|
|
| 95 |
)
|
| 96 |
add(repo, model, mask_strategy, st.session_state.get("probe:variant"))
|
| 97 |
|
|
|
|
| 4 |
import streamlit as st
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
|
| 7 |
+
from utils.analysis_sources import DEFAULT_COMPARE_MODEL, DEFAULT_HUB_REPO, SOURCE_HUB
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
from utils.helpers import DATASET_SOURCES, session_key, widget_key
|
| 9 |
from utils.preload import preload_once
|
| 10 |
from utils.runtime import list_remote_models
|
|
|
|
| 56 |
calls: list[tuple[str, tuple[str, str, str, str | None]]] = []
|
| 57 |
|
| 58 |
def add(repo: str, model: str, mask_strategy: str, variant: str | None) -> None:
|
| 59 |
+
calls.append(
|
| 60 |
+
(
|
| 61 |
+
"utils.analysis_sources:prefetch_hub_metadata",
|
| 62 |
+
(repo, model, mask_strategy, variant),
|
| 63 |
+
)
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
shared_source = st.session_state.get("source:last_source", SOURCE_HUB)
|
| 67 |
+
shared_mask_strategy = st.session_state.get(
|
| 68 |
+
"source:last_mask_strategy", "answer_mean"
|
| 69 |
+
)
|
| 70 |
|
| 71 |
+
analysis_source = st.session_state.get("analysis:last_source", shared_source)
|
| 72 |
if analysis_source == SOURCE_HUB:
|
| 73 |
+
repo = st.session_state.get(
|
| 74 |
+
"analysis:hub_repo",
|
| 75 |
+
st.session_state.get("source:hub_repo", DEFAULT_HUB_REPO),
|
| 76 |
+
)
|
| 77 |
mask_strategy = st.session_state.get(
|
| 78 |
"analysis:last_mask_strategy",
|
| 79 |
+
shared_mask_strategy,
|
| 80 |
)
|
| 81 |
model = st.session_state.get(
|
| 82 |
widget_key("load", "hub_model", repo, mask_strategy),
|
| 83 |
+
st.session_state.get(
|
| 84 |
+
"analysis:hub_model_fallback",
|
| 85 |
+
st.session_state.get("source:hub_model", DEFAULT_COMPARE_MODEL),
|
| 86 |
+
),
|
| 87 |
)
|
| 88 |
variant = st.session_state.get(
|
| 89 |
"analysis:last_projection_variant",
|
|
|
|
| 91 |
)
|
| 92 |
add(repo, model, mask_strategy, variant)
|
| 93 |
|
| 94 |
+
probe_source = st.session_state.get(widget_key("probe", "source"), shared_source)
|
| 95 |
if probe_source == SOURCE_HUB:
|
| 96 |
+
repo = st.session_state.get(
|
| 97 |
+
"probe:hub_repo",
|
| 98 |
+
st.session_state.get("source:hub_repo", DEFAULT_HUB_REPO),
|
| 99 |
+
)
|
| 100 |
mask_strategy = st.session_state.get(
|
| 101 |
"probe:last_mask_strategy",
|
| 102 |
+
shared_mask_strategy,
|
| 103 |
)
|
| 104 |
model = st.session_state.get(
|
| 105 |
widget_key("probe", "hub_model", repo, mask_strategy),
|
| 106 |
+
st.session_state.get(
|
| 107 |
+
"probe:hub_model_fallback",
|
| 108 |
+
st.session_state.get("source:hub_model", DEFAULT_COMPARE_MODEL),
|
| 109 |
+
),
|
| 110 |
)
|
| 111 |
add(repo, model, mask_strategy, st.session_state.get("probe:variant"))
|
| 112 |
|
pyproject.toml
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
[project]
|
| 2 |
name = "persona-ui"
|
| 3 |
-
version = "0.
|
| 4 |
description = "Streamlit UI for persona-vectors"
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.12"
|
|
|
|
| 1 |
[project]
|
| 2 |
name = "persona-ui"
|
| 3 |
+
version = "0.5.0"
|
| 4 |
description = "Streamlit UI for persona-vectors"
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.12"
|
state.py
CHANGED
|
@@ -21,9 +21,19 @@ class ChatState(TypedDict):
|
|
| 21 |
|
| 22 |
|
| 23 |
def chat_session_key(model_name: str, dataset_source: str) -> str:
|
| 24 |
-
"""Build the session-state key for a chat
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
|
| 29 |
def default_chat_state() -> ChatState:
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
def chat_session_key(model_name: str, dataset_source: str) -> str:
|
| 24 |
+
"""Build the session-state key for a chat conversation.
|
| 25 |
|
| 26 |
+
A model/backend switch changes *how* the next turn is generated, not which
|
| 27 |
+
conversation the user is looking at. Keeping the model out of the key means
|
| 28 |
+
toggling local/remote execution (or selecting another model) no longer makes
|
| 29 |
+
an existing thread appear to vanish behind a fresh empty state.
|
| 30 |
+
|
| 31 |
+
``model_name`` stays in the signature for call-site compatibility and to
|
| 32 |
+
make the intent explicit where chat state is requested.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
_ = model_name
|
| 36 |
+
return session_key("chat_state", dataset_source)
|
| 37 |
|
| 38 |
|
| 39 |
def default_chat_state() -> ChatState:
|
tabs/analysis/_shared.py
CHANGED
|
@@ -261,6 +261,7 @@ def _render_persona_count_controls(
|
|
| 261 |
*,
|
| 262 |
default_count: int,
|
| 263 |
include_assistant_default: bool,
|
|
|
|
| 264 |
) -> tuple[int, bool]:
|
| 265 |
count_key = widget_key(
|
| 266 |
"load",
|
|
@@ -280,11 +281,16 @@ def _render_persona_count_controls(
|
|
| 280 |
)
|
| 281 |
|
| 282 |
if options.regular_ids:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
persona_count = st.slider(
|
| 284 |
"Personas",
|
| 285 |
min_value=0 if options.assistant_id is not None else 1,
|
| 286 |
-
max_value=
|
| 287 |
-
value=default_count,
|
| 288 |
key=count_key,
|
| 289 |
help="Use the first N available non-assistant personas.",
|
| 290 |
)
|
|
@@ -310,6 +316,7 @@ def _select_artifact_personas(
|
|
| 310 |
remember_key: str,
|
| 311 |
default_all: bool = False,
|
| 312 |
default_count_limit: int | None = None,
|
|
|
|
| 313 |
) -> list[str]:
|
| 314 |
empty_message = _personas_empty_message(variants)
|
| 315 |
options = _load_persona_options(
|
|
@@ -336,6 +343,7 @@ def _select_artifact_personas(
|
|
| 336 |
options,
|
| 337 |
default_count=default_count,
|
| 338 |
include_assistant_default=include_assistant_default,
|
|
|
|
| 339 |
)
|
| 340 |
|
| 341 |
persona_ids = options.regular_ids[:persona_count]
|
|
@@ -361,6 +369,48 @@ def _select_artifact_personas(
|
|
| 361 |
return persona_ids
|
| 362 |
|
| 363 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
def _render_save_buttons(
|
| 365 |
figs: list[object],
|
| 366 |
filenames: list[str],
|
|
@@ -398,6 +448,7 @@ def _render_mask_strategy_select(scope: str) -> MaskStrategy:
|
|
| 398 |
return render_mask_strategy_select(
|
| 399 |
key=widget_key("load", "mask_strategy", scope),
|
| 400 |
last_key=_LAST_MASK_STRATEGY_KEY,
|
|
|
|
| 401 |
help_text="Which extracted activation set to load.",
|
| 402 |
)
|
| 403 |
|
|
@@ -410,6 +461,8 @@ def _select_single_variant_samples(
|
|
| 410 |
remember_key: str,
|
| 411 |
variant_remember_key: str,
|
| 412 |
default_count_limit: int,
|
|
|
|
|
|
|
| 413 |
) -> tuple[str, list[str], str, list[int]] | None:
|
| 414 |
variants = available_variants(store, mask_strategy)
|
| 415 |
if not variants:
|
|
@@ -425,14 +478,41 @@ def _select_single_variant_samples(
|
|
| 425 |
default=default_variant,
|
| 426 |
format_func=prompt_variant_label,
|
| 427 |
)
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
if not persona_ids:
|
| 437 |
return None
|
| 438 |
|
|
|
|
| 261 |
*,
|
| 262 |
default_count: int,
|
| 263 |
include_assistant_default: bool,
|
| 264 |
+
max_count_limit: int | None = None,
|
| 265 |
) -> tuple[int, bool]:
|
| 266 |
count_key = widget_key(
|
| 267 |
"load",
|
|
|
|
| 281 |
)
|
| 282 |
|
| 283 |
if options.regular_ids:
|
| 284 |
+
max_count = (
|
| 285 |
+
min(max_count_limit, len(options.regular_ids))
|
| 286 |
+
if max_count_limit is not None
|
| 287 |
+
else len(options.regular_ids)
|
| 288 |
+
)
|
| 289 |
persona_count = st.slider(
|
| 290 |
"Personas",
|
| 291 |
min_value=0 if options.assistant_id is not None else 1,
|
| 292 |
+
max_value=max_count,
|
| 293 |
+
value=min(default_count, max_count),
|
| 294 |
key=count_key,
|
| 295 |
help="Use the first N available non-assistant personas.",
|
| 296 |
)
|
|
|
|
| 316 |
remember_key: str,
|
| 317 |
default_all: bool = False,
|
| 318 |
default_count_limit: int | None = None,
|
| 319 |
+
max_count_limit: int | None = None,
|
| 320 |
) -> list[str]:
|
| 321 |
empty_message = _personas_empty_message(variants)
|
| 322 |
options = _load_persona_options(
|
|
|
|
| 343 |
options,
|
| 344 |
default_count=default_count,
|
| 345 |
include_assistant_default=include_assistant_default,
|
| 346 |
+
max_count_limit=max_count_limit,
|
| 347 |
)
|
| 348 |
|
| 349 |
persona_ids = options.regular_ids[:persona_count]
|
|
|
|
| 369 |
return persona_ids
|
| 370 |
|
| 371 |
|
| 372 |
+
def _render_persona_select_controls(
|
| 373 |
+
options: PersonaOptions,
|
| 374 |
+
widget_scope: str,
|
| 375 |
+
*,
|
| 376 |
+
max_selections: int | None = None,
|
| 377 |
+
) -> list[str]:
|
| 378 |
+
select_key = widget_key("load", "persona_select", widget_scope)
|
| 379 |
+
assistant_key = widget_key("load", "persona_select_assistant", widget_scope)
|
| 380 |
+
|
| 381 |
+
label_map = {
|
| 382 |
+
persona_id: f"{options.persona_names.get(persona_id, persona_id)} ({persona_id})"
|
| 383 |
+
for persona_id in options.regular_ids
|
| 384 |
+
}
|
| 385 |
+
sorted_labels = sorted(label_map.values())
|
| 386 |
+
selected_labels = st.multiselect(
|
| 387 |
+
"Select personas",
|
| 388 |
+
options=sorted_labels,
|
| 389 |
+
key=select_key,
|
| 390 |
+
placeholder="Search and select personas...",
|
| 391 |
+
max_selections=max_selections,
|
| 392 |
+
)
|
| 393 |
+
label_to_id = {label: persona_id for persona_id, label in label_map.items()}
|
| 394 |
+
selected_ids = [label_to_id[label] for label in selected_labels]
|
| 395 |
+
|
| 396 |
+
if options.assistant_id is not None:
|
| 397 |
+
include_assistant = st.checkbox(
|
| 398 |
+
"Include Assistant persona",
|
| 399 |
+
key=assistant_key,
|
| 400 |
+
)
|
| 401 |
+
if include_assistant:
|
| 402 |
+
selected_ids.append(options.assistant_id)
|
| 403 |
+
|
| 404 |
+
st.session_state[_persona_names_state_key(widget_scope)] = dict(
|
| 405 |
+
options.persona_names
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
if not selected_ids:
|
| 409 |
+
st.info("Select at least one persona.")
|
| 410 |
+
|
| 411 |
+
return selected_ids
|
| 412 |
+
|
| 413 |
+
|
| 414 |
def _render_save_buttons(
|
| 415 |
figs: list[object],
|
| 416 |
filenames: list[str],
|
|
|
|
| 448 |
return render_mask_strategy_select(
|
| 449 |
key=widget_key("load", "mask_strategy", scope),
|
| 450 |
last_key=_LAST_MASK_STRATEGY_KEY,
|
| 451 |
+
remember_key="source:last_mask_strategy",
|
| 452 |
help_text="Which extracted activation set to load.",
|
| 453 |
)
|
| 454 |
|
|
|
|
| 461 |
remember_key: str,
|
| 462 |
variant_remember_key: str,
|
| 463 |
default_count_limit: int,
|
| 464 |
+
max_count_limit: int | None = None,
|
| 465 |
+
allow_specific_personas: bool = False,
|
| 466 |
) -> tuple[str, list[str], str, list[int]] | None:
|
| 467 |
variants = available_variants(store, mask_strategy)
|
| 468 |
if not variants:
|
|
|
|
| 478 |
default=default_variant,
|
| 479 |
format_func=prompt_variant_label,
|
| 480 |
)
|
| 481 |
+
widget_scope = f"{scope}:{store_id(store)}"
|
| 482 |
+
select_specific = False
|
| 483 |
+
if allow_specific_personas:
|
| 484 |
+
select_specific = st.toggle(
|
| 485 |
+
"Select specific personas",
|
| 486 |
+
value=False,
|
| 487 |
+
key=widget_key("load", "select_specific_personas", scope, store_id(store)),
|
| 488 |
+
help="Search and select specific personas instead of using the first N.",
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
if select_specific:
|
| 492 |
+
options = _load_persona_options(
|
| 493 |
+
store,
|
| 494 |
+
[variant],
|
| 495 |
+
mask_strategy,
|
| 496 |
+
empty_message=_personas_empty_message([variant]),
|
| 497 |
+
)
|
| 498 |
+
if options is None:
|
| 499 |
+
st.session_state.pop(_persona_names_state_key(widget_scope), None)
|
| 500 |
+
return None
|
| 501 |
+
persona_ids = _render_persona_select_controls(
|
| 502 |
+
options,
|
| 503 |
+
widget_scope,
|
| 504 |
+
max_selections=max_count_limit,
|
| 505 |
+
)
|
| 506 |
+
else:
|
| 507 |
+
persona_ids = _select_artifact_personas(
|
| 508 |
+
store,
|
| 509 |
+
[variant],
|
| 510 |
+
mask_strategy,
|
| 511 |
+
widget_scope=widget_scope,
|
| 512 |
+
remember_key=remember_key,
|
| 513 |
+
default_count_limit=default_count_limit,
|
| 514 |
+
max_count_limit=max_count_limit,
|
| 515 |
+
)
|
| 516 |
if not persona_ids:
|
| 517 |
return None
|
| 518 |
|
tabs/analysis/_state.py
CHANGED
|
@@ -4,7 +4,7 @@ import streamlit as st
|
|
| 4 |
from persona_data.synth_persona import BASELINE_PERSONA_ID
|
| 5 |
from persona_vectors.attributes import DEFAULT_MAX_ATTRIBUTE_CATEGORIES
|
| 6 |
|
| 7 |
-
from utils.helpers import slugify, widget_key
|
| 8 |
|
| 9 |
|
| 10 |
def _filename(*parts: str) -> str:
|
|
@@ -30,11 +30,15 @@ _LAST_LAYER_FRAMES_KEY = "analysis:last_layer_frames"
|
|
| 30 |
|
| 31 |
_DEFAULT_LAYER_FRAMES = 16
|
| 32 |
_DEFAULT_PERSONA_LIMITS = {
|
| 33 |
-
"similarity":
|
| 34 |
"pca": 500,
|
| 35 |
"umap": 500,
|
| 36 |
"isomap": 500,
|
| 37 |
-
"dendro":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
}
|
| 39 |
_MAX_SIMILARITY_CELLS = 4_000_000
|
| 40 |
_MAX_PAIR_TRAJECTORY_TRACES = 500
|
|
@@ -136,28 +140,38 @@ def _sequence_to_list(value: object) -> list[object] | None:
|
|
| 136 |
|
| 137 |
|
| 138 |
_TRACKED_STATE_KEYS_KEY = "analysis:_tracked_state_keys"
|
|
|
|
|
|
|
| 139 |
|
| 140 |
|
| 141 |
-
def
|
| 142 |
-
#
|
| 143 |
-
#
|
| 144 |
-
#
|
| 145 |
-
|
| 146 |
-
tracked: dict[str, set[str]] = st.session_state.setdefault(
|
| 147 |
_TRACKED_STATE_KEYS_KEY, {}
|
| 148 |
)
|
| 149 |
-
for key in tracked.get(suffix,
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
|
|
|
| 153 |
|
| 154 |
|
| 155 |
def _clear_old_figure_states(current_key: str) -> None:
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
|
| 159 |
def _clear_old_prepared_states(current_key: str) -> None:
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
|
| 163 |
def _store_figure_state(key: str, value: object) -> None:
|
|
|
|
| 4 |
from persona_data.synth_persona import BASELINE_PERSONA_ID
|
| 5 |
from persona_vectors.attributes import DEFAULT_MAX_ATTRIBUTE_CATEGORIES
|
| 6 |
|
| 7 |
+
from utils.helpers import env_int, slugify, widget_key
|
| 8 |
|
| 9 |
|
| 10 |
def _filename(*parts: str) -> str:
|
|
|
|
| 30 |
|
| 31 |
_DEFAULT_LAYER_FRAMES = 16
|
| 32 |
_DEFAULT_PERSONA_LIMITS = {
|
| 33 |
+
"similarity": 20,
|
| 34 |
"pca": 500,
|
| 35 |
"umap": 500,
|
| 36 |
"isomap": 500,
|
| 37 |
+
"dendro": 20,
|
| 38 |
+
}
|
| 39 |
+
_MAX_PERSONA_COUNTS = {
|
| 40 |
+
"similarity": 100,
|
| 41 |
+
"dendro": 100,
|
| 42 |
}
|
| 43 |
_MAX_SIMILARITY_CELLS = 4_000_000
|
| 44 |
_MAX_PAIR_TRAJECTORY_TRACES = 500
|
|
|
|
| 140 |
|
| 141 |
|
| 142 |
_TRACKED_STATE_KEYS_KEY = "analysis:_tracked_state_keys"
|
| 143 |
+
_FIGURE_STATE_ENTRIES = env_int("PERSONA_UI_FIGURE_STATE_ENTRIES", 2)
|
| 144 |
+
_PREPARED_STATE_ENTRIES = env_int("PERSONA_UI_PREPARED_STATE_ENTRIES", 4)
|
| 145 |
|
| 146 |
|
| 147 |
+
def _touch_load_state(current_key: str, suffix: str, *, max_entries: int) -> None:
|
| 148 |
+
# Keep a tiny MRU window of heavy state instead of scanning all of
|
| 149 |
+
# session_state or retaining every figure forever. This makes nearby
|
| 150 |
+
# method-switching feel warm while still giving RAM a hard ceiling.
|
| 151 |
+
tracked: dict[str, list[str]] = st.session_state.setdefault(
|
|
|
|
| 152 |
_TRACKED_STATE_KEYS_KEY, {}
|
| 153 |
)
|
| 154 |
+
keys = [key for key in tracked.get(suffix, []) if key != current_key]
|
| 155 |
+
keys.append(current_key)
|
| 156 |
+
while len(keys) > max(1, max_entries):
|
| 157 |
+
st.session_state.pop(keys.pop(0), None)
|
| 158 |
+
tracked[suffix] = keys
|
| 159 |
|
| 160 |
|
| 161 |
def _clear_old_figure_states(current_key: str) -> None:
|
| 162 |
+
_touch_load_state(
|
| 163 |
+
current_key,
|
| 164 |
+
"_fig_state",
|
| 165 |
+
max_entries=_FIGURE_STATE_ENTRIES,
|
| 166 |
+
)
|
| 167 |
|
| 168 |
|
| 169 |
def _clear_old_prepared_states(current_key: str) -> None:
|
| 170 |
+
_touch_load_state(
|
| 171 |
+
current_key,
|
| 172 |
+
"_projection_ready",
|
| 173 |
+
max_entries=_PREPARED_STATE_ENTRIES,
|
| 174 |
+
)
|
| 175 |
|
| 176 |
|
| 177 |
def _store_figure_state(key: str, value: object) -> None:
|
tabs/analysis/cosine.py
CHANGED
|
@@ -4,9 +4,6 @@ import streamlit as st
|
|
| 4 |
from persona_vectors.extraction import MaskStrategy
|
| 5 |
from persona_vectors.plots import plot_layer_similarity
|
| 6 |
|
| 7 |
-
from utils.analysis_sources import Store, available_variants, store_id
|
| 8 |
-
from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
|
| 9 |
-
|
| 10 |
from tabs.analysis._shared import (
|
| 11 |
_load_variant_vectors,
|
| 12 |
_plotly_chart,
|
|
@@ -21,6 +18,8 @@ from tabs.analysis._state import (
|
|
| 21 |
_filename,
|
| 22 |
_store_figure_state,
|
| 23 |
)
|
|
|
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
def _render_cosine_selection(
|
|
|
|
| 4 |
from persona_vectors.extraction import MaskStrategy
|
| 5 |
from persona_vectors.plots import plot_layer_similarity
|
| 6 |
|
|
|
|
|
|
|
|
|
|
| 7 |
from tabs.analysis._shared import (
|
| 8 |
_load_variant_vectors,
|
| 9 |
_plotly_chart,
|
|
|
|
| 18 |
_filename,
|
| 19 |
_store_figure_state,
|
| 20 |
)
|
| 21 |
+
from utils.analysis_sources import Store, available_variants, store_id
|
| 22 |
+
from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
|
| 23 |
|
| 24 |
|
| 25 |
def _render_cosine_selection(
|
tabs/analysis/dendrogram.py
CHANGED
|
@@ -1,15 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from persona_vectors.extraction import MaskStrategy
|
| 3 |
from persona_vectors.plots import plot_persona_dendrogram
|
| 4 |
-
|
| 5 |
-
from utils.analysis_sources import (
|
| 6 |
-
Store,
|
| 7 |
-
available_variants,
|
| 8 |
-
store_cache_parts,
|
| 9 |
-
store_id,
|
| 10 |
-
store_layers_cached,
|
| 11 |
-
)
|
| 12 |
-
from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
|
| 13 |
|
| 14 |
from tabs.analysis._shared import (
|
| 15 |
_load_persona_options,
|
|
@@ -17,60 +12,113 @@ from tabs.analysis._shared import (
|
|
| 17 |
_plotly_chart,
|
| 18 |
_release_vector_memory,
|
| 19 |
_render_layer_frame_controls,
|
|
|
|
| 20 |
_render_save_buttons,
|
| 21 |
_select_artifact_personas,
|
| 22 |
)
|
| 23 |
from tabs.analysis._state import (
|
| 24 |
_DEFAULT_PERSONA_LIMITS,
|
| 25 |
-
|
| 26 |
_clear_old_figure_states,
|
| 27 |
_filename,
|
| 28 |
_persona_names_state_key,
|
| 29 |
_personas_empty_message,
|
| 30 |
_store_figure_state,
|
| 31 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
_LAST_DENDRO_PERSONAS_KEY = "analysis:last_personas:dendro"
|
| 34 |
_DENDRO_LINKAGE_OPTIONS = ["ward", "complete", "average", "single"]
|
| 35 |
|
| 36 |
|
| 37 |
-
def
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
options=sorted_labels,
|
| 52 |
-
key=select_key,
|
| 53 |
-
placeholder="Search and select personas...",
|
| 54 |
)
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
)
|
| 63 |
-
if include_assistant:
|
| 64 |
-
selected_ids.append(options.assistant_id)
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
)
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
def _render_dendrogram_analysis(
|
|
@@ -132,6 +180,7 @@ def _render_dendrogram_analysis(
|
|
| 132 |
persona_ids = _render_persona_select_controls(
|
| 133 |
options,
|
| 134 |
widget_scope=f"dendro:{store_id(store)}",
|
|
|
|
| 135 |
)
|
| 136 |
if not persona_ids:
|
| 137 |
return
|
|
@@ -143,6 +192,7 @@ def _render_dendrogram_analysis(
|
|
| 143 |
widget_scope=f"dendro:{store_id(store)}",
|
| 144 |
remember_key=_LAST_DENDRO_PERSONAS_KEY,
|
| 145 |
default_count_limit=_DEFAULT_PERSONA_LIMITS["dendro"],
|
|
|
|
| 146 |
)
|
| 147 |
if not persona_ids:
|
| 148 |
return
|
|
@@ -221,7 +271,6 @@ def _render_dendrogram_analysis(
|
|
| 221 |
title=f"Dendrogram — {prompt_variant_label(variant_a)}",
|
| 222 |
)
|
| 223 |
fig_a.update_layout(height=750)
|
| 224 |
-
del samples_a
|
| 225 |
fig_b = None
|
| 226 |
if variant_a != variant_b:
|
| 227 |
progress.progress(60, text="Building second dendrogram…")
|
|
@@ -236,10 +285,26 @@ def _render_dendrogram_analysis(
|
|
| 236 |
)
|
| 237 |
fig_b.update_layout(height=750)
|
| 238 |
del samples_b
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
progress.progress(90, text="Storing figure state…")
|
| 240 |
_store_figure_state(
|
| 241 |
fig_key,
|
| 242 |
-
(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
)
|
| 244 |
progress.progress(100, text="Done.")
|
| 245 |
except Exception as exc:
|
|
@@ -250,8 +315,16 @@ def _render_dendrogram_analysis(
|
|
| 250 |
progress.empty()
|
| 251 |
|
| 252 |
if fig_key in st.session_state:
|
| 253 |
-
|
| 254 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
col_a, col_b = st.columns(2)
|
| 256 |
with col_a:
|
| 257 |
st.subheader(prompt_variant_label(va))
|
|
@@ -262,14 +335,22 @@ def _render_dendrogram_analysis(
|
|
| 262 |
else:
|
| 263 |
_plotly_chart(fig_a)
|
| 264 |
|
| 265 |
-
figs =
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
_render_save_buttons(figs, filenames, "dendro")
|
| 275 |
st.success(f"Generated dendrogram(s) for {n_personas} persona(s).")
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
|
| 3 |
+
import plotly.graph_objects as go
|
| 4 |
import streamlit as st
|
| 5 |
from persona_vectors.extraction import MaskStrategy
|
| 6 |
from persona_vectors.plots import plot_persona_dendrogram
|
| 7 |
+
from plotly.subplots import make_subplots
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
from tabs.analysis._shared import (
|
| 10 |
_load_persona_options,
|
|
|
|
| 12 |
_plotly_chart,
|
| 13 |
_release_vector_memory,
|
| 14 |
_render_layer_frame_controls,
|
| 15 |
+
_render_persona_select_controls,
|
| 16 |
_render_save_buttons,
|
| 17 |
_select_artifact_personas,
|
| 18 |
)
|
| 19 |
from tabs.analysis._state import (
|
| 20 |
_DEFAULT_PERSONA_LIMITS,
|
| 21 |
+
_MAX_PERSONA_COUNTS,
|
| 22 |
_clear_old_figure_states,
|
| 23 |
_filename,
|
| 24 |
_persona_names_state_key,
|
| 25 |
_personas_empty_message,
|
| 26 |
_store_figure_state,
|
| 27 |
)
|
| 28 |
+
from utils.analysis_sources import (
|
| 29 |
+
Store,
|
| 30 |
+
available_variants,
|
| 31 |
+
store_cache_parts,
|
| 32 |
+
store_id,
|
| 33 |
+
store_layers_cached,
|
| 34 |
+
)
|
| 35 |
+
from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
|
| 36 |
|
| 37 |
_LAST_DENDRO_PERSONAS_KEY = "analysis:last_personas:dendro"
|
| 38 |
_DENDRO_LINKAGE_OPTIONS = ["ward", "complete", "average", "single"]
|
| 39 |
|
| 40 |
|
| 41 |
+
def _comparison_dendrogram_figure(
|
| 42 |
+
fig_a: go.Figure,
|
| 43 |
+
fig_b: go.Figure,
|
| 44 |
+
*,
|
| 45 |
+
title_a: str,
|
| 46 |
+
title_b: str,
|
| 47 |
+
) -> go.Figure:
|
| 48 |
+
"""Merge two layered dendrograms so one slider drives both panels."""
|
| 49 |
+
combined = make_subplots(
|
| 50 |
+
rows=1,
|
| 51 |
+
cols=2,
|
| 52 |
+
subplot_titles=(title_a, title_b),
|
| 53 |
+
shared_yaxes=True,
|
| 54 |
+
horizontal_spacing=0.05,
|
|
|
|
|
|
|
|
|
|
| 55 |
)
|
| 56 |
+
for trace in fig_a.data:
|
| 57 |
+
combined.add_trace(deepcopy(trace), row=1, col=1)
|
| 58 |
+
for trace in fig_b.data:
|
| 59 |
+
combined.add_trace(deepcopy(trace), row=1, col=2)
|
| 60 |
|
| 61 |
+
frames: list[go.Frame] = []
|
| 62 |
+
for frame_a, frame_b in zip(fig_a.frames, fig_b.frames, strict=True):
|
| 63 |
+
right_data = []
|
| 64 |
+
for trace in frame_b.data:
|
| 65 |
+
copied = deepcopy(trace)
|
| 66 |
+
copied.update(xaxis="x2", yaxis="y2")
|
| 67 |
+
right_data.append(copied)
|
| 68 |
+
frame_xaxis = frame_a.layout.xaxis.to_plotly_json()
|
| 69 |
+
frame_xaxis2 = frame_b.layout.xaxis.to_plotly_json()
|
| 70 |
+
frame_xaxis2["matches"] = None
|
| 71 |
+
frame_xaxis2["anchor"] = "y2"
|
| 72 |
+
frame_yaxis = frame_a.layout.yaxis.to_plotly_json()
|
| 73 |
+
frame_yaxis2 = frame_b.layout.yaxis.to_plotly_json()
|
| 74 |
+
frame_yaxis2["matches"] = "y"
|
| 75 |
+
frame_yaxis2["anchor"] = "x2"
|
| 76 |
+
frames.append(
|
| 77 |
+
go.Frame(
|
| 78 |
+
name=frame_a.name,
|
| 79 |
+
data=[*deepcopy(frame_a.data), *right_data],
|
| 80 |
+
layout={
|
| 81 |
+
"title": {"text": f"Dendrogram comparison - Layer {frame_a.name}"},
|
| 82 |
+
"xaxis": frame_xaxis,
|
| 83 |
+
"xaxis2": frame_xaxis2,
|
| 84 |
+
"yaxis": frame_yaxis,
|
| 85 |
+
"yaxis2": frame_yaxis2,
|
| 86 |
+
},
|
| 87 |
+
)
|
| 88 |
)
|
|
|
|
|
|
|
| 89 |
|
| 90 |
+
y_ranges = [
|
| 91 |
+
fig_a.layout.yaxis.range,
|
| 92 |
+
fig_b.layout.yaxis.range,
|
| 93 |
+
]
|
| 94 |
+
max_y = max(float(axis_range[1]) for axis_range in y_ranges if axis_range)
|
| 95 |
+
first_layer = fig_a.frames[0].name if fig_a.frames else ""
|
| 96 |
+
combined.frames = frames
|
| 97 |
+
combined.update_layout(
|
| 98 |
+
title={
|
| 99 |
+
"text": f"Dendrogram comparison - Layer {first_layer}",
|
| 100 |
+
"font": {"size": 24},
|
| 101 |
+
"y": 0.98,
|
| 102 |
+
"yanchor": "top",
|
| 103 |
+
},
|
| 104 |
+
template="plotly_white",
|
| 105 |
+
height=750,
|
| 106 |
+
margin=dict(t=140, b=260),
|
| 107 |
+
updatemenus=fig_a.layout.updatemenus,
|
| 108 |
+
sliders=fig_a.layout.sliders,
|
| 109 |
)
|
| 110 |
+
left_xaxis = fig_a.layout.xaxis.to_plotly_json()
|
| 111 |
+
right_xaxis = fig_b.layout.xaxis.to_plotly_json()
|
| 112 |
+
right_xaxis["matches"] = None
|
| 113 |
+
right_xaxis["anchor"] = "y2"
|
| 114 |
+
combined.update_layout(xaxis=left_xaxis, xaxis2=right_xaxis)
|
| 115 |
+
combined.update_xaxes(tickangle=-45, automargin=True)
|
| 116 |
+
combined.update_yaxes(
|
| 117 |
+
title_text=fig_a.layout.yaxis.title.text,
|
| 118 |
+
range=[0.0, max_y],
|
| 119 |
+
automargin=True,
|
| 120 |
+
)
|
| 121 |
+
return combined
|
| 122 |
|
| 123 |
|
| 124 |
def _render_dendrogram_analysis(
|
|
|
|
| 180 |
persona_ids = _render_persona_select_controls(
|
| 181 |
options,
|
| 182 |
widget_scope=f"dendro:{store_id(store)}",
|
| 183 |
+
max_selections=_MAX_PERSONA_COUNTS["dendro"],
|
| 184 |
)
|
| 185 |
if not persona_ids:
|
| 186 |
return
|
|
|
|
| 192 |
widget_scope=f"dendro:{store_id(store)}",
|
| 193 |
remember_key=_LAST_DENDRO_PERSONAS_KEY,
|
| 194 |
default_count_limit=_DEFAULT_PERSONA_LIMITS["dendro"],
|
| 195 |
+
max_count_limit=_MAX_PERSONA_COUNTS["dendro"],
|
| 196 |
)
|
| 197 |
if not persona_ids:
|
| 198 |
return
|
|
|
|
| 271 |
title=f"Dendrogram — {prompt_variant_label(variant_a)}",
|
| 272 |
)
|
| 273 |
fig_a.update_layout(height=750)
|
|
|
|
| 274 |
fig_b = None
|
| 275 |
if variant_a != variant_b:
|
| 276 |
progress.progress(60, text="Building second dendrogram…")
|
|
|
|
| 285 |
)
|
| 286 |
fig_b.update_layout(height=750)
|
| 287 |
del samples_b
|
| 288 |
+
del samples_a
|
| 289 |
+
comparison_fig = None
|
| 290 |
+
if fig_b is not None and layered_mode:
|
| 291 |
+
comparison_fig = _comparison_dendrogram_figure(
|
| 292 |
+
fig_a,
|
| 293 |
+
fig_b,
|
| 294 |
+
title_a=prompt_variant_label(variant_a),
|
| 295 |
+
title_b=prompt_variant_label(variant_b),
|
| 296 |
+
)
|
| 297 |
progress.progress(90, text="Storing figure state…")
|
| 298 |
_store_figure_state(
|
| 299 |
fig_key,
|
| 300 |
+
(
|
| 301 |
+
None if comparison_fig is not None else fig_a,
|
| 302 |
+
None if comparison_fig is not None else fig_b,
|
| 303 |
+
comparison_fig,
|
| 304 |
+
len(persona_ids),
|
| 305 |
+
variant_a,
|
| 306 |
+
variant_b,
|
| 307 |
+
),
|
| 308 |
)
|
| 309 |
progress.progress(100, text="Done.")
|
| 310 |
except Exception as exc:
|
|
|
|
| 315 |
progress.empty()
|
| 316 |
|
| 317 |
if fig_key in st.session_state:
|
| 318 |
+
saved = st.session_state[fig_key]
|
| 319 |
+
if len(saved) == 5:
|
| 320 |
+
# Drop pre-refactor state so hot-reloaded sessions do not unpack the
|
| 321 |
+
# old two-figure payload shape.
|
| 322 |
+
st.session_state.pop(fig_key, None)
|
| 323 |
+
return
|
| 324 |
+
fig_a, fig_b, comparison_fig, n_personas, va, vb = saved
|
| 325 |
+
if comparison_fig is not None:
|
| 326 |
+
_plotly_chart(comparison_fig)
|
| 327 |
+
elif fig_b is not None:
|
| 328 |
col_a, col_b = st.columns(2)
|
| 329 |
with col_a:
|
| 330 |
st.subheader(prompt_variant_label(va))
|
|
|
|
| 335 |
else:
|
| 336 |
_plotly_chart(fig_a)
|
| 337 |
|
| 338 |
+
figs = (
|
| 339 |
+
[comparison_fig]
|
| 340 |
+
if comparison_fig is not None
|
| 341 |
+
else [fig_a] + ([fig_b] if fig_b else [])
|
| 342 |
+
)
|
| 343 |
+
filenames = (
|
| 344 |
+
[_filename("dendro_compare", store.model_name, mask_strategy.value, va, vb)]
|
| 345 |
+
if comparison_fig is not None
|
| 346 |
+
else [
|
| 347 |
+
_filename("dendro", store.model_name, mask_strategy.value, va),
|
| 348 |
+
*(
|
| 349 |
+
[_filename("dendro", store.model_name, mask_strategy.value, vb)]
|
| 350 |
+
if fig_b
|
| 351 |
+
else []
|
| 352 |
+
),
|
| 353 |
+
]
|
| 354 |
+
)
|
| 355 |
_render_save_buttons(figs, filenames, "dendro")
|
| 356 |
st.success(f"Generated dendrogram(s) for {n_personas} persona(s).")
|
tabs/analysis/layered.py
CHANGED
|
@@ -2,10 +2,7 @@ from collections.abc import Callable
|
|
| 2 |
|
| 3 |
import plotly.graph_objects as go
|
| 4 |
import streamlit as st
|
| 5 |
-
from persona_vectors.attributes import
|
| 6 |
-
attribute_color_kwargs,
|
| 7 |
-
attribute_display_label,
|
| 8 |
-
)
|
| 9 |
from persona_vectors.extraction import MaskStrategy
|
| 10 |
from persona_vectors.plots import (
|
| 11 |
build_layered_figure,
|
|
@@ -13,19 +10,6 @@ from persona_vectors.plots import (
|
|
| 13 |
build_similarity_figures,
|
| 14 |
)
|
| 15 |
|
| 16 |
-
from utils.analysis_metadata import (
|
| 17 |
-
synth_persona_attribute_names,
|
| 18 |
-
synth_persona_dataset_cached,
|
| 19 |
-
)
|
| 20 |
-
from utils.analysis_sources import (
|
| 21 |
-
Store,
|
| 22 |
-
kmeans_groups_cached,
|
| 23 |
-
projection_data_cached,
|
| 24 |
-
store_cache_parts,
|
| 25 |
-
store_id,
|
| 26 |
-
)
|
| 27 |
-
from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
|
| 28 |
-
|
| 29 |
from tabs.analysis._shared import (
|
| 30 |
_gray_out_unselected_personas,
|
| 31 |
_load_persona_vectors,
|
|
@@ -61,6 +45,18 @@ from tabs.analysis._state import (
|
|
| 61 |
_remembered_selectbox,
|
| 62 |
_store_figure_state,
|
| 63 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
|
| 66 |
def _render_pair_trajectory_control(
|
|
@@ -446,6 +442,8 @@ def _render_layered_figure_analysis(
|
|
| 446 |
n_components: int = 2,
|
| 447 |
remember_key: str = _LAST_PROJECTION_PERSONAS_KEY,
|
| 448 |
default_count_limit: int = 500,
|
|
|
|
|
|
|
| 449 |
) -> None:
|
| 450 |
"""Render a single-variant layered analysis: select → button → figure(s).
|
| 451 |
|
|
@@ -463,6 +461,8 @@ def _render_layered_figure_analysis(
|
|
| 463 |
else _LAST_SIMILARITY_VARIANT_KEY
|
| 464 |
),
|
| 465 |
default_count_limit=default_count_limit,
|
|
|
|
|
|
|
| 466 |
)
|
| 467 |
if selected is None:
|
| 468 |
return
|
|
|
|
| 2 |
|
| 3 |
import plotly.graph_objects as go
|
| 4 |
import streamlit as st
|
| 5 |
+
from persona_vectors.attributes import attribute_color_kwargs, attribute_display_label
|
|
|
|
|
|
|
|
|
|
| 6 |
from persona_vectors.extraction import MaskStrategy
|
| 7 |
from persona_vectors.plots import (
|
| 8 |
build_layered_figure,
|
|
|
|
| 10 |
build_similarity_figures,
|
| 11 |
)
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
from tabs.analysis._shared import (
|
| 14 |
_gray_out_unselected_personas,
|
| 15 |
_load_persona_vectors,
|
|
|
|
| 45 |
_remembered_selectbox,
|
| 46 |
_store_figure_state,
|
| 47 |
)
|
| 48 |
+
from utils.analysis_metadata import (
|
| 49 |
+
synth_persona_attribute_names,
|
| 50 |
+
synth_persona_dataset_cached,
|
| 51 |
+
)
|
| 52 |
+
from utils.analysis_sources import (
|
| 53 |
+
Store,
|
| 54 |
+
kmeans_groups_cached,
|
| 55 |
+
projection_data_cached,
|
| 56 |
+
store_cache_parts,
|
| 57 |
+
store_id,
|
| 58 |
+
)
|
| 59 |
+
from utils.helpers import personas_fingerprint, prompt_variant_label, widget_key
|
| 60 |
|
| 61 |
|
| 62 |
def _render_pair_trajectory_control(
|
|
|
|
| 442 |
n_components: int = 2,
|
| 443 |
remember_key: str = _LAST_PROJECTION_PERSONAS_KEY,
|
| 444 |
default_count_limit: int = 500,
|
| 445 |
+
max_count_limit: int | None = None,
|
| 446 |
+
allow_specific_personas: bool = False,
|
| 447 |
) -> None:
|
| 448 |
"""Render a single-variant layered analysis: select → button → figure(s).
|
| 449 |
|
|
|
|
| 461 |
else _LAST_SIMILARITY_VARIANT_KEY
|
| 462 |
),
|
| 463 |
default_count_limit=default_count_limit,
|
| 464 |
+
max_count_limit=max_count_limit,
|
| 465 |
+
allow_specific_personas=allow_specific_personas,
|
| 466 |
)
|
| 467 |
if selected is None:
|
| 468 |
return
|
tabs/analysis_core.py
CHANGED
|
@@ -1,27 +1,4 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
|
| 3 |
import streamlit as st
|
| 4 |
-
from persona_data.environment import get_artifacts_dir
|
| 5 |
-
from persona_vectors.extraction import MaskStrategy
|
| 6 |
-
|
| 7 |
-
from utils.analysis_sources import (
|
| 8 |
-
DEFAULT_COMPARE_MODEL,
|
| 9 |
-
DEFAULT_HUB_REPO,
|
| 10 |
-
SOURCE_HUB,
|
| 11 |
-
SOURCE_LOCAL,
|
| 12 |
-
SOURCES,
|
| 13 |
-
Store,
|
| 14 |
-
activation_store_cached,
|
| 15 |
-
hub_models_by_mask_strategy,
|
| 16 |
-
local_model_matches,
|
| 17 |
-
local_model_options_cached,
|
| 18 |
-
)
|
| 19 |
-
from utils.helpers import (
|
| 20 |
-
ANALYSIS_HELP_TEXT,
|
| 21 |
-
ANALYSIS_MODES,
|
| 22 |
-
prompt_variant_label,
|
| 23 |
-
widget_key,
|
| 24 |
-
)
|
| 25 |
|
| 26 |
from tabs.analysis._shared import _render_mask_strategy_select
|
| 27 |
from tabs.analysis._state import (
|
|
@@ -29,153 +6,18 @@ from tabs.analysis._state import (
|
|
| 29 |
_LAST_PROJECTION_DIMS_KEY,
|
| 30 |
_LAST_SIMILARITY_PERSONAS_KEY,
|
| 31 |
_LAST_SOURCE_KEY,
|
|
|
|
| 32 |
)
|
| 33 |
from tabs.analysis.cosine import _render_cosine_similarity
|
| 34 |
from tabs.analysis.dendrogram import _render_dendrogram_analysis
|
| 35 |
from tabs.analysis.layered import _render_layered_figure_analysis
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
default=last_source if last_source in SOURCES else SOURCE_HUB,
|
| 44 |
-
key=widget_key("load", "source"),
|
| 45 |
-
label_visibility="collapsed",
|
| 46 |
-
)
|
| 47 |
-
if source is None:
|
| 48 |
-
source = SOURCE_HUB
|
| 49 |
-
st.session_state[_LAST_SOURCE_KEY] = source
|
| 50 |
-
return source
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def _render_hub_model_select(
|
| 54 |
-
repo_id: str,
|
| 55 |
-
mask_strategy: MaskStrategy,
|
| 56 |
-
) -> str:
|
| 57 |
-
fallback_model = st.session_state.get(
|
| 58 |
-
"analysis:hub_model_fallback",
|
| 59 |
-
DEFAULT_COMPARE_MODEL,
|
| 60 |
-
)
|
| 61 |
-
try:
|
| 62 |
-
models_by_strategy = hub_models_by_mask_strategy(repo_id)
|
| 63 |
-
except Exception as exc:
|
| 64 |
-
st.warning(f"Could not load Hub configs for `{repo_id}`: {exc}")
|
| 65 |
-
return st.text_input(
|
| 66 |
-
"Hub model",
|
| 67 |
-
value=fallback_model,
|
| 68 |
-
key="analysis:hub_model_fallback",
|
| 69 |
-
help="Analysis-only model id to use if Hub config discovery is unavailable.",
|
| 70 |
-
)
|
| 71 |
-
|
| 72 |
-
model_options = models_by_strategy.get(mask_strategy, [])
|
| 73 |
-
if not model_options:
|
| 74 |
-
st.warning(
|
| 75 |
-
f"No Hub vector configs found for `{mask_strategy.value}` in `{repo_id}`."
|
| 76 |
-
)
|
| 77 |
-
return st.text_input(
|
| 78 |
-
"Hub model",
|
| 79 |
-
value=fallback_model,
|
| 80 |
-
key="analysis:hub_model_fallback",
|
| 81 |
-
help="Analysis-only model id to use for this Hub repo.",
|
| 82 |
-
)
|
| 83 |
-
|
| 84 |
-
previous_model = st.session_state.get(
|
| 85 |
-
widget_key("load", "hub_model", repo_id, mask_strategy.value),
|
| 86 |
-
fallback_model,
|
| 87 |
-
)
|
| 88 |
-
default_model = (
|
| 89 |
-
previous_model if previous_model in model_options else model_options[0]
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
return st.selectbox(
|
| 93 |
-
"Hub model",
|
| 94 |
-
options=model_options,
|
| 95 |
-
index=model_options.index(default_model),
|
| 96 |
-
key=widget_key("load", "hub_model", repo_id, mask_strategy.value),
|
| 97 |
-
help="Models with vectors in the selected Hub repo and mask strategy.",
|
| 98 |
-
)
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
def _render_local_model_select(
|
| 102 |
-
artifacts_root: str,
|
| 103 |
-
mask_strategy: MaskStrategy,
|
| 104 |
-
) -> str:
|
| 105 |
-
fallback_model = st.session_state.get("analysis:local_model", DEFAULT_COMPARE_MODEL)
|
| 106 |
-
model_options = local_model_options_cached(artifacts_root, mask_strategy.value)
|
| 107 |
-
if not model_options:
|
| 108 |
-
return st.text_input(
|
| 109 |
-
"Local model",
|
| 110 |
-
value=fallback_model,
|
| 111 |
-
key="analysis:local_model",
|
| 112 |
-
help="Analysis-only local model id or path.",
|
| 113 |
-
)
|
| 114 |
-
|
| 115 |
-
custom = st.toggle(
|
| 116 |
-
"Custom local model",
|
| 117 |
-
value=False,
|
| 118 |
-
key="analysis:local_model_custom_enabled",
|
| 119 |
-
help="Enter a model id/path manually instead of choosing from activation directories.",
|
| 120 |
-
)
|
| 121 |
-
if custom:
|
| 122 |
-
return st.text_input(
|
| 123 |
-
"Local model",
|
| 124 |
-
value=fallback_model,
|
| 125 |
-
key="analysis:local_model",
|
| 126 |
-
help="Analysis-only local model id or path.",
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
previous_model = st.session_state.get("analysis:local_model_select", fallback_model)
|
| 130 |
-
if not any(local_model_matches(previous_model, option) for option in model_options):
|
| 131 |
-
previous_model = fallback_model
|
| 132 |
-
default_model = next(
|
| 133 |
-
(
|
| 134 |
-
option
|
| 135 |
-
for option in model_options
|
| 136 |
-
if local_model_matches(option, previous_model)
|
| 137 |
-
),
|
| 138 |
-
model_options[0],
|
| 139 |
-
)
|
| 140 |
-
selected = st.selectbox(
|
| 141 |
-
"Local model",
|
| 142 |
-
options=model_options,
|
| 143 |
-
index=model_options.index(default_model),
|
| 144 |
-
key="analysis:local_model_select",
|
| 145 |
-
help="Models discovered under the selected artifacts root.",
|
| 146 |
-
)
|
| 147 |
-
st.session_state["analysis:local_model"] = selected
|
| 148 |
-
return selected
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
def _build_store(source: str, mask_strategy: MaskStrategy) -> Store:
|
| 152 |
-
if source == SOURCE_HUB:
|
| 153 |
-
repo = st.text_input(
|
| 154 |
-
"Hub repo",
|
| 155 |
-
value=st.session_state.get("analysis:hub_repo", DEFAULT_HUB_REPO),
|
| 156 |
-
key="analysis:hub_repo",
|
| 157 |
-
help="Hugging Face dataset published by `scripts/push_to_hf.py`.",
|
| 158 |
-
)
|
| 159 |
-
hub_model_name = _render_hub_model_select(repo, mask_strategy)
|
| 160 |
-
return activation_store_cached(
|
| 161 |
-
SOURCE_HUB,
|
| 162 |
-
repo,
|
| 163 |
-
hub_model_name,
|
| 164 |
-
mask_strategy.value,
|
| 165 |
-
)
|
| 166 |
-
artifacts_root = st.text_input(
|
| 167 |
-
"Artifacts root",
|
| 168 |
-
value=str(get_artifacts_dir() / "activations"),
|
| 169 |
-
key="analysis:artifacts_root",
|
| 170 |
-
)
|
| 171 |
-
artifacts_root = str(Path(artifacts_root).expanduser())
|
| 172 |
-
local_model_name = _render_local_model_select(artifacts_root, mask_strategy)
|
| 173 |
-
return activation_store_cached(
|
| 174 |
-
SOURCE_LOCAL,
|
| 175 |
-
artifacts_root,
|
| 176 |
-
local_model_name,
|
| 177 |
-
mask_strategy.value,
|
| 178 |
-
)
|
| 179 |
|
| 180 |
|
| 181 |
def render_analysis_tab() -> None:
|
|
@@ -186,7 +28,7 @@ def render_analysis_tab() -> None:
|
|
| 186 |
"Analyse persona vectors by cosine similarity, PCA, UMAP, Isomap, or hierarchical clustering."
|
| 187 |
)
|
| 188 |
|
| 189 |
-
source =
|
| 190 |
|
| 191 |
analysis_mode = st.segmented_control(
|
| 192 |
"Analysis mode",
|
|
@@ -201,7 +43,18 @@ def render_analysis_tab() -> None:
|
|
| 201 |
|
| 202 |
with st.expander("Source settings", expanded=True):
|
| 203 |
mask_strategy = _render_mask_strategy_select(analysis_mode)
|
| 204 |
-
store =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
if analysis_mode == "Cosine similarity":
|
| 207 |
_render_cosine_similarity(store, mask_strategy)
|
|
@@ -219,6 +72,8 @@ def render_analysis_tab() -> None:
|
|
| 219 |
include_pair_trajectories=True,
|
| 220 |
remember_key=_LAST_SIMILARITY_PERSONAS_KEY,
|
| 221 |
default_count_limit=_DEFAULT_PERSONA_LIMITS["similarity"],
|
|
|
|
|
|
|
| 222 |
)
|
| 223 |
return
|
| 224 |
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
from tabs.analysis._shared import _render_mask_strategy_select
|
| 4 |
from tabs.analysis._state import (
|
|
|
|
| 6 |
_LAST_PROJECTION_DIMS_KEY,
|
| 7 |
_LAST_SIMILARITY_PERSONAS_KEY,
|
| 8 |
_LAST_SOURCE_KEY,
|
| 9 |
+
_MAX_PERSONA_COUNTS,
|
| 10 |
)
|
| 11 |
from tabs.analysis.cosine import _render_cosine_similarity
|
| 12 |
from tabs.analysis.dendrogram import _render_dendrogram_analysis
|
| 13 |
from tabs.analysis.layered import _render_layered_figure_analysis
|
| 14 |
+
from utils.helpers import (
|
| 15 |
+
ANALYSIS_HELP_TEXT,
|
| 16 |
+
ANALYSIS_MODES,
|
| 17 |
+
prompt_variant_label,
|
| 18 |
+
widget_key,
|
| 19 |
+
)
|
| 20 |
+
from utils.source_controls import render_source_select, render_store_select
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
def render_analysis_tab() -> None:
|
|
|
|
| 28 |
"Analyse persona vectors by cosine similarity, PCA, UMAP, Isomap, or hierarchical clustering."
|
| 29 |
)
|
| 30 |
|
| 31 |
+
source = render_source_select(widget_scope="load", last_source_key=_LAST_SOURCE_KEY)
|
| 32 |
|
| 33 |
analysis_mode = st.segmented_control(
|
| 34 |
"Analysis mode",
|
|
|
|
| 43 |
|
| 44 |
with st.expander("Source settings", expanded=True):
|
| 45 |
mask_strategy = _render_mask_strategy_select(analysis_mode)
|
| 46 |
+
store = render_store_select(
|
| 47 |
+
source,
|
| 48 |
+
mask_strategy,
|
| 49 |
+
state_prefix="analysis",
|
| 50 |
+
widget_scope="load",
|
| 51 |
+
artifacts_root_key="analysis:artifacts_root",
|
| 52 |
+
model_label="Hub model",
|
| 53 |
+
local_model_label="Local model",
|
| 54 |
+
allow_custom_local_model=True,
|
| 55 |
+
repo_help="Hugging Face dataset published by `scripts/push_to_hf.py`.",
|
| 56 |
+
fallback_help="Analysis-only model id to use if Hub config discovery is unavailable.",
|
| 57 |
+
)
|
| 58 |
|
| 59 |
if analysis_mode == "Cosine similarity":
|
| 60 |
_render_cosine_similarity(store, mask_strategy)
|
|
|
|
| 72 |
include_pair_trajectories=True,
|
| 73 |
remember_key=_LAST_SIMILARITY_PERSONAS_KEY,
|
| 74 |
default_count_limit=_DEFAULT_PERSONA_LIMITS["similarity"],
|
| 75 |
+
max_count_limit=_MAX_PERSONA_COUNTS["similarity"],
|
| 76 |
+
allow_specific_personas=True,
|
| 77 |
)
|
| 78 |
return
|
| 79 |
|
tabs/chat.py
CHANGED
|
@@ -15,6 +15,8 @@ from tabs.chat_shared import (
|
|
| 15 |
generate_chat_reply_result,
|
| 16 |
hydrate_chat_state,
|
| 17 |
load_chat_personas,
|
|
|
|
|
|
|
| 18 |
render_chat_selection,
|
| 19 |
)
|
| 20 |
from tabs.chat_ui import (
|
|
@@ -25,7 +27,7 @@ from tabs.chat_ui import (
|
|
| 25 |
)
|
| 26 |
from utils.chat import build_chat_messages, resolve_system_prompt
|
| 27 |
from utils.chat_export import save_chat_export
|
| 28 |
-
from utils.helpers import session_key, widget_key
|
| 29 |
from utils.runtime import cached_model
|
| 30 |
|
| 31 |
if TYPE_CHECKING:
|
|
@@ -94,9 +96,26 @@ def _handle_single_chat_generation(
|
|
| 94 |
chat_log,
|
| 95 |
) -> None:
|
| 96 |
messages = build_chat_messages(active_system_prompt, chat_state["messages"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
with st.spinner("Generating reply..."):
|
|
|
|
| 99 |
model = cached_model(model_name=model_name)
|
|
|
|
|
|
|
| 100 |
|
| 101 |
def _show_error(exc: Exception) -> None:
|
| 102 |
with chat_log:
|
|
@@ -108,15 +127,19 @@ def _handle_single_chat_generation(
|
|
| 108 |
messages=messages,
|
| 109 |
remote=remote,
|
| 110 |
generation=generation,
|
|
|
|
| 111 |
on_error=_show_error,
|
| 112 |
)
|
| 113 |
if error is not None:
|
|
|
|
| 114 |
if pending_action == "new_user_prompt" and chat_state["messages"]:
|
| 115 |
chat_state["messages"].pop()
|
| 116 |
return
|
| 117 |
if reply is None:
|
|
|
|
| 118 |
return
|
| 119 |
|
|
|
|
| 120 |
chat_state["messages"].append({"role": "assistant", "content": reply.text})
|
| 121 |
st.rerun()
|
| 122 |
|
|
|
|
| 15 |
generate_chat_reply_result,
|
| 16 |
hydrate_chat_state,
|
| 17 |
load_chat_personas,
|
| 18 |
+
mark_model_loaded,
|
| 19 |
+
model_load_status,
|
| 20 |
render_chat_selection,
|
| 21 |
)
|
| 22 |
from tabs.chat_ui import (
|
|
|
|
| 27 |
)
|
| 28 |
from utils.chat import build_chat_messages, resolve_system_prompt
|
| 29 |
from utils.chat_export import save_chat_export
|
| 30 |
+
from utils.helpers import format_ndif_status, session_key, widget_key
|
| 31 |
from utils.runtime import cached_model
|
| 32 |
|
| 33 |
if TYPE_CHECKING:
|
|
|
|
| 96 |
chat_log,
|
| 97 |
) -> None:
|
| 98 |
messages = build_chat_messages(active_system_prompt, chat_state["messages"])
|
| 99 |
+
status_box = st.empty()
|
| 100 |
+
|
| 101 |
+
def _show_phase(text: str) -> None:
|
| 102 |
+
status_box.caption(text)
|
| 103 |
+
|
| 104 |
+
def _show_ndif_status(job_id: str, status_name: str, description: str) -> None:
|
| 105 |
+
status_box.caption(
|
| 106 |
+
format_ndif_status(
|
| 107 |
+
job_id,
|
| 108 |
+
status_name,
|
| 109 |
+
description,
|
| 110 |
+
completed_detail="Downloading result...",
|
| 111 |
+
)
|
| 112 |
+
)
|
| 113 |
|
| 114 |
with st.spinner("Generating reply..."):
|
| 115 |
+
_show_phase(model_load_status(model_name))
|
| 116 |
model = cached_model(model_name=model_name)
|
| 117 |
+
mark_model_loaded(model_name)
|
| 118 |
+
_show_phase("Submitting to NDIF..." if remote else "Generating locally...")
|
| 119 |
|
| 120 |
def _show_error(exc: Exception) -> None:
|
| 121 |
with chat_log:
|
|
|
|
| 127 |
messages=messages,
|
| 128 |
remote=remote,
|
| 129 |
generation=generation,
|
| 130 |
+
on_status=_show_ndif_status if remote else None,
|
| 131 |
on_error=_show_error,
|
| 132 |
)
|
| 133 |
if error is not None:
|
| 134 |
+
status_box.empty()
|
| 135 |
if pending_action == "new_user_prompt" and chat_state["messages"]:
|
| 136 |
chat_state["messages"].pop()
|
| 137 |
return
|
| 138 |
if reply is None:
|
| 139 |
+
status_box.empty()
|
| 140 |
return
|
| 141 |
|
| 142 |
+
status_box.empty()
|
| 143 |
chat_state["messages"].append({"role": "assistant", "content": reply.text})
|
| 144 |
st.rerun()
|
| 145 |
|
tabs/chat_shared.py
CHANGED
|
@@ -23,6 +23,9 @@ class ChatSelection:
|
|
| 23 |
changed: bool
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
| 26 |
def load_chat_personas(dataset_source: str) -> list[PersonaData] | None:
|
| 27 |
personas_file_key = session_key("extract", "personas_file")
|
| 28 |
qa_file_key = session_key("extract", "qa_file")
|
|
@@ -84,12 +87,27 @@ def render_chat_selection(
|
|
| 84 |
return ChatSelection(selected_persona, prompt_mode, changed)
|
| 85 |
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
def generate_chat_reply_result(
|
| 88 |
*,
|
| 89 |
model: object,
|
| 90 |
messages: list[dict[str, str]],
|
| 91 |
remote: bool,
|
| 92 |
generation: GenerationConfig,
|
|
|
|
| 93 |
on_error: Callable[[Exception], None] | None = None,
|
| 94 |
) -> tuple[ChatReply | None, Exception | None]:
|
| 95 |
try:
|
|
@@ -98,6 +116,7 @@ def generate_chat_reply_result(
|
|
| 98 |
model=model,
|
| 99 |
messages=messages,
|
| 100 |
remote=remote,
|
|
|
|
| 101 |
**generation.to_generate_kwargs(),
|
| 102 |
),
|
| 103 |
None,
|
|
|
|
| 23 |
changed: bool
|
| 24 |
|
| 25 |
|
| 26 |
+
_LOADED_MODEL_NAMES_KEY = session_key("chat", "loaded_model_names")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
def load_chat_personas(dataset_source: str) -> list[PersonaData] | None:
|
| 30 |
personas_file_key = session_key("extract", "personas_file")
|
| 31 |
qa_file_key = session_key("extract", "qa_file")
|
|
|
|
| 87 |
return ChatSelection(selected_persona, prompt_mode, changed)
|
| 88 |
|
| 89 |
|
| 90 |
+
def model_load_status(model_name: str) -> str:
|
| 91 |
+
"""Return an honest coarse-grained loading label for the current session."""
|
| 92 |
+
|
| 93 |
+
loaded_names = st.session_state.setdefault(_LOADED_MODEL_NAMES_KEY, set())
|
| 94 |
+
return "Using cached model..." if model_name in loaded_names else "Loading model..."
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def mark_model_loaded(model_name: str) -> None:
|
| 98 |
+
"""Remember that this session has already requested a model once."""
|
| 99 |
+
|
| 100 |
+
loaded_names = st.session_state.setdefault(_LOADED_MODEL_NAMES_KEY, set())
|
| 101 |
+
loaded_names.add(model_name)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
def generate_chat_reply_result(
|
| 105 |
*,
|
| 106 |
model: object,
|
| 107 |
messages: list[dict[str, str]],
|
| 108 |
remote: bool,
|
| 109 |
generation: GenerationConfig,
|
| 110 |
+
on_status: Callable[[str, str, str], None] | None = None,
|
| 111 |
on_error: Callable[[Exception], None] | None = None,
|
| 112 |
) -> tuple[ChatReply | None, Exception | None]:
|
| 113 |
try:
|
|
|
|
| 116 |
model=model,
|
| 117 |
messages=messages,
|
| 118 |
remote=remote,
|
| 119 |
+
on_status=on_status,
|
| 120 |
**generation.to_generate_kwargs(),
|
| 121 |
),
|
| 122 |
None,
|
tabs/chat_ui.py
CHANGED
|
@@ -16,6 +16,7 @@ from utils.helpers import (
|
|
| 16 |
|
| 17 |
if TYPE_CHECKING:
|
| 18 |
from persona_data.synth_persona import PersonaData
|
|
|
|
| 19 |
from utils.contrast import TokenContrast
|
| 20 |
|
| 21 |
GENERATION_DEFAULTS = {
|
|
|
|
| 16 |
|
| 17 |
if TYPE_CHECKING:
|
| 18 |
from persona_data.synth_persona import PersonaData
|
| 19 |
+
|
| 20 |
from utils.contrast import TokenContrast
|
| 21 |
|
| 22 |
GENERATION_DEFAULTS = {
|
tabs/compare_chat.py
CHANGED
|
@@ -14,7 +14,7 @@ from tabs.chat_shared import (
|
|
| 14 |
from utils.chat import ChatReply, build_chat_messages, resolve_system_prompt
|
| 15 |
from utils.chat_export import save_chat_export
|
| 16 |
from utils.contrast import compute_contrast, compute_contrast_pair
|
| 17 |
-
from utils.helpers import persona_label, session_key, widget_key
|
| 18 |
from utils.runtime import cached_model
|
| 19 |
|
| 20 |
from .chat_ui import (
|
|
@@ -142,15 +142,40 @@ def _generate_panels(
|
|
| 142 |
spinner_label: str,
|
| 143 |
) -> list[ChatReply | Exception]:
|
| 144 |
results: list[ChatReply | Exception] = []
|
|
|
|
| 145 |
with st.spinner(spinner_label):
|
| 146 |
for panel in panels:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
reply, error = generate_chat_reply_result(
|
| 148 |
model=model,
|
| 149 |
messages=build_chat_messages(panel.prompt, panel.state["messages"]),
|
| 150 |
remote=remote,
|
| 151 |
generation=generation,
|
|
|
|
| 152 |
)
|
| 153 |
results.append(reply if error is None else error)
|
|
|
|
| 154 |
return results
|
| 155 |
|
| 156 |
|
|
|
|
| 14 |
from utils.chat import ChatReply, build_chat_messages, resolve_system_prompt
|
| 15 |
from utils.chat_export import save_chat_export
|
| 16 |
from utils.contrast import compute_contrast, compute_contrast_pair
|
| 17 |
+
from utils.helpers import format_ndif_status, persona_label, session_key, widget_key
|
| 18 |
from utils.runtime import cached_model
|
| 19 |
|
| 20 |
from .chat_ui import (
|
|
|
|
| 142 |
spinner_label: str,
|
| 143 |
) -> list[ChatReply | Exception]:
|
| 144 |
results: list[ChatReply | Exception] = []
|
| 145 |
+
status_box = st.empty()
|
| 146 |
with st.spinner(spinner_label):
|
| 147 |
for panel in panels:
|
| 148 |
+
panel_label = panel.side.title()
|
| 149 |
+
status_box.caption(
|
| 150 |
+
f"{panel_label}: {'Submitting to NDIF...' if remote else 'Generating locally...'}"
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def _show_ndif_status(
|
| 154 |
+
job_id: str,
|
| 155 |
+
status_name: str,
|
| 156 |
+
description: str,
|
| 157 |
+
*,
|
| 158 |
+
label: str = panel_label,
|
| 159 |
+
) -> None:
|
| 160 |
+
status_box.caption(
|
| 161 |
+
format_ndif_status(
|
| 162 |
+
job_id,
|
| 163 |
+
status_name,
|
| 164 |
+
description,
|
| 165 |
+
prefix=label,
|
| 166 |
+
completed_detail="Downloading result...",
|
| 167 |
+
)
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
reply, error = generate_chat_reply_result(
|
| 171 |
model=model,
|
| 172 |
messages=build_chat_messages(panel.prompt, panel.state["messages"]),
|
| 173 |
remote=remote,
|
| 174 |
generation=generation,
|
| 175 |
+
on_status=_show_ndif_status if remote else None,
|
| 176 |
)
|
| 177 |
results.append(reply if error is None else error)
|
| 178 |
+
status_box.empty()
|
| 179 |
return results
|
| 180 |
|
| 181 |
|
tabs/extract.py
CHANGED
|
@@ -20,7 +20,7 @@ from utils.datasets import (
|
|
| 20 |
warm_qa_in_background,
|
| 21 |
)
|
| 22 |
from utils.helpers import (
|
| 23 |
-
|
| 24 |
persona_label,
|
| 25 |
prompt_variant_label,
|
| 26 |
session_key,
|
|
@@ -353,8 +353,7 @@ def _run_extraction_plan(
|
|
| 353 |
ndif_status_box = st.empty()
|
| 354 |
|
| 355 |
def _on_ndif_status(job_id: str, status_name: str, description: str) -> None:
|
| 356 |
-
|
| 357 |
-
ndif_status_box.caption(f"{icon} `{job_id}` **{status_name}** — {description}")
|
| 358 |
|
| 359 |
with st.spinner("Loading model..."):
|
| 360 |
model = cached_model(model_name=model_name)
|
|
|
|
| 20 |
warm_qa_in_background,
|
| 21 |
)
|
| 22 |
from utils.helpers import (
|
| 23 |
+
format_ndif_status,
|
| 24 |
persona_label,
|
| 25 |
prompt_variant_label,
|
| 26 |
session_key,
|
|
|
|
| 353 |
ndif_status_box = st.empty()
|
| 354 |
|
| 355 |
def _on_ndif_status(job_id: str, status_name: str, description: str) -> None:
|
| 356 |
+
ndif_status_box.caption(format_ndif_status(job_id, status_name, description))
|
|
|
|
| 357 |
|
| 358 |
with st.spinner("Loading model..."):
|
| 359 |
model = cached_model(model_name=model_name)
|
tabs/probe.py
CHANGED
|
@@ -11,43 +11,28 @@ is a thin Streamlit wrapper around them.
|
|
| 11 |
|
| 12 |
from __future__ import annotations
|
| 13 |
|
| 14 |
-
from dataclasses import dataclass
|
| 15 |
-
from pathlib import Path
|
| 16 |
-
|
| 17 |
import streamlit as st
|
| 18 |
-
from persona_data.environment import get_artifacts_dir
|
| 19 |
from persona_vectors.analysis import LayeredSamples
|
| 20 |
from persona_vectors.attributes import attribute_display_label
|
| 21 |
from persona_vectors.extraction import MaskStrategy
|
| 22 |
from persona_vectors.plots import plot_metric_comparison, plot_metric_over_layers
|
| 23 |
from persona_vectors.probes import (
|
| 24 |
AttributeLabels,
|
| 25 |
-
attribute_probe_labels,
|
| 26 |
default_probe_kinds,
|
| 27 |
-
filter_attribute_samples_min_count,
|
| 28 |
infer_probe_task,
|
| 29 |
layer_matrix,
|
| 30 |
save_probe_artifact,
|
| 31 |
shuffle_label_baseline,
|
| 32 |
-
sweep_attribute,
|
| 33 |
)
|
| 34 |
|
|
|
|
| 35 |
from utils.analysis_metadata import (
|
| 36 |
synth_persona_attribute_names,
|
| 37 |
synth_persona_dataset_cached,
|
| 38 |
)
|
| 39 |
from utils.analysis_sources import (
|
| 40 |
-
DEFAULT_COMPARE_MODEL,
|
| 41 |
-
DEFAULT_HUB_REPO,
|
| 42 |
-
SOURCE_HUB,
|
| 43 |
-
SOURCE_LOCAL,
|
| 44 |
-
SOURCES,
|
| 45 |
Store,
|
| 46 |
-
activation_store_cached,
|
| 47 |
available_variants,
|
| 48 |
-
hub_models_by_mask_strategy,
|
| 49 |
-
load_persona_vectors_cached,
|
| 50 |
-
local_model_options_cached,
|
| 51 |
persona_names_cached,
|
| 52 |
personas_cached,
|
| 53 |
store_cache_parts,
|
|
@@ -55,6 +40,7 @@ from utils.analysis_sources import (
|
|
| 55 |
)
|
| 56 |
from utils.controls import render_mask_strategy_select
|
| 57 |
from utils.helpers import widget_key
|
|
|
|
| 58 |
|
| 59 |
# ---------------------------------------------------------------------------
|
| 60 |
# Constants and config
|
|
@@ -78,94 +64,6 @@ _SECONDARY_METRIC = {
|
|
| 78 |
}
|
| 79 |
|
| 80 |
|
| 81 |
-
@dataclass(frozen=True)
|
| 82 |
-
class _SweepInputs:
|
| 83 |
-
source: str
|
| 84 |
-
location: str
|
| 85 |
-
model_name: str
|
| 86 |
-
mask_value: str
|
| 87 |
-
variant: str
|
| 88 |
-
persona_ids: tuple[str, ...]
|
| 89 |
-
attributes: tuple[str, ...]
|
| 90 |
-
task: str
|
| 91 |
-
probe_kinds: tuple[str, ...]
|
| 92 |
-
n_pca_components: int | None
|
| 93 |
-
layers: tuple[int, ...]
|
| 94 |
-
min_class_count: int
|
| 95 |
-
seed: int
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
# ---------------------------------------------------------------------------
|
| 99 |
-
# Source / store selection (slim mirror of the analysis tab pattern)
|
| 100 |
-
# ---------------------------------------------------------------------------
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def _select_source() -> str:
|
| 104 |
-
key = widget_key("probe", "source")
|
| 105 |
-
source = st.segmented_control(
|
| 106 |
-
"Source",
|
| 107 |
-
options=SOURCES,
|
| 108 |
-
default=st.session_state.get(key, SOURCE_HUB),
|
| 109 |
-
key=key,
|
| 110 |
-
label_visibility="collapsed",
|
| 111 |
-
)
|
| 112 |
-
return source or SOURCE_HUB
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def _select_store(source: str, mask_strategy: MaskStrategy) -> Store:
|
| 116 |
-
if source == SOURCE_HUB:
|
| 117 |
-
repo = st.text_input(
|
| 118 |
-
"Hub repo",
|
| 119 |
-
value=st.session_state.get("probe:hub_repo", DEFAULT_HUB_REPO),
|
| 120 |
-
key="probe:hub_repo",
|
| 121 |
-
)
|
| 122 |
-
models = hub_models_by_mask_strategy(repo).get(mask_strategy, [])
|
| 123 |
-
if not models:
|
| 124 |
-
st.warning(
|
| 125 |
-
f"No Hub vector configs for `{mask_strategy.value}` in `{repo}`."
|
| 126 |
-
)
|
| 127 |
-
model_name = st.text_input(
|
| 128 |
-
"Model",
|
| 129 |
-
value=st.session_state.get("probe:hub_model_fallback", DEFAULT_COMPARE_MODEL),
|
| 130 |
-
key="probe:hub_model_fallback",
|
| 131 |
-
)
|
| 132 |
-
else:
|
| 133 |
-
previous = st.session_state.get(
|
| 134 |
-
widget_key("probe", "hub_model", repo, mask_strategy.value),
|
| 135 |
-
models[0],
|
| 136 |
-
)
|
| 137 |
-
model_name = st.selectbox(
|
| 138 |
-
"Model",
|
| 139 |
-
options=models,
|
| 140 |
-
index=models.index(previous) if previous in models else 0,
|
| 141 |
-
key=widget_key("probe", "hub_model", repo, mask_strategy.value),
|
| 142 |
-
)
|
| 143 |
-
return activation_store_cached(SOURCE_HUB, repo, model_name, mask_strategy.value)
|
| 144 |
-
|
| 145 |
-
root = st.text_input(
|
| 146 |
-
"Artifacts root",
|
| 147 |
-
value=str(get_artifacts_dir() / "activations"),
|
| 148 |
-
key="probe:local_root",
|
| 149 |
-
)
|
| 150 |
-
root = str(Path(root).expanduser())
|
| 151 |
-
models = local_model_options_cached(root, mask_strategy.value)
|
| 152 |
-
if models:
|
| 153 |
-
previous = st.session_state.get("probe:local_model", models[0])
|
| 154 |
-
model_name = st.selectbox(
|
| 155 |
-
"Model",
|
| 156 |
-
options=models,
|
| 157 |
-
index=models.index(previous) if previous in models else 0,
|
| 158 |
-
key="probe:local_model",
|
| 159 |
-
)
|
| 160 |
-
else:
|
| 161 |
-
model_name = st.text_input(
|
| 162 |
-
"Model",
|
| 163 |
-
value=st.session_state.get("probe:local_model_fallback", DEFAULT_COMPARE_MODEL),
|
| 164 |
-
key="probe:local_model_fallback",
|
| 165 |
-
)
|
| 166 |
-
return activation_store_cached(SOURCE_LOCAL, root, model_name, mask_strategy.value)
|
| 167 |
-
|
| 168 |
-
|
| 169 |
def _select_variant(store: Store, mask_strategy: MaskStrategy) -> str | None:
|
| 170 |
variants = available_variants(store, mask_strategy)
|
| 171 |
if not variants:
|
|
@@ -184,7 +82,9 @@ def _select_personas(
|
|
| 184 |
store: Store, variant: str, mask_strategy: MaskStrategy
|
| 185 |
) -> list[str]:
|
| 186 |
source, location, model_name = store_cache_parts(store)
|
| 187 |
-
all_ids = personas_cached(
|
|
|
|
|
|
|
| 188 |
if not all_ids:
|
| 189 |
st.info("No personas found for this variant.")
|
| 190 |
return []
|
|
@@ -225,7 +125,12 @@ def _select_personas(
|
|
| 225 |
st.session_state["probe:persona_count"] = count
|
| 226 |
persona_ids = regular[:count]
|
| 227 |
persona_names_cached(
|
| 228 |
-
source,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
)
|
| 230 |
st.caption(f"Probing {len(persona_ids)} of {len(regular)} non-assistant personas.")
|
| 231 |
return persona_ids
|
|
@@ -323,13 +228,15 @@ def _select_layers(num_layers: int) -> list[int]:
|
|
| 323 |
)
|
| 324 |
if not fast:
|
| 325 |
return list(range(num_layers))
|
| 326 |
-
return sorted(
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
|
|
|
|
|
|
| 333 |
|
| 334 |
|
| 335 |
# ---------------------------------------------------------------------------
|
|
@@ -337,66 +244,12 @@ def _select_layers(num_layers: int) -> list[int]:
|
|
| 337 |
# ---------------------------------------------------------------------------
|
| 338 |
|
| 339 |
|
| 340 |
-
@st.cache_resource(show_spinner=False)
|
| 341 |
-
def _cached_sweep(
|
| 342 |
-
inputs: _SweepInputs,
|
| 343 |
-
) -> tuple[
|
| 344 |
-
dict[str, list[dict[str, object]]],
|
| 345 |
-
dict[str, tuple[AttributeLabels, LayeredSamples]],
|
| 346 |
-
]:
|
| 347 |
-
samples = load_persona_vectors_cached(
|
| 348 |
-
inputs.source, inputs.location, inputs.model_name,
|
| 349 |
-
inputs.mask_value, inputs.variant, inputs.persona_ids,
|
| 350 |
-
)
|
| 351 |
-
dataset = synth_persona_dataset_cached()
|
| 352 |
-
# The min-count filter drops personas per attribute, so each attribute keeps
|
| 353 |
-
# its own (labels, samples) pair for the downstream selectivity/save tools.
|
| 354 |
-
per_attr: dict[str, tuple[AttributeLabels, LayeredSamples]] = {}
|
| 355 |
-
|
| 356 |
-
def _labels_and_samples(attribute: str) -> tuple[AttributeLabels, LayeredSamples]:
|
| 357 |
-
if attribute not in per_attr:
|
| 358 |
-
labels = attribute_probe_labels(
|
| 359 |
-
dataset, attribute, list(inputs.persona_ids), task=inputs.task, # type: ignore[arg-type]
|
| 360 |
-
)
|
| 361 |
-
probe_samples, labels = filter_attribute_samples_min_count(
|
| 362 |
-
samples, labels, min_count=inputs.min_class_count
|
| 363 |
-
)
|
| 364 |
-
per_attr[attribute] = (labels, probe_samples)
|
| 365 |
-
return per_attr[attribute]
|
| 366 |
-
|
| 367 |
-
def _sweep(attribute: str, n_pca: int | None) -> list[dict[str, object]]:
|
| 368 |
-
labels, probe_samples = _labels_and_samples(attribute)
|
| 369 |
-
return sweep_attribute(
|
| 370 |
-
probe_samples, labels,
|
| 371 |
-
layers=list(inputs.layers),
|
| 372 |
-
probe_kinds=list(inputs.probe_kinds), # type: ignore[arg-type]
|
| 373 |
-
n_pca_components=n_pca,
|
| 374 |
-
seed=inputs.seed,
|
| 375 |
-
)
|
| 376 |
-
|
| 377 |
-
def _sweep_all(n_pca: int | None) -> list[dict[str, object]]:
|
| 378 |
-
rows: list[dict[str, object]] = []
|
| 379 |
-
for attribute in inputs.attributes:
|
| 380 |
-
rows.extend(_sweep(attribute, n_pca))
|
| 381 |
-
return rows
|
| 382 |
-
|
| 383 |
-
if inputs.n_pca_components is not None:
|
| 384 |
-
# Always overlay the compressed sweep against full activations.
|
| 385 |
-
rows_by_label = {
|
| 386 |
-
"full": _sweep_all(None),
|
| 387 |
-
f"pca{inputs.n_pca_components}": _sweep_all(inputs.n_pca_components),
|
| 388 |
-
}
|
| 389 |
-
else:
|
| 390 |
-
rows_by_label = {"full": _sweep_all(None)}
|
| 391 |
-
return rows_by_label, per_attr
|
| 392 |
-
|
| 393 |
-
|
| 394 |
def _show_sweep(
|
| 395 |
rows_by_label: dict[str, list[dict[str, object]]],
|
| 396 |
per_attr: dict[str, tuple[AttributeLabels, LayeredSamples]],
|
| 397 |
attributes: tuple[str, ...],
|
| 398 |
task: str,
|
| 399 |
-
inputs:
|
| 400 |
) -> None:
|
| 401 |
primary = _PRIMARY_METRIC[task]
|
| 402 |
secondary = _SECONDARY_METRIC.get(task)
|
|
@@ -442,8 +295,7 @@ def _show_sweep(
|
|
| 442 |
for label, label_rows in rows_by_label.items():
|
| 443 |
for attribute in attributes:
|
| 444 |
attr_rows = [
|
| 445 |
-
row for row in label_rows
|
| 446 |
-
if row.get("attribute") == attribute
|
| 447 |
]
|
| 448 |
label_best = _best_row(attr_rows)
|
| 449 |
if label_best is None:
|
|
@@ -451,22 +303,23 @@ def _show_sweep(
|
|
| 451 |
summary_row: dict[str, object] = {}
|
| 452 |
if multi_attr:
|
| 453 |
summary_row["attribute"] = attribute
|
| 454 |
-
summary_row.update(
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
|
|
|
|
|
|
|
|
|
| 463 |
summary_rows.append(summary_row)
|
| 464 |
if summary_rows:
|
| 465 |
st.dataframe(summary_rows, width="stretch", hide_index=True)
|
| 466 |
|
| 467 |
-
feature_desc =
|
| 468 |
-
f" · pca{inputs.n_pca_components}" if inputs.n_pca_components else ""
|
| 469 |
-
)
|
| 470 |
|
| 471 |
best_attr = str(best["attribute"])
|
| 472 |
labels, samples = per_attr[best_attr]
|
|
@@ -495,7 +348,7 @@ def _render_selectivity_control(
|
|
| 495 |
labels: AttributeLabels,
|
| 496 |
samples: LayeredSamples,
|
| 497 |
task: str,
|
| 498 |
-
inputs:
|
| 499 |
) -> None:
|
| 500 |
if task == "numeric":
|
| 501 |
return # selectivity control is classification-only
|
|
@@ -507,14 +360,18 @@ def _render_selectivity_control(
|
|
| 507 |
"dataset artifacts, not the property."
|
| 508 |
)
|
| 509 |
n_repeats = st.slider(
|
| 510 |
-
"Shuffle repeats",
|
|
|
|
|
|
|
|
|
|
| 511 |
key="probe:shuffle_repeats",
|
| 512 |
)
|
| 513 |
if st.button("Run selectivity control", key="probe:run_shuffle"):
|
| 514 |
with st.spinner("Running shuffled-label control..."):
|
| 515 |
X = layer_matrix(samples, int(best["layer"]))
|
| 516 |
shuffled = shuffle_label_baseline(
|
| 517 |
-
X,
|
|
|
|
| 518 |
task=task, # type: ignore[arg-type]
|
| 519 |
layer=int(best["layer"]),
|
| 520 |
probe_kind=best["probe_kind"], # type: ignore[arg-type]
|
|
@@ -539,7 +396,7 @@ def _render_save_artifact(
|
|
| 539 |
labels: AttributeLabels,
|
| 540 |
samples: LayeredSamples,
|
| 541 |
task: str,
|
| 542 |
-
inputs:
|
| 543 |
) -> None:
|
| 544 |
def synced_default(key: str, default: str) -> str:
|
| 545 |
default_key = f"{key}:default"
|
|
@@ -575,7 +432,9 @@ def _render_save_artifact(
|
|
| 575 |
if st.button("Save", key="probe:save_artifact"):
|
| 576 |
X = layer_matrix(samples, int(best["layer"]))
|
| 577 |
directory = save_probe_artifact(
|
| 578 |
-
X=X,
|
|
|
|
|
|
|
| 579 |
task=task, # type: ignore[arg-type]
|
| 580 |
probe_kind=best["probe_kind"], # type: ignore[arg-type]
|
| 581 |
n_pca_components=inputs.n_pca_components,
|
|
@@ -601,14 +460,21 @@ def _render_save_artifact(
|
|
| 601 |
def render_probing_tab() -> None:
|
| 602 |
st.title("Probing")
|
| 603 |
|
| 604 |
-
source =
|
| 605 |
with st.expander("Source", expanded=True):
|
| 606 |
mask_strategy = render_mask_strategy_select(
|
| 607 |
key=widget_key("probe", "mask_strategy"),
|
| 608 |
last_key="probe:last_mask_strategy",
|
|
|
|
| 609 |
help_text="Which extracted activation set to load.",
|
| 610 |
)
|
| 611 |
-
store =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
variant = _select_variant(store, mask_strategy)
|
| 613 |
if variant is None:
|
| 614 |
return
|
|
@@ -644,13 +510,19 @@ def render_probing_tab() -> None:
|
|
| 644 |
min_class_count = _MIN_CLASS_COUNT
|
| 645 |
seed = 0
|
| 646 |
|
| 647 |
-
inputs =
|
| 648 |
-
source=source,
|
| 649 |
-
|
| 650 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 651 |
probe_kinds=tuple(probe_kinds),
|
| 652 |
n_pca_components=n_pca_components,
|
| 653 |
-
layers=tuple(layers),
|
|
|
|
| 654 |
seed=int(seed),
|
| 655 |
)
|
| 656 |
|
|
@@ -659,7 +531,7 @@ def render_probing_tab() -> None:
|
|
| 659 |
if run:
|
| 660 |
with st.spinner("Evaluating probes across layers..."):
|
| 661 |
try:
|
| 662 |
-
sweep, per_attr =
|
| 663 |
except Exception as exc:
|
| 664 |
st.error(f"Sweep failed: {exc}")
|
| 665 |
st.session_state.pop(state_key, None)
|
|
@@ -674,6 +546,9 @@ def render_probing_tab() -> None:
|
|
| 674 |
else:
|
| 675 |
sweep, per_attr, result_inputs = saved_result
|
| 676 |
_show_sweep(
|
| 677 |
-
sweep,
|
| 678 |
-
|
|
|
|
|
|
|
|
|
|
| 679 |
)
|
|
|
|
| 11 |
|
| 12 |
from __future__ import annotations
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
import streamlit as st
|
|
|
|
| 15 |
from persona_vectors.analysis import LayeredSamples
|
| 16 |
from persona_vectors.attributes import attribute_display_label
|
| 17 |
from persona_vectors.extraction import MaskStrategy
|
| 18 |
from persona_vectors.plots import plot_metric_comparison, plot_metric_over_layers
|
| 19 |
from persona_vectors.probes import (
|
| 20 |
AttributeLabels,
|
|
|
|
| 21 |
default_probe_kinds,
|
|
|
|
| 22 |
infer_probe_task,
|
| 23 |
layer_matrix,
|
| 24 |
save_probe_artifact,
|
| 25 |
shuffle_label_baseline,
|
|
|
|
| 26 |
)
|
| 27 |
|
| 28 |
+
from tabs.probe_sweep import SweepInputs, cached_sweep
|
| 29 |
from utils.analysis_metadata import (
|
| 30 |
synth_persona_attribute_names,
|
| 31 |
synth_persona_dataset_cached,
|
| 32 |
)
|
| 33 |
from utils.analysis_sources import (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
Store,
|
|
|
|
| 35 |
available_variants,
|
|
|
|
|
|
|
|
|
|
| 36 |
persona_names_cached,
|
| 37 |
personas_cached,
|
| 38 |
store_cache_parts,
|
|
|
|
| 40 |
)
|
| 41 |
from utils.controls import render_mask_strategy_select
|
| 42 |
from utils.helpers import widget_key
|
| 43 |
+
from utils.source_controls import render_source_select, render_store_select
|
| 44 |
|
| 45 |
# ---------------------------------------------------------------------------
|
| 46 |
# Constants and config
|
|
|
|
| 64 |
}
|
| 65 |
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def _select_variant(store: Store, mask_strategy: MaskStrategy) -> str | None:
|
| 68 |
variants = available_variants(store, mask_strategy)
|
| 69 |
if not variants:
|
|
|
|
| 82 |
store: Store, variant: str, mask_strategy: MaskStrategy
|
| 83 |
) -> list[str]:
|
| 84 |
source, location, model_name = store_cache_parts(store)
|
| 85 |
+
all_ids = personas_cached(
|
| 86 |
+
source, location, model_name, mask_strategy.value, (variant,)
|
| 87 |
+
)
|
| 88 |
if not all_ids:
|
| 89 |
st.info("No personas found for this variant.")
|
| 90 |
return []
|
|
|
|
| 125 |
st.session_state["probe:persona_count"] = count
|
| 126 |
persona_ids = regular[:count]
|
| 127 |
persona_names_cached(
|
| 128 |
+
source,
|
| 129 |
+
location,
|
| 130 |
+
model_name,
|
| 131 |
+
mask_strategy.value,
|
| 132 |
+
(variant,),
|
| 133 |
+
tuple(persona_ids),
|
| 134 |
)
|
| 135 |
st.caption(f"Probing {len(persona_ids)} of {len(regular)} non-assistant personas.")
|
| 136 |
return persona_ids
|
|
|
|
| 228 |
)
|
| 229 |
if not fast:
|
| 230 |
return list(range(num_layers))
|
| 231 |
+
return sorted(
|
| 232 |
+
{
|
| 233 |
+
0,
|
| 234 |
+
num_layers // 4,
|
| 235 |
+
num_layers // 2,
|
| 236 |
+
(3 * num_layers) // 4,
|
| 237 |
+
num_layers - 1,
|
| 238 |
+
}
|
| 239 |
+
)
|
| 240 |
|
| 241 |
|
| 242 |
# ---------------------------------------------------------------------------
|
|
|
|
| 244 |
# ---------------------------------------------------------------------------
|
| 245 |
|
| 246 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
def _show_sweep(
|
| 248 |
rows_by_label: dict[str, list[dict[str, object]]],
|
| 249 |
per_attr: dict[str, tuple[AttributeLabels, LayeredSamples]],
|
| 250 |
attributes: tuple[str, ...],
|
| 251 |
task: str,
|
| 252 |
+
inputs: SweepInputs,
|
| 253 |
) -> None:
|
| 254 |
primary = _PRIMARY_METRIC[task]
|
| 255 |
secondary = _SECONDARY_METRIC.get(task)
|
|
|
|
| 295 |
for label, label_rows in rows_by_label.items():
|
| 296 |
for attribute in attributes:
|
| 297 |
attr_rows = [
|
| 298 |
+
row for row in label_rows if row.get("attribute") == attribute
|
|
|
|
| 299 |
]
|
| 300 |
label_best = _best_row(attr_rows)
|
| 301 |
if label_best is None:
|
|
|
|
| 303 |
summary_row: dict[str, object] = {}
|
| 304 |
if multi_attr:
|
| 305 |
summary_row["attribute"] = attribute
|
| 306 |
+
summary_row.update(
|
| 307 |
+
{
|
| 308 |
+
"features": label,
|
| 309 |
+
"best_layer": label_best["layer"],
|
| 310 |
+
"probe": label_best["probe_kind"],
|
| 311 |
+
primary: round(float(label_best[primary]), 3),
|
| 312 |
+
f"baseline_{primary}": round(
|
| 313 |
+
float(label_best.get(f"baseline_{primary}", float("nan"))),
|
| 314 |
+
3,
|
| 315 |
+
),
|
| 316 |
+
}
|
| 317 |
+
)
|
| 318 |
summary_rows.append(summary_row)
|
| 319 |
if summary_rows:
|
| 320 |
st.dataframe(summary_rows, width="stretch", hide_index=True)
|
| 321 |
|
| 322 |
+
feature_desc = f" · pca{inputs.n_pca_components}" if inputs.n_pca_components else ""
|
|
|
|
|
|
|
| 323 |
|
| 324 |
best_attr = str(best["attribute"])
|
| 325 |
labels, samples = per_attr[best_attr]
|
|
|
|
| 348 |
labels: AttributeLabels,
|
| 349 |
samples: LayeredSamples,
|
| 350 |
task: str,
|
| 351 |
+
inputs: SweepInputs,
|
| 352 |
) -> None:
|
| 353 |
if task == "numeric":
|
| 354 |
return # selectivity control is classification-only
|
|
|
|
| 360 |
"dataset artifacts, not the property."
|
| 361 |
)
|
| 362 |
n_repeats = st.slider(
|
| 363 |
+
"Shuffle repeats",
|
| 364 |
+
min_value=3,
|
| 365 |
+
max_value=15,
|
| 366 |
+
value=5,
|
| 367 |
key="probe:shuffle_repeats",
|
| 368 |
)
|
| 369 |
if st.button("Run selectivity control", key="probe:run_shuffle"):
|
| 370 |
with st.spinner("Running shuffled-label control..."):
|
| 371 |
X = layer_matrix(samples, int(best["layer"]))
|
| 372 |
shuffled = shuffle_label_baseline(
|
| 373 |
+
X,
|
| 374 |
+
labels.y,
|
| 375 |
task=task, # type: ignore[arg-type]
|
| 376 |
layer=int(best["layer"]),
|
| 377 |
probe_kind=best["probe_kind"], # type: ignore[arg-type]
|
|
|
|
| 396 |
labels: AttributeLabels,
|
| 397 |
samples: LayeredSamples,
|
| 398 |
task: str,
|
| 399 |
+
inputs: SweepInputs,
|
| 400 |
) -> None:
|
| 401 |
def synced_default(key: str, default: str) -> str:
|
| 402 |
default_key = f"{key}:default"
|
|
|
|
| 432 |
if st.button("Save", key="probe:save_artifact"):
|
| 433 |
X = layer_matrix(samples, int(best["layer"]))
|
| 434 |
directory = save_probe_artifact(
|
| 435 |
+
X=X,
|
| 436 |
+
y=labels.y,
|
| 437 |
+
labels=labels,
|
| 438 |
task=task, # type: ignore[arg-type]
|
| 439 |
probe_kind=best["probe_kind"], # type: ignore[arg-type]
|
| 440 |
n_pca_components=inputs.n_pca_components,
|
|
|
|
| 460 |
def render_probing_tab() -> None:
|
| 461 |
st.title("Probing")
|
| 462 |
|
| 463 |
+
source = render_source_select(widget_scope="probe")
|
| 464 |
with st.expander("Source", expanded=True):
|
| 465 |
mask_strategy = render_mask_strategy_select(
|
| 466 |
key=widget_key("probe", "mask_strategy"),
|
| 467 |
last_key="probe:last_mask_strategy",
|
| 468 |
+
remember_key="source:last_mask_strategy",
|
| 469 |
help_text="Which extracted activation set to load.",
|
| 470 |
)
|
| 471 |
+
store = render_store_select(
|
| 472 |
+
source,
|
| 473 |
+
mask_strategy,
|
| 474 |
+
state_prefix="probe",
|
| 475 |
+
widget_scope="probe",
|
| 476 |
+
artifacts_root_key="probe:local_root",
|
| 477 |
+
)
|
| 478 |
variant = _select_variant(store, mask_strategy)
|
| 479 |
if variant is None:
|
| 480 |
return
|
|
|
|
| 510 |
min_class_count = _MIN_CLASS_COUNT
|
| 511 |
seed = 0
|
| 512 |
|
| 513 |
+
inputs = SweepInputs(
|
| 514 |
+
source=source,
|
| 515 |
+
location=location,
|
| 516 |
+
model_name=model_name,
|
| 517 |
+
mask_value=mask_strategy.value,
|
| 518 |
+
variant=variant,
|
| 519 |
+
persona_ids=tuple(persona_ids),
|
| 520 |
+
attributes=tuple(attributes),
|
| 521 |
+
task=task,
|
| 522 |
probe_kinds=tuple(probe_kinds),
|
| 523 |
n_pca_components=n_pca_components,
|
| 524 |
+
layers=tuple(layers),
|
| 525 |
+
min_class_count=min_class_count,
|
| 526 |
seed=int(seed),
|
| 527 |
)
|
| 528 |
|
|
|
|
| 531 |
if run:
|
| 532 |
with st.spinner("Evaluating probes across layers..."):
|
| 533 |
try:
|
| 534 |
+
sweep, per_attr = cached_sweep(inputs)
|
| 535 |
except Exception as exc:
|
| 536 |
st.error(f"Sweep failed: {exc}")
|
| 537 |
st.session_state.pop(state_key, None)
|
|
|
|
| 546 |
else:
|
| 547 |
sweep, per_attr, result_inputs = saved_result
|
| 548 |
_show_sweep(
|
| 549 |
+
sweep,
|
| 550 |
+
per_attr,
|
| 551 |
+
result_inputs.attributes,
|
| 552 |
+
result_inputs.task,
|
| 553 |
+
result_inputs,
|
| 554 |
)
|
tabs/probe_sweep.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import streamlit as st
|
| 6 |
+
from persona_vectors.analysis import LayeredSamples
|
| 7 |
+
from persona_vectors.probes import (
|
| 8 |
+
AttributeLabels,
|
| 9 |
+
attribute_probe_labels,
|
| 10 |
+
filter_attribute_samples_min_count,
|
| 11 |
+
sweep_attribute,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
from utils.analysis_metadata import synth_persona_dataset_cached
|
| 15 |
+
from utils.analysis_sources import load_persona_vectors_cached
|
| 16 |
+
from utils.helpers import env_int
|
| 17 |
+
|
| 18 |
+
_SWEEP_CACHE_ENTRIES = env_int("PERSONA_UI_PROBE_SWEEP_CACHE_ENTRIES", 4)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass(frozen=True)
|
| 22 |
+
class SweepInputs:
|
| 23 |
+
source: str
|
| 24 |
+
location: str
|
| 25 |
+
model_name: str
|
| 26 |
+
mask_value: str
|
| 27 |
+
variant: str
|
| 28 |
+
persona_ids: tuple[str, ...]
|
| 29 |
+
attributes: tuple[str, ...]
|
| 30 |
+
task: str
|
| 31 |
+
probe_kinds: tuple[str, ...]
|
| 32 |
+
n_pca_components: int | None
|
| 33 |
+
layers: tuple[int, ...]
|
| 34 |
+
min_class_count: int
|
| 35 |
+
seed: int
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@st.cache_resource(show_spinner=False, max_entries=_SWEEP_CACHE_ENTRIES)
|
| 39 |
+
def cached_sweep(
|
| 40 |
+
inputs: SweepInputs,
|
| 41 |
+
) -> tuple[
|
| 42 |
+
dict[str, list[dict[str, object]]],
|
| 43 |
+
dict[str, tuple[AttributeLabels, LayeredSamples]],
|
| 44 |
+
]:
|
| 45 |
+
samples = load_persona_vectors_cached(
|
| 46 |
+
inputs.source,
|
| 47 |
+
inputs.location,
|
| 48 |
+
inputs.model_name,
|
| 49 |
+
inputs.mask_value,
|
| 50 |
+
inputs.variant,
|
| 51 |
+
inputs.persona_ids,
|
| 52 |
+
)
|
| 53 |
+
dataset = synth_persona_dataset_cached()
|
| 54 |
+
per_attr: dict[str, tuple[AttributeLabels, LayeredSamples]] = {}
|
| 55 |
+
|
| 56 |
+
def labels_and_samples(attribute: str) -> tuple[AttributeLabels, LayeredSamples]:
|
| 57 |
+
if attribute not in per_attr:
|
| 58 |
+
labels = attribute_probe_labels(
|
| 59 |
+
dataset,
|
| 60 |
+
attribute,
|
| 61 |
+
list(inputs.persona_ids),
|
| 62 |
+
task=inputs.task, # type: ignore[arg-type]
|
| 63 |
+
)
|
| 64 |
+
probe_samples, labels = filter_attribute_samples_min_count(
|
| 65 |
+
samples,
|
| 66 |
+
labels,
|
| 67 |
+
min_count=inputs.min_class_count,
|
| 68 |
+
)
|
| 69 |
+
per_attr[attribute] = (labels, probe_samples)
|
| 70 |
+
return per_attr[attribute]
|
| 71 |
+
|
| 72 |
+
def sweep_one(attribute: str, n_pca: int | None) -> list[dict[str, object]]:
|
| 73 |
+
labels, probe_samples = labels_and_samples(attribute)
|
| 74 |
+
return sweep_attribute(
|
| 75 |
+
probe_samples,
|
| 76 |
+
labels,
|
| 77 |
+
layers=list(inputs.layers),
|
| 78 |
+
probe_kinds=list(inputs.probe_kinds), # type: ignore[arg-type]
|
| 79 |
+
n_pca_components=n_pca,
|
| 80 |
+
seed=inputs.seed,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def sweep_all(n_pca: int | None) -> list[dict[str, object]]:
|
| 84 |
+
rows: list[dict[str, object]] = []
|
| 85 |
+
for attribute in inputs.attributes:
|
| 86 |
+
rows.extend(sweep_one(attribute, n_pca))
|
| 87 |
+
return rows
|
| 88 |
+
|
| 89 |
+
rows_by_label = {"full": sweep_all(None)}
|
| 90 |
+
if inputs.n_pca_components is not None:
|
| 91 |
+
rows_by_label[f"pca{inputs.n_pca_components}"] = sweep_all(
|
| 92 |
+
inputs.n_pca_components
|
| 93 |
+
)
|
| 94 |
+
return rows_by_label, per_attr
|
tabs/probe_ui.py
CHANGED
|
@@ -6,7 +6,15 @@ import streamlit as st
|
|
| 6 |
import torch
|
| 7 |
|
| 8 |
from utils.chat import build_chat_messages
|
| 9 |
-
from utils.helpers import session_key, widget_key
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from utils.probe_overlay import (
|
| 11 |
attach_overlays,
|
| 12 |
build_classification_overlays,
|
|
@@ -15,24 +23,25 @@ from utils.probe_overlay import (
|
|
| 15 |
)
|
| 16 |
from utils.probe_trace import ConversationTrace, trace_conversation
|
| 17 |
from utils.probes import (
|
| 18 |
-
DEFAULT_LOCAL_PROBE_DIR,
|
| 19 |
-
DEFAULT_PROBE_REPO,
|
| 20 |
LoadedProbe,
|
| 21 |
-
list_local_probe_files,
|
| 22 |
-
list_probe_files,
|
| 23 |
load_local_probe,
|
| 24 |
load_probe,
|
| 25 |
load_probe_from_bytes,
|
| 26 |
-
model_probe_dir_name,
|
| 27 |
-
parse_probe_filename,
|
| 28 |
)
|
| 29 |
from utils.runtime import cached_model
|
|
|
|
| 30 |
|
| 31 |
_LAST_SOURCE_KEY = session_key("probe", "last_source")
|
| 32 |
_LAST_LOCAL_FILE_KEY = session_key("probe", "last_local_file")
|
| 33 |
_LAST_HUB_FILE_KEY = session_key("probe", "last_hub_file")
|
| 34 |
|
| 35 |
_PROBE_SOURCES = ("Local artifact", "Hugging Face repo", "Upload .pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
# ---------------------------------------------------------------------------
|
|
@@ -62,23 +71,16 @@ def _default_file(files: list[str], remembered: str | None) -> str:
|
|
| 62 |
return files[0]
|
| 63 |
|
| 64 |
|
| 65 |
-
def _render_probe_selector(
|
| 66 |
-
*, context_key: str, model_name: str
|
| 67 |
-
) -> LoadedProbe | None:
|
| 68 |
"""Inline source + file selector. Returns the loaded probe or None."""
|
| 69 |
-
|
| 70 |
-
if source_key not in st.session_state:
|
| 71 |
-
st.session_state[source_key] = st.session_state.get(
|
| 72 |
-
_LAST_SOURCE_KEY, _PROBE_SOURCES[0]
|
| 73 |
-
)
|
| 74 |
-
source = st.segmented_control(
|
| 75 |
"Probe source",
|
| 76 |
options=_PROBE_SOURCES,
|
| 77 |
-
key=
|
|
|
|
|
|
|
| 78 |
label_visibility="collapsed",
|
| 79 |
)
|
| 80 |
-
source = source or _PROBE_SOURCES[0]
|
| 81 |
-
st.session_state[_LAST_SOURCE_KEY] = source
|
| 82 |
|
| 83 |
if source == "Local artifact":
|
| 84 |
return _render_local_probe(context_key=context_key, model_name=model_name)
|
|
@@ -87,9 +89,7 @@ def _render_probe_selector(
|
|
| 87 |
return _render_upload_probe(context_key=context_key)
|
| 88 |
|
| 89 |
|
| 90 |
-
def _render_local_probe(
|
| 91 |
-
*, context_key: str, model_name: str
|
| 92 |
-
) -> LoadedProbe | None:
|
| 93 |
root_dir = st.text_input(
|
| 94 |
"Probe directory",
|
| 95 |
value=st.session_state.get(
|
|
@@ -118,9 +118,7 @@ def _render_local_probe(
|
|
| 118 |
return None
|
| 119 |
|
| 120 |
|
| 121 |
-
def _render_hub_probe(
|
| 122 |
-
*, context_key: str, model_name: str
|
| 123 |
-
) -> LoadedProbe | None:
|
| 124 |
repo_id = st.text_input(
|
| 125 |
"Probe repo",
|
| 126 |
value=st.session_state.get(
|
|
@@ -249,15 +247,43 @@ def _validate(
|
|
| 249 |
# ---------------------------------------------------------------------------
|
| 250 |
|
| 251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
def _classification_predictions(
|
| 253 |
probe: LoadedProbe, activations: torch.Tensor, cache_key: str
|
| 254 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 255 |
full_key = widget_key("probe_predictions", cache_key, str(id(probe)))
|
| 256 |
-
cached =
|
| 257 |
if cached is not None:
|
| 258 |
return cached
|
| 259 |
_, probs, predicted = probe.run_batch(activations)
|
| 260 |
-
|
| 261 |
return probs, predicted
|
| 262 |
|
| 263 |
|
|
@@ -265,11 +291,11 @@ def _regression_values(
|
|
| 265 |
probe: LoadedProbe, activations: torch.Tensor, cache_key: str
|
| 266 |
) -> torch.Tensor:
|
| 267 |
full_key = widget_key("probe_values", cache_key, str(id(probe)))
|
| 268 |
-
cached =
|
| 269 |
if cached is not None:
|
| 270 |
return cached
|
| 271 |
values = probe.predict_batch(activations)
|
| 272 |
-
|
| 273 |
return values
|
| 274 |
|
| 275 |
|
|
@@ -297,9 +323,7 @@ def _apply_overlays(
|
|
| 297 |
probs, predicted = _classification_predictions(
|
| 298 |
probe, trace.activations, trace.cache_key
|
| 299 |
)
|
| 300 |
-
binary = probs.shape[1] == 1 or (
|
| 301 |
-
probs.shape[1] == 2 and len(probe.labels) == 2
|
| 302 |
-
)
|
| 303 |
overlays = build_classification_overlays(
|
| 304 |
trace=trace,
|
| 305 |
probs=probs,
|
|
@@ -332,9 +356,7 @@ def render_probe_inspector(
|
|
| 332 |
def _conversation_sig() -> int:
|
| 333 |
return hash(
|
| 334 |
tuple(
|
| 335 |
-
(m.get("role"), m.get("content"))
|
| 336 |
-
for m in messages
|
| 337 |
-
if m.get("content")
|
| 338 |
)
|
| 339 |
)
|
| 340 |
|
|
@@ -349,9 +371,7 @@ def render_probe_inspector(
|
|
| 349 |
st.caption("Probe overlay shows up after the first assistant reply.")
|
| 350 |
return
|
| 351 |
|
| 352 |
-
probe = _render_probe_selector(
|
| 353 |
-
context_key=context_key, model_name=model_name
|
| 354 |
-
)
|
| 355 |
if probe is None:
|
| 356 |
_reset()
|
| 357 |
return
|
|
|
|
| 6 |
import torch
|
| 7 |
|
| 8 |
from utils.chat import build_chat_messages
|
| 9 |
+
from utils.helpers import env_int, session_key, widget_key
|
| 10 |
+
from utils.probe_files import (
|
| 11 |
+
DEFAULT_LOCAL_PROBE_DIR,
|
| 12 |
+
DEFAULT_PROBE_REPO,
|
| 13 |
+
list_local_probe_files,
|
| 14 |
+
list_probe_files,
|
| 15 |
+
model_probe_dir_name,
|
| 16 |
+
parse_probe_filename,
|
| 17 |
+
)
|
| 18 |
from utils.probe_overlay import (
|
| 19 |
attach_overlays,
|
| 20 |
build_classification_overlays,
|
|
|
|
| 23 |
)
|
| 24 |
from utils.probe_trace import ConversationTrace, trace_conversation
|
| 25 |
from utils.probes import (
|
|
|
|
|
|
|
| 26 |
LoadedProbe,
|
|
|
|
|
|
|
| 27 |
load_local_probe,
|
| 28 |
load_probe,
|
| 29 |
load_probe_from_bytes,
|
|
|
|
|
|
|
| 30 |
)
|
| 31 |
from utils.runtime import cached_model
|
| 32 |
+
from utils.selection_controls import remembered_segmented_control
|
| 33 |
|
| 34 |
_LAST_SOURCE_KEY = session_key("probe", "last_source")
|
| 35 |
_LAST_LOCAL_FILE_KEY = session_key("probe", "last_local_file")
|
| 36 |
_LAST_HUB_FILE_KEY = session_key("probe", "last_hub_file")
|
| 37 |
|
| 38 |
_PROBE_SOURCES = ("Local artifact", "Hugging Face repo", "Upload .pt")
|
| 39 |
+
_DERIVED_CACHE_TRACKER_KEY = session_key("probe", "derived_cache_keys")
|
| 40 |
+
# Keep enough room for the three retained traces plus a few recently explored
|
| 41 |
+
# probes per trace. Derived outputs are much smaller than the trace activations
|
| 42 |
+
# themselves, so this avoids needless recomputation without reopening
|
| 43 |
+
# unbounded growth.
|
| 44 |
+
_DERIVED_CACHE_ENTRIES = env_int("PERSONA_UI_PROBE_DERIVED_CACHE_ENTRIES", 12)
|
| 45 |
|
| 46 |
|
| 47 |
# ---------------------------------------------------------------------------
|
|
|
|
| 71 |
return files[0]
|
| 72 |
|
| 73 |
|
| 74 |
+
def _render_probe_selector(*, context_key: str, model_name: str) -> LoadedProbe | None:
|
|
|
|
|
|
|
| 75 |
"""Inline source + file selector. Returns the loaded probe or None."""
|
| 76 |
+
source = remembered_segmented_control(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
"Probe source",
|
| 78 |
options=_PROBE_SOURCES,
|
| 79 |
+
key=widget_key(context_key, "probe_source"),
|
| 80 |
+
remember_key=_LAST_SOURCE_KEY,
|
| 81 |
+
default=_PROBE_SOURCES[0],
|
| 82 |
label_visibility="collapsed",
|
| 83 |
)
|
|
|
|
|
|
|
| 84 |
|
| 85 |
if source == "Local artifact":
|
| 86 |
return _render_local_probe(context_key=context_key, model_name=model_name)
|
|
|
|
| 89 |
return _render_upload_probe(context_key=context_key)
|
| 90 |
|
| 91 |
|
| 92 |
+
def _render_local_probe(*, context_key: str, model_name: str) -> LoadedProbe | None:
|
|
|
|
|
|
|
| 93 |
root_dir = st.text_input(
|
| 94 |
"Probe directory",
|
| 95 |
value=st.session_state.get(
|
|
|
|
| 118 |
return None
|
| 119 |
|
| 120 |
|
| 121 |
+
def _render_hub_probe(*, context_key: str, model_name: str) -> LoadedProbe | None:
|
|
|
|
|
|
|
| 122 |
repo_id = st.text_input(
|
| 123 |
"Probe repo",
|
| 124 |
value=st.session_state.get(
|
|
|
|
| 247 |
# ---------------------------------------------------------------------------
|
| 248 |
|
| 249 |
|
| 250 |
+
def _store_derived_cache(key: str, value: object) -> None:
|
| 251 |
+
"""Store one derived probe result while keeping a small MRU window."""
|
| 252 |
+
|
| 253 |
+
tracked = st.session_state.setdefault(_DERIVED_CACHE_TRACKER_KEY, [])
|
| 254 |
+
if not isinstance(tracked, list):
|
| 255 |
+
tracked = []
|
| 256 |
+
tracked = [existing for existing in tracked if existing != key]
|
| 257 |
+
tracked.append(key)
|
| 258 |
+
while len(tracked) > _DERIVED_CACHE_ENTRIES:
|
| 259 |
+
st.session_state.pop(tracked.pop(0), None)
|
| 260 |
+
st.session_state[_DERIVED_CACHE_TRACKER_KEY] = tracked
|
| 261 |
+
st.session_state[key] = value
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def _get_derived_cache(key: str) -> object | None:
|
| 265 |
+
"""Return a derived probe result and refresh its MRU position."""
|
| 266 |
+
|
| 267 |
+
cached = st.session_state.get(key)
|
| 268 |
+
if cached is None:
|
| 269 |
+
return None
|
| 270 |
+
tracked = st.session_state.get(_DERIVED_CACHE_TRACKER_KEY)
|
| 271 |
+
if isinstance(tracked, list) and key in tracked:
|
| 272 |
+
tracked = [existing for existing in tracked if existing != key]
|
| 273 |
+
tracked.append(key)
|
| 274 |
+
st.session_state[_DERIVED_CACHE_TRACKER_KEY] = tracked
|
| 275 |
+
return cached
|
| 276 |
+
|
| 277 |
+
|
| 278 |
def _classification_predictions(
|
| 279 |
probe: LoadedProbe, activations: torch.Tensor, cache_key: str
|
| 280 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 281 |
full_key = widget_key("probe_predictions", cache_key, str(id(probe)))
|
| 282 |
+
cached = _get_derived_cache(full_key)
|
| 283 |
if cached is not None:
|
| 284 |
return cached
|
| 285 |
_, probs, predicted = probe.run_batch(activations)
|
| 286 |
+
_store_derived_cache(full_key, (probs, predicted))
|
| 287 |
return probs, predicted
|
| 288 |
|
| 289 |
|
|
|
|
| 291 |
probe: LoadedProbe, activations: torch.Tensor, cache_key: str
|
| 292 |
) -> torch.Tensor:
|
| 293 |
full_key = widget_key("probe_values", cache_key, str(id(probe)))
|
| 294 |
+
cached = _get_derived_cache(full_key)
|
| 295 |
if cached is not None:
|
| 296 |
return cached
|
| 297 |
values = probe.predict_batch(activations)
|
| 298 |
+
_store_derived_cache(full_key, values)
|
| 299 |
return values
|
| 300 |
|
| 301 |
|
|
|
|
| 323 |
probs, predicted = _classification_predictions(
|
| 324 |
probe, trace.activations, trace.cache_key
|
| 325 |
)
|
| 326 |
+
binary = probs.shape[1] == 1 or (probs.shape[1] == 2 and len(probe.labels) == 2)
|
|
|
|
|
|
|
| 327 |
overlays = build_classification_overlays(
|
| 328 |
trace=trace,
|
| 329 |
probs=probs,
|
|
|
|
| 356 |
def _conversation_sig() -> int:
|
| 357 |
return hash(
|
| 358 |
tuple(
|
| 359 |
+
(m.get("role"), m.get("content")) for m in messages if m.get("content")
|
|
|
|
|
|
|
| 360 |
)
|
| 361 |
)
|
| 362 |
|
|
|
|
| 371 |
st.caption("Probe overlay shows up after the first assistant reply.")
|
| 372 |
return
|
| 373 |
|
| 374 |
+
probe = _render_probe_selector(context_key=context_key, model_name=model_name)
|
|
|
|
|
|
|
| 375 |
if probe is None:
|
| 376 |
_reset()
|
| 377 |
return
|
tests/test_datasets.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from utils import datasets
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class _Progress:
|
| 7 |
+
def __init__(self) -> None:
|
| 8 |
+
self.updates: list[tuple[float, str | None]] = []
|
| 9 |
+
|
| 10 |
+
def progress(self, value: float, *, text: str | None = None) -> None:
|
| 11 |
+
self.updates.append((value, text))
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def test_download_missing_startup_files_only_fetches_uncached_files(monkeypatch):
|
| 15 |
+
warnings: list[str] = []
|
| 16 |
+
progress = _Progress()
|
| 17 |
+
downloads: list[tuple[str, str, str]] = []
|
| 18 |
+
|
| 19 |
+
monkeypatch.setattr(
|
| 20 |
+
datasets,
|
| 21 |
+
"_is_cached",
|
| 22 |
+
lambda _repo, filename: filename == "already.jsonl",
|
| 23 |
+
)
|
| 24 |
+
monkeypatch.setattr(datasets.st, "warning", warnings.append)
|
| 25 |
+
monkeypatch.setattr(
|
| 26 |
+
datasets.st,
|
| 27 |
+
"progress",
|
| 28 |
+
lambda value, *, text=None: progress,
|
| 29 |
+
)
|
| 30 |
+
monkeypatch.setattr(
|
| 31 |
+
datasets,
|
| 32 |
+
"hf_hub_download",
|
| 33 |
+
lambda repo, filename, *, repo_type: downloads.append(
|
| 34 |
+
(repo, filename, repo_type)
|
| 35 |
+
),
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
datasets._download_missing_startup_files_if_needed(
|
| 39 |
+
"org/repo",
|
| 40 |
+
("already.jsonl", "missing.jsonl"),
|
| 41 |
+
"Example",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
assert warnings and "First-time setup for Example" in warnings[0]
|
| 45 |
+
assert downloads == [("org/repo", "missing.jsonl", "dataset")]
|
| 46 |
+
assert progress.updates[-1] == (1.0, "Downloaded missing.jsonl (1/1)")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def test_download_missing_startup_files_stays_quiet_when_cached(monkeypatch):
|
| 50 |
+
monkeypatch.setattr(datasets, "_is_cached", lambda *_args: True)
|
| 51 |
+
|
| 52 |
+
def unexpected(*_args, **_kwargs):
|
| 53 |
+
raise AssertionError("cold-download UI should not render for warm cache")
|
| 54 |
+
|
| 55 |
+
monkeypatch.setattr(datasets.st, "warning", unexpected)
|
| 56 |
+
monkeypatch.setattr(datasets.st, "progress", unexpected)
|
| 57 |
+
monkeypatch.setattr(datasets, "hf_hub_download", unexpected)
|
| 58 |
+
|
| 59 |
+
datasets._download_missing_startup_files_if_needed(
|
| 60 |
+
"org/repo",
|
| 61 |
+
("cached.jsonl",),
|
| 62 |
+
"Example",
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def test_prepare_nemotron_prefetches_first_parquet_shard(monkeypatch):
|
| 67 |
+
calls: list[tuple[str, tuple[str, ...], str]] = []
|
| 68 |
+
monkeypatch.setattr(
|
| 69 |
+
datasets,
|
| 70 |
+
"list_repo_files",
|
| 71 |
+
lambda *_args, **_kwargs: (
|
| 72 |
+
"README.md",
|
| 73 |
+
"data/train-00001-of-00002.parquet",
|
| 74 |
+
"data/train-00000-of-00002.parquet",
|
| 75 |
+
),
|
| 76 |
+
)
|
| 77 |
+
monkeypatch.setattr(
|
| 78 |
+
datasets,
|
| 79 |
+
"_download_missing_startup_files_if_needed",
|
| 80 |
+
lambda repo, filenames, label: calls.append((repo, filenames, label)),
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
datasets._prepare_nemotron_startup_download(
|
| 84 |
+
datasets.DatasetSource.NEMOTRON_USA.value,
|
| 85 |
+
"Nemotron USA",
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
assert calls == [
|
| 89 |
+
(
|
| 90 |
+
"nvidia/Nemotron-Personas-USA",
|
| 91 |
+
("data/train-00000-of-00002.parquet",),
|
| 92 |
+
"Nemotron USA",
|
| 93 |
+
)
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def test_warm_qa_makes_synth_qa_download_visible_before_thread(monkeypatch):
|
| 98 |
+
calls: list[tuple[str, tuple[str, ...], str]] = []
|
| 99 |
+
started: list[bool] = []
|
| 100 |
+
|
| 101 |
+
class DummySynth:
|
| 102 |
+
def prefetch_qa(self) -> None:
|
| 103 |
+
pass
|
| 104 |
+
|
| 105 |
+
class DummyThread:
|
| 106 |
+
def __init__(self, *args, **kwargs) -> None:
|
| 107 |
+
pass
|
| 108 |
+
|
| 109 |
+
def start(self) -> None:
|
| 110 |
+
started.append(True)
|
| 111 |
+
|
| 112 |
+
monkeypatch.setattr(datasets, "SynthPersonaDataset", DummySynth)
|
| 113 |
+
monkeypatch.setattr(
|
| 114 |
+
datasets,
|
| 115 |
+
"_download_missing_startup_files_if_needed",
|
| 116 |
+
lambda repo, filenames, label: calls.append((repo, filenames, label)),
|
| 117 |
+
)
|
| 118 |
+
monkeypatch.setattr(datasets.threading, "Thread", DummyThread)
|
| 119 |
+
|
| 120 |
+
datasets.warm_qa_in_background(DummySynth())
|
| 121 |
+
|
| 122 |
+
assert calls == [
|
| 123 |
+
(
|
| 124 |
+
"implicit-personalization/synth-persona",
|
| 125 |
+
("dataset_qa.jsonl",),
|
| 126 |
+
"SynthPersona QA",
|
| 127 |
+
)
|
| 128 |
+
]
|
| 129 |
+
assert started == [True]
|
tests/test_probe_cache_bounds.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from tabs import probe_ui
|
| 6 |
+
from utils import probe_trace
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def test_store_derived_cache_evicts_oldest(monkeypatch):
|
| 10 |
+
session_state: dict[str, object] = {}
|
| 11 |
+
monkeypatch.setattr(probe_ui.st, "session_state", session_state)
|
| 12 |
+
monkeypatch.setattr(probe_ui, "_DERIVED_CACHE_ENTRIES", 2)
|
| 13 |
+
|
| 14 |
+
probe_ui._store_derived_cache("k1", 1)
|
| 15 |
+
probe_ui._store_derived_cache("k2", 2)
|
| 16 |
+
probe_ui._store_derived_cache("k3", 3)
|
| 17 |
+
|
| 18 |
+
assert "k1" not in session_state
|
| 19 |
+
assert session_state["k2"] == 2
|
| 20 |
+
assert session_state["k3"] == 3
|
| 21 |
+
assert session_state[probe_ui._DERIVED_CACHE_TRACKER_KEY] == ["k2", "k3"]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def test_get_derived_cache_refreshes_recently_used_entry(monkeypatch):
|
| 25 |
+
session_state: dict[str, object] = {}
|
| 26 |
+
monkeypatch.setattr(probe_ui.st, "session_state", session_state)
|
| 27 |
+
monkeypatch.setattr(probe_ui, "_DERIVED_CACHE_ENTRIES", 2)
|
| 28 |
+
|
| 29 |
+
probe_ui._store_derived_cache("k1", 1)
|
| 30 |
+
probe_ui._store_derived_cache("k2", 2)
|
| 31 |
+
|
| 32 |
+
assert probe_ui._get_derived_cache("k1") == 1
|
| 33 |
+
probe_ui._store_derived_cache("k3", 3)
|
| 34 |
+
|
| 35 |
+
assert "k1" in session_state
|
| 36 |
+
assert "k2" not in session_state
|
| 37 |
+
assert session_state[probe_ui._DERIVED_CACHE_TRACKER_KEY] == ["k1", "k3"]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_trace_eviction_drops_derived_results(monkeypatch):
|
| 41 |
+
session_state: dict[str, object] = {}
|
| 42 |
+
monkeypatch.setattr(probe_trace.st, "session_state", session_state)
|
| 43 |
+
monkeypatch.setattr(probe_trace, "_MAX_CACHED_TRACES", 1)
|
| 44 |
+
|
| 45 |
+
trace = probe_trace.ConversationTrace(
|
| 46 |
+
cache_key="old",
|
| 47 |
+
model_name="m",
|
| 48 |
+
remote=False,
|
| 49 |
+
prompt_text="p",
|
| 50 |
+
prompt_hash="h",
|
| 51 |
+
layer=0,
|
| 52 |
+
location="post_reasoning",
|
| 53 |
+
input_ids=torch.tensor([1]),
|
| 54 |
+
activations=torch.zeros((1, 1)),
|
| 55 |
+
tokens=["x"],
|
| 56 |
+
assistant_spans=[],
|
| 57 |
+
is_special=torch.tensor([False]),
|
| 58 |
+
)
|
| 59 |
+
old_prediction_key = "probe_predictions::old::probe"
|
| 60 |
+
kept_prediction_key = "probe_predictions::new::probe"
|
| 61 |
+
session_state[probe_trace._DERIVED_CACHE_TRACKER_KEY] = [
|
| 62 |
+
old_prediction_key,
|
| 63 |
+
kept_prediction_key,
|
| 64 |
+
]
|
| 65 |
+
session_state[old_prediction_key] = object()
|
| 66 |
+
session_state[kept_prediction_key] = object()
|
| 67 |
+
|
| 68 |
+
probe_trace._store_cached_trace("old", trace)
|
| 69 |
+
probe_trace._store_cached_trace(
|
| 70 |
+
"new",
|
| 71 |
+
probe_trace.ConversationTrace(
|
| 72 |
+
**{**trace.__dict__, "cache_key": "new"},
|
| 73 |
+
),
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
assert old_prediction_key not in session_state
|
| 77 |
+
assert kept_prediction_key in session_state
|
| 78 |
+
assert session_state[probe_trace._DERIVED_CACHE_TRACKER_KEY] == [
|
| 79 |
+
kept_prediction_key
|
| 80 |
+
]
|
tests/test_probe_sweep.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from types import SimpleNamespace
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from persona_vectors.analysis import LayeredSamples
|
| 7 |
+
from persona_vectors.probes import AttributeLabels
|
| 8 |
+
|
| 9 |
+
from tabs import probe_sweep
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def test_cached_sweep_keeps_per_attribute_samples_and_full_plus_pca(monkeypatch):
|
| 13 |
+
samples = LayeredSamples(
|
| 14 |
+
vectors=torch.zeros((3, 2, 4)),
|
| 15 |
+
labels=["p0", "p1", "p2"],
|
| 16 |
+
hover_text=["p0", "p1", "p2"],
|
| 17 |
+
)
|
| 18 |
+
sweep_calls: list[tuple[str, int | None]] = []
|
| 19 |
+
|
| 20 |
+
monkeypatch.setattr(
|
| 21 |
+
probe_sweep,
|
| 22 |
+
"load_persona_vectors_cached",
|
| 23 |
+
lambda *args: samples,
|
| 24 |
+
)
|
| 25 |
+
monkeypatch.setattr(
|
| 26 |
+
probe_sweep,
|
| 27 |
+
"synth_persona_dataset_cached",
|
| 28 |
+
lambda: SimpleNamespace(),
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def labels_for(_dataset, attribute, _persona_ids, *, task):
|
| 32 |
+
return AttributeLabels(
|
| 33 |
+
attribute_name=attribute,
|
| 34 |
+
task=task,
|
| 35 |
+
y=torch.tensor([0, 1, 0]).numpy(),
|
| 36 |
+
labels=["a", "b", "a"],
|
| 37 |
+
class_names=["a", "b"],
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
monkeypatch.setattr(probe_sweep, "attribute_probe_labels", labels_for)
|
| 41 |
+
|
| 42 |
+
def filtered(input_samples, labels, *, min_count):
|
| 43 |
+
assert min_count == 2
|
| 44 |
+
return input_samples, labels
|
| 45 |
+
|
| 46 |
+
monkeypatch.setattr(
|
| 47 |
+
probe_sweep,
|
| 48 |
+
"filter_attribute_samples_min_count",
|
| 49 |
+
filtered,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def sweep(input_samples, labels, *, layers, probe_kinds, n_pca_components, seed):
|
| 53 |
+
assert input_samples is samples
|
| 54 |
+
assert layers == [0, 1]
|
| 55 |
+
assert probe_kinds == ["logistic_regression"]
|
| 56 |
+
assert seed == 0
|
| 57 |
+
sweep_calls.append((labels.attribute_name, n_pca_components))
|
| 58 |
+
return [
|
| 59 |
+
{
|
| 60 |
+
"attribute": labels.attribute_name,
|
| 61 |
+
"layer": 0,
|
| 62 |
+
"probe_kind": probe_kinds[0],
|
| 63 |
+
"balanced_accuracy": 0.5,
|
| 64 |
+
}
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
monkeypatch.setattr(probe_sweep, "sweep_attribute", sweep)
|
| 68 |
+
|
| 69 |
+
inputs = probe_sweep.SweepInputs(
|
| 70 |
+
source="src",
|
| 71 |
+
location="loc",
|
| 72 |
+
model_name="model",
|
| 73 |
+
mask_value="answer_mean",
|
| 74 |
+
variant="templated",
|
| 75 |
+
persona_ids=("p0", "p1", "p2"),
|
| 76 |
+
attributes=("sex", "gender"),
|
| 77 |
+
task="binary",
|
| 78 |
+
probe_kinds=("logistic_regression",),
|
| 79 |
+
n_pca_components=2,
|
| 80 |
+
layers=(0, 1),
|
| 81 |
+
min_class_count=2,
|
| 82 |
+
seed=0,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
rows_by_label, per_attr = probe_sweep.cached_sweep.__wrapped__(inputs)
|
| 86 |
+
|
| 87 |
+
assert list(rows_by_label) == ["full", "pca2"]
|
| 88 |
+
assert [row["attribute"] for row in rows_by_label["full"]] == ["sex", "gender"]
|
| 89 |
+
assert set(per_attr) == {"sex", "gender"}
|
| 90 |
+
assert sweep_calls == [
|
| 91 |
+
("sex", None),
|
| 92 |
+
("gender", None),
|
| 93 |
+
("sex", 2),
|
| 94 |
+
("gender", 2),
|
| 95 |
+
]
|
tests/test_probes.py
CHANGED
|
@@ -11,17 +11,16 @@ two correctness fixes:
|
|
| 11 |
|
| 12 |
import pytest
|
| 13 |
import torch
|
| 14 |
-
|
| 15 |
from persona_vectors.probes import ProbeArtifact
|
|
|
|
|
|
|
| 16 |
from utils.probes import (
|
| 17 |
LoadedProbe,
|
| 18 |
_LinearProbe,
|
| 19 |
_loaded_probe_from_artifact,
|
| 20 |
_normalize_labels,
|
| 21 |
-
parse_probe_filename,
|
| 22 |
)
|
| 23 |
|
| 24 |
-
|
| 25 |
# --------------------------------------------------------------------------- #
|
| 26 |
# parse_probe_filename
|
| 27 |
# --------------------------------------------------------------------------- #
|
|
@@ -123,9 +122,7 @@ def test_normalize_batch_pca_only_applies_pca():
|
|
| 123 |
probe = _probe(
|
| 124 |
2,
|
| 125 |
pca_mean=torch.ones(3),
|
| 126 |
-
pca_components=torch.tensor(
|
| 127 |
-
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]
|
| 128 |
-
),
|
| 129 |
)
|
| 130 |
batch = torch.tensor([[2.0, 4.0, 9.0]])
|
| 131 |
out = probe._normalize_batch(batch)
|
|
|
|
| 11 |
|
| 12 |
import pytest
|
| 13 |
import torch
|
|
|
|
| 14 |
from persona_vectors.probes import ProbeArtifact
|
| 15 |
+
|
| 16 |
+
from utils.probe_files import parse_probe_filename
|
| 17 |
from utils.probes import (
|
| 18 |
LoadedProbe,
|
| 19 |
_LinearProbe,
|
| 20 |
_loaded_probe_from_artifact,
|
| 21 |
_normalize_labels,
|
|
|
|
| 22 |
)
|
| 23 |
|
|
|
|
| 24 |
# --------------------------------------------------------------------------- #
|
| 25 |
# parse_probe_filename
|
| 26 |
# --------------------------------------------------------------------------- #
|
|
|
|
| 122 |
probe = _probe(
|
| 123 |
2,
|
| 124 |
pca_mean=torch.ones(3),
|
| 125 |
+
pca_components=torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]),
|
|
|
|
|
|
|
| 126 |
)
|
| 127 |
batch = torch.tensor([[2.0, 4.0, 9.0]])
|
| 128 |
out = probe._normalize_batch(batch)
|
tests/test_state.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from state import chat_session_key
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def test_chat_session_key_is_stable_across_model_switches() -> None:
|
| 5 |
+
dataset = "HuggingFace: synth-persona"
|
| 6 |
+
|
| 7 |
+
assert chat_session_key("google/gemma-2-2b-it", dataset) == chat_session_key(
|
| 8 |
+
"google/gemma-2-9b-it",
|
| 9 |
+
dataset,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_chat_session_key_still_separates_datasets() -> None:
|
| 14 |
+
model = "google/gemma-2-2b-it"
|
| 15 |
+
|
| 16 |
+
assert chat_session_key(model, "dataset-a") != chat_session_key(model, "dataset-b")
|
utils/analysis_sources.py
CHANGED
|
@@ -7,8 +7,8 @@ from persona_vectors.analysis import (
|
|
| 7 |
load_analysis_dataset,
|
| 8 |
)
|
| 9 |
from persona_vectors.artifacts import (
|
| 10 |
-
PersonaVectorStore,
|
| 11 |
HFPersonaVectorStore,
|
|
|
|
| 12 |
discover_activation_models,
|
| 13 |
model_dir_name,
|
| 14 |
)
|
|
|
|
| 7 |
load_analysis_dataset,
|
| 8 |
)
|
| 9 |
from persona_vectors.artifacts import (
|
|
|
|
| 10 |
HFPersonaVectorStore,
|
| 11 |
+
PersonaVectorStore,
|
| 12 |
discover_activation_models,
|
| 13 |
model_dir_name,
|
| 14 |
)
|
utils/chat.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import logging
|
|
|
|
| 4 |
from contextlib import contextmanager, nullcontext
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import TYPE_CHECKING, Any, Literal
|
|
@@ -185,6 +186,7 @@ def generate_chat_reply(
|
|
| 185 |
top_k: int = 50,
|
| 186 |
repetition_penalty: float = 1.0,
|
| 187 |
seed: int | None = None,
|
|
|
|
| 188 |
) -> ChatReply:
|
| 189 |
"""Generate one assistant reply from a full chat history.
|
| 190 |
|
|
@@ -228,9 +230,16 @@ def generate_chat_reply(
|
|
| 228 |
generation_kwargs["repetition_penalty"] = repetition_penalty
|
| 229 |
# `remote` is captured by nnsight's RemoteableMixin.trace() and is NOT
|
| 230 |
# forwarded to the underlying model's generate
|
|
|
|
|
|
|
| 231 |
with (
|
| 232 |
_seeded_rng(seed if do_sample and not remote else None),
|
| 233 |
-
model.generate(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
):
|
| 235 |
generated = tracer.result.save()
|
| 236 |
|
|
@@ -247,3 +256,34 @@ def generate_chat_reply(
|
|
| 247 |
text=text,
|
| 248 |
generated_ids=generated_ids.detach().cpu(),
|
| 249 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import logging
|
| 4 |
+
from collections.abc import Callable
|
| 5 |
from contextlib import contextmanager, nullcontext
|
| 6 |
from dataclasses import dataclass
|
| 7 |
from typing import TYPE_CHECKING, Any, Literal
|
|
|
|
| 186 |
top_k: int = 50,
|
| 187 |
repetition_penalty: float = 1.0,
|
| 188 |
seed: int | None = None,
|
| 189 |
+
on_status: Callable[[str, str, str], None] | None = None,
|
| 190 |
) -> ChatReply:
|
| 191 |
"""Generate one assistant reply from a full chat history.
|
| 192 |
|
|
|
|
| 230 |
generation_kwargs["repetition_penalty"] = repetition_penalty
|
| 231 |
# `remote` is captured by nnsight's RemoteableMixin.trace() and is NOT
|
| 232 |
# forwarded to the underlying model's generate
|
| 233 |
+
backend = _build_remote_backend(model, on_status) if remote else None
|
| 234 |
+
|
| 235 |
with (
|
| 236 |
_seeded_rng(seed if do_sample and not remote else None),
|
| 237 |
+
model.generate(
|
| 238 |
+
prompt,
|
| 239 |
+
remote=remote,
|
| 240 |
+
backend=backend,
|
| 241 |
+
**generation_kwargs,
|
| 242 |
+
) as tracer,
|
| 243 |
):
|
| 244 |
generated = tracer.result.save()
|
| 245 |
|
|
|
|
| 256 |
text=text,
|
| 257 |
generated_ids=generated_ids.detach().cpu(),
|
| 258 |
)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def _build_remote_backend(
|
| 262 |
+
model: StandardizedTransformer,
|
| 263 |
+
on_status: Callable[[str, str, str], None] | None,
|
| 264 |
+
):
|
| 265 |
+
"""Build an NDIF backend that can surface lifecycle updates to callers."""
|
| 266 |
+
|
| 267 |
+
if on_status is None:
|
| 268 |
+
return None
|
| 269 |
+
|
| 270 |
+
from nnsight.intervention.backends.remote import JobStatusDisplay, RemoteBackend
|
| 271 |
+
|
| 272 |
+
class _CallbackJobStatusDisplay(JobStatusDisplay):
|
| 273 |
+
def update(
|
| 274 |
+
self,
|
| 275 |
+
job_id: str = "",
|
| 276 |
+
status_name: str = "",
|
| 277 |
+
description: str = "",
|
| 278 |
+
):
|
| 279 |
+
super().update(job_id, status_name, description)
|
| 280 |
+
if status_name:
|
| 281 |
+
on_status(job_id, status_name, description)
|
| 282 |
+
|
| 283 |
+
backend = RemoteBackend(model.to_model_key())
|
| 284 |
+
backend.CONNECT_TIMEOUT = 300.0
|
| 285 |
+
backend.status_display = _CallbackJobStatusDisplay(
|
| 286 |
+
enabled=True,
|
| 287 |
+
verbose=backend.verbose,
|
| 288 |
+
)
|
| 289 |
+
return backend
|
utils/contrast.py
CHANGED
|
@@ -247,9 +247,7 @@ def render_contrast_html(result: TokenContrast) -> str:
|
|
| 247 |
# those render as blank lines before the first word. Drop leading
|
| 248 |
# whitespace-only tokens (and left-trim the first visible one) so the
|
| 249 |
# contrast starts at real content. Display-only — weights stay aligned.
|
| 250 |
-
items = list(
|
| 251 |
-
zip(result.tokens, result.weights, result.raw_diffs, strict=True)
|
| 252 |
-
)
|
| 253 |
start = 0
|
| 254 |
while start < len(items) and not items[start][0].strip():
|
| 255 |
start += 1
|
|
|
|
| 247 |
# those render as blank lines before the first word. Drop leading
|
| 248 |
# whitespace-only tokens (and left-trim the first visible one) so the
|
| 249 |
# contrast starts at real content. Display-only — weights stay aligned.
|
| 250 |
+
items = list(zip(result.tokens, result.weights, result.raw_diffs, strict=True))
|
|
|
|
|
|
|
| 251 |
start = 0
|
| 252 |
while start < len(items) and not items[start][0].strip():
|
| 253 |
start += 1
|
utils/controls.py
CHANGED
|
@@ -7,8 +7,12 @@ def render_mask_strategy_select(
|
|
| 7 |
key: str,
|
| 8 |
last_key: str,
|
| 9 |
help_text: str,
|
|
|
|
| 10 |
) -> MaskStrategy:
|
| 11 |
-
last_strategy = st.session_state.get(
|
|
|
|
|
|
|
|
|
|
| 12 |
strategies = list(MaskStrategy)
|
| 13 |
selected = st.selectbox(
|
| 14 |
"Mask strategy",
|
|
@@ -26,4 +30,6 @@ def render_mask_strategy_select(
|
|
| 26 |
help=help_text,
|
| 27 |
)
|
| 28 |
st.session_state[last_key] = selected.value
|
|
|
|
|
|
|
| 29 |
return selected
|
|
|
|
| 7 |
key: str,
|
| 8 |
last_key: str,
|
| 9 |
help_text: str,
|
| 10 |
+
remember_key: str | None = None,
|
| 11 |
) -> MaskStrategy:
|
| 12 |
+
last_strategy = st.session_state.get(
|
| 13 |
+
remember_key,
|
| 14 |
+
st.session_state.get(last_key, MaskStrategy.ANSWER_MEAN.value),
|
| 15 |
+
)
|
| 16 |
strategies = list(MaskStrategy)
|
| 17 |
selected = st.selectbox(
|
| 18 |
"Mask strategy",
|
|
|
|
| 30 |
help=help_text,
|
| 31 |
)
|
| 32 |
st.session_state[last_key] = selected.value
|
| 33 |
+
if remember_key is not None:
|
| 34 |
+
st.session_state[remember_key] = selected.value
|
| 35 |
return selected
|
utils/datasets.py
CHANGED
|
@@ -7,6 +7,7 @@ from tempfile import mkdtemp
|
|
| 7 |
from typing import Any
|
| 8 |
|
| 9 |
import streamlit as st
|
|
|
|
| 10 |
from persona_data.nemotron_personas import (
|
| 11 |
NemotronPersonasFranceDataset,
|
| 12 |
NemotronPersonasUSADataset,
|
|
@@ -16,6 +17,17 @@ from persona_data.synth_persona import SynthPersonaDataset
|
|
| 16 |
|
| 17 |
from .helpers import DatasetSource
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
@st.cache_resource(show_spinner=False)
|
| 21 |
def _cached_dataset(cls: type) -> Any:
|
|
@@ -39,13 +51,19 @@ def warm_qa_in_background(dataset: Any) -> None:
|
|
| 39 |
warm = getattr(dataset, "prefetch_qa", None)
|
| 40 |
if warm is None:
|
| 41 |
return # persona-only dataset (e.g. Nemotron) has no QA
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
with _qa_warm_lock:
|
| 43 |
if getattr(dataset, "_qa_warm_started", False):
|
| 44 |
return
|
| 45 |
dataset._qa_warm_started = True
|
| 46 |
-
threading.Thread(
|
| 47 |
-
target=warm, name="persona-ui-warm-qa", daemon=True
|
| 48 |
-
).start()
|
| 49 |
|
| 50 |
|
| 51 |
@st.cache_resource(show_spinner=False)
|
|
@@ -118,12 +136,19 @@ def load_dataset(
|
|
| 118 |
"""Load the selected dataset source for the UI."""
|
| 119 |
|
| 120 |
if dataset_source == DatasetSource.SYNTH_PERSONA.value:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
return _cached_dataset(SynthPersonaDataset), "SynthPersona"
|
| 122 |
|
| 123 |
if dataset_source == DatasetSource.NEMOTRON_FRANCE.value:
|
|
|
|
| 124 |
return _cached_dataset(NemotronPersonasFranceDataset), "Nemotron France"
|
| 125 |
|
| 126 |
if dataset_source == DatasetSource.NEMOTRON_USA.value:
|
|
|
|
| 127 |
return _cached_dataset(NemotronPersonasUSADataset), "Nemotron USA"
|
| 128 |
|
| 129 |
if personas_file is None or qa_file is None:
|
|
@@ -132,3 +157,60 @@ def load_dataset(
|
|
| 132 |
personas_path = _uploaded_file_to_temp_path(personas_file, stem="personas")
|
| 133 |
qa_path = _uploaded_file_to_temp_path(qa_file, stem="qa")
|
| 134 |
return _cached_local_dataset(str(personas_path), str(qa_path)), "Local upload"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from typing import Any
|
| 8 |
|
| 9 |
import streamlit as st
|
| 10 |
+
from huggingface_hub import hf_hub_download, list_repo_files, try_to_load_from_cache
|
| 11 |
from persona_data.nemotron_personas import (
|
| 12 |
NemotronPersonasFranceDataset,
|
| 13 |
NemotronPersonasUSADataset,
|
|
|
|
| 17 |
|
| 18 |
from .helpers import DatasetSource
|
| 19 |
|
| 20 |
+
_SYNTH_PERSONA_REPO = "implicit-personalization/synth-persona"
|
| 21 |
+
_SYNTH_PERSONA_STARTUP_FILES = (
|
| 22 |
+
"implicit_shared_mc_bank.json",
|
| 23 |
+
"dataset_personas.jsonl",
|
| 24 |
+
)
|
| 25 |
+
_SYNTH_PERSONA_QA_FILE = "dataset_qa.jsonl"
|
| 26 |
+
_NEMOTRON_REPOS = {
|
| 27 |
+
DatasetSource.NEMOTRON_FRANCE.value: "nvidia/Nemotron-Personas-France",
|
| 28 |
+
DatasetSource.NEMOTRON_USA.value: "nvidia/Nemotron-Personas-USA",
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
|
| 32 |
@st.cache_resource(show_spinner=False)
|
| 33 |
def _cached_dataset(cls: type) -> Any:
|
|
|
|
| 51 |
warm = getattr(dataset, "prefetch_qa", None)
|
| 52 |
if warm is None:
|
| 53 |
return # persona-only dataset (e.g. Nemotron) has no QA
|
| 54 |
+
if isinstance(dataset, SynthPersonaDataset):
|
| 55 |
+
# Extract will need QA soon. Make the one-time large transfer explicit,
|
| 56 |
+
# then leave the CPU-heavy parse on the existing background thread.
|
| 57 |
+
_download_missing_startup_files_if_needed(
|
| 58 |
+
_SYNTH_PERSONA_REPO,
|
| 59 |
+
(_SYNTH_PERSONA_QA_FILE,),
|
| 60 |
+
"SynthPersona QA",
|
| 61 |
+
)
|
| 62 |
with _qa_warm_lock:
|
| 63 |
if getattr(dataset, "_qa_warm_started", False):
|
| 64 |
return
|
| 65 |
dataset._qa_warm_started = True
|
| 66 |
+
threading.Thread(target=warm, name="persona-ui-warm-qa", daemon=True).start()
|
|
|
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
@st.cache_resource(show_spinner=False)
|
|
|
|
| 136 |
"""Load the selected dataset source for the UI."""
|
| 137 |
|
| 138 |
if dataset_source == DatasetSource.SYNTH_PERSONA.value:
|
| 139 |
+
_download_missing_startup_files_if_needed(
|
| 140 |
+
_SYNTH_PERSONA_REPO,
|
| 141 |
+
_SYNTH_PERSONA_STARTUP_FILES,
|
| 142 |
+
"SynthPersona",
|
| 143 |
+
)
|
| 144 |
return _cached_dataset(SynthPersonaDataset), "SynthPersona"
|
| 145 |
|
| 146 |
if dataset_source == DatasetSource.NEMOTRON_FRANCE.value:
|
| 147 |
+
_prepare_nemotron_startup_download(dataset_source, "Nemotron France")
|
| 148 |
return _cached_dataset(NemotronPersonasFranceDataset), "Nemotron France"
|
| 149 |
|
| 150 |
if dataset_source == DatasetSource.NEMOTRON_USA.value:
|
| 151 |
+
_prepare_nemotron_startup_download(dataset_source, "Nemotron USA")
|
| 152 |
return _cached_dataset(NemotronPersonasUSADataset), "Nemotron USA"
|
| 153 |
|
| 154 |
if personas_file is None or qa_file is None:
|
|
|
|
| 157 |
personas_path = _uploaded_file_to_temp_path(personas_file, stem="personas")
|
| 158 |
qa_path = _uploaded_file_to_temp_path(qa_file, stem="qa")
|
| 159 |
return _cached_local_dataset(str(personas_path), str(qa_path)), "Local upload"
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _is_cached(repo_id: str, filename: str) -> bool:
|
| 163 |
+
"""Return whether a Hub dataset file already exists in the local HF cache."""
|
| 164 |
+
|
| 165 |
+
cached = try_to_load_from_cache(repo_id, filename, repo_type="dataset")
|
| 166 |
+
return isinstance(cached, str)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def _download_missing_startup_files_if_needed(
|
| 170 |
+
repo_id: str,
|
| 171 |
+
filenames: tuple[str, ...],
|
| 172 |
+
label: str,
|
| 173 |
+
) -> None:
|
| 174 |
+
"""Make first-time Hub downloads visible before dataset construction blocks.
|
| 175 |
+
|
| 176 |
+
Hugging Face handles byte-level transfer internally. We expose file-level
|
| 177 |
+
progress here, which is the useful unit this UI can know in advance.
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
missing = tuple(
|
| 181 |
+
filename for filename in filenames if not _is_cached(repo_id, filename)
|
| 182 |
+
)
|
| 183 |
+
if not missing:
|
| 184 |
+
return
|
| 185 |
+
|
| 186 |
+
st.warning(
|
| 187 |
+
f"First-time setup for {label}: downloading dataset files from Hugging Face. "
|
| 188 |
+
"Later loads should use the local cache."
|
| 189 |
+
)
|
| 190 |
+
progress = st.progress(0.0, text=f"Preparing {label} download…")
|
| 191 |
+
total = len(missing)
|
| 192 |
+
for index, filename in enumerate(missing, start=1):
|
| 193 |
+
progress.progress(
|
| 194 |
+
(index - 1) / total,
|
| 195 |
+
text=f"Downloading {filename} ({index}/{total})",
|
| 196 |
+
)
|
| 197 |
+
hf_hub_download(repo_id, filename, repo_type="dataset")
|
| 198 |
+
progress.progress(
|
| 199 |
+
index / total,
|
| 200 |
+
text=f"Downloaded {filename} ({index}/{total})",
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _prepare_nemotron_startup_download(dataset_source: str, label: str) -> None:
|
| 205 |
+
"""Prefetch the first parquet shard used by the default Nemotron sample."""
|
| 206 |
+
|
| 207 |
+
repo_id = _NEMOTRON_REPOS[dataset_source]
|
| 208 |
+
parquet_files = tuple(
|
| 209 |
+
sorted(
|
| 210 |
+
filename
|
| 211 |
+
for filename in list_repo_files(repo_id, repo_type="dataset")
|
| 212 |
+
if filename.startswith("data/train-") and filename.endswith(".parquet")
|
| 213 |
+
)
|
| 214 |
+
)
|
| 215 |
+
if parquet_files:
|
| 216 |
+
_download_missing_startup_files_if_needed(repo_id, (parquet_files[0],), label)
|
utils/helpers.py
CHANGED
|
@@ -64,6 +64,26 @@ NDIF_STATUS_ICONS = {
|
|
| 64 |
}
|
| 65 |
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def slugify(value: str) -> str:
|
| 68 |
"""Convert a string to a filesystem-safe slug."""
|
| 69 |
|
|
|
|
| 64 |
}
|
| 65 |
|
| 66 |
|
| 67 |
+
def format_ndif_status(
|
| 68 |
+
job_id: str,
|
| 69 |
+
status_name: str,
|
| 70 |
+
description: str,
|
| 71 |
+
*,
|
| 72 |
+
prefix: str | None = None,
|
| 73 |
+
completed_detail: str | None = None,
|
| 74 |
+
) -> str:
|
| 75 |
+
"""Build the shared one-line NDIF status label used across the UI."""
|
| 76 |
+
|
| 77 |
+
icon = NDIF_STATUS_ICONS.get(status_name, "•")
|
| 78 |
+
detail = (
|
| 79 |
+
completed_detail
|
| 80 |
+
if completed_detail is not None and status_name == "COMPLETED"
|
| 81 |
+
else description
|
| 82 |
+
)
|
| 83 |
+
label = f"{icon} `{job_id}` **{status_name}** — {detail}"
|
| 84 |
+
return f"{prefix}: {label}" if prefix else label
|
| 85 |
+
|
| 86 |
+
|
| 87 |
def slugify(value: str) -> str:
|
| 88 |
"""Convert a string to a filesystem-safe slug."""
|
| 89 |
|
utils/probe_files.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import streamlit as st
|
| 9 |
+
|
| 10 |
+
PROBE_FILENAME_RE = re.compile(
|
| 11 |
+
r"^cognitive_map_probe_layer(?P<layer>\d+)_(?P<model_type>[a-z0-9]+)_"
|
| 12 |
+
r"(?P<location>pre_reasoning|post_reasoning)_all_(?P<scope>general|size\d+)\.pt$"
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
PERSONA_PROBE_DIR_RE = re.compile(
|
| 16 |
+
r"^(?P<probe_kind>[a-z_]+?)(?:_pca(?P<pca>\d+))?_layer(?P<layer>\d+)$"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
DEFAULT_PROBE_REPO = "project-telos/cognitive_map_probes"
|
| 20 |
+
DEFAULT_LOCAL_PROBE_DIR = os.environ.get("PERSONA_PROBES_DIR", "artifacts/probes")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass(frozen=True)
|
| 24 |
+
class ProbeFileMetadata:
|
| 25 |
+
filename: str
|
| 26 |
+
layer: int | None
|
| 27 |
+
model_type: str
|
| 28 |
+
location: str | None
|
| 29 |
+
scope: str | None
|
| 30 |
+
label: str
|
| 31 |
+
model_name: str | None = None
|
| 32 |
+
attribute_name: str | None = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def model_probe_dir_name(model_name: str) -> str:
|
| 36 |
+
return model_name.replace("/", "__")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def parse_probe_filename(filename: str) -> ProbeFileMetadata:
|
| 40 |
+
path = Path(filename)
|
| 41 |
+
match = PROBE_FILENAME_RE.match(path.name)
|
| 42 |
+
if match:
|
| 43 |
+
layer = int(match.group("layer"))
|
| 44 |
+
model_type = match.group("model_type")
|
| 45 |
+
location = match.group("location")
|
| 46 |
+
scope = match.group("scope")
|
| 47 |
+
scope_label = scope.replace("size", "size ")
|
| 48 |
+
return ProbeFileMetadata(
|
| 49 |
+
filename=filename,
|
| 50 |
+
layer=layer,
|
| 51 |
+
model_type=model_type,
|
| 52 |
+
location=location,
|
| 53 |
+
scope=scope,
|
| 54 |
+
label=(
|
| 55 |
+
f"Layer {layer} - {model_type.upper()} - "
|
| 56 |
+
f"{location.replace('_', ' ')} - {scope_label}"
|
| 57 |
+
),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
parent_match = PERSONA_PROBE_DIR_RE.match(path.parent.name)
|
| 61 |
+
if parent_match and path.name in {"probe.json", "weights.safetensors"}:
|
| 62 |
+
layer = int(parent_match.group("layer"))
|
| 63 |
+
probe_kind = parent_match.group("probe_kind")
|
| 64 |
+
pca = parent_match.group("pca")
|
| 65 |
+
scope = f"pca{pca}" if pca else None
|
| 66 |
+
attribute = path.parent.parent.name or "attribute"
|
| 67 |
+
model_name = path.parts[0].replace("__", "/") if len(path.parts) >= 5 else None
|
| 68 |
+
label = f"{attribute} - layer {layer} - {probe_kind}"
|
| 69 |
+
if pca:
|
| 70 |
+
label += f" (pca{pca})"
|
| 71 |
+
return ProbeFileMetadata(
|
| 72 |
+
filename=filename,
|
| 73 |
+
layer=layer,
|
| 74 |
+
model_type=probe_kind,
|
| 75 |
+
location=None,
|
| 76 |
+
scope=scope,
|
| 77 |
+
label=label,
|
| 78 |
+
model_name=model_name,
|
| 79 |
+
attribute_name=attribute,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
return ProbeFileMetadata(
|
| 83 |
+
filename=filename,
|
| 84 |
+
layer=None,
|
| 85 |
+
model_type="unknown",
|
| 86 |
+
location=None,
|
| 87 |
+
scope=None,
|
| 88 |
+
label=path.stem.replace("_", " "),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@st.cache_data(show_spinner=False, ttl=300)
|
| 93 |
+
def list_probe_files(repo_id: str) -> list[str]:
|
| 94 |
+
from huggingface_hub import list_repo_files
|
| 95 |
+
|
| 96 |
+
return _dedupe_probe_entries(list_repo_files(repo_id, repo_type="model"))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@st.cache_data(show_spinner=False, ttl=30)
|
| 100 |
+
def list_local_probe_files(root_dir: str) -> list[str]:
|
| 101 |
+
root = Path(root_dir).expanduser()
|
| 102 |
+
if not root.is_dir():
|
| 103 |
+
return []
|
| 104 |
+
files = _dedupe_probe_entries(
|
| 105 |
+
[
|
| 106 |
+
str(path.relative_to(root))
|
| 107 |
+
for path in root.rglob("*")
|
| 108 |
+
if path.is_file()
|
| 109 |
+
and path.name in {"probe.pt", "probe.json", "weights.safetensors"}
|
| 110 |
+
]
|
| 111 |
+
)
|
| 112 |
+
return sorted(files, key=_probe_sort_key)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@st.cache_data(show_spinner=False, ttl=300)
|
| 116 |
+
def download_probe_file(repo_id: str, filename: str) -> str:
|
| 117 |
+
from huggingface_hub import hf_hub_download
|
| 118 |
+
|
| 119 |
+
return hf_hub_download(repo_id, filename, repo_type="model")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@st.cache_data(show_spinner=False, ttl=300)
|
| 123 |
+
def download_probe_json_and_weights(repo_id: str, filename: str) -> tuple[str, str]:
|
| 124 |
+
from huggingface_hub import hf_hub_download
|
| 125 |
+
|
| 126 |
+
metadata_path = hf_hub_download(repo_id, filename, repo_type="model")
|
| 127 |
+
weights_name = str(Path(filename).with_name("weights.safetensors"))
|
| 128 |
+
weights_path = hf_hub_download(repo_id, weights_name, repo_type="model")
|
| 129 |
+
return metadata_path, weights_path
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _probe_sort_key(filename: str) -> tuple[str, str, int, str]:
|
| 133 |
+
metadata = parse_probe_filename(filename)
|
| 134 |
+
return (
|
| 135 |
+
metadata.model_name or "",
|
| 136 |
+
metadata.attribute_name or "",
|
| 137 |
+
metadata.layer if metadata.layer is not None else 10**9,
|
| 138 |
+
filename,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _dedupe_probe_entries(files: list[str]) -> list[str]:
|
| 143 |
+
by_dir: dict[str, set[str]] = {}
|
| 144 |
+
standalone: list[str] = []
|
| 145 |
+
for filename in files:
|
| 146 |
+
path = Path(filename)
|
| 147 |
+
if path.name in {"probe.pt", "probe.json", "weights.safetensors"}:
|
| 148 |
+
by_dir.setdefault(str(path.parent), set()).add(path.name)
|
| 149 |
+
elif filename.endswith(".pt"):
|
| 150 |
+
standalone.append(filename)
|
| 151 |
+
|
| 152 |
+
entries = list(standalone)
|
| 153 |
+
for directory, names in by_dir.items():
|
| 154 |
+
selected = (
|
| 155 |
+
"probe.json"
|
| 156 |
+
if "probe.json" in names
|
| 157 |
+
else "probe.pt"
|
| 158 |
+
if "probe.pt" in names
|
| 159 |
+
else "weights.safetensors"
|
| 160 |
+
)
|
| 161 |
+
entries.append(str(Path(directory) / selected))
|
| 162 |
+
return sorted(entries, key=_probe_sort_key)
|
utils/probe_overlay.py
CHANGED
|
@@ -124,18 +124,14 @@ def build_regression_overlays(
|
|
| 124 |
return overlays
|
| 125 |
|
| 126 |
|
| 127 |
-
def attach_overlays(
|
| 128 |
-
messages: list[dict], overlays: list[ProbeOverlay]
|
| 129 |
-
) -> None:
|
| 130 |
"""Attach one overlay to each assistant message, in order.
|
| 131 |
|
| 132 |
Requires a 1:1 match. If the counts don't line up (e.g. the chat template
|
| 133 |
doesn't mark assistant tokens), clear overlays so the caller can show a
|
| 134 |
clear status instead of painting the wrong message.
|
| 135 |
"""
|
| 136 |
-
assistant_idxs = [
|
| 137 |
-
i for i, m in enumerate(messages) if m.get("role") == "assistant"
|
| 138 |
-
]
|
| 139 |
clear_overlays(messages)
|
| 140 |
if not assistant_idxs or len(overlays) != len(assistant_idxs):
|
| 141 |
return
|
|
@@ -189,8 +185,7 @@ def _tooltip(probs_row: list[float], labels: list[str | None]) -> str:
|
|
| 189 |
# Single-output sigmoid: synthesize the complementary class so the
|
| 190 |
# hover shows both label probabilities, not just one.
|
| 191 |
return escape(
|
| 192 |
-
f"{positive_label} {positive:.2f} · "
|
| 193 |
-
f"not {positive_label} {1 - positive:.2f}"
|
| 194 |
)
|
| 195 |
ranked = sorted(enumerate(probs_row), key=lambda item: item[1], reverse=True)
|
| 196 |
parts = [f"{_label_for(labels, idx)} {prob:.2f}" for idx, prob in ranked]
|
|
|
|
| 124 |
return overlays
|
| 125 |
|
| 126 |
|
| 127 |
+
def attach_overlays(messages: list[dict], overlays: list[ProbeOverlay]) -> None:
|
|
|
|
|
|
|
| 128 |
"""Attach one overlay to each assistant message, in order.
|
| 129 |
|
| 130 |
Requires a 1:1 match. If the counts don't line up (e.g. the chat template
|
| 131 |
doesn't mark assistant tokens), clear overlays so the caller can show a
|
| 132 |
clear status instead of painting the wrong message.
|
| 133 |
"""
|
| 134 |
+
assistant_idxs = [i for i, m in enumerate(messages) if m.get("role") == "assistant"]
|
|
|
|
|
|
|
| 135 |
clear_overlays(messages)
|
| 136 |
if not assistant_idxs or len(overlays) != len(assistant_idxs):
|
| 137 |
return
|
|
|
|
| 185 |
# Single-output sigmoid: synthesize the complementary class so the
|
| 186 |
# hover shows both label probabilities, not just one.
|
| 187 |
return escape(
|
| 188 |
+
f"{positive_label} {positive:.2f} · not {positive_label} {1 - positive:.2f}"
|
|
|
|
| 189 |
)
|
| 190 |
ranked = sorted(enumerate(probs_row), key=lambda item: item[1], reverse=True)
|
| 191 |
parts = [f"{_label_for(labels, idx)} {prob:.2f}" for idx, prob in ranked]
|
utils/probe_trace.py
CHANGED
|
@@ -11,6 +11,7 @@ from persona_data.prompts import normalize_messages, supports_system_role
|
|
| 11 |
from utils.chat import decode_token, format_generation_prompt, resolve_saved_tensor
|
| 12 |
|
| 13 |
_TRACE_CACHE_KEY = "probe:trace_cache"
|
|
|
|
| 14 |
_MAX_CACHED_TRACES = 3
|
| 15 |
|
| 16 |
|
|
@@ -92,9 +93,7 @@ def trace_conversation(
|
|
| 92 |
|
| 93 |
n_tokens = int(input_ids.shape[0])
|
| 94 |
assistant_spans = _clip_spans(
|
| 95 |
-
_assistant_spans_from_offsets(
|
| 96 |
-
model.tokenizer, prompt_text, messages, n_tokens
|
| 97 |
-
),
|
| 98 |
n_tokens,
|
| 99 |
)
|
| 100 |
if not assistant_spans and assistant_mask_seq is not None:
|
|
@@ -182,6 +181,30 @@ def _store_cached_trace(cache_key: str, trace: ConversationTrace) -> None:
|
|
| 182 |
while len(cache) > _MAX_CACHED_TRACES:
|
| 183 |
oldest_key = next(iter(cache))
|
| 184 |
cache.pop(oldest_key, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
|
| 187 |
def _compute_assistant_mask(
|
|
@@ -302,9 +325,7 @@ def _assistant_spans_from_prefixes(
|
|
| 302 |
for i, message in enumerate(messages):
|
| 303 |
if message.get("role") != "assistant":
|
| 304 |
continue
|
| 305 |
-
prefix_ids = apply(
|
| 306 |
-
messages[:i], tokenize=True, add_generation_prompt=True
|
| 307 |
-
)
|
| 308 |
through_ids = apply(
|
| 309 |
messages[: i + 1], tokenize=True, add_generation_prompt=False
|
| 310 |
)
|
|
@@ -332,9 +353,7 @@ def _flatten_ids(value: object) -> list[int] | None:
|
|
| 332 |
return None
|
| 333 |
|
| 334 |
|
| 335 |
-
def _clip_spans(
|
| 336 |
-
spans: list[tuple[int, int]], n_tokens: int
|
| 337 |
-
) -> list[tuple[int, int]]:
|
| 338 |
clipped: list[tuple[int, int]] = []
|
| 339 |
for start, end in spans:
|
| 340 |
s = max(0, min(start, n_tokens))
|
|
|
|
| 11 |
from utils.chat import decode_token, format_generation_prompt, resolve_saved_tensor
|
| 12 |
|
| 13 |
_TRACE_CACHE_KEY = "probe:trace_cache"
|
| 14 |
+
_DERIVED_CACHE_TRACKER_KEY = "probe:derived_cache_keys"
|
| 15 |
_MAX_CACHED_TRACES = 3
|
| 16 |
|
| 17 |
|
|
|
|
| 93 |
|
| 94 |
n_tokens = int(input_ids.shape[0])
|
| 95 |
assistant_spans = _clip_spans(
|
| 96 |
+
_assistant_spans_from_offsets(model.tokenizer, prompt_text, messages, n_tokens),
|
|
|
|
|
|
|
| 97 |
n_tokens,
|
| 98 |
)
|
| 99 |
if not assistant_spans and assistant_mask_seq is not None:
|
|
|
|
| 181 |
while len(cache) > _MAX_CACHED_TRACES:
|
| 182 |
oldest_key = next(iter(cache))
|
| 183 |
cache.pop(oldest_key, None)
|
| 184 |
+
_drop_derived_results_for_trace(oldest_key)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def _drop_derived_results_for_trace(cache_key: str) -> None:
|
| 188 |
+
"""Remove probe predictions tied to a trace that just aged out."""
|
| 189 |
+
|
| 190 |
+
prefixes = (
|
| 191 |
+
f"probe_predictions::{cache_key}::",
|
| 192 |
+
f"probe_values::{cache_key}::",
|
| 193 |
+
)
|
| 194 |
+
tracked = st.session_state.get(_DERIVED_CACHE_TRACKER_KEY)
|
| 195 |
+
if isinstance(tracked, list):
|
| 196 |
+
kept: list[str] = []
|
| 197 |
+
for key in tracked:
|
| 198 |
+
if isinstance(key, str) and key.startswith(prefixes):
|
| 199 |
+
st.session_state.pop(key, None)
|
| 200 |
+
else:
|
| 201 |
+
kept.append(key)
|
| 202 |
+
st.session_state[_DERIVED_CACHE_TRACKER_KEY] = kept
|
| 203 |
+
return
|
| 204 |
+
|
| 205 |
+
for key in list(st.session_state):
|
| 206 |
+
if isinstance(key, str) and key.startswith(prefixes):
|
| 207 |
+
st.session_state.pop(key, None)
|
| 208 |
|
| 209 |
|
| 210 |
def _compute_assistant_mask(
|
|
|
|
| 325 |
for i, message in enumerate(messages):
|
| 326 |
if message.get("role") != "assistant":
|
| 327 |
continue
|
| 328 |
+
prefix_ids = apply(messages[:i], tokenize=True, add_generation_prompt=True)
|
|
|
|
|
|
|
| 329 |
through_ids = apply(
|
| 330 |
messages[: i + 1], tokenize=True, add_generation_prompt=False
|
| 331 |
)
|
|
|
|
| 353 |
return None
|
| 354 |
|
| 355 |
|
| 356 |
+
def _clip_spans(spans: list[tuple[int, int]], n_tokens: int) -> list[tuple[int, int]]:
|
|
|
|
|
|
|
| 357 |
clipped: list[tuple[int, int]] = []
|
| 358 |
for start, end in spans:
|
| 359 |
s = max(0, min(start, n_tokens))
|
utils/probes.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import io
|
| 4 |
-
import os
|
| 5 |
-
import re
|
| 6 |
from dataclasses import dataclass
|
| 7 |
from pathlib import Path
|
| 8 |
from typing import Any
|
|
@@ -13,33 +11,14 @@ import torch.nn as nn
|
|
| 13 |
import torch.nn.functional as F
|
| 14 |
from persona_vectors.probes import ProbeArtifact, load_probe_artifact
|
| 15 |
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
| 19 |
)
|
| 20 |
|
| 21 |
-
|
| 22 |
-
PERSONA_PROBE_DIR_RE = re.compile(
|
| 23 |
-
r"^(?P<probe_kind>[a-z_]+?)(?:_pca(?P<pca>\d+))?_layer(?P<layer>\d+)$"
|
| 24 |
-
)
|
| 25 |
-
|
| 26 |
-
DEFAULT_PROBE_REPO = "project-telos/cognitive_map_probes"
|
| 27 |
-
DEFAULT_LOCAL_PROBE_DIR = os.environ.get(
|
| 28 |
-
"PERSONA_PROBES_DIR",
|
| 29 |
-
"artifacts/probes",
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
@dataclass(frozen=True)
|
| 34 |
-
class ProbeFileMetadata:
|
| 35 |
-
filename: str
|
| 36 |
-
layer: int | None
|
| 37 |
-
model_type: str
|
| 38 |
-
location: str | None
|
| 39 |
-
scope: str | None
|
| 40 |
-
label: str
|
| 41 |
-
model_name: str | None = None
|
| 42 |
-
attribute_name: str | None = None
|
| 43 |
|
| 44 |
|
| 45 |
@dataclass(frozen=True)
|
|
@@ -195,9 +174,7 @@ class LoadedProbe:
|
|
| 195 |
predicted = torch.argmax(probs, dim=-1)
|
| 196 |
return logits, probs, predicted
|
| 197 |
|
| 198 |
-
def _forward_batch(
|
| 199 |
-
self, batch: torch.Tensor
|
| 200 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 201 |
normalized = self._normalize_batch(batch)
|
| 202 |
with torch.no_grad():
|
| 203 |
logits = self.model(normalized).detach().cpu()
|
|
@@ -233,104 +210,7 @@ class LoadedProbe:
|
|
| 233 |
return batch
|
| 234 |
|
| 235 |
|
| 236 |
-
|
| 237 |
-
return model_name.replace("/", "__")
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
def parse_probe_filename(filename: str) -> ProbeFileMetadata:
|
| 241 |
-
path = Path(filename)
|
| 242 |
-
match = PROBE_FILENAME_RE.match(path.name)
|
| 243 |
-
if match:
|
| 244 |
-
layer = int(match.group("layer"))
|
| 245 |
-
model_type = match.group("model_type")
|
| 246 |
-
location = match.group("location")
|
| 247 |
-
scope = match.group("scope")
|
| 248 |
-
scope_label = scope.replace("size", "size ")
|
| 249 |
-
return ProbeFileMetadata(
|
| 250 |
-
filename=filename,
|
| 251 |
-
layer=layer,
|
| 252 |
-
model_type=model_type,
|
| 253 |
-
location=location,
|
| 254 |
-
scope=scope,
|
| 255 |
-
label=(
|
| 256 |
-
f"Layer {layer} - {model_type.upper()} - "
|
| 257 |
-
f"{location.replace('_', ' ')} - {scope_label}"
|
| 258 |
-
),
|
| 259 |
-
)
|
| 260 |
-
|
| 261 |
-
# persona-vectors layout: parent dir holds probe_kind[_pca{K}]_layer{N},
|
| 262 |
-
# and the dir above that is the attribute name.
|
| 263 |
-
parent_match = PERSONA_PROBE_DIR_RE.match(path.parent.name)
|
| 264 |
-
if parent_match and path.name in {"probe.json", "weights.safetensors"}:
|
| 265 |
-
layer = int(parent_match.group("layer"))
|
| 266 |
-
probe_kind = parent_match.group("probe_kind")
|
| 267 |
-
pca = parent_match.group("pca")
|
| 268 |
-
scope = f"pca{pca}" if pca else None
|
| 269 |
-
attribute = path.parent.parent.name or "attribute"
|
| 270 |
-
model_name = path.parts[0].replace("__", "/") if len(path.parts) >= 5 else None
|
| 271 |
-
label = f"{attribute} - layer {layer} - {probe_kind}"
|
| 272 |
-
if pca:
|
| 273 |
-
label += f" (pca{pca})"
|
| 274 |
-
return ProbeFileMetadata(
|
| 275 |
-
filename=filename,
|
| 276 |
-
layer=layer,
|
| 277 |
-
model_type=probe_kind,
|
| 278 |
-
location=None,
|
| 279 |
-
scope=scope,
|
| 280 |
-
label=label,
|
| 281 |
-
model_name=model_name,
|
| 282 |
-
attribute_name=attribute,
|
| 283 |
-
)
|
| 284 |
-
|
| 285 |
-
return ProbeFileMetadata(
|
| 286 |
-
filename=filename,
|
| 287 |
-
layer=None,
|
| 288 |
-
model_type="unknown",
|
| 289 |
-
location=None,
|
| 290 |
-
scope=None,
|
| 291 |
-
label=path.stem.replace("_", " "),
|
| 292 |
-
)
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
@st.cache_data(show_spinner=False, ttl=300)
|
| 296 |
-
def list_probe_files(repo_id: str) -> list[str]:
|
| 297 |
-
from huggingface_hub import list_repo_files
|
| 298 |
-
|
| 299 |
-
files = list_repo_files(repo_id, repo_type="model")
|
| 300 |
-
return _dedupe_probe_entries(files)
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
@st.cache_data(show_spinner=False, ttl=30)
|
| 304 |
-
def list_local_probe_files(root_dir: str) -> list[str]:
|
| 305 |
-
root = Path(root_dir).expanduser()
|
| 306 |
-
if not root.is_dir():
|
| 307 |
-
return []
|
| 308 |
-
files = _dedupe_probe_entries([
|
| 309 |
-
str(path.relative_to(root))
|
| 310 |
-
for path in root.rglob("*")
|
| 311 |
-
if path.is_file() and path.name in {"probe.pt", "probe.json", "weights.safetensors"}
|
| 312 |
-
])
|
| 313 |
-
return sorted(files, key=_probe_sort_key)
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
@st.cache_data(show_spinner=False, ttl=300)
|
| 317 |
-
def download_probe_file(repo_id: str, filename: str) -> str:
|
| 318 |
-
from huggingface_hub import hf_hub_download
|
| 319 |
-
|
| 320 |
-
return hf_hub_download(repo_id, filename, repo_type="model")
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
@st.cache_data(show_spinner=False, ttl=300)
|
| 324 |
-
def download_probe_json_and_weights(repo_id: str, filename: str) -> tuple[str, str]:
|
| 325 |
-
from huggingface_hub import hf_hub_download
|
| 326 |
-
|
| 327 |
-
metadata_path = hf_hub_download(repo_id, filename, repo_type="model")
|
| 328 |
-
weights_name = str(Path(filename).with_name("weights.safetensors"))
|
| 329 |
-
weights_path = hf_hub_download(repo_id, weights_name, repo_type="model")
|
| 330 |
-
return metadata_path, weights_path
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
@st.cache_resource(show_spinner=False)
|
| 334 |
def load_probe(repo_id: str, filename: str) -> LoadedProbe:
|
| 335 |
if filename.endswith("probe.json"):
|
| 336 |
metadata_path, weights_path = download_probe_json_and_weights(repo_id, filename)
|
|
@@ -346,7 +226,7 @@ def load_probe(repo_id: str, filename: str) -> LoadedProbe:
|
|
| 346 |
)
|
| 347 |
|
| 348 |
|
| 349 |
-
@st.cache_resource(show_spinner=False)
|
| 350 |
def load_local_probe(root_dir: str, filename: str) -> LoadedProbe:
|
| 351 |
root = Path(root_dir).expanduser()
|
| 352 |
path = (root / filename).resolve()
|
|
@@ -370,7 +250,7 @@ def load_local_probe(root_dir: str, filename: str) -> LoadedProbe:
|
|
| 370 |
)
|
| 371 |
|
| 372 |
|
| 373 |
-
@st.cache_resource(show_spinner=False)
|
| 374 |
def load_probe_from_bytes(filename: str, data: bytes) -> LoadedProbe:
|
| 375 |
return _load_probe_payload(
|
| 376 |
filename=filename,
|
|
@@ -432,16 +312,20 @@ def _load_probe_payload(
|
|
| 432 |
_optional_str(payload.get("attribute_name")) or metadata.attribute_name
|
| 433 |
),
|
| 434 |
feature_space=(
|
| 435 |
-
(
|
| 436 |
-
|
| 437 |
-
|
|
|
|
|
|
|
| 438 |
or _optional_str(payload.get("feature_space"))
|
| 439 |
or metadata.scope
|
| 440 |
),
|
| 441 |
task=_optional_str(payload.get("task")),
|
| 442 |
probe_kind=_optional_str(payload.get("probe_kind")),
|
| 443 |
scaler_mean=_as_cpu_tensor(payload.get("scaler_mean")),
|
| 444 |
-
scaler_std=_as_cpu_tensor(
|
|
|
|
|
|
|
| 445 |
pca_mean=_as_cpu_tensor(payload.get("pca_mean")),
|
| 446 |
pca_components=_as_cpu_tensor(payload.get("pca_components")),
|
| 447 |
)
|
|
@@ -617,39 +501,6 @@ def _first_present(payload: dict[str, Any], *keys: str) -> Any:
|
|
| 617 |
return None
|
| 618 |
|
| 619 |
|
| 620 |
-
def _probe_sort_key(filename: str) -> tuple[str, str, int, str]:
|
| 621 |
-
metadata = parse_probe_filename(filename)
|
| 622 |
-
return (
|
| 623 |
-
metadata.model_name or "",
|
| 624 |
-
metadata.attribute_name or "",
|
| 625 |
-
metadata.layer if metadata.layer is not None else 10**9,
|
| 626 |
-
filename,
|
| 627 |
-
)
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
def _dedupe_probe_entries(files: list[str]) -> list[str]:
|
| 631 |
-
by_dir: dict[str, set[str]] = {}
|
| 632 |
-
standalone: list[str] = []
|
| 633 |
-
for filename in files:
|
| 634 |
-
path = Path(filename)
|
| 635 |
-
if path.name in {"probe.pt", "probe.json", "weights.safetensors"}:
|
| 636 |
-
by_dir.setdefault(str(path.parent), set()).add(path.name)
|
| 637 |
-
elif filename.endswith(".pt"):
|
| 638 |
-
standalone.append(filename)
|
| 639 |
-
|
| 640 |
-
entries = list(standalone)
|
| 641 |
-
for directory, names in by_dir.items():
|
| 642 |
-
selected = (
|
| 643 |
-
"probe.json"
|
| 644 |
-
if "probe.json" in names
|
| 645 |
-
else "probe.pt"
|
| 646 |
-
if "probe.pt" in names
|
| 647 |
-
else "weights.safetensors"
|
| 648 |
-
)
|
| 649 |
-
entries.append(str(Path(directory) / selected))
|
| 650 |
-
return sorted(entries, key=_probe_sort_key)
|
| 651 |
-
|
| 652 |
-
|
| 653 |
def _normalize_labels(raw_labels: Any, num_classes: int) -> list[str | None]:
|
| 654 |
if isinstance(raw_labels, (list, tuple)):
|
| 655 |
labels = [str(label) for label in raw_labels[:num_classes]]
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import io
|
|
|
|
|
|
|
| 4 |
from dataclasses import dataclass
|
| 5 |
from pathlib import Path
|
| 6 |
from typing import Any
|
|
|
|
| 11 |
import torch.nn.functional as F
|
| 12 |
from persona_vectors.probes import ProbeArtifact, load_probe_artifact
|
| 13 |
|
| 14 |
+
from utils.helpers import env_int
|
| 15 |
+
from utils.probe_files import (
|
| 16 |
+
download_probe_file,
|
| 17 |
+
download_probe_json_and_weights,
|
| 18 |
+
parse_probe_filename,
|
| 19 |
)
|
| 20 |
|
| 21 |
+
_PROBE_CACHE_ENTRIES = env_int("PERSONA_UI_PROBE_CACHE_ENTRIES", 8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
@dataclass(frozen=True)
|
|
|
|
| 174 |
predicted = torch.argmax(probs, dim=-1)
|
| 175 |
return logits, probs, predicted
|
| 176 |
|
| 177 |
+
def _forward_batch(self, batch: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
|
|
| 178 |
normalized = self._normalize_batch(batch)
|
| 179 |
with torch.no_grad():
|
| 180 |
logits = self.model(normalized).detach().cpu()
|
|
|
|
| 210 |
return batch
|
| 211 |
|
| 212 |
|
| 213 |
+
@st.cache_resource(show_spinner=False, max_entries=_PROBE_CACHE_ENTRIES)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
def load_probe(repo_id: str, filename: str) -> LoadedProbe:
|
| 215 |
if filename.endswith("probe.json"):
|
| 216 |
metadata_path, weights_path = download_probe_json_and_weights(repo_id, filename)
|
|
|
|
| 226 |
)
|
| 227 |
|
| 228 |
|
| 229 |
+
@st.cache_resource(show_spinner=False, max_entries=_PROBE_CACHE_ENTRIES)
|
| 230 |
def load_local_probe(root_dir: str, filename: str) -> LoadedProbe:
|
| 231 |
root = Path(root_dir).expanduser()
|
| 232 |
path = (root / filename).resolve()
|
|
|
|
| 250 |
)
|
| 251 |
|
| 252 |
|
| 253 |
+
@st.cache_resource(show_spinner=False, max_entries=_PROBE_CACHE_ENTRIES)
|
| 254 |
def load_probe_from_bytes(filename: str, data: bytes) -> LoadedProbe:
|
| 255 |
return _load_probe_payload(
|
| 256 |
filename=filename,
|
|
|
|
| 312 |
_optional_str(payload.get("attribute_name")) or metadata.attribute_name
|
| 313 |
),
|
| 314 |
feature_space=(
|
| 315 |
+
(
|
| 316 |
+
f"pca{payload['n_pca_components']}"
|
| 317 |
+
if payload.get("n_pca_components")
|
| 318 |
+
else None
|
| 319 |
+
)
|
| 320 |
or _optional_str(payload.get("feature_space"))
|
| 321 |
or metadata.scope
|
| 322 |
),
|
| 323 |
task=_optional_str(payload.get("task")),
|
| 324 |
probe_kind=_optional_str(payload.get("probe_kind")),
|
| 325 |
scaler_mean=_as_cpu_tensor(payload.get("scaler_mean")),
|
| 326 |
+
scaler_std=_as_cpu_tensor(
|
| 327 |
+
_first_present(payload, "scaler_std", "scaler_scale")
|
| 328 |
+
),
|
| 329 |
pca_mean=_as_cpu_tensor(payload.get("pca_mean")),
|
| 330 |
pca_components=_as_cpu_tensor(payload.get("pca_components")),
|
| 331 |
)
|
|
|
|
| 501 |
return None
|
| 502 |
|
| 503 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
def _normalize_labels(raw_labels: Any, num_classes: int) -> list[str | None]:
|
| 505 |
if isinstance(raw_labels, (list, tuple)):
|
| 506 |
labels = [str(label) for label in raw_labels[:num_classes]]
|
utils/selection_controls.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections.abc import Sequence
|
| 4 |
+
|
| 5 |
+
import streamlit as st
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def remembered_segmented_control(
|
| 9 |
+
label: str,
|
| 10 |
+
*,
|
| 11 |
+
options: Sequence[str],
|
| 12 |
+
key: str,
|
| 13 |
+
remember_key: str | None = None,
|
| 14 |
+
default: str | None = None,
|
| 15 |
+
label_visibility: str = "visible",
|
| 16 |
+
) -> str:
|
| 17 |
+
"""Render a segmented control with one small, reusable memory pattern."""
|
| 18 |
+
fallback = default or options[0]
|
| 19 |
+
remembered = st.session_state.get(
|
| 20 |
+
remember_key,
|
| 21 |
+
st.session_state.get(key, fallback),
|
| 22 |
+
)
|
| 23 |
+
selected = (
|
| 24 |
+
st.segmented_control(
|
| 25 |
+
label,
|
| 26 |
+
options=options,
|
| 27 |
+
default=remembered if remembered in options else fallback,
|
| 28 |
+
key=key,
|
| 29 |
+
label_visibility=label_visibility,
|
| 30 |
+
)
|
| 31 |
+
or fallback
|
| 32 |
+
)
|
| 33 |
+
if remember_key is not None:
|
| 34 |
+
st.session_state[remember_key] = selected
|
| 35 |
+
return selected
|
utils/source_controls.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import streamlit as st
|
| 6 |
+
from persona_data.environment import get_artifacts_dir
|
| 7 |
+
from persona_vectors.extraction import MaskStrategy
|
| 8 |
+
|
| 9 |
+
from utils.analysis_sources import (
|
| 10 |
+
DEFAULT_COMPARE_MODEL,
|
| 11 |
+
DEFAULT_HUB_REPO,
|
| 12 |
+
SOURCE_HUB,
|
| 13 |
+
SOURCE_LOCAL,
|
| 14 |
+
SOURCES,
|
| 15 |
+
Store,
|
| 16 |
+
activation_store_cached,
|
| 17 |
+
hub_models_by_mask_strategy,
|
| 18 |
+
local_model_matches,
|
| 19 |
+
local_model_options_cached,
|
| 20 |
+
)
|
| 21 |
+
from utils.helpers import widget_key
|
| 22 |
+
from utils.selection_controls import remembered_segmented_control
|
| 23 |
+
|
| 24 |
+
_SHARED_SOURCE_KEY = "source:last_source"
|
| 25 |
+
_SHARED_HUB_REPO_KEY = "source:hub_repo"
|
| 26 |
+
_SHARED_HUB_MODEL_KEY = "source:hub_model"
|
| 27 |
+
_SHARED_LOCAL_ROOT_KEY = "source:local_root"
|
| 28 |
+
_SHARED_LOCAL_MODEL_KEY = "source:local_model"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def render_source_select(
|
| 32 |
+
*,
|
| 33 |
+
widget_scope: str,
|
| 34 |
+
last_source_key: str | None = None,
|
| 35 |
+
) -> str:
|
| 36 |
+
key = widget_key(widget_scope, "source")
|
| 37 |
+
if last_source_key is not None and last_source_key not in st.session_state:
|
| 38 |
+
shared_source = st.session_state.get(_SHARED_SOURCE_KEY)
|
| 39 |
+
if shared_source is not None:
|
| 40 |
+
st.session_state[last_source_key] = shared_source
|
| 41 |
+
selected = remembered_segmented_control(
|
| 42 |
+
"Source",
|
| 43 |
+
options=SOURCES,
|
| 44 |
+
key=key,
|
| 45 |
+
remember_key=last_source_key or _SHARED_SOURCE_KEY,
|
| 46 |
+
default=SOURCE_HUB,
|
| 47 |
+
label_visibility="collapsed",
|
| 48 |
+
)
|
| 49 |
+
st.session_state[_SHARED_SOURCE_KEY] = selected
|
| 50 |
+
if last_source_key is not None:
|
| 51 |
+
st.session_state[last_source_key] = selected
|
| 52 |
+
return selected
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _render_hub_model_select(
|
| 56 |
+
*,
|
| 57 |
+
state_prefix: str,
|
| 58 |
+
widget_scope: str,
|
| 59 |
+
repo_id: str,
|
| 60 |
+
mask_strategy: MaskStrategy,
|
| 61 |
+
model_label: str,
|
| 62 |
+
fallback_help: str,
|
| 63 |
+
selection_help: str,
|
| 64 |
+
) -> str:
|
| 65 |
+
fallback_key = f"{state_prefix}:hub_model_fallback"
|
| 66 |
+
fallback_model = st.session_state.get(
|
| 67 |
+
fallback_key,
|
| 68 |
+
st.session_state.get(_SHARED_HUB_MODEL_KEY, DEFAULT_COMPARE_MODEL),
|
| 69 |
+
)
|
| 70 |
+
try:
|
| 71 |
+
models_by_strategy = hub_models_by_mask_strategy(repo_id)
|
| 72 |
+
except Exception as exc:
|
| 73 |
+
st.warning(f"Could not load Hub configs for `{repo_id}`: {exc}")
|
| 74 |
+
model = st.text_input(
|
| 75 |
+
model_label,
|
| 76 |
+
value=fallback_model,
|
| 77 |
+
key=fallback_key,
|
| 78 |
+
help=fallback_help,
|
| 79 |
+
)
|
| 80 |
+
st.session_state[_SHARED_HUB_MODEL_KEY] = model
|
| 81 |
+
return model
|
| 82 |
+
|
| 83 |
+
model_options = models_by_strategy.get(mask_strategy, [])
|
| 84 |
+
if not model_options:
|
| 85 |
+
st.warning(
|
| 86 |
+
f"No Hub vector configs found for `{mask_strategy.value}` in `{repo_id}`."
|
| 87 |
+
)
|
| 88 |
+
model = st.text_input(
|
| 89 |
+
model_label,
|
| 90 |
+
value=fallback_model,
|
| 91 |
+
key=fallback_key,
|
| 92 |
+
help=fallback_help,
|
| 93 |
+
)
|
| 94 |
+
st.session_state[_SHARED_HUB_MODEL_KEY] = model
|
| 95 |
+
return model
|
| 96 |
+
|
| 97 |
+
select_key = widget_key(widget_scope, "hub_model", repo_id, mask_strategy.value)
|
| 98 |
+
previous_model = st.session_state.get(
|
| 99 |
+
select_key,
|
| 100 |
+
st.session_state.get(_SHARED_HUB_MODEL_KEY, fallback_model),
|
| 101 |
+
)
|
| 102 |
+
default_model = (
|
| 103 |
+
previous_model if previous_model in model_options else model_options[0]
|
| 104 |
+
)
|
| 105 |
+
selected = st.selectbox(
|
| 106 |
+
model_label,
|
| 107 |
+
options=model_options,
|
| 108 |
+
index=model_options.index(default_model),
|
| 109 |
+
key=select_key,
|
| 110 |
+
help=selection_help,
|
| 111 |
+
)
|
| 112 |
+
st.session_state[fallback_key] = selected
|
| 113 |
+
st.session_state[_SHARED_HUB_MODEL_KEY] = selected
|
| 114 |
+
return selected
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _render_local_model_select(
|
| 118 |
+
*,
|
| 119 |
+
state_prefix: str,
|
| 120 |
+
artifacts_root: str,
|
| 121 |
+
mask_strategy: MaskStrategy,
|
| 122 |
+
allow_custom_toggle: bool,
|
| 123 |
+
model_label: str,
|
| 124 |
+
) -> str:
|
| 125 |
+
fallback_key = f"{state_prefix}:local_model"
|
| 126 |
+
fallback_model = st.session_state.get(
|
| 127 |
+
fallback_key,
|
| 128 |
+
st.session_state.get(_SHARED_LOCAL_MODEL_KEY, DEFAULT_COMPARE_MODEL),
|
| 129 |
+
)
|
| 130 |
+
model_options = local_model_options_cached(artifacts_root, mask_strategy.value)
|
| 131 |
+
if not model_options:
|
| 132 |
+
model = st.text_input(model_label, value=fallback_model, key=fallback_key)
|
| 133 |
+
st.session_state[_SHARED_LOCAL_MODEL_KEY] = model
|
| 134 |
+
return model
|
| 135 |
+
|
| 136 |
+
if allow_custom_toggle:
|
| 137 |
+
custom = st.toggle(
|
| 138 |
+
"Custom local model",
|
| 139 |
+
value=False,
|
| 140 |
+
key=f"{state_prefix}:local_model_custom_enabled",
|
| 141 |
+
help="Enter a model id/path manually instead of choosing from activation directories.",
|
| 142 |
+
)
|
| 143 |
+
if custom:
|
| 144 |
+
model = st.text_input("Local model", value=fallback_model, key=fallback_key)
|
| 145 |
+
st.session_state[_SHARED_LOCAL_MODEL_KEY] = model
|
| 146 |
+
return model
|
| 147 |
+
|
| 148 |
+
select_key = f"{state_prefix}:local_model_select"
|
| 149 |
+
previous_model = st.session_state.get(
|
| 150 |
+
select_key,
|
| 151 |
+
st.session_state.get(_SHARED_LOCAL_MODEL_KEY, fallback_model),
|
| 152 |
+
)
|
| 153 |
+
if not any(local_model_matches(previous_model, option) for option in model_options):
|
| 154 |
+
previous_model = fallback_model
|
| 155 |
+
default_model = next(
|
| 156 |
+
(
|
| 157 |
+
option
|
| 158 |
+
for option in model_options
|
| 159 |
+
if local_model_matches(option, previous_model)
|
| 160 |
+
),
|
| 161 |
+
model_options[0],
|
| 162 |
+
)
|
| 163 |
+
selected = st.selectbox(
|
| 164 |
+
model_label,
|
| 165 |
+
options=model_options,
|
| 166 |
+
index=model_options.index(default_model),
|
| 167 |
+
key=select_key,
|
| 168 |
+
help="Models discovered under the selected artifacts root.",
|
| 169 |
+
)
|
| 170 |
+
st.session_state[fallback_key] = selected
|
| 171 |
+
st.session_state[_SHARED_LOCAL_MODEL_KEY] = selected
|
| 172 |
+
return selected
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def render_store_select(
|
| 176 |
+
source: str,
|
| 177 |
+
mask_strategy: MaskStrategy,
|
| 178 |
+
*,
|
| 179 |
+
state_prefix: str,
|
| 180 |
+
widget_scope: str,
|
| 181 |
+
artifacts_root_key: str,
|
| 182 |
+
model_label: str = "Model",
|
| 183 |
+
local_model_label: str = "Model",
|
| 184 |
+
allow_custom_local_model: bool = False,
|
| 185 |
+
repo_help: str | None = None,
|
| 186 |
+
fallback_help: str = "Model id to use if Hub config discovery is unavailable.",
|
| 187 |
+
) -> Store:
|
| 188 |
+
if source == SOURCE_HUB:
|
| 189 |
+
repo_key = f"{state_prefix}:hub_repo"
|
| 190 |
+
repo = st.text_input(
|
| 191 |
+
"Hub repo",
|
| 192 |
+
value=st.session_state.get(
|
| 193 |
+
repo_key,
|
| 194 |
+
st.session_state.get(_SHARED_HUB_REPO_KEY, DEFAULT_HUB_REPO),
|
| 195 |
+
),
|
| 196 |
+
key=repo_key,
|
| 197 |
+
help=repo_help,
|
| 198 |
+
)
|
| 199 |
+
st.session_state[_SHARED_HUB_REPO_KEY] = repo
|
| 200 |
+
model_name = _render_hub_model_select(
|
| 201 |
+
state_prefix=state_prefix,
|
| 202 |
+
widget_scope=widget_scope,
|
| 203 |
+
repo_id=repo,
|
| 204 |
+
mask_strategy=mask_strategy,
|
| 205 |
+
model_label=model_label,
|
| 206 |
+
fallback_help=fallback_help,
|
| 207 |
+
selection_help="Models with vectors in the selected Hub repo and mask strategy.",
|
| 208 |
+
)
|
| 209 |
+
return activation_store_cached(
|
| 210 |
+
SOURCE_HUB, repo, model_name, mask_strategy.value
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
root = st.text_input(
|
| 214 |
+
"Artifacts root",
|
| 215 |
+
value=st.session_state.get(
|
| 216 |
+
_SHARED_LOCAL_ROOT_KEY,
|
| 217 |
+
str(get_artifacts_dir() / "activations"),
|
| 218 |
+
),
|
| 219 |
+
key=artifacts_root_key,
|
| 220 |
+
)
|
| 221 |
+
root = str(Path(root).expanduser())
|
| 222 |
+
st.session_state[_SHARED_LOCAL_ROOT_KEY] = root
|
| 223 |
+
model_name = _render_local_model_select(
|
| 224 |
+
state_prefix=state_prefix,
|
| 225 |
+
artifacts_root=root,
|
| 226 |
+
mask_strategy=mask_strategy,
|
| 227 |
+
allow_custom_toggle=allow_custom_local_model,
|
| 228 |
+
model_label=local_model_label,
|
| 229 |
+
)
|
| 230 |
+
return activation_store_cached(SOURCE_LOCAL, root, model_name, mask_strategy.value)
|
uv.lock
CHANGED
|
@@ -464,11 +464,11 @@ wheels = [
|
|
| 464 |
|
| 465 |
[[package]]
|
| 466 |
name = "decorator"
|
| 467 |
-
version = "5.3.
|
| 468 |
source = { registry = "https://pypi.org/simple" }
|
| 469 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 470 |
wheels = [
|
| 471 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 472 |
]
|
| 473 |
|
| 474 |
[[package]]
|
|
@@ -1585,7 +1585,7 @@ wheels = [
|
|
| 1585 |
|
| 1586 |
[[package]]
|
| 1587 |
name = "persona-ui"
|
| 1588 |
-
version = "0.
|
| 1589 |
source = { virtual = "." }
|
| 1590 |
dependencies = [
|
| 1591 |
{ name = "catppuccin" },
|
|
@@ -2145,11 +2145,11 @@ wheels = [
|
|
| 2145 |
|
| 2146 |
[[package]]
|
| 2147 |
name = "python-multipart"
|
| 2148 |
-
version = "0.0.
|
| 2149 |
source = { registry = "https://pypi.org/simple" }
|
| 2150 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 2151 |
wheels = [
|
| 2152 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 2153 |
]
|
| 2154 |
|
| 2155 |
[[package]]
|
|
|
|
| 464 |
|
| 465 |
[[package]]
|
| 466 |
name = "decorator"
|
| 467 |
+
version = "5.3.1"
|
| 468 |
source = { registry = "https://pypi.org/simple" }
|
| 469 |
+
sdist = { url = "https://files.pythonhosted.org/packages/60/8b/32f9823da46cde7df2087faa08cd98d01b908f8dcab982cdba9c84e85355/decorator-5.3.1.tar.gz", hash = "sha256:4cbcdd55a6efadb9dbea26b858f4fb3264567b52d69ca0d25b721b553f60ea82", size = 58084, upload-time = "2026-05-18T06:03:28.057Z" }
|
| 470 |
wheels = [
|
| 471 |
+
{ url = "https://files.pythonhosted.org/packages/05/7f/798705f5296a58ca505d600456748d1be48078eac8a7050d8a98bc9edb89/decorator-5.3.1-py3-none-any.whl", hash = "sha256:f47fe6fdbd2edd623ecfe36875d37aba411624e2670dd395dddae1358689bb3c", size = 10365, upload-time = "2026-05-18T06:03:26.517Z" },
|
| 472 |
]
|
| 473 |
|
| 474 |
[[package]]
|
|
|
|
| 1585 |
|
| 1586 |
[[package]]
|
| 1587 |
name = "persona-ui"
|
| 1588 |
+
version = "0.5.0"
|
| 1589 |
source = { virtual = "." }
|
| 1590 |
dependencies = [
|
| 1591 |
{ name = "catppuccin" },
|
|
|
|
| 2145 |
|
| 2146 |
[[package]]
|
| 2147 |
name = "python-multipart"
|
| 2148 |
+
version = "0.0.29"
|
| 2149 |
source = { registry = "https://pypi.org/simple" }
|
| 2150 |
+
sdist = { url = "https://files.pythonhosted.org/packages/4e/fe/70bd71a6738b09a0bdf6480ca6436b167469ca4578b2a0efbe390b4b0e70/python_multipart-0.0.29.tar.gz", hash = "sha256:643e93849196645e2dbdd81a0f8829a23123ad7f797a84a364c6fb3563f18904", size = 45678, upload-time = "2026-05-17T17:29:47.654Z" }
|
| 2151 |
wheels = [
|
| 2152 |
+
{ url = "https://files.pythonhosted.org/packages/8f/cb/769cfc37177252872a45a71f3fbdde9d51b471a3f3c14bfe95dde3407386/python_multipart-0.0.29-py3-none-any.whl", hash = "sha256:2ddcc971cef266225f54f552d8fa10bcfbb1f14446caec199060daac59ff2d69", size = 29640, upload-time = "2026-05-17T17:29:45.69Z" },
|
| 2153 |
]
|
| 2154 |
|
| 2155 |
[[package]]
|