Jac-Zac commited on
Commit ·
db3d901
1
Parent(s): 99c28ab
Big refactor and feature addition to analyses + support latest persona-vector
Browse files- README.md +2 -2
- app.py +35 -26
- pyproject.toml +1 -2
- state.py +5 -5
- tabs/analysis.py +1 -0
- tabs/{compare.py → analysis_core.py} +649 -198
- tabs/chat.py +63 -68
- tabs/chat_shared.py +105 -0
- tabs/chat_ui.py +16 -12
- tabs/compare_chat.py +26 -28
- tabs/extract.py +20 -18
- tabs/probe_ui.py +1 -1
- utils/{compare_sources.py → analysis_sources.py} +44 -192
- utils/chat.py +6 -10
- utils/chat_export.py +2 -2
- utils/contrast.py +3 -1
- utils/datasets.py +11 -5
- utils/helpers.py +26 -16
- uv.lock +13 -15
README.md
CHANGED
|
@@ -20,7 +20,7 @@ Streamlit interface for persona vector extraction, analysis, and chat.
|
|
| 20 |
A web app built on top of [persona-vectors](../persona-vectors) that provides three tabs:
|
| 21 |
|
| 22 |
- **Chat** — interactive conversations with a model using persona-based system prompts (templated or biography)
|
| 23 |
-
- **Compare** — load local or Hub persona vectors and explore cosine similarity, PCA, UMAP,
|
| 24 |
- **Extract** — run activation extraction from HuggingFace persona datasets or a local JSONL dataset directly from the browser
|
| 25 |
|
| 26 |
## Repository Layout
|
|
@@ -31,7 +31,7 @@ persona-ui/
|
|
| 31 |
├── state.py # Session state management (chat history, KV cache)
|
| 32 |
├── tabs/
|
| 33 |
│ ├── chat.py # Chat tab
|
| 34 |
-
│ ├──
|
| 35 |
│ ├── compare_chat.py # Side-by-side chat comparison mode
|
| 36 |
│ ├── extract.py # Extraction tab
|
| 37 |
│ └── probe_ui.py # Probe upload and tracing controls
|
|
|
|
| 20 |
A web app built on top of [persona-vectors](../persona-vectors) that provides three tabs:
|
| 21 |
|
| 22 |
- **Chat** — interactive conversations with a model using persona-based system prompts (templated or biography)
|
| 23 |
+
- **Compare** — load local or Hub persona vectors and explore cosine similarity, PCA, UMAP, attribute-colored projections, and dendrograms
|
| 24 |
- **Extract** — run activation extraction from HuggingFace persona datasets or a local JSONL dataset directly from the browser
|
| 25 |
|
| 26 |
## Repository Layout
|
|
|
|
| 31 |
├── state.py # Session state management (chat history, KV cache)
|
| 32 |
├── tabs/
|
| 33 |
│ ├── chat.py # Chat tab
|
| 34 |
+
│ ├── analysis.py # Analysis tab (cosine similarity, PCA, UMAP, Isomap, dendrogram)
|
| 35 |
│ ├── compare_chat.py # Side-by-side chat comparison mode
|
| 36 |
│ ├── extract.py # Extraction tab
|
| 37 |
│ └── probe_ui.py # Probe upload and tracing controls
|
app.py
CHANGED
|
@@ -4,13 +4,26 @@ from dataclasses import dataclass
|
|
| 4 |
import streamlit as st
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
|
| 7 |
-
from utils.helpers import DATASET_SOURCES
|
|
|
|
|
|
|
| 8 |
|
| 9 |
load_dotenv()
|
| 10 |
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
|
| 11 |
REMOTE_DEFAULT_MODEL = os.environ.get("REMOTE_DEFAULT_MODEL", "google/gemma-2-9b-it")
|
| 12 |
-
_LAST_LOCAL_MODEL_KEY = "sidebar
|
| 13 |
-
_LAST_REMOTE_MODEL_KEY = "sidebar
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
_TABS = ["Chat", "Analysis", "Extract"]
|
|
@@ -35,9 +48,9 @@ def _remote_model_input(remote_models: list[str]) -> str:
|
|
| 35 |
model_name = st.text_input(
|
| 36 |
"Model",
|
| 37 |
value=st.session_state.get(
|
| 38 |
-
|
| 39 |
),
|
| 40 |
-
key=
|
| 41 |
help="NDIF model id. Use this to cold-load a remote model.",
|
| 42 |
)
|
| 43 |
st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name
|
|
@@ -46,16 +59,16 @@ def _remote_model_input(remote_models: list[str]) -> str:
|
|
| 46 |
custom = st.toggle(
|
| 47 |
"Custom remote model",
|
| 48 |
value=False,
|
| 49 |
-
key=
|
| 50 |
help="Enter any NDIF-loadable model id, even if it is not currently running.",
|
| 51 |
)
|
| 52 |
if custom:
|
| 53 |
model_name = st.text_input(
|
| 54 |
"Model",
|
| 55 |
value=st.session_state.get(
|
| 56 |
-
|
| 57 |
),
|
| 58 |
-
key=
|
| 59 |
help="NDIF model id. Example: openai/gpt-oss-20b",
|
| 60 |
)
|
| 61 |
st.caption(
|
|
@@ -63,20 +76,20 @@ def _remote_model_input(remote_models: list[str]) -> str:
|
|
| 63 |
"Custom model ids can cold-load if your NDIF account allows it."
|
| 64 |
)
|
| 65 |
else:
|
| 66 |
-
default_model = st.session_state.get(
|
| 67 |
if default_model not in remote_models:
|
| 68 |
default_model = (
|
| 69 |
REMOTE_DEFAULT_MODEL
|
| 70 |
if REMOTE_DEFAULT_MODEL in remote_models
|
| 71 |
else remote_models[0]
|
| 72 |
)
|
| 73 |
-
if st.session_state.get(
|
| 74 |
-
st.session_state[
|
| 75 |
model_name = st.selectbox(
|
| 76 |
"Model",
|
| 77 |
options=remote_models,
|
| 78 |
index=remote_models.index(default_model),
|
| 79 |
-
key=
|
| 80 |
help="Running NDIF model.",
|
| 81 |
)
|
| 82 |
st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name
|
|
@@ -84,15 +97,13 @@ def _remote_model_input(remote_models: list[str]) -> str:
|
|
| 84 |
|
| 85 |
|
| 86 |
def _sidebar_controls() -> SidebarState:
|
| 87 |
-
from utils.runtime import list_remote_models
|
| 88 |
-
|
| 89 |
with st.sidebar:
|
| 90 |
st.markdown("## Persona UI")
|
| 91 |
|
| 92 |
-
if
|
| 93 |
-
st.session_state[
|
| 94 |
|
| 95 |
-
active_tab = st.session_state[
|
| 96 |
for tab_name, icon in zip(_TABS, _TAB_ICONS, strict=True):
|
| 97 |
is_selected = tab_name == active_tab
|
| 98 |
if st.button(
|
|
@@ -102,13 +113,13 @@ def _sidebar_controls() -> SidebarState:
|
|
| 102 |
type="primary" if is_selected else "secondary",
|
| 103 |
icon=icon,
|
| 104 |
):
|
| 105 |
-
st.session_state[
|
| 106 |
st.rerun()
|
| 107 |
|
| 108 |
if active_tab == "Analysis":
|
| 109 |
model_name = st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL)
|
| 110 |
dataset_source = st.session_state.get(
|
| 111 |
-
|
| 112 |
DATASET_SOURCES[0],
|
| 113 |
)
|
| 114 |
return SidebarState(
|
|
@@ -120,7 +131,7 @@ def _sidebar_controls() -> SidebarState:
|
|
| 120 |
|
| 121 |
st.divider()
|
| 122 |
st.caption("Runtime")
|
| 123 |
-
remote = st.toggle("Remote (NDIF)", value=False, key=
|
| 124 |
|
| 125 |
if remote:
|
| 126 |
model_name = _remote_model_input(list_remote_models())
|
|
@@ -128,7 +139,7 @@ def _sidebar_controls() -> SidebarState:
|
|
| 128 |
model_name = st.text_input(
|
| 129 |
"Model",
|
| 130 |
value=st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL),
|
| 131 |
-
key=
|
| 132 |
help="Local model id or path.",
|
| 133 |
)
|
| 134 |
st.session_state[_LAST_LOCAL_MODEL_KEY] = model_name
|
|
@@ -137,7 +148,7 @@ def _sidebar_controls() -> SidebarState:
|
|
| 137 |
dataset_source = st.selectbox(
|
| 138 |
"Source",
|
| 139 |
DATASET_SOURCES,
|
| 140 |
-
key=
|
| 141 |
help="Dataset for Chat and Extract.",
|
| 142 |
)
|
| 143 |
|
|
@@ -153,8 +164,6 @@ def main() -> None:
|
|
| 153 |
"""Run the Streamlit app."""
|
| 154 |
|
| 155 |
st.set_page_config(page_title="Persona UI", layout="wide")
|
| 156 |
-
from utils.theme import install_catppuccin_theme
|
| 157 |
-
|
| 158 |
install_catppuccin_theme(st.get_option("theme.base"))
|
| 159 |
|
| 160 |
sidebar = _sidebar_controls()
|
|
@@ -164,9 +173,9 @@ def main() -> None:
|
|
| 164 |
|
| 165 |
render_extract_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)
|
| 166 |
elif sidebar.active_tab == "Analysis":
|
| 167 |
-
from tabs.
|
| 168 |
|
| 169 |
-
|
| 170 |
else:
|
| 171 |
from tabs.chat import render_chat_tab
|
| 172 |
|
|
|
|
| 4 |
import streamlit as st
|
| 5 |
from dotenv import load_dotenv
|
| 6 |
|
| 7 |
+
from utils.helpers import DATASET_SOURCES, session_key
|
| 8 |
+
from utils.runtime import list_remote_models
|
| 9 |
+
from utils.theme import install_catppuccin_theme
|
| 10 |
|
| 11 |
load_dotenv()
|
| 12 |
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
|
| 13 |
REMOTE_DEFAULT_MODEL = os.environ.get("REMOTE_DEFAULT_MODEL", "google/gemma-2-9b-it")
|
| 14 |
+
_LAST_LOCAL_MODEL_KEY = session_key("sidebar", "last_local_model")
|
| 15 |
+
_LAST_REMOTE_MODEL_KEY = session_key("sidebar", "last_remote_model")
|
| 16 |
+
_SIDEBAR_ACTIVE_TAB_KEY = session_key("sidebar", "active_tab")
|
| 17 |
+
_SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY = session_key(
|
| 18 |
+
"sidebar", "remote_model_custom_value"
|
| 19 |
+
)
|
| 20 |
+
_SIDEBAR_REMOTE_MODEL_CUSTOM_ENABLED_KEY = session_key(
|
| 21 |
+
"sidebar", "remote_model_custom_enabled"
|
| 22 |
+
)
|
| 23 |
+
_SIDEBAR_REMOTE_MODEL_KEY = session_key("sidebar", "remote_model")
|
| 24 |
+
_SIDEBAR_LOCAL_MODEL_KEY = session_key("sidebar", "local_model")
|
| 25 |
+
_SIDEBAR_REMOTE_KEY = session_key("sidebar", "remote")
|
| 26 |
+
_SIDEBAR_DATASET_SOURCE_KEY = session_key("sidebar", "dataset_source")
|
| 27 |
|
| 28 |
|
| 29 |
_TABS = ["Chat", "Analysis", "Extract"]
|
|
|
|
| 48 |
model_name = st.text_input(
|
| 49 |
"Model",
|
| 50 |
value=st.session_state.get(
|
| 51 |
+
_SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY, last_remote
|
| 52 |
),
|
| 53 |
+
key=_SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY,
|
| 54 |
help="NDIF model id. Use this to cold-load a remote model.",
|
| 55 |
)
|
| 56 |
st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name
|
|
|
|
| 59 |
custom = st.toggle(
|
| 60 |
"Custom remote model",
|
| 61 |
value=False,
|
| 62 |
+
key=_SIDEBAR_REMOTE_MODEL_CUSTOM_ENABLED_KEY,
|
| 63 |
help="Enter any NDIF-loadable model id, even if it is not currently running.",
|
| 64 |
)
|
| 65 |
if custom:
|
| 66 |
model_name = st.text_input(
|
| 67 |
"Model",
|
| 68 |
value=st.session_state.get(
|
| 69 |
+
_SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY, last_remote
|
| 70 |
),
|
| 71 |
+
key=_SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY,
|
| 72 |
help="NDIF model id. Example: openai/gpt-oss-20b",
|
| 73 |
)
|
| 74 |
st.caption(
|
|
|
|
| 76 |
"Custom model ids can cold-load if your NDIF account allows it."
|
| 77 |
)
|
| 78 |
else:
|
| 79 |
+
default_model = st.session_state.get(_SIDEBAR_REMOTE_MODEL_KEY, last_remote)
|
| 80 |
if default_model not in remote_models:
|
| 81 |
default_model = (
|
| 82 |
REMOTE_DEFAULT_MODEL
|
| 83 |
if REMOTE_DEFAULT_MODEL in remote_models
|
| 84 |
else remote_models[0]
|
| 85 |
)
|
| 86 |
+
if st.session_state.get(_SIDEBAR_REMOTE_MODEL_KEY) not in remote_models:
|
| 87 |
+
st.session_state[_SIDEBAR_REMOTE_MODEL_KEY] = default_model
|
| 88 |
model_name = st.selectbox(
|
| 89 |
"Model",
|
| 90 |
options=remote_models,
|
| 91 |
index=remote_models.index(default_model),
|
| 92 |
+
key=_SIDEBAR_REMOTE_MODEL_KEY,
|
| 93 |
help="Running NDIF model.",
|
| 94 |
)
|
| 95 |
st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name
|
|
|
|
| 97 |
|
| 98 |
|
| 99 |
def _sidebar_controls() -> SidebarState:
|
|
|
|
|
|
|
| 100 |
with st.sidebar:
|
| 101 |
st.markdown("## Persona UI")
|
| 102 |
|
| 103 |
+
if _SIDEBAR_ACTIVE_TAB_KEY not in st.session_state:
|
| 104 |
+
st.session_state[_SIDEBAR_ACTIVE_TAB_KEY] = "Chat"
|
| 105 |
|
| 106 |
+
active_tab = st.session_state[_SIDEBAR_ACTIVE_TAB_KEY]
|
| 107 |
for tab_name, icon in zip(_TABS, _TAB_ICONS, strict=True):
|
| 108 |
is_selected = tab_name == active_tab
|
| 109 |
if st.button(
|
|
|
|
| 113 |
type="primary" if is_selected else "secondary",
|
| 114 |
icon=icon,
|
| 115 |
):
|
| 116 |
+
st.session_state[_SIDEBAR_ACTIVE_TAB_KEY] = tab_name
|
| 117 |
st.rerun()
|
| 118 |
|
| 119 |
if active_tab == "Analysis":
|
| 120 |
model_name = st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL)
|
| 121 |
dataset_source = st.session_state.get(
|
| 122 |
+
_SIDEBAR_DATASET_SOURCE_KEY,
|
| 123 |
DATASET_SOURCES[0],
|
| 124 |
)
|
| 125 |
return SidebarState(
|
|
|
|
| 131 |
|
| 132 |
st.divider()
|
| 133 |
st.caption("Runtime")
|
| 134 |
+
remote = st.toggle("Remote (NDIF)", value=False, key=_SIDEBAR_REMOTE_KEY)
|
| 135 |
|
| 136 |
if remote:
|
| 137 |
model_name = _remote_model_input(list_remote_models())
|
|
|
|
| 139 |
model_name = st.text_input(
|
| 140 |
"Model",
|
| 141 |
value=st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL),
|
| 142 |
+
key=_SIDEBAR_LOCAL_MODEL_KEY,
|
| 143 |
help="Local model id or path.",
|
| 144 |
)
|
| 145 |
st.session_state[_LAST_LOCAL_MODEL_KEY] = model_name
|
|
|
|
| 148 |
dataset_source = st.selectbox(
|
| 149 |
"Source",
|
| 150 |
DATASET_SOURCES,
|
| 151 |
+
key=_SIDEBAR_DATASET_SOURCE_KEY,
|
| 152 |
help="Dataset for Chat and Extract.",
|
| 153 |
)
|
| 154 |
|
|
|
|
| 164 |
"""Run the Streamlit app."""
|
| 165 |
|
| 166 |
st.set_page_config(page_title="Persona UI", layout="wide")
|
|
|
|
|
|
|
| 167 |
install_catppuccin_theme(st.get_option("theme.base"))
|
| 168 |
|
| 169 |
sidebar = _sidebar_controls()
|
|
|
|
| 173 |
|
| 174 |
render_extract_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)
|
| 175 |
elif sidebar.active_tab == "Analysis":
|
| 176 |
+
from tabs.analysis import render_analysis_tab
|
| 177 |
|
| 178 |
+
render_analysis_tab()
|
| 179 |
else:
|
| 180 |
from tabs.chat import render_chat_tab
|
| 181 |
|
pyproject.toml
CHANGED
|
@@ -5,8 +5,7 @@ description = "Streamlit UI for persona-vectors"
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
-
"persona-vectors>=0.
|
| 9 |
-
"persona-data>=0.4.2",
|
| 10 |
"datasets>=4.8.5",
|
| 11 |
"huggingface-hub>=1.14.0",
|
| 12 |
"streamlit>=1.44.0",
|
|
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
+
"persona-vectors>=0.8.0",
|
|
|
|
| 9 |
"datasets>=4.8.5",
|
| 10 |
"huggingface-hub>=1.14.0",
|
| 11 |
"streamlit>=1.44.0",
|
state.py
CHANGED
|
@@ -2,7 +2,8 @@ from typing import Literal, NotRequired, TypedDict
|
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
|
| 5 |
-
|
|
|
|
| 6 |
PendingChatAction = Literal["new_user_prompt", "regenerate_after_edit"]
|
| 7 |
|
| 8 |
|
|
@@ -22,7 +23,7 @@ class ChatState(TypedDict):
|
|
| 22 |
def chat_session_key(model_name: str, dataset_source: str) -> str:
|
| 23 |
"""Build the session-state key for a chat context."""
|
| 24 |
|
| 25 |
-
return
|
| 26 |
|
| 27 |
|
| 28 |
def default_chat_state() -> ChatState:
|
|
@@ -48,9 +49,8 @@ def reset_chat_context_state(
|
|
| 48 |
st.session_state.pop(key, None)
|
| 49 |
|
| 50 |
|
| 51 |
-
def get_chat_state(model_name: str,
|
| 52 |
"""Return the mutable chat state for the active context."""
|
| 53 |
|
| 54 |
key = chat_session_key(model_name, dataset_source)
|
| 55 |
-
|
| 56 |
-
return state
|
|
|
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
|
| 5 |
+
from utils.helpers import session_key
|
| 6 |
+
|
| 7 |
PendingChatAction = Literal["new_user_prompt", "regenerate_after_edit"]
|
| 8 |
|
| 9 |
|
|
|
|
| 23 |
def chat_session_key(model_name: str, dataset_source: str) -> str:
|
| 24 |
"""Build the session-state key for a chat context."""
|
| 25 |
|
| 26 |
+
return session_key("chat_state", model_name, dataset_source)
|
| 27 |
|
| 28 |
|
| 29 |
def default_chat_state() -> ChatState:
|
|
|
|
| 49 |
st.session_state.pop(key, None)
|
| 50 |
|
| 51 |
|
| 52 |
+
def get_chat_state(model_name: str, dataset_source: str) -> ChatState:
|
| 53 |
"""Return the mutable chat state for the active context."""
|
| 54 |
|
| 55 |
key = chat_session_key(model_name, dataset_source)
|
| 56 |
+
return st.session_state.setdefault(key, default_chat_state())
|
|
|
tabs/analysis.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .analysis_core import render_analysis_tab
|
tabs/{compare.py → analysis_core.py}
RENAMED
|
@@ -7,7 +7,12 @@ from pathlib import Path
|
|
| 7 |
import plotly.graph_objects as go
|
| 8 |
import streamlit as st
|
| 9 |
from persona_data.environment import get_artifacts_dir
|
| 10 |
-
from persona_data.synth_persona import BASELINE_PERSONA_ID
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from persona_vectors.extraction import MaskStrategy
|
| 12 |
from persona_vectors.plots import (
|
| 13 |
build_layered_figure,
|
|
@@ -15,10 +20,11 @@ from persona_vectors.plots import (
|
|
| 15 |
build_similarity_figures,
|
| 16 |
plot_layer_similarity,
|
| 17 |
plot_persona_dendrogram,
|
|
|
|
| 18 |
save_plot_html,
|
| 19 |
)
|
| 20 |
|
| 21 |
-
from utils.
|
| 22 |
DEFAULT_COMPARE_MODEL,
|
| 23 |
DEFAULT_HUB_REPO,
|
| 24 |
SOURCE_HUB,
|
|
@@ -28,13 +34,13 @@ from utils.compare_sources import (
|
|
| 28 |
activation_store_cached,
|
| 29 |
available_variants,
|
| 30 |
hub_models_by_mask_strategy,
|
| 31 |
-
|
| 32 |
-
|
| 33 |
local_model_matches,
|
| 34 |
local_model_options_cached,
|
| 35 |
persona_names_cached,
|
| 36 |
personas_cached,
|
| 37 |
-
|
| 38 |
store_cache_parts,
|
| 39 |
store_id,
|
| 40 |
store_layers_cached,
|
|
@@ -57,32 +63,45 @@ def _filename(*parts: str) -> str:
|
|
| 57 |
|
| 58 |
# Keep compare-tab selection state separate so projection defaults do not
|
| 59 |
# overwrite cosine similarity defaults.
|
| 60 |
-
_LAST_COSINE_PERSONAS_KEY = "
|
| 61 |
-
_LAST_PROJECTION_PERSONAS_KEY = "
|
| 62 |
-
_LAST_SIMILARITY_PERSONAS_KEY = "
|
| 63 |
-
_LAST_MASK_STRATEGY_KEY = "
|
| 64 |
-
_LAST_SOURCE_KEY = "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
_DEFAULT_LAYER_FRAMES = 16
|
| 67 |
_DEFAULT_PERSONA_LIMITS = {
|
| 68 |
"similarity": 120,
|
| 69 |
"pca": 500,
|
| 70 |
"umap": 500,
|
|
|
|
| 71 |
"dendro": 160,
|
| 72 |
}
|
| 73 |
_MAX_SIMILARITY_CELLS = 4_000_000
|
| 74 |
_MAX_PAIR_TRAJECTORY_TRACES = 500
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
"Agglomerative": "agglomerative",
|
| 78 |
-
"HDBSCAN": "hdbscan",
|
| 79 |
-
}
|
| 80 |
_CLUSTER_MODES = {
|
| 81 |
"Mean across layers": "mean_across_layers",
|
| 82 |
"First selected layer": "first_layer",
|
| 83 |
"Per layer": "per_layer",
|
| 84 |
}
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
def _is_assistant_persona(persona_id: str, persona_name: str | None = None) -> bool:
|
|
@@ -107,6 +126,98 @@ class CosineSelection:
|
|
| 107 |
class PersonaOptions:
|
| 108 |
regular_ids: list[str]
|
| 109 |
assistant_id: str | None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
|
| 112 |
def _layers_for_variant(
|
|
@@ -133,7 +244,7 @@ def _load_persona_vectors(
|
|
| 133 |
persona_ids: list[str],
|
| 134 |
):
|
| 135 |
source, location, model_name = store_cache_parts(store)
|
| 136 |
-
return
|
| 137 |
source,
|
| 138 |
location,
|
| 139 |
model_name,
|
|
@@ -150,7 +261,7 @@ def _load_variant_vectors(
|
|
| 150 |
persona_ids: list[str],
|
| 151 |
):
|
| 152 |
source, location, model_name = store_cache_parts(store)
|
| 153 |
-
return
|
| 154 |
source,
|
| 155 |
location,
|
| 156 |
model_name,
|
|
@@ -160,22 +271,55 @@ def _load_variant_vectors(
|
|
| 160 |
)
|
| 161 |
|
| 162 |
|
| 163 |
-
def
|
| 164 |
for key in list(st.session_state):
|
| 165 |
if key == current_key or not isinstance(key, str):
|
| 166 |
continue
|
| 167 |
parts = key.split("::", 2)
|
| 168 |
-
if len(parts) >= 2 and parts[0] == "load" and parts[1].endswith(
|
| 169 |
st.session_state.pop(key, None)
|
| 170 |
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
def _store_figure_state(key: str, value: object) -> None:
|
| 173 |
_clear_old_figure_states(key)
|
| 174 |
st.session_state[key] = value
|
| 175 |
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
def _release_vector_memory(store: Store, variants: list[str] | tuple[str, ...]) -> None:
|
| 178 |
-
|
| 179 |
gc.collect()
|
| 180 |
|
| 181 |
|
|
@@ -203,10 +347,22 @@ def _render_layer_frame_controls(
|
|
| 203 |
"Layer frames",
|
| 204 |
min_value=2,
|
| 205 |
max_value=len(layers),
|
| 206 |
-
value=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
key=widget_key("load", "layer_frames", scope, store_id(store)),
|
| 208 |
help="Limit animated Plotly frames to keep browser and RAM usage bounded.",
|
| 209 |
)
|
|
|
|
| 210 |
selected = _evenly_spaced_layers(layers, frame_count)
|
| 211 |
st.caption(f"Using {len(selected)} of {len(layers)} layers.")
|
| 212 |
return selected
|
|
@@ -259,7 +415,11 @@ def _load_persona_options(
|
|
| 259 |
if not regular_ids and assistant_id is None:
|
| 260 |
st.info("No personas found for this model and variant.")
|
| 261 |
return None
|
| 262 |
-
return PersonaOptions(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
|
| 265 |
def _seed_persona_memory(
|
|
@@ -366,6 +526,7 @@ def _select_artifact_personas(
|
|
| 366 |
empty_message=empty_message,
|
| 367 |
)
|
| 368 |
if options is None:
|
|
|
|
| 369 |
return []
|
| 370 |
|
| 371 |
default_count, include_assistant_default = _seed_persona_memory(
|
|
@@ -393,6 +554,7 @@ def _select_artifact_personas(
|
|
| 393 |
st.session_state[remembered_count_key] = persona_count
|
| 394 |
st.session_state[remembered_assistant_key] = include_assistant
|
| 395 |
st.session_state[remember_key] = persona_ids
|
|
|
|
| 396 |
|
| 397 |
if not persona_ids:
|
| 398 |
st.info("Select at least one persona or include the Assistant persona.")
|
|
@@ -415,7 +577,9 @@ def _render_save_buttons(
|
|
| 415 |
if st.button("Save HTML", key=widget_key("load", "save_html", key_suffix)):
|
| 416 |
try:
|
| 417 |
_style_plotly_figures(figs)
|
| 418 |
-
paths = [
|
|
|
|
|
|
|
| 419 |
st.success(f"Saved {len(paths)} HTML file(s) to `artifacts/plots`.")
|
| 420 |
except Exception as exc:
|
| 421 |
st.error(f"Could not save HTML: {exc}")
|
|
@@ -430,7 +594,11 @@ def _style_plotly_figures(figs: list[object]) -> None:
|
|
| 430 |
|
| 431 |
def _plotly_chart(fig: object) -> None:
|
| 432 |
_style_plotly_figures([fig])
|
| 433 |
-
st.plotly_chart(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 434 |
|
| 435 |
|
| 436 |
def _render_mask_strategy_select(scope: str) -> MaskStrategy:
|
|
@@ -584,7 +752,7 @@ def _render_cosine_similarity(
|
|
| 584 |
selection.persona_key,
|
| 585 |
)
|
| 586 |
filename = _filename(
|
| 587 |
-
"
|
| 588 |
"cosine",
|
| 589 |
store.model_name,
|
| 590 |
mask_strategy.value,
|
|
@@ -592,7 +760,7 @@ def _render_cosine_similarity(
|
|
| 592 |
selection.variant_b,
|
| 593 |
)
|
| 594 |
pairs_filename = _filename(
|
| 595 |
-
"
|
| 596 |
"cosine_pairs",
|
| 597 |
store.model_name,
|
| 598 |
mask_strategy.value,
|
|
@@ -605,7 +773,7 @@ def _render_cosine_similarity(
|
|
| 605 |
type="primary",
|
| 606 |
key=widget_key(
|
| 607 |
"load",
|
| 608 |
-
"
|
| 609 |
store_id(store),
|
| 610 |
store.model_name,
|
| 611 |
mask_strategy.value,
|
|
@@ -650,19 +818,29 @@ def _select_single_variant_samples(
|
|
| 650 |
scope: str,
|
| 651 |
*,
|
| 652 |
remember_key: str,
|
|
|
|
| 653 |
default_count_limit: int,
|
| 654 |
) -> tuple[str, list[str], str, list[int]] | None:
|
| 655 |
variants = available_variants(store, mask_strategy)
|
| 656 |
if not variants:
|
| 657 |
st.info("No variants with saved vectors for this model.")
|
| 658 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
variant = st.selectbox(
|
| 660 |
"Variant",
|
| 661 |
options=variants,
|
| 662 |
-
index=variants.index(
|
| 663 |
format_func=prompt_variant_label,
|
| 664 |
-
key=
|
| 665 |
)
|
|
|
|
| 666 |
persona_ids = _select_artifact_personas(
|
| 667 |
store,
|
| 668 |
[variant],
|
|
@@ -684,6 +862,352 @@ def _select_single_variant_samples(
|
|
| 684 |
return variant, persona_ids, persona_key, selected_layers
|
| 685 |
|
| 686 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 687 |
def _render_layered_figure_analysis(
|
| 688 |
store: Store,
|
| 689 |
mask_strategy: MaskStrategy,
|
|
@@ -707,124 +1231,60 @@ def _render_layered_figure_analysis(
|
|
| 707 |
mask_strategy,
|
| 708 |
scope,
|
| 709 |
remember_key=remember_key,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
default_count_limit=default_count_limit,
|
| 711 |
)
|
| 712 |
if selected is None:
|
| 713 |
return
|
| 714 |
variant, persona_ids, persona_key, selected_layers = selected
|
| 715 |
|
| 716 |
-
pair_trajectories =
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
"Pair trajectories",
|
| 727 |
-
value=False,
|
| 728 |
-
key=widget_key("load", "pair_trajectories", scope, store_id(store)),
|
| 729 |
-
help="Adds one line per persona pair. Keep this off for larger selections.",
|
| 730 |
-
)
|
| 731 |
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
"Reduce personas or layer frames before generating the similarity "
|
| 737 |
-
f"matrix ({similarity_cells:,} cells selected)."
|
| 738 |
-
)
|
| 739 |
return
|
| 740 |
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
)
|
| 753 |
-
if use_clusters:
|
| 754 |
-
method_label = st.selectbox(
|
| 755 |
-
"Cluster algorithm",
|
| 756 |
-
options=list(_CLUSTER_METHODS),
|
| 757 |
-
index=0,
|
| 758 |
-
key=widget_key("load", "cluster_method", scope, store_id(store)),
|
| 759 |
-
)
|
| 760 |
-
cluster_method = _CLUSTER_METHODS[method_label]
|
| 761 |
-
if cluster_method in {"kmeans", "agglomerative"}:
|
| 762 |
-
n_clusters = st.slider(
|
| 763 |
-
"K (clusters)",
|
| 764 |
-
min_value=2,
|
| 765 |
-
max_value=min(10, len(persona_ids)),
|
| 766 |
-
value=min(3, len(persona_ids)),
|
| 767 |
-
key=widget_key("load", "cluster_k", scope, store_id(store)),
|
| 768 |
-
)
|
| 769 |
-
if cluster_method == "agglomerative":
|
| 770 |
-
cluster_linkage = st.selectbox(
|
| 771 |
-
"Linkage",
|
| 772 |
-
options=_CLUSTER_LINKAGES,
|
| 773 |
-
index=0,
|
| 774 |
-
key=widget_key("load", "cluster_linkage", scope, store_id(store)),
|
| 775 |
-
)
|
| 776 |
-
if cluster_method == "hdbscan":
|
| 777 |
-
min_cluster_size = st.slider(
|
| 778 |
-
"Minimum cluster size",
|
| 779 |
-
min_value=2,
|
| 780 |
-
max_value=len(persona_ids),
|
| 781 |
-
value=min(5, len(persona_ids)),
|
| 782 |
-
key=widget_key(
|
| 783 |
-
"load",
|
| 784 |
-
"cluster_min_cluster_size",
|
| 785 |
-
scope,
|
| 786 |
-
store_id(store),
|
| 787 |
-
),
|
| 788 |
-
)
|
| 789 |
-
mode_label = st.selectbox(
|
| 790 |
-
"Cluster fit",
|
| 791 |
-
options=list(_CLUSTER_MODES),
|
| 792 |
-
index=0,
|
| 793 |
-
key=widget_key("load", "cluster_mode", scope, store_id(store)),
|
| 794 |
-
help=(
|
| 795 |
-
"Mean across layers is the previous behavior. First selected "
|
| 796 |
-
"layer keeps one fixed clustering from the first frame. Per layer "
|
| 797 |
-
"recomputes clustering for each animation frame."
|
| 798 |
-
),
|
| 799 |
-
)
|
| 800 |
-
cluster_mode = _CLUSTER_MODES[mode_label]
|
| 801 |
-
|
| 802 |
-
fig_key = widget_key(
|
| 803 |
-
"load",
|
| 804 |
-
f"{scope}_fig_state",
|
| 805 |
-
store_id(store),
|
| 806 |
-
store.model_name,
|
| 807 |
-
mask_strategy.value,
|
| 808 |
-
figure_kind,
|
| 809 |
-
str(n_components),
|
| 810 |
-
str(n_clusters),
|
| 811 |
-
str(cluster_mode),
|
| 812 |
-
str(cluster_method),
|
| 813 |
-
str(cluster_linkage),
|
| 814 |
-
str(min_cluster_size),
|
| 815 |
-
variant,
|
| 816 |
-
"persona_vector",
|
| 817 |
-
persona_key,
|
| 818 |
-
"_".join(map(str, selected_layers)),
|
| 819 |
-
str(pair_trajectories),
|
| 820 |
)
|
|
|
|
|
|
|
| 821 |
filename = scope
|
| 822 |
-
_clear_old_figure_states(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 823 |
|
| 824 |
if st.button(button_label, type="primary"):
|
| 825 |
build_label = {
|
| 826 |
"umap": "Computing UMAP projections…",
|
| 827 |
"pca": "Computing PCA projections…",
|
|
|
|
| 828 |
"similarity": "Computing similarity matrices…",
|
| 829 |
}.get(figure_kind, "Building figure…")
|
| 830 |
progress = st.progress(0, text="Loading activation vectors…")
|
|
@@ -837,63 +1297,44 @@ def _render_layered_figure_analysis(
|
|
| 837 |
persona_ids,
|
| 838 |
)
|
| 839 |
progress.progress(55, text=build_label)
|
| 840 |
-
build_kwargs =
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
figure_kind,
|
| 865 |
-
layers=selected_layers,
|
| 866 |
-
title=title_fn(variant),
|
| 867 |
-
**build_kwargs,
|
| 868 |
-
)
|
| 869 |
-
if figure_kind in {"umap", "pca"}:
|
| 870 |
-
main_fig.update_layout(height=700)
|
| 871 |
-
extra_fig = (
|
| 872 |
-
build_pair_similarity_figure(
|
| 873 |
-
samples,
|
| 874 |
-
layers=selected_layers,
|
| 875 |
-
title=(
|
| 876 |
-
"Pair similarity trajectories - "
|
| 877 |
-
f"{prompt_variant_label(variant)} - persona vectors"
|
| 878 |
-
),
|
| 879 |
-
)
|
| 880 |
-
if pair_trajectories
|
| 881 |
-
else None
|
| 882 |
-
)
|
| 883 |
progress.progress(90, text="Storing figure state…")
|
| 884 |
n_samples = samples.vectors.shape[0]
|
| 885 |
del samples
|
| 886 |
-
_store_figure_state(
|
| 887 |
progress.progress(100, text="Done.")
|
| 888 |
except Exception as exc:
|
| 889 |
st.error(f"Could not build figure: {exc}")
|
| 890 |
-
st.session_state.pop(
|
| 891 |
finally:
|
| 892 |
_release_vector_memory(store, [variant])
|
| 893 |
progress.empty()
|
| 894 |
|
| 895 |
-
if
|
| 896 |
-
main_fig, extra_fig, n_samples = st.session_state[
|
| 897 |
_plotly_chart(main_fig)
|
| 898 |
figs = [main_fig]
|
| 899 |
filenames = [filename]
|
|
@@ -906,7 +1347,7 @@ def _render_layered_figure_analysis(
|
|
| 906 |
st.success(f"Loaded {n_samples} samples.")
|
| 907 |
|
| 908 |
|
| 909 |
-
_LAST_DENDRO_PERSONAS_KEY = "
|
| 910 |
_DENDRO_LINKAGE_OPTIONS = ["ward", "complete", "average", "single"]
|
| 911 |
|
| 912 |
|
|
@@ -1108,7 +1549,7 @@ def _render_hub_model_select(
|
|
| 1108 |
mask_strategy: MaskStrategy,
|
| 1109 |
) -> str:
|
| 1110 |
fallback_model = st.session_state.get(
|
| 1111 |
-
"
|
| 1112 |
DEFAULT_COMPARE_MODEL,
|
| 1113 |
)
|
| 1114 |
try:
|
|
@@ -1118,7 +1559,7 @@ def _render_hub_model_select(
|
|
| 1118 |
return st.text_input(
|
| 1119 |
"Hub model",
|
| 1120 |
value=fallback_model,
|
| 1121 |
-
key="
|
| 1122 |
help="Compare-only model id to use if Hub config discovery is unavailable.",
|
| 1123 |
)
|
| 1124 |
|
|
@@ -1130,7 +1571,7 @@ def _render_hub_model_select(
|
|
| 1130 |
return st.text_input(
|
| 1131 |
"Hub model",
|
| 1132 |
value=fallback_model,
|
| 1133 |
-
key="
|
| 1134 |
help="Compare-only model id to use for this Hub repo.",
|
| 1135 |
)
|
| 1136 |
|
|
@@ -1155,31 +1596,31 @@ def _render_local_model_select(
|
|
| 1155 |
artifacts_root: str,
|
| 1156 |
mask_strategy: MaskStrategy,
|
| 1157 |
) -> str:
|
| 1158 |
-
fallback_model = st.session_state.get("
|
| 1159 |
model_options = local_model_options_cached(artifacts_root, mask_strategy.value)
|
| 1160 |
if not model_options:
|
| 1161 |
return st.text_input(
|
| 1162 |
"Local model",
|
| 1163 |
value=fallback_model,
|
| 1164 |
-
key="
|
| 1165 |
help="Compare-only local model id or path.",
|
| 1166 |
)
|
| 1167 |
|
| 1168 |
custom = st.toggle(
|
| 1169 |
"Custom local model",
|
| 1170 |
value=False,
|
| 1171 |
-
key="
|
| 1172 |
help="Enter a model id/path manually instead of choosing from activation directories.",
|
| 1173 |
)
|
| 1174 |
if custom:
|
| 1175 |
return st.text_input(
|
| 1176 |
"Local model",
|
| 1177 |
value=fallback_model,
|
| 1178 |
-
key="
|
| 1179 |
help="Compare-only local model id or path.",
|
| 1180 |
)
|
| 1181 |
|
| 1182 |
-
previous_model = st.session_state.get("
|
| 1183 |
if not any(local_model_matches(previous_model, option) for option in model_options):
|
| 1184 |
previous_model = fallback_model
|
| 1185 |
default_model = next(
|
|
@@ -1194,10 +1635,10 @@ def _render_local_model_select(
|
|
| 1194 |
"Local model",
|
| 1195 |
options=model_options,
|
| 1196 |
index=model_options.index(default_model),
|
| 1197 |
-
key="
|
| 1198 |
help="Models discovered under the selected artifacts root.",
|
| 1199 |
)
|
| 1200 |
-
st.session_state["
|
| 1201 |
return selected
|
| 1202 |
|
| 1203 |
|
|
@@ -1205,8 +1646,8 @@ def _build_store(source: str, mask_strategy: MaskStrategy) -> Store:
|
|
| 1205 |
if source == SOURCE_HUB:
|
| 1206 |
repo = st.text_input(
|
| 1207 |
"Hub repo",
|
| 1208 |
-
value=st.session_state.get("
|
| 1209 |
-
key="
|
| 1210 |
help="Hugging Face dataset published by `scripts/push_to_hf.py`.",
|
| 1211 |
)
|
| 1212 |
hub_model_name = _render_hub_model_select(repo, mask_strategy)
|
|
@@ -1219,7 +1660,7 @@ def _build_store(source: str, mask_strategy: MaskStrategy) -> Store:
|
|
| 1219 |
artifacts_root = st.text_input(
|
| 1220 |
"Artifacts root",
|
| 1221 |
value=str(get_artifacts_dir() / "activations"),
|
| 1222 |
-
key="
|
| 1223 |
)
|
| 1224 |
artifacts_root = str(Path(artifacts_root).expanduser())
|
| 1225 |
local_model_name = _render_local_model_select(artifacts_root, mask_strategy)
|
|
@@ -1231,12 +1672,12 @@ def _build_store(source: str, mask_strategy: MaskStrategy) -> Store:
|
|
| 1231 |
)
|
| 1232 |
|
| 1233 |
|
| 1234 |
-
def
|
| 1235 |
"""Render the analysis tab."""
|
| 1236 |
|
| 1237 |
st.title("Analysis")
|
| 1238 |
st.caption(
|
| 1239 |
-
"Analyse persona vectors by cosine similarity, PCA, UMAP, or hierarchical clustering."
|
| 1240 |
)
|
| 1241 |
|
| 1242 |
source = _render_source_select()
|
|
@@ -1279,13 +1720,23 @@ def render_compare_tab() -> None:
|
|
| 1279 |
_render_dendrogram_analysis(store, mask_strategy)
|
| 1280 |
return
|
| 1281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1282 |
dimension_choice = st.segmented_control(
|
| 1283 |
"Projection dimensions",
|
| 1284 |
-
options=
|
| 1285 |
-
default=
|
| 1286 |
-
key=
|
| 1287 |
label_visibility="collapsed",
|
| 1288 |
)
|
|
|
|
|
|
|
| 1289 |
n_components = 3 if dimension_choice == "3D" else 2
|
| 1290 |
dim_suffix = "" if n_components == 2 else " (3D)"
|
| 1291 |
_render_layered_figure_analysis(
|
|
|
|
| 7 |
import plotly.graph_objects as go
|
| 8 |
import streamlit as st
|
| 9 |
from persona_data.environment import get_artifacts_dir
|
| 10 |
+
from persona_data.synth_persona import BASELINE_PERSONA_ID, SynthPersonaDataset
|
| 11 |
+
from persona_vectors.attributes import (
|
| 12 |
+
DEFAULT_MAX_ATTRIBUTE_CATEGORIES,
|
| 13 |
+
attribute_color_kwargs,
|
| 14 |
+
attribute_display_label,
|
| 15 |
+
)
|
| 16 |
from persona_vectors.extraction import MaskStrategy
|
| 17 |
from persona_vectors.plots import (
|
| 18 |
build_layered_figure,
|
|
|
|
| 20 |
build_similarity_figures,
|
| 21 |
plot_layer_similarity,
|
| 22 |
plot_persona_dendrogram,
|
| 23 |
+
prepare_layered_projection_data,
|
| 24 |
save_plot_html,
|
| 25 |
)
|
| 26 |
|
| 27 |
+
from utils.analysis_sources import (
|
| 28 |
DEFAULT_COMPARE_MODEL,
|
| 29 |
DEFAULT_HUB_REPO,
|
| 30 |
SOURCE_HUB,
|
|
|
|
| 34 |
activation_store_cached,
|
| 35 |
available_variants,
|
| 36 |
hub_models_by_mask_strategy,
|
| 37 |
+
load_persona_vectors_cached,
|
| 38 |
+
load_variant_vectors_cached,
|
| 39 |
local_model_matches,
|
| 40 |
local_model_options_cached,
|
| 41 |
persona_names_cached,
|
| 42 |
personas_cached,
|
| 43 |
+
release_hf_store_cache,
|
| 44 |
store_cache_parts,
|
| 45 |
store_id,
|
| 46 |
store_layers_cached,
|
|
|
|
| 63 |
|
| 64 |
# Keep compare-tab selection state separate so projection defaults do not
|
| 65 |
# overwrite cosine similarity defaults.
|
| 66 |
+
_LAST_COSINE_PERSONAS_KEY = "analysis:last_personas:cosine"
|
| 67 |
+
_LAST_PROJECTION_PERSONAS_KEY = "analysis:last_personas:projection"
|
| 68 |
+
_LAST_SIMILARITY_PERSONAS_KEY = "analysis:last_personas:similarity"
|
| 69 |
+
_LAST_MASK_STRATEGY_KEY = "analysis:last_mask_strategy"
|
| 70 |
+
_LAST_SOURCE_KEY = "analysis:last_source"
|
| 71 |
+
_LAST_PROJECTION_VARIANT_KEY = "analysis:last_projection_variant"
|
| 72 |
+
_LAST_SIMILARITY_VARIANT_KEY = "analysis:last_similarity_variant"
|
| 73 |
+
_LAST_PROJECTION_COLOR_MODE_KEY = "analysis:last_projection_color_mode"
|
| 74 |
+
_LAST_PROJECTION_ATTRIBUTE_KEY = "analysis:last_projection_attribute"
|
| 75 |
+
_LAST_PROJECTION_CLUSTER_K_KEY = "analysis:last_projection_cluster_k"
|
| 76 |
+
_LAST_PROJECTION_CLUSTER_MODE_KEY = "analysis:last_projection_cluster_mode"
|
| 77 |
+
_LAST_PROJECTION_HIGHLIGHTS_KEY = "analysis:last_projection_highlights"
|
| 78 |
+
_LAST_PROJECTION_DIMS_KEY = "analysis:last_projection_dims"
|
| 79 |
+
_LAST_LAYER_FRAMES_KEY = "analysis:last_layer_frames"
|
| 80 |
|
| 81 |
_DEFAULT_LAYER_FRAMES = 16
|
| 82 |
_DEFAULT_PERSONA_LIMITS = {
|
| 83 |
"similarity": 120,
|
| 84 |
"pca": 500,
|
| 85 |
"umap": 500,
|
| 86 |
+
"isomap": 500,
|
| 87 |
"dendro": 160,
|
| 88 |
}
|
| 89 |
_MAX_SIMILARITY_CELLS = 4_000_000
|
| 90 |
_MAX_PAIR_TRAJECTORY_TRACES = 500
|
| 91 |
+
_DEFAULT_GRAPH_NEIGHBORS = 5
|
| 92 |
+
_PROJECTION_KINDS = {"pca", "umap", "isomap"}
|
|
|
|
|
|
|
|
|
|
| 93 |
_CLUSTER_MODES = {
|
| 94 |
"Mean across layers": "mean_across_layers",
|
| 95 |
"First selected layer": "first_layer",
|
| 96 |
"Per layer": "per_layer",
|
| 97 |
}
|
| 98 |
+
_PROJECTION_COLOR_MODES = ["Persona", "K-means clusters", "Persona attribute"]
|
| 99 |
+
_MAX_ATTRIBUTE_CATEGORIES = DEFAULT_MAX_ATTRIBUTE_CATEGORIES
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@st.cache_resource(show_spinner=False)
|
| 103 |
+
def _synth_persona_dataset() -> SynthPersonaDataset:
|
| 104 |
+
return SynthPersonaDataset()
|
| 105 |
|
| 106 |
|
| 107 |
def _is_assistant_persona(persona_id: str, persona_name: str | None = None) -> bool:
|
|
|
|
| 126 |
class PersonaOptions:
|
| 127 |
regular_ids: list[str]
|
| 128 |
assistant_id: str | None
|
| 129 |
+
persona_names: dict[str, str]
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@dataclass(frozen=True)
|
| 133 |
+
class ProjectionColorConfig:
|
| 134 |
+
color_mode: str = "Persona"
|
| 135 |
+
n_clusters: int | None = None
|
| 136 |
+
cluster_mode: str | None = None
|
| 137 |
+
attribute_name: str | None = None
|
| 138 |
+
highlight_persona_ids: tuple[str, ...] = ()
|
| 139 |
+
highlight_persona_key: str = ""
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@dataclass(frozen=True)
|
| 143 |
+
class LayeredFigureStateKeys:
|
| 144 |
+
figure: str
|
| 145 |
+
projection: str | None = None
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
_HIGHLIGHT_OTHER_LABEL = "Other"
|
| 149 |
+
_HIGHLIGHT_OTHER_COLOR = "rgba(148, 163, 184, 0.35)"
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _persona_names_state_key(widget_scope: str) -> str:
|
| 153 |
+
return widget_key("load", "persona_names", widget_scope)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def _persona_display_label(persona_names: dict[str, str], persona_id: str) -> str:
|
| 157 |
+
name = persona_names.get(persona_id, persona_id)
|
| 158 |
+
return f"{name} ({persona_id})" if name != persona_id else persona_id
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _highlight_persona_groups(
|
| 162 |
+
persona_ids: list[str],
|
| 163 |
+
persona_names: dict[str, str],
|
| 164 |
+
highlight_persona_ids: tuple[str, ...],
|
| 165 |
+
) -> list[str] | None:
|
| 166 |
+
if not highlight_persona_ids:
|
| 167 |
+
return None
|
| 168 |
+
|
| 169 |
+
highlighted = set(highlight_persona_ids)
|
| 170 |
+
return [
|
| 171 |
+
(
|
| 172 |
+
_persona_display_label(persona_names, persona_id)
|
| 173 |
+
if persona_id in highlighted
|
| 174 |
+
else _HIGHLIGHT_OTHER_LABEL
|
| 175 |
+
)
|
| 176 |
+
for persona_id in persona_ids
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def _sequence_to_list(value: object) -> list[object] | None:
|
| 181 |
+
if value is None or isinstance(value, (str, bytes)):
|
| 182 |
+
return None
|
| 183 |
+
if isinstance(value, list):
|
| 184 |
+
return value
|
| 185 |
+
if isinstance(value, tuple):
|
| 186 |
+
return list(value)
|
| 187 |
+
try:
|
| 188 |
+
return list(value)
|
| 189 |
+
except TypeError:
|
| 190 |
+
return None
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _gray_out_unselected_personas(fig: go.Figure) -> None:
|
| 194 |
+
def _gray_trace(trace: object) -> None:
|
| 195 |
+
marker = getattr(trace, "marker", None)
|
| 196 |
+
if marker is None:
|
| 197 |
+
return
|
| 198 |
+
|
| 199 |
+
colors = _sequence_to_list(getattr(marker, "color", None))
|
| 200 |
+
labels = _sequence_to_list(getattr(trace, "customdata", None))
|
| 201 |
+
if colors is not None and labels is not None and len(colors) == len(labels):
|
| 202 |
+
trace.marker.color = [
|
| 203 |
+
(
|
| 204 |
+
_HIGHLIGHT_OTHER_COLOR
|
| 205 |
+
if str(label) == _HIGHLIGHT_OTHER_LABEL
|
| 206 |
+
else color
|
| 207 |
+
)
|
| 208 |
+
for label, color in zip(labels, colors, strict=True)
|
| 209 |
+
]
|
| 210 |
+
return
|
| 211 |
+
|
| 212 |
+
if getattr(trace, "name", None) == _HIGHLIGHT_OTHER_LABEL:
|
| 213 |
+
trace.marker.color = _HIGHLIGHT_OTHER_COLOR
|
| 214 |
+
trace.opacity = 0.28
|
| 215 |
+
|
| 216 |
+
for trace in fig.data:
|
| 217 |
+
_gray_trace(trace)
|
| 218 |
+
for frame in fig.frames:
|
| 219 |
+
for trace in frame.data:
|
| 220 |
+
_gray_trace(trace)
|
| 221 |
|
| 222 |
|
| 223 |
def _layers_for_variant(
|
|
|
|
| 244 |
persona_ids: list[str],
|
| 245 |
):
|
| 246 |
source, location, model_name = store_cache_parts(store)
|
| 247 |
+
return load_persona_vectors_cached(
|
| 248 |
source,
|
| 249 |
location,
|
| 250 |
model_name,
|
|
|
|
| 261 |
persona_ids: list[str],
|
| 262 |
):
|
| 263 |
source, location, model_name = store_cache_parts(store)
|
| 264 |
+
return load_variant_vectors_cached(
|
| 265 |
source,
|
| 266 |
location,
|
| 267 |
model_name,
|
|
|
|
| 271 |
)
|
| 272 |
|
| 273 |
|
| 274 |
+
def _clear_old_load_states(current_key: str, suffix: str) -> None:
|
| 275 |
for key in list(st.session_state):
|
| 276 |
if key == current_key or not isinstance(key, str):
|
| 277 |
continue
|
| 278 |
parts = key.split("::", 2)
|
| 279 |
+
if len(parts) >= 2 and parts[0] == "load" and parts[1].endswith(suffix):
|
| 280 |
st.session_state.pop(key, None)
|
| 281 |
|
| 282 |
|
| 283 |
+
def _clear_old_figure_states(current_key: str) -> None:
|
| 284 |
+
_clear_old_load_states(current_key, "_fig_state")
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def _clear_old_projection_states(current_key: str) -> None:
|
| 288 |
+
_clear_old_load_states(current_key, "_projection_state")
|
| 289 |
+
|
| 290 |
+
|
| 291 |
def _store_figure_state(key: str, value: object) -> None:
|
| 292 |
_clear_old_figure_states(key)
|
| 293 |
st.session_state[key] = value
|
| 294 |
|
| 295 |
|
| 296 |
+
def _seed_selectbox_key(
|
| 297 |
+
*,
|
| 298 |
+
key: str,
|
| 299 |
+
remember_key: str,
|
| 300 |
+
options: list[str],
|
| 301 |
+
default: str,
|
| 302 |
+
) -> str:
|
| 303 |
+
value = st.session_state.get(key, st.session_state.get(remember_key, default))
|
| 304 |
+
if value not in options:
|
| 305 |
+
value = default
|
| 306 |
+
return value
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def _remember_multiselect(
|
| 310 |
+
*,
|
| 311 |
+
key: str,
|
| 312 |
+
remember_key: str,
|
| 313 |
+
options: list[str],
|
| 314 |
+
) -> list[str]:
|
| 315 |
+
remembered = st.session_state.get(key, st.session_state.get(remember_key, []))
|
| 316 |
+
if not isinstance(remembered, list):
|
| 317 |
+
remembered = []
|
| 318 |
+
return [value for value in remembered if value in options]
|
| 319 |
+
|
| 320 |
+
|
| 321 |
def _release_vector_memory(store: Store, variants: list[str] | tuple[str, ...]) -> None:
|
| 322 |
+
release_hf_store_cache(store, variants)
|
| 323 |
gc.collect()
|
| 324 |
|
| 325 |
|
|
|
|
| 347 |
"Layer frames",
|
| 348 |
min_value=2,
|
| 349 |
max_value=len(layers),
|
| 350 |
+
value=min(
|
| 351 |
+
max(
|
| 352 |
+
int(
|
| 353 |
+
st.session_state.get(
|
| 354 |
+
_LAST_LAYER_FRAMES_KEY,
|
| 355 |
+
_DEFAULT_LAYER_FRAMES,
|
| 356 |
+
)
|
| 357 |
+
),
|
| 358 |
+
2,
|
| 359 |
+
),
|
| 360 |
+
len(layers),
|
| 361 |
+
),
|
| 362 |
key=widget_key("load", "layer_frames", scope, store_id(store)),
|
| 363 |
help="Limit animated Plotly frames to keep browser and RAM usage bounded.",
|
| 364 |
)
|
| 365 |
+
st.session_state[_LAST_LAYER_FRAMES_KEY] = frame_count
|
| 366 |
selected = _evenly_spaced_layers(layers, frame_count)
|
| 367 |
st.caption(f"Using {len(selected)} of {len(layers)} layers.")
|
| 368 |
return selected
|
|
|
|
| 415 |
if not regular_ids and assistant_id is None:
|
| 416 |
st.info("No personas found for this model and variant.")
|
| 417 |
return None
|
| 418 |
+
return PersonaOptions(
|
| 419 |
+
regular_ids=regular_ids,
|
| 420 |
+
assistant_id=assistant_id,
|
| 421 |
+
persona_names=persona_names,
|
| 422 |
+
)
|
| 423 |
|
| 424 |
|
| 425 |
def _seed_persona_memory(
|
|
|
|
| 526 |
empty_message=empty_message,
|
| 527 |
)
|
| 528 |
if options is None:
|
| 529 |
+
st.session_state.pop(_persona_names_state_key(widget_scope), None)
|
| 530 |
return []
|
| 531 |
|
| 532 |
default_count, include_assistant_default = _seed_persona_memory(
|
|
|
|
| 554 |
st.session_state[remembered_count_key] = persona_count
|
| 555 |
st.session_state[remembered_assistant_key] = include_assistant
|
| 556 |
st.session_state[remember_key] = persona_ids
|
| 557 |
+
st.session_state[_persona_names_state_key(widget_scope)] = options.persona_names
|
| 558 |
|
| 559 |
if not persona_ids:
|
| 560 |
st.info("Select at least one persona or include the Assistant persona.")
|
|
|
|
| 577 |
if st.button("Save HTML", key=widget_key("load", "save_html", key_suffix)):
|
| 578 |
try:
|
| 579 |
_style_plotly_figures(figs)
|
| 580 |
+
paths = [
|
| 581 |
+
save_plot_html(fig, fn) for fig, fn in zip(figs, filenames, strict=True)
|
| 582 |
+
]
|
| 583 |
st.success(f"Saved {len(paths)} HTML file(s) to `artifacts/plots`.")
|
| 584 |
except Exception as exc:
|
| 585 |
st.error(f"Could not save HTML: {exc}")
|
|
|
|
| 594 |
|
| 595 |
def _plotly_chart(fig: object) -> None:
|
| 596 |
_style_plotly_figures([fig])
|
| 597 |
+
st.plotly_chart(
|
| 598 |
+
fig,
|
| 599 |
+
width="stretch",
|
| 600 |
+
config={"responsive": True, "displaylogo": False},
|
| 601 |
+
)
|
| 602 |
|
| 603 |
|
| 604 |
def _render_mask_strategy_select(scope: str) -> MaskStrategy:
|
|
|
|
| 752 |
selection.persona_key,
|
| 753 |
)
|
| 754 |
filename = _filename(
|
| 755 |
+
"analysis",
|
| 756 |
"cosine",
|
| 757 |
store.model_name,
|
| 758 |
mask_strategy.value,
|
|
|
|
| 760 |
selection.variant_b,
|
| 761 |
)
|
| 762 |
pairs_filename = _filename(
|
| 763 |
+
"analysis",
|
| 764 |
"cosine_pairs",
|
| 765 |
store.model_name,
|
| 766 |
mask_strategy.value,
|
|
|
|
| 773 |
type="primary",
|
| 774 |
key=widget_key(
|
| 775 |
"load",
|
| 776 |
+
"analysis_vectors",
|
| 777 |
store_id(store),
|
| 778 |
store.model_name,
|
| 779 |
mask_strategy.value,
|
|
|
|
| 818 |
scope: str,
|
| 819 |
*,
|
| 820 |
remember_key: str,
|
| 821 |
+
variant_remember_key: str,
|
| 822 |
default_count_limit: int,
|
| 823 |
) -> tuple[str, list[str], str, list[int]] | None:
|
| 824 |
variants = available_variants(store, mask_strategy)
|
| 825 |
if not variants:
|
| 826 |
st.info("No variants with saved vectors for this model.")
|
| 827 |
return None
|
| 828 |
+
variant_key = widget_key("load", "variant", scope, store_id(store))
|
| 829 |
+
default_variant = "biography" if "biography" in variants else variants[0]
|
| 830 |
+
selected_variant = _seed_selectbox_key(
|
| 831 |
+
key=variant_key,
|
| 832 |
+
remember_key=variant_remember_key,
|
| 833 |
+
options=variants,
|
| 834 |
+
default=default_variant,
|
| 835 |
+
)
|
| 836 |
variant = st.selectbox(
|
| 837 |
"Variant",
|
| 838 |
options=variants,
|
| 839 |
+
index=variants.index(selected_variant),
|
| 840 |
format_func=prompt_variant_label,
|
| 841 |
+
key=variant_key,
|
| 842 |
)
|
| 843 |
+
st.session_state[variant_remember_key] = variant
|
| 844 |
persona_ids = _select_artifact_personas(
|
| 845 |
store,
|
| 846 |
[variant],
|
|
|
|
| 862 |
return variant, persona_ids, persona_key, selected_layers
|
| 863 |
|
| 864 |
|
| 865 |
+
def _render_pair_trajectory_control(
|
| 866 |
+
*,
|
| 867 |
+
enabled: bool,
|
| 868 |
+
persona_count: int,
|
| 869 |
+
scope: str,
|
| 870 |
+
store: Store,
|
| 871 |
+
) -> bool:
|
| 872 |
+
if not enabled:
|
| 873 |
+
return False
|
| 874 |
+
pair_count = persona_count * (persona_count - 1) // 2
|
| 875 |
+
if pair_count > _MAX_PAIR_TRAJECTORY_TRACES:
|
| 876 |
+
st.caption(
|
| 877 |
+
"Pair trajectories hidden because this selection would create "
|
| 878 |
+
f"{pair_count:,} Plotly traces."
|
| 879 |
+
)
|
| 880 |
+
return False
|
| 881 |
+
return st.checkbox(
|
| 882 |
+
"Pair trajectories",
|
| 883 |
+
value=False,
|
| 884 |
+
key=widget_key("load", "pair_trajectories", scope, store_id(store)),
|
| 885 |
+
help="Adds one line per persona pair. Keep this off for larger selections.",
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
|
| 889 |
+
def _validate_layered_figure_size(
|
| 890 |
+
figure_kind: str,
|
| 891 |
+
persona_count: int,
|
| 892 |
+
selected_layers: list[int],
|
| 893 |
+
) -> bool:
|
| 894 |
+
if figure_kind != "similarity":
|
| 895 |
+
return True
|
| 896 |
+
similarity_cells = persona_count * persona_count * len(selected_layers)
|
| 897 |
+
if similarity_cells <= _MAX_SIMILARITY_CELLS:
|
| 898 |
+
return True
|
| 899 |
+
st.error(
|
| 900 |
+
"Reduce personas or layer frames before generating the similarity "
|
| 901 |
+
f"matrix ({similarity_cells:,} cells selected)."
|
| 902 |
+
)
|
| 903 |
+
return False
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
def _render_projection_color_config(
|
| 907 |
+
store: Store,
|
| 908 |
+
scope: str,
|
| 909 |
+
persona_ids: list[str],
|
| 910 |
+
) -> ProjectionColorConfig | None:
|
| 911 |
+
widget_scope = f"{scope}:{store_id(store)}"
|
| 912 |
+
persona_key = personas_fingerprint(persona_ids)
|
| 913 |
+
persona_names = st.session_state.get(
|
| 914 |
+
_persona_names_state_key(widget_scope),
|
| 915 |
+
{},
|
| 916 |
+
)
|
| 917 |
+
color_mode_key = widget_key("load", "color_mode", scope, store_id(store))
|
| 918 |
+
selected_color_mode = _seed_selectbox_key(
|
| 919 |
+
key=color_mode_key,
|
| 920 |
+
remember_key=_LAST_PROJECTION_COLOR_MODE_KEY,
|
| 921 |
+
options=_PROJECTION_COLOR_MODES,
|
| 922 |
+
default="Persona",
|
| 923 |
+
)
|
| 924 |
+
color_mode = st.selectbox(
|
| 925 |
+
"Color by",
|
| 926 |
+
options=_PROJECTION_COLOR_MODES,
|
| 927 |
+
index=_PROJECTION_COLOR_MODES.index(selected_color_mode),
|
| 928 |
+
key=color_mode_key,
|
| 929 |
+
)
|
| 930 |
+
st.session_state[_LAST_PROJECTION_COLOR_MODE_KEY] = color_mode
|
| 931 |
+
if color_mode == "K-means clusters":
|
| 932 |
+
max_clusters = min(10, len(persona_ids))
|
| 933 |
+
if max_clusters < 2:
|
| 934 |
+
st.info("Select at least two personas to use K-means coloring.")
|
| 935 |
+
return None
|
| 936 |
+
cluster_key = widget_key("load", "cluster_k", scope, store_id(store))
|
| 937 |
+
default_clusters = min(3, len(persona_ids))
|
| 938 |
+
if cluster_key not in st.session_state:
|
| 939 |
+
st.session_state[cluster_key] = min(
|
| 940 |
+
max(
|
| 941 |
+
int(
|
| 942 |
+
st.session_state.get(
|
| 943 |
+
_LAST_PROJECTION_CLUSTER_K_KEY,
|
| 944 |
+
default_clusters,
|
| 945 |
+
)
|
| 946 |
+
),
|
| 947 |
+
2,
|
| 948 |
+
),
|
| 949 |
+
max_clusters,
|
| 950 |
+
)
|
| 951 |
+
n_clusters = st.slider(
|
| 952 |
+
"K (clusters)",
|
| 953 |
+
min_value=2,
|
| 954 |
+
max_value=max_clusters,
|
| 955 |
+
key=cluster_key,
|
| 956 |
+
)
|
| 957 |
+
mode_key = widget_key("load", "cluster_mode", scope, store_id(store))
|
| 958 |
+
mode_options = list(_CLUSTER_MODES)
|
| 959 |
+
selected_mode = _seed_selectbox_key(
|
| 960 |
+
key=mode_key,
|
| 961 |
+
remember_key=_LAST_PROJECTION_CLUSTER_MODE_KEY,
|
| 962 |
+
options=mode_options,
|
| 963 |
+
default=mode_options[0],
|
| 964 |
+
)
|
| 965 |
+
mode_label = st.selectbox(
|
| 966 |
+
"Cluster fit",
|
| 967 |
+
options=mode_options,
|
| 968 |
+
index=mode_options.index(selected_mode),
|
| 969 |
+
key=mode_key,
|
| 970 |
+
help=(
|
| 971 |
+
"Mean across layers is the previous behavior. First selected "
|
| 972 |
+
"layer keeps one fixed clustering from the first frame. Per layer "
|
| 973 |
+
"recomputes clustering for each animation frame."
|
| 974 |
+
),
|
| 975 |
+
)
|
| 976 |
+
st.session_state[_LAST_PROJECTION_CLUSTER_K_KEY] = n_clusters
|
| 977 |
+
st.session_state[_LAST_PROJECTION_CLUSTER_MODE_KEY] = mode_label
|
| 978 |
+
return ProjectionColorConfig(
|
| 979 |
+
color_mode=color_mode,
|
| 980 |
+
n_clusters=n_clusters,
|
| 981 |
+
cluster_mode=_CLUSTER_MODES[mode_label],
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
if color_mode == "Persona attribute":
|
| 985 |
+
persona_dataset = _synth_persona_dataset()
|
| 986 |
+
attribute_options = list(persona_dataset.attribute_names)
|
| 987 |
+
if not attribute_options:
|
| 988 |
+
st.info("No persona attributes are available for this dataset.")
|
| 989 |
+
return None
|
| 990 |
+
default_attribute = (
|
| 991 |
+
attribute_options.index("sex") if "sex" in attribute_options else 0
|
| 992 |
+
)
|
| 993 |
+
attribute_key = widget_key("load", "attribute", scope, store_id(store))
|
| 994 |
+
selected_attribute = _seed_selectbox_key(
|
| 995 |
+
key=attribute_key,
|
| 996 |
+
remember_key=_LAST_PROJECTION_ATTRIBUTE_KEY,
|
| 997 |
+
options=attribute_options,
|
| 998 |
+
default=attribute_options[default_attribute],
|
| 999 |
+
)
|
| 1000 |
+
attribute_name = st.selectbox(
|
| 1001 |
+
"Attribute",
|
| 1002 |
+
options=attribute_options,
|
| 1003 |
+
index=attribute_options.index(selected_attribute),
|
| 1004 |
+
format_func=lambda name: attribute_display_label(persona_dataset, name),
|
| 1005 |
+
key=attribute_key,
|
| 1006 |
+
)
|
| 1007 |
+
st.session_state[_LAST_PROJECTION_ATTRIBUTE_KEY] = attribute_name
|
| 1008 |
+
info = persona_dataset.attribute_info(attribute_name)
|
| 1009 |
+
if info.get("high_cardinality"):
|
| 1010 |
+
st.caption(
|
| 1011 |
+
"High-cardinality categorical attributes are grouped to the "
|
| 1012 |
+
f"top {_MAX_ATTRIBUTE_CATEGORIES} values plus Other."
|
| 1013 |
+
)
|
| 1014 |
+
return ProjectionColorConfig(
|
| 1015 |
+
color_mode=color_mode,
|
| 1016 |
+
attribute_name=attribute_name,
|
| 1017 |
+
)
|
| 1018 |
+
|
| 1019 |
+
highlight_persona_ids: tuple[str, ...] = ()
|
| 1020 |
+
if persona_ids:
|
| 1021 |
+
highlight_key = widget_key(
|
| 1022 |
+
"load", "persona_highlight", scope, store_id(store), persona_key
|
| 1023 |
+
)
|
| 1024 |
+
highlighted = st.multiselect(
|
| 1025 |
+
"Highlight personas",
|
| 1026 |
+
options=persona_ids,
|
| 1027 |
+
default=_remember_multiselect(
|
| 1028 |
+
key=highlight_key,
|
| 1029 |
+
remember_key=_LAST_PROJECTION_HIGHLIGHTS_KEY,
|
| 1030 |
+
options=persona_ids,
|
| 1031 |
+
),
|
| 1032 |
+
format_func=lambda persona_id: _persona_display_label(
|
| 1033 |
+
persona_names, persona_id
|
| 1034 |
+
),
|
| 1035 |
+
key=highlight_key,
|
| 1036 |
+
help=(
|
| 1037 |
+
"Select a few personas to keep their default colors while the rest "
|
| 1038 |
+
"are grayed out."
|
| 1039 |
+
),
|
| 1040 |
+
)
|
| 1041 |
+
highlight_persona_ids = tuple(highlighted)
|
| 1042 |
+
st.session_state[_LAST_PROJECTION_HIGHLIGHTS_KEY] = list(highlighted)
|
| 1043 |
+
|
| 1044 |
+
highlight_persona_key = (
|
| 1045 |
+
personas_fingerprint(highlight_persona_ids) if highlight_persona_ids else ""
|
| 1046 |
+
)
|
| 1047 |
+
|
| 1048 |
+
return ProjectionColorConfig(
|
| 1049 |
+
color_mode=color_mode,
|
| 1050 |
+
highlight_persona_ids=highlight_persona_ids,
|
| 1051 |
+
highlight_persona_key=highlight_persona_key,
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
+
|
| 1055 |
+
def _layered_figure_state_keys(
|
| 1056 |
+
store: Store,
|
| 1057 |
+
mask_strategy: MaskStrategy,
|
| 1058 |
+
*,
|
| 1059 |
+
scope: str,
|
| 1060 |
+
figure_kind: str,
|
| 1061 |
+
n_components: int,
|
| 1062 |
+
color_config: ProjectionColorConfig,
|
| 1063 |
+
variant: str,
|
| 1064 |
+
persona_key: str,
|
| 1065 |
+
selected_layers: list[int],
|
| 1066 |
+
pair_trajectories: bool,
|
| 1067 |
+
) -> LayeredFigureStateKeys:
|
| 1068 |
+
layer_key = "_".join(map(str, selected_layers))
|
| 1069 |
+
figure_key = widget_key(
|
| 1070 |
+
"load",
|
| 1071 |
+
f"{scope}_fig_state",
|
| 1072 |
+
store_id(store),
|
| 1073 |
+
store.model_name,
|
| 1074 |
+
mask_strategy.value,
|
| 1075 |
+
figure_kind,
|
| 1076 |
+
str(n_components),
|
| 1077 |
+
color_config.color_mode,
|
| 1078 |
+
str(color_config.attribute_name),
|
| 1079 |
+
str(color_config.n_clusters),
|
| 1080 |
+
str(color_config.cluster_mode),
|
| 1081 |
+
str(color_config.highlight_persona_key),
|
| 1082 |
+
variant,
|
| 1083 |
+
"persona_vector",
|
| 1084 |
+
persona_key,
|
| 1085 |
+
layer_key,
|
| 1086 |
+
str(pair_trajectories),
|
| 1087 |
+
)
|
| 1088 |
+
if figure_kind not in _PROJECTION_KINDS:
|
| 1089 |
+
return LayeredFigureStateKeys(figure=figure_key)
|
| 1090 |
+
|
| 1091 |
+
graph_overlay = figure_kind == "isomap"
|
| 1092 |
+
projection_key = widget_key(
|
| 1093 |
+
"load",
|
| 1094 |
+
f"{scope}_projection_state",
|
| 1095 |
+
store_id(store),
|
| 1096 |
+
store.model_name,
|
| 1097 |
+
mask_strategy.value,
|
| 1098 |
+
figure_kind,
|
| 1099 |
+
str(n_components),
|
| 1100 |
+
str(graph_overlay),
|
| 1101 |
+
str(_DEFAULT_GRAPH_NEIGHBORS),
|
| 1102 |
+
variant,
|
| 1103 |
+
"persona_vector",
|
| 1104 |
+
persona_key,
|
| 1105 |
+
layer_key,
|
| 1106 |
+
)
|
| 1107 |
+
return LayeredFigureStateKeys(figure=figure_key, projection=projection_key)
|
| 1108 |
+
|
| 1109 |
+
|
| 1110 |
+
def _projection_build_kwargs(
|
| 1111 |
+
samples,
|
| 1112 |
+
*,
|
| 1113 |
+
figure_kind: str,
|
| 1114 |
+
selected_layers: list[int],
|
| 1115 |
+
n_components: int,
|
| 1116 |
+
color_config: ProjectionColorConfig,
|
| 1117 |
+
persona_ids: list[str],
|
| 1118 |
+
persona_names: dict[str, str],
|
| 1119 |
+
projection_key: str | None,
|
| 1120 |
+
) -> dict:
|
| 1121 |
+
if figure_kind not in _PROJECTION_KINDS:
|
| 1122 |
+
return {}
|
| 1123 |
+
|
| 1124 |
+
graph_overlay = figure_kind == "isomap"
|
| 1125 |
+
build_kwargs = {
|
| 1126 |
+
"n_components": n_components,
|
| 1127 |
+
"graph_overlay": graph_overlay,
|
| 1128 |
+
"graph_n_neighbors": _DEFAULT_GRAPH_NEIGHBORS,
|
| 1129 |
+
}
|
| 1130 |
+
if color_config.n_clusters is not None:
|
| 1131 |
+
build_kwargs["n_clusters"] = color_config.n_clusters
|
| 1132 |
+
build_kwargs["cluster_mode"] = color_config.cluster_mode
|
| 1133 |
+
if projection_key is not None:
|
| 1134 |
+
projection_data = st.session_state.get(projection_key)
|
| 1135 |
+
if projection_data is None:
|
| 1136 |
+
projection_data = prepare_layered_projection_data(
|
| 1137 |
+
samples,
|
| 1138 |
+
figure_kind,
|
| 1139 |
+
layers=selected_layers,
|
| 1140 |
+
n_components=n_components,
|
| 1141 |
+
graph_overlay=graph_overlay,
|
| 1142 |
+
graph_n_neighbors=_DEFAULT_GRAPH_NEIGHBORS,
|
| 1143 |
+
)
|
| 1144 |
+
st.session_state[projection_key] = projection_data
|
| 1145 |
+
build_kwargs["projection_data"] = projection_data
|
| 1146 |
+
if color_config.attribute_name is not None:
|
| 1147 |
+
build_kwargs.update(
|
| 1148 |
+
attribute_color_kwargs(
|
| 1149 |
+
_synth_persona_dataset(),
|
| 1150 |
+
color_config.attribute_name,
|
| 1151 |
+
persona_ids,
|
| 1152 |
+
max_categories=_MAX_ATTRIBUTE_CATEGORIES,
|
| 1153 |
+
)
|
| 1154 |
+
)
|
| 1155 |
+
if color_config.color_mode == "Persona" and color_config.highlight_persona_ids:
|
| 1156 |
+
groups = _highlight_persona_groups(
|
| 1157 |
+
persona_ids,
|
| 1158 |
+
persona_names,
|
| 1159 |
+
color_config.highlight_persona_ids,
|
| 1160 |
+
)
|
| 1161 |
+
if groups is not None:
|
| 1162 |
+
build_kwargs["groups"] = groups
|
| 1163 |
+
return build_kwargs
|
| 1164 |
+
|
| 1165 |
+
|
| 1166 |
+
def _build_layered_analysis_figures(
|
| 1167 |
+
samples,
|
| 1168 |
+
*,
|
| 1169 |
+
figure_kind: str,
|
| 1170 |
+
selected_layers: list[int],
|
| 1171 |
+
variant: str,
|
| 1172 |
+
title_fn: Callable[[str], str],
|
| 1173 |
+
pair_trajectories: bool,
|
| 1174 |
+
build_kwargs: dict,
|
| 1175 |
+
) -> tuple[go.Figure, go.Figure | None]:
|
| 1176 |
+
if figure_kind == "similarity" and pair_trajectories:
|
| 1177 |
+
return build_similarity_figures(
|
| 1178 |
+
samples,
|
| 1179 |
+
layers=selected_layers,
|
| 1180 |
+
title=title_fn(variant),
|
| 1181 |
+
pair_title=(
|
| 1182 |
+
"Pair similarity trajectories - "
|
| 1183 |
+
f"{prompt_variant_label(variant)} - persona vectors"
|
| 1184 |
+
),
|
| 1185 |
+
)
|
| 1186 |
+
|
| 1187 |
+
main_fig = build_layered_figure(
|
| 1188 |
+
samples,
|
| 1189 |
+
figure_kind,
|
| 1190 |
+
layers=selected_layers,
|
| 1191 |
+
title=title_fn(variant),
|
| 1192 |
+
**build_kwargs,
|
| 1193 |
+
)
|
| 1194 |
+
if figure_kind in _PROJECTION_KINDS:
|
| 1195 |
+
main_fig.update_layout(height=700)
|
| 1196 |
+
extra_fig = (
|
| 1197 |
+
build_pair_similarity_figure(
|
| 1198 |
+
samples,
|
| 1199 |
+
layers=selected_layers,
|
| 1200 |
+
title=(
|
| 1201 |
+
"Pair similarity trajectories - "
|
| 1202 |
+
f"{prompt_variant_label(variant)} - persona vectors"
|
| 1203 |
+
),
|
| 1204 |
+
)
|
| 1205 |
+
if pair_trajectories
|
| 1206 |
+
else None
|
| 1207 |
+
)
|
| 1208 |
+
return main_fig, extra_fig
|
| 1209 |
+
|
| 1210 |
+
|
| 1211 |
def _render_layered_figure_analysis(
|
| 1212 |
store: Store,
|
| 1213 |
mask_strategy: MaskStrategy,
|
|
|
|
| 1231 |
mask_strategy,
|
| 1232 |
scope,
|
| 1233 |
remember_key=remember_key,
|
| 1234 |
+
variant_remember_key=(
|
| 1235 |
+
_LAST_PROJECTION_VARIANT_KEY
|
| 1236 |
+
if figure_kind in _PROJECTION_KINDS
|
| 1237 |
+
else _LAST_SIMILARITY_VARIANT_KEY
|
| 1238 |
+
),
|
| 1239 |
default_count_limit=default_count_limit,
|
| 1240 |
)
|
| 1241 |
if selected is None:
|
| 1242 |
return
|
| 1243 |
variant, persona_ids, persona_key, selected_layers = selected
|
| 1244 |
|
| 1245 |
+
pair_trajectories = _render_pair_trajectory_control(
|
| 1246 |
+
enabled=include_pair_trajectories,
|
| 1247 |
+
persona_count=len(persona_ids),
|
| 1248 |
+
scope=scope,
|
| 1249 |
+
store=store,
|
| 1250 |
+
)
|
| 1251 |
+
if not _validate_layered_figure_size(
|
| 1252 |
+
figure_kind, len(persona_ids), selected_layers
|
| 1253 |
+
):
|
| 1254 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1255 |
|
| 1256 |
+
color_config = ProjectionColorConfig()
|
| 1257 |
+
if figure_kind in _PROJECTION_KINDS:
|
| 1258 |
+
color_config = _render_projection_color_config(store, scope, persona_ids)
|
| 1259 |
+
if color_config is None:
|
|
|
|
|
|
|
|
|
|
| 1260 |
return
|
| 1261 |
|
| 1262 |
+
state_keys = _layered_figure_state_keys(
|
| 1263 |
+
store,
|
| 1264 |
+
mask_strategy,
|
| 1265 |
+
scope=scope,
|
| 1266 |
+
figure_kind=figure_kind,
|
| 1267 |
+
n_components=n_components,
|
| 1268 |
+
color_config=color_config,
|
| 1269 |
+
variant=variant,
|
| 1270 |
+
persona_key=persona_key,
|
| 1271 |
+
selected_layers=selected_layers,
|
| 1272 |
+
pair_trajectories=pair_trajectories,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1273 |
)
|
| 1274 |
+
if state_keys.projection is not None:
|
| 1275 |
+
_clear_old_projection_states(state_keys.projection)
|
| 1276 |
filename = scope
|
| 1277 |
+
_clear_old_figure_states(state_keys.figure)
|
| 1278 |
+
persona_names = st.session_state.get(
|
| 1279 |
+
_persona_names_state_key(f"{scope}:{store_id(store)}"),
|
| 1280 |
+
{},
|
| 1281 |
+
)
|
| 1282 |
|
| 1283 |
if st.button(button_label, type="primary"):
|
| 1284 |
build_label = {
|
| 1285 |
"umap": "Computing UMAP projections…",
|
| 1286 |
"pca": "Computing PCA projections…",
|
| 1287 |
+
"isomap": "Computing Isomap projections…",
|
| 1288 |
"similarity": "Computing similarity matrices…",
|
| 1289 |
}.get(figure_kind, "Building figure…")
|
| 1290 |
progress = st.progress(0, text="Loading activation vectors…")
|
|
|
|
| 1297 |
persona_ids,
|
| 1298 |
)
|
| 1299 |
progress.progress(55, text=build_label)
|
| 1300 |
+
build_kwargs = _projection_build_kwargs(
|
| 1301 |
+
samples,
|
| 1302 |
+
figure_kind=figure_kind,
|
| 1303 |
+
selected_layers=selected_layers,
|
| 1304 |
+
n_components=n_components,
|
| 1305 |
+
color_config=color_config,
|
| 1306 |
+
persona_ids=persona_ids,
|
| 1307 |
+
persona_names=persona_names,
|
| 1308 |
+
projection_key=state_keys.projection,
|
| 1309 |
+
)
|
| 1310 |
+
main_fig, extra_fig = _build_layered_analysis_figures(
|
| 1311 |
+
samples,
|
| 1312 |
+
figure_kind=figure_kind,
|
| 1313 |
+
selected_layers=selected_layers,
|
| 1314 |
+
variant=variant,
|
| 1315 |
+
title_fn=title_fn,
|
| 1316 |
+
pair_trajectories=pair_trajectories,
|
| 1317 |
+
build_kwargs=build_kwargs,
|
| 1318 |
+
)
|
| 1319 |
+
if (
|
| 1320 |
+
color_config.color_mode == "Persona"
|
| 1321 |
+
and color_config.highlight_persona_ids
|
| 1322 |
+
):
|
| 1323 |
+
_gray_out_unselected_personas(main_fig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1324 |
progress.progress(90, text="Storing figure state…")
|
| 1325 |
n_samples = samples.vectors.shape[0]
|
| 1326 |
del samples
|
| 1327 |
+
_store_figure_state(state_keys.figure, (main_fig, extra_fig, n_samples))
|
| 1328 |
progress.progress(100, text="Done.")
|
| 1329 |
except Exception as exc:
|
| 1330 |
st.error(f"Could not build figure: {exc}")
|
| 1331 |
+
st.session_state.pop(state_keys.figure, None)
|
| 1332 |
finally:
|
| 1333 |
_release_vector_memory(store, [variant])
|
| 1334 |
progress.empty()
|
| 1335 |
|
| 1336 |
+
if state_keys.figure in st.session_state:
|
| 1337 |
+
main_fig, extra_fig, n_samples = st.session_state[state_keys.figure]
|
| 1338 |
_plotly_chart(main_fig)
|
| 1339 |
figs = [main_fig]
|
| 1340 |
filenames = [filename]
|
|
|
|
| 1347 |
st.success(f"Loaded {n_samples} samples.")
|
| 1348 |
|
| 1349 |
|
| 1350 |
+
_LAST_DENDRO_PERSONAS_KEY = "analysis:last_personas:dendro"
|
| 1351 |
_DENDRO_LINKAGE_OPTIONS = ["ward", "complete", "average", "single"]
|
| 1352 |
|
| 1353 |
|
|
|
|
| 1549 |
mask_strategy: MaskStrategy,
|
| 1550 |
) -> str:
|
| 1551 |
fallback_model = st.session_state.get(
|
| 1552 |
+
"analysis:hub_model_fallback",
|
| 1553 |
DEFAULT_COMPARE_MODEL,
|
| 1554 |
)
|
| 1555 |
try:
|
|
|
|
| 1559 |
return st.text_input(
|
| 1560 |
"Hub model",
|
| 1561 |
value=fallback_model,
|
| 1562 |
+
key="analysis:hub_model_fallback",
|
| 1563 |
help="Compare-only model id to use if Hub config discovery is unavailable.",
|
| 1564 |
)
|
| 1565 |
|
|
|
|
| 1571 |
return st.text_input(
|
| 1572 |
"Hub model",
|
| 1573 |
value=fallback_model,
|
| 1574 |
+
key="analysis:hub_model_fallback",
|
| 1575 |
help="Compare-only model id to use for this Hub repo.",
|
| 1576 |
)
|
| 1577 |
|
|
|
|
| 1596 |
artifacts_root: str,
|
| 1597 |
mask_strategy: MaskStrategy,
|
| 1598 |
) -> str:
|
| 1599 |
+
fallback_model = st.session_state.get("analysis:local_model", DEFAULT_COMPARE_MODEL)
|
| 1600 |
model_options = local_model_options_cached(artifacts_root, mask_strategy.value)
|
| 1601 |
if not model_options:
|
| 1602 |
return st.text_input(
|
| 1603 |
"Local model",
|
| 1604 |
value=fallback_model,
|
| 1605 |
+
key="analysis:local_model",
|
| 1606 |
help="Compare-only local model id or path.",
|
| 1607 |
)
|
| 1608 |
|
| 1609 |
custom = st.toggle(
|
| 1610 |
"Custom local model",
|
| 1611 |
value=False,
|
| 1612 |
+
key="analysis:local_model_custom_enabled",
|
| 1613 |
help="Enter a model id/path manually instead of choosing from activation directories.",
|
| 1614 |
)
|
| 1615 |
if custom:
|
| 1616 |
return st.text_input(
|
| 1617 |
"Local model",
|
| 1618 |
value=fallback_model,
|
| 1619 |
+
key="analysis:local_model",
|
| 1620 |
help="Compare-only local model id or path.",
|
| 1621 |
)
|
| 1622 |
|
| 1623 |
+
previous_model = st.session_state.get("analysis:local_model_select", fallback_model)
|
| 1624 |
if not any(local_model_matches(previous_model, option) for option in model_options):
|
| 1625 |
previous_model = fallback_model
|
| 1626 |
default_model = next(
|
|
|
|
| 1635 |
"Local model",
|
| 1636 |
options=model_options,
|
| 1637 |
index=model_options.index(default_model),
|
| 1638 |
+
key="analysis:local_model_select",
|
| 1639 |
help="Models discovered under the selected artifacts root.",
|
| 1640 |
)
|
| 1641 |
+
st.session_state["analysis:local_model"] = selected
|
| 1642 |
return selected
|
| 1643 |
|
| 1644 |
|
|
|
|
| 1646 |
if source == SOURCE_HUB:
|
| 1647 |
repo = st.text_input(
|
| 1648 |
"Hub repo",
|
| 1649 |
+
value=st.session_state.get("analysis:hub_repo", DEFAULT_HUB_REPO),
|
| 1650 |
+
key="analysis:hub_repo",
|
| 1651 |
help="Hugging Face dataset published by `scripts/push_to_hf.py`.",
|
| 1652 |
)
|
| 1653 |
hub_model_name = _render_hub_model_select(repo, mask_strategy)
|
|
|
|
| 1660 |
artifacts_root = st.text_input(
|
| 1661 |
"Artifacts root",
|
| 1662 |
value=str(get_artifacts_dir() / "activations"),
|
| 1663 |
+
key="analysis:artifacts_root",
|
| 1664 |
)
|
| 1665 |
artifacts_root = str(Path(artifacts_root).expanduser())
|
| 1666 |
local_model_name = _render_local_model_select(artifacts_root, mask_strategy)
|
|
|
|
| 1672 |
)
|
| 1673 |
|
| 1674 |
|
| 1675 |
+
def render_analysis_tab() -> None:
|
| 1676 |
"""Render the analysis tab."""
|
| 1677 |
|
| 1678 |
st.title("Analysis")
|
| 1679 |
st.caption(
|
| 1680 |
+
"Analyse persona vectors by cosine similarity, PCA, UMAP, Isomap, or hierarchical clustering."
|
| 1681 |
)
|
| 1682 |
|
| 1683 |
source = _render_source_select()
|
|
|
|
| 1720 |
_render_dendrogram_analysis(store, mask_strategy)
|
| 1721 |
return
|
| 1722 |
|
| 1723 |
+
dim_options = ["2D", "3D"]
|
| 1724 |
+
dim_key = widget_key("load", "projection_dims", analysis_mode)
|
| 1725 |
+
remembered_dim = st.session_state.get(
|
| 1726 |
+
dim_key,
|
| 1727 |
+
st.session_state.get(_LAST_PROJECTION_DIMS_KEY, "2D"),
|
| 1728 |
+
)
|
| 1729 |
+
if remembered_dim not in dim_options:
|
| 1730 |
+
remembered_dim = "2D"
|
| 1731 |
dimension_choice = st.segmented_control(
|
| 1732 |
"Projection dimensions",
|
| 1733 |
+
options=dim_options,
|
| 1734 |
+
default=remembered_dim,
|
| 1735 |
+
key=dim_key,
|
| 1736 |
label_visibility="collapsed",
|
| 1737 |
)
|
| 1738 |
+
if dimension_choice is not None:
|
| 1739 |
+
st.session_state[_LAST_PROJECTION_DIMS_KEY] = dimension_choice
|
| 1740 |
n_components = 3 if dimension_choice == "3D" else 2
|
| 1741 |
dim_suffix = "" if n_components == 2 else " (3D)"
|
| 1742 |
_render_layered_figure_analysis(
|
tabs/chat.py
CHANGED
|
@@ -1,50 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from persona_data.synth_persona import PersonaData
|
| 3 |
|
| 4 |
-
from state import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from tabs.chat_ui import (
|
| 6 |
GenerationConfig,
|
| 7 |
render_advanced_settings,
|
| 8 |
render_chat_window,
|
| 9 |
-
render_persona_prompt_controls,
|
| 10 |
render_system_prompt,
|
| 11 |
)
|
| 12 |
-
from utils.chat import
|
| 13 |
-
ChatReply,
|
| 14 |
-
build_chat_messages,
|
| 15 |
-
generate_chat_reply,
|
| 16 |
-
resolve_system_prompt,
|
| 17 |
-
)
|
| 18 |
from utils.chat_export import save_chat_export
|
| 19 |
-
from utils.
|
| 20 |
-
from utils.helpers import widget_key
|
| 21 |
from utils.runtime import cached_model
|
| 22 |
|
| 23 |
-
_LAST_PERSONA_ID_KEY = "chat
|
| 24 |
-
_LAST_PROMPT_MODE_KEY = "chat
|
| 25 |
-
_LAST_COMPARE_MODE_KEY = "chat
|
| 26 |
-
_LAST_PROBE_ENABLED_KEY = "chat
|
| 27 |
-
_LAST_TOKEN_CONTRAST_KEY = "chat
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def _load_personas(dataset_source: str) -> list[PersonaData] | None:
|
| 31 |
-
try:
|
| 32 |
-
personas, dataset_status = load_persona_list(
|
| 33 |
-
dataset_source,
|
| 34 |
-
personas_file=st.session_state.get("extract__personas_file"),
|
| 35 |
-
qa_file=st.session_state.get("extract__qa_file"),
|
| 36 |
-
)
|
| 37 |
-
st.caption(dataset_status)
|
| 38 |
-
except Exception as exc:
|
| 39 |
-
st.error(f"Could not load data: {exc}")
|
| 40 |
-
st.info("Check the selected dataset source or upload both JSONL files.")
|
| 41 |
-
return None
|
| 42 |
-
|
| 43 |
-
if not personas:
|
| 44 |
-
st.warning("No personas found in the selected dataset.")
|
| 45 |
-
st.info("Try a different dataset source or upload a non-empty personas file.")
|
| 46 |
-
return None
|
| 47 |
-
return personas
|
| 48 |
|
| 49 |
|
| 50 |
def _render_single_chat_footer(
|
|
@@ -99,27 +88,32 @@ def _handle_single_chat_generation(
|
|
| 99 |
chat_state: ChatState,
|
| 100 |
active_system_prompt: str | None,
|
| 101 |
generation: GenerationConfig,
|
| 102 |
-
pending_action:
|
| 103 |
chat_log,
|
| 104 |
) -> None:
|
| 105 |
messages = build_chat_messages(active_system_prompt, chat_state["messages"])
|
| 106 |
|
| 107 |
with st.spinner("Generating reply..."):
|
| 108 |
model = cached_model(model_name=model_name)
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
model=model,
|
| 112 |
-
messages=messages,
|
| 113 |
-
remote=remote,
|
| 114 |
-
**generation.to_generate_kwargs(),
|
| 115 |
-
)
|
| 116 |
-
except Exception as exc:
|
| 117 |
with chat_log:
|
| 118 |
st.error(f"Could not generate a reply: {exc}")
|
| 119 |
st.info("Try a shorter prompt, reset the chat, or switch personas.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
if pending_action == "new_user_prompt" and chat_state["messages"]:
|
| 121 |
chat_state["messages"].pop()
|
| 122 |
return
|
|
|
|
|
|
|
| 123 |
|
| 124 |
chat_state["messages"].append({"role": "assistant", "content": reply.text})
|
| 125 |
st.rerun()
|
|
@@ -132,16 +126,14 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 132 |
st.caption("Chat with a persona, optionally side-by-side or with token contrast.")
|
| 133 |
|
| 134 |
context_key = chat_session_key(model_name, dataset_source)
|
| 135 |
-
chat_state = get_chat_state(model_name,
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
_LAST_PROMPT_MODE_KEY, "templated"
|
| 142 |
-
)
|
| 143 |
|
| 144 |
-
personas =
|
| 145 |
if personas is None:
|
| 146 |
return
|
| 147 |
|
|
@@ -166,7 +158,6 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 166 |
)
|
| 167 |
return
|
| 168 |
|
| 169 |
-
# ── Single-chat mode ──────────────────────────────────────────────────────
|
| 170 |
persona_select_key = widget_key(context_key, "persona_select")
|
| 171 |
prompt_mode_select_key = widget_key(context_key, "system_prompt_select")
|
| 172 |
prompt_key = widget_key(context_key, "custom_system_prompt")
|
|
@@ -176,6 +167,20 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 176 |
reset_key = widget_key(context_key, "reset")
|
| 177 |
edit_key = widget_key(context_key, "edit_idx")
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
def _reset_active_chat_context() -> None:
|
| 180 |
reset_chat_context_state(
|
| 181 |
chat_state,
|
|
@@ -187,17 +192,6 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 187 |
)
|
| 188 |
st.session_state.pop(edit_key, None)
|
| 189 |
|
| 190 |
-
selected_persona, prompt_mode, changed_context = render_persona_prompt_controls(
|
| 191 |
-
personas,
|
| 192 |
-
chat_state["persona_id"],
|
| 193 |
-
chat_state["prompt_mode"],
|
| 194 |
-
persona_select_key,
|
| 195 |
-
prompt_mode_select_key,
|
| 196 |
-
column_widths=(2, 1),
|
| 197 |
-
)
|
| 198 |
-
st.session_state[_LAST_PERSONA_ID_KEY] = selected_persona.id
|
| 199 |
-
st.session_state[_LAST_PROMPT_MODE_KEY] = prompt_mode
|
| 200 |
-
|
| 201 |
active_system_prompt = resolve_system_prompt(
|
| 202 |
persona=selected_persona,
|
| 203 |
mode=prompt_mode,
|
|
@@ -259,14 +253,15 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 259 |
|
| 260 |
user_prompt = st.chat_input("Ask something...", key=chat_input_key)
|
| 261 |
|
| 262 |
-
# Pass 1: user submitted — append message and rerun so it renders before generation.
|
| 263 |
if user_prompt:
|
| 264 |
chat_state["messages"].append({"role": "user", "content": user_prompt})
|
| 265 |
st.session_state[pending_key] = "new_user_prompt"
|
| 266 |
st.rerun()
|
| 267 |
|
| 268 |
-
|
| 269 |
-
|
|
|
|
|
|
|
| 270 |
if not pending_action:
|
| 271 |
return
|
| 272 |
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import cast
|
| 4 |
+
|
| 5 |
import streamlit as st
|
| 6 |
from persona_data.synth_persona import PersonaData
|
| 7 |
|
| 8 |
+
from state import (
|
| 9 |
+
ChatState,
|
| 10 |
+
PendingChatAction,
|
| 11 |
+
chat_session_key,
|
| 12 |
+
get_chat_state,
|
| 13 |
+
reset_chat_context_state,
|
| 14 |
+
)
|
| 15 |
+
from tabs.chat_shared import (
|
| 16 |
+
generate_chat_reply_result,
|
| 17 |
+
hydrate_chat_state,
|
| 18 |
+
load_chat_personas,
|
| 19 |
+
render_chat_selection,
|
| 20 |
+
)
|
| 21 |
from tabs.chat_ui import (
|
| 22 |
GenerationConfig,
|
| 23 |
render_advanced_settings,
|
| 24 |
render_chat_window,
|
|
|
|
| 25 |
render_system_prompt,
|
| 26 |
)
|
| 27 |
+
from utils.chat import build_chat_messages, resolve_system_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
from utils.chat_export import save_chat_export
|
| 29 |
+
from utils.helpers import session_key, widget_key
|
|
|
|
| 30 |
from utils.runtime import cached_model
|
| 31 |
|
| 32 |
+
_LAST_PERSONA_ID_KEY = session_key("chat", "last_persona_id")
|
| 33 |
+
_LAST_PROMPT_MODE_KEY = session_key("chat", "last_prompt_mode")
|
| 34 |
+
_LAST_COMPARE_MODE_KEY = session_key("chat", "last_compare_mode")
|
| 35 |
+
_LAST_PROBE_ENABLED_KEY = session_key("chat", "last_probe_enabled")
|
| 36 |
+
_LAST_TOKEN_CONTRAST_KEY = session_key("chat", "last_token_contrast")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
def _render_single_chat_footer(
|
|
|
|
| 88 |
chat_state: ChatState,
|
| 89 |
active_system_prompt: str | None,
|
| 90 |
generation: GenerationConfig,
|
| 91 |
+
pending_action: PendingChatAction,
|
| 92 |
chat_log,
|
| 93 |
) -> None:
|
| 94 |
messages = build_chat_messages(active_system_prompt, chat_state["messages"])
|
| 95 |
|
| 96 |
with st.spinner("Generating reply..."):
|
| 97 |
model = cached_model(model_name=model_name)
|
| 98 |
+
|
| 99 |
+
def _show_error(exc: Exception) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
with chat_log:
|
| 101 |
st.error(f"Could not generate a reply: {exc}")
|
| 102 |
st.info("Try a shorter prompt, reset the chat, or switch personas.")
|
| 103 |
+
|
| 104 |
+
reply, error = generate_chat_reply_result(
|
| 105 |
+
model=model,
|
| 106 |
+
messages=messages,
|
| 107 |
+
remote=remote,
|
| 108 |
+
generation=generation,
|
| 109 |
+
on_error=_show_error,
|
| 110 |
+
)
|
| 111 |
+
if error is not None:
|
| 112 |
if pending_action == "new_user_prompt" and chat_state["messages"]:
|
| 113 |
chat_state["messages"].pop()
|
| 114 |
return
|
| 115 |
+
if reply is None:
|
| 116 |
+
return
|
| 117 |
|
| 118 |
chat_state["messages"].append({"role": "assistant", "content": reply.text})
|
| 119 |
st.rerun()
|
|
|
|
| 126 |
st.caption("Chat with a persona, optionally side-by-side or with token contrast.")
|
| 127 |
|
| 128 |
context_key = chat_session_key(model_name, dataset_source)
|
| 129 |
+
chat_state = get_chat_state(model_name, dataset_source)
|
| 130 |
+
hydrate_chat_state(
|
| 131 |
+
chat_state,
|
| 132 |
+
persisted_persona_key=_LAST_PERSONA_ID_KEY,
|
| 133 |
+
persisted_prompt_key=_LAST_PROMPT_MODE_KEY,
|
| 134 |
+
)
|
|
|
|
|
|
|
| 135 |
|
| 136 |
+
personas = load_chat_personas(dataset_source)
|
| 137 |
if personas is None:
|
| 138 |
return
|
| 139 |
|
|
|
|
| 158 |
)
|
| 159 |
return
|
| 160 |
|
|
|
|
| 161 |
persona_select_key = widget_key(context_key, "persona_select")
|
| 162 |
prompt_mode_select_key = widget_key(context_key, "system_prompt_select")
|
| 163 |
prompt_key = widget_key(context_key, "custom_system_prompt")
|
|
|
|
| 167 |
reset_key = widget_key(context_key, "reset")
|
| 168 |
edit_key = widget_key(context_key, "edit_idx")
|
| 169 |
|
| 170 |
+
selection = render_chat_selection(
|
| 171 |
+
personas,
|
| 172 |
+
chat_state["persona_id"],
|
| 173 |
+
chat_state["prompt_mode"],
|
| 174 |
+
persona_select_key,
|
| 175 |
+
prompt_mode_select_key,
|
| 176 |
+
persisted_persona_key=_LAST_PERSONA_ID_KEY,
|
| 177 |
+
persisted_prompt_key=_LAST_PROMPT_MODE_KEY,
|
| 178 |
+
column_widths=(2, 1),
|
| 179 |
+
)
|
| 180 |
+
selected_persona = selection.persona
|
| 181 |
+
prompt_mode = selection.prompt_mode
|
| 182 |
+
changed_context = selection.changed
|
| 183 |
+
|
| 184 |
def _reset_active_chat_context() -> None:
|
| 185 |
reset_chat_context_state(
|
| 186 |
chat_state,
|
|
|
|
| 192 |
)
|
| 193 |
st.session_state.pop(edit_key, None)
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
active_system_prompt = resolve_system_prompt(
|
| 196 |
persona=selected_persona,
|
| 197 |
mode=prompt_mode,
|
|
|
|
| 253 |
|
| 254 |
user_prompt = st.chat_input("Ask something...", key=chat_input_key)
|
| 255 |
|
|
|
|
| 256 |
if user_prompt:
|
| 257 |
chat_state["messages"].append({"role": "user", "content": user_prompt})
|
| 258 |
st.session_state[pending_key] = "new_user_prompt"
|
| 259 |
st.rerun()
|
| 260 |
|
| 261 |
+
pending_action = cast(
|
| 262 |
+
PendingChatAction | None,
|
| 263 |
+
st.session_state.pop(pending_key, None),
|
| 264 |
+
)
|
| 265 |
if not pending_action:
|
| 266 |
return
|
| 267 |
|
tabs/chat_shared.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from collections.abc import Callable
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
import streamlit as st
|
| 7 |
+
from persona_data.synth_persona import PersonaData
|
| 8 |
+
|
| 9 |
+
from state import ChatState
|
| 10 |
+
from tabs.chat_ui import GenerationConfig, render_persona_prompt_controls
|
| 11 |
+
from utils.chat import ChatReply, generate_chat_reply
|
| 12 |
+
from utils.datasets import load_persona_list
|
| 13 |
+
from utils.helpers import session_key
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass(frozen=True)
|
| 17 |
+
class ChatSelection:
|
| 18 |
+
persona: PersonaData
|
| 19 |
+
prompt_mode: str
|
| 20 |
+
changed: bool
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_chat_personas(dataset_source: str) -> list[PersonaData] | None:
|
| 24 |
+
personas_file_key = session_key("extract", "personas_file")
|
| 25 |
+
qa_file_key = session_key("extract", "qa_file")
|
| 26 |
+
try:
|
| 27 |
+
personas, dataset_status = load_persona_list(
|
| 28 |
+
dataset_source,
|
| 29 |
+
personas_file=st.session_state.get(personas_file_key),
|
| 30 |
+
qa_file=st.session_state.get(qa_file_key),
|
| 31 |
+
)
|
| 32 |
+
st.caption(dataset_status)
|
| 33 |
+
except Exception as exc:
|
| 34 |
+
st.error(f"Could not load data: {exc}")
|
| 35 |
+
st.info("Check the selected dataset source or upload both JSONL files.")
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
if not personas:
|
| 39 |
+
st.warning("No personas found in the selected dataset.")
|
| 40 |
+
st.info("Try a different dataset source or upload a non-empty personas file.")
|
| 41 |
+
return None
|
| 42 |
+
return personas
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def hydrate_chat_state(
|
| 46 |
+
state: ChatState,
|
| 47 |
+
*,
|
| 48 |
+
persisted_persona_key: str,
|
| 49 |
+
persisted_prompt_key: str,
|
| 50 |
+
default_prompt_mode: str = "templated",
|
| 51 |
+
) -> None:
|
| 52 |
+
if state["persona_id"] is None:
|
| 53 |
+
state["persona_id"] = st.session_state.get(persisted_persona_key)
|
| 54 |
+
state["prompt_mode"] = st.session_state.get(
|
| 55 |
+
persisted_prompt_key,
|
| 56 |
+
default_prompt_mode,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def render_chat_selection(
|
| 61 |
+
personas: list[PersonaData],
|
| 62 |
+
current_persona_id: str | None,
|
| 63 |
+
current_prompt_mode: str,
|
| 64 |
+
persona_key: str,
|
| 65 |
+
prompt_key: str,
|
| 66 |
+
*,
|
| 67 |
+
persisted_persona_key: str,
|
| 68 |
+
persisted_prompt_key: str,
|
| 69 |
+
column_widths: tuple[int, int] = (3, 2),
|
| 70 |
+
) -> ChatSelection:
|
| 71 |
+
selected_persona, prompt_mode, changed = render_persona_prompt_controls(
|
| 72 |
+
personas,
|
| 73 |
+
current_persona_id,
|
| 74 |
+
current_prompt_mode,
|
| 75 |
+
persona_key,
|
| 76 |
+
prompt_key,
|
| 77 |
+
column_widths=column_widths,
|
| 78 |
+
)
|
| 79 |
+
st.session_state[persisted_persona_key] = selected_persona.id
|
| 80 |
+
st.session_state[persisted_prompt_key] = prompt_mode
|
| 81 |
+
return ChatSelection(selected_persona, prompt_mode, changed)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def generate_chat_reply_result(
|
| 85 |
+
*,
|
| 86 |
+
model: object,
|
| 87 |
+
messages: list[dict[str, str]],
|
| 88 |
+
remote: bool,
|
| 89 |
+
generation: GenerationConfig,
|
| 90 |
+
on_error: Callable[[Exception], None] | None = None,
|
| 91 |
+
) -> tuple[ChatReply | None, Exception | None]:
|
| 92 |
+
try:
|
| 93 |
+
return (
|
| 94 |
+
generate_chat_reply(
|
| 95 |
+
model=model,
|
| 96 |
+
messages=messages,
|
| 97 |
+
remote=remote,
|
| 98 |
+
**generation.to_generate_kwargs(),
|
| 99 |
+
),
|
| 100 |
+
None,
|
| 101 |
+
)
|
| 102 |
+
except Exception as exc:
|
| 103 |
+
if on_error is not None:
|
| 104 |
+
on_error(exc)
|
| 105 |
+
return None, exc
|
tabs/chat_ui.py
CHANGED
|
@@ -29,19 +29,21 @@ GENERATION_DEFAULTS = {
|
|
| 29 |
_LAST_GEN_PREFIX = "chat:last_gen:"
|
| 30 |
|
| 31 |
|
| 32 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
"""Per-context widget key, seeded from the last cross-context value."""
|
| 34 |
-
last_key = f"{_LAST_GEN_PREFIX}{name}"
|
| 35 |
key = widget_key(context_key, name)
|
| 36 |
if key not in st.session_state:
|
| 37 |
-
st.session_state[key] = st.session_state.get(
|
|
|
|
|
|
|
|
|
|
| 38 |
return key
|
| 39 |
|
| 40 |
|
| 41 |
-
def _remember(name: str, value) -> None:
|
| 42 |
-
st.session_state[f"{_LAST_GEN_PREFIX}{name}"] = value
|
| 43 |
-
|
| 44 |
-
|
| 45 |
@dataclass(frozen=True)
|
| 46 |
class GenerationConfig:
|
| 47 |
max_new_tokens: int
|
|
@@ -100,7 +102,7 @@ def _open_edit_dialog(
|
|
| 100 |
|
| 101 |
save_col, cancel_col = st.columns(2)
|
| 102 |
with save_col:
|
| 103 |
-
if st.button("Save", type="primary",
|
| 104 |
messages[msg_index]["content"] = new_content
|
| 105 |
messages[msg_index].pop("_contrast", None)
|
| 106 |
if role == "assistant":
|
|
@@ -110,7 +112,7 @@ def _open_edit_dialog(
|
|
| 110 |
st.session_state[pending_key] = "regenerate_after_edit"
|
| 111 |
st.rerun()
|
| 112 |
with cancel_col:
|
| 113 |
-
if st.button("Cancel",
|
| 114 |
st.rerun()
|
| 115 |
|
| 116 |
|
|
@@ -129,13 +131,13 @@ def _open_system_prompt_dialog(
|
|
| 129 |
)
|
| 130 |
save_col, cancel_col = st.columns(2)
|
| 131 |
with save_col:
|
| 132 |
-
if st.button("Save", type="primary",
|
| 133 |
st.session_state[prompt_key] = new_value
|
| 134 |
if on_save is not None:
|
| 135 |
on_save()
|
| 136 |
st.rerun()
|
| 137 |
with cancel_col:
|
| 138 |
-
if st.button("Cancel",
|
| 139 |
st.rerun()
|
| 140 |
|
| 141 |
|
|
@@ -307,7 +309,9 @@ def _render_generation_fragment(context_key: str, remote: bool) -> GenerationCon
|
|
| 307 |
("top_k", top_k),
|
| 308 |
("seed_enabled", seed_enabled),
|
| 309 |
):
|
| 310 |
-
|
|
|
|
|
|
|
| 311 |
|
| 312 |
do_sample = bool(use_sampling)
|
| 313 |
return GenerationConfig(
|
|
|
|
| 29 |
_LAST_GEN_PREFIX = "chat:last_gen:"
|
| 30 |
|
| 31 |
|
| 32 |
+
def _last_generation_key(name: str) -> str:
|
| 33 |
+
return f"{_LAST_GEN_PREFIX}{name}"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _persisted_key(context_key: str, name: str, default: object) -> str:
|
| 37 |
"""Per-context widget key, seeded from the last cross-context value."""
|
|
|
|
| 38 |
key = widget_key(context_key, name)
|
| 39 |
if key not in st.session_state:
|
| 40 |
+
st.session_state[key] = st.session_state.get(
|
| 41 |
+
_last_generation_key(name),
|
| 42 |
+
default,
|
| 43 |
+
)
|
| 44 |
return key
|
| 45 |
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
@dataclass(frozen=True)
|
| 48 |
class GenerationConfig:
|
| 49 |
max_new_tokens: int
|
|
|
|
| 102 |
|
| 103 |
save_col, cancel_col = st.columns(2)
|
| 104 |
with save_col:
|
| 105 |
+
if st.button("Save", type="primary", width="stretch"):
|
| 106 |
messages[msg_index]["content"] = new_content
|
| 107 |
messages[msg_index].pop("_contrast", None)
|
| 108 |
if role == "assistant":
|
|
|
|
| 112 |
st.session_state[pending_key] = "regenerate_after_edit"
|
| 113 |
st.rerun()
|
| 114 |
with cancel_col:
|
| 115 |
+
if st.button("Cancel", width="stretch"):
|
| 116 |
st.rerun()
|
| 117 |
|
| 118 |
|
|
|
|
| 131 |
)
|
| 132 |
save_col, cancel_col = st.columns(2)
|
| 133 |
with save_col:
|
| 134 |
+
if st.button("Save", type="primary", width="stretch"):
|
| 135 |
st.session_state[prompt_key] = new_value
|
| 136 |
if on_save is not None:
|
| 137 |
on_save()
|
| 138 |
st.rerun()
|
| 139 |
with cancel_col:
|
| 140 |
+
if st.button("Cancel", width="stretch"):
|
| 141 |
st.rerun()
|
| 142 |
|
| 143 |
|
|
|
|
| 309 |
("top_k", top_k),
|
| 310 |
("seed_enabled", seed_enabled),
|
| 311 |
):
|
| 312 |
+
st.session_state[_last_generation_key(name)] = value
|
| 313 |
+
if seed is not None:
|
| 314 |
+
st.session_state[_last_generation_key("seed")] = seed
|
| 315 |
|
| 316 |
do_sample = bool(use_sampling)
|
| 317 |
return GenerationConfig(
|
tabs/compare_chat.py
CHANGED
|
@@ -6,22 +6,21 @@ from nnterp import StandardizedTransformer
|
|
| 6 |
from persona_data.synth_persona import PersonaData
|
| 7 |
|
| 8 |
from state import ChatState, default_chat_state, reset_chat_context_state
|
| 9 |
-
from
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
resolve_system_prompt,
|
| 14 |
)
|
|
|
|
| 15 |
from utils.chat_export import save_chat_export
|
| 16 |
from utils.contrast import compute_contrast, compute_contrast_pair
|
| 17 |
-
from utils.helpers import persona_label, widget_key
|
| 18 |
from utils.runtime import cached_model
|
| 19 |
|
| 20 |
from .chat_ui import (
|
| 21 |
GenerationConfig,
|
| 22 |
render_chat_message,
|
| 23 |
render_chat_window,
|
| 24 |
-
render_persona_prompt_controls,
|
| 25 |
render_system_prompt,
|
| 26 |
)
|
| 27 |
|
|
@@ -68,21 +67,26 @@ def _render_compare_panel(
|
|
| 68 |
edit_key = widget_key(panel_key, "edit_idx")
|
| 69 |
pending_key = widget_key(panel_key, "pending_regen")
|
| 70 |
|
| 71 |
-
persist_persona_key =
|
| 72 |
-
persist_prompt_key =
|
| 73 |
-
|
| 74 |
-
state
|
| 75 |
-
|
|
|
|
|
|
|
| 76 |
|
| 77 |
-
|
| 78 |
personas,
|
| 79 |
state["persona_id"],
|
| 80 |
state["prompt_mode"],
|
| 81 |
widget_key(panel_key, "persona"),
|
| 82 |
widget_key(panel_key, "prompt_mode"),
|
|
|
|
|
|
|
| 83 |
)
|
| 84 |
-
|
| 85 |
-
|
|
|
|
| 86 |
|
| 87 |
if changed:
|
| 88 |
reset_chat_context_state(
|
|
@@ -136,19 +140,13 @@ def _generate_panels(
|
|
| 136 |
results: list[ChatReply | Exception] = []
|
| 137 |
with st.spinner(spinner_label):
|
| 138 |
for panel in panels:
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
remote=remote,
|
| 147 |
-
**generation.to_generate_kwargs(),
|
| 148 |
-
)
|
| 149 |
-
)
|
| 150 |
-
except Exception as exc:
|
| 151 |
-
results.append(exc)
|
| 152 |
return results
|
| 153 |
|
| 154 |
|
|
|
|
| 6 |
from persona_data.synth_persona import PersonaData
|
| 7 |
|
| 8 |
from state import ChatState, default_chat_state, reset_chat_context_state
|
| 9 |
+
from tabs.chat_shared import (
|
| 10 |
+
generate_chat_reply_result,
|
| 11 |
+
hydrate_chat_state,
|
| 12 |
+
render_chat_selection,
|
|
|
|
| 13 |
)
|
| 14 |
+
from utils.chat import ChatReply, build_chat_messages, resolve_system_prompt
|
| 15 |
from utils.chat_export import save_chat_export
|
| 16 |
from utils.contrast import compute_contrast, compute_contrast_pair
|
| 17 |
+
from utils.helpers import persona_label, session_key, widget_key
|
| 18 |
from utils.runtime import cached_model
|
| 19 |
|
| 20 |
from .chat_ui import (
|
| 21 |
GenerationConfig,
|
| 22 |
render_chat_message,
|
| 23 |
render_chat_window,
|
|
|
|
| 24 |
render_system_prompt,
|
| 25 |
)
|
| 26 |
|
|
|
|
| 67 |
edit_key = widget_key(panel_key, "edit_idx")
|
| 68 |
pending_key = widget_key(panel_key, "pending_regen")
|
| 69 |
|
| 70 |
+
persist_persona_key = session_key("chat", f"last_cmp_{side}_persona")
|
| 71 |
+
persist_prompt_key = session_key("chat", f"last_cmp_{side}_prompt")
|
| 72 |
+
hydrate_chat_state(
|
| 73 |
+
state,
|
| 74 |
+
persisted_persona_key=persist_persona_key,
|
| 75 |
+
persisted_prompt_key=persist_prompt_key,
|
| 76 |
+
)
|
| 77 |
|
| 78 |
+
selection = render_chat_selection(
|
| 79 |
personas,
|
| 80 |
state["persona_id"],
|
| 81 |
state["prompt_mode"],
|
| 82 |
widget_key(panel_key, "persona"),
|
| 83 |
widget_key(panel_key, "prompt_mode"),
|
| 84 |
+
persisted_persona_key=persist_persona_key,
|
| 85 |
+
persisted_prompt_key=persist_prompt_key,
|
| 86 |
)
|
| 87 |
+
selected_persona = selection.persona
|
| 88 |
+
prompt_mode = selection.prompt_mode
|
| 89 |
+
changed = selection.changed
|
| 90 |
|
| 91 |
if changed:
|
| 92 |
reset_chat_context_state(
|
|
|
|
| 140 |
results: list[ChatReply | Exception] = []
|
| 141 |
with st.spinner(spinner_label):
|
| 142 |
for panel in panels:
|
| 143 |
+
reply, error = generate_chat_reply_result(
|
| 144 |
+
model=model,
|
| 145 |
+
messages=build_chat_messages(panel.prompt, panel.state["messages"]),
|
| 146 |
+
remote=remote,
|
| 147 |
+
generation=generation,
|
| 148 |
+
)
|
| 149 |
+
results.append(reply if error is None else error)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
return results
|
| 151 |
|
| 152 |
|
tabs/extract.py
CHANGED
|
@@ -5,7 +5,7 @@ import streamlit as st
|
|
| 5 |
from catppuccin import PALETTE
|
| 6 |
from persona_data.prompts import format_prompt
|
| 7 |
from persona_data.synth_persona import BASELINE_PERSONA_ID, PersonaData, QAPair
|
| 8 |
-
from persona_vectors.artifacts import
|
| 9 |
from persona_vectors.extraction import (
|
| 10 |
MaskStrategy,
|
| 11 |
prepare_inputs_for_strategy,
|
|
@@ -14,11 +14,12 @@ from persona_vectors.extraction import (
|
|
| 14 |
from persona_vectors.preview import TokenSegment, preview_token_segments
|
| 15 |
|
| 16 |
from utils.controls import render_mask_strategy_select
|
| 17 |
-
from utils.datasets import load_dataset,
|
| 18 |
from utils.helpers import (
|
| 19 |
NDIF_STATUS_ICONS,
|
| 20 |
persona_label,
|
| 21 |
prompt_variant_label,
|
|
|
|
| 22 |
widget_key,
|
| 23 |
)
|
| 24 |
from utils.runtime import cached_model
|
|
@@ -29,6 +30,9 @@ _LAST_PERSONA_IDS_KEY = "extract:last_persona_ids"
|
|
| 29 |
_LAST_MAX_QUESTIONS_KEY = "extract:last_max_questions"
|
| 30 |
_LAST_MASK_STRATEGY_KEY = "extract:last_mask_strategy"
|
| 31 |
|
|
|
|
|
|
|
|
|
|
| 32 |
_DEFAULT_MAX_QUESTIONS = 50
|
| 33 |
|
| 34 |
|
|
@@ -42,7 +46,7 @@ def _build_run_plan(
|
|
| 42 |
selected_variants: list[str],
|
| 43 |
runs: list[tuple[PersonaData, list[QAPair]]],
|
| 44 |
) -> list[tuple[PersonaData, list[QAPair], str]]:
|
| 45 |
-
"""Cartesian product of personas
|
| 46 |
return [(p, qa, v) for v in selected_variants for p, qa in runs]
|
| 47 |
|
| 48 |
|
|
@@ -63,13 +67,13 @@ def _render_local_dataset_upload(dataset_source: str) -> None:
|
|
| 63 |
st.file_uploader(
|
| 64 |
"personas.jsonl",
|
| 65 |
type=["jsonl"],
|
| 66 |
-
key=
|
| 67 |
help="Expected fields: id, persona, templated_view, biography_view",
|
| 68 |
)
|
| 69 |
st.file_uploader(
|
| 70 |
"qa.jsonl",
|
| 71 |
type=["jsonl"],
|
| 72 |
-
key=
|
| 73 |
help="Expected fields: id, qid, type, item_type, scope, question, answer",
|
| 74 |
)
|
| 75 |
|
|
@@ -80,12 +84,14 @@ def _render_variant_controls(
|
|
| 80 |
remote: bool,
|
| 81 |
dataset_source: str,
|
| 82 |
) -> tuple[list[str], bool] | None:
|
| 83 |
-
default_variants = st.session_state.get(
|
|
|
|
|
|
|
| 84 |
selected_variants = st.multiselect(
|
| 85 |
"Persona variants",
|
| 86 |
-
options=
|
| 87 |
-
default=[v for v in default_variants if v in
|
| 88 |
-
or list(
|
| 89 |
format_func=prompt_variant_label,
|
| 90 |
key=_extract_widget_key(model_name, remote, dataset_source, "persona_variants"),
|
| 91 |
help="Extract these variants for each selected persona.",
|
|
@@ -110,14 +116,10 @@ def _load_qa_dataset_personas(
|
|
| 110 |
try:
|
| 111 |
dataset, dataset_status = load_dataset(
|
| 112 |
dataset_source,
|
| 113 |
-
personas_file=st.session_state.get(
|
| 114 |
-
qa_file=st.session_state.get(
|
| 115 |
-
)
|
| 116 |
-
personas, _ = load_persona_list(
|
| 117 |
-
dataset_source,
|
| 118 |
-
personas_file=st.session_state.get("extract__personas_file"),
|
| 119 |
-
qa_file=st.session_state.get("extract__qa_file"),
|
| 120 |
)
|
|
|
|
| 121 |
st.caption(dataset_status)
|
| 122 |
except Exception as exc:
|
| 123 |
st.error(f"Could not load data: {exc}")
|
|
@@ -289,10 +291,10 @@ def _render_extract_actions() -> tuple[bool, bool]:
|
|
| 289 |
run_clicked = st.button(
|
| 290 |
"Run extraction",
|
| 291 |
type="primary",
|
| 292 |
-
|
| 293 |
)
|
| 294 |
with preview_col:
|
| 295 |
-
preview_clicked = st.button("Preview tokens",
|
| 296 |
return run_clicked, preview_clicked
|
| 297 |
|
| 298 |
|
|
|
|
| 5 |
from catppuccin import PALETTE
|
| 6 |
from persona_data.prompts import format_prompt
|
| 7 |
from persona_data.synth_persona import BASELINE_PERSONA_ID, PersonaData, QAPair
|
| 8 |
+
from persona_vectors.artifacts import SUPPORTED_VARIANTS
|
| 9 |
from persona_vectors.extraction import (
|
| 10 |
MaskStrategy,
|
| 11 |
prepare_inputs_for_strategy,
|
|
|
|
| 14 |
from persona_vectors.preview import TokenSegment, preview_token_segments
|
| 15 |
|
| 16 |
from utils.controls import render_mask_strategy_select
|
| 17 |
+
from utils.datasets import load_dataset, load_persona_list_from_dataset
|
| 18 |
from utils.helpers import (
|
| 19 |
NDIF_STATUS_ICONS,
|
| 20 |
persona_label,
|
| 21 |
prompt_variant_label,
|
| 22 |
+
session_key,
|
| 23 |
widget_key,
|
| 24 |
)
|
| 25 |
from utils.runtime import cached_model
|
|
|
|
| 30 |
_LAST_MAX_QUESTIONS_KEY = "extract:last_max_questions"
|
| 31 |
_LAST_MASK_STRATEGY_KEY = "extract:last_mask_strategy"
|
| 32 |
|
| 33 |
+
_PERSONAS_FILE_KEY = session_key("extract", "personas_file")
|
| 34 |
+
_QA_FILE_KEY = session_key("extract", "qa_file")
|
| 35 |
+
|
| 36 |
_DEFAULT_MAX_QUESTIONS = 50
|
| 37 |
|
| 38 |
|
|
|
|
| 46 |
selected_variants: list[str],
|
| 47 |
runs: list[tuple[PersonaData, list[QAPair]]],
|
| 48 |
) -> list[tuple[PersonaData, list[QAPair], str]]:
|
| 49 |
+
"""Cartesian product of personas x variants."""
|
| 50 |
return [(p, qa, v) for v in selected_variants for p, qa in runs]
|
| 51 |
|
| 52 |
|
|
|
|
| 67 |
st.file_uploader(
|
| 68 |
"personas.jsonl",
|
| 69 |
type=["jsonl"],
|
| 70 |
+
key=_PERSONAS_FILE_KEY,
|
| 71 |
help="Expected fields: id, persona, templated_view, biography_view",
|
| 72 |
)
|
| 73 |
st.file_uploader(
|
| 74 |
"qa.jsonl",
|
| 75 |
type=["jsonl"],
|
| 76 |
+
key=_QA_FILE_KEY,
|
| 77 |
help="Expected fields: id, qid, type, item_type, scope, question, answer",
|
| 78 |
)
|
| 79 |
|
|
|
|
| 84 |
remote: bool,
|
| 85 |
dataset_source: str,
|
| 86 |
) -> tuple[list[str], bool] | None:
|
| 87 |
+
default_variants = st.session_state.get(
|
| 88 |
+
_LAST_VARIANTS_KEY, list(SUPPORTED_VARIANTS)
|
| 89 |
+
)
|
| 90 |
selected_variants = st.multiselect(
|
| 91 |
"Persona variants",
|
| 92 |
+
options=SUPPORTED_VARIANTS,
|
| 93 |
+
default=[v for v in default_variants if v in SUPPORTED_VARIANTS]
|
| 94 |
+
or list(SUPPORTED_VARIANTS),
|
| 95 |
format_func=prompt_variant_label,
|
| 96 |
key=_extract_widget_key(model_name, remote, dataset_source, "persona_variants"),
|
| 97 |
help="Extract these variants for each selected persona.",
|
|
|
|
| 116 |
try:
|
| 117 |
dataset, dataset_status = load_dataset(
|
| 118 |
dataset_source,
|
| 119 |
+
personas_file=st.session_state.get(_PERSONAS_FILE_KEY),
|
| 120 |
+
qa_file=st.session_state.get(_QA_FILE_KEY),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
)
|
| 122 |
+
personas = load_persona_list_from_dataset(dataset)
|
| 123 |
st.caption(dataset_status)
|
| 124 |
except Exception as exc:
|
| 125 |
st.error(f"Could not load data: {exc}")
|
|
|
|
| 291 |
run_clicked = st.button(
|
| 292 |
"Run extraction",
|
| 293 |
type="primary",
|
| 294 |
+
width="stretch",
|
| 295 |
)
|
| 296 |
with preview_col:
|
| 297 |
+
preview_clicked = st.button("Preview tokens", width="stretch")
|
| 298 |
return run_clicked, preview_clicked
|
| 299 |
|
| 300 |
|
tabs/probe_ui.py
CHANGED
|
@@ -197,7 +197,7 @@ def _trace_requested(context_key: str) -> bool:
|
|
| 197 |
if st.button(
|
| 198 |
"Trace conversation",
|
| 199 |
key=widget_key(context_key, "probe_trace"),
|
| 200 |
-
|
| 201 |
):
|
| 202 |
st.session_state[trace_key] = True
|
| 203 |
return bool(st.session_state.get(trace_key, False))
|
|
|
|
| 197 |
if st.button(
|
| 198 |
"Trace conversation",
|
| 199 |
key=widget_key(context_key, "probe_trace"),
|
| 200 |
+
width="stretch",
|
| 201 |
):
|
| 202 |
st.session_state[trace_key] = True
|
| 203 |
return bool(st.session_state.get(trace_key, False))
|
utils/{compare_sources.py → analysis_sources.py}
RENAMED
|
@@ -1,12 +1,10 @@
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
-
import
|
| 5 |
-
from persona_vectors.analysis import LayeredSamples
|
| 6 |
from persona_vectors.artifacts import (
|
| 7 |
ActivationStore,
|
| 8 |
HFActivationStore,
|
| 9 |
-
activation_config_name,
|
| 10 |
discover_activation_models,
|
| 11 |
model_dir_name,
|
| 12 |
)
|
|
@@ -25,28 +23,6 @@ SOURCE_LOCAL = "Local activations"
|
|
| 25 |
SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
|
| 26 |
|
| 27 |
|
| 28 |
-
def _hub_split(repo_id: str, model_name: str, mask_strategy_value: str, variant: str):
|
| 29 |
-
from datasets import load_dataset
|
| 30 |
-
|
| 31 |
-
return load_dataset(
|
| 32 |
-
repo_id,
|
| 33 |
-
name=activation_config_name(model_name, mask_strategy_value),
|
| 34 |
-
split=variant,
|
| 35 |
-
keep_in_memory=False,
|
| 36 |
-
)
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def _hub_split_columns(
|
| 40 |
-
repo_id: str,
|
| 41 |
-
model_name: str,
|
| 42 |
-
mask_strategy_value: str,
|
| 43 |
-
variant: str,
|
| 44 |
-
columns: list[str],
|
| 45 |
-
):
|
| 46 |
-
dataset = _hub_split(repo_id, model_name, mask_strategy_value, variant)
|
| 47 |
-
return dataset.select_columns(columns)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
@st.cache_resource(show_spinner=False, max_entries=1)
|
| 51 |
def activation_store_cached(
|
| 52 |
source: str,
|
|
@@ -67,8 +43,9 @@ def available_variants_cached(
|
|
| 67 |
model_name: str,
|
| 68 |
mask_strategy_value: str,
|
| 69 |
) -> list[str]:
|
| 70 |
-
|
| 71 |
-
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
@st.cache_data(show_spinner=False)
|
|
@@ -79,31 +56,9 @@ def personas_cached(
|
|
| 79 |
mask_strategy_value: str,
|
| 80 |
variants: tuple[str, ...],
|
| 81 |
) -> list[str]:
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
_hub_split_columns(
|
| 86 |
-
location,
|
| 87 |
-
model_name,
|
| 88 |
-
mask_strategy_value,
|
| 89 |
-
variant,
|
| 90 |
-
["persona_id"],
|
| 91 |
-
)["persona_id"]
|
| 92 |
-
)
|
| 93 |
-
for variant in variants
|
| 94 |
-
]
|
| 95 |
-
if not variant_ids:
|
| 96 |
-
return []
|
| 97 |
-
shared = set(variant_ids[0])
|
| 98 |
-
for ids in variant_ids[1:]:
|
| 99 |
-
shared &= set(ids)
|
| 100 |
-
return [persona_id for persona_id in variant_ids[0] if persona_id in shared]
|
| 101 |
-
|
| 102 |
-
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 103 |
-
return store.list_personas(
|
| 104 |
-
list(variants),
|
| 105 |
-
mask_strategy=MaskStrategy(mask_strategy_value),
|
| 106 |
-
)
|
| 107 |
|
| 108 |
|
| 109 |
@st.cache_data(show_spinner=False)
|
|
@@ -115,31 +70,24 @@ def persona_names_cached(
|
|
| 115 |
variants: tuple[str, ...],
|
| 116 |
persona_ids: tuple[str, ...],
|
| 117 |
) -> dict[str, str]:
|
| 118 |
-
if source == SOURCE_HUB:
|
| 119 |
-
requested = set(persona_ids)
|
| 120 |
-
names: dict[str, str] = {}
|
| 121 |
-
for variant in variants:
|
| 122 |
-
metadata = _hub_split_columns(
|
| 123 |
-
location,
|
| 124 |
-
model_name,
|
| 125 |
-
mask_strategy_value,
|
| 126 |
-
variant,
|
| 127 |
-
["persona_id", "name"],
|
| 128 |
-
)
|
| 129 |
-
for row in metadata:
|
| 130 |
-
persona_id = row["persona_id"]
|
| 131 |
-
if persona_id in requested and persona_id not in names:
|
| 132 |
-
names[persona_id] = row.get("name") or persona_id
|
| 133 |
-
if len(names) == len(requested):
|
| 134 |
-
return {pid: names.get(pid, pid) for pid in persona_ids}
|
| 135 |
-
return {pid: names.get(pid, pid) for pid in persona_ids}
|
| 136 |
-
|
| 137 |
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
|
| 145 |
@st.cache_data(show_spinner=False)
|
|
@@ -151,11 +99,11 @@ def local_model_options_cached(
|
|
| 151 |
|
| 152 |
@st.cache_data(show_spinner=False)
|
| 153 |
def hub_models_by_mask_strategy(repo_id: str) -> dict[MaskStrategy, list[str]]:
|
| 154 |
-
|
| 155 |
return {
|
| 156 |
MaskStrategy(strategy_value): models
|
| 157 |
-
for strategy_value, models in
|
| 158 |
-
if strategy_value in
|
| 159 |
}
|
| 160 |
|
| 161 |
|
|
@@ -173,56 +121,14 @@ def store_id(store: Store) -> str:
|
|
| 173 |
|
| 174 |
def available_variants(store: Store, mask_strategy: MaskStrategy) -> list[str]:
|
| 175 |
source, location, model_name = store_cache_parts(store)
|
| 176 |
-
return available_variants_cached(
|
| 177 |
-
source,
|
| 178 |
-
location,
|
| 179 |
-
model_name,
|
| 180 |
-
mask_strategy.value,
|
| 181 |
-
)
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
@st.cache_data(show_spinner=False)
|
| 185 |
-
def store_layers_cached(
|
| 186 |
-
source: str,
|
| 187 |
-
location: str,
|
| 188 |
-
model_name: str,
|
| 189 |
-
mask_strategy_value: str,
|
| 190 |
-
variants: tuple[str, ...],
|
| 191 |
-
persona_ids: tuple[str, ...],
|
| 192 |
-
) -> list[int]:
|
| 193 |
-
if source == SOURCE_HUB:
|
| 194 |
-
shared_layers: set[int] | None = None
|
| 195 |
-
requested = list(persona_ids)
|
| 196 |
-
for variant in variants:
|
| 197 |
-
dataset = _hub_split(location, model_name, mask_strategy_value, variant)
|
| 198 |
-
ids = list(dataset.select_columns(["persona_id"])["persona_id"])
|
| 199 |
-
sample_id = requested[0] if requested else (ids[0] if ids else None)
|
| 200 |
-
if sample_id is None:
|
| 201 |
-
return []
|
| 202 |
-
if requested and any(persona_id not in ids for persona_id in requested):
|
| 203 |
-
return []
|
| 204 |
-
vector = torch.as_tensor(dataset[ids.index(sample_id)]["vector"])
|
| 205 |
-
if vector.ndim != 2:
|
| 206 |
-
raise ValueError(
|
| 207 |
-
f"tensor for {sample_id!r} must have shape (num_layers, hidden_size)"
|
| 208 |
-
)
|
| 209 |
-
layers = set(range(int(vector.shape[0])))
|
| 210 |
-
shared_layers = layers if shared_layers is None else shared_layers & layers
|
| 211 |
-
return sorted(shared_layers or set())
|
| 212 |
-
|
| 213 |
-
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 214 |
-
return store.list_layers(
|
| 215 |
-
list(variants),
|
| 216 |
-
list(persona_ids),
|
| 217 |
-
mask_strategy=MaskStrategy(mask_strategy_value),
|
| 218 |
-
)
|
| 219 |
|
| 220 |
|
| 221 |
def local_model_matches(left: str, right: str) -> bool:
|
| 222 |
return model_dir_name(left) == model_dir_name(right)
|
| 223 |
|
| 224 |
|
| 225 |
-
def
|
| 226 |
source: str,
|
| 227 |
location: str,
|
| 228 |
model_name: str,
|
|
@@ -230,61 +136,16 @@ def load_persona_vectors_lean(
|
|
| 230 |
variant: str,
|
| 231 |
persona_ids: tuple[str, ...],
|
| 232 |
) -> LayeredSamples:
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
mask_strategy_value,
|
| 241 |
-
)
|
| 242 |
-
return load_persona_vectors(
|
| 243 |
-
store,
|
| 244 |
-
variant,
|
| 245 |
-
mask_strategy=MaskStrategy(mask_strategy_value),
|
| 246 |
-
persona_ids=list(persona_ids),
|
| 247 |
-
)
|
| 248 |
-
|
| 249 |
-
dataset = _hub_split(location, model_name, mask_strategy_value, variant)
|
| 250 |
-
metadata = dataset.select_columns(["persona_id", "name"])
|
| 251 |
-
index_by_id: dict[str, int] = {}
|
| 252 |
-
name_by_id: dict[str, str] = {}
|
| 253 |
-
requested = set(persona_ids)
|
| 254 |
-
for index, row in enumerate(metadata):
|
| 255 |
-
persona_id = row["persona_id"]
|
| 256 |
-
if persona_id in requested:
|
| 257 |
-
index_by_id[persona_id] = index
|
| 258 |
-
name_by_id[persona_id] = row.get("name") or persona_id
|
| 259 |
-
if len(index_by_id) == len(requested):
|
| 260 |
-
break
|
| 261 |
-
|
| 262 |
-
missing = [
|
| 263 |
-
persona_id for persona_id in persona_ids if persona_id not in index_by_id
|
| 264 |
-
]
|
| 265 |
-
if missing:
|
| 266 |
-
raise FileNotFoundError(
|
| 267 |
-
f"Missing {len(missing)} persona vector(s) in {variant!r}: {missing[:3]}"
|
| 268 |
-
)
|
| 269 |
-
|
| 270 |
-
vectors, labels, hover_text = [], [], []
|
| 271 |
-
for persona_id in persona_ids:
|
| 272 |
-
name = name_by_id.get(persona_id, persona_id)
|
| 273 |
-
vector = torch.as_tensor(
|
| 274 |
-
dataset[index_by_id[persona_id]]["vector"],
|
| 275 |
-
dtype=torch.float32,
|
| 276 |
-
)
|
| 277 |
-
if vector.ndim != 2:
|
| 278 |
-
raise ValueError(
|
| 279 |
-
f"tensor for {persona_id!r} must have shape (num_layers, hidden_size)"
|
| 280 |
-
)
|
| 281 |
-
vectors.append(vector)
|
| 282 |
-
labels.append(name)
|
| 283 |
-
hover_text.append(f"Persona: {name}<br>ID: {persona_id}")
|
| 284 |
-
return LayeredSamples(torch.stack(vectors), labels, hover_text)
|
| 285 |
|
| 286 |
|
| 287 |
-
def
|
| 288 |
source: str,
|
| 289 |
location: str,
|
| 290 |
model_name: str,
|
|
@@ -293,27 +154,18 @@ def load_variant_vectors_lean(
|
|
| 293 |
persona_ids: tuple[str, ...],
|
| 294 |
) -> dict[str, LayeredSamples]:
|
| 295 |
return {
|
| 296 |
-
variant:
|
| 297 |
-
source,
|
| 298 |
-
location,
|
| 299 |
-
model_name,
|
| 300 |
-
mask_strategy_value,
|
| 301 |
-
variant,
|
| 302 |
-
persona_ids,
|
| 303 |
)
|
| 304 |
for variant in variants
|
| 305 |
}
|
| 306 |
|
| 307 |
|
| 308 |
-
def
|
| 309 |
store: Store,
|
| 310 |
variants: list[str] | tuple[str, ...] | None = None,
|
| 311 |
) -> None:
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
cache.clear()
|
| 317 |
-
return
|
| 318 |
-
for variant in variants:
|
| 319 |
-
cache.pop(variant, None)
|
|
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
+
from persona_vectors.analysis import LayeredSamples, load_persona_vectors
|
|
|
|
| 5 |
from persona_vectors.artifacts import (
|
| 6 |
ActivationStore,
|
| 7 |
HFActivationStore,
|
|
|
|
| 8 |
discover_activation_models,
|
| 9 |
model_dir_name,
|
| 10 |
)
|
|
|
|
| 23 |
SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
@st.cache_resource(show_spinner=False, max_entries=1)
|
| 27 |
def activation_store_cached(
|
| 28 |
source: str,
|
|
|
|
| 43 |
model_name: str,
|
| 44 |
mask_strategy_value: str,
|
| 45 |
) -> list[str]:
|
| 46 |
+
return activation_store_cached(
|
| 47 |
+
source, location, model_name, mask_strategy_value
|
| 48 |
+
).available_variants()
|
| 49 |
|
| 50 |
|
| 51 |
@st.cache_data(show_spinner=False)
|
|
|
|
| 56 |
mask_strategy_value: str,
|
| 57 |
variants: tuple[str, ...],
|
| 58 |
) -> list[str]:
|
| 59 |
+
return activation_store_cached(
|
| 60 |
+
source, location, model_name, mask_strategy_value
|
| 61 |
+
).list_personas(list(variants))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
|
| 64 |
@st.cache_data(show_spinner=False)
|
|
|
|
| 70 |
variants: tuple[str, ...],
|
| 71 |
persona_ids: tuple[str, ...],
|
| 72 |
) -> dict[str, str]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 74 |
+
names = store.persona_names(list(persona_ids), variants=list(variants))
|
| 75 |
+
# Preserve input order, fall back to the id when the row has no display name.
|
| 76 |
+
return {pid: names.get(pid, pid) for pid in persona_ids}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@st.cache_data(show_spinner=False)
|
| 80 |
+
def store_layers_cached(
|
| 81 |
+
source: str,
|
| 82 |
+
location: str,
|
| 83 |
+
model_name: str,
|
| 84 |
+
mask_strategy_value: str,
|
| 85 |
+
variants: tuple[str, ...],
|
| 86 |
+
persona_ids: tuple[str, ...],
|
| 87 |
+
) -> list[int]:
|
| 88 |
+
return activation_store_cached(
|
| 89 |
+
source, location, model_name, mask_strategy_value
|
| 90 |
+
).list_layers(list(variants), list(persona_ids))
|
| 91 |
|
| 92 |
|
| 93 |
@st.cache_data(show_spinner=False)
|
|
|
|
| 99 |
|
| 100 |
@st.cache_data(show_spinner=False)
|
| 101 |
def hub_models_by_mask_strategy(repo_id: str) -> dict[MaskStrategy, list[str]]:
|
| 102 |
+
valid = {strategy.value for strategy in MaskStrategy}
|
| 103 |
return {
|
| 104 |
MaskStrategy(strategy_value): models
|
| 105 |
+
for strategy_value, models in list_hub_vector_models(repo_id).items()
|
| 106 |
+
if strategy_value in valid
|
| 107 |
}
|
| 108 |
|
| 109 |
|
|
|
|
| 121 |
|
| 122 |
def available_variants(store: Store, mask_strategy: MaskStrategy) -> list[str]:
|
| 123 |
source, location, model_name = store_cache_parts(store)
|
| 124 |
+
return available_variants_cached(source, location, model_name, mask_strategy.value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
def local_model_matches(left: str, right: str) -> bool:
|
| 128 |
return model_dir_name(left) == model_dir_name(right)
|
| 129 |
|
| 130 |
|
| 131 |
+
def load_persona_vectors_cached(
|
| 132 |
source: str,
|
| 133 |
location: str,
|
| 134 |
model_name: str,
|
|
|
|
| 136 |
variant: str,
|
| 137 |
persona_ids: tuple[str, ...],
|
| 138 |
) -> LayeredSamples:
|
| 139 |
+
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 140 |
+
return load_persona_vectors(
|
| 141 |
+
store,
|
| 142 |
+
variant,
|
| 143 |
+
mask_strategy=MaskStrategy(mask_strategy_value),
|
| 144 |
+
persona_ids=list(persona_ids),
|
| 145 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
|
| 148 |
+
def load_variant_vectors_cached(
|
| 149 |
source: str,
|
| 150 |
location: str,
|
| 151 |
model_name: str,
|
|
|
|
| 154 |
persona_ids: tuple[str, ...],
|
| 155 |
) -> dict[str, LayeredSamples]:
|
| 156 |
return {
|
| 157 |
+
variant: load_persona_vectors_cached(
|
| 158 |
+
source, location, model_name, mask_strategy_value, variant, persona_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
)
|
| 160 |
for variant in variants
|
| 161 |
}
|
| 162 |
|
| 163 |
|
| 164 |
+
def release_hf_store_cache(
|
| 165 |
store: Store,
|
| 166 |
variants: list[str] | tuple[str, ...] | None = None,
|
| 167 |
) -> None:
|
| 168 |
+
"""Drop cached HF data for ``variants`` (or all) on Hub stores."""
|
| 169 |
+
release_cache = getattr(store, "release_cache", None)
|
| 170 |
+
if isinstance(store, HFActivationStore) and callable(release_cache):
|
| 171 |
+
release_cache(variants)
|
|
|
|
|
|
|
|
|
|
|
|
utils/chat.py
CHANGED
|
@@ -5,11 +5,11 @@ from contextlib import contextmanager, nullcontext
|
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import TYPE_CHECKING, Literal
|
| 7 |
|
|
|
|
| 8 |
from persona_data.prompts import format_messages, format_prompt, normalize_messages
|
| 9 |
from persona_data.synth_persona import PersonaData
|
| 10 |
|
| 11 |
if TYPE_CHECKING:
|
| 12 |
-
import torch
|
| 13 |
from nnterp import StandardizedTransformer
|
| 14 |
|
| 15 |
logger = logging.getLogger(__name__)
|
|
@@ -133,8 +133,6 @@ def format_generation_prompt(
|
|
| 133 |
|
| 134 |
def resolve_saved_tensor(value: object) -> torch.Tensor:
|
| 135 |
"""Resolve an nnsight ``.save()`` proxy (or raw tensor) to a CPU tensor."""
|
| 136 |
-
import torch
|
| 137 |
-
|
| 138 |
resolved = value.value if getattr(value, "value", None) is not None else value
|
| 139 |
if not isinstance(resolved, torch.Tensor):
|
| 140 |
raise TypeError(f"Trace result did not resolve to a tensor: {type(resolved)!r}")
|
|
@@ -160,8 +158,6 @@ def _seeded_rng(seed: int | None):
|
|
| 160 |
yield
|
| 161 |
return
|
| 162 |
|
| 163 |
-
import torch
|
| 164 |
-
|
| 165 |
cuda_ctx = torch.random.fork_rng(devices=range(torch.cuda.device_count()))
|
| 166 |
mps_ctx = (
|
| 167 |
torch.random.fork_rng(devices=range(1), device_type="mps")
|
|
@@ -207,8 +203,6 @@ def generate_chat_reply(
|
|
| 207 |
ChatReply with generated text and token ids.
|
| 208 |
"""
|
| 209 |
|
| 210 |
-
import torch
|
| 211 |
-
|
| 212 |
tokenizer = model.tokenizer
|
| 213 |
prompt, prompt_token_count = format_generation_prompt(messages, tokenizer)
|
| 214 |
|
|
@@ -228,9 +222,11 @@ def generate_chat_reply(
|
|
| 228 |
generation_kwargs["repetition_penalty"] = repetition_penalty
|
| 229 |
# `remote` is captured by nnsight's RemoteableMixin.trace() and is NOT
|
| 230 |
# forwarded to the underlying model's generate
|
| 231 |
-
with
|
| 232 |
-
|
| 233 |
-
|
|
|
|
|
|
|
| 234 |
|
| 235 |
if getattr(generated, "value", None) is not None:
|
| 236 |
generated = generated.value
|
|
|
|
| 5 |
from dataclasses import dataclass
|
| 6 |
from typing import TYPE_CHECKING, Literal
|
| 7 |
|
| 8 |
+
import torch
|
| 9 |
from persona_data.prompts import format_messages, format_prompt, normalize_messages
|
| 10 |
from persona_data.synth_persona import PersonaData
|
| 11 |
|
| 12 |
if TYPE_CHECKING:
|
|
|
|
| 13 |
from nnterp import StandardizedTransformer
|
| 14 |
|
| 15 |
logger = logging.getLogger(__name__)
|
|
|
|
| 133 |
|
| 134 |
def resolve_saved_tensor(value: object) -> torch.Tensor:
|
| 135 |
"""Resolve an nnsight ``.save()`` proxy (or raw tensor) to a CPU tensor."""
|
|
|
|
|
|
|
| 136 |
resolved = value.value if getattr(value, "value", None) is not None else value
|
| 137 |
if not isinstance(resolved, torch.Tensor):
|
| 138 |
raise TypeError(f"Trace result did not resolve to a tensor: {type(resolved)!r}")
|
|
|
|
| 158 |
yield
|
| 159 |
return
|
| 160 |
|
|
|
|
|
|
|
| 161 |
cuda_ctx = torch.random.fork_rng(devices=range(torch.cuda.device_count()))
|
| 162 |
mps_ctx = (
|
| 163 |
torch.random.fork_rng(devices=range(1), device_type="mps")
|
|
|
|
| 203 |
ChatReply with generated text and token ids.
|
| 204 |
"""
|
| 205 |
|
|
|
|
|
|
|
| 206 |
tokenizer = model.tokenizer
|
| 207 |
prompt, prompt_token_count = format_generation_prompt(messages, tokenizer)
|
| 208 |
|
|
|
|
| 222 |
generation_kwargs["repetition_penalty"] = repetition_penalty
|
| 223 |
# `remote` is captured by nnsight's RemoteableMixin.trace() and is NOT
|
| 224 |
# forwarded to the underlying model's generate
|
| 225 |
+
with (
|
| 226 |
+
_seeded_rng(seed if do_sample and not remote else None),
|
| 227 |
+
model.generate(prompt, remote=remote, **generation_kwargs) as tracer,
|
| 228 |
+
):
|
| 229 |
+
generated = tracer.result.save()
|
| 230 |
|
| 231 |
if getattr(generated, "value", None) is not None:
|
| 232 |
generated = generated.value
|
utils/chat_export.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import json
|
| 2 |
from dataclasses import asdict, is_dataclass
|
| 3 |
-
from datetime import
|
| 4 |
from pathlib import Path
|
| 5 |
|
| 6 |
from utils.helpers import slugify
|
|
@@ -72,7 +72,7 @@ def save_chat_export(
|
|
| 72 |
)
|
| 73 |
export_dir.mkdir(parents=True, exist_ok=True)
|
| 74 |
|
| 75 |
-
timestamp = datetime.now(
|
| 76 |
filename_parts = [
|
| 77 |
timestamp,
|
| 78 |
slugify(persona_name or persona_id),
|
|
|
|
| 1 |
import json
|
| 2 |
from dataclasses import asdict, is_dataclass
|
| 3 |
+
from datetime import UTC, datetime
|
| 4 |
from pathlib import Path
|
| 5 |
|
| 6 |
from utils.helpers import slugify
|
|
|
|
| 72 |
)
|
| 73 |
export_dir.mkdir(parents=True, exist_ok=True)
|
| 74 |
|
| 75 |
+
timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ")
|
| 76 |
filename_parts = [
|
| 77 |
timestamp,
|
| 78 |
slugify(persona_name or persona_id),
|
utils/contrast.py
CHANGED
|
@@ -244,7 +244,9 @@ def render_contrast_html(result: TokenContrast) -> str:
|
|
| 244 |
it is, with a hover tooltip showing the raw Δlog P, plus a legend.
|
| 245 |
"""
|
| 246 |
spans: list[str] = []
|
| 247 |
-
for token, weight, raw in zip(
|
|
|
|
|
|
|
| 248 |
bg = _weight_to_bg(weight)
|
| 249 |
tip = escape(f"Δlog P(A−B): {raw:+.3f}")
|
| 250 |
text = escape(token)
|
|
|
|
| 244 |
it is, with a hover tooltip showing the raw Δlog P, plus a legend.
|
| 245 |
"""
|
| 246 |
spans: list[str] = []
|
| 247 |
+
for token, weight, raw in zip(
|
| 248 |
+
result.tokens, result.weights, result.raw_diffs, strict=True
|
| 249 |
+
):
|
| 250 |
bg = _weight_to_bg(weight)
|
| 251 |
tip = escape(f"Δlog P(A−B): {raw:+.3f}")
|
| 252 |
text = escape(token)
|
utils/datasets.py
CHANGED
|
@@ -13,7 +13,7 @@ from persona_data.nemotron_personas import (
|
|
| 13 |
from persona_data.synth_persona import PersonaDataset as LocalPersonaDataset
|
| 14 |
from persona_data.synth_persona import SynthPersonaDataset
|
| 15 |
|
| 16 |
-
from .helpers import
|
| 17 |
|
| 18 |
|
| 19 |
@st.cache_resource(show_spinner=False)
|
|
@@ -63,6 +63,12 @@ def load_persona_list(
|
|
| 63 |
"""
|
| 64 |
|
| 65 |
dataset, status = load_dataset(dataset_source, personas_file, qa_file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
cached = getattr(dataset, "_persona_list_cache", None)
|
| 67 |
if cached is None:
|
| 68 |
cached = list(dataset)
|
|
@@ -70,7 +76,7 @@ def load_persona_list(
|
|
| 70 |
dataset._persona_list_cache = cached
|
| 71 |
except (AttributeError, TypeError):
|
| 72 |
pass
|
| 73 |
-
return cached
|
| 74 |
|
| 75 |
|
| 76 |
def load_dataset(
|
|
@@ -86,13 +92,13 @@ def load_dataset(
|
|
| 86 |
]:
|
| 87 |
"""Load the selected dataset source for the UI."""
|
| 88 |
|
| 89 |
-
if dataset_source ==
|
| 90 |
return _cached_dataset(SynthPersonaDataset), "SynthPersona"
|
| 91 |
|
| 92 |
-
if dataset_source ==
|
| 93 |
return _cached_dataset(NemotronPersonasFranceDataset), "Nemotron France"
|
| 94 |
|
| 95 |
-
if dataset_source ==
|
| 96 |
return _cached_dataset(NemotronPersonasUSADataset), "Nemotron USA"
|
| 97 |
|
| 98 |
if personas_file is None or qa_file is None:
|
|
|
|
| 13 |
from persona_data.synth_persona import PersonaDataset as LocalPersonaDataset
|
| 14 |
from persona_data.synth_persona import SynthPersonaDataset
|
| 15 |
|
| 16 |
+
from .helpers import DatasetSource
|
| 17 |
|
| 18 |
|
| 19 |
@st.cache_resource(show_spinner=False)
|
|
|
|
| 63 |
"""
|
| 64 |
|
| 65 |
dataset, status = load_dataset(dataset_source, personas_file, qa_file)
|
| 66 |
+
return load_persona_list_from_dataset(dataset), status
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_persona_list_from_dataset(dataset: Any) -> list:
|
| 70 |
+
"""Materialize and cache personas from an already-loaded dataset."""
|
| 71 |
+
|
| 72 |
cached = getattr(dataset, "_persona_list_cache", None)
|
| 73 |
if cached is None:
|
| 74 |
cached = list(dataset)
|
|
|
|
| 76 |
dataset._persona_list_cache = cached
|
| 77 |
except (AttributeError, TypeError):
|
| 78 |
pass
|
| 79 |
+
return cached
|
| 80 |
|
| 81 |
|
| 82 |
def load_dataset(
|
|
|
|
| 92 |
]:
|
| 93 |
"""Load the selected dataset source for the UI."""
|
| 94 |
|
| 95 |
+
if dataset_source == DatasetSource.SYNTH_PERSONA.value:
|
| 96 |
return _cached_dataset(SynthPersonaDataset), "SynthPersona"
|
| 97 |
|
| 98 |
+
if dataset_source == DatasetSource.NEMOTRON_FRANCE.value:
|
| 99 |
return _cached_dataset(NemotronPersonasFranceDataset), "Nemotron France"
|
| 100 |
|
| 101 |
+
if dataset_source == DatasetSource.NEMOTRON_USA.value:
|
| 102 |
return _cached_dataset(NemotronPersonasUSADataset), "Nemotron USA"
|
| 103 |
|
| 104 |
if personas_file is None or qa_file is None:
|
utils/helpers.py
CHANGED
|
@@ -1,9 +1,21 @@
|
|
| 1 |
import hashlib
|
| 2 |
import re
|
| 3 |
from collections.abc import Iterable
|
|
|
|
| 4 |
|
| 5 |
from persona_data.synth_persona import PersonaData
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
# Variant key -> human-readable label mapping
|
| 8 |
VARIANT_LABELS = {
|
| 9 |
"empty": "None",
|
|
@@ -16,21 +28,21 @@ VARIANT_LABELS = {
|
|
| 16 |
CHAT_PROMPT_MODES = ("empty", "templated", "biography", "custom")
|
| 17 |
CHAT_PROMPT_MODE_LABELS = [VARIANT_LABELS[key] for key in CHAT_PROMPT_MODES]
|
| 18 |
CHAT_PROMPT_MODE_LABEL_TO_KEY = {VARIANT_LABELS[key]: key for key in CHAT_PROMPT_MODES}
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
"
|
| 23 |
-
"
|
| 24 |
-
"
|
| 25 |
-
"
|
| 26 |
]
|
| 27 |
-
ANALYSIS_MODES = ["Cosine similarity", "Similarity matrix", "PCA", "UMAP", "Dendrogram"]
|
| 28 |
|
| 29 |
ANALYSIS_HELP_TEXT = {
|
| 30 |
"Cosine similarity": "Compare layer-wise alignment between variants.",
|
| 31 |
"Similarity matrix": "Compare centered pairwise similarity between persona vectors by layer, with pair trajectories across layers.",
|
| 32 |
"PCA": "Project per-persona vectors into a 2D or 3D global view.",
|
| 33 |
"UMAP": "Project per-persona vectors into a 2D or 3D local-neighborhood view.",
|
|
|
|
| 34 |
"Dendrogram": "Hierarchical clustering of persona vectors — shows biography and templated side by side for direct comparison.",
|
| 35 |
}
|
| 36 |
|
|
@@ -56,6 +68,12 @@ def widget_key(*parts: str) -> str:
|
|
| 56 |
return "::".join(parts)
|
| 57 |
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
def personas_fingerprint(persona_ids: Iterable[str]) -> str:
|
| 60 |
"""Stable short fingerprint for a set of persona ids.
|
| 61 |
|
|
@@ -78,11 +96,3 @@ def persona_label(persona: PersonaData) -> str:
|
|
| 78 |
"""Format a persona for selection widgets."""
|
| 79 |
|
| 80 |
return f"{persona.name} ({persona.id})"
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
def persona_display_label(persona_id: str, persona_name: str | None) -> str:
|
| 84 |
-
"""Format a persona id with an optional display name."""
|
| 85 |
-
|
| 86 |
-
if persona_name:
|
| 87 |
-
return f"{persona_name} ({persona_id})"
|
| 88 |
-
return persona_id
|
|
|
|
| 1 |
import hashlib
|
| 2 |
import re
|
| 3 |
from collections.abc import Iterable
|
| 4 |
+
from enum import Enum
|
| 5 |
|
| 6 |
from persona_data.synth_persona import PersonaData
|
| 7 |
|
| 8 |
+
|
| 9 |
+
class DatasetSource(str, Enum):
|
| 10 |
+
SYNTH_PERSONA = "HuggingFace: synth-persona"
|
| 11 |
+
NEMOTRON_FRANCE = "HuggingFace: nemotron-france"
|
| 12 |
+
NEMOTRON_USA = "HuggingFace: nemotron-usa"
|
| 13 |
+
LOCAL_UPLOAD = "Local JSONL upload"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
DATASET_SOURCES = [s.value for s in DatasetSource]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
# Variant key -> human-readable label mapping
|
| 20 |
VARIANT_LABELS = {
|
| 21 |
"empty": "None",
|
|
|
|
| 28 |
CHAT_PROMPT_MODES = ("empty", "templated", "biography", "custom")
|
| 29 |
CHAT_PROMPT_MODE_LABELS = [VARIANT_LABELS[key] for key in CHAT_PROMPT_MODES]
|
| 30 |
CHAT_PROMPT_MODE_LABEL_TO_KEY = {VARIANT_LABELS[key]: key for key in CHAT_PROMPT_MODES}
|
| 31 |
+
ANALYSIS_MODES = [
|
| 32 |
+
"Cosine similarity",
|
| 33 |
+
"Similarity matrix",
|
| 34 |
+
"PCA",
|
| 35 |
+
"UMAP",
|
| 36 |
+
"Isomap",
|
| 37 |
+
"Dendrogram",
|
| 38 |
]
|
|
|
|
| 39 |
|
| 40 |
ANALYSIS_HELP_TEXT = {
|
| 41 |
"Cosine similarity": "Compare layer-wise alignment between variants.",
|
| 42 |
"Similarity matrix": "Compare centered pairwise similarity between persona vectors by layer, with pair trajectories across layers.",
|
| 43 |
"PCA": "Project per-persona vectors into a 2D or 3D global view.",
|
| 44 |
"UMAP": "Project per-persona vectors into a 2D or 3D local-neighborhood view.",
|
| 45 |
+
"Isomap": "Project per-persona vectors with graph-geodesic distances to probe manifold-like geometry.",
|
| 46 |
"Dendrogram": "Hierarchical clustering of persona vectors — shows biography and templated side by side for direct comparison.",
|
| 47 |
}
|
| 48 |
|
|
|
|
| 68 |
return "::".join(parts)
|
| 69 |
|
| 70 |
|
| 71 |
+
def session_key(*parts: str) -> str:
|
| 72 |
+
"""Generate a colon-separated Streamlit session-state key from parts."""
|
| 73 |
+
|
| 74 |
+
return ":".join(parts)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
def personas_fingerprint(persona_ids: Iterable[str]) -> str:
|
| 78 |
"""Stable short fingerprint for a set of persona ids.
|
| 79 |
|
|
|
|
| 96 |
"""Format a persona for selection widgets."""
|
| 97 |
|
| 98 |
return f"{persona.name} ({persona.id})"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uv.lock
CHANGED
|
@@ -748,11 +748,11 @@ wheels = [
|
|
| 748 |
|
| 749 |
[[package]]
|
| 750 |
name = "idna"
|
| 751 |
-
version = "3.
|
| 752 |
source = { registry = "https://pypi.org/simple" }
|
| 753 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 754 |
wheels = [
|
| 755 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 756 |
]
|
| 757 |
|
| 758 |
[[package]]
|
|
@@ -1559,7 +1559,7 @@ wheels = [
|
|
| 1559 |
|
| 1560 |
[[package]]
|
| 1561 |
name = "persona-data"
|
| 1562 |
-
version = "0.
|
| 1563 |
source = { registry = "https://pypi.org/simple" }
|
| 1564 |
dependencies = [
|
| 1565 |
{ name = "huggingface-hub" },
|
|
@@ -1568,9 +1568,9 @@ dependencies = [
|
|
| 1568 |
{ name = "python-dotenv" },
|
| 1569 |
{ name = "torch" },
|
| 1570 |
]
|
| 1571 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 1572 |
wheels = [
|
| 1573 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 1574 |
]
|
| 1575 |
|
| 1576 |
[[package]]
|
|
@@ -1581,7 +1581,6 @@ dependencies = [
|
|
| 1581 |
{ name = "catppuccin" },
|
| 1582 |
{ name = "datasets" },
|
| 1583 |
{ name = "huggingface-hub" },
|
| 1584 |
-
{ name = "persona-data" },
|
| 1585 |
{ name = "persona-vectors" },
|
| 1586 |
{ name = "plotly" },
|
| 1587 |
{ name = "python-dotenv" },
|
|
@@ -1593,8 +1592,7 @@ requires-dist = [
|
|
| 1593 |
{ name = "catppuccin", specifier = ">=2.5.0" },
|
| 1594 |
{ name = "datasets", specifier = ">=4.8.5" },
|
| 1595 |
{ name = "huggingface-hub", specifier = ">=1.14.0" },
|
| 1596 |
-
{ name = "persona-
|
| 1597 |
-
{ name = "persona-vectors", specifier = ">=0.7.3" },
|
| 1598 |
{ name = "plotly", specifier = ">=6.6.0" },
|
| 1599 |
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
| 1600 |
{ name = "streamlit", specifier = ">=1.44.0" },
|
|
@@ -1602,7 +1600,7 @@ requires-dist = [
|
|
| 1602 |
|
| 1603 |
[[package]]
|
| 1604 |
name = "persona-vectors"
|
| 1605 |
-
version = "0.
|
| 1606 |
source = { registry = "https://pypi.org/simple" }
|
| 1607 |
dependencies = [
|
| 1608 |
{ name = "datasets" },
|
|
@@ -1621,9 +1619,9 @@ dependencies = [
|
|
| 1621 |
{ name = "transformers" },
|
| 1622 |
{ name = "umap-learn" },
|
| 1623 |
]
|
| 1624 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 1625 |
wheels = [
|
| 1626 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 1627 |
]
|
| 1628 |
|
| 1629 |
[[package]]
|
|
@@ -2838,7 +2836,7 @@ wheels = [
|
|
| 2838 |
|
| 2839 |
[[package]]
|
| 2840 |
name = "transformers"
|
| 2841 |
-
version = "5.8.
|
| 2842 |
source = { registry = "https://pypi.org/simple" }
|
| 2843 |
dependencies = [
|
| 2844 |
{ name = "huggingface-hub" },
|
|
@@ -2851,9 +2849,9 @@ dependencies = [
|
|
| 2851 |
{ name = "tqdm" },
|
| 2852 |
{ name = "typer" },
|
| 2853 |
]
|
| 2854 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 2855 |
wheels = [
|
| 2856 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 2857 |
]
|
| 2858 |
|
| 2859 |
[[package]]
|
|
|
|
| 748 |
|
| 749 |
[[package]]
|
| 750 |
name = "idna"
|
| 751 |
+
version = "3.15"
|
| 752 |
source = { registry = "https://pypi.org/simple" }
|
| 753 |
+
sdist = { url = "https://files.pythonhosted.org/packages/82/77/7b3966d0b9d1d31a36ddf1746926a11dface89a83409bf1483f0237aa758/idna-3.15.tar.gz", hash = "sha256:ca962446ea538f7092a95e057da437618e886f4d349216d2b1e294abfdb65fdc", size = 199245, upload-time = "2026-05-12T22:45:57.011Z" }
|
| 754 |
wheels = [
|
| 755 |
+
{ url = "https://files.pythonhosted.org/packages/d2/23/408243171aa9aaba178d3e2559159c24c1171a641aa83b67bdd3394ead8e/idna-3.15-py3-none-any.whl", hash = "sha256:048adeaf8c2d788c40fee287673ccaa74c24ffd8dcf09ffa555a2fbb59f10ac8", size = 72340, upload-time = "2026-05-12T22:45:55.733Z" },
|
| 756 |
]
|
| 757 |
|
| 758 |
[[package]]
|
|
|
|
| 1559 |
|
| 1560 |
[[package]]
|
| 1561 |
name = "persona-data"
|
| 1562 |
+
version = "0.5.1"
|
| 1563 |
source = { registry = "https://pypi.org/simple" }
|
| 1564 |
dependencies = [
|
| 1565 |
{ name = "huggingface-hub" },
|
|
|
|
| 1568 |
{ name = "python-dotenv" },
|
| 1569 |
{ name = "torch" },
|
| 1570 |
]
|
| 1571 |
+
sdist = { url = "https://files.pythonhosted.org/packages/de/9f/2257b6df8c28f0844b88f64a200a4d27f82ea10a16e657ba9fd02f561135/persona_data-0.5.1.tar.gz", hash = "sha256:5ac4467c449905fecf26a743b7128f76dbd984a076426c3ce854a13394c1fc5c", size = 10336, upload-time = "2026-05-13T11:55:00.356Z" }
|
| 1572 |
wheels = [
|
| 1573 |
+
{ url = "https://files.pythonhosted.org/packages/55/ec/328013ee81672ba800777b3a9c24f18dc7cb3a93223391e3476cac55fa1b/persona_data-0.5.1-py3-none-any.whl", hash = "sha256:ccf230b4028d08b9345910b57de6ea4b60e9ec7f65ce12203f69693988314543", size = 13078, upload-time = "2026-05-13T11:55:01.402Z" },
|
| 1574 |
]
|
| 1575 |
|
| 1576 |
[[package]]
|
|
|
|
| 1581 |
{ name = "catppuccin" },
|
| 1582 |
{ name = "datasets" },
|
| 1583 |
{ name = "huggingface-hub" },
|
|
|
|
| 1584 |
{ name = "persona-vectors" },
|
| 1585 |
{ name = "plotly" },
|
| 1586 |
{ name = "python-dotenv" },
|
|
|
|
| 1592 |
{ name = "catppuccin", specifier = ">=2.5.0" },
|
| 1593 |
{ name = "datasets", specifier = ">=4.8.5" },
|
| 1594 |
{ name = "huggingface-hub", specifier = ">=1.14.0" },
|
| 1595 |
+
{ name = "persona-vectors", specifier = ">=0.8.0" },
|
|
|
|
| 1596 |
{ name = "plotly", specifier = ">=6.6.0" },
|
| 1597 |
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
| 1598 |
{ name = "streamlit", specifier = ">=1.44.0" },
|
|
|
|
| 1600 |
|
| 1601 |
[[package]]
|
| 1602 |
name = "persona-vectors"
|
| 1603 |
+
version = "0.8.0"
|
| 1604 |
source = { registry = "https://pypi.org/simple" }
|
| 1605 |
dependencies = [
|
| 1606 |
{ name = "datasets" },
|
|
|
|
| 1619 |
{ name = "transformers" },
|
| 1620 |
{ name = "umap-learn" },
|
| 1621 |
]
|
| 1622 |
+
sdist = { url = "https://files.pythonhosted.org/packages/76/22/8a0ca0e6e54ebd8dd07a4064c2890ec33b68ad81a00e4e93c4f9eee2bcf7/persona_vectors-0.8.0.tar.gz", hash = "sha256:3775afc7e04ab1d02582e9c4b3f2d124174ea40d376dd2b91492457a747dd553", size = 31938, upload-time = "2026-05-13T20:00:46.357Z" }
|
| 1623 |
wheels = [
|
| 1624 |
+
{ url = "https://files.pythonhosted.org/packages/43/a6/7f67a7df27d78db706cbc9afd5d5ca4b52970b9005717c3bfcc0ce90ec71/persona_vectors-0.8.0-py3-none-any.whl", hash = "sha256:08b37a749f98b764d22d4c943158922338ab054729f7137eff2c3a167e2b2ae5", size = 36838, upload-time = "2026-05-13T20:00:47.252Z" },
|
| 1625 |
]
|
| 1626 |
|
| 1627 |
[[package]]
|
|
|
|
| 2836 |
|
| 2837 |
[[package]]
|
| 2838 |
name = "transformers"
|
| 2839 |
+
version = "5.8.1"
|
| 2840 |
source = { registry = "https://pypi.org/simple" }
|
| 2841 |
dependencies = [
|
| 2842 |
{ name = "huggingface-hub" },
|
|
|
|
| 2849 |
{ name = "tqdm" },
|
| 2850 |
{ name = "typer" },
|
| 2851 |
]
|
| 2852 |
+
sdist = { url = "https://files.pythonhosted.org/packages/e7/e6/4134ea2fbea322cddc7ffc94a0d8ee47fe32ce8e876b320cd37d88edfc4d/transformers-5.8.1.tar.gz", hash = "sha256:4dd5b6de4105725104d84fd6abd74b305f4debfc251b38c648ee5dd087cf543b", size = 8532019, upload-time = "2026-05-13T03:21:57.234Z" }
|
| 2853 |
wheels = [
|
| 2854 |
+
{ url = "https://files.pythonhosted.org/packages/fc/b1/8be7e7ef0b5200491312201918b6125ef9c9df9dd0f0240ccef9ac824e6b/transformers-5.8.1-py3-none-any.whl", hash = "sha256:5340fb95962162cdfdae5cc91d7f8fedd92ed75216c1154c5e1f590fcf56dd0e", size = 10632882, upload-time = "2026-05-13T03:21:52.876Z" },
|
| 2855 |
]
|
| 2856 |
|
| 2857 |
[[package]]
|