Jac-Zac commited on
Commit
db3d901
·
1 Parent(s): 99c28ab

Big refactor and feature addition to analyses + support latest persona-vector

Browse files
README.md CHANGED
@@ -20,7 +20,7 @@ Streamlit interface for persona vector extraction, analysis, and chat.
20
  A web app built on top of [persona-vectors](../persona-vectors) that provides three tabs:
21
 
22
  - **Chat** — interactive conversations with a model using persona-based system prompts (templated or biography)
23
- - **Compare** — load local or Hub persona vectors and explore cosine similarity, PCA, UMAP, and similarity views
24
  - **Extract** — run activation extraction from HuggingFace persona datasets or a local JSONL dataset directly from the browser
25
 
26
  ## Repository Layout
@@ -31,7 +31,7 @@ persona-ui/
31
  ├── state.py # Session state management (chat history, KV cache)
32
  ├── tabs/
33
  │ ├── chat.py # Chat tab
34
- │ ├── compare.py # Activation comparison tab
35
  │ ├── compare_chat.py # Side-by-side chat comparison mode
36
  │ ├── extract.py # Extraction tab
37
  │ └── probe_ui.py # Probe upload and tracing controls
 
20
  A web app built on top of [persona-vectors](../persona-vectors) that provides three tabs:
21
 
22
  - **Chat** — interactive conversations with a model using persona-based system prompts (templated or biography)
23
+ - **Compare** — load local or Hub persona vectors and explore cosine similarity, PCA, UMAP, attribute-colored projections, and dendrograms
24
  - **Extract** — run activation extraction from HuggingFace persona datasets or a local JSONL dataset directly from the browser
25
 
26
  ## Repository Layout
 
31
  ├── state.py # Session state management (chat history, KV cache)
32
  ├── tabs/
33
  │ ├── chat.py # Chat tab
34
+ │ ├── analysis.py # Analysis tab (cosine similarity, PCA, UMAP, Isomap, dendrogram)
35
  │ ├── compare_chat.py # Side-by-side chat comparison mode
36
  │ ├── extract.py # Extraction tab
37
  │ └── probe_ui.py # Probe upload and tracing controls
app.py CHANGED
@@ -4,13 +4,26 @@ from dataclasses import dataclass
4
  import streamlit as st
5
  from dotenv import load_dotenv
6
 
7
- from utils.helpers import DATASET_SOURCES
 
 
8
 
9
  load_dotenv()
10
  DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
11
  REMOTE_DEFAULT_MODEL = os.environ.get("REMOTE_DEFAULT_MODEL", "google/gemma-2-9b-it")
12
- _LAST_LOCAL_MODEL_KEY = "sidebar:last_local_model"
13
- _LAST_REMOTE_MODEL_KEY = "sidebar:last_remote_model"
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  _TABS = ["Chat", "Analysis", "Extract"]
@@ -35,9 +48,9 @@ def _remote_model_input(remote_models: list[str]) -> str:
35
  model_name = st.text_input(
36
  "Model",
37
  value=st.session_state.get(
38
- "sidebar__remote_model_custom_value", last_remote
39
  ),
40
- key="sidebar__remote_model_custom_value",
41
  help="NDIF model id. Use this to cold-load a remote model.",
42
  )
43
  st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name
@@ -46,16 +59,16 @@ def _remote_model_input(remote_models: list[str]) -> str:
46
  custom = st.toggle(
47
  "Custom remote model",
48
  value=False,
49
- key="sidebar__remote_model_custom_enabled",
50
  help="Enter any NDIF-loadable model id, even if it is not currently running.",
51
  )
52
  if custom:
53
  model_name = st.text_input(
54
  "Model",
55
  value=st.session_state.get(
56
- "sidebar__remote_model_custom_value", last_remote
57
  ),
58
- key="sidebar__remote_model_custom_value",
59
  help="NDIF model id. Example: openai/gpt-oss-20b",
60
  )
61
  st.caption(
@@ -63,20 +76,20 @@ def _remote_model_input(remote_models: list[str]) -> str:
63
  "Custom model ids can cold-load if your NDIF account allows it."
64
  )
65
  else:
66
- default_model = st.session_state.get("sidebar__remote_model", last_remote)
67
  if default_model not in remote_models:
68
  default_model = (
69
  REMOTE_DEFAULT_MODEL
70
  if REMOTE_DEFAULT_MODEL in remote_models
71
  else remote_models[0]
72
  )
73
- if st.session_state.get("sidebar__remote_model") not in remote_models:
74
- st.session_state["sidebar__remote_model"] = default_model
75
  model_name = st.selectbox(
76
  "Model",
77
  options=remote_models,
78
  index=remote_models.index(default_model),
79
- key="sidebar__remote_model",
80
  help="Running NDIF model.",
81
  )
82
  st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name
@@ -84,15 +97,13 @@ def _remote_model_input(remote_models: list[str]) -> str:
84
 
85
 
86
  def _sidebar_controls() -> SidebarState:
87
- from utils.runtime import list_remote_models
88
-
89
  with st.sidebar:
90
  st.markdown("## Persona UI")
91
 
92
- if "sidebar__active_tab" not in st.session_state:
93
- st.session_state["sidebar__active_tab"] = "Chat"
94
 
95
- active_tab = st.session_state["sidebar__active_tab"]
96
  for tab_name, icon in zip(_TABS, _TAB_ICONS, strict=True):
97
  is_selected = tab_name == active_tab
98
  if st.button(
@@ -102,13 +113,13 @@ def _sidebar_controls() -> SidebarState:
102
  type="primary" if is_selected else "secondary",
103
  icon=icon,
104
  ):
105
- st.session_state["sidebar__active_tab"] = tab_name
106
  st.rerun()
107
 
108
  if active_tab == "Analysis":
109
  model_name = st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL)
110
  dataset_source = st.session_state.get(
111
- "sidebar__dataset_source",
112
  DATASET_SOURCES[0],
113
  )
114
  return SidebarState(
@@ -120,7 +131,7 @@ def _sidebar_controls() -> SidebarState:
120
 
121
  st.divider()
122
  st.caption("Runtime")
123
- remote = st.toggle("Remote (NDIF)", value=False, key="sidebar__remote")
124
 
125
  if remote:
126
  model_name = _remote_model_input(list_remote_models())
@@ -128,7 +139,7 @@ def _sidebar_controls() -> SidebarState:
128
  model_name = st.text_input(
129
  "Model",
130
  value=st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL),
131
- key="sidebar__local_model",
132
  help="Local model id or path.",
133
  )
134
  st.session_state[_LAST_LOCAL_MODEL_KEY] = model_name
@@ -137,7 +148,7 @@ def _sidebar_controls() -> SidebarState:
137
  dataset_source = st.selectbox(
138
  "Source",
139
  DATASET_SOURCES,
140
- key="sidebar__dataset_source",
141
  help="Dataset for Chat and Extract.",
142
  )
143
 
@@ -153,8 +164,6 @@ def main() -> None:
153
  """Run the Streamlit app."""
154
 
155
  st.set_page_config(page_title="Persona UI", layout="wide")
156
- from utils.theme import install_catppuccin_theme
157
-
158
  install_catppuccin_theme(st.get_option("theme.base"))
159
 
160
  sidebar = _sidebar_controls()
@@ -164,9 +173,9 @@ def main() -> None:
164
 
165
  render_extract_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)
166
  elif sidebar.active_tab == "Analysis":
167
- from tabs.compare import render_compare_tab
168
 
169
- render_compare_tab()
170
  else:
171
  from tabs.chat import render_chat_tab
172
 
 
4
  import streamlit as st
5
  from dotenv import load_dotenv
6
 
7
+ from utils.helpers import DATASET_SOURCES, session_key
8
+ from utils.runtime import list_remote_models
9
+ from utils.theme import install_catppuccin_theme
10
 
11
  load_dotenv()
12
  DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
13
  REMOTE_DEFAULT_MODEL = os.environ.get("REMOTE_DEFAULT_MODEL", "google/gemma-2-9b-it")
14
+ _LAST_LOCAL_MODEL_KEY = session_key("sidebar", "last_local_model")
15
+ _LAST_REMOTE_MODEL_KEY = session_key("sidebar", "last_remote_model")
16
+ _SIDEBAR_ACTIVE_TAB_KEY = session_key("sidebar", "active_tab")
17
+ _SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY = session_key(
18
+ "sidebar", "remote_model_custom_value"
19
+ )
20
+ _SIDEBAR_REMOTE_MODEL_CUSTOM_ENABLED_KEY = session_key(
21
+ "sidebar", "remote_model_custom_enabled"
22
+ )
23
+ _SIDEBAR_REMOTE_MODEL_KEY = session_key("sidebar", "remote_model")
24
+ _SIDEBAR_LOCAL_MODEL_KEY = session_key("sidebar", "local_model")
25
+ _SIDEBAR_REMOTE_KEY = session_key("sidebar", "remote")
26
+ _SIDEBAR_DATASET_SOURCE_KEY = session_key("sidebar", "dataset_source")
27
 
28
 
29
  _TABS = ["Chat", "Analysis", "Extract"]
 
48
  model_name = st.text_input(
49
  "Model",
50
  value=st.session_state.get(
51
+ _SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY, last_remote
52
  ),
53
+ key=_SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY,
54
  help="NDIF model id. Use this to cold-load a remote model.",
55
  )
56
  st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name
 
59
  custom = st.toggle(
60
  "Custom remote model",
61
  value=False,
62
+ key=_SIDEBAR_REMOTE_MODEL_CUSTOM_ENABLED_KEY,
63
  help="Enter any NDIF-loadable model id, even if it is not currently running.",
64
  )
65
  if custom:
66
  model_name = st.text_input(
67
  "Model",
68
  value=st.session_state.get(
69
+ _SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY, last_remote
70
  ),
71
+ key=_SIDEBAR_REMOTE_MODEL_CUSTOM_VALUE_KEY,
72
  help="NDIF model id. Example: openai/gpt-oss-20b",
73
  )
74
  st.caption(
 
76
  "Custom model ids can cold-load if your NDIF account allows it."
77
  )
78
  else:
79
+ default_model = st.session_state.get(_SIDEBAR_REMOTE_MODEL_KEY, last_remote)
80
  if default_model not in remote_models:
81
  default_model = (
82
  REMOTE_DEFAULT_MODEL
83
  if REMOTE_DEFAULT_MODEL in remote_models
84
  else remote_models[0]
85
  )
86
+ if st.session_state.get(_SIDEBAR_REMOTE_MODEL_KEY) not in remote_models:
87
+ st.session_state[_SIDEBAR_REMOTE_MODEL_KEY] = default_model
88
  model_name = st.selectbox(
89
  "Model",
90
  options=remote_models,
91
  index=remote_models.index(default_model),
92
+ key=_SIDEBAR_REMOTE_MODEL_KEY,
93
  help="Running NDIF model.",
94
  )
95
  st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name
 
97
 
98
 
99
  def _sidebar_controls() -> SidebarState:
 
 
100
  with st.sidebar:
101
  st.markdown("## Persona UI")
102
 
103
+ if _SIDEBAR_ACTIVE_TAB_KEY not in st.session_state:
104
+ st.session_state[_SIDEBAR_ACTIVE_TAB_KEY] = "Chat"
105
 
106
+ active_tab = st.session_state[_SIDEBAR_ACTIVE_TAB_KEY]
107
  for tab_name, icon in zip(_TABS, _TAB_ICONS, strict=True):
108
  is_selected = tab_name == active_tab
109
  if st.button(
 
113
  type="primary" if is_selected else "secondary",
114
  icon=icon,
115
  ):
116
+ st.session_state[_SIDEBAR_ACTIVE_TAB_KEY] = tab_name
117
  st.rerun()
118
 
119
  if active_tab == "Analysis":
120
  model_name = st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL)
121
  dataset_source = st.session_state.get(
122
+ _SIDEBAR_DATASET_SOURCE_KEY,
123
  DATASET_SOURCES[0],
124
  )
125
  return SidebarState(
 
131
 
132
  st.divider()
133
  st.caption("Runtime")
134
+ remote = st.toggle("Remote (NDIF)", value=False, key=_SIDEBAR_REMOTE_KEY)
135
 
136
  if remote:
137
  model_name = _remote_model_input(list_remote_models())
 
139
  model_name = st.text_input(
140
  "Model",
141
  value=st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL),
142
+ key=_SIDEBAR_LOCAL_MODEL_KEY,
143
  help="Local model id or path.",
144
  )
145
  st.session_state[_LAST_LOCAL_MODEL_KEY] = model_name
 
148
  dataset_source = st.selectbox(
149
  "Source",
150
  DATASET_SOURCES,
151
+ key=_SIDEBAR_DATASET_SOURCE_KEY,
152
  help="Dataset for Chat and Extract.",
153
  )
154
 
 
164
  """Run the Streamlit app."""
165
 
166
  st.set_page_config(page_title="Persona UI", layout="wide")
 
 
167
  install_catppuccin_theme(st.get_option("theme.base"))
168
 
169
  sidebar = _sidebar_controls()
 
173
 
174
  render_extract_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)
175
  elif sidebar.active_tab == "Analysis":
176
+ from tabs.analysis import render_analysis_tab
177
 
178
+ render_analysis_tab()
179
  else:
180
  from tabs.chat import render_chat_tab
181
 
pyproject.toml CHANGED
@@ -5,8 +5,7 @@ description = "Streamlit UI for persona-vectors"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
- "persona-vectors>=0.7.3",
9
- "persona-data>=0.4.2",
10
  "datasets>=4.8.5",
11
  "huggingface-hub>=1.14.0",
12
  "streamlit>=1.44.0",
 
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
+ "persona-vectors>=0.8.0",
 
9
  "datasets>=4.8.5",
10
  "huggingface-hub>=1.14.0",
11
  "streamlit>=1.44.0",
state.py CHANGED
@@ -2,7 +2,8 @@ from typing import Literal, NotRequired, TypedDict
2
 
3
  import streamlit as st
4
 
5
- _CHAT_STATE_PREFIX = "chat_state::"
 
6
  PendingChatAction = Literal["new_user_prompt", "regenerate_after_edit"]
7
 
8
 
@@ -22,7 +23,7 @@ class ChatState(TypedDict):
22
  def chat_session_key(model_name: str, dataset_source: str) -> str:
23
  """Build the session-state key for a chat context."""
24
 
25
- return f"{_CHAT_STATE_PREFIX}{model_name}::{dataset_source}"
26
 
27
 
28
  def default_chat_state() -> ChatState:
@@ -48,9 +49,8 @@ def reset_chat_context_state(
48
  st.session_state.pop(key, None)
49
 
50
 
51
- def get_chat_state(model_name: str, _remote: bool, dataset_source: str) -> ChatState:
52
  """Return the mutable chat state for the active context."""
53
 
54
  key = chat_session_key(model_name, dataset_source)
55
- state = st.session_state.setdefault(key, default_chat_state())
56
- return state
 
2
 
3
  import streamlit as st
4
 
5
+ from utils.helpers import session_key
6
+
7
  PendingChatAction = Literal["new_user_prompt", "regenerate_after_edit"]
8
 
9
 
 
23
  def chat_session_key(model_name: str, dataset_source: str) -> str:
24
  """Build the session-state key for a chat context."""
25
 
26
+ return session_key("chat_state", model_name, dataset_source)
27
 
28
 
29
  def default_chat_state() -> ChatState:
 
49
  st.session_state.pop(key, None)
50
 
51
 
52
+ def get_chat_state(model_name: str, dataset_source: str) -> ChatState:
53
  """Return the mutable chat state for the active context."""
54
 
55
  key = chat_session_key(model_name, dataset_source)
56
+ return st.session_state.setdefault(key, default_chat_state())
 
tabs/analysis.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .analysis_core import render_analysis_tab
tabs/{compare.py → analysis_core.py} RENAMED
@@ -7,7 +7,12 @@ from pathlib import Path
7
  import plotly.graph_objects as go
8
  import streamlit as st
9
  from persona_data.environment import get_artifacts_dir
10
- from persona_data.synth_persona import BASELINE_PERSONA_ID
 
 
 
 
 
11
  from persona_vectors.extraction import MaskStrategy
12
  from persona_vectors.plots import (
13
  build_layered_figure,
@@ -15,10 +20,11 @@ from persona_vectors.plots import (
15
  build_similarity_figures,
16
  plot_layer_similarity,
17
  plot_persona_dendrogram,
 
18
  save_plot_html,
19
  )
20
 
21
- from utils.compare_sources import (
22
  DEFAULT_COMPARE_MODEL,
23
  DEFAULT_HUB_REPO,
24
  SOURCE_HUB,
@@ -28,13 +34,13 @@ from utils.compare_sources import (
28
  activation_store_cached,
29
  available_variants,
30
  hub_models_by_mask_strategy,
31
- load_persona_vectors_lean,
32
- load_variant_vectors_lean,
33
  local_model_matches,
34
  local_model_options_cached,
35
  persona_names_cached,
36
  personas_cached,
37
- release_store_cache,
38
  store_cache_parts,
39
  store_id,
40
  store_layers_cached,
@@ -57,32 +63,45 @@ def _filename(*parts: str) -> str:
57
 
58
  # Keep compare-tab selection state separate so projection defaults do not
59
  # overwrite cosine similarity defaults.
60
- _LAST_COSINE_PERSONAS_KEY = "compare:last_personas:cosine"
61
- _LAST_PROJECTION_PERSONAS_KEY = "compare:last_personas:projection"
62
- _LAST_SIMILARITY_PERSONAS_KEY = "compare:last_personas:similarity"
63
- _LAST_MASK_STRATEGY_KEY = "compare:last_mask_strategy"
64
- _LAST_SOURCE_KEY = "compare:last_source"
 
 
 
 
 
 
 
 
 
65
 
66
  _DEFAULT_LAYER_FRAMES = 16
67
  _DEFAULT_PERSONA_LIMITS = {
68
  "similarity": 120,
69
  "pca": 500,
70
  "umap": 500,
 
71
  "dendro": 160,
72
  }
73
  _MAX_SIMILARITY_CELLS = 4_000_000
74
  _MAX_PAIR_TRAJECTORY_TRACES = 500
75
- _CLUSTER_METHODS = {
76
- "K-means": "kmeans",
77
- "Agglomerative": "agglomerative",
78
- "HDBSCAN": "hdbscan",
79
- }
80
  _CLUSTER_MODES = {
81
  "Mean across layers": "mean_across_layers",
82
  "First selected layer": "first_layer",
83
  "Per layer": "per_layer",
84
  }
85
- _CLUSTER_LINKAGES = ["ward", "complete", "average", "single"]
 
 
 
 
 
 
86
 
87
 
88
  def _is_assistant_persona(persona_id: str, persona_name: str | None = None) -> bool:
@@ -107,6 +126,98 @@ class CosineSelection:
107
  class PersonaOptions:
108
  regular_ids: list[str]
109
  assistant_id: str | None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
 
112
  def _layers_for_variant(
@@ -133,7 +244,7 @@ def _load_persona_vectors(
133
  persona_ids: list[str],
134
  ):
135
  source, location, model_name = store_cache_parts(store)
136
- return load_persona_vectors_lean(
137
  source,
138
  location,
139
  model_name,
@@ -150,7 +261,7 @@ def _load_variant_vectors(
150
  persona_ids: list[str],
151
  ):
152
  source, location, model_name = store_cache_parts(store)
153
- return load_variant_vectors_lean(
154
  source,
155
  location,
156
  model_name,
@@ -160,22 +271,55 @@ def _load_variant_vectors(
160
  )
161
 
162
 
163
- def _clear_old_figure_states(current_key: str) -> None:
164
  for key in list(st.session_state):
165
  if key == current_key or not isinstance(key, str):
166
  continue
167
  parts = key.split("::", 2)
168
- if len(parts) >= 2 and parts[0] == "load" and parts[1].endswith("_fig_state"):
169
  st.session_state.pop(key, None)
170
 
171
 
 
 
 
 
 
 
 
 
172
  def _store_figure_state(key: str, value: object) -> None:
173
  _clear_old_figure_states(key)
174
  st.session_state[key] = value
175
 
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  def _release_vector_memory(store: Store, variants: list[str] | tuple[str, ...]) -> None:
178
- release_store_cache(store, variants)
179
  gc.collect()
180
 
181
 
@@ -203,10 +347,22 @@ def _render_layer_frame_controls(
203
  "Layer frames",
204
  min_value=2,
205
  max_value=len(layers),
206
- value=_DEFAULT_LAYER_FRAMES,
 
 
 
 
 
 
 
 
 
 
 
207
  key=widget_key("load", "layer_frames", scope, store_id(store)),
208
  help="Limit animated Plotly frames to keep browser and RAM usage bounded.",
209
  )
 
210
  selected = _evenly_spaced_layers(layers, frame_count)
211
  st.caption(f"Using {len(selected)} of {len(layers)} layers.")
212
  return selected
@@ -259,7 +415,11 @@ def _load_persona_options(
259
  if not regular_ids and assistant_id is None:
260
  st.info("No personas found for this model and variant.")
261
  return None
262
- return PersonaOptions(regular_ids=regular_ids, assistant_id=assistant_id)
 
 
 
 
263
 
264
 
265
  def _seed_persona_memory(
@@ -366,6 +526,7 @@ def _select_artifact_personas(
366
  empty_message=empty_message,
367
  )
368
  if options is None:
 
369
  return []
370
 
371
  default_count, include_assistant_default = _seed_persona_memory(
@@ -393,6 +554,7 @@ def _select_artifact_personas(
393
  st.session_state[remembered_count_key] = persona_count
394
  st.session_state[remembered_assistant_key] = include_assistant
395
  st.session_state[remember_key] = persona_ids
 
396
 
397
  if not persona_ids:
398
  st.info("Select at least one persona or include the Assistant persona.")
@@ -415,7 +577,9 @@ def _render_save_buttons(
415
  if st.button("Save HTML", key=widget_key("load", "save_html", key_suffix)):
416
  try:
417
  _style_plotly_figures(figs)
418
- paths = [save_plot_html(fig, fn) for fig, fn in zip(figs, filenames)]
 
 
419
  st.success(f"Saved {len(paths)} HTML file(s) to `artifacts/plots`.")
420
  except Exception as exc:
421
  st.error(f"Could not save HTML: {exc}")
@@ -430,7 +594,11 @@ def _style_plotly_figures(figs: list[object]) -> None:
430
 
431
  def _plotly_chart(fig: object) -> None:
432
  _style_plotly_figures([fig])
433
- st.plotly_chart(fig, width="stretch")
 
 
 
 
434
 
435
 
436
  def _render_mask_strategy_select(scope: str) -> MaskStrategy:
@@ -584,7 +752,7 @@ def _render_cosine_similarity(
584
  selection.persona_key,
585
  )
586
  filename = _filename(
587
- "compare",
588
  "cosine",
589
  store.model_name,
590
  mask_strategy.value,
@@ -592,7 +760,7 @@ def _render_cosine_similarity(
592
  selection.variant_b,
593
  )
594
  pairs_filename = _filename(
595
- "compare",
596
  "cosine_pairs",
597
  store.model_name,
598
  mask_strategy.value,
@@ -605,7 +773,7 @@ def _render_cosine_similarity(
605
  type="primary",
606
  key=widget_key(
607
  "load",
608
- "compare_vectors",
609
  store_id(store),
610
  store.model_name,
611
  mask_strategy.value,
@@ -650,19 +818,29 @@ def _select_single_variant_samples(
650
  scope: str,
651
  *,
652
  remember_key: str,
 
653
  default_count_limit: int,
654
  ) -> tuple[str, list[str], str, list[int]] | None:
655
  variants = available_variants(store, mask_strategy)
656
  if not variants:
657
  st.info("No variants with saved vectors for this model.")
658
  return None
 
 
 
 
 
 
 
 
659
  variant = st.selectbox(
660
  "Variant",
661
  options=variants,
662
- index=variants.index("biography") if "biography" in variants else 0,
663
  format_func=prompt_variant_label,
664
- key=widget_key("load", "variant", scope, store_id(store)),
665
  )
 
666
  persona_ids = _select_artifact_personas(
667
  store,
668
  [variant],
@@ -684,6 +862,352 @@ def _select_single_variant_samples(
684
  return variant, persona_ids, persona_key, selected_layers
685
 
686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
687
  def _render_layered_figure_analysis(
688
  store: Store,
689
  mask_strategy: MaskStrategy,
@@ -707,124 +1231,60 @@ def _render_layered_figure_analysis(
707
  mask_strategy,
708
  scope,
709
  remember_key=remember_key,
 
 
 
 
 
710
  default_count_limit=default_count_limit,
711
  )
712
  if selected is None:
713
  return
714
  variant, persona_ids, persona_key, selected_layers = selected
715
 
716
- pair_trajectories = False
717
- if include_pair_trajectories:
718
- pair_count = len(persona_ids) * (len(persona_ids) - 1) // 2
719
- if pair_count > _MAX_PAIR_TRAJECTORY_TRACES:
720
- st.caption(
721
- "Pair trajectories hidden because this selection would create "
722
- f"{pair_count:,} Plotly traces."
723
- )
724
- else:
725
- pair_trajectories = st.checkbox(
726
- "Pair trajectories",
727
- value=False,
728
- key=widget_key("load", "pair_trajectories", scope, store_id(store)),
729
- help="Adds one line per persona pair. Keep this off for larger selections.",
730
- )
731
 
732
- if figure_kind == "similarity":
733
- similarity_cells = len(persona_ids) * len(persona_ids) * len(selected_layers)
734
- if similarity_cells > _MAX_SIMILARITY_CELLS:
735
- st.error(
736
- "Reduce personas or layer frames before generating the similarity "
737
- f"matrix ({similarity_cells:,} cells selected)."
738
- )
739
  return
740
 
741
- n_clusters = None
742
- cluster_mode = None
743
- cluster_method = None
744
- cluster_linkage = None
745
- min_cluster_size = None
746
- if figure_kind in {"pca", "umap"}:
747
- use_clusters = st.toggle(
748
- "Color by clusters",
749
- value=False,
750
- key=widget_key("load", "clusters_enabled", scope, store_id(store)),
751
- help="Cluster persona vectors and color points by cluster.",
752
- )
753
- if use_clusters:
754
- method_label = st.selectbox(
755
- "Cluster algorithm",
756
- options=list(_CLUSTER_METHODS),
757
- index=0,
758
- key=widget_key("load", "cluster_method", scope, store_id(store)),
759
- )
760
- cluster_method = _CLUSTER_METHODS[method_label]
761
- if cluster_method in {"kmeans", "agglomerative"}:
762
- n_clusters = st.slider(
763
- "K (clusters)",
764
- min_value=2,
765
- max_value=min(10, len(persona_ids)),
766
- value=min(3, len(persona_ids)),
767
- key=widget_key("load", "cluster_k", scope, store_id(store)),
768
- )
769
- if cluster_method == "agglomerative":
770
- cluster_linkage = st.selectbox(
771
- "Linkage",
772
- options=_CLUSTER_LINKAGES,
773
- index=0,
774
- key=widget_key("load", "cluster_linkage", scope, store_id(store)),
775
- )
776
- if cluster_method == "hdbscan":
777
- min_cluster_size = st.slider(
778
- "Minimum cluster size",
779
- min_value=2,
780
- max_value=len(persona_ids),
781
- value=min(5, len(persona_ids)),
782
- key=widget_key(
783
- "load",
784
- "cluster_min_cluster_size",
785
- scope,
786
- store_id(store),
787
- ),
788
- )
789
- mode_label = st.selectbox(
790
- "Cluster fit",
791
- options=list(_CLUSTER_MODES),
792
- index=0,
793
- key=widget_key("load", "cluster_mode", scope, store_id(store)),
794
- help=(
795
- "Mean across layers is the previous behavior. First selected "
796
- "layer keeps one fixed clustering from the first frame. Per layer "
797
- "recomputes clustering for each animation frame."
798
- ),
799
- )
800
- cluster_mode = _CLUSTER_MODES[mode_label]
801
-
802
- fig_key = widget_key(
803
- "load",
804
- f"{scope}_fig_state",
805
- store_id(store),
806
- store.model_name,
807
- mask_strategy.value,
808
- figure_kind,
809
- str(n_components),
810
- str(n_clusters),
811
- str(cluster_mode),
812
- str(cluster_method),
813
- str(cluster_linkage),
814
- str(min_cluster_size),
815
- variant,
816
- "persona_vector",
817
- persona_key,
818
- "_".join(map(str, selected_layers)),
819
- str(pair_trajectories),
820
  )
 
 
821
  filename = scope
822
- _clear_old_figure_states(fig_key)
 
 
 
 
823
 
824
  if st.button(button_label, type="primary"):
825
  build_label = {
826
  "umap": "Computing UMAP projections…",
827
  "pca": "Computing PCA projections…",
 
828
  "similarity": "Computing similarity matrices…",
829
  }.get(figure_kind, "Building figure…")
830
  progress = st.progress(0, text="Loading activation vectors…")
@@ -837,63 +1297,44 @@ def _render_layered_figure_analysis(
837
  persona_ids,
838
  )
839
  progress.progress(55, text=build_label)
840
- build_kwargs = {}
841
- if figure_kind in {"umap", "pca"}:
842
- build_kwargs["n_components"] = n_components
843
- if cluster_method is not None:
844
- build_kwargs["cluster_method"] = cluster_method
845
- build_kwargs["n_clusters"] = n_clusters
846
- build_kwargs["cluster_mode"] = cluster_mode
847
- if cluster_linkage is not None:
848
- build_kwargs["cluster_linkage"] = cluster_linkage
849
- if min_cluster_size is not None:
850
- build_kwargs["min_cluster_size"] = min_cluster_size
851
- if figure_kind == "similarity" and pair_trajectories:
852
- main_fig, extra_fig = build_similarity_figures(
853
- samples,
854
- layers=selected_layers,
855
- title=title_fn(variant),
856
- pair_title=(
857
- "Pair similarity trajectories - "
858
- f"{prompt_variant_label(variant)} - persona vectors"
859
- ),
860
- )
861
- else:
862
- main_fig = build_layered_figure(
863
- samples,
864
- figure_kind,
865
- layers=selected_layers,
866
- title=title_fn(variant),
867
- **build_kwargs,
868
- )
869
- if figure_kind in {"umap", "pca"}:
870
- main_fig.update_layout(height=700)
871
- extra_fig = (
872
- build_pair_similarity_figure(
873
- samples,
874
- layers=selected_layers,
875
- title=(
876
- "Pair similarity trajectories - "
877
- f"{prompt_variant_label(variant)} - persona vectors"
878
- ),
879
- )
880
- if pair_trajectories
881
- else None
882
- )
883
  progress.progress(90, text="Storing figure state…")
884
  n_samples = samples.vectors.shape[0]
885
  del samples
886
- _store_figure_state(fig_key, (main_fig, extra_fig, n_samples))
887
  progress.progress(100, text="Done.")
888
  except Exception as exc:
889
  st.error(f"Could not build figure: {exc}")
890
- st.session_state.pop(fig_key, None)
891
  finally:
892
  _release_vector_memory(store, [variant])
893
  progress.empty()
894
 
895
- if fig_key in st.session_state:
896
- main_fig, extra_fig, n_samples = st.session_state[fig_key]
897
  _plotly_chart(main_fig)
898
  figs = [main_fig]
899
  filenames = [filename]
@@ -906,7 +1347,7 @@ def _render_layered_figure_analysis(
906
  st.success(f"Loaded {n_samples} samples.")
907
 
908
 
909
- _LAST_DENDRO_PERSONAS_KEY = "compare:last_personas:dendro"
910
  _DENDRO_LINKAGE_OPTIONS = ["ward", "complete", "average", "single"]
911
 
912
 
@@ -1108,7 +1549,7 @@ def _render_hub_model_select(
1108
  mask_strategy: MaskStrategy,
1109
  ) -> str:
1110
  fallback_model = st.session_state.get(
1111
- "compare:hub_model_fallback",
1112
  DEFAULT_COMPARE_MODEL,
1113
  )
1114
  try:
@@ -1118,7 +1559,7 @@ def _render_hub_model_select(
1118
  return st.text_input(
1119
  "Hub model",
1120
  value=fallback_model,
1121
- key="compare:hub_model_fallback",
1122
  help="Compare-only model id to use if Hub config discovery is unavailable.",
1123
  )
1124
 
@@ -1130,7 +1571,7 @@ def _render_hub_model_select(
1130
  return st.text_input(
1131
  "Hub model",
1132
  value=fallback_model,
1133
- key="compare:hub_model_fallback",
1134
  help="Compare-only model id to use for this Hub repo.",
1135
  )
1136
 
@@ -1155,31 +1596,31 @@ def _render_local_model_select(
1155
  artifacts_root: str,
1156
  mask_strategy: MaskStrategy,
1157
  ) -> str:
1158
- fallback_model = st.session_state.get("compare:local_model", DEFAULT_COMPARE_MODEL)
1159
  model_options = local_model_options_cached(artifacts_root, mask_strategy.value)
1160
  if not model_options:
1161
  return st.text_input(
1162
  "Local model",
1163
  value=fallback_model,
1164
- key="compare:local_model",
1165
  help="Compare-only local model id or path.",
1166
  )
1167
 
1168
  custom = st.toggle(
1169
  "Custom local model",
1170
  value=False,
1171
- key="compare:local_model_custom_enabled",
1172
  help="Enter a model id/path manually instead of choosing from activation directories.",
1173
  )
1174
  if custom:
1175
  return st.text_input(
1176
  "Local model",
1177
  value=fallback_model,
1178
- key="compare:local_model",
1179
  help="Compare-only local model id or path.",
1180
  )
1181
 
1182
- previous_model = st.session_state.get("compare:local_model_select", fallback_model)
1183
  if not any(local_model_matches(previous_model, option) for option in model_options):
1184
  previous_model = fallback_model
1185
  default_model = next(
@@ -1194,10 +1635,10 @@ def _render_local_model_select(
1194
  "Local model",
1195
  options=model_options,
1196
  index=model_options.index(default_model),
1197
- key="compare:local_model_select",
1198
  help="Models discovered under the selected artifacts root.",
1199
  )
1200
- st.session_state["compare:local_model"] = selected
1201
  return selected
1202
 
1203
 
@@ -1205,8 +1646,8 @@ def _build_store(source: str, mask_strategy: MaskStrategy) -> Store:
1205
  if source == SOURCE_HUB:
1206
  repo = st.text_input(
1207
  "Hub repo",
1208
- value=st.session_state.get("compare:hub_repo", DEFAULT_HUB_REPO),
1209
- key="compare:hub_repo",
1210
  help="Hugging Face dataset published by `scripts/push_to_hf.py`.",
1211
  )
1212
  hub_model_name = _render_hub_model_select(repo, mask_strategy)
@@ -1219,7 +1660,7 @@ def _build_store(source: str, mask_strategy: MaskStrategy) -> Store:
1219
  artifacts_root = st.text_input(
1220
  "Artifacts root",
1221
  value=str(get_artifacts_dir() / "activations"),
1222
- key="compare:artifacts_root",
1223
  )
1224
  artifacts_root = str(Path(artifacts_root).expanduser())
1225
  local_model_name = _render_local_model_select(artifacts_root, mask_strategy)
@@ -1231,12 +1672,12 @@ def _build_store(source: str, mask_strategy: MaskStrategy) -> Store:
1231
  )
1232
 
1233
 
1234
- def render_compare_tab() -> None:
1235
  """Render the analysis tab."""
1236
 
1237
  st.title("Analysis")
1238
  st.caption(
1239
- "Analyse persona vectors by cosine similarity, PCA, UMAP, or hierarchical clustering."
1240
  )
1241
 
1242
  source = _render_source_select()
@@ -1279,13 +1720,23 @@ def render_compare_tab() -> None:
1279
  _render_dendrogram_analysis(store, mask_strategy)
1280
  return
1281
 
 
 
 
 
 
 
 
 
1282
  dimension_choice = st.segmented_control(
1283
  "Projection dimensions",
1284
- options=["2D", "3D"],
1285
- default="2D",
1286
- key=widget_key("load", "projection_dims", analysis_mode),
1287
  label_visibility="collapsed",
1288
  )
 
 
1289
  n_components = 3 if dimension_choice == "3D" else 2
1290
  dim_suffix = "" if n_components == 2 else " (3D)"
1291
  _render_layered_figure_analysis(
 
7
  import plotly.graph_objects as go
8
  import streamlit as st
9
  from persona_data.environment import get_artifacts_dir
10
+ from persona_data.synth_persona import BASELINE_PERSONA_ID, SynthPersonaDataset
11
+ from persona_vectors.attributes import (
12
+ DEFAULT_MAX_ATTRIBUTE_CATEGORIES,
13
+ attribute_color_kwargs,
14
+ attribute_display_label,
15
+ )
16
  from persona_vectors.extraction import MaskStrategy
17
  from persona_vectors.plots import (
18
  build_layered_figure,
 
20
  build_similarity_figures,
21
  plot_layer_similarity,
22
  plot_persona_dendrogram,
23
+ prepare_layered_projection_data,
24
  save_plot_html,
25
  )
26
 
27
+ from utils.analysis_sources import (
28
  DEFAULT_COMPARE_MODEL,
29
  DEFAULT_HUB_REPO,
30
  SOURCE_HUB,
 
34
  activation_store_cached,
35
  available_variants,
36
  hub_models_by_mask_strategy,
37
+ load_persona_vectors_cached,
38
+ load_variant_vectors_cached,
39
  local_model_matches,
40
  local_model_options_cached,
41
  persona_names_cached,
42
  personas_cached,
43
+ release_hf_store_cache,
44
  store_cache_parts,
45
  store_id,
46
  store_layers_cached,
 
63
 
64
  # Keep compare-tab selection state separate so projection defaults do not
65
  # overwrite cosine similarity defaults.
66
+ _LAST_COSINE_PERSONAS_KEY = "analysis:last_personas:cosine"
67
+ _LAST_PROJECTION_PERSONAS_KEY = "analysis:last_personas:projection"
68
+ _LAST_SIMILARITY_PERSONAS_KEY = "analysis:last_personas:similarity"
69
+ _LAST_MASK_STRATEGY_KEY = "analysis:last_mask_strategy"
70
+ _LAST_SOURCE_KEY = "analysis:last_source"
71
+ _LAST_PROJECTION_VARIANT_KEY = "analysis:last_projection_variant"
72
+ _LAST_SIMILARITY_VARIANT_KEY = "analysis:last_similarity_variant"
73
+ _LAST_PROJECTION_COLOR_MODE_KEY = "analysis:last_projection_color_mode"
74
+ _LAST_PROJECTION_ATTRIBUTE_KEY = "analysis:last_projection_attribute"
75
+ _LAST_PROJECTION_CLUSTER_K_KEY = "analysis:last_projection_cluster_k"
76
+ _LAST_PROJECTION_CLUSTER_MODE_KEY = "analysis:last_projection_cluster_mode"
77
+ _LAST_PROJECTION_HIGHLIGHTS_KEY = "analysis:last_projection_highlights"
78
+ _LAST_PROJECTION_DIMS_KEY = "analysis:last_projection_dims"
79
+ _LAST_LAYER_FRAMES_KEY = "analysis:last_layer_frames"
80
 
81
  _DEFAULT_LAYER_FRAMES = 16
82
  _DEFAULT_PERSONA_LIMITS = {
83
  "similarity": 120,
84
  "pca": 500,
85
  "umap": 500,
86
+ "isomap": 500,
87
  "dendro": 160,
88
  }
89
  _MAX_SIMILARITY_CELLS = 4_000_000
90
  _MAX_PAIR_TRAJECTORY_TRACES = 500
91
+ _DEFAULT_GRAPH_NEIGHBORS = 5
92
+ _PROJECTION_KINDS = {"pca", "umap", "isomap"}
 
 
 
93
  _CLUSTER_MODES = {
94
  "Mean across layers": "mean_across_layers",
95
  "First selected layer": "first_layer",
96
  "Per layer": "per_layer",
97
  }
98
+ _PROJECTION_COLOR_MODES = ["Persona", "K-means clusters", "Persona attribute"]
99
+ _MAX_ATTRIBUTE_CATEGORIES = DEFAULT_MAX_ATTRIBUTE_CATEGORIES
100
+
101
+
102
+ @st.cache_resource(show_spinner=False)
103
+ def _synth_persona_dataset() -> SynthPersonaDataset:
104
+ return SynthPersonaDataset()
105
 
106
 
107
  def _is_assistant_persona(persona_id: str, persona_name: str | None = None) -> bool:
 
126
  class PersonaOptions:
127
  regular_ids: list[str]
128
  assistant_id: str | None
129
+ persona_names: dict[str, str]
130
+
131
+
132
+ @dataclass(frozen=True)
133
+ class ProjectionColorConfig:
134
+ color_mode: str = "Persona"
135
+ n_clusters: int | None = None
136
+ cluster_mode: str | None = None
137
+ attribute_name: str | None = None
138
+ highlight_persona_ids: tuple[str, ...] = ()
139
+ highlight_persona_key: str = ""
140
+
141
+
142
+ @dataclass(frozen=True)
143
+ class LayeredFigureStateKeys:
144
+ figure: str
145
+ projection: str | None = None
146
+
147
+
148
+ _HIGHLIGHT_OTHER_LABEL = "Other"
149
+ _HIGHLIGHT_OTHER_COLOR = "rgba(148, 163, 184, 0.35)"
150
+
151
+
152
+ def _persona_names_state_key(widget_scope: str) -> str:
153
+ return widget_key("load", "persona_names", widget_scope)
154
+
155
+
156
+ def _persona_display_label(persona_names: dict[str, str], persona_id: str) -> str:
157
+ name = persona_names.get(persona_id, persona_id)
158
+ return f"{name} ({persona_id})" if name != persona_id else persona_id
159
+
160
+
161
+ def _highlight_persona_groups(
162
+ persona_ids: list[str],
163
+ persona_names: dict[str, str],
164
+ highlight_persona_ids: tuple[str, ...],
165
+ ) -> list[str] | None:
166
+ if not highlight_persona_ids:
167
+ return None
168
+
169
+ highlighted = set(highlight_persona_ids)
170
+ return [
171
+ (
172
+ _persona_display_label(persona_names, persona_id)
173
+ if persona_id in highlighted
174
+ else _HIGHLIGHT_OTHER_LABEL
175
+ )
176
+ for persona_id in persona_ids
177
+ ]
178
+
179
+
180
+ def _sequence_to_list(value: object) -> list[object] | None:
181
+ if value is None or isinstance(value, (str, bytes)):
182
+ return None
183
+ if isinstance(value, list):
184
+ return value
185
+ if isinstance(value, tuple):
186
+ return list(value)
187
+ try:
188
+ return list(value)
189
+ except TypeError:
190
+ return None
191
+
192
+
193
+ def _gray_out_unselected_personas(fig: go.Figure) -> None:
194
+ def _gray_trace(trace: object) -> None:
195
+ marker = getattr(trace, "marker", None)
196
+ if marker is None:
197
+ return
198
+
199
+ colors = _sequence_to_list(getattr(marker, "color", None))
200
+ labels = _sequence_to_list(getattr(trace, "customdata", None))
201
+ if colors is not None and labels is not None and len(colors) == len(labels):
202
+ trace.marker.color = [
203
+ (
204
+ _HIGHLIGHT_OTHER_COLOR
205
+ if str(label) == _HIGHLIGHT_OTHER_LABEL
206
+ else color
207
+ )
208
+ for label, color in zip(labels, colors, strict=True)
209
+ ]
210
+ return
211
+
212
+ if getattr(trace, "name", None) == _HIGHLIGHT_OTHER_LABEL:
213
+ trace.marker.color = _HIGHLIGHT_OTHER_COLOR
214
+ trace.opacity = 0.28
215
+
216
+ for trace in fig.data:
217
+ _gray_trace(trace)
218
+ for frame in fig.frames:
219
+ for trace in frame.data:
220
+ _gray_trace(trace)
221
 
222
 
223
  def _layers_for_variant(
 
244
  persona_ids: list[str],
245
  ):
246
  source, location, model_name = store_cache_parts(store)
247
+ return load_persona_vectors_cached(
248
  source,
249
  location,
250
  model_name,
 
261
  persona_ids: list[str],
262
  ):
263
  source, location, model_name = store_cache_parts(store)
264
+ return load_variant_vectors_cached(
265
  source,
266
  location,
267
  model_name,
 
271
  )
272
 
273
 
274
+ def _clear_old_load_states(current_key: str, suffix: str) -> None:
275
  for key in list(st.session_state):
276
  if key == current_key or not isinstance(key, str):
277
  continue
278
  parts = key.split("::", 2)
279
+ if len(parts) >= 2 and parts[0] == "load" and parts[1].endswith(suffix):
280
  st.session_state.pop(key, None)
281
 
282
 
283
+ def _clear_old_figure_states(current_key: str) -> None:
284
+ _clear_old_load_states(current_key, "_fig_state")
285
+
286
+
287
+ def _clear_old_projection_states(current_key: str) -> None:
288
+ _clear_old_load_states(current_key, "_projection_state")
289
+
290
+
291
  def _store_figure_state(key: str, value: object) -> None:
292
  _clear_old_figure_states(key)
293
  st.session_state[key] = value
294
 
295
 
296
+ def _seed_selectbox_key(
297
+ *,
298
+ key: str,
299
+ remember_key: str,
300
+ options: list[str],
301
+ default: str,
302
+ ) -> str:
303
+ value = st.session_state.get(key, st.session_state.get(remember_key, default))
304
+ if value not in options:
305
+ value = default
306
+ return value
307
+
308
+
309
+ def _remember_multiselect(
310
+ *,
311
+ key: str,
312
+ remember_key: str,
313
+ options: list[str],
314
+ ) -> list[str]:
315
+ remembered = st.session_state.get(key, st.session_state.get(remember_key, []))
316
+ if not isinstance(remembered, list):
317
+ remembered = []
318
+ return [value for value in remembered if value in options]
319
+
320
+
321
  def _release_vector_memory(store: Store, variants: list[str] | tuple[str, ...]) -> None:
322
+ release_hf_store_cache(store, variants)
323
  gc.collect()
324
 
325
 
 
347
  "Layer frames",
348
  min_value=2,
349
  max_value=len(layers),
350
+ value=min(
351
+ max(
352
+ int(
353
+ st.session_state.get(
354
+ _LAST_LAYER_FRAMES_KEY,
355
+ _DEFAULT_LAYER_FRAMES,
356
+ )
357
+ ),
358
+ 2,
359
+ ),
360
+ len(layers),
361
+ ),
362
  key=widget_key("load", "layer_frames", scope, store_id(store)),
363
  help="Limit animated Plotly frames to keep browser and RAM usage bounded.",
364
  )
365
+ st.session_state[_LAST_LAYER_FRAMES_KEY] = frame_count
366
  selected = _evenly_spaced_layers(layers, frame_count)
367
  st.caption(f"Using {len(selected)} of {len(layers)} layers.")
368
  return selected
 
415
  if not regular_ids and assistant_id is None:
416
  st.info("No personas found for this model and variant.")
417
  return None
418
+ return PersonaOptions(
419
+ regular_ids=regular_ids,
420
+ assistant_id=assistant_id,
421
+ persona_names=persona_names,
422
+ )
423
 
424
 
425
  def _seed_persona_memory(
 
526
  empty_message=empty_message,
527
  )
528
  if options is None:
529
+ st.session_state.pop(_persona_names_state_key(widget_scope), None)
530
  return []
531
 
532
  default_count, include_assistant_default = _seed_persona_memory(
 
554
  st.session_state[remembered_count_key] = persona_count
555
  st.session_state[remembered_assistant_key] = include_assistant
556
  st.session_state[remember_key] = persona_ids
557
+ st.session_state[_persona_names_state_key(widget_scope)] = options.persona_names
558
 
559
  if not persona_ids:
560
  st.info("Select at least one persona or include the Assistant persona.")
 
577
  if st.button("Save HTML", key=widget_key("load", "save_html", key_suffix)):
578
  try:
579
  _style_plotly_figures(figs)
580
+ paths = [
581
+ save_plot_html(fig, fn) for fig, fn in zip(figs, filenames, strict=True)
582
+ ]
583
  st.success(f"Saved {len(paths)} HTML file(s) to `artifacts/plots`.")
584
  except Exception as exc:
585
  st.error(f"Could not save HTML: {exc}")
 
594
 
595
  def _plotly_chart(fig: object) -> None:
596
  _style_plotly_figures([fig])
597
+ st.plotly_chart(
598
+ fig,
599
+ width="stretch",
600
+ config={"responsive": True, "displaylogo": False},
601
+ )
602
 
603
 
604
  def _render_mask_strategy_select(scope: str) -> MaskStrategy:
 
752
  selection.persona_key,
753
  )
754
  filename = _filename(
755
+ "analysis",
756
  "cosine",
757
  store.model_name,
758
  mask_strategy.value,
 
760
  selection.variant_b,
761
  )
762
  pairs_filename = _filename(
763
+ "analysis",
764
  "cosine_pairs",
765
  store.model_name,
766
  mask_strategy.value,
 
773
  type="primary",
774
  key=widget_key(
775
  "load",
776
+ "analysis_vectors",
777
  store_id(store),
778
  store.model_name,
779
  mask_strategy.value,
 
818
  scope: str,
819
  *,
820
  remember_key: str,
821
+ variant_remember_key: str,
822
  default_count_limit: int,
823
  ) -> tuple[str, list[str], str, list[int]] | None:
824
  variants = available_variants(store, mask_strategy)
825
  if not variants:
826
  st.info("No variants with saved vectors for this model.")
827
  return None
828
+ variant_key = widget_key("load", "variant", scope, store_id(store))
829
+ default_variant = "biography" if "biography" in variants else variants[0]
830
+ selected_variant = _seed_selectbox_key(
831
+ key=variant_key,
832
+ remember_key=variant_remember_key,
833
+ options=variants,
834
+ default=default_variant,
835
+ )
836
  variant = st.selectbox(
837
  "Variant",
838
  options=variants,
839
+ index=variants.index(selected_variant),
840
  format_func=prompt_variant_label,
841
+ key=variant_key,
842
  )
843
+ st.session_state[variant_remember_key] = variant
844
  persona_ids = _select_artifact_personas(
845
  store,
846
  [variant],
 
862
  return variant, persona_ids, persona_key, selected_layers
863
 
864
 
865
+ def _render_pair_trajectory_control(
866
+ *,
867
+ enabled: bool,
868
+ persona_count: int,
869
+ scope: str,
870
+ store: Store,
871
+ ) -> bool:
872
+ if not enabled:
873
+ return False
874
+ pair_count = persona_count * (persona_count - 1) // 2
875
+ if pair_count > _MAX_PAIR_TRAJECTORY_TRACES:
876
+ st.caption(
877
+ "Pair trajectories hidden because this selection would create "
878
+ f"{pair_count:,} Plotly traces."
879
+ )
880
+ return False
881
+ return st.checkbox(
882
+ "Pair trajectories",
883
+ value=False,
884
+ key=widget_key("load", "pair_trajectories", scope, store_id(store)),
885
+ help="Adds one line per persona pair. Keep this off for larger selections.",
886
+ )
887
+
888
+
889
+ def _validate_layered_figure_size(
890
+ figure_kind: str,
891
+ persona_count: int,
892
+ selected_layers: list[int],
893
+ ) -> bool:
894
+ if figure_kind != "similarity":
895
+ return True
896
+ similarity_cells = persona_count * persona_count * len(selected_layers)
897
+ if similarity_cells <= _MAX_SIMILARITY_CELLS:
898
+ return True
899
+ st.error(
900
+ "Reduce personas or layer frames before generating the similarity "
901
+ f"matrix ({similarity_cells:,} cells selected)."
902
+ )
903
+ return False
904
+
905
+
906
+ def _render_projection_color_config(
907
+ store: Store,
908
+ scope: str,
909
+ persona_ids: list[str],
910
+ ) -> ProjectionColorConfig | None:
911
+ widget_scope = f"{scope}:{store_id(store)}"
912
+ persona_key = personas_fingerprint(persona_ids)
913
+ persona_names = st.session_state.get(
914
+ _persona_names_state_key(widget_scope),
915
+ {},
916
+ )
917
+ color_mode_key = widget_key("load", "color_mode", scope, store_id(store))
918
+ selected_color_mode = _seed_selectbox_key(
919
+ key=color_mode_key,
920
+ remember_key=_LAST_PROJECTION_COLOR_MODE_KEY,
921
+ options=_PROJECTION_COLOR_MODES,
922
+ default="Persona",
923
+ )
924
+ color_mode = st.selectbox(
925
+ "Color by",
926
+ options=_PROJECTION_COLOR_MODES,
927
+ index=_PROJECTION_COLOR_MODES.index(selected_color_mode),
928
+ key=color_mode_key,
929
+ )
930
+ st.session_state[_LAST_PROJECTION_COLOR_MODE_KEY] = color_mode
931
+ if color_mode == "K-means clusters":
932
+ max_clusters = min(10, len(persona_ids))
933
+ if max_clusters < 2:
934
+ st.info("Select at least two personas to use K-means coloring.")
935
+ return None
936
+ cluster_key = widget_key("load", "cluster_k", scope, store_id(store))
937
+ default_clusters = min(3, len(persona_ids))
938
+ if cluster_key not in st.session_state:
939
+ st.session_state[cluster_key] = min(
940
+ max(
941
+ int(
942
+ st.session_state.get(
943
+ _LAST_PROJECTION_CLUSTER_K_KEY,
944
+ default_clusters,
945
+ )
946
+ ),
947
+ 2,
948
+ ),
949
+ max_clusters,
950
+ )
951
+ n_clusters = st.slider(
952
+ "K (clusters)",
953
+ min_value=2,
954
+ max_value=max_clusters,
955
+ key=cluster_key,
956
+ )
957
+ mode_key = widget_key("load", "cluster_mode", scope, store_id(store))
958
+ mode_options = list(_CLUSTER_MODES)
959
+ selected_mode = _seed_selectbox_key(
960
+ key=mode_key,
961
+ remember_key=_LAST_PROJECTION_CLUSTER_MODE_KEY,
962
+ options=mode_options,
963
+ default=mode_options[0],
964
+ )
965
+ mode_label = st.selectbox(
966
+ "Cluster fit",
967
+ options=mode_options,
968
+ index=mode_options.index(selected_mode),
969
+ key=mode_key,
970
+ help=(
971
+ "Mean across layers is the previous behavior. First selected "
972
+ "layer keeps one fixed clustering from the first frame. Per layer "
973
+ "recomputes clustering for each animation frame."
974
+ ),
975
+ )
976
+ st.session_state[_LAST_PROJECTION_CLUSTER_K_KEY] = n_clusters
977
+ st.session_state[_LAST_PROJECTION_CLUSTER_MODE_KEY] = mode_label
978
+ return ProjectionColorConfig(
979
+ color_mode=color_mode,
980
+ n_clusters=n_clusters,
981
+ cluster_mode=_CLUSTER_MODES[mode_label],
982
+ )
983
+
984
+ if color_mode == "Persona attribute":
985
+ persona_dataset = _synth_persona_dataset()
986
+ attribute_options = list(persona_dataset.attribute_names)
987
+ if not attribute_options:
988
+ st.info("No persona attributes are available for this dataset.")
989
+ return None
990
+ default_attribute = (
991
+ attribute_options.index("sex") if "sex" in attribute_options else 0
992
+ )
993
+ attribute_key = widget_key("load", "attribute", scope, store_id(store))
994
+ selected_attribute = _seed_selectbox_key(
995
+ key=attribute_key,
996
+ remember_key=_LAST_PROJECTION_ATTRIBUTE_KEY,
997
+ options=attribute_options,
998
+ default=attribute_options[default_attribute],
999
+ )
1000
+ attribute_name = st.selectbox(
1001
+ "Attribute",
1002
+ options=attribute_options,
1003
+ index=attribute_options.index(selected_attribute),
1004
+ format_func=lambda name: attribute_display_label(persona_dataset, name),
1005
+ key=attribute_key,
1006
+ )
1007
+ st.session_state[_LAST_PROJECTION_ATTRIBUTE_KEY] = attribute_name
1008
+ info = persona_dataset.attribute_info(attribute_name)
1009
+ if info.get("high_cardinality"):
1010
+ st.caption(
1011
+ "High-cardinality categorical attributes are grouped to the "
1012
+ f"top {_MAX_ATTRIBUTE_CATEGORIES} values plus Other."
1013
+ )
1014
+ return ProjectionColorConfig(
1015
+ color_mode=color_mode,
1016
+ attribute_name=attribute_name,
1017
+ )
1018
+
1019
+ highlight_persona_ids: tuple[str, ...] = ()
1020
+ if persona_ids:
1021
+ highlight_key = widget_key(
1022
+ "load", "persona_highlight", scope, store_id(store), persona_key
1023
+ )
1024
+ highlighted = st.multiselect(
1025
+ "Highlight personas",
1026
+ options=persona_ids,
1027
+ default=_remember_multiselect(
1028
+ key=highlight_key,
1029
+ remember_key=_LAST_PROJECTION_HIGHLIGHTS_KEY,
1030
+ options=persona_ids,
1031
+ ),
1032
+ format_func=lambda persona_id: _persona_display_label(
1033
+ persona_names, persona_id
1034
+ ),
1035
+ key=highlight_key,
1036
+ help=(
1037
+ "Select a few personas to keep their default colors while the rest "
1038
+ "are grayed out."
1039
+ ),
1040
+ )
1041
+ highlight_persona_ids = tuple(highlighted)
1042
+ st.session_state[_LAST_PROJECTION_HIGHLIGHTS_KEY] = list(highlighted)
1043
+
1044
+ highlight_persona_key = (
1045
+ personas_fingerprint(highlight_persona_ids) if highlight_persona_ids else ""
1046
+ )
1047
+
1048
+ return ProjectionColorConfig(
1049
+ color_mode=color_mode,
1050
+ highlight_persona_ids=highlight_persona_ids,
1051
+ highlight_persona_key=highlight_persona_key,
1052
+ )
1053
+
1054
+
1055
+ def _layered_figure_state_keys(
1056
+ store: Store,
1057
+ mask_strategy: MaskStrategy,
1058
+ *,
1059
+ scope: str,
1060
+ figure_kind: str,
1061
+ n_components: int,
1062
+ color_config: ProjectionColorConfig,
1063
+ variant: str,
1064
+ persona_key: str,
1065
+ selected_layers: list[int],
1066
+ pair_trajectories: bool,
1067
+ ) -> LayeredFigureStateKeys:
1068
+ layer_key = "_".join(map(str, selected_layers))
1069
+ figure_key = widget_key(
1070
+ "load",
1071
+ f"{scope}_fig_state",
1072
+ store_id(store),
1073
+ store.model_name,
1074
+ mask_strategy.value,
1075
+ figure_kind,
1076
+ str(n_components),
1077
+ color_config.color_mode,
1078
+ str(color_config.attribute_name),
1079
+ str(color_config.n_clusters),
1080
+ str(color_config.cluster_mode),
1081
+ str(color_config.highlight_persona_key),
1082
+ variant,
1083
+ "persona_vector",
1084
+ persona_key,
1085
+ layer_key,
1086
+ str(pair_trajectories),
1087
+ )
1088
+ if figure_kind not in _PROJECTION_KINDS:
1089
+ return LayeredFigureStateKeys(figure=figure_key)
1090
+
1091
+ graph_overlay = figure_kind == "isomap"
1092
+ projection_key = widget_key(
1093
+ "load",
1094
+ f"{scope}_projection_state",
1095
+ store_id(store),
1096
+ store.model_name,
1097
+ mask_strategy.value,
1098
+ figure_kind,
1099
+ str(n_components),
1100
+ str(graph_overlay),
1101
+ str(_DEFAULT_GRAPH_NEIGHBORS),
1102
+ variant,
1103
+ "persona_vector",
1104
+ persona_key,
1105
+ layer_key,
1106
+ )
1107
+ return LayeredFigureStateKeys(figure=figure_key, projection=projection_key)
1108
+
1109
+
1110
+ def _projection_build_kwargs(
1111
+ samples,
1112
+ *,
1113
+ figure_kind: str,
1114
+ selected_layers: list[int],
1115
+ n_components: int,
1116
+ color_config: ProjectionColorConfig,
1117
+ persona_ids: list[str],
1118
+ persona_names: dict[str, str],
1119
+ projection_key: str | None,
1120
+ ) -> dict:
1121
+ if figure_kind not in _PROJECTION_KINDS:
1122
+ return {}
1123
+
1124
+ graph_overlay = figure_kind == "isomap"
1125
+ build_kwargs = {
1126
+ "n_components": n_components,
1127
+ "graph_overlay": graph_overlay,
1128
+ "graph_n_neighbors": _DEFAULT_GRAPH_NEIGHBORS,
1129
+ }
1130
+ if color_config.n_clusters is not None:
1131
+ build_kwargs["n_clusters"] = color_config.n_clusters
1132
+ build_kwargs["cluster_mode"] = color_config.cluster_mode
1133
+ if projection_key is not None:
1134
+ projection_data = st.session_state.get(projection_key)
1135
+ if projection_data is None:
1136
+ projection_data = prepare_layered_projection_data(
1137
+ samples,
1138
+ figure_kind,
1139
+ layers=selected_layers,
1140
+ n_components=n_components,
1141
+ graph_overlay=graph_overlay,
1142
+ graph_n_neighbors=_DEFAULT_GRAPH_NEIGHBORS,
1143
+ )
1144
+ st.session_state[projection_key] = projection_data
1145
+ build_kwargs["projection_data"] = projection_data
1146
+ if color_config.attribute_name is not None:
1147
+ build_kwargs.update(
1148
+ attribute_color_kwargs(
1149
+ _synth_persona_dataset(),
1150
+ color_config.attribute_name,
1151
+ persona_ids,
1152
+ max_categories=_MAX_ATTRIBUTE_CATEGORIES,
1153
+ )
1154
+ )
1155
+ if color_config.color_mode == "Persona" and color_config.highlight_persona_ids:
1156
+ groups = _highlight_persona_groups(
1157
+ persona_ids,
1158
+ persona_names,
1159
+ color_config.highlight_persona_ids,
1160
+ )
1161
+ if groups is not None:
1162
+ build_kwargs["groups"] = groups
1163
+ return build_kwargs
1164
+
1165
+
1166
+ def _build_layered_analysis_figures(
1167
+ samples,
1168
+ *,
1169
+ figure_kind: str,
1170
+ selected_layers: list[int],
1171
+ variant: str,
1172
+ title_fn: Callable[[str], str],
1173
+ pair_trajectories: bool,
1174
+ build_kwargs: dict,
1175
+ ) -> tuple[go.Figure, go.Figure | None]:
1176
+ if figure_kind == "similarity" and pair_trajectories:
1177
+ return build_similarity_figures(
1178
+ samples,
1179
+ layers=selected_layers,
1180
+ title=title_fn(variant),
1181
+ pair_title=(
1182
+ "Pair similarity trajectories - "
1183
+ f"{prompt_variant_label(variant)} - persona vectors"
1184
+ ),
1185
+ )
1186
+
1187
+ main_fig = build_layered_figure(
1188
+ samples,
1189
+ figure_kind,
1190
+ layers=selected_layers,
1191
+ title=title_fn(variant),
1192
+ **build_kwargs,
1193
+ )
1194
+ if figure_kind in _PROJECTION_KINDS:
1195
+ main_fig.update_layout(height=700)
1196
+ extra_fig = (
1197
+ build_pair_similarity_figure(
1198
+ samples,
1199
+ layers=selected_layers,
1200
+ title=(
1201
+ "Pair similarity trajectories - "
1202
+ f"{prompt_variant_label(variant)} - persona vectors"
1203
+ ),
1204
+ )
1205
+ if pair_trajectories
1206
+ else None
1207
+ )
1208
+ return main_fig, extra_fig
1209
+
1210
+
1211
  def _render_layered_figure_analysis(
1212
  store: Store,
1213
  mask_strategy: MaskStrategy,
 
1231
  mask_strategy,
1232
  scope,
1233
  remember_key=remember_key,
1234
+ variant_remember_key=(
1235
+ _LAST_PROJECTION_VARIANT_KEY
1236
+ if figure_kind in _PROJECTION_KINDS
1237
+ else _LAST_SIMILARITY_VARIANT_KEY
1238
+ ),
1239
  default_count_limit=default_count_limit,
1240
  )
1241
  if selected is None:
1242
  return
1243
  variant, persona_ids, persona_key, selected_layers = selected
1244
 
1245
+ pair_trajectories = _render_pair_trajectory_control(
1246
+ enabled=include_pair_trajectories,
1247
+ persona_count=len(persona_ids),
1248
+ scope=scope,
1249
+ store=store,
1250
+ )
1251
+ if not _validate_layered_figure_size(
1252
+ figure_kind, len(persona_ids), selected_layers
1253
+ ):
1254
+ return
 
 
 
 
 
1255
 
1256
+ color_config = ProjectionColorConfig()
1257
+ if figure_kind in _PROJECTION_KINDS:
1258
+ color_config = _render_projection_color_config(store, scope, persona_ids)
1259
+ if color_config is None:
 
 
 
1260
  return
1261
 
1262
+ state_keys = _layered_figure_state_keys(
1263
+ store,
1264
+ mask_strategy,
1265
+ scope=scope,
1266
+ figure_kind=figure_kind,
1267
+ n_components=n_components,
1268
+ color_config=color_config,
1269
+ variant=variant,
1270
+ persona_key=persona_key,
1271
+ selected_layers=selected_layers,
1272
+ pair_trajectories=pair_trajectories,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1273
  )
1274
+ if state_keys.projection is not None:
1275
+ _clear_old_projection_states(state_keys.projection)
1276
  filename = scope
1277
+ _clear_old_figure_states(state_keys.figure)
1278
+ persona_names = st.session_state.get(
1279
+ _persona_names_state_key(f"{scope}:{store_id(store)}"),
1280
+ {},
1281
+ )
1282
 
1283
  if st.button(button_label, type="primary"):
1284
  build_label = {
1285
  "umap": "Computing UMAP projections…",
1286
  "pca": "Computing PCA projections…",
1287
+ "isomap": "Computing Isomap projections…",
1288
  "similarity": "Computing similarity matrices…",
1289
  }.get(figure_kind, "Building figure…")
1290
  progress = st.progress(0, text="Loading activation vectors…")
 
1297
  persona_ids,
1298
  )
1299
  progress.progress(55, text=build_label)
1300
+ build_kwargs = _projection_build_kwargs(
1301
+ samples,
1302
+ figure_kind=figure_kind,
1303
+ selected_layers=selected_layers,
1304
+ n_components=n_components,
1305
+ color_config=color_config,
1306
+ persona_ids=persona_ids,
1307
+ persona_names=persona_names,
1308
+ projection_key=state_keys.projection,
1309
+ )
1310
+ main_fig, extra_fig = _build_layered_analysis_figures(
1311
+ samples,
1312
+ figure_kind=figure_kind,
1313
+ selected_layers=selected_layers,
1314
+ variant=variant,
1315
+ title_fn=title_fn,
1316
+ pair_trajectories=pair_trajectories,
1317
+ build_kwargs=build_kwargs,
1318
+ )
1319
+ if (
1320
+ color_config.color_mode == "Persona"
1321
+ and color_config.highlight_persona_ids
1322
+ ):
1323
+ _gray_out_unselected_personas(main_fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1324
  progress.progress(90, text="Storing figure state…")
1325
  n_samples = samples.vectors.shape[0]
1326
  del samples
1327
+ _store_figure_state(state_keys.figure, (main_fig, extra_fig, n_samples))
1328
  progress.progress(100, text="Done.")
1329
  except Exception as exc:
1330
  st.error(f"Could not build figure: {exc}")
1331
+ st.session_state.pop(state_keys.figure, None)
1332
  finally:
1333
  _release_vector_memory(store, [variant])
1334
  progress.empty()
1335
 
1336
+ if state_keys.figure in st.session_state:
1337
+ main_fig, extra_fig, n_samples = st.session_state[state_keys.figure]
1338
  _plotly_chart(main_fig)
1339
  figs = [main_fig]
1340
  filenames = [filename]
 
1347
  st.success(f"Loaded {n_samples} samples.")
1348
 
1349
 
1350
+ _LAST_DENDRO_PERSONAS_KEY = "analysis:last_personas:dendro"
1351
  _DENDRO_LINKAGE_OPTIONS = ["ward", "complete", "average", "single"]
1352
 
1353
 
 
1549
  mask_strategy: MaskStrategy,
1550
  ) -> str:
1551
  fallback_model = st.session_state.get(
1552
+ "analysis:hub_model_fallback",
1553
  DEFAULT_COMPARE_MODEL,
1554
  )
1555
  try:
 
1559
  return st.text_input(
1560
  "Hub model",
1561
  value=fallback_model,
1562
+ key="analysis:hub_model_fallback",
1563
  help="Compare-only model id to use if Hub config discovery is unavailable.",
1564
  )
1565
 
 
1571
  return st.text_input(
1572
  "Hub model",
1573
  value=fallback_model,
1574
+ key="analysis:hub_model_fallback",
1575
  help="Compare-only model id to use for this Hub repo.",
1576
  )
1577
 
 
1596
  artifacts_root: str,
1597
  mask_strategy: MaskStrategy,
1598
  ) -> str:
1599
+ fallback_model = st.session_state.get("analysis:local_model", DEFAULT_COMPARE_MODEL)
1600
  model_options = local_model_options_cached(artifacts_root, mask_strategy.value)
1601
  if not model_options:
1602
  return st.text_input(
1603
  "Local model",
1604
  value=fallback_model,
1605
+ key="analysis:local_model",
1606
  help="Compare-only local model id or path.",
1607
  )
1608
 
1609
  custom = st.toggle(
1610
  "Custom local model",
1611
  value=False,
1612
+ key="analysis:local_model_custom_enabled",
1613
  help="Enter a model id/path manually instead of choosing from activation directories.",
1614
  )
1615
  if custom:
1616
  return st.text_input(
1617
  "Local model",
1618
  value=fallback_model,
1619
+ key="analysis:local_model",
1620
  help="Compare-only local model id or path.",
1621
  )
1622
 
1623
+ previous_model = st.session_state.get("analysis:local_model_select", fallback_model)
1624
  if not any(local_model_matches(previous_model, option) for option in model_options):
1625
  previous_model = fallback_model
1626
  default_model = next(
 
1635
  "Local model",
1636
  options=model_options,
1637
  index=model_options.index(default_model),
1638
+ key="analysis:local_model_select",
1639
  help="Models discovered under the selected artifacts root.",
1640
  )
1641
+ st.session_state["analysis:local_model"] = selected
1642
  return selected
1643
 
1644
 
 
1646
  if source == SOURCE_HUB:
1647
  repo = st.text_input(
1648
  "Hub repo",
1649
+ value=st.session_state.get("analysis:hub_repo", DEFAULT_HUB_REPO),
1650
+ key="analysis:hub_repo",
1651
  help="Hugging Face dataset published by `scripts/push_to_hf.py`.",
1652
  )
1653
  hub_model_name = _render_hub_model_select(repo, mask_strategy)
 
1660
  artifacts_root = st.text_input(
1661
  "Artifacts root",
1662
  value=str(get_artifacts_dir() / "activations"),
1663
+ key="analysis:artifacts_root",
1664
  )
1665
  artifacts_root = str(Path(artifacts_root).expanduser())
1666
  local_model_name = _render_local_model_select(artifacts_root, mask_strategy)
 
1672
  )
1673
 
1674
 
1675
+ def render_analysis_tab() -> None:
1676
  """Render the analysis tab."""
1677
 
1678
  st.title("Analysis")
1679
  st.caption(
1680
+ "Analyse persona vectors by cosine similarity, PCA, UMAP, Isomap, or hierarchical clustering."
1681
  )
1682
 
1683
  source = _render_source_select()
 
1720
  _render_dendrogram_analysis(store, mask_strategy)
1721
  return
1722
 
1723
+ dim_options = ["2D", "3D"]
1724
+ dim_key = widget_key("load", "projection_dims", analysis_mode)
1725
+ remembered_dim = st.session_state.get(
1726
+ dim_key,
1727
+ st.session_state.get(_LAST_PROJECTION_DIMS_KEY, "2D"),
1728
+ )
1729
+ if remembered_dim not in dim_options:
1730
+ remembered_dim = "2D"
1731
  dimension_choice = st.segmented_control(
1732
  "Projection dimensions",
1733
+ options=dim_options,
1734
+ default=remembered_dim,
1735
+ key=dim_key,
1736
  label_visibility="collapsed",
1737
  )
1738
+ if dimension_choice is not None:
1739
+ st.session_state[_LAST_PROJECTION_DIMS_KEY] = dimension_choice
1740
  n_components = 3 if dimension_choice == "3D" else 2
1741
  dim_suffix = "" if n_components == 2 else " (3D)"
1742
  _render_layered_figure_analysis(
tabs/chat.py CHANGED
@@ -1,50 +1,39 @@
 
 
 
 
1
  import streamlit as st
2
  from persona_data.synth_persona import PersonaData
3
 
4
- from state import ChatState, chat_session_key, get_chat_state, reset_chat_context_state
 
 
 
 
 
 
 
 
 
 
 
 
5
  from tabs.chat_ui import (
6
  GenerationConfig,
7
  render_advanced_settings,
8
  render_chat_window,
9
- render_persona_prompt_controls,
10
  render_system_prompt,
11
  )
12
- from utils.chat import (
13
- ChatReply,
14
- build_chat_messages,
15
- generate_chat_reply,
16
- resolve_system_prompt,
17
- )
18
  from utils.chat_export import save_chat_export
19
- from utils.datasets import load_persona_list
20
- from utils.helpers import widget_key
21
  from utils.runtime import cached_model
22
 
23
- _LAST_PERSONA_ID_KEY = "chat:last_persona_id"
24
- _LAST_PROMPT_MODE_KEY = "chat:last_prompt_mode"
25
- _LAST_COMPARE_MODE_KEY = "chat:last_compare_mode"
26
- _LAST_PROBE_ENABLED_KEY = "chat:last_probe_enabled"
27
- _LAST_TOKEN_CONTRAST_KEY = "chat:last_token_contrast"
28
-
29
-
30
- def _load_personas(dataset_source: str) -> list[PersonaData] | None:
31
- try:
32
- personas, dataset_status = load_persona_list(
33
- dataset_source,
34
- personas_file=st.session_state.get("extract__personas_file"),
35
- qa_file=st.session_state.get("extract__qa_file"),
36
- )
37
- st.caption(dataset_status)
38
- except Exception as exc:
39
- st.error(f"Could not load data: {exc}")
40
- st.info("Check the selected dataset source or upload both JSONL files.")
41
- return None
42
-
43
- if not personas:
44
- st.warning("No personas found in the selected dataset.")
45
- st.info("Try a different dataset source or upload a non-empty personas file.")
46
- return None
47
- return personas
48
 
49
 
50
  def _render_single_chat_footer(
@@ -99,27 +88,32 @@ def _handle_single_chat_generation(
99
  chat_state: ChatState,
100
  active_system_prompt: str | None,
101
  generation: GenerationConfig,
102
- pending_action: object,
103
  chat_log,
104
  ) -> None:
105
  messages = build_chat_messages(active_system_prompt, chat_state["messages"])
106
 
107
  with st.spinner("Generating reply..."):
108
  model = cached_model(model_name=model_name)
109
- try:
110
- reply: ChatReply = generate_chat_reply(
111
- model=model,
112
- messages=messages,
113
- remote=remote,
114
- **generation.to_generate_kwargs(),
115
- )
116
- except Exception as exc:
117
  with chat_log:
118
  st.error(f"Could not generate a reply: {exc}")
119
  st.info("Try a shorter prompt, reset the chat, or switch personas.")
 
 
 
 
 
 
 
 
 
120
  if pending_action == "new_user_prompt" and chat_state["messages"]:
121
  chat_state["messages"].pop()
122
  return
 
 
123
 
124
  chat_state["messages"].append({"role": "assistant", "content": reply.text})
125
  st.rerun()
@@ -132,16 +126,14 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
132
  st.caption("Chat with a persona, optionally side-by-side or with token contrast.")
133
 
134
  context_key = chat_session_key(model_name, dataset_source)
135
- chat_state = get_chat_state(model_name, remote, dataset_source)
136
-
137
- # Carry over persona / prompt selections across model or remote switches.
138
- if chat_state["persona_id"] is None:
139
- chat_state["persona_id"] = st.session_state.get(_LAST_PERSONA_ID_KEY)
140
- chat_state["prompt_mode"] = st.session_state.get(
141
- _LAST_PROMPT_MODE_KEY, "templated"
142
- )
143
 
144
- personas = _load_personas(dataset_source)
145
  if personas is None:
146
  return
147
 
@@ -166,7 +158,6 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
166
  )
167
  return
168
 
169
- # ── Single-chat mode ──────────────────────────────────────────────────────
170
  persona_select_key = widget_key(context_key, "persona_select")
171
  prompt_mode_select_key = widget_key(context_key, "system_prompt_select")
172
  prompt_key = widget_key(context_key, "custom_system_prompt")
@@ -176,6 +167,20 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
176
  reset_key = widget_key(context_key, "reset")
177
  edit_key = widget_key(context_key, "edit_idx")
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  def _reset_active_chat_context() -> None:
180
  reset_chat_context_state(
181
  chat_state,
@@ -187,17 +192,6 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
187
  )
188
  st.session_state.pop(edit_key, None)
189
 
190
- selected_persona, prompt_mode, changed_context = render_persona_prompt_controls(
191
- personas,
192
- chat_state["persona_id"],
193
- chat_state["prompt_mode"],
194
- persona_select_key,
195
- prompt_mode_select_key,
196
- column_widths=(2, 1),
197
- )
198
- st.session_state[_LAST_PERSONA_ID_KEY] = selected_persona.id
199
- st.session_state[_LAST_PROMPT_MODE_KEY] = prompt_mode
200
-
201
  active_system_prompt = resolve_system_prompt(
202
  persona=selected_persona,
203
  mode=prompt_mode,
@@ -259,14 +253,15 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
259
 
260
  user_prompt = st.chat_input("Ask something...", key=chat_input_key)
261
 
262
- # Pass 1: user submitted — append message and rerun so it renders before generation.
263
  if user_prompt:
264
  chat_state["messages"].append({"role": "user", "content": user_prompt})
265
  st.session_state[pending_key] = "new_user_prompt"
266
  st.rerun()
267
 
268
- # Pass 2: message is already rendered above; now run generation.
269
- pending_action = st.session_state.pop(pending_key, None)
 
 
270
  if not pending_action:
271
  return
272
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import cast
4
+
5
  import streamlit as st
6
  from persona_data.synth_persona import PersonaData
7
 
8
+ from state import (
9
+ ChatState,
10
+ PendingChatAction,
11
+ chat_session_key,
12
+ get_chat_state,
13
+ reset_chat_context_state,
14
+ )
15
+ from tabs.chat_shared import (
16
+ generate_chat_reply_result,
17
+ hydrate_chat_state,
18
+ load_chat_personas,
19
+ render_chat_selection,
20
+ )
21
  from tabs.chat_ui import (
22
  GenerationConfig,
23
  render_advanced_settings,
24
  render_chat_window,
 
25
  render_system_prompt,
26
  )
27
+ from utils.chat import build_chat_messages, resolve_system_prompt
 
 
 
 
 
28
  from utils.chat_export import save_chat_export
29
+ from utils.helpers import session_key, widget_key
 
30
  from utils.runtime import cached_model
31
 
32
+ _LAST_PERSONA_ID_KEY = session_key("chat", "last_persona_id")
33
+ _LAST_PROMPT_MODE_KEY = session_key("chat", "last_prompt_mode")
34
+ _LAST_COMPARE_MODE_KEY = session_key("chat", "last_compare_mode")
35
+ _LAST_PROBE_ENABLED_KEY = session_key("chat", "last_probe_enabled")
36
+ _LAST_TOKEN_CONTRAST_KEY = session_key("chat", "last_token_contrast")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  def _render_single_chat_footer(
 
88
  chat_state: ChatState,
89
  active_system_prompt: str | None,
90
  generation: GenerationConfig,
91
+ pending_action: PendingChatAction,
92
  chat_log,
93
  ) -> None:
94
  messages = build_chat_messages(active_system_prompt, chat_state["messages"])
95
 
96
  with st.spinner("Generating reply..."):
97
  model = cached_model(model_name=model_name)
98
+
99
+ def _show_error(exc: Exception) -> None:
 
 
 
 
 
 
100
  with chat_log:
101
  st.error(f"Could not generate a reply: {exc}")
102
  st.info("Try a shorter prompt, reset the chat, or switch personas.")
103
+
104
+ reply, error = generate_chat_reply_result(
105
+ model=model,
106
+ messages=messages,
107
+ remote=remote,
108
+ generation=generation,
109
+ on_error=_show_error,
110
+ )
111
+ if error is not None:
112
  if pending_action == "new_user_prompt" and chat_state["messages"]:
113
  chat_state["messages"].pop()
114
  return
115
+ if reply is None:
116
+ return
117
 
118
  chat_state["messages"].append({"role": "assistant", "content": reply.text})
119
  st.rerun()
 
126
  st.caption("Chat with a persona, optionally side-by-side or with token contrast.")
127
 
128
  context_key = chat_session_key(model_name, dataset_source)
129
+ chat_state = get_chat_state(model_name, dataset_source)
130
+ hydrate_chat_state(
131
+ chat_state,
132
+ persisted_persona_key=_LAST_PERSONA_ID_KEY,
133
+ persisted_prompt_key=_LAST_PROMPT_MODE_KEY,
134
+ )
 
 
135
 
136
+ personas = load_chat_personas(dataset_source)
137
  if personas is None:
138
  return
139
 
 
158
  )
159
  return
160
 
 
161
  persona_select_key = widget_key(context_key, "persona_select")
162
  prompt_mode_select_key = widget_key(context_key, "system_prompt_select")
163
  prompt_key = widget_key(context_key, "custom_system_prompt")
 
167
  reset_key = widget_key(context_key, "reset")
168
  edit_key = widget_key(context_key, "edit_idx")
169
 
170
+ selection = render_chat_selection(
171
+ personas,
172
+ chat_state["persona_id"],
173
+ chat_state["prompt_mode"],
174
+ persona_select_key,
175
+ prompt_mode_select_key,
176
+ persisted_persona_key=_LAST_PERSONA_ID_KEY,
177
+ persisted_prompt_key=_LAST_PROMPT_MODE_KEY,
178
+ column_widths=(2, 1),
179
+ )
180
+ selected_persona = selection.persona
181
+ prompt_mode = selection.prompt_mode
182
+ changed_context = selection.changed
183
+
184
  def _reset_active_chat_context() -> None:
185
  reset_chat_context_state(
186
  chat_state,
 
192
  )
193
  st.session_state.pop(edit_key, None)
194
 
 
 
 
 
 
 
 
 
 
 
 
195
  active_system_prompt = resolve_system_prompt(
196
  persona=selected_persona,
197
  mode=prompt_mode,
 
253
 
254
  user_prompt = st.chat_input("Ask something...", key=chat_input_key)
255
 
 
256
  if user_prompt:
257
  chat_state["messages"].append({"role": "user", "content": user_prompt})
258
  st.session_state[pending_key] = "new_user_prompt"
259
  st.rerun()
260
 
261
+ pending_action = cast(
262
+ PendingChatAction | None,
263
+ st.session_state.pop(pending_key, None),
264
+ )
265
  if not pending_action:
266
  return
267
 
tabs/chat_shared.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from collections.abc import Callable
4
+ from dataclasses import dataclass
5
+
6
+ import streamlit as st
7
+ from persona_data.synth_persona import PersonaData
8
+
9
+ from state import ChatState
10
+ from tabs.chat_ui import GenerationConfig, render_persona_prompt_controls
11
+ from utils.chat import ChatReply, generate_chat_reply
12
+ from utils.datasets import load_persona_list
13
+ from utils.helpers import session_key
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class ChatSelection:
18
+ persona: PersonaData
19
+ prompt_mode: str
20
+ changed: bool
21
+
22
+
23
+ def load_chat_personas(dataset_source: str) -> list[PersonaData] | None:
24
+ personas_file_key = session_key("extract", "personas_file")
25
+ qa_file_key = session_key("extract", "qa_file")
26
+ try:
27
+ personas, dataset_status = load_persona_list(
28
+ dataset_source,
29
+ personas_file=st.session_state.get(personas_file_key),
30
+ qa_file=st.session_state.get(qa_file_key),
31
+ )
32
+ st.caption(dataset_status)
33
+ except Exception as exc:
34
+ st.error(f"Could not load data: {exc}")
35
+ st.info("Check the selected dataset source or upload both JSONL files.")
36
+ return None
37
+
38
+ if not personas:
39
+ st.warning("No personas found in the selected dataset.")
40
+ st.info("Try a different dataset source or upload a non-empty personas file.")
41
+ return None
42
+ return personas
43
+
44
+
45
+ def hydrate_chat_state(
46
+ state: ChatState,
47
+ *,
48
+ persisted_persona_key: str,
49
+ persisted_prompt_key: str,
50
+ default_prompt_mode: str = "templated",
51
+ ) -> None:
52
+ if state["persona_id"] is None:
53
+ state["persona_id"] = st.session_state.get(persisted_persona_key)
54
+ state["prompt_mode"] = st.session_state.get(
55
+ persisted_prompt_key,
56
+ default_prompt_mode,
57
+ )
58
+
59
+
60
+ def render_chat_selection(
61
+ personas: list[PersonaData],
62
+ current_persona_id: str | None,
63
+ current_prompt_mode: str,
64
+ persona_key: str,
65
+ prompt_key: str,
66
+ *,
67
+ persisted_persona_key: str,
68
+ persisted_prompt_key: str,
69
+ column_widths: tuple[int, int] = (3, 2),
70
+ ) -> ChatSelection:
71
+ selected_persona, prompt_mode, changed = render_persona_prompt_controls(
72
+ personas,
73
+ current_persona_id,
74
+ current_prompt_mode,
75
+ persona_key,
76
+ prompt_key,
77
+ column_widths=column_widths,
78
+ )
79
+ st.session_state[persisted_persona_key] = selected_persona.id
80
+ st.session_state[persisted_prompt_key] = prompt_mode
81
+ return ChatSelection(selected_persona, prompt_mode, changed)
82
+
83
+
84
+ def generate_chat_reply_result(
85
+ *,
86
+ model: object,
87
+ messages: list[dict[str, str]],
88
+ remote: bool,
89
+ generation: GenerationConfig,
90
+ on_error: Callable[[Exception], None] | None = None,
91
+ ) -> tuple[ChatReply | None, Exception | None]:
92
+ try:
93
+ return (
94
+ generate_chat_reply(
95
+ model=model,
96
+ messages=messages,
97
+ remote=remote,
98
+ **generation.to_generate_kwargs(),
99
+ ),
100
+ None,
101
+ )
102
+ except Exception as exc:
103
+ if on_error is not None:
104
+ on_error(exc)
105
+ return None, exc
tabs/chat_ui.py CHANGED
@@ -29,19 +29,21 @@ GENERATION_DEFAULTS = {
29
  _LAST_GEN_PREFIX = "chat:last_gen:"
30
 
31
 
32
- def _persisted_key(context_key: str, name: str, default) -> str:
 
 
 
 
33
  """Per-context widget key, seeded from the last cross-context value."""
34
- last_key = f"{_LAST_GEN_PREFIX}{name}"
35
  key = widget_key(context_key, name)
36
  if key not in st.session_state:
37
- st.session_state[key] = st.session_state.get(last_key, default)
 
 
 
38
  return key
39
 
40
 
41
- def _remember(name: str, value) -> None:
42
- st.session_state[f"{_LAST_GEN_PREFIX}{name}"] = value
43
-
44
-
45
  @dataclass(frozen=True)
46
  class GenerationConfig:
47
  max_new_tokens: int
@@ -100,7 +102,7 @@ def _open_edit_dialog(
100
 
101
  save_col, cancel_col = st.columns(2)
102
  with save_col:
103
- if st.button("Save", type="primary", use_container_width=True):
104
  messages[msg_index]["content"] = new_content
105
  messages[msg_index].pop("_contrast", None)
106
  if role == "assistant":
@@ -110,7 +112,7 @@ def _open_edit_dialog(
110
  st.session_state[pending_key] = "regenerate_after_edit"
111
  st.rerun()
112
  with cancel_col:
113
- if st.button("Cancel", use_container_width=True):
114
  st.rerun()
115
 
116
 
@@ -129,13 +131,13 @@ def _open_system_prompt_dialog(
129
  )
130
  save_col, cancel_col = st.columns(2)
131
  with save_col:
132
- if st.button("Save", type="primary", use_container_width=True):
133
  st.session_state[prompt_key] = new_value
134
  if on_save is not None:
135
  on_save()
136
  st.rerun()
137
  with cancel_col:
138
- if st.button("Cancel", use_container_width=True):
139
  st.rerun()
140
 
141
 
@@ -307,7 +309,9 @@ def _render_generation_fragment(context_key: str, remote: bool) -> GenerationCon
307
  ("top_k", top_k),
308
  ("seed_enabled", seed_enabled),
309
  ):
310
- _remember(name, value)
 
 
311
 
312
  do_sample = bool(use_sampling)
313
  return GenerationConfig(
 
29
  _LAST_GEN_PREFIX = "chat:last_gen:"
30
 
31
 
32
+ def _last_generation_key(name: str) -> str:
33
+ return f"{_LAST_GEN_PREFIX}{name}"
34
+
35
+
36
+ def _persisted_key(context_key: str, name: str, default: object) -> str:
37
  """Per-context widget key, seeded from the last cross-context value."""
 
38
  key = widget_key(context_key, name)
39
  if key not in st.session_state:
40
+ st.session_state[key] = st.session_state.get(
41
+ _last_generation_key(name),
42
+ default,
43
+ )
44
  return key
45
 
46
 
 
 
 
 
47
  @dataclass(frozen=True)
48
  class GenerationConfig:
49
  max_new_tokens: int
 
102
 
103
  save_col, cancel_col = st.columns(2)
104
  with save_col:
105
+ if st.button("Save", type="primary", width="stretch"):
106
  messages[msg_index]["content"] = new_content
107
  messages[msg_index].pop("_contrast", None)
108
  if role == "assistant":
 
112
  st.session_state[pending_key] = "regenerate_after_edit"
113
  st.rerun()
114
  with cancel_col:
115
+ if st.button("Cancel", width="stretch"):
116
  st.rerun()
117
 
118
 
 
131
  )
132
  save_col, cancel_col = st.columns(2)
133
  with save_col:
134
+ if st.button("Save", type="primary", width="stretch"):
135
  st.session_state[prompt_key] = new_value
136
  if on_save is not None:
137
  on_save()
138
  st.rerun()
139
  with cancel_col:
140
+ if st.button("Cancel", width="stretch"):
141
  st.rerun()
142
 
143
 
 
309
  ("top_k", top_k),
310
  ("seed_enabled", seed_enabled),
311
  ):
312
+ st.session_state[_last_generation_key(name)] = value
313
+ if seed is not None:
314
+ st.session_state[_last_generation_key("seed")] = seed
315
 
316
  do_sample = bool(use_sampling)
317
  return GenerationConfig(
tabs/compare_chat.py CHANGED
@@ -6,22 +6,21 @@ from nnterp import StandardizedTransformer
6
  from persona_data.synth_persona import PersonaData
7
 
8
  from state import ChatState, default_chat_state, reset_chat_context_state
9
- from utils.chat import (
10
- ChatReply,
11
- build_chat_messages,
12
- generate_chat_reply,
13
- resolve_system_prompt,
14
  )
 
15
  from utils.chat_export import save_chat_export
16
  from utils.contrast import compute_contrast, compute_contrast_pair
17
- from utils.helpers import persona_label, widget_key
18
  from utils.runtime import cached_model
19
 
20
  from .chat_ui import (
21
  GenerationConfig,
22
  render_chat_message,
23
  render_chat_window,
24
- render_persona_prompt_controls,
25
  render_system_prompt,
26
  )
27
 
@@ -68,21 +67,26 @@ def _render_compare_panel(
68
  edit_key = widget_key(panel_key, "edit_idx")
69
  pending_key = widget_key(panel_key, "pending_regen")
70
 
71
- persist_persona_key = f"chat:last_cmp_{side}_persona"
72
- persist_prompt_key = f"chat:last_cmp_{side}_prompt"
73
- if state["persona_id"] is None:
74
- state["persona_id"] = st.session_state.get(persist_persona_key)
75
- state["prompt_mode"] = st.session_state.get(persist_prompt_key, "templated")
 
 
76
 
77
- selected_persona, prompt_mode, changed = render_persona_prompt_controls(
78
  personas,
79
  state["persona_id"],
80
  state["prompt_mode"],
81
  widget_key(panel_key, "persona"),
82
  widget_key(panel_key, "prompt_mode"),
 
 
83
  )
84
- st.session_state[persist_persona_key] = selected_persona.id
85
- st.session_state[persist_prompt_key] = prompt_mode
 
86
 
87
  if changed:
88
  reset_chat_context_state(
@@ -136,19 +140,13 @@ def _generate_panels(
136
  results: list[ChatReply | Exception] = []
137
  with st.spinner(spinner_label):
138
  for panel in panels:
139
- try:
140
- results.append(
141
- generate_chat_reply(
142
- model=model,
143
- messages=build_chat_messages(
144
- panel.prompt, panel.state["messages"]
145
- ),
146
- remote=remote,
147
- **generation.to_generate_kwargs(),
148
- )
149
- )
150
- except Exception as exc:
151
- results.append(exc)
152
  return results
153
 
154
 
 
6
  from persona_data.synth_persona import PersonaData
7
 
8
  from state import ChatState, default_chat_state, reset_chat_context_state
9
+ from tabs.chat_shared import (
10
+ generate_chat_reply_result,
11
+ hydrate_chat_state,
12
+ render_chat_selection,
 
13
  )
14
+ from utils.chat import ChatReply, build_chat_messages, resolve_system_prompt
15
  from utils.chat_export import save_chat_export
16
  from utils.contrast import compute_contrast, compute_contrast_pair
17
+ from utils.helpers import persona_label, session_key, widget_key
18
  from utils.runtime import cached_model
19
 
20
  from .chat_ui import (
21
  GenerationConfig,
22
  render_chat_message,
23
  render_chat_window,
 
24
  render_system_prompt,
25
  )
26
 
 
67
  edit_key = widget_key(panel_key, "edit_idx")
68
  pending_key = widget_key(panel_key, "pending_regen")
69
 
70
+ persist_persona_key = session_key("chat", f"last_cmp_{side}_persona")
71
+ persist_prompt_key = session_key("chat", f"last_cmp_{side}_prompt")
72
+ hydrate_chat_state(
73
+ state,
74
+ persisted_persona_key=persist_persona_key,
75
+ persisted_prompt_key=persist_prompt_key,
76
+ )
77
 
78
+ selection = render_chat_selection(
79
  personas,
80
  state["persona_id"],
81
  state["prompt_mode"],
82
  widget_key(panel_key, "persona"),
83
  widget_key(panel_key, "prompt_mode"),
84
+ persisted_persona_key=persist_persona_key,
85
+ persisted_prompt_key=persist_prompt_key,
86
  )
87
+ selected_persona = selection.persona
88
+ prompt_mode = selection.prompt_mode
89
+ changed = selection.changed
90
 
91
  if changed:
92
  reset_chat_context_state(
 
140
  results: list[ChatReply | Exception] = []
141
  with st.spinner(spinner_label):
142
  for panel in panels:
143
+ reply, error = generate_chat_reply_result(
144
+ model=model,
145
+ messages=build_chat_messages(panel.prompt, panel.state["messages"]),
146
+ remote=remote,
147
+ generation=generation,
148
+ )
149
+ results.append(reply if error is None else error)
 
 
 
 
 
 
150
  return results
151
 
152
 
tabs/extract.py CHANGED
@@ -5,7 +5,7 @@ import streamlit as st
5
  from catppuccin import PALETTE
6
  from persona_data.prompts import format_prompt
7
  from persona_data.synth_persona import BASELINE_PERSONA_ID, PersonaData, QAPair
8
- from persona_vectors.artifacts import PERSONA_VARIANTS
9
  from persona_vectors.extraction import (
10
  MaskStrategy,
11
  prepare_inputs_for_strategy,
@@ -14,11 +14,12 @@ from persona_vectors.extraction import (
14
  from persona_vectors.preview import TokenSegment, preview_token_segments
15
 
16
  from utils.controls import render_mask_strategy_select
17
- from utils.datasets import load_dataset, load_persona_list
18
  from utils.helpers import (
19
  NDIF_STATUS_ICONS,
20
  persona_label,
21
  prompt_variant_label,
 
22
  widget_key,
23
  )
24
  from utils.runtime import cached_model
@@ -29,6 +30,9 @@ _LAST_PERSONA_IDS_KEY = "extract:last_persona_ids"
29
  _LAST_MAX_QUESTIONS_KEY = "extract:last_max_questions"
30
  _LAST_MASK_STRATEGY_KEY = "extract:last_mask_strategy"
31
 
 
 
 
32
  _DEFAULT_MAX_QUESTIONS = 50
33
 
34
 
@@ -42,7 +46,7 @@ def _build_run_plan(
42
  selected_variants: list[str],
43
  runs: list[tuple[PersonaData, list[QAPair]]],
44
  ) -> list[tuple[PersonaData, list[QAPair], str]]:
45
- """Cartesian product of personas × variants."""
46
  return [(p, qa, v) for v in selected_variants for p, qa in runs]
47
 
48
 
@@ -63,13 +67,13 @@ def _render_local_dataset_upload(dataset_source: str) -> None:
63
  st.file_uploader(
64
  "personas.jsonl",
65
  type=["jsonl"],
66
- key="extract__personas_file",
67
  help="Expected fields: id, persona, templated_view, biography_view",
68
  )
69
  st.file_uploader(
70
  "qa.jsonl",
71
  type=["jsonl"],
72
- key="extract__qa_file",
73
  help="Expected fields: id, qid, type, item_type, scope, question, answer",
74
  )
75
 
@@ -80,12 +84,14 @@ def _render_variant_controls(
80
  remote: bool,
81
  dataset_source: str,
82
  ) -> tuple[list[str], bool] | None:
83
- default_variants = st.session_state.get(_LAST_VARIANTS_KEY, list(PERSONA_VARIANTS))
 
 
84
  selected_variants = st.multiselect(
85
  "Persona variants",
86
- options=PERSONA_VARIANTS,
87
- default=[v for v in default_variants if v in PERSONA_VARIANTS]
88
- or list(PERSONA_VARIANTS),
89
  format_func=prompt_variant_label,
90
  key=_extract_widget_key(model_name, remote, dataset_source, "persona_variants"),
91
  help="Extract these variants for each selected persona.",
@@ -110,14 +116,10 @@ def _load_qa_dataset_personas(
110
  try:
111
  dataset, dataset_status = load_dataset(
112
  dataset_source,
113
- personas_file=st.session_state.get("extract__personas_file"),
114
- qa_file=st.session_state.get("extract__qa_file"),
115
- )
116
- personas, _ = load_persona_list(
117
- dataset_source,
118
- personas_file=st.session_state.get("extract__personas_file"),
119
- qa_file=st.session_state.get("extract__qa_file"),
120
  )
 
121
  st.caption(dataset_status)
122
  except Exception as exc:
123
  st.error(f"Could not load data: {exc}")
@@ -289,10 +291,10 @@ def _render_extract_actions() -> tuple[bool, bool]:
289
  run_clicked = st.button(
290
  "Run extraction",
291
  type="primary",
292
- use_container_width=True,
293
  )
294
  with preview_col:
295
- preview_clicked = st.button("Preview tokens", use_container_width=True)
296
  return run_clicked, preview_clicked
297
 
298
 
 
5
  from catppuccin import PALETTE
6
  from persona_data.prompts import format_prompt
7
  from persona_data.synth_persona import BASELINE_PERSONA_ID, PersonaData, QAPair
8
+ from persona_vectors.artifacts import SUPPORTED_VARIANTS
9
  from persona_vectors.extraction import (
10
  MaskStrategy,
11
  prepare_inputs_for_strategy,
 
14
  from persona_vectors.preview import TokenSegment, preview_token_segments
15
 
16
  from utils.controls import render_mask_strategy_select
17
+ from utils.datasets import load_dataset, load_persona_list_from_dataset
18
  from utils.helpers import (
19
  NDIF_STATUS_ICONS,
20
  persona_label,
21
  prompt_variant_label,
22
+ session_key,
23
  widget_key,
24
  )
25
  from utils.runtime import cached_model
 
30
  _LAST_MAX_QUESTIONS_KEY = "extract:last_max_questions"
31
  _LAST_MASK_STRATEGY_KEY = "extract:last_mask_strategy"
32
 
33
+ _PERSONAS_FILE_KEY = session_key("extract", "personas_file")
34
+ _QA_FILE_KEY = session_key("extract", "qa_file")
35
+
36
  _DEFAULT_MAX_QUESTIONS = 50
37
 
38
 
 
46
  selected_variants: list[str],
47
  runs: list[tuple[PersonaData, list[QAPair]]],
48
  ) -> list[tuple[PersonaData, list[QAPair], str]]:
49
+ """Cartesian product of personas x variants."""
50
  return [(p, qa, v) for v in selected_variants for p, qa in runs]
51
 
52
 
 
67
  st.file_uploader(
68
  "personas.jsonl",
69
  type=["jsonl"],
70
+ key=_PERSONAS_FILE_KEY,
71
  help="Expected fields: id, persona, templated_view, biography_view",
72
  )
73
  st.file_uploader(
74
  "qa.jsonl",
75
  type=["jsonl"],
76
+ key=_QA_FILE_KEY,
77
  help="Expected fields: id, qid, type, item_type, scope, question, answer",
78
  )
79
 
 
84
  remote: bool,
85
  dataset_source: str,
86
  ) -> tuple[list[str], bool] | None:
87
+ default_variants = st.session_state.get(
88
+ _LAST_VARIANTS_KEY, list(SUPPORTED_VARIANTS)
89
+ )
90
  selected_variants = st.multiselect(
91
  "Persona variants",
92
+ options=SUPPORTED_VARIANTS,
93
+ default=[v for v in default_variants if v in SUPPORTED_VARIANTS]
94
+ or list(SUPPORTED_VARIANTS),
95
  format_func=prompt_variant_label,
96
  key=_extract_widget_key(model_name, remote, dataset_source, "persona_variants"),
97
  help="Extract these variants for each selected persona.",
 
116
  try:
117
  dataset, dataset_status = load_dataset(
118
  dataset_source,
119
+ personas_file=st.session_state.get(_PERSONAS_FILE_KEY),
120
+ qa_file=st.session_state.get(_QA_FILE_KEY),
 
 
 
 
 
121
  )
122
+ personas = load_persona_list_from_dataset(dataset)
123
  st.caption(dataset_status)
124
  except Exception as exc:
125
  st.error(f"Could not load data: {exc}")
 
291
  run_clicked = st.button(
292
  "Run extraction",
293
  type="primary",
294
+ width="stretch",
295
  )
296
  with preview_col:
297
+ preview_clicked = st.button("Preview tokens", width="stretch")
298
  return run_clicked, preview_clicked
299
 
300
 
tabs/probe_ui.py CHANGED
@@ -197,7 +197,7 @@ def _trace_requested(context_key: str) -> bool:
197
  if st.button(
198
  "Trace conversation",
199
  key=widget_key(context_key, "probe_trace"),
200
- use_container_width=True,
201
  ):
202
  st.session_state[trace_key] = True
203
  return bool(st.session_state.get(trace_key, False))
 
197
  if st.button(
198
  "Trace conversation",
199
  key=widget_key(context_key, "probe_trace"),
200
+ width="stretch",
201
  ):
202
  st.session_state[trace_key] = True
203
  return bool(st.session_state.get(trace_key, False))
utils/{compare_sources.py → analysis_sources.py} RENAMED
@@ -1,12 +1,10 @@
1
  import os
2
 
3
  import streamlit as st
4
- import torch
5
- from persona_vectors.analysis import LayeredSamples
6
  from persona_vectors.artifacts import (
7
  ActivationStore,
8
  HFActivationStore,
9
- activation_config_name,
10
  discover_activation_models,
11
  model_dir_name,
12
  )
@@ -25,28 +23,6 @@ SOURCE_LOCAL = "Local activations"
25
  SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
26
 
27
 
28
- def _hub_split(repo_id: str, model_name: str, mask_strategy_value: str, variant: str):
29
- from datasets import load_dataset
30
-
31
- return load_dataset(
32
- repo_id,
33
- name=activation_config_name(model_name, mask_strategy_value),
34
- split=variant,
35
- keep_in_memory=False,
36
- )
37
-
38
-
39
- def _hub_split_columns(
40
- repo_id: str,
41
- model_name: str,
42
- mask_strategy_value: str,
43
- variant: str,
44
- columns: list[str],
45
- ):
46
- dataset = _hub_split(repo_id, model_name, mask_strategy_value, variant)
47
- return dataset.select_columns(columns)
48
-
49
-
50
  @st.cache_resource(show_spinner=False, max_entries=1)
51
  def activation_store_cached(
52
  source: str,
@@ -67,8 +43,9 @@ def available_variants_cached(
67
  model_name: str,
68
  mask_strategy_value: str,
69
  ) -> list[str]:
70
- store = activation_store_cached(source, location, model_name, mask_strategy_value)
71
- return store.available_variants()
 
72
 
73
 
74
  @st.cache_data(show_spinner=False)
@@ -79,31 +56,9 @@ def personas_cached(
79
  mask_strategy_value: str,
80
  variants: tuple[str, ...],
81
  ) -> list[str]:
82
- if source == SOURCE_HUB:
83
- variant_ids = [
84
- list(
85
- _hub_split_columns(
86
- location,
87
- model_name,
88
- mask_strategy_value,
89
- variant,
90
- ["persona_id"],
91
- )["persona_id"]
92
- )
93
- for variant in variants
94
- ]
95
- if not variant_ids:
96
- return []
97
- shared = set(variant_ids[0])
98
- for ids in variant_ids[1:]:
99
- shared &= set(ids)
100
- return [persona_id for persona_id in variant_ids[0] if persona_id in shared]
101
-
102
- store = activation_store_cached(source, location, model_name, mask_strategy_value)
103
- return store.list_personas(
104
- list(variants),
105
- mask_strategy=MaskStrategy(mask_strategy_value),
106
- )
107
 
108
 
109
  @st.cache_data(show_spinner=False)
@@ -115,31 +70,24 @@ def persona_names_cached(
115
  variants: tuple[str, ...],
116
  persona_ids: tuple[str, ...],
117
  ) -> dict[str, str]:
118
- if source == SOURCE_HUB:
119
- requested = set(persona_ids)
120
- names: dict[str, str] = {}
121
- for variant in variants:
122
- metadata = _hub_split_columns(
123
- location,
124
- model_name,
125
- mask_strategy_value,
126
- variant,
127
- ["persona_id", "name"],
128
- )
129
- for row in metadata:
130
- persona_id = row["persona_id"]
131
- if persona_id in requested and persona_id not in names:
132
- names[persona_id] = row.get("name") or persona_id
133
- if len(names) == len(requested):
134
- return {pid: names.get(pid, pid) for pid in persona_ids}
135
- return {pid: names.get(pid, pid) for pid in persona_ids}
136
-
137
  store = activation_store_cached(source, location, model_name, mask_strategy_value)
138
- return store.persona_names(
139
- list(persona_ids),
140
- variants=list(variants),
141
- mask_strategy=MaskStrategy(mask_strategy_value),
142
- )
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
 
145
  @st.cache_data(show_spinner=False)
@@ -151,11 +99,11 @@ def local_model_options_cached(
151
 
152
  @st.cache_data(show_spinner=False)
153
  def hub_models_by_mask_strategy(repo_id: str) -> dict[MaskStrategy, list[str]]:
154
- raw = list_hub_vector_models(repo_id)
155
  return {
156
  MaskStrategy(strategy_value): models
157
- for strategy_value, models in raw.items()
158
- if strategy_value in {strategy.value for strategy in MaskStrategy}
159
  }
160
 
161
 
@@ -173,56 +121,14 @@ def store_id(store: Store) -> str:
173
 
174
  def available_variants(store: Store, mask_strategy: MaskStrategy) -> list[str]:
175
  source, location, model_name = store_cache_parts(store)
176
- return available_variants_cached(
177
- source,
178
- location,
179
- model_name,
180
- mask_strategy.value,
181
- )
182
-
183
-
184
- @st.cache_data(show_spinner=False)
185
- def store_layers_cached(
186
- source: str,
187
- location: str,
188
- model_name: str,
189
- mask_strategy_value: str,
190
- variants: tuple[str, ...],
191
- persona_ids: tuple[str, ...],
192
- ) -> list[int]:
193
- if source == SOURCE_HUB:
194
- shared_layers: set[int] | None = None
195
- requested = list(persona_ids)
196
- for variant in variants:
197
- dataset = _hub_split(location, model_name, mask_strategy_value, variant)
198
- ids = list(dataset.select_columns(["persona_id"])["persona_id"])
199
- sample_id = requested[0] if requested else (ids[0] if ids else None)
200
- if sample_id is None:
201
- return []
202
- if requested and any(persona_id not in ids for persona_id in requested):
203
- return []
204
- vector = torch.as_tensor(dataset[ids.index(sample_id)]["vector"])
205
- if vector.ndim != 2:
206
- raise ValueError(
207
- f"tensor for {sample_id!r} must have shape (num_layers, hidden_size)"
208
- )
209
- layers = set(range(int(vector.shape[0])))
210
- shared_layers = layers if shared_layers is None else shared_layers & layers
211
- return sorted(shared_layers or set())
212
-
213
- store = activation_store_cached(source, location, model_name, mask_strategy_value)
214
- return store.list_layers(
215
- list(variants),
216
- list(persona_ids),
217
- mask_strategy=MaskStrategy(mask_strategy_value),
218
- )
219
 
220
 
221
  def local_model_matches(left: str, right: str) -> bool:
222
  return model_dir_name(left) == model_dir_name(right)
223
 
224
 
225
- def load_persona_vectors_lean(
226
  source: str,
227
  location: str,
228
  model_name: str,
@@ -230,61 +136,16 @@ def load_persona_vectors_lean(
230
  variant: str,
231
  persona_ids: tuple[str, ...],
232
  ) -> LayeredSamples:
233
- if source != SOURCE_HUB:
234
- from persona_vectors.analysis import load_persona_vectors
235
-
236
- store = activation_store_cached(
237
- source,
238
- location,
239
- model_name,
240
- mask_strategy_value,
241
- )
242
- return load_persona_vectors(
243
- store,
244
- variant,
245
- mask_strategy=MaskStrategy(mask_strategy_value),
246
- persona_ids=list(persona_ids),
247
- )
248
-
249
- dataset = _hub_split(location, model_name, mask_strategy_value, variant)
250
- metadata = dataset.select_columns(["persona_id", "name"])
251
- index_by_id: dict[str, int] = {}
252
- name_by_id: dict[str, str] = {}
253
- requested = set(persona_ids)
254
- for index, row in enumerate(metadata):
255
- persona_id = row["persona_id"]
256
- if persona_id in requested:
257
- index_by_id[persona_id] = index
258
- name_by_id[persona_id] = row.get("name") or persona_id
259
- if len(index_by_id) == len(requested):
260
- break
261
-
262
- missing = [
263
- persona_id for persona_id in persona_ids if persona_id not in index_by_id
264
- ]
265
- if missing:
266
- raise FileNotFoundError(
267
- f"Missing {len(missing)} persona vector(s) in {variant!r}: {missing[:3]}"
268
- )
269
-
270
- vectors, labels, hover_text = [], [], []
271
- for persona_id in persona_ids:
272
- name = name_by_id.get(persona_id, persona_id)
273
- vector = torch.as_tensor(
274
- dataset[index_by_id[persona_id]]["vector"],
275
- dtype=torch.float32,
276
- )
277
- if vector.ndim != 2:
278
- raise ValueError(
279
- f"tensor for {persona_id!r} must have shape (num_layers, hidden_size)"
280
- )
281
- vectors.append(vector)
282
- labels.append(name)
283
- hover_text.append(f"Persona: {name}<br>ID: {persona_id}")
284
- return LayeredSamples(torch.stack(vectors), labels, hover_text)
285
 
286
 
287
- def load_variant_vectors_lean(
288
  source: str,
289
  location: str,
290
  model_name: str,
@@ -293,27 +154,18 @@ def load_variant_vectors_lean(
293
  persona_ids: tuple[str, ...],
294
  ) -> dict[str, LayeredSamples]:
295
  return {
296
- variant: load_persona_vectors_lean(
297
- source,
298
- location,
299
- model_name,
300
- mask_strategy_value,
301
- variant,
302
- persona_ids,
303
  )
304
  for variant in variants
305
  }
306
 
307
 
308
- def release_store_cache(
309
  store: Store,
310
  variants: list[str] | tuple[str, ...] | None = None,
311
  ) -> None:
312
- cache = getattr(store, "_cache", None)
313
- if not isinstance(cache, dict):
314
- return
315
- if variants is None:
316
- cache.clear()
317
- return
318
- for variant in variants:
319
- cache.pop(variant, None)
 
1
  import os
2
 
3
  import streamlit as st
4
+ from persona_vectors.analysis import LayeredSamples, load_persona_vectors
 
5
  from persona_vectors.artifacts import (
6
  ActivationStore,
7
  HFActivationStore,
 
8
  discover_activation_models,
9
  model_dir_name,
10
  )
 
23
  SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  @st.cache_resource(show_spinner=False, max_entries=1)
27
  def activation_store_cached(
28
  source: str,
 
43
  model_name: str,
44
  mask_strategy_value: str,
45
  ) -> list[str]:
46
+ return activation_store_cached(
47
+ source, location, model_name, mask_strategy_value
48
+ ).available_variants()
49
 
50
 
51
  @st.cache_data(show_spinner=False)
 
56
  mask_strategy_value: str,
57
  variants: tuple[str, ...],
58
  ) -> list[str]:
59
+ return activation_store_cached(
60
+ source, location, model_name, mask_strategy_value
61
+ ).list_personas(list(variants))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
 
64
  @st.cache_data(show_spinner=False)
 
70
  variants: tuple[str, ...],
71
  persona_ids: tuple[str, ...],
72
  ) -> dict[str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  store = activation_store_cached(source, location, model_name, mask_strategy_value)
74
+ names = store.persona_names(list(persona_ids), variants=list(variants))
75
+ # Preserve input order, fall back to the id when the row has no display name.
76
+ return {pid: names.get(pid, pid) for pid in persona_ids}
77
+
78
+
79
+ @st.cache_data(show_spinner=False)
80
+ def store_layers_cached(
81
+ source: str,
82
+ location: str,
83
+ model_name: str,
84
+ mask_strategy_value: str,
85
+ variants: tuple[str, ...],
86
+ persona_ids: tuple[str, ...],
87
+ ) -> list[int]:
88
+ return activation_store_cached(
89
+ source, location, model_name, mask_strategy_value
90
+ ).list_layers(list(variants), list(persona_ids))
91
 
92
 
93
  @st.cache_data(show_spinner=False)
 
99
 
100
  @st.cache_data(show_spinner=False)
101
  def hub_models_by_mask_strategy(repo_id: str) -> dict[MaskStrategy, list[str]]:
102
+ valid = {strategy.value for strategy in MaskStrategy}
103
  return {
104
  MaskStrategy(strategy_value): models
105
+ for strategy_value, models in list_hub_vector_models(repo_id).items()
106
+ if strategy_value in valid
107
  }
108
 
109
 
 
121
 
122
  def available_variants(store: Store, mask_strategy: MaskStrategy) -> list[str]:
123
  source, location, model_name = store_cache_parts(store)
124
+ return available_variants_cached(source, location, model_name, mask_strategy.value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
 
127
  def local_model_matches(left: str, right: str) -> bool:
128
  return model_dir_name(left) == model_dir_name(right)
129
 
130
 
131
+ def load_persona_vectors_cached(
132
  source: str,
133
  location: str,
134
  model_name: str,
 
136
  variant: str,
137
  persona_ids: tuple[str, ...],
138
  ) -> LayeredSamples:
139
+ store = activation_store_cached(source, location, model_name, mask_strategy_value)
140
+ return load_persona_vectors(
141
+ store,
142
+ variant,
143
+ mask_strategy=MaskStrategy(mask_strategy_value),
144
+ persona_ids=list(persona_ids),
145
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
 
148
+ def load_variant_vectors_cached(
149
  source: str,
150
  location: str,
151
  model_name: str,
 
154
  persona_ids: tuple[str, ...],
155
  ) -> dict[str, LayeredSamples]:
156
  return {
157
+ variant: load_persona_vectors_cached(
158
+ source, location, model_name, mask_strategy_value, variant, persona_ids
 
 
 
 
 
159
  )
160
  for variant in variants
161
  }
162
 
163
 
164
+ def release_hf_store_cache(
165
  store: Store,
166
  variants: list[str] | tuple[str, ...] | None = None,
167
  ) -> None:
168
+ """Drop cached HF data for ``variants`` (or all) on Hub stores."""
169
+ release_cache = getattr(store, "release_cache", None)
170
+ if isinstance(store, HFActivationStore) and callable(release_cache):
171
+ release_cache(variants)
 
 
 
 
utils/chat.py CHANGED
@@ -5,11 +5,11 @@ from contextlib import contextmanager, nullcontext
5
  from dataclasses import dataclass
6
  from typing import TYPE_CHECKING, Literal
7
 
 
8
  from persona_data.prompts import format_messages, format_prompt, normalize_messages
9
  from persona_data.synth_persona import PersonaData
10
 
11
  if TYPE_CHECKING:
12
- import torch
13
  from nnterp import StandardizedTransformer
14
 
15
  logger = logging.getLogger(__name__)
@@ -133,8 +133,6 @@ def format_generation_prompt(
133
 
134
  def resolve_saved_tensor(value: object) -> torch.Tensor:
135
  """Resolve an nnsight ``.save()`` proxy (or raw tensor) to a CPU tensor."""
136
- import torch
137
-
138
  resolved = value.value if getattr(value, "value", None) is not None else value
139
  if not isinstance(resolved, torch.Tensor):
140
  raise TypeError(f"Trace result did not resolve to a tensor: {type(resolved)!r}")
@@ -160,8 +158,6 @@ def _seeded_rng(seed: int | None):
160
  yield
161
  return
162
 
163
- import torch
164
-
165
  cuda_ctx = torch.random.fork_rng(devices=range(torch.cuda.device_count()))
166
  mps_ctx = (
167
  torch.random.fork_rng(devices=range(1), device_type="mps")
@@ -207,8 +203,6 @@ def generate_chat_reply(
207
  ChatReply with generated text and token ids.
208
  """
209
 
210
- import torch
211
-
212
  tokenizer = model.tokenizer
213
  prompt, prompt_token_count = format_generation_prompt(messages, tokenizer)
214
 
@@ -228,9 +222,11 @@ def generate_chat_reply(
228
  generation_kwargs["repetition_penalty"] = repetition_penalty
229
  # `remote` is captured by nnsight's RemoteableMixin.trace() and is NOT
230
  # forwarded to the underlying model's generate
231
- with _seeded_rng(seed if do_sample and not remote else None):
232
- with model.generate(prompt, remote=remote, **generation_kwargs) as tracer:
233
- generated = tracer.result.save()
 
 
234
 
235
  if getattr(generated, "value", None) is not None:
236
  generated = generated.value
 
5
  from dataclasses import dataclass
6
  from typing import TYPE_CHECKING, Literal
7
 
8
+ import torch
9
  from persona_data.prompts import format_messages, format_prompt, normalize_messages
10
  from persona_data.synth_persona import PersonaData
11
 
12
  if TYPE_CHECKING:
 
13
  from nnterp import StandardizedTransformer
14
 
15
  logger = logging.getLogger(__name__)
 
133
 
134
  def resolve_saved_tensor(value: object) -> torch.Tensor:
135
  """Resolve an nnsight ``.save()`` proxy (or raw tensor) to a CPU tensor."""
 
 
136
  resolved = value.value if getattr(value, "value", None) is not None else value
137
  if not isinstance(resolved, torch.Tensor):
138
  raise TypeError(f"Trace result did not resolve to a tensor: {type(resolved)!r}")
 
158
  yield
159
  return
160
 
 
 
161
  cuda_ctx = torch.random.fork_rng(devices=range(torch.cuda.device_count()))
162
  mps_ctx = (
163
  torch.random.fork_rng(devices=range(1), device_type="mps")
 
203
  ChatReply with generated text and token ids.
204
  """
205
 
 
 
206
  tokenizer = model.tokenizer
207
  prompt, prompt_token_count = format_generation_prompt(messages, tokenizer)
208
 
 
222
  generation_kwargs["repetition_penalty"] = repetition_penalty
223
  # `remote` is captured by nnsight's RemoteableMixin.trace() and is NOT
224
  # forwarded to the underlying model's generate
225
+ with (
226
+ _seeded_rng(seed if do_sample and not remote else None),
227
+ model.generate(prompt, remote=remote, **generation_kwargs) as tracer,
228
+ ):
229
+ generated = tracer.result.save()
230
 
231
  if getattr(generated, "value", None) is not None:
232
  generated = generated.value
utils/chat_export.py CHANGED
@@ -1,6 +1,6 @@
1
  import json
2
  from dataclasses import asdict, is_dataclass
3
- from datetime import datetime, timezone
4
  from pathlib import Path
5
 
6
  from utils.helpers import slugify
@@ -72,7 +72,7 @@ def save_chat_export(
72
  )
73
  export_dir.mkdir(parents=True, exist_ok=True)
74
 
75
- timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
76
  filename_parts = [
77
  timestamp,
78
  slugify(persona_name or persona_id),
 
1
  import json
2
  from dataclasses import asdict, is_dataclass
3
+ from datetime import UTC, datetime
4
  from pathlib import Path
5
 
6
  from utils.helpers import slugify
 
72
  )
73
  export_dir.mkdir(parents=True, exist_ok=True)
74
 
75
+ timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ")
76
  filename_parts = [
77
  timestamp,
78
  slugify(persona_name or persona_id),
utils/contrast.py CHANGED
@@ -244,7 +244,9 @@ def render_contrast_html(result: TokenContrast) -> str:
244
  it is, with a hover tooltip showing the raw Δlog P, plus a legend.
245
  """
246
  spans: list[str] = []
247
- for token, weight, raw in zip(result.tokens, result.weights, result.raw_diffs):
 
 
248
  bg = _weight_to_bg(weight)
249
  tip = escape(f"Δlog P(A−B): {raw:+.3f}")
250
  text = escape(token)
 
244
  it is, with a hover tooltip showing the raw Δlog P, plus a legend.
245
  """
246
  spans: list[str] = []
247
+ for token, weight, raw in zip(
248
+ result.tokens, result.weights, result.raw_diffs, strict=True
249
+ ):
250
  bg = _weight_to_bg(weight)
251
  tip = escape(f"Δlog P(A−B): {raw:+.3f}")
252
  text = escape(token)
utils/datasets.py CHANGED
@@ -13,7 +13,7 @@ from persona_data.nemotron_personas import (
13
  from persona_data.synth_persona import PersonaDataset as LocalPersonaDataset
14
  from persona_data.synth_persona import SynthPersonaDataset
15
 
16
- from .helpers import DATASET_SOURCES
17
 
18
 
19
  @st.cache_resource(show_spinner=False)
@@ -63,6 +63,12 @@ def load_persona_list(
63
  """
64
 
65
  dataset, status = load_dataset(dataset_source, personas_file, qa_file)
 
 
 
 
 
 
66
  cached = getattr(dataset, "_persona_list_cache", None)
67
  if cached is None:
68
  cached = list(dataset)
@@ -70,7 +76,7 @@ def load_persona_list(
70
  dataset._persona_list_cache = cached
71
  except (AttributeError, TypeError):
72
  pass
73
- return cached, status
74
 
75
 
76
  def load_dataset(
@@ -86,13 +92,13 @@ def load_dataset(
86
  ]:
87
  """Load the selected dataset source for the UI."""
88
 
89
- if dataset_source == DATASET_SOURCES[0]:
90
  return _cached_dataset(SynthPersonaDataset), "SynthPersona"
91
 
92
- if dataset_source == DATASET_SOURCES[1]:
93
  return _cached_dataset(NemotronPersonasFranceDataset), "Nemotron France"
94
 
95
- if dataset_source == DATASET_SOURCES[2]:
96
  return _cached_dataset(NemotronPersonasUSADataset), "Nemotron USA"
97
 
98
  if personas_file is None or qa_file is None:
 
13
  from persona_data.synth_persona import PersonaDataset as LocalPersonaDataset
14
  from persona_data.synth_persona import SynthPersonaDataset
15
 
16
+ from .helpers import DatasetSource
17
 
18
 
19
  @st.cache_resource(show_spinner=False)
 
63
  """
64
 
65
  dataset, status = load_dataset(dataset_source, personas_file, qa_file)
66
+ return load_persona_list_from_dataset(dataset), status
67
+
68
+
69
+ def load_persona_list_from_dataset(dataset: Any) -> list:
70
+ """Materialize and cache personas from an already-loaded dataset."""
71
+
72
  cached = getattr(dataset, "_persona_list_cache", None)
73
  if cached is None:
74
  cached = list(dataset)
 
76
  dataset._persona_list_cache = cached
77
  except (AttributeError, TypeError):
78
  pass
79
+ return cached
80
 
81
 
82
  def load_dataset(
 
92
  ]:
93
  """Load the selected dataset source for the UI."""
94
 
95
+ if dataset_source == DatasetSource.SYNTH_PERSONA.value:
96
  return _cached_dataset(SynthPersonaDataset), "SynthPersona"
97
 
98
+ if dataset_source == DatasetSource.NEMOTRON_FRANCE.value:
99
  return _cached_dataset(NemotronPersonasFranceDataset), "Nemotron France"
100
 
101
+ if dataset_source == DatasetSource.NEMOTRON_USA.value:
102
  return _cached_dataset(NemotronPersonasUSADataset), "Nemotron USA"
103
 
104
  if personas_file is None or qa_file is None:
utils/helpers.py CHANGED
@@ -1,9 +1,21 @@
1
  import hashlib
2
  import re
3
  from collections.abc import Iterable
 
4
 
5
  from persona_data.synth_persona import PersonaData
6
 
 
 
 
 
 
 
 
 
 
 
 
7
  # Variant key -> human-readable label mapping
8
  VARIANT_LABELS = {
9
  "empty": "None",
@@ -16,21 +28,21 @@ VARIANT_LABELS = {
16
  CHAT_PROMPT_MODES = ("empty", "templated", "biography", "custom")
17
  CHAT_PROMPT_MODE_LABELS = [VARIANT_LABELS[key] for key in CHAT_PROMPT_MODES]
18
  CHAT_PROMPT_MODE_LABEL_TO_KEY = {VARIANT_LABELS[key]: key for key in CHAT_PROMPT_MODES}
19
-
20
-
21
- DATASET_SOURCES = [
22
- "HuggingFace: synth-persona",
23
- "HuggingFace: nemotron-france",
24
- "HuggingFace: nemotron-usa",
25
- "Local JSONL upload",
26
  ]
27
- ANALYSIS_MODES = ["Cosine similarity", "Similarity matrix", "PCA", "UMAP", "Dendrogram"]
28
 
29
  ANALYSIS_HELP_TEXT = {
30
  "Cosine similarity": "Compare layer-wise alignment between variants.",
31
  "Similarity matrix": "Compare centered pairwise similarity between persona vectors by layer, with pair trajectories across layers.",
32
  "PCA": "Project per-persona vectors into a 2D or 3D global view.",
33
  "UMAP": "Project per-persona vectors into a 2D or 3D local-neighborhood view.",
 
34
  "Dendrogram": "Hierarchical clustering of persona vectors — shows biography and templated side by side for direct comparison.",
35
  }
36
 
@@ -56,6 +68,12 @@ def widget_key(*parts: str) -> str:
56
  return "::".join(parts)
57
 
58
 
 
 
 
 
 
 
59
  def personas_fingerprint(persona_ids: Iterable[str]) -> str:
60
  """Stable short fingerprint for a set of persona ids.
61
 
@@ -78,11 +96,3 @@ def persona_label(persona: PersonaData) -> str:
78
  """Format a persona for selection widgets."""
79
 
80
  return f"{persona.name} ({persona.id})"
81
-
82
-
83
- def persona_display_label(persona_id: str, persona_name: str | None) -> str:
84
- """Format a persona id with an optional display name."""
85
-
86
- if persona_name:
87
- return f"{persona_name} ({persona_id})"
88
- return persona_id
 
1
  import hashlib
2
  import re
3
  from collections.abc import Iterable
4
+ from enum import Enum
5
 
6
  from persona_data.synth_persona import PersonaData
7
 
8
+
9
+ class DatasetSource(str, Enum):
10
+ SYNTH_PERSONA = "HuggingFace: synth-persona"
11
+ NEMOTRON_FRANCE = "HuggingFace: nemotron-france"
12
+ NEMOTRON_USA = "HuggingFace: nemotron-usa"
13
+ LOCAL_UPLOAD = "Local JSONL upload"
14
+
15
+
16
+ DATASET_SOURCES = [s.value for s in DatasetSource]
17
+
18
+
19
  # Variant key -> human-readable label mapping
20
  VARIANT_LABELS = {
21
  "empty": "None",
 
28
  CHAT_PROMPT_MODES = ("empty", "templated", "biography", "custom")
29
  CHAT_PROMPT_MODE_LABELS = [VARIANT_LABELS[key] for key in CHAT_PROMPT_MODES]
30
  CHAT_PROMPT_MODE_LABEL_TO_KEY = {VARIANT_LABELS[key]: key for key in CHAT_PROMPT_MODES}
31
+ ANALYSIS_MODES = [
32
+ "Cosine similarity",
33
+ "Similarity matrix",
34
+ "PCA",
35
+ "UMAP",
36
+ "Isomap",
37
+ "Dendrogram",
38
  ]
 
39
 
40
  ANALYSIS_HELP_TEXT = {
41
  "Cosine similarity": "Compare layer-wise alignment between variants.",
42
  "Similarity matrix": "Compare centered pairwise similarity between persona vectors by layer, with pair trajectories across layers.",
43
  "PCA": "Project per-persona vectors into a 2D or 3D global view.",
44
  "UMAP": "Project per-persona vectors into a 2D or 3D local-neighborhood view.",
45
+ "Isomap": "Project per-persona vectors with graph-geodesic distances to probe manifold-like geometry.",
46
  "Dendrogram": "Hierarchical clustering of persona vectors — shows biography and templated side by side for direct comparison.",
47
  }
48
 
 
68
  return "::".join(parts)
69
 
70
 
71
+ def session_key(*parts: str) -> str:
72
+ """Generate a colon-separated Streamlit session-state key from parts."""
73
+
74
+ return ":".join(parts)
75
+
76
+
77
  def personas_fingerprint(persona_ids: Iterable[str]) -> str:
78
  """Stable short fingerprint for a set of persona ids.
79
 
 
96
  """Format a persona for selection widgets."""
97
 
98
  return f"{persona.name} ({persona.id})"
 
 
 
 
 
 
 
 
uv.lock CHANGED
@@ -748,11 +748,11 @@ wheels = [
748
 
749
  [[package]]
750
  name = "idna"
751
- version = "3.14"
752
  source = { registry = "https://pypi.org/simple" }
753
- sdist = { url = "https://files.pythonhosted.org/packages/05/b1/efac073e0c297ecf2fb33c346989a529d4e19164f1759102dee5953ee17e/idna-3.14.tar.gz", hash = "sha256:466d810d7a2cc1022bea9b037c39728d51ae7dad40d480fc9b7d7ecf98ba8ee3", size = 198272, upload-time = "2026-05-10T20:32:15.935Z" }
754
  wheels = [
755
- { url = "https://files.pythonhosted.org/packages/6c/3c/3f62dee257eb3d6b2c1ef2a09d36d9793c7111156a73b5654d2c2305e5ce/idna-3.14-py3-none-any.whl", hash = "sha256:e677eaf072e290f7b725f9acf0b3a2bd55f9fd6f7c70abe5f0e34823d0accf69", size = 72184, upload-time = "2026-05-10T20:32:14.295Z" },
756
  ]
757
 
758
  [[package]]
@@ -1559,7 +1559,7 @@ wheels = [
1559
 
1560
  [[package]]
1561
  name = "persona-data"
1562
- version = "0.4.2"
1563
  source = { registry = "https://pypi.org/simple" }
1564
  dependencies = [
1565
  { name = "huggingface-hub" },
@@ -1568,9 +1568,9 @@ dependencies = [
1568
  { name = "python-dotenv" },
1569
  { name = "torch" },
1570
  ]
1571
- sdist = { url = "https://files.pythonhosted.org/packages/a4/2f/099a74e54846172a20b697b46b285eb2f0004e1db530308d6b4ff1f19079/persona_data-0.4.2.tar.gz", hash = "sha256:7870292a79b3943a77c31595140de3b2243b783222590248d09891de70e7fe1b", size = 9276, upload-time = "2026-05-08T13:59:27.58Z" }
1572
  wheels = [
1573
- { url = "https://files.pythonhosted.org/packages/57/03/e76a48b41ee00684a4430269007e217e70f59e2597d7c862d93cfc5ac78b/persona_data-0.4.2-py3-none-any.whl", hash = "sha256:c881d6fb71af87a6fa773284076e4cb55794db6dc447a7eb0047eee2b389c855", size = 11914, upload-time = "2026-05-08T13:59:28.198Z" },
1574
  ]
1575
 
1576
  [[package]]
@@ -1581,7 +1581,6 @@ dependencies = [
1581
  { name = "catppuccin" },
1582
  { name = "datasets" },
1583
  { name = "huggingface-hub" },
1584
- { name = "persona-data" },
1585
  { name = "persona-vectors" },
1586
  { name = "plotly" },
1587
  { name = "python-dotenv" },
@@ -1593,8 +1592,7 @@ requires-dist = [
1593
  { name = "catppuccin", specifier = ">=2.5.0" },
1594
  { name = "datasets", specifier = ">=4.8.5" },
1595
  { name = "huggingface-hub", specifier = ">=1.14.0" },
1596
- { name = "persona-data", specifier = ">=0.4.2" },
1597
- { name = "persona-vectors", specifier = ">=0.7.3" },
1598
  { name = "plotly", specifier = ">=6.6.0" },
1599
  { name = "python-dotenv", specifier = ">=1.2.2" },
1600
  { name = "streamlit", specifier = ">=1.44.0" },
@@ -1602,7 +1600,7 @@ requires-dist = [
1602
 
1603
  [[package]]
1604
  name = "persona-vectors"
1605
- version = "0.7.3"
1606
  source = { registry = "https://pypi.org/simple" }
1607
  dependencies = [
1608
  { name = "datasets" },
@@ -1621,9 +1619,9 @@ dependencies = [
1621
  { name = "transformers" },
1622
  { name = "umap-learn" },
1623
  ]
1624
- sdist = { url = "https://files.pythonhosted.org/packages/6d/36/25d766934dc43f60faeba8a51c698da78bdd9af2e5d191b7ce8721612dc4/persona_vectors-0.7.3.tar.gz", hash = "sha256:75a90e68142097419a2f1cf6d21878dc5202234c12ed342d63349796255baad6", size = 28641, upload-time = "2026-05-12T10:04:37.21Z" }
1625
  wheels = [
1626
- { url = "https://files.pythonhosted.org/packages/6b/e7/db961133fda6755e215e6cd9d4058a1cb93719d05ab6e24030c5da885d15/persona_vectors-0.7.3-py3-none-any.whl", hash = "sha256:abf07b6715321a16b218aede69f7efac7bf6a309e090db62ee376e3f09240fde", size = 33224, upload-time = "2026-05-12T10:04:36.351Z" },
1627
  ]
1628
 
1629
  [[package]]
@@ -2838,7 +2836,7 @@ wheels = [
2838
 
2839
  [[package]]
2840
  name = "transformers"
2841
- version = "5.8.0"
2842
  source = { registry = "https://pypi.org/simple" }
2843
  dependencies = [
2844
  { name = "huggingface-hub" },
@@ -2851,9 +2849,9 @@ dependencies = [
2851
  { name = "tqdm" },
2852
  { name = "typer" },
2853
  ]
2854
- sdist = { url = "https://files.pythonhosted.org/packages/f2/36/390075693b76d4fb4a2bea360fb6080347763bd1f1147c49ed0ed938778c/transformers-5.8.0.tar.gz", hash = "sha256:6cc9a1f0291d16b1c1b735bad775e78ebefff7722701d4e28f98aaaa2bd6fb91", size = 8528141, upload-time = "2026-05-05T16:50:04.778Z" }
2855
  wheels = [
2856
- { url = "https://files.pythonhosted.org/packages/97/7b/5621d08b34ac35deb9fa14b58d27d124d21ef125ee1c64bc724ca47dfb63/transformers-5.8.0-py3-none-any.whl", hash = "sha256:e9d2cae6d195a7e1e05164c5ebf26142a7044e4dc4267274f4809204f92827e4", size = 10630279, upload-time = "2026-05-05T16:50:01.026Z" },
2857
  ]
2858
 
2859
  [[package]]
 
748
 
749
  [[package]]
750
  name = "idna"
751
+ version = "3.15"
752
  source = { registry = "https://pypi.org/simple" }
753
+ sdist = { url = "https://files.pythonhosted.org/packages/82/77/7b3966d0b9d1d31a36ddf1746926a11dface89a83409bf1483f0237aa758/idna-3.15.tar.gz", hash = "sha256:ca962446ea538f7092a95e057da437618e886f4d349216d2b1e294abfdb65fdc", size = 199245, upload-time = "2026-05-12T22:45:57.011Z" }
754
  wheels = [
755
+ { url = "https://files.pythonhosted.org/packages/d2/23/408243171aa9aaba178d3e2559159c24c1171a641aa83b67bdd3394ead8e/idna-3.15-py3-none-any.whl", hash = "sha256:048adeaf8c2d788c40fee287673ccaa74c24ffd8dcf09ffa555a2fbb59f10ac8", size = 72340, upload-time = "2026-05-12T22:45:55.733Z" },
756
  ]
757
 
758
  [[package]]
 
1559
 
1560
  [[package]]
1561
  name = "persona-data"
1562
+ version = "0.5.1"
1563
  source = { registry = "https://pypi.org/simple" }
1564
  dependencies = [
1565
  { name = "huggingface-hub" },
 
1568
  { name = "python-dotenv" },
1569
  { name = "torch" },
1570
  ]
1571
+ sdist = { url = "https://files.pythonhosted.org/packages/de/9f/2257b6df8c28f0844b88f64a200a4d27f82ea10a16e657ba9fd02f561135/persona_data-0.5.1.tar.gz", hash = "sha256:5ac4467c449905fecf26a743b7128f76dbd984a076426c3ce854a13394c1fc5c", size = 10336, upload-time = "2026-05-13T11:55:00.356Z" }
1572
  wheels = [
1573
+ { url = "https://files.pythonhosted.org/packages/55/ec/328013ee81672ba800777b3a9c24f18dc7cb3a93223391e3476cac55fa1b/persona_data-0.5.1-py3-none-any.whl", hash = "sha256:ccf230b4028d08b9345910b57de6ea4b60e9ec7f65ce12203f69693988314543", size = 13078, upload-time = "2026-05-13T11:55:01.402Z" },
1574
  ]
1575
 
1576
  [[package]]
 
1581
  { name = "catppuccin" },
1582
  { name = "datasets" },
1583
  { name = "huggingface-hub" },
 
1584
  { name = "persona-vectors" },
1585
  { name = "plotly" },
1586
  { name = "python-dotenv" },
 
1592
  { name = "catppuccin", specifier = ">=2.5.0" },
1593
  { name = "datasets", specifier = ">=4.8.5" },
1594
  { name = "huggingface-hub", specifier = ">=1.14.0" },
1595
+ { name = "persona-vectors", specifier = ">=0.8.0" },
 
1596
  { name = "plotly", specifier = ">=6.6.0" },
1597
  { name = "python-dotenv", specifier = ">=1.2.2" },
1598
  { name = "streamlit", specifier = ">=1.44.0" },
 
1600
 
1601
  [[package]]
1602
  name = "persona-vectors"
1603
+ version = "0.8.0"
1604
  source = { registry = "https://pypi.org/simple" }
1605
  dependencies = [
1606
  { name = "datasets" },
 
1619
  { name = "transformers" },
1620
  { name = "umap-learn" },
1621
  ]
1622
+ sdist = { url = "https://files.pythonhosted.org/packages/76/22/8a0ca0e6e54ebd8dd07a4064c2890ec33b68ad81a00e4e93c4f9eee2bcf7/persona_vectors-0.8.0.tar.gz", hash = "sha256:3775afc7e04ab1d02582e9c4b3f2d124174ea40d376dd2b91492457a747dd553", size = 31938, upload-time = "2026-05-13T20:00:46.357Z" }
1623
  wheels = [
1624
+ { url = "https://files.pythonhosted.org/packages/43/a6/7f67a7df27d78db706cbc9afd5d5ca4b52970b9005717c3bfcc0ce90ec71/persona_vectors-0.8.0-py3-none-any.whl", hash = "sha256:08b37a749f98b764d22d4c943158922338ab054729f7137eff2c3a167e2b2ae5", size = 36838, upload-time = "2026-05-13T20:00:47.252Z" },
1625
  ]
1626
 
1627
  [[package]]
 
2836
 
2837
  [[package]]
2838
  name = "transformers"
2839
+ version = "5.8.1"
2840
  source = { registry = "https://pypi.org/simple" }
2841
  dependencies = [
2842
  { name = "huggingface-hub" },
 
2849
  { name = "tqdm" },
2850
  { name = "typer" },
2851
  ]
2852
+ sdist = { url = "https://files.pythonhosted.org/packages/e7/e6/4134ea2fbea322cddc7ffc94a0d8ee47fe32ce8e876b320cd37d88edfc4d/transformers-5.8.1.tar.gz", hash = "sha256:4dd5b6de4105725104d84fd6abd74b305f4debfc251b38c648ee5dd087cf543b", size = 8532019, upload-time = "2026-05-13T03:21:57.234Z" }
2853
  wheels = [
2854
+ { url = "https://files.pythonhosted.org/packages/fc/b1/8be7e7ef0b5200491312201918b6125ef9c9df9dd0f0240ccef9ac824e6b/transformers-5.8.1-py3-none-any.whl", hash = "sha256:5340fb95962162cdfdae5cc91d7f8fedd92ed75216c1154c5e1f590fcf56dd0e", size = 10632882, upload-time = "2026-05-13T03:21:52.876Z" },
2855
  ]
2856
 
2857
  [[package]]