Jac-Zac commited on
Commit
9ba2da4
·
1 Parent(s): 77c2d62

Updated code supporting latest version of persona-vector and data

Browse files
Files changed (10) hide show
  1. pyproject.toml +2 -3
  2. tabs/chat.py +136 -93
  3. tabs/chat_ui.py +60 -6
  4. tabs/compare.py +204 -253
  5. tabs/compare_chat.py +326 -243
  6. tabs/extract.py +390 -258
  7. tabs/probe_ui.py +164 -125
  8. utils/contrast.py +17 -28
  9. utils/runtime.py +58 -43
  10. uv.lock +0 -0
pyproject.toml CHANGED
@@ -5,12 +5,11 @@ description = "Streamlit UI for persona-vectors"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
- "persona-vectors>=0.4.4",
9
- "persona-data>=0.3.4",
10
  "streamlit>=1.44.0",
11
  "plotly>=6.6.0",
12
  "python-dotenv>=1.2.2",
13
- "transformers>=5.5.0",
14
  ]
15
 
16
  # Local development:
 
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
+ "persona-vectors>=0.5.1",
9
+ "persona-data>=0.4.0",
10
  "streamlit>=1.44.0",
11
  "plotly>=6.6.0",
12
  "python-dotenv>=1.2.2",
 
13
  ]
14
 
15
  # Local development:
tabs/chat.py CHANGED
@@ -1,11 +1,11 @@
1
  import streamlit as st
 
2
 
3
- from state import chat_session_key, get_chat_state, reset_chat_context_state
4
  from tabs.chat_ui import (
5
  GenerationConfig,
6
- generation_dict,
7
  render_chat_window,
8
- render_generation_settings,
9
  render_persona_prompt_controls,
10
  render_system_prompt,
11
  )
@@ -27,21 +27,7 @@ _LAST_PROMPT_MODE_KEY = "chat:last_prompt_mode"
27
  _LAST_COMPARE_MODE_KEY = "chat:last_compare_mode"
28
 
29
 
30
- def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
31
- """Render the chat tab."""
32
-
33
- st.title("Chat")
34
-
35
- context_key = chat_session_key(model_name, dataset_source)
36
- chat_state = get_chat_state(model_name, remote, dataset_source)
37
-
38
- # Carry over persona / prompt selections across model or remote switches.
39
- if chat_state["persona_id"] is None:
40
- chat_state["persona_id"] = st.session_state.get(_LAST_PERSONA_ID_KEY)
41
- chat_state["prompt_mode"] = st.session_state.get(
42
- _LAST_PROMPT_MODE_KEY, "templated"
43
- )
44
-
45
  try:
46
  dataset, dataset_status = load_dataset(
47
  dataset_source,
@@ -52,36 +38,124 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
52
  except Exception as exc:
53
  st.error(f"Could not load data: {exc}")
54
  st.info("Check the selected dataset source or upload both JSONL files.")
55
- return
56
 
57
  personas = list(dataset)
58
  if not personas:
59
  st.warning("No personas found in the selected dataset.")
60
  st.info("Try a different dataset source or upload a non-empty personas file.")
61
- return
 
62
 
63
- generation: GenerationConfig = render_generation_settings(context_key, remote)
64
- probe_enabled = st.toggle(
65
- "Probe tools",
66
- value=False,
67
- key=widget_key(context_key, "probe_enabled"),
68
- help="Trace chat activations and run compatible `.pt` probes on tapped tokens.",
69
- )
70
 
71
- # ── Mode toggle ───────────────────────────────────────────────────────────
72
- compare_key = widget_key(context_key, "compare_mode")
73
- if compare_key not in st.session_state:
74
- st.session_state[compare_key] = st.session_state.get(
75
- _LAST_COMPARE_MODE_KEY, False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  )
77
- compare_mode = st.toggle(
78
- "Compare mode",
79
- key=compare_key,
80
- help="Side-by-side: send one message to two independent persona/prompt configurations.",
81
- )
82
- st.session_state[_LAST_COMPARE_MODE_KEY] = compare_mode
83
 
84
- if compare_mode:
 
 
 
 
 
 
 
 
 
85
  render_compare_mode(
86
  remote,
87
  model_name,
@@ -89,6 +163,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
89
  dataset_source,
90
  personas,
91
  generation,
 
92
  )
93
  return
94
 
@@ -150,7 +225,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
150
  remote=remote,
151
  active_system_prompt=active_system_prompt,
152
  chat_state=chat_state,
153
- enabled=probe_enabled,
154
  )
155
 
156
  render_chat_window(
@@ -161,36 +236,18 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
161
  pending_key=pending_key,
162
  )
163
 
164
- footer = st.container()
165
- with footer:
166
- exp_col, rst_col, _spacer = st.columns([0.5, 0.5, 10], gap="xsmall")
167
- with exp_col:
168
- if st.button(
169
- "",
170
- icon=":material/download:",
171
- key=export_key,
172
- help="Export chat",
173
- ):
174
- save_chat_export(
175
- model_name=model_name,
176
- dataset_source=dataset_source,
177
- persona_id=selected_persona.id,
178
- persona_name=getattr(selected_persona, "name", None),
179
- prompt_mode=prompt_mode,
180
- system_prompt=active_system_prompt,
181
- messages=chat_state["messages"],
182
- generation=generation_dict(generation),
183
- )
184
- st.toast("Exported", icon=":material/check:")
185
- with rst_col:
186
- if st.button(
187
- "",
188
- icon=":material/delete_sweep:",
189
- key=reset_key,
190
- help="Reset chat",
191
- ):
192
- _reset_active_chat_context()
193
- st.rerun()
194
 
195
  user_prompt = st.chat_input("Ask something...", key=chat_input_key)
196
 
@@ -205,26 +262,12 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
205
  if not pending_action:
206
  return
207
 
208
- messages = build_chat_messages(active_system_prompt, chat_state["messages"])
209
-
210
- with st.spinner("Generating reply..."):
211
- model = cached_model(model_name=model_name, remote=remote)
212
- try:
213
- reply: ChatReply = generate_chat_reply(
214
- model=model,
215
- messages=messages,
216
- remote=remote,
217
- past_key_values=chat_state["past_key_values"],
218
- **generation.to_generate_kwargs(),
219
- )
220
- except Exception as exc:
221
- with chat_log:
222
- st.error(f"Could not generate a reply: {exc}")
223
- st.info("Try a shorter prompt, reset the chat, or switch personas.")
224
- if pending_action == "new_user_prompt" and chat_state["messages"]:
225
- chat_state["messages"].pop()
226
- return
227
-
228
- chat_state["messages"].append({"role": "assistant", "content": reply.text})
229
- chat_state["past_key_values"] = reply.past_key_values if not remote else None
230
- st.rerun()
 
1
  import streamlit as st
2
+ from persona_data.synth_persona import PersonaData
3
 
4
+ from state import ChatState, chat_session_key, get_chat_state, reset_chat_context_state
5
  from tabs.chat_ui import (
6
  GenerationConfig,
7
+ render_advanced_settings,
8
  render_chat_window,
 
9
  render_persona_prompt_controls,
10
  render_system_prompt,
11
  )
 
27
  _LAST_COMPARE_MODE_KEY = "chat:last_compare_mode"
28
 
29
 
30
+ def _load_personas(dataset_source: str) -> list[PersonaData] | None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  try:
32
  dataset, dataset_status = load_dataset(
33
  dataset_source,
 
38
  except Exception as exc:
39
  st.error(f"Could not load data: {exc}")
40
  st.info("Check the selected dataset source or upload both JSONL files.")
41
+ return None
42
 
43
  personas = list(dataset)
44
  if not personas:
45
  st.warning("No personas found in the selected dataset.")
46
  st.info("Try a different dataset source or upload a non-empty personas file.")
47
+ return None
48
+ return personas
49
 
 
 
 
 
 
 
 
50
 
51
+ def _render_single_chat_footer(
52
+ *,
53
+ model_name: str,
54
+ dataset_source: str,
55
+ persona: PersonaData,
56
+ prompt_mode: str,
57
+ system_prompt: str | None,
58
+ chat_state: ChatState,
59
+ generation: GenerationConfig,
60
+ export_key: str,
61
+ reset_key: str,
62
+ on_reset,
63
+ ) -> None:
64
+ footer = st.container()
65
+ with footer:
66
+ exp_col, rst_col, _spacer = st.columns([0.5, 0.5, 10], gap="xsmall")
67
+ with exp_col:
68
+ if st.button(
69
+ "",
70
+ icon=":material/download:",
71
+ key=export_key,
72
+ help="Export chat",
73
+ ):
74
+ save_chat_export(
75
+ model_name=model_name,
76
+ dataset_source=dataset_source,
77
+ persona_id=persona.id,
78
+ persona_name=getattr(persona, "name", None),
79
+ prompt_mode=prompt_mode,
80
+ system_prompt=system_prompt,
81
+ messages=chat_state["messages"],
82
+ generation=generation.to_export_dict(),
83
+ )
84
+ st.toast("Exported", icon=":material/check:")
85
+ with rst_col:
86
+ if st.button(
87
+ "",
88
+ icon=":material/delete_sweep:",
89
+ key=reset_key,
90
+ help="Reset chat",
91
+ ):
92
+ on_reset()
93
+ st.rerun()
94
+
95
+
96
+ def _handle_single_chat_generation(
97
+ *,
98
+ remote: bool,
99
+ model_name: str,
100
+ chat_state: ChatState,
101
+ active_system_prompt: str | None,
102
+ generation: GenerationConfig,
103
+ pending_action: object,
104
+ chat_log,
105
+ ) -> None:
106
+ messages = build_chat_messages(active_system_prompt, chat_state["messages"])
107
+
108
+ with st.spinner("Generating reply..."):
109
+ model = cached_model(model_name=model_name)
110
+ try:
111
+ reply: ChatReply = generate_chat_reply(
112
+ model=model,
113
+ messages=messages,
114
+ remote=remote,
115
+ past_key_values=chat_state["past_key_values"],
116
+ **generation.to_generate_kwargs(),
117
+ )
118
+ except Exception as exc:
119
+ with chat_log:
120
+ st.error(f"Could not generate a reply: {exc}")
121
+ st.info("Try a shorter prompt, reset the chat, or switch personas.")
122
+ if pending_action == "new_user_prompt" and chat_state["messages"]:
123
+ chat_state["messages"].pop()
124
+ return
125
+
126
+ chat_state["messages"].append({"role": "assistant", "content": reply.text})
127
+ chat_state["past_key_values"] = reply.past_key_values if not remote else None
128
+ st.rerun()
129
+
130
+
131
+
132
+
133
+ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
134
+ """Render the chat tab."""
135
+
136
+ st.title("Chat")
137
+ st.caption("Chat with a persona, optionally side-by-side or with token contrast.")
138
+
139
+ context_key = chat_session_key(model_name, dataset_source)
140
+ chat_state = get_chat_state(model_name, remote, dataset_source)
141
+
142
+ # Carry over persona / prompt selections across model or remote switches.
143
+ if chat_state["persona_id"] is None:
144
+ chat_state["persona_id"] = st.session_state.get(_LAST_PERSONA_ID_KEY)
145
+ chat_state["prompt_mode"] = st.session_state.get(
146
+ _LAST_PROMPT_MODE_KEY, "templated"
147
  )
 
 
 
 
 
 
148
 
149
+ personas = _load_personas(dataset_source)
150
+ if personas is None:
151
+ return
152
+
153
+ generation, tools = render_advanced_settings(
154
+ context_key,
155
+ remote,
156
+ last_compare_mode_key=_LAST_COMPARE_MODE_KEY,
157
+ )
158
+ if tools.compare_mode:
159
  render_compare_mode(
160
  remote,
161
  model_name,
 
163
  dataset_source,
164
  personas,
165
  generation,
166
+ contrast_enabled=tools.token_contrast,
167
  )
168
  return
169
 
 
225
  remote=remote,
226
  active_system_prompt=active_system_prompt,
227
  chat_state=chat_state,
228
+ enabled=tools.probe_enabled,
229
  )
230
 
231
  render_chat_window(
 
236
  pending_key=pending_key,
237
  )
238
 
239
+ _render_single_chat_footer(
240
+ model_name=model_name,
241
+ dataset_source=dataset_source,
242
+ persona=selected_persona,
243
+ prompt_mode=prompt_mode,
244
+ system_prompt=active_system_prompt,
245
+ chat_state=chat_state,
246
+ generation=generation,
247
+ export_key=export_key,
248
+ reset_key=reset_key,
249
+ on_reset=_reset_active_chat_context,
250
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
  user_prompt = st.chat_input("Ask something...", key=chat_input_key)
253
 
 
262
  if not pending_action:
263
  return
264
 
265
+ _handle_single_chat_generation(
266
+ remote=remote,
267
+ model_name=model_name,
268
+ chat_state=chat_state,
269
+ active_system_prompt=active_system_prompt,
270
+ generation=generation,
271
+ pending_action=pending_action,
272
+ chat_log=chat_log,
273
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tabs/chat_ui.py CHANGED
@@ -48,6 +48,13 @@ class GenerationConfig:
48
  }
49
 
50
 
 
 
 
 
 
 
 
51
  @st.dialog("Edit", width="medium")
52
  def _open_edit_dialog(
53
  *,
@@ -108,13 +115,54 @@ def _open_system_prompt_dialog(*, prompt_key: str, current_value: str) -> None:
108
  st.rerun()
109
 
110
 
111
- def generation_dict(config: GenerationConfig) -> dict[str, object]:
112
- return config.to_export_dict()
 
 
 
 
 
 
 
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- def render_generation_settings(context_key: str, remote: bool) -> GenerationConfig:
116
- """Render the Advanced generation settings expander."""
117
- with st.expander("Advanced", expanded=False):
118
  config_col1, config_col2 = st.columns([2, 1])
119
  with config_col1:
120
  max_new_tokens = st.slider(
@@ -199,7 +247,7 @@ def render_generation_settings(context_key: str, remote: bool) -> GenerationConf
199
  st.caption("Seed is local-only and disabled for remote runs.")
200
 
201
  do_sample = bool(use_sampling)
202
- return GenerationConfig(
203
  max_new_tokens=int(max_new_tokens),
204
  do_sample=do_sample,
205
  temperature=float(temperature),
@@ -208,6 +256,12 @@ def render_generation_settings(context_key: str, remote: bool) -> GenerationConf
208
  repetition_penalty=float(repetition_penalty),
209
  seed=seed if do_sample and seed is not None and not remote else None,
210
  )
 
 
 
 
 
 
211
 
212
 
213
  def render_chat_message(
 
48
  }
49
 
50
 
51
+ @dataclass(frozen=True)
52
+ class ChatTools:
53
+ probe_enabled: bool
54
+ compare_mode: bool
55
+ token_contrast: bool
56
+
57
+
58
  @st.dialog("Edit", width="medium")
59
  def _open_edit_dialog(
60
  *,
 
115
  st.rerun()
116
 
117
 
118
+ def render_advanced_settings(
119
+ context_key: str,
120
+ remote: bool,
121
+ *,
122
+ last_compare_mode_key: str,
123
+ ) -> tuple[GenerationConfig, ChatTools]:
124
+ """Render the Advanced expander: tool toggles + generation settings."""
125
+ with st.expander("Advanced", expanded=False):
126
+ st.caption("Tools")
127
 
128
+ compare_key = widget_key(context_key, "compare_mode")
129
+ if compare_key not in st.session_state:
130
+ st.session_state[compare_key] = st.session_state.get(
131
+ last_compare_mode_key, False
132
+ )
133
+
134
+ tools_col1, tools_col2, tools_col3 = st.columns(3)
135
+ with tools_col1:
136
+ probe_enabled = st.toggle(
137
+ "Probe tools",
138
+ value=False,
139
+ key=widget_key(context_key, "probe_enabled"),
140
+ help="Trace chat activations and run compatible `.pt` probes on tapped tokens.",
141
+ )
142
+ with tools_col2:
143
+ compare_mode = st.toggle(
144
+ "Compare mode",
145
+ key=compare_key,
146
+ help="Side-by-side: send one message to two independent persona/prompt configurations.",
147
+ )
148
+ with tools_col3:
149
+ token_contrast = st.toggle(
150
+ "Token contrast",
151
+ value=False,
152
+ key=widget_key(context_key, "token_contrast"),
153
+ disabled=not compare_mode,
154
+ help=(
155
+ "Color each generated token by how characteristic it is of each persona. "
156
+ "Red = more likely under the left persona, blue = more likely under the "
157
+ "right. Requires up to four extra scoring passes after each turn. "
158
+ "Available only in Compare mode."
159
+ ),
160
+ )
161
+ st.session_state[last_compare_mode_key] = compare_mode
162
+
163
+ st.divider()
164
+ st.caption("Generation")
165
 
 
 
 
166
  config_col1, config_col2 = st.columns([2, 1])
167
  with config_col1:
168
  max_new_tokens = st.slider(
 
247
  st.caption("Seed is local-only and disabled for remote runs.")
248
 
249
  do_sample = bool(use_sampling)
250
+ generation = GenerationConfig(
251
  max_new_tokens=int(max_new_tokens),
252
  do_sample=do_sample,
253
  temperature=float(temperature),
 
256
  repetition_penalty=float(repetition_penalty),
257
  seed=seed if do_sample and seed is not None and not remote else None,
258
  )
259
+ tools = ChatTools(
260
+ probe_enabled=probe_enabled,
261
+ compare_mode=compare_mode,
262
+ token_contrast=token_contrast and compare_mode,
263
+ )
264
+ return generation, tools
265
 
266
 
267
  def render_chat_message(
tabs/compare.py CHANGED
@@ -1,8 +1,9 @@
 
1
  from itertools import combinations
 
2
 
3
  import streamlit as st
4
  from persona_data.environment import get_artifacts_dir
5
- from persona_data.prompts import BASELINE_PERSONA_ID
6
  from persona_vectors.analysis import (
7
  load_persona_mean_samples,
8
  load_variant_mean_samples,
@@ -41,6 +42,15 @@ _LAST_PROJECTION_PERSONAS_KEY = "compare:last_personas:projection"
41
  _LAST_MASK_STRATEGY_KEY = "compare:last_mask_strategy"
42
 
43
 
 
 
 
 
 
 
 
 
 
44
  def _select_artifact_personas(
45
  store: ActivationStore,
46
  variants: list[str],
@@ -143,71 +153,157 @@ def _render_mask_strategy_select(scope: str) -> MaskStrategy:
143
  return selected
144
 
145
 
146
- def _render_cosine_similarity(
147
  store: ActivationStore,
148
  mask_strategy: MaskStrategy,
149
- ) -> None:
150
  variants = list(store.variants)
151
  if len(variants) < 2:
152
  st.info("Need at least two non-baseline variants for cosine comparison.")
153
- return
154
 
155
- col1, col2 = st.columns(2)
156
- with col1:
157
- variant_a = st.selectbox(
158
- "Variant A",
159
- options=variants,
160
- index=0,
161
- format_func=prompt_variant_label,
162
- key=widget_key("load", "variant_a"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  )
164
- with col2:
165
- variant_b = st.selectbox(
166
- "Variant B",
167
- options=variants,
168
- index=min(1, len(variants) - 1),
169
- format_func=prompt_variant_label,
170
- key=widget_key("load", "variant_b"),
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  )
 
 
 
172
 
173
- if variant_a == variant_b:
174
- st.warning("Choose two different variants to compare.")
175
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
- persona_ids, _ = _select_artifact_personas(
178
- store,
179
- [variant_a, variant_b],
180
- mask_strategy,
181
- widget_scope="cosine",
182
- remember_key=_LAST_COSINE_PERSONAS_KEY,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  )
184
- if not persona_ids:
 
 
 
 
 
 
 
 
185
  return
186
- persona_key = "_".join(sorted(persona_ids))
187
 
188
  cosine_fig_key = widget_key(
189
  "load",
190
  "cosine_fig_state",
191
  store.model_name,
192
  mask_strategy.value,
193
- variant_a,
194
- variant_b,
195
- persona_key,
196
  )
197
  filename = _filename(
198
  "compare",
199
  "cosine",
200
  store.model_name,
201
  mask_strategy.value,
202
- variant_a,
203
- variant_b,
204
  )
205
  pairs_filename = _filename(
206
  "compare",
207
  "cosine_pairs",
208
  store.model_name,
209
  mask_strategy.value,
210
- "_".join(variants),
211
  )
212
 
213
  if st.button(
@@ -218,79 +314,16 @@ def _render_cosine_similarity(
218
  "compare_vectors",
219
  store.model_name,
220
  mask_strategy.value,
221
- variant_a,
222
- variant_b,
223
- persona_key,
224
  ),
225
  ):
226
- try:
227
- variant_samples = load_variant_mean_samples(
228
- store,
229
- [variant_a, variant_b],
230
- persona_ids=persona_ids,
231
- )
232
- except Exception as exc:
233
- st.error(f"Could not load vectors: {exc}")
234
  st.session_state.pop(cosine_fig_key, None)
235
  return
236
-
237
- labels = variant_samples[variant_a].labels
238
- display_traces = [
239
- (
240
- label,
241
- variant_samples[variant_a].vectors[index],
242
- variant_samples[variant_b].vectors[index],
243
- )
244
- for index, label in enumerate(labels)
245
- ]
246
- fig = plot_layer_similarity(
247
- display_traces,
248
- title=f"{prompt_variant_label(variant_a)} vs {prompt_variant_label(variant_b)}",
249
- show=False,
250
- )
251
-
252
- pair_traces = []
253
- pair_errors = []
254
- for left, right in combinations(variants, 2):
255
- try:
256
- pair_samples = (
257
- variant_samples
258
- if {left, right} == {variant_a, variant_b}
259
- else load_variant_mean_samples(
260
- store,
261
- [left, right],
262
- persona_ids=persona_ids,
263
- )
264
- )
265
- except Exception as exc:
266
- pair_errors.append(f"{left} vs {right}: {exc}")
267
- continue
268
- pair_traces.append(
269
- (
270
- f"{prompt_variant_label(left)} vs {prompt_variant_label(right)}",
271
- pair_samples[left].vectors.mean(dim=0),
272
- pair_samples[right].vectors.mean(dim=0),
273
- )
274
- )
275
-
276
- if pair_errors:
277
- for err in pair_errors:
278
- st.warning(f"Skipped pair trace: `{err}`")
279
- pair_fig = (
280
- plot_layer_similarity(
281
- pair_traces,
282
- title="Variant-pair cosine similarity averaged over selected personas",
283
- show=False,
284
- )
285
- if pair_traces
286
- else None
287
- )
288
- st.session_state[cosine_fig_key] = (
289
- fig,
290
- pair_fig,
291
- len(display_traces),
292
- len(pair_traces),
293
- )
294
 
295
  if cosine_fig_key in st.session_state:
296
  fig, pair_fig, n_traces, n_pair_traces = st.session_state[cosine_fig_key]
@@ -369,190 +402,89 @@ def _select_single_variant_samples(
369
  return variant, persona_ids, persona_key, selected_layers
370
 
371
 
372
- def _baseline_available(
373
- store: ActivationStore,
374
- ) -> bool:
375
- return BASELINE_PERSONA_ID in store.list_personas(
376
- [BASELINE_PERSONA_ID],
377
- warn_missing=False,
378
- )
379
-
380
-
381
- def _render_baseline_reference_toggle(
382
  store: ActivationStore,
383
  mask_strategy: MaskStrategy,
 
384
  scope: str,
385
- ) -> bool:
386
- available = _baseline_available(store)
387
- return st.checkbox(
388
- "Include Assistant baseline reference",
389
- value=available,
390
- disabled=not available,
391
- key=widget_key("load", "include_baseline", scope, mask_strategy.value),
392
- help=(
393
- "Adds the single saved baseline artifact as one reference sample."
394
- if available
395
- else "Run Assistant baseline extraction first."
396
- ),
397
- )
398
-
399
-
400
- def _render_similarity_matrix(
401
- store: ActivationStore,
402
- mask_strategy: MaskStrategy,
403
  ) -> None:
404
- selected = _select_single_variant_samples(
405
- store,
406
- mask_strategy,
407
- "similarity_matrix",
408
- )
409
- if selected is None:
410
- return
411
- variant, persona_ids, persona_key, selected_layers = selected
412
- include_baseline = _render_baseline_reference_toggle(
413
- store,
414
- mask_strategy,
415
- "similarity_matrix",
416
- )
417
-
418
- fig_key = widget_key(
419
- "load",
420
- "similarity_matrix_fig_state",
421
- store.model_name,
422
- mask_strategy.value,
423
- variant,
424
- "persona_mean",
425
- persona_key,
426
- BASELINE_PERSONA_ID if include_baseline else "no_baseline",
427
- )
428
- filename = _filename(
429
- "compare",
430
- "similarity_matrix",
431
- store.model_name,
432
- mask_strategy.value,
433
- variant,
434
- "persona_mean",
435
- persona_key,
436
- BASELINE_PERSONA_ID if include_baseline else "",
437
- )
438
 
439
- if st.button("Generate similarity matrix", type="primary"):
440
- try:
441
- samples = load_persona_mean_samples(
442
- store,
443
- variant,
444
- mask_strategy=mask_strategy,
445
- persona_ids=persona_ids,
446
- include_baseline=include_baseline,
447
- )
448
- matrix_fig = build_layered_figure(
449
- samples,
450
- "similarity",
451
- layers=selected_layers,
452
- title=(
453
- "Centered similarity - "
454
- f"{prompt_variant_label(variant)} - personas averaged over questions"
455
- ),
456
- )
457
- trajectory_fig = build_pair_similarity_figure(
458
- samples,
459
- layers=selected_layers,
460
- title=(
461
- "Pair similarity trajectories - "
462
- f"{prompt_variant_label(variant)} - personas averaged over questions"
463
- ),
464
- )
465
- st.session_state[fig_key] = (
466
- matrix_fig,
467
- trajectory_fig,
468
- samples.vectors.shape[0],
469
- )
470
- except Exception as exc:
471
- st.error(f"Could not build similarity matrix: {exc}")
472
- st.session_state.pop(fig_key, None)
473
-
474
- if fig_key in st.session_state:
475
- matrix_fig, trajectory_fig, n_samples = st.session_state[fig_key]
476
- st.plotly_chart(matrix_fig, width="stretch")
477
- st.subheader("Pair trajectories")
478
- st.plotly_chart(trajectory_fig, width="stretch")
479
- _render_save_buttons(
480
- [matrix_fig, trajectory_fig],
481
- [filename, f"{filename}__pair_trajectories"],
482
- "similarity_matrix",
483
- )
484
- st.success(f"Loaded {n_samples} samples.")
485
-
486
-
487
- def _render_embedding_analysis(
488
- store: ActivationStore,
489
- analysis_mode: str,
490
- mask_strategy: MaskStrategy,
491
- ) -> None:
492
- selected = _select_single_variant_samples(
493
- store,
494
- mask_strategy,
495
- analysis_mode.lower(),
496
- )
497
  if selected is None:
498
  return
499
  variant, persona_ids, persona_key, selected_layers = selected
500
 
501
- figure_kind = analysis_mode.lower()
502
- include_baseline = _render_baseline_reference_toggle(
503
- store,
504
- mask_strategy,
505
- analysis_mode.lower(),
506
- )
507
-
508
  fig_key = widget_key(
509
  "load",
510
- "embedding_fig_state",
511
  store.model_name,
512
  mask_strategy.value,
513
  figure_kind,
514
  variant,
515
  "persona_mean",
516
  persona_key,
517
- BASELINE_PERSONA_ID if include_baseline else "no_baseline",
518
  )
519
  filename = _filename(
520
  "compare",
521
- figure_kind,
522
  store.model_name,
523
  mask_strategy.value,
524
  variant,
525
  "persona_mean",
526
  persona_key,
527
- BASELINE_PERSONA_ID if include_baseline else "",
528
  )
529
 
530
- if st.button(f"Generate {analysis_mode} projection", type="primary"):
531
  try:
532
  samples = load_persona_mean_samples(
533
  store,
534
  variant,
535
  mask_strategy=mask_strategy,
536
  persona_ids=persona_ids,
537
- include_baseline=include_baseline,
538
  )
539
- fig = build_layered_figure(
540
  samples,
541
  figure_kind,
542
  layers=selected_layers,
543
- title=(
544
- f"{analysis_mode} - {prompt_variant_label(variant)} - Persona means"
545
- ),
 
 
 
 
 
 
 
 
 
 
 
546
  )
547
- st.session_state[fig_key] = (fig, samples.vectors.shape[0])
548
  except Exception as exc:
549
- st.error(f"Could not build {analysis_mode}: {exc}")
550
  st.session_state.pop(fig_key, None)
551
 
552
  if fig_key in st.session_state:
553
- fig, n_samples = st.session_state[fig_key]
554
- st.plotly_chart(fig, width="stretch")
555
- _render_save_buttons([fig], [filename], figure_kind)
 
 
 
 
 
 
 
556
  st.success(f"Loaded {n_samples} samples.")
557
 
558
 
@@ -562,9 +494,7 @@ def render_compare_tab(model_name: str) -> None:
562
  st.title("Compare")
563
  st.caption("Compare saved activations by cosine similarity, PCA, or UMAP.")
564
 
565
- st.subheader("Analysis")
566
-
567
- with st.expander("Advanced", expanded=False):
568
  artifacts_root = st.text_input(
569
  "Artifacts root",
570
  value=str(get_artifacts_dir() / "activations"),
@@ -580,14 +510,35 @@ def render_compare_tab(model_name: str) -> None:
580
  if analysis_mode is None:
581
  analysis_mode = ANALYSIS_MODES[0]
582
  st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
583
- mask_strategy = _render_mask_strategy_select(analysis_mode)
 
584
  store = ActivationStore(model_name, artifacts_root, mask_strategy=mask_strategy)
585
 
586
  if analysis_mode == "Cosine similarity":
587
  _render_cosine_similarity(store, mask_strategy)
588
  return
589
  if analysis_mode == "Similarity matrix":
590
- _render_similarity_matrix(store, mask_strategy)
 
 
 
 
 
 
 
 
 
 
 
591
  return
592
 
593
- _render_embedding_analysis(store, analysis_mode, mask_strategy)
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Callable
2
  from itertools import combinations
3
+ from dataclasses import dataclass
4
 
5
  import streamlit as st
6
  from persona_data.environment import get_artifacts_dir
 
7
  from persona_vectors.analysis import (
8
  load_persona_mean_samples,
9
  load_variant_mean_samples,
 
42
  _LAST_MASK_STRATEGY_KEY = "compare:last_mask_strategy"
43
 
44
 
45
+ @dataclass(frozen=True)
46
+ class CosineSelection:
47
+ variants: list[str]
48
+ variant_a: str
49
+ variant_b: str
50
+ persona_ids: list[str]
51
+ persona_key: str
52
+
53
+
54
  def _select_artifact_personas(
55
  store: ActivationStore,
56
  variants: list[str],
 
153
  return selected
154
 
155
 
156
+ def _render_cosine_selection(
157
  store: ActivationStore,
158
  mask_strategy: MaskStrategy,
159
+ ) -> CosineSelection | None:
160
  variants = list(store.variants)
161
  if len(variants) < 2:
162
  st.info("Need at least two non-baseline variants for cosine comparison.")
163
+ return None
164
 
165
+ with st.expander("Vector selection", expanded=True):
166
+ col1, col2 = st.columns(2)
167
+ with col1:
168
+ variant_a = st.selectbox(
169
+ "Variant A",
170
+ options=variants,
171
+ index=0,
172
+ format_func=prompt_variant_label,
173
+ key=widget_key("load", "variant_a"),
174
+ )
175
+ with col2:
176
+ variant_b = st.selectbox(
177
+ "Variant B",
178
+ options=variants,
179
+ index=min(1, len(variants) - 1),
180
+ format_func=prompt_variant_label,
181
+ key=widget_key("load", "variant_b"),
182
+ )
183
+
184
+ if variant_a == variant_b:
185
+ st.warning("Choose two different variants to compare.")
186
+ return None
187
+
188
+ persona_ids, _ = _select_artifact_personas(
189
+ store,
190
+ [variant_a, variant_b],
191
+ mask_strategy,
192
+ widget_scope="cosine",
193
+ remember_key=_LAST_COSINE_PERSONAS_KEY,
194
  )
195
+ if not persona_ids:
196
+ return None
197
+ return CosineSelection(
198
+ variants=variants,
199
+ variant_a=variant_a,
200
+ variant_b=variant_b,
201
+ persona_ids=persona_ids,
202
+ persona_key="_".join(sorted(persona_ids)),
203
+ )
204
+
205
+
206
+ def _build_cosine_figures(
207
+ store: ActivationStore,
208
+ selection: CosineSelection,
209
+ ) -> tuple[object, object | None, int, int] | None:
210
+ try:
211
+ variant_samples = load_variant_mean_samples(
212
+ store,
213
+ [selection.variant_a, selection.variant_b],
214
+ persona_ids=selection.persona_ids,
215
  )
216
+ except Exception as exc:
217
+ st.error(f"Could not load vectors: {exc}")
218
+ return None
219
 
220
+ labels = variant_samples[selection.variant_a].labels
221
+ display_traces = [
222
+ (
223
+ label,
224
+ variant_samples[selection.variant_a].vectors[index],
225
+ variant_samples[selection.variant_b].vectors[index],
226
+ )
227
+ for index, label in enumerate(labels)
228
+ ]
229
+ fig = plot_layer_similarity(
230
+ display_traces,
231
+ title=(
232
+ f"{prompt_variant_label(selection.variant_a)} vs "
233
+ f"{prompt_variant_label(selection.variant_b)}"
234
+ ),
235
+ show=False,
236
+ )
237
 
238
+ pair_traces = []
239
+ pair_errors = []
240
+ for left, right in combinations(selection.variants, 2):
241
+ try:
242
+ pair_samples = (
243
+ variant_samples
244
+ if {left, right} == {selection.variant_a, selection.variant_b}
245
+ else load_variant_mean_samples(
246
+ store,
247
+ [left, right],
248
+ persona_ids=selection.persona_ids,
249
+ )
250
+ )
251
+ pair_traces.append(
252
+ (
253
+ f"{prompt_variant_label(left)} vs {prompt_variant_label(right)}",
254
+ pair_samples[left].vectors.mean(dim=0),
255
+ pair_samples[right].vectors.mean(dim=0),
256
+ )
257
+ )
258
+ except Exception as exc:
259
+ pair_errors.append(f"{left} vs {right}: {exc}")
260
+ continue
261
+
262
+ for err in pair_errors:
263
+ st.warning(f"Skipped pair trace: `{err}`")
264
+ pair_fig = (
265
+ plot_layer_similarity(
266
+ pair_traces,
267
+ title="Variant-pair cosine similarity averaged over selected personas",
268
+ show=False,
269
+ )
270
+ if pair_traces
271
+ else None
272
  )
273
+ return fig, pair_fig, len(display_traces), len(pair_traces)
274
+
275
+
276
+ def _render_cosine_similarity(
277
+ store: ActivationStore,
278
+ mask_strategy: MaskStrategy,
279
+ ) -> None:
280
+ selection = _render_cosine_selection(store, mask_strategy)
281
+ if selection is None:
282
  return
 
283
 
284
  cosine_fig_key = widget_key(
285
  "load",
286
  "cosine_fig_state",
287
  store.model_name,
288
  mask_strategy.value,
289
+ selection.variant_a,
290
+ selection.variant_b,
291
+ selection.persona_key,
292
  )
293
  filename = _filename(
294
  "compare",
295
  "cosine",
296
  store.model_name,
297
  mask_strategy.value,
298
+ selection.variant_a,
299
+ selection.variant_b,
300
  )
301
  pairs_filename = _filename(
302
  "compare",
303
  "cosine_pairs",
304
  store.model_name,
305
  mask_strategy.value,
306
+ "_".join(selection.variants),
307
  )
308
 
309
  if st.button(
 
314
  "compare_vectors",
315
  store.model_name,
316
  mask_strategy.value,
317
+ selection.variant_a,
318
+ selection.variant_b,
319
+ selection.persona_key,
320
  ),
321
  ):
322
+ figures = _build_cosine_figures(store, selection)
323
+ if figures is None:
 
 
 
 
 
 
324
  st.session_state.pop(cosine_fig_key, None)
325
  return
326
+ st.session_state[cosine_fig_key] = figures
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
  if cosine_fig_key in st.session_state:
329
  fig, pair_fig, n_traces, n_pair_traces = st.session_state[cosine_fig_key]
 
402
  return variant, persona_ids, persona_key, selected_layers
403
 
404
 
405
+ def _render_layered_figure_analysis(
 
 
 
 
 
 
 
 
 
406
  store: ActivationStore,
407
  mask_strategy: MaskStrategy,
408
+ *,
409
  scope: str,
410
+ figure_kind: str,
411
+ button_label: str,
412
+ title_fn: Callable[[str], str],
413
+ include_pair_trajectories: bool = False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  ) -> None:
415
+ """Render a single-variant layered analysis: select → button → figure(s).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
+ Used for similarity matrix, PCA, and UMAP. Set ``include_pair_trajectories``
418
+ to add the pair-similarity-trajectory figure (similarity matrix only).
419
+ """
420
+ selected = _select_single_variant_samples(store, mask_strategy, scope)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  if selected is None:
422
  return
423
  variant, persona_ids, persona_key, selected_layers = selected
424
 
 
 
 
 
 
 
 
425
  fig_key = widget_key(
426
  "load",
427
+ f"{scope}_fig_state",
428
  store.model_name,
429
  mask_strategy.value,
430
  figure_kind,
431
  variant,
432
  "persona_mean",
433
  persona_key,
 
434
  )
435
  filename = _filename(
436
  "compare",
437
+ scope,
438
  store.model_name,
439
  mask_strategy.value,
440
  variant,
441
  "persona_mean",
442
  persona_key,
 
443
  )
444
 
445
+ if st.button(button_label, type="primary"):
446
  try:
447
  samples = load_persona_mean_samples(
448
  store,
449
  variant,
450
  mask_strategy=mask_strategy,
451
  persona_ids=persona_ids,
 
452
  )
453
+ main_fig = build_layered_figure(
454
  samples,
455
  figure_kind,
456
  layers=selected_layers,
457
+ title=title_fn(variant),
458
+ )
459
+ extra_fig = (
460
+ build_pair_similarity_figure(
461
+ samples,
462
+ layers=selected_layers,
463
+ title=(
464
+ "Pair similarity trajectories - "
465
+ f"{prompt_variant_label(variant)} - "
466
+ "persona mean activations"
467
+ ),
468
+ )
469
+ if include_pair_trajectories
470
+ else None
471
  )
472
+ st.session_state[fig_key] = (main_fig, extra_fig, samples.vectors.shape[0])
473
  except Exception as exc:
474
+ st.error(f"Could not build figure: {exc}")
475
  st.session_state.pop(fig_key, None)
476
 
477
  if fig_key in st.session_state:
478
+ main_fig, extra_fig, n_samples = st.session_state[fig_key]
479
+ st.plotly_chart(main_fig, width="stretch")
480
+ figs = [main_fig]
481
+ filenames = [filename]
482
+ if extra_fig is not None:
483
+ st.subheader("Pair trajectories")
484
+ st.plotly_chart(extra_fig, width="stretch")
485
+ figs.append(extra_fig)
486
+ filenames.append(f"{filename}__pair_trajectories")
487
+ _render_save_buttons(figs, filenames, scope)
488
  st.success(f"Loaded {n_samples} samples.")
489
 
490
 
 
494
  st.title("Compare")
495
  st.caption("Compare saved activations by cosine similarity, PCA, or UMAP.")
496
 
497
+ with st.expander("Artifact settings", expanded=False):
 
 
498
  artifacts_root = st.text_input(
499
  "Artifacts root",
500
  value=str(get_artifacts_dir() / "activations"),
 
510
  if analysis_mode is None:
511
  analysis_mode = ANALYSIS_MODES[0]
512
  st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
513
+ with st.expander("Activation settings", expanded=False):
514
+ mask_strategy = _render_mask_strategy_select(analysis_mode)
515
  store = ActivationStore(model_name, artifacts_root, mask_strategy=mask_strategy)
516
 
517
  if analysis_mode == "Cosine similarity":
518
  _render_cosine_similarity(store, mask_strategy)
519
  return
520
  if analysis_mode == "Similarity matrix":
521
+ _render_layered_figure_analysis(
522
+ store,
523
+ mask_strategy,
524
+ scope="similarity_matrix",
525
+ figure_kind="similarity",
526
+ button_label="Generate similarity matrix",
527
+ title_fn=lambda v: (
528
+ "Centered similarity - "
529
+ f"{prompt_variant_label(v)} - persona mean activations"
530
+ ),
531
+ include_pair_trajectories=True,
532
+ )
533
  return
534
 
535
+ _render_layered_figure_analysis(
536
+ store,
537
+ mask_strategy,
538
+ scope=analysis_mode.lower(),
539
+ figure_kind=analysis_mode.lower(),
540
+ button_label=f"Generate {analysis_mode} projection",
541
+ title_fn=lambda v: (
542
+ f"{analysis_mode} - {prompt_variant_label(v)} - Persona means"
543
+ ),
544
+ )
tabs/compare_chat.py CHANGED
@@ -1,10 +1,11 @@
1
- from typing import Any, NamedTuple
 
2
 
3
  import streamlit as st
4
  from nnterp import StandardizedTransformer
5
  from persona_data.synth_persona import PersonaData
6
 
7
- from state import default_chat_state, reset_chat_context_state
8
  from utils.chat import (
9
  ChatReply,
10
  build_chat_messages,
@@ -18,7 +19,6 @@ from utils.runtime import cached_model
18
 
19
  from .chat_ui import (
20
  GenerationConfig,
21
- generation_dict,
22
  render_chat_message,
23
  render_chat_window,
24
  render_persona_prompt_controls,
@@ -26,9 +26,10 @@ from .chat_ui import (
26
  )
27
 
28
 
29
- class ComparePanel(NamedTuple):
 
30
  side: str
31
- state: dict[str, object]
32
  log: Any
33
  prompt: str | None
34
  persona: PersonaData
@@ -37,6 +38,13 @@ class ComparePanel(NamedTuple):
37
  pending_key: str
38
 
39
 
 
 
 
 
 
 
 
40
  def _reset_compare_panel(panel: ComparePanel) -> None:
41
  reset_chat_context_state(
42
  panel.state,
@@ -48,195 +56,188 @@ def _reset_compare_panel(panel: ComparePanel) -> None:
48
  st.session_state.pop(panel.edit_key, None)
49
 
50
 
51
- def _generate_panel_reply(
52
  *,
53
- model: StandardizedTransformer,
54
- remote: bool,
55
- panel_state: dict[str, object],
56
- panel_prompt: str | None,
57
- generation: GenerationConfig,
58
- ) -> ChatReply:
59
- return generate_chat_reply(
60
- model=model,
61
- messages=build_chat_messages(panel_prompt, panel_state["messages"]),
62
- remote=remote,
63
- past_key_values=panel_state["past_key_values"],
64
- **generation.to_generate_kwargs(),
65
- )
66
-
67
-
68
- def render_compare_mode(
69
- remote: bool,
70
- model_name: str,
71
  context_key: str,
72
- dataset_source: str,
73
  personas: list[PersonaData],
74
- generation: GenerationConfig,
75
- ) -> None:
76
- """Render the full side-by-side comparison UI."""
77
- model: StandardizedTransformer | None = None
78
-
79
- def _get_model() -> StandardizedTransformer:
80
- nonlocal model
81
- if model is None:
82
- model = cached_model(model_name=model_name, remote=remote)
83
- return model
84
-
85
- contrast_key = widget_key(context_key, "token_contrast")
86
- contrast_enabled = st.toggle(
87
- "Token contrast",
88
- value=False,
89
- key=contrast_key,
90
- help=(
91
- "Color each generated token by how characteristic it is of each persona. "
92
- "Red = more likely under the left persona, blue = more likely under the right. "
93
- "Requires up to four extra scoring passes after each turn."
94
- ),
95
  )
96
-
97
- def render_panel(side: str) -> ComparePanel:
98
- panel_key = widget_key(context_key, f"cmp_{side}")
99
- if panel_key not in st.session_state:
100
- st.session_state[panel_key] = default_chat_state()
101
- state = st.session_state[panel_key]
102
-
103
- prompt_key = widget_key(panel_key, "custom_prompt")
104
- edit_key = widget_key(panel_key, "edit_idx")
105
- pending_key = widget_key(panel_key, "pending_regen")
106
-
107
- # Carry over persona / prompt selections across model or remote switches.
108
- persist_persona_key = f"chat:last_cmp_{side}_persona"
109
- persist_prompt_key = f"chat:last_cmp_{side}_prompt"
110
- if state["persona_id"] is None:
111
- state["persona_id"] = st.session_state.get(persist_persona_key)
112
- state["prompt_mode"] = st.session_state.get(persist_prompt_key, "templated")
113
-
114
- selected_persona, prompt_mode, changed = render_persona_prompt_controls(
115
- personas,
116
- state["persona_id"],
117
- state["prompt_mode"],
118
- widget_key(panel_key, "persona"),
119
- widget_key(panel_key, "prompt_mode"),
120
  )
121
- st.session_state[persist_persona_key] = selected_persona.id
122
- st.session_state[persist_prompt_key] = prompt_mode
123
 
124
- if changed:
125
- reset_chat_context_state(
126
- state, selected_persona.id, prompt_mode, prompt_key, pending_key
127
- )
128
- st.session_state.pop(edit_key, None)
129
 
130
- active_system_prompt = resolve_system_prompt(
131
- persona=selected_persona, mode=prompt_mode
 
 
 
 
132
  )
133
 
134
- chat_log = st.container()
135
- with chat_log:
136
- active_system_prompt = render_system_prompt(
137
- prompt_key, prompt_mode, active_system_prompt
138
- )
139
- return ComparePanel(
140
- side=side,
141
- state=state,
142
- log=chat_log,
143
- prompt=active_system_prompt,
144
- persona=selected_persona,
145
- prompt_key=prompt_key,
146
- edit_key=edit_key,
147
- pending_key=pending_key,
148
- )
149
 
150
- left_col, right_col = st.columns(2)
151
- with left_col:
152
- left = render_panel("left")
153
- with right_col:
154
- right = render_panel("right")
155
- panels: list[ComparePanel] = [left, right]
156
 
157
- # Handle per-panel regeneration triggered by message edits
158
- regen_panels = [p for p in panels if st.session_state.pop(p.pending_key, False)]
159
- if regen_panels:
160
- model = _get_model()
161
-
162
- results: list[ChatReply | Exception] = []
163
- with st.spinner("Regenerating..."):
164
- for panel in regen_panels:
165
- try:
166
- results.append(
167
- _generate_panel_reply(
168
- model=model,
169
- remote=remote,
170
- panel_state=panel.state,
171
- panel_prompt=panel.prompt,
172
- generation=generation,
173
- )
 
 
 
 
174
  )
175
- except Exception as exc:
176
- results.append(exc)
 
 
177
 
178
- for panel, result in zip(regen_panels, results):
179
- if isinstance(result, Exception):
180
- with panel.log:
181
- st.error(f"Generation failed: {result}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  continue
183
- panel.state["messages"].append(
184
- {"role": "assistant", "content": result.text}
 
 
185
  )
186
- panel.state["past_key_values"] = (
187
- result.past_key_values if not remote else None
 
188
  )
189
- st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
- # Recompute contrast for assistant messages that were edited in place.
192
- if contrast_enabled:
193
- pending_edits: list[tuple[int, int]] = [
194
- (panel_idx, msg_idx)
195
- for panel_idx, panel in enumerate(panels)
196
- for msg_idx, msg in enumerate(panel.state["messages"])
197
- if msg.get("_needs_contrast") and msg.get("role") == "assistant"
198
- ]
199
- if pending_edits:
200
- model = _get_model()
201
- label_a = persona_label(left.persona)
202
- label_b = persona_label(right.persona)
203
- with st.spinner("Recomputing token contrast…"):
204
- for panel_idx, msg_idx in pending_edits:
205
- panel = panels[panel_idx]
206
- msg = panel.state["messages"][msg_idx]
207
- if msg_idx >= len(left.state["messages"]) or msg_idx >= len(
208
- right.state["messages"]
209
- ):
210
- msg.pop("_needs_contrast", None)
211
- continue
212
- context_a = build_chat_messages(
213
- left.prompt, left.state["messages"][:msg_idx]
214
- )
215
- context_b = build_chat_messages(
216
- right.prompt, right.state["messages"][:msg_idx]
217
- )
218
- try:
219
- response_ids = model.tokenizer(
220
- msg["content"],
221
- add_special_tokens=False,
222
- return_tensors="pt",
223
- ).input_ids[0]
224
- tc = compute_contrast(
225
- model=model,
226
- context_a=context_a,
227
- context_b=context_b,
228
- response_ids=response_ids,
229
- label_a=label_a,
230
- label_b=label_b,
231
- remote=remote,
232
- )
233
- if tc is not None:
234
- msg["_contrast"] = tc
235
- except Exception as exc:
236
- st.warning(f"Token contrast recompute failed: {exc}")
237
- msg.pop("_needs_contrast", None)
238
- st.rerun()
239
 
 
 
 
 
 
240
  for panel in panels:
241
  render_chat_window(
242
  chat_log=panel.log,
@@ -248,10 +249,23 @@ def render_compare_mode(
248
  edit_column_ratio=(10, 1),
249
  )
250
 
251
- footer = st.container()
 
 
 
 
 
 
 
 
 
 
 
252
  reset_menu_nonce_key = widget_key(context_key, "cmp_reset_menu_nonce")
253
  if reset_menu_nonce_key not in st.session_state:
254
  st.session_state[reset_menu_nonce_key] = 0
 
 
255
  with footer:
256
  exp_col, rst_col, _spacer = st.columns([0.5, 0.5, 10], gap="xsmall")
257
  with exp_col:
@@ -270,7 +284,7 @@ def render_compare_mode(
270
  prompt_mode=panel.state["prompt_mode"],
271
  system_prompt=panel.prompt,
272
  messages=panel.state["messages"],
273
- generation=generation_dict(generation),
274
  panel_label=panel.side,
275
  )
276
  st.toast("Exported", icon=":material/check:")
@@ -304,81 +318,150 @@ def render_compare_mode(
304
  st.session_state[reset_menu_nonce_key] += 1
305
  st.rerun()
306
 
307
- user_prompt = st.chat_input(
308
- "Ask both...",
309
- key=widget_key(context_key, "cmp_input"),
310
- )
311
-
312
- if not user_prompt:
313
- return
314
-
315
- model = cached_model(model_name=model_name, remote=remote)
316
 
 
317
  for panel in panels:
318
  panel.state["messages"].append({"role": "user", "content": user_prompt})
319
  with panel.log:
320
  render_chat_message({"role": "user", "content": user_prompt})
321
 
322
- # Snapshot contexts before the new assistant turn is appended (needed for contrast).
323
- pre_gen_contexts = [
324
- build_chat_messages(panel.prompt, panel.state["messages"]) for panel in panels
325
- ]
326
 
327
- results: list[ChatReply | Exception] = []
328
- with st.spinner("Generating..."):
329
- # Sequential generation keeps both panels using model/session state safely.
330
- for panel in panels:
331
- try:
332
- results.append(
333
- _generate_panel_reply(
334
- model=model,
335
- remote=remote,
336
- panel_state=panel.state,
337
- panel_prompt=panel.prompt,
338
- generation=generation,
339
- )
340
- )
341
- except Exception as exc:
342
- results.append(exc)
343
 
344
- valid_results: list[ChatReply | None] = []
345
- for panel, result in zip(panels, results):
346
- if isinstance(result, Exception):
347
- with panel.log:
348
- st.error(f"Generation failed: {result}")
349
- panel.state["messages"].pop()
350
- valid_results.append(None)
351
- continue
 
 
 
 
 
 
 
 
 
 
 
352
 
353
- panel.state["messages"].append({"role": "assistant", "content": result.text})
354
- panel.state["past_key_values"] = result.past_key_values if not remote else None
355
- valid_results.append(result)
356
 
357
- # Compute contrastive token coloring when both panels succeeded.
358
- if (
359
- contrast_enabled
360
- and len(valid_results) == 2
361
- and all(r is not None and r.generated_ids is not None for r in valid_results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  ):
363
- with st.spinner("Computing token contrast…"):
364
- try:
365
- tc_a, tc_b = compute_contrast_pair(
366
- model=model,
367
- context_a=pre_gen_contexts[0],
368
- context_b=pre_gen_contexts[1],
369
- response_ids_a=valid_results[0].generated_ids,
370
- response_ids_b=valid_results[1].generated_ids,
371
- label_a=persona_label(left.persona),
372
- label_b=persona_label(right.persona),
373
- remote=remote,
374
- )
375
- if tc_a is not None:
376
- left.state["messages"][-1]["_contrast"] = tc_a
377
- if tc_b is not None:
378
- right.state["messages"][-1]["_contrast"] = tc_b
379
- except Exception as exc:
380
- st.warning(f"Token contrast failed: {exc}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
 
382
- # Rerun so the newly appended turns are redrawn through the editable history
383
- # renderer instead of only appearing in the one-off generation pass.
384
  st.rerun()
 
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
 
4
  import streamlit as st
5
  from nnterp import StandardizedTransformer
6
  from persona_data.synth_persona import PersonaData
7
 
8
+ from state import ChatState, default_chat_state, reset_chat_context_state
9
  from utils.chat import (
10
  ChatReply,
11
  build_chat_messages,
 
19
 
20
  from .chat_ui import (
21
  GenerationConfig,
 
22
  render_chat_message,
23
  render_chat_window,
24
  render_persona_prompt_controls,
 
26
  )
27
 
28
 
29
+ @dataclass(frozen=True)
30
+ class ComparePanel:
31
  side: str
32
+ state: ChatState
33
  log: Any
34
  prompt: str | None
35
  persona: PersonaData
 
38
  pending_key: str
39
 
40
 
41
+ def _get_compare_state(context_key: str, side: str) -> tuple[str, ChatState]:
42
+ panel_key = widget_key(context_key, f"cmp_{side}")
43
+ if panel_key not in st.session_state:
44
+ st.session_state[panel_key] = default_chat_state()
45
+ return panel_key, st.session_state[panel_key]
46
+
47
+
48
  def _reset_compare_panel(panel: ComparePanel) -> None:
49
  reset_chat_context_state(
50
  panel.state,
 
56
  st.session_state.pop(panel.edit_key, None)
57
 
58
 
59
+ def _render_compare_panel(
60
  *,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  context_key: str,
62
+ side: str,
63
  personas: list[PersonaData],
64
+ ) -> ComparePanel:
65
+ panel_key, state = _get_compare_state(context_key, side)
66
+
67
+ prompt_key = widget_key(panel_key, "custom_prompt")
68
+ edit_key = widget_key(panel_key, "edit_idx")
69
+ pending_key = widget_key(panel_key, "pending_regen")
70
+
71
+ persist_persona_key = f"chat:last_cmp_{side}_persona"
72
+ persist_prompt_key = f"chat:last_cmp_{side}_prompt"
73
+ if state["persona_id"] is None:
74
+ state["persona_id"] = st.session_state.get(persist_persona_key)
75
+ state["prompt_mode"] = st.session_state.get(persist_prompt_key, "templated")
76
+
77
+ selected_persona, prompt_mode, changed = render_persona_prompt_controls(
78
+ personas,
79
+ state["persona_id"],
80
+ state["prompt_mode"],
81
+ widget_key(panel_key, "persona"),
82
+ widget_key(panel_key, "prompt_mode"),
 
 
83
  )
84
+ st.session_state[persist_persona_key] = selected_persona.id
85
+ st.session_state[persist_prompt_key] = prompt_mode
86
+
87
+ if changed:
88
+ reset_chat_context_state(
89
+ state,
90
+ selected_persona.id,
91
+ prompt_mode,
92
+ prompt_key,
93
+ pending_key,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  )
95
+ st.session_state.pop(edit_key, None)
 
96
 
97
+ active_system_prompt = resolve_system_prompt(
98
+ persona=selected_persona,
99
+ mode=prompt_mode,
100
+ )
 
101
 
102
+ chat_log = st.container()
103
+ with chat_log:
104
+ active_system_prompt = render_system_prompt(
105
+ prompt_key,
106
+ prompt_mode,
107
+ active_system_prompt,
108
  )
109
 
110
+ return ComparePanel(
111
+ side=side,
112
+ state=state,
113
+ log=chat_log,
114
+ prompt=active_system_prompt,
115
+ persona=selected_persona,
116
+ prompt_key=prompt_key,
117
+ edit_key=edit_key,
118
+ pending_key=pending_key,
119
+ )
 
 
 
 
 
120
 
 
 
 
 
 
 
121
 
122
+ def _generate_panels(
123
+ *,
124
+ model: StandardizedTransformer,
125
+ remote: bool,
126
+ panels: list[ComparePanel],
127
+ generation: GenerationConfig,
128
+ spinner_label: str,
129
+ ) -> list[ChatReply | Exception]:
130
+ results: list[ChatReply | Exception] = []
131
+ with st.spinner(spinner_label):
132
+ for panel in panels:
133
+ try:
134
+ results.append(
135
+ generate_chat_reply(
136
+ model=model,
137
+ messages=build_chat_messages(
138
+ panel.prompt, panel.state["messages"]
139
+ ),
140
+ remote=remote,
141
+ past_key_values=panel.state["past_key_values"],
142
+ **generation.to_generate_kwargs(),
143
  )
144
+ )
145
+ except Exception as exc:
146
+ results.append(exc)
147
+ return results
148
 
149
+
150
+ def _apply_panel_results(
151
+ *,
152
+ panels: list[ComparePanel],
153
+ results: list[ChatReply | Exception],
154
+ remote: bool,
155
+ rollback_user_on_error: bool,
156
+ ) -> list[ChatReply | None]:
157
+ valid_results: list[ChatReply | None] = []
158
+ for panel, result in zip(panels, results, strict=True):
159
+ if isinstance(result, Exception):
160
+ with panel.log:
161
+ st.error(f"Generation failed: {result}")
162
+ if rollback_user_on_error and panel.state["messages"]:
163
+ panel.state["messages"].pop()
164
+ valid_results.append(None)
165
+ continue
166
+
167
+ panel.state["messages"].append({"role": "assistant", "content": result.text})
168
+ panel.state["past_key_values"] = result.past_key_values if not remote else None
169
+ valid_results.append(result)
170
+ return valid_results
171
+
172
+
173
+ def _pending_contrast_edits(panels: list[ComparePanel]) -> list[tuple[int, int]]:
174
+ return [
175
+ (panel_idx, msg_idx)
176
+ for panel_idx, panel in enumerate(panels)
177
+ for msg_idx, msg in enumerate(panel.state["messages"])
178
+ if msg.get("_needs_contrast") and msg.get("role") == "assistant"
179
+ ]
180
+
181
+
182
+ def _recompute_pending_contrast(
183
+ *,
184
+ model: StandardizedTransformer,
185
+ remote: bool,
186
+ panels: list[ComparePanel],
187
+ ) -> bool:
188
+ pending_edits = _pending_contrast_edits(panels)
189
+ if not pending_edits:
190
+ return False
191
+
192
+ left, right = panels
193
+ label_a = persona_label(left.persona)
194
+ label_b = persona_label(right.persona)
195
+ with st.spinner("Recomputing token contrast..."):
196
+ for panel_idx, msg_idx in pending_edits:
197
+ panel = panels[panel_idx]
198
+ msg = panel.state["messages"][msg_idx]
199
+ if msg_idx >= len(left.state["messages"]) or msg_idx >= len(
200
+ right.state["messages"]
201
+ ):
202
+ msg.pop("_needs_contrast", None)
203
  continue
204
+
205
+ context_a = build_chat_messages(
206
+ left.prompt,
207
+ left.state["messages"][:msg_idx],
208
  )
209
+ context_b = build_chat_messages(
210
+ right.prompt,
211
+ right.state["messages"][:msg_idx],
212
  )
213
+ try:
214
+ response_ids = model.tokenizer(
215
+ msg["content"],
216
+ add_special_tokens=False,
217
+ return_tensors="pt",
218
+ ).input_ids[0]
219
+ contrast = compute_contrast(
220
+ model=model,
221
+ context_a=context_a,
222
+ context_b=context_b,
223
+ response_ids=response_ids,
224
+ label_a=label_a,
225
+ label_b=label_b,
226
+ remote=remote,
227
+ )
228
+ if contrast is not None:
229
+ msg["_contrast"] = contrast
230
+ except Exception as exc:
231
+ st.warning(f"Token contrast recompute failed: {exc}")
232
+ msg.pop("_needs_contrast", None)
233
+ return True
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
+ def _render_compare_history(
237
+ *,
238
+ panels: list[ComparePanel],
239
+ contrast_enabled: bool,
240
+ ) -> None:
241
  for panel in panels:
242
  render_chat_window(
243
  chat_log=panel.log,
 
249
  edit_column_ratio=(10, 1),
250
  )
251
 
252
+
253
+ def _render_compare_footer(
254
+ *,
255
+ context_key: str,
256
+ model_name: str,
257
+ dataset_source: str,
258
+ panels: list[ComparePanel],
259
+ generation: GenerationConfig,
260
+ ) -> None:
261
+ # Bumping this nonce after a reset gives the popover a fresh widget key,
262
+ # which forces Streamlit to re-mount it closed (popovers don't auto-close
263
+ # on click).
264
  reset_menu_nonce_key = widget_key(context_key, "cmp_reset_menu_nonce")
265
  if reset_menu_nonce_key not in st.session_state:
266
  st.session_state[reset_menu_nonce_key] = 0
267
+
268
+ footer = st.container()
269
  with footer:
270
  exp_col, rst_col, _spacer = st.columns([0.5, 0.5, 10], gap="xsmall")
271
  with exp_col:
 
284
  prompt_mode=panel.state["prompt_mode"],
285
  system_prompt=panel.prompt,
286
  messages=panel.state["messages"],
287
+ generation=generation.to_export_dict(),
288
  panel_label=panel.side,
289
  )
290
  st.toast("Exported", icon=":material/check:")
 
318
  st.session_state[reset_menu_nonce_key] += 1
319
  st.rerun()
320
 
 
 
 
 
 
 
 
 
 
321
 
322
+ def _append_user_prompt(panels: list[ComparePanel], user_prompt: str) -> None:
323
  for panel in panels:
324
  panel.state["messages"].append({"role": "user", "content": user_prompt})
325
  with panel.log:
326
  render_chat_message({"role": "user", "content": user_prompt})
327
 
 
 
 
 
328
 
329
+ def _compute_new_reply_contrast(
330
+ *,
331
+ model: StandardizedTransformer,
332
+ remote: bool,
333
+ panels: list[ComparePanel],
334
+ pre_gen_contexts: list[list[dict[str, str]]],
335
+ results: list[ChatReply | None],
336
+ ) -> None:
337
+ if len(results) != 2 or any(
338
+ result is None or result.generated_ids is None for result in results
339
+ ):
340
+ return
 
 
 
 
341
 
342
+ left, right = panels
343
+ with st.spinner("Computing token contrast..."):
344
+ try:
345
+ left_contrast, right_contrast = compute_contrast_pair(
346
+ model=model,
347
+ context_a=pre_gen_contexts[0],
348
+ context_b=pre_gen_contexts[1],
349
+ response_ids_a=results[0].generated_ids,
350
+ response_ids_b=results[1].generated_ids,
351
+ label_a=persona_label(left.persona),
352
+ label_b=persona_label(right.persona),
353
+ remote=remote,
354
+ )
355
+ if left_contrast is not None:
356
+ left.state["messages"][-1]["_contrast"] = left_contrast
357
+ if right_contrast is not None:
358
+ right.state["messages"][-1]["_contrast"] = right_contrast
359
+ except Exception as exc:
360
+ st.warning(f"Token contrast failed: {exc}")
361
 
 
 
 
362
 
363
+ def _render_compare_panels(
364
+ *,
365
+ context_key: str,
366
+ personas: list[PersonaData],
367
+ ) -> list[ComparePanel]:
368
+ left_col, right_col = st.columns(2)
369
+ with left_col:
370
+ left = _render_compare_panel(
371
+ context_key=context_key,
372
+ side="left",
373
+ personas=personas,
374
+ )
375
+ with right_col:
376
+ right = _render_compare_panel(
377
+ context_key=context_key,
378
+ side="right",
379
+ personas=personas,
380
+ )
381
+ return [left, right]
382
+
383
+
384
+ def render_compare_mode(
385
+ remote: bool,
386
+ model_name: str,
387
+ context_key: str,
388
+ dataset_source: str,
389
+ personas: list[PersonaData],
390
+ generation: GenerationConfig,
391
+ *,
392
+ contrast_enabled: bool,
393
+ ) -> None:
394
+ """Render the full side-by-side comparison UI."""
395
+
396
+ panels = _render_compare_panels(context_key=context_key, personas=personas)
397
+
398
+ regen_panels = [
399
+ panel for panel in panels if st.session_state.pop(panel.pending_key, False)
400
+ ]
401
+ if regen_panels:
402
+ results = _generate_panels(
403
+ model=cached_model(model_name=model_name),
404
+ remote=remote,
405
+ panels=regen_panels,
406
+ generation=generation,
407
+ spinner_label="Regenerating...",
408
+ )
409
+ _apply_panel_results(
410
+ panels=regen_panels,
411
+ results=results,
412
+ remote=remote,
413
+ rollback_user_on_error=False,
414
+ )
415
+ st.rerun()
416
+
417
+ if contrast_enabled and _recompute_pending_contrast(
418
+ model=cached_model(model_name=model_name),
419
+ remote=remote,
420
+ panels=panels,
421
  ):
422
+ st.rerun()
423
+
424
+ _render_compare_history(panels=panels, contrast_enabled=contrast_enabled)
425
+ _render_compare_footer(
426
+ context_key=context_key,
427
+ model_name=model_name,
428
+ dataset_source=dataset_source,
429
+ panels=panels,
430
+ generation=generation,
431
+ )
432
+
433
+ user_prompt = st.chat_input(
434
+ "Ask both...",
435
+ key=widget_key(context_key, "cmp_input"),
436
+ )
437
+ if not user_prompt:
438
+ return
439
+
440
+ _append_user_prompt(panels, user_prompt)
441
+ pre_gen_contexts = [
442
+ build_chat_messages(panel.prompt, panel.state["messages"]) for panel in panels
443
+ ]
444
+ model = cached_model(model_name=model_name)
445
+ results = _generate_panels(
446
+ model=model,
447
+ remote=remote,
448
+ panels=panels,
449
+ generation=generation,
450
+ spinner_label="Generating...",
451
+ )
452
+ valid_results = _apply_panel_results(
453
+ panels=panels,
454
+ results=results,
455
+ remote=remote,
456
+ rollback_user_on_error=True,
457
+ )
458
+ if contrast_enabled:
459
+ _compute_new_reply_contrast(
460
+ model=model,
461
+ remote=remote,
462
+ panels=panels,
463
+ pre_gen_contexts=pre_gen_contexts,
464
+ results=valid_results,
465
+ )
466
 
 
 
467
  st.rerun()
tabs/extract.py CHANGED
@@ -1,12 +1,9 @@
1
  import html
 
2
  from typing import Literal, cast
3
 
4
  import streamlit as st
5
- from persona_data.prompts import (
6
- BASELINE_PERSONA_ID,
7
- BASELINE_PERSONA_NAME,
8
- format_prompt,
9
- )
10
  from persona_data.synth_persona import PersonaData, QAPair
11
  from persona_vectors.artifacts import PERSONA_VARIANTS
12
  from persona_vectors.extraction import (
@@ -42,8 +39,11 @@ _ITEM_TYPE_OPTIONS = ["all", "mcq", "frq"]
42
  _SCOPE_OPTIONS = ["all", "individual", "shared"]
43
 
44
 
45
- def _option_index(options: list[str], value: str) -> int:
46
- return options.index(value) if value in options else 0
 
 
 
47
 
48
 
49
  def _remembered_select(
@@ -53,10 +53,11 @@ def _remembered_select(
53
  key: str,
54
  default: str = "all",
55
  ) -> str:
 
56
  selected = st.selectbox(
57
  label,
58
  options=options,
59
- index=_option_index(options, st.session_state.get(state_key, default)),
60
  key=key,
61
  )
62
  st.session_state[state_key] = selected
@@ -66,21 +67,13 @@ def _remembered_select(
66
  def _build_run_plan(
67
  selected_variants: list[str],
68
  runs: list[tuple[PersonaData, list[QAPair]]],
69
- ) -> list[tuple[PersonaData | None, list[QAPair], str]]:
70
- """Expand selected variants × personas into one (persona, qa, variant) per call.
71
-
72
- The baseline variant is run once across the first persona's QA pairs and
73
- has no associated persona.
74
- """
75
- plan: list[tuple[PersonaData | None, list[QAPair], str]] = []
76
- for variant in selected_variants:
77
- if variant == BASELINE_PERSONA_ID:
78
- _, qa_pairs = runs[0]
79
- plan.append((None, qa_pairs, variant))
80
- else:
81
- for persona, qa_pairs in runs:
82
- plan.append((persona, qa_pairs, variant))
83
- return plan
84
 
85
 
86
  def _extract_widget_key(
@@ -89,103 +82,55 @@ def _extract_widget_key(
89
  return widget_key("extract", str(remote), model_name, dataset_source, suffix)
90
 
91
 
92
- _TOKEN_LEGEND = (
93
- '<div style="display:flex;gap:12px;flex-wrap:wrap;font-size:0.8em;margin-bottom:8px">'
94
- '<span style="background:#86efac;color:black;padding:1px 6px;border-radius:3px">masked</span>'
95
- '<span style="color:#fde047;padding:1px 6px">question</span>'
96
- '<span style="color:#22d3ee;padding:1px 6px">response</span>'
97
- '<span style="color:#d946ef;font-weight:bold;padding:1px 6px">special</span>'
98
- '<span style="color:#9ca3af;padding:1px 6px">template</span>'
99
- "</div>"
100
- )
101
-
102
- _MAX_PREVIEW_SAMPLES = 3
103
-
104
-
105
- def _token_style(segment: TokenSegment) -> str:
106
- style = {
107
- "response": "color:#22d3ee",
108
- "question": "color:#fde047",
109
- }.get(segment.role, "color:#9ca3af")
110
-
111
- if segment.is_special:
112
- style = "color:#d946ef;font-weight:bold"
113
- if segment.is_masked:
114
- style = f"{style};background:#86efac;border-radius:2px;padding:0 1px"
115
- return style
116
-
117
-
118
- def _render_sample_tokens_html(p, tokenizer, *, max_tokens: int = 200) -> str:
119
- spans: list[str] = []
120
- for segment in preview_token_segments(p, tokenizer, max_tokens=max_tokens):
121
- spans.append(
122
- f'<span style="{_token_style(segment)}">{html.escape(segment.text)}</span>'
123
  )
124
 
125
- return (
126
- '<pre style="white-space:pre-wrap;font-size:0.82em;line-height:1.5;'
127
- "background:#0e1117;padding:8px 10px;border-radius:6px;"
128
- 'border:1px solid #333;margin:0">'
129
- f"{''.join(spans)}</pre>"
130
- )
131
-
132
-
133
- def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> None:
134
- """Render the extraction tab."""
135
-
136
- st.title("Extract")
137
-
138
- if dataset_source == "Local JSONL upload":
139
- with st.expander("Local dataset upload", expanded=True):
140
- st.file_uploader(
141
- "personas.jsonl",
142
- type=["jsonl"],
143
- key="extract__personas_file",
144
- help="Expected fields: id, persona, templated_view, biography_view",
145
- )
146
- st.file_uploader(
147
- "qa.jsonl",
148
- type=["jsonl"],
149
- key="extract__qa_file",
150
- help="Expected fields: id, qid, type, item_type, scope, question, answer",
151
- )
152
 
153
- last_variants = st.session_state.get(
154
- _LAST_VARIANTS_KEY, [*PERSONA_VARIANTS, BASELINE_PERSONA_ID]
155
- )
156
- default_persona_variants = [
157
- v for v in last_variants if v in PERSONA_VARIANTS
158
- ] or list(PERSONA_VARIANTS)
159
- selected_persona_variants = st.multiselect(
 
160
  "Persona variants",
161
  options=PERSONA_VARIANTS,
162
- default=default_persona_variants,
 
163
  format_func=prompt_variant_label,
164
  key=_extract_widget_key(model_name, remote, dataset_source, "persona_variants"),
165
  help="Extract these variants for each selected persona.",
166
  )
167
  include_baseline = st.checkbox(
168
  "Extract Assistant baseline",
169
- value=st.session_state.get(
170
- _LAST_BASELINE_KEY,
171
- BASELINE_PERSONA_ID in last_variants,
172
- ),
173
  key=_extract_widget_key(model_name, remote, dataset_source, "baseline"),
174
- help=(
175
- "Extracts the persona-less Assistant prompt once using the first "
176
- "selected persona's QA set."
177
- ),
178
  )
179
- selected_variants = [
180
- *selected_persona_variants,
181
- *([BASELINE_PERSONA_ID] if include_baseline else []),
182
- ]
183
  st.session_state[_LAST_VARIANTS_KEY] = selected_variants
184
  st.session_state[_LAST_BASELINE_KEY] = include_baseline
185
  if not selected_variants:
186
- st.info("Select at least one persona variant or enable the baseline.")
187
- return
 
 
188
 
 
189
  try:
190
  dataset, dataset_status = load_dataset(
191
  dataset_source,
@@ -198,11 +143,11 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
198
  st.info(
199
  "Upload both JSONL files or switch to the built-in SynthPersona source."
200
  )
201
- return
202
 
203
  if not getattr(dataset, "supports_qa", True):
204
  st.info("This dataset is persona-only for now. Use Chat to browse personas.")
205
- return
206
 
207
  personas = list(dataset)
208
  if not personas:
@@ -210,8 +155,17 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
210
  st.info(
211
  "Try another dataset source or check that the personas file is not empty."
212
  )
213
- return
 
 
214
 
 
 
 
 
 
 
 
215
  last_persona_ids: set[str] = set(st.session_state.get(_LAST_PERSONA_IDS_KEY, []))
216
  default_personas = [p for p in personas if p.id in last_persona_ids] or [
217
  personas[0]
@@ -227,185 +181,295 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
227
 
228
  if not selected_personas:
229
  st.info("Select at least one persona.")
230
- return
 
231
 
232
- with st.expander("Advanced", expanded=False):
233
- st.caption("Filters")
234
 
235
- col1, col2, col3 = st.columns(3)
236
- with col1:
237
- qa_type_select = _remembered_select(
238
- "QA type",
239
- _QA_TYPE_OPTIONS,
240
- _LAST_QA_TYPE_KEY,
241
- key=_extract_widget_key(
242
- model_name, remote, dataset_source, "qa_type_select"
243
- ),
244
- )
245
- qa_filter_type: Literal["explicit", "implicit"] | None = (
246
- cast(Literal["explicit", "implicit"], qa_type_select)
247
- if qa_type_select in ("explicit", "implicit")
248
- else None
249
- )
250
- with col2:
251
- item_type_select = _remembered_select(
252
- "Item type",
253
- _ITEM_TYPE_OPTIONS,
254
- _LAST_ITEM_TYPE_KEY,
255
- key=_extract_widget_key(
256
- model_name, remote, dataset_source, "item_type_select"
257
- ),
258
- )
259
- qa_filter_item_type: Literal["mcq", "frq"] | None = (
260
- cast(Literal["mcq", "frq"], item_type_select)
261
- if item_type_select in ("mcq", "frq")
262
- else None
263
- )
264
- with col3:
265
- scope_select = _remembered_select(
266
- "Scope",
267
- _SCOPE_OPTIONS,
268
- _LAST_SCOPE_KEY,
269
- key=_extract_widget_key(
270
- model_name,
271
- remote,
272
- dataset_source,
273
- "scope_select",
274
- ),
275
- )
276
- qa_filter_scope: Literal["individual", "shared"] | None = (
277
- cast(Literal["individual", "shared"], scope_select)
278
- if scope_select in ("individual", "shared")
279
- else None
280
- )
281
 
282
- st.caption("Extraction settings")
283
- last_strategy = st.session_state.get(
284
- _LAST_MASK_STRATEGY_KEY,
285
- MaskStrategy.ANSWER_MEAN.value,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  )
287
- strategy_options = list(MaskStrategy)
288
- mask_strategy = st.selectbox(
289
- "Mask strategy",
290
- options=strategy_options,
291
- index=next(
292
- (
293
- idx
294
- for idx, strategy in enumerate(strategy_options)
295
- if strategy.value == last_strategy
296
- ),
297
- 0,
298
- ),
299
- format_func=lambda s: s.value.replace("_", " ").title(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  key=_extract_widget_key(
301
  model_name,
302
  remote,
303
  dataset_source,
304
- "mask_strategy",
305
  ),
306
- help="Which tokens contribute to the averaged hidden state.",
307
  )
308
- st.session_state[_LAST_MASK_STRATEGY_KEY] = mask_strategy.value
309
-
310
- runs, skipped = [], []
311
- for persona in selected_personas:
312
- qa = list(
313
- dataset.get_qa(
314
- persona.id,
315
- type=qa_filter_type,
316
- item_type=qa_filter_item_type,
317
- scope=qa_filter_scope,
318
- )
319
- )
320
- if qa:
321
- runs.append((persona, qa))
322
- else:
323
- skipped.append(persona)
324
- if skipped:
325
- names = ", ".join(p.name for p in skipped)
326
- st.warning(f"No QA pairs match filters for: {names}. They will be skipped.")
327
-
328
- if not runs:
329
- st.info("No personas have matching QA pairs. Widen the filters.")
330
- return
331
-
332
- max_q = min(len(qa_pairs) for _, qa_pairs in runs)
333
- max_questions = st.slider(
334
- "Max questions",
335
- min_value=1,
336
- max_value=max_q,
337
- value=min(
338
- max(st.session_state.get(_LAST_MAX_QUESTIONS_KEY, max_q), 1),
339
- max_q,
340
- ),
341
- key=_extract_widget_key(
342
- model_name, remote, dataset_source, "max_questions"
 
 
 
 
 
343
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  )
345
- st.session_state[_LAST_MAX_QUESTIONS_KEY] = max_questions
346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  run_col, preview_col, _spacer = st.columns([1, 1, 4], gap="small")
348
  with run_col:
349
  run_clicked = st.button(
350
- "Run extraction", type="primary", use_container_width=True
 
 
351
  )
352
  with preview_col:
353
  preview_clicked = st.button("Preview tokens", use_container_width=True)
354
-
355
- run_plan = _build_run_plan(selected_variants, runs)
356
-
357
- def _row_label(persona: PersonaData | None, variant: str) -> str:
358
- name = persona.name if persona is not None else BASELINE_PERSONA_NAME
359
- return f"{name} · {prompt_variant_label(variant)}"
360
-
361
- if preview_clicked:
362
- with st.spinner("Loading tokenizer..."):
363
- model = cached_model(model_name=model_name, remote=remote)
364
- st.markdown(_TOKEN_LEGEND, unsafe_allow_html=True)
365
- for persona, qa_pairs, variant in run_plan:
366
- system_prompt = (
367
- format_prompt()
368
- if persona is None
369
- else format_prompt(persona, variant) # type: ignore[arg-type]
370
- )
371
- prepared = prepare_inputs_for_strategy(
372
- tokenizer=model.tokenizer,
373
- system_prompt=system_prompt,
374
- qa_pairs=qa_pairs[:max_questions],
375
- mask_strategy=mask_strategy,
376
- )
377
- st.caption(_row_label(persona, variant))
378
- for i, p in enumerate(prepared[:_MAX_PREVIEW_SAMPLES]):
379
- question = (
380
- p.question if len(p.question) <= 60 else p.question[:57] + "..."
 
 
 
 
381
  )
382
- seq_len = int(p.input_ids.shape[0])
383
- masked = int(p.token_mask.sum())
384
- label = f"sample {i} {question} (len={seq_len}, masked={masked})"
385
- with st.expander(label):
386
- st.markdown(
387
- _render_sample_tokens_html(p, model.tokenizer),
388
- unsafe_allow_html=True,
389
- )
390
- if len(prepared) > _MAX_PREVIEW_SAMPLES:
391
- remaining = len(prepared) - _MAX_PREVIEW_SAMPLES
392
- st.caption(f"… and {remaining} more sample(s) not shown.")
393
- return
394
-
395
- if not run_clicked:
396
- return
397
-
398
  status_box = st.empty()
399
  status_box.info("Extraction in progress...")
400
  progress = st.progress(0, text="Preparing extraction...")
401
- ndif_status_box = st.empty() # shows live NDIF job status when remote=True
402
 
403
  def _on_ndif_status(job_id: str, status_name: str, description: str) -> None:
404
  icon = NDIF_STATUS_ICONS.get(status_name, "•")
405
  ndif_status_box.caption(f"{icon} `{job_id}` **{status_name}** — {description}")
406
 
407
  with st.spinner("Loading model..."):
408
- model = cached_model(model_name=model_name, remote=remote)
409
 
410
  try:
411
  total_steps = len(run_plan)
@@ -419,10 +483,10 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
419
  run_extraction(
420
  model=model,
421
  model_name=model_name,
422
- qa_pairs=qa_pairs[:max_questions],
423
  variants=(variant,),
424
  persona=persona,
425
- mask_strategy=mask_strategy,
426
  remote=remote,
427
  on_status=_on_ndif_status if remote else None,
428
  )
@@ -444,3 +508,71 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
444
  f"- **{result.persona_name}** · {prompt_variant_label(result.variant)}: "
445
  f"{result.n_questions} questions"
446
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import html
2
+ from dataclasses import dataclass
3
  from typing import Literal, cast
4
 
5
  import streamlit as st
6
+ from persona_data.prompts import format_prompt
 
 
 
 
7
  from persona_data.synth_persona import PersonaData, QAPair
8
  from persona_vectors.artifacts import PERSONA_VARIANTS
9
  from persona_vectors.extraction import (
 
39
  _SCOPE_OPTIONS = ["all", "individual", "shared"]
40
 
41
 
42
+ @dataclass(frozen=True)
43
+ class ExtractSettings:
44
+ runs: list[tuple[PersonaData, list[QAPair]]]
45
+ mask_strategy: MaskStrategy
46
+ max_questions: int
47
 
48
 
49
  def _remembered_select(
 
53
  key: str,
54
  default: str = "all",
55
  ) -> str:
56
+ current = st.session_state.get(state_key, default)
57
  selected = st.selectbox(
58
  label,
59
  options=options,
60
+ index=options.index(current) if current in options else 0,
61
  key=key,
62
  )
63
  st.session_state[state_key] = selected
 
67
  def _build_run_plan(
68
  selected_variants: list[str],
69
  runs: list[tuple[PersonaData, list[QAPair]]],
70
+ ) -> list[tuple[PersonaData, list[QAPair], str]]:
71
+ """Cartesian product of personas × variants."""
72
+ return [(p, qa, v) for v in selected_variants for p, qa in runs]
73
+
74
+
75
+ def _row_label(persona: PersonaData, variant: str) -> str:
76
+ return f"{persona.name} · {prompt_variant_label(variant)}"
 
 
 
 
 
 
 
 
77
 
78
 
79
  def _extract_widget_key(
 
82
  return widget_key("extract", str(remote), model_name, dataset_source, suffix)
83
 
84
 
85
+ def _render_local_dataset_upload(dataset_source: str) -> None:
86
+ if dataset_source != "Local JSONL upload":
87
+ return
88
+ with st.expander("Local dataset upload", expanded=True):
89
+ st.file_uploader(
90
+ "personas.jsonl",
91
+ type=["jsonl"],
92
+ key="extract__personas_file",
93
+ help="Expected fields: id, persona, templated_view, biography_view",
94
+ )
95
+ st.file_uploader(
96
+ "qa.jsonl",
97
+ type=["jsonl"],
98
+ key="extract__qa_file",
99
+ help="Expected fields: id, qid, type, item_type, scope, question, answer",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  )
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
+ def _render_variant_controls(
104
+ *,
105
+ model_name: str,
106
+ remote: bool,
107
+ dataset_source: str,
108
+ ) -> tuple[list[str], bool] | None:
109
+ default_variants = st.session_state.get(_LAST_VARIANTS_KEY, list(PERSONA_VARIANTS))
110
+ selected_variants = st.multiselect(
111
  "Persona variants",
112
  options=PERSONA_VARIANTS,
113
+ default=[v for v in default_variants if v in PERSONA_VARIANTS]
114
+ or list(PERSONA_VARIANTS),
115
  format_func=prompt_variant_label,
116
  key=_extract_widget_key(model_name, remote, dataset_source, "persona_variants"),
117
  help="Extract these variants for each selected persona.",
118
  )
119
  include_baseline = st.checkbox(
120
  "Extract Assistant baseline",
121
+ value=st.session_state.get(_LAST_BASELINE_KEY, True),
 
 
 
122
  key=_extract_widget_key(model_name, remote, dataset_source, "baseline"),
123
+ help="Also extract the Assistant baseline persona using the first persona's QA set.",
 
 
 
124
  )
 
 
 
 
125
  st.session_state[_LAST_VARIANTS_KEY] = selected_variants
126
  st.session_state[_LAST_BASELINE_KEY] = include_baseline
127
  if not selected_variants:
128
+ st.info("Select at least one persona variant.")
129
+ return None
130
+ return selected_variants, include_baseline
131
+
132
 
133
+ def _load_qa_dataset_personas(dataset_source: str) -> tuple[object, list[PersonaData]] | None:
134
  try:
135
  dataset, dataset_status = load_dataset(
136
  dataset_source,
 
143
  st.info(
144
  "Upload both JSONL files or switch to the built-in SynthPersona source."
145
  )
146
+ return None
147
 
148
  if not getattr(dataset, "supports_qa", True):
149
  st.info("This dataset is persona-only for now. Use Chat to browse personas.")
150
+ return None
151
 
152
  personas = list(dataset)
153
  if not personas:
 
155
  st.info(
156
  "Try another dataset source or check that the personas file is not empty."
157
  )
158
+ return None
159
+ return dataset, personas
160
+
161
 
162
+ def _render_persona_select(
163
+ *,
164
+ personas: list[PersonaData],
165
+ model_name: str,
166
+ remote: bool,
167
+ dataset_source: str,
168
+ ) -> list[PersonaData] | None:
169
  last_persona_ids: set[str] = set(st.session_state.get(_LAST_PERSONA_IDS_KEY, []))
170
  default_personas = [p for p in personas if p.id in last_persona_ids] or [
171
  personas[0]
 
181
 
182
  if not selected_personas:
183
  st.info("Select at least one persona.")
184
+ return None
185
+ return selected_personas
186
 
 
 
187
 
188
+ _TOKEN_LEGEND = (
189
+ '<div style="display:flex;gap:12px;flex-wrap:wrap;font-size:0.8em;margin-bottom:8px">'
190
+ '<span style="background:#86efac;color:black;padding:1px 6px;border-radius:3px">masked</span>'
191
+ '<span style="color:#fde047;padding:1px 6px">question</span>'
192
+ '<span style="color:#22d3ee;padding:1px 6px">response</span>'
193
+ '<span style="color:#d946ef;font-weight:bold;padding:1px 6px">special</span>'
194
+ '<span style="color:#9ca3af;padding:1px 6px">template</span>'
195
+ "</div>"
196
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
+ _MAX_PREVIEW_SAMPLES = 3
199
+
200
+
201
+ def _token_style(segment: TokenSegment) -> str:
202
+ style = {
203
+ "response": "color:#22d3ee",
204
+ "question": "color:#fde047",
205
+ }.get(segment.role, "color:#9ca3af")
206
+
207
+ if segment.is_special:
208
+ style = "color:#d946ef;font-weight:bold"
209
+ if segment.is_masked:
210
+ style = f"{style};background:#86efac;border-radius:2px;padding:0 1px"
211
+ return style
212
+
213
+
214
+ def _render_sample_tokens_html(p, tokenizer, *, max_tokens: int = 200) -> str:
215
+ spans: list[str] = []
216
+ for segment in preview_token_segments(p, tokenizer, max_tokens=max_tokens):
217
+ spans.append(
218
+ f'<span style="{_token_style(segment)}">{html.escape(segment.text)}</span>'
219
  )
220
+
221
+ return (
222
+ '<pre style="white-space:pre-wrap;font-size:0.82em;line-height:1.5;'
223
+ "background:var(--secondary-background-color,rgba(127,127,127,0.08));"
224
+ "padding:8px 10px;border-radius:6px;"
225
+ 'border:1px solid rgba(127,127,127,0.25);margin:0">'
226
+ f"{''.join(spans)}</pre>"
227
+ )
228
+
229
+
230
+ def _render_filter_controls(
231
+ *,
232
+ model_name: str,
233
+ remote: bool,
234
+ dataset_source: str,
235
+ ) -> tuple[
236
+ Literal["explicit", "implicit"] | None,
237
+ Literal["mcq", "frq"] | None,
238
+ Literal["individual", "shared"] | None,
239
+ ]:
240
+ col1, col2, col3 = st.columns(3)
241
+ with col1:
242
+ qa_type_select = _remembered_select(
243
+ "QA type",
244
+ _QA_TYPE_OPTIONS,
245
+ _LAST_QA_TYPE_KEY,
246
+ key=_extract_widget_key(model_name, remote, dataset_source, "qa_type_select"),
247
+ )
248
+ with col2:
249
+ item_type_select = _remembered_select(
250
+ "Item type",
251
+ _ITEM_TYPE_OPTIONS,
252
+ _LAST_ITEM_TYPE_KEY,
253
  key=_extract_widget_key(
254
  model_name,
255
  remote,
256
  dataset_source,
257
+ "item_type_select",
258
  ),
 
259
  )
260
+ with col3:
261
+ scope_select = _remembered_select(
262
+ "Scope",
263
+ _SCOPE_OPTIONS,
264
+ _LAST_SCOPE_KEY,
265
+ key=_extract_widget_key(model_name, remote, dataset_source, "scope_select"),
266
+ )
267
+
268
+ return (
269
+ cast(Literal["explicit", "implicit"], qa_type_select)
270
+ if qa_type_select in ("explicit", "implicit")
271
+ else None,
272
+ cast(Literal["mcq", "frq"], item_type_select)
273
+ if item_type_select in ("mcq", "frq")
274
+ else None,
275
+ cast(Literal["individual", "shared"], scope_select)
276
+ if scope_select in ("individual", "shared")
277
+ else None,
278
+ )
279
+
280
+
281
+ def _render_mask_strategy_select(
282
+ *,
283
+ model_name: str,
284
+ remote: bool,
285
+ dataset_source: str,
286
+ ) -> MaskStrategy:
287
+ last_strategy = st.session_state.get(
288
+ _LAST_MASK_STRATEGY_KEY,
289
+ MaskStrategy.ANSWER_MEAN.value,
290
+ )
291
+ strategy_options = list(MaskStrategy)
292
+ mask_strategy = st.selectbox(
293
+ "Mask strategy",
294
+ options=strategy_options,
295
+ index=next(
296
+ (
297
+ idx
298
+ for idx, strategy in enumerate(strategy_options)
299
+ if strategy.value == last_strategy
300
  ),
301
+ 0,
302
+ ),
303
+ format_func=lambda s: s.value.replace("_", " ").title(),
304
+ key=_extract_widget_key(model_name, remote, dataset_source, "mask_strategy"),
305
+ help="Which tokens contribute to the averaged hidden state.",
306
+ )
307
+ st.session_state[_LAST_MASK_STRATEGY_KEY] = mask_strategy.value
308
+ return mask_strategy
309
+
310
+
311
+ def _collect_runs(
312
+ *,
313
+ dataset,
314
+ selected_personas: list[PersonaData],
315
+ qa_filter_type: Literal["explicit", "implicit"] | None,
316
+ qa_filter_item_type: Literal["mcq", "frq"] | None,
317
+ qa_filter_scope: Literal["individual", "shared"] | None,
318
+ ) -> list[tuple[PersonaData, list[QAPair]]] | None:
319
+ runs, skipped = [], []
320
+ for persona in selected_personas:
321
+ qa = list(
322
+ dataset.get_qa(
323
+ persona.id,
324
+ type=qa_filter_type,
325
+ item_type=qa_filter_item_type,
326
+ scope=qa_filter_scope,
327
+ )
328
+ )
329
+ if qa:
330
+ runs.append((persona, qa))
331
+ else:
332
+ skipped.append(persona)
333
+ if skipped:
334
+ names = ", ".join(p.name for p in skipped)
335
+ st.warning(f"No QA pairs match filters for: {names}. They will be skipped.")
336
+
337
+ if not runs:
338
+ st.info("No personas have matching QA pairs. Widen the filters.")
339
+ return None
340
+ return runs
341
+
342
+
343
+ def _render_max_questions(
344
+ *,
345
+ model_name: str,
346
+ remote: bool,
347
+ dataset_source: str,
348
+ runs: list[tuple[PersonaData, list[QAPair]]],
349
+ ) -> int:
350
+ max_q = min(len(qa_pairs) for _, qa_pairs in runs)
351
+ max_questions = st.slider(
352
+ "Max questions",
353
+ min_value=1,
354
+ max_value=max_q,
355
+ value=min(max(st.session_state.get(_LAST_MAX_QUESTIONS_KEY, max_q), 1), max_q),
356
+ key=_extract_widget_key(model_name, remote, dataset_source, "max_questions"),
357
+ )
358
+ st.session_state[_LAST_MAX_QUESTIONS_KEY] = max_questions
359
+ return max_questions
360
+
361
+
362
+ def _render_advanced_settings(
363
+ *,
364
+ dataset,
365
+ selected_personas: list[PersonaData],
366
+ model_name: str,
367
+ remote: bool,
368
+ dataset_source: str,
369
+ ) -> ExtractSettings | None:
370
+ with st.expander("Advanced", expanded=False):
371
+ st.caption("Filters")
372
+ qa_filter_type, qa_filter_item_type, qa_filter_scope = _render_filter_controls(
373
+ model_name=model_name,
374
+ remote=remote,
375
+ dataset_source=dataset_source,
376
  )
 
377
 
378
+ st.caption("Extraction settings")
379
+ mask_strategy = _render_mask_strategy_select(
380
+ model_name=model_name,
381
+ remote=remote,
382
+ dataset_source=dataset_source,
383
+ )
384
+ runs = _collect_runs(
385
+ dataset=dataset,
386
+ selected_personas=selected_personas,
387
+ qa_filter_type=qa_filter_type,
388
+ qa_filter_item_type=qa_filter_item_type,
389
+ qa_filter_scope=qa_filter_scope,
390
+ )
391
+ if runs is None:
392
+ return None
393
+
394
+ max_questions = _render_max_questions(
395
+ model_name=model_name,
396
+ remote=remote,
397
+ dataset_source=dataset_source,
398
+ runs=runs,
399
+ )
400
+
401
+ return ExtractSettings(
402
+ runs=runs,
403
+ mask_strategy=mask_strategy,
404
+ max_questions=max_questions,
405
+ )
406
+
407
+
408
+ def _render_extract_actions() -> tuple[bool, bool]:
409
  run_col, preview_col, _spacer = st.columns([1, 1, 4], gap="small")
410
  with run_col:
411
  run_clicked = st.button(
412
+ "Run extraction",
413
+ type="primary",
414
+ use_container_width=True,
415
  )
416
  with preview_col:
417
  preview_clicked = st.button("Preview tokens", use_container_width=True)
418
+ return run_clicked, preview_clicked
419
+
420
+
421
+ def _render_token_preview(
422
+ *,
423
+ remote: bool,
424
+ model_name: str,
425
+ run_plan: list[tuple[PersonaData, list[QAPair], str]],
426
+ settings: ExtractSettings,
427
+ ) -> None:
428
+ with st.spinner("Loading tokenizer..."):
429
+ model = cached_model(model_name=model_name)
430
+ st.markdown(_TOKEN_LEGEND, unsafe_allow_html=True)
431
+ for persona, qa_pairs, variant in run_plan:
432
+ system_prompt = format_prompt(persona, variant) # type: ignore[arg-type]
433
+ prepared = prepare_inputs_for_strategy(
434
+ tokenizer=model.tokenizer,
435
+ system_prompt=system_prompt,
436
+ qa_pairs=qa_pairs[: settings.max_questions],
437
+ mask_strategy=settings.mask_strategy,
438
+ )
439
+ st.caption(_row_label(persona, variant))
440
+ for i, p in enumerate(prepared[:_MAX_PREVIEW_SAMPLES]):
441
+ question = p.question if len(p.question) <= 60 else p.question[:57] + "..."
442
+ seq_len = int(p.input_ids.shape[0])
443
+ masked = int(p.token_mask.sum())
444
+ label = f"sample {i} — {question} (len={seq_len}, masked={masked})"
445
+ with st.expander(label):
446
+ st.markdown(
447
+ _render_sample_tokens_html(p, model.tokenizer),
448
+ unsafe_allow_html=True,
449
  )
450
+ if len(prepared) > _MAX_PREVIEW_SAMPLES:
451
+ remaining = len(prepared) - _MAX_PREVIEW_SAMPLES
452
+ st.caption(f" and {remaining} more sample(s) not shown.")
453
+
454
+
455
+ def _run_extraction_plan(
456
+ *,
457
+ remote: bool,
458
+ model_name: str,
459
+ run_plan: list[tuple[PersonaData, list[QAPair], str]],
460
+ settings: ExtractSettings,
461
+ ) -> None:
 
 
 
 
462
  status_box = st.empty()
463
  status_box.info("Extraction in progress...")
464
  progress = st.progress(0, text="Preparing extraction...")
465
+ ndif_status_box = st.empty()
466
 
467
  def _on_ndif_status(job_id: str, status_name: str, description: str) -> None:
468
  icon = NDIF_STATUS_ICONS.get(status_name, "•")
469
  ndif_status_box.caption(f"{icon} `{job_id}` **{status_name}** — {description}")
470
 
471
  with st.spinner("Loading model..."):
472
+ model = cached_model(model_name=model_name)
473
 
474
  try:
475
  total_steps = len(run_plan)
 
483
  run_extraction(
484
  model=model,
485
  model_name=model_name,
486
+ qa_pairs=qa_pairs[: settings.max_questions],
487
  variants=(variant,),
488
  persona=persona,
489
+ mask_strategy=settings.mask_strategy,
490
  remote=remote,
491
  on_status=_on_ndif_status if remote else None,
492
  )
 
508
  f"- **{result.persona_name}** · {prompt_variant_label(result.variant)}: "
509
  f"{result.n_questions} questions"
510
  )
511
+
512
+
513
+ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> None:
514
+ """Render the extraction tab."""
515
+
516
+ st.title("Extract")
517
+ st.caption("Extract per-persona activation vectors from QA pairs.")
518
+
519
+ _render_local_dataset_upload(dataset_source)
520
+ variant_choice = _render_variant_controls(
521
+ model_name=model_name,
522
+ remote=remote,
523
+ dataset_source=dataset_source,
524
+ )
525
+ if variant_choice is None:
526
+ return
527
+ selected_variants, include_baseline = variant_choice
528
+
529
+ loaded = _load_qa_dataset_personas(dataset_source)
530
+ if loaded is None:
531
+ return
532
+ dataset, personas = loaded
533
+
534
+ selected_personas = _render_persona_select(
535
+ personas=personas,
536
+ model_name=model_name,
537
+ remote=remote,
538
+ dataset_source=dataset_source,
539
+ )
540
+ if selected_personas is None:
541
+ return
542
+
543
+ settings = _render_advanced_settings(
544
+ dataset=dataset,
545
+ selected_personas=selected_personas,
546
+ model_name=model_name,
547
+ remote=remote,
548
+ dataset_source=dataset_source,
549
+ )
550
+ if settings is None:
551
+ return
552
+
553
+ runs = list(settings.runs)
554
+ baseline = getattr(dataset, "baseline", None)
555
+ if include_baseline and baseline is not None and runs:
556
+ runs.append((baseline, runs[0][1]))
557
+
558
+ run_clicked, preview_clicked = _render_extract_actions()
559
+ run_plan = _build_run_plan(selected_variants, runs)
560
+
561
+ if preview_clicked:
562
+ _render_token_preview(
563
+ remote=remote,
564
+ model_name=model_name,
565
+ run_plan=run_plan,
566
+ settings=settings,
567
+ )
568
+ return
569
+
570
+ if not run_clicked:
571
+ return
572
+
573
+ _run_extraction_plan(
574
+ remote=remote,
575
+ model_name=model_name,
576
+ run_plan=run_plan,
577
+ settings=settings,
578
+ )
tabs/probe_ui.py CHANGED
@@ -16,13 +16,6 @@ from utils.probes import (
16
  from utils.runtime import cached_model
17
 
18
 
19
- def _token_button_label(index: int, token: str) -> str:
20
- display = token.encode("unicode_escape").decode("ascii") or "<empty>"
21
- if len(display) > 18:
22
- display = display[:15] + "..."
23
- return f"{index}: {display}"
24
-
25
-
26
  def _render_probe_results(result: ProbeRunResult, probe: LoadedProbe) -> None:
27
  top_k = min(5, int(result.probabilities.numel()))
28
  if top_k == 0:
@@ -95,55 +88,27 @@ def _load_probe_from_controls(context_key: str) -> LoadedProbe | None:
95
  return load_probe(repo_id.strip(), selected_file)
96
 
97
 
98
- def _render_token_buttons(trace: ConversationTrace, context_key: str) -> int:
99
- selected_key = widget_key(
100
- context_key,
101
- "probe_selected_token",
102
- trace.prompt_hash[:12],
103
- )
104
- selected = int(st.session_state.get(selected_key, trace.n_tokens - 1))
105
- selected = max(0, min(selected, trace.n_tokens - 1))
106
-
107
- window_size = st.slider(
108
- "Token window",
109
- min_value=8,
110
- max_value=min(96, max(8, trace.n_tokens)),
111
- value=min(32, max(8, trace.n_tokens)),
112
- step=8,
113
- key=widget_key(context_key, "probe_token_window", trace.prompt_hash[:12]),
114
- )
115
- center = st.slider(
116
- "Window center",
117
  min_value=0,
118
  max_value=trace.n_tokens - 1,
119
- value=selected,
120
- key=widget_key(context_key, "probe_token_center", trace.prompt_hash[:12]),
121
  )
122
- start = max(0, center - window_size // 2)
123
- end = min(trace.n_tokens, start + window_size)
124
- start = max(0, end - window_size)
125
-
126
- cols = st.columns(8)
127
- for offset, token_index in enumerate(range(start, end)):
128
- col = cols[offset % len(cols)]
129
- token = trace.tokens[token_index]
130
- if col.button(
131
- _token_button_label(token_index, token),
132
- key=widget_key(
133
- context_key,
134
- "probe_token",
135
- trace.prompt_hash[:12],
136
- str(token_index),
137
- ),
138
- type="primary" if token_index == selected else "secondary",
139
- help=token.encode("unicode_escape").decode("ascii"),
140
- ):
141
- selected = token_index
142
- st.session_state[selected_key] = token_index
143
 
144
- st.caption(
145
- f"Selected token {selected}: "
146
- f"`{trace.tokens[selected].encode('unicode_escape').decode('ascii')}`"
 
 
 
 
 
 
 
 
 
147
  )
148
  return selected
149
 
@@ -163,6 +128,128 @@ def _model_dimensions(model: object) -> tuple[int, int]:
163
  return int(hidden_size), int(num_layers)
164
 
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  def render_probe_inspector(
167
  *,
168
  context_key: str,
@@ -188,88 +275,40 @@ def render_probe_inspector(
188
  if probe is None:
189
  return
190
 
191
- with st.spinner("Loading model metadata..."):
192
- model = cached_model(model_name=model_name, remote=remote)
193
- try:
194
- hidden_size, num_layers = _model_dimensions(model)
195
- except Exception as exc:
196
- st.error(str(exc))
197
  return
 
198
 
199
- layer = probe.layer
200
- if layer is None:
201
- layer = int(
202
- st.number_input(
203
- "Layer",
204
- min_value=0,
205
- max_value=max(0, num_layers - 1),
206
- value=min(15, max(0, num_layers - 1)),
207
- step=1,
208
- key=widget_key(context_key, "probe_layer"),
209
- )
210
- )
211
-
212
- location = probe.location
213
- if location is None:
214
- location = st.selectbox(
215
- "Activation location",
216
- options=("post_reasoning", "pre_reasoning"),
217
- key=widget_key(context_key, "probe_location"),
218
- )
219
-
220
  st.caption(
221
  f"Probe layer {layer}; {location}; input dim {probe.input_dim}; "
222
  f"model hidden size {hidden_size}"
223
  )
224
- if not 0 <= layer < num_layers:
225
- st.error(f"Probe layer {layer} is outside the model's {num_layers} layers.")
226
- return
227
- if probe.input_dim != hidden_size:
228
- st.warning(
229
- "This probe input dim does not match a single-token activation "
230
- "for the active model."
231
- )
232
- return
233
-
234
- trace_key = widget_key(context_key, "probe_trace_enabled")
235
- if st.button(
236
- "Trace conversation",
237
- key=widget_key(context_key, "probe_trace"),
238
- use_container_width=True,
239
  ):
240
- st.session_state[trace_key] = True
241
- if not st.session_state.get(trace_key, False):
242
- return
243
-
244
- messages = build_chat_messages(active_system_prompt, chat_state["messages"])
245
- with st.spinner("Tracing conversation..."):
246
- trace = trace_conversation(
247
- model=model,
248
- model_name=model_name,
249
- messages=messages,
250
- layer=layer,
251
- location=location,
252
- remote=remote,
253
- )
254
-
255
- st.caption(
256
- f"Cached {trace.n_tokens} tokens from layer {trace.layer}; "
257
- f"prompt hash `{trace.prompt_hash[:10]}`"
258
- )
259
- if trace.n_tokens == 0:
260
- st.warning("The traced conversation produced no tokens.")
261
  return
262
 
263
- selected_token = _render_token_buttons(trace, context_key)
264
- try:
265
- vector = vectorize_token(trace, token_index=selected_token)
266
- result = probe.run(vector.vector)
267
- except Exception as exc:
268
- st.error(f"Probe execution failed: {exc}")
269
  return
270
 
271
- st.caption(
272
- f"Vectorization {vector.mode}; token {vector.token_index}; "
273
- f"vector dim {int(vector.vector.shape[0])}"
 
 
 
 
 
274
  )
275
- _render_probe_results(result, probe)
 
 
 
16
  from utils.runtime import cached_model
17
 
18
 
 
 
 
 
 
 
 
19
  def _render_probe_results(result: ProbeRunResult, probe: LoadedProbe) -> None:
20
  top_k = min(5, int(result.probabilities.numel()))
21
  if top_k == 0:
 
88
  return load_probe(repo_id.strip(), selected_file)
89
 
90
 
91
+ def _render_token_picker(trace: ConversationTrace, context_key: str) -> int:
92
+ selected = st.slider(
93
+ "Token index",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  min_value=0,
95
  max_value=trace.n_tokens - 1,
96
+ value=trace.n_tokens - 1,
97
+ key=widget_key(context_key, "probe_selected_token", trace.prompt_hash[:12]),
98
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ window = 8
101
+ start = max(0, selected - window)
102
+ end = min(trace.n_tokens, selected + window + 1)
103
+ parts: list[str] = []
104
+ for i in range(start, end):
105
+ token_repr = trace.tokens[i].encode("unicode_escape").decode("ascii") or "·"
106
+ parts.append(f"**[{token_repr}]**" if i == selected else token_repr)
107
+ st.markdown(
108
+ f"<div style='font-family:ui-monospace,monospace;font-size:0.85em;"
109
+ f"line-height:1.6;background:rgba(127,127,127,0.08);padding:6px 10px;"
110
+ f"border-radius:4px;'>{' '.join(parts)}</div>",
111
+ unsafe_allow_html=True,
112
  )
113
  return selected
114
 
 
128
  return int(hidden_size), int(num_layers)
129
 
130
 
131
+ def _load_model_with_dimensions(model_name: str) -> tuple[object, int, int] | None:
132
+ with st.spinner("Loading model metadata..."):
133
+ model = cached_model(model_name=model_name)
134
+ try:
135
+ hidden_size, num_layers = _model_dimensions(model)
136
+ except Exception as exc:
137
+ st.error(str(exc))
138
+ return None
139
+ return model, hidden_size, num_layers
140
+
141
+
142
+ def _select_probe_target(
143
+ *,
144
+ probe: LoadedProbe,
145
+ context_key: str,
146
+ num_layers: int,
147
+ ) -> tuple[int, str]:
148
+ layer = probe.layer
149
+ if layer is None:
150
+ layer = int(
151
+ st.number_input(
152
+ "Layer",
153
+ min_value=0,
154
+ max_value=max(0, num_layers - 1),
155
+ value=min(15, max(0, num_layers - 1)),
156
+ step=1,
157
+ key=widget_key(context_key, "probe_layer"),
158
+ )
159
+ )
160
+
161
+ location = probe.location
162
+ if location is None:
163
+ location = st.selectbox(
164
+ "Activation location",
165
+ options=("post_reasoning", "pre_reasoning"),
166
+ key=widget_key(context_key, "probe_location"),
167
+ )
168
+ return layer, location
169
+
170
+
171
+ def _probe_target_is_valid(
172
+ *,
173
+ probe: LoadedProbe,
174
+ layer: int,
175
+ num_layers: int,
176
+ hidden_size: int,
177
+ ) -> bool:
178
+ if not 0 <= layer < num_layers:
179
+ st.error(f"Probe layer {layer} is outside the model's {num_layers} layers.")
180
+ return False
181
+ if probe.input_dim != hidden_size:
182
+ st.warning(
183
+ "This probe input dim does not match a single-token activation "
184
+ "for the active model."
185
+ )
186
+ return False
187
+ return True
188
+
189
+
190
+ def _trace_requested(context_key: str) -> bool:
191
+ trace_key = widget_key(context_key, "probe_trace_enabled")
192
+ if st.button(
193
+ "Trace conversation",
194
+ key=widget_key(context_key, "probe_trace"),
195
+ use_container_width=True,
196
+ ):
197
+ st.session_state[trace_key] = True
198
+ return bool(st.session_state.get(trace_key, False))
199
+
200
+
201
+ def _trace_active_conversation(
202
+ *,
203
+ model: object,
204
+ model_name: str,
205
+ remote: bool,
206
+ active_system_prompt: str | None,
207
+ chat_state: dict[str, object],
208
+ layer: int,
209
+ location: str,
210
+ ) -> ConversationTrace | None:
211
+ messages = build_chat_messages(active_system_prompt, chat_state["messages"])
212
+ with st.spinner("Tracing conversation..."):
213
+ trace = trace_conversation(
214
+ model=model,
215
+ model_name=model_name,
216
+ messages=messages,
217
+ layer=layer,
218
+ location=location,
219
+ remote=remote,
220
+ )
221
+
222
+ st.caption(
223
+ f"Cached {trace.n_tokens} tokens from layer {trace.layer}; "
224
+ f"prompt hash `{trace.prompt_hash[:10]}`"
225
+ )
226
+ if trace.n_tokens == 0:
227
+ st.warning("The traced conversation produced no tokens.")
228
+ return None
229
+ return trace
230
+
231
+
232
+ def _run_probe_on_selected_token(
233
+ *,
234
+ trace: ConversationTrace,
235
+ context_key: str,
236
+ probe: LoadedProbe,
237
+ ) -> None:
238
+ selected_token = _render_token_picker(trace, context_key)
239
+ try:
240
+ vector = vectorize_token(trace, token_index=selected_token)
241
+ result = probe.run(vector.vector)
242
+ except Exception as exc:
243
+ st.error(f"Probe execution failed: {exc}")
244
+ return
245
+
246
+ st.caption(
247
+ f"Vectorization {vector.mode}; token {vector.token_index}; "
248
+ f"vector dim {int(vector.vector.shape[0])}"
249
+ )
250
+ _render_probe_results(result, probe)
251
+
252
+
253
  def render_probe_inspector(
254
  *,
255
  context_key: str,
 
275
  if probe is None:
276
  return
277
 
278
+ loaded = _load_model_with_dimensions(model_name)
279
+ if loaded is None:
 
 
 
 
280
  return
281
+ model, hidden_size, num_layers = loaded
282
 
283
+ layer, location = _select_probe_target(
284
+ probe=probe,
285
+ context_key=context_key,
286
+ num_layers=num_layers,
287
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  st.caption(
289
  f"Probe layer {layer}; {location}; input dim {probe.input_dim}; "
290
  f"model hidden size {hidden_size}"
291
  )
292
+ if not _probe_target_is_valid(
293
+ probe=probe,
294
+ layer=layer,
295
+ num_layers=num_layers,
296
+ hidden_size=hidden_size,
 
 
 
 
 
 
 
 
 
 
297
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  return
299
 
300
+ if not _trace_requested(context_key):
 
 
 
 
 
301
  return
302
 
303
+ trace = _trace_active_conversation(
304
+ model=model,
305
+ model_name=model_name,
306
+ remote=remote,
307
+ active_system_prompt=active_system_prompt,
308
+ chat_state=chat_state,
309
+ layer=layer,
310
+ location=location,
311
  )
312
+ if trace is None:
313
+ return
314
+ _run_probe_on_selected_token(trace=trace, context_key=context_key, probe=probe)
utils/contrast.py CHANGED
@@ -11,7 +11,6 @@ Negative (blue) → token is more characteristic of persona B.
11
  Near-zero (gray) → both personas would emit this token with similar likelihood.
12
  """
13
 
14
- import logging
15
  from dataclasses import dataclass
16
  from html import escape
17
 
@@ -20,8 +19,6 @@ from nnterp import StandardizedTransformer
20
 
21
  from utils.chat import format_generation_prompt
22
 
23
- logger = logging.getLogger(__name__)
24
-
25
 
26
  @dataclass
27
  class TokenContrast:
@@ -73,28 +70,18 @@ def _strip_special_ids(
73
  return ids[keep], keep
74
 
75
 
76
- def _prepare_trace_text(
77
  tokenizer: object,
78
  context_messages: list[dict[str, str]],
79
  response_ids: torch.Tensor,
80
- ) -> tuple[str, int, int]:
81
- """Build the trace text and return ``(full_text, n_ctx, n_resp)``."""
82
  context_prompt, _ = format_generation_prompt(context_messages, tokenizer)
83
  context_ids = tokenizer(context_prompt, return_tensors="pt").input_ids[0]
84
- response_text = _decode_ids(tokenizer, response_ids.tolist())
85
- full_text = context_prompt + response_text
86
- full_ids = tokenizer(full_text, return_tensors="pt").input_ids[0]
87
- expected_ids = torch.cat([context_ids, response_ids.cpu()])
88
- if full_ids.tolist() != expected_ids.tolist():
89
- logger.warning(
90
- "contrast trace text did not round-trip to the expected token ids "
91
- "(expected %d tokens, got %d); contrast scores may be slightly misaligned",
92
- len(expected_ids),
93
- len(full_ids),
94
- )
95
  n_ctx = len(context_ids)
96
  n_resp = len(response_ids)
97
- return full_text, n_ctx, n_resp
98
 
99
 
100
  def _build_contrast(
@@ -122,8 +109,8 @@ def _token_display(tokenizer: object, token_id: int) -> str:
122
  return _decode_ids(tokenizer, [token_id])
123
 
124
 
125
- # Each spec: (key, full_text, n_ctx, n_resp, target_ids).
126
- PassSpec = tuple[str, str, int, int, torch.Tensor]
127
 
128
 
129
  def _score_passes(
@@ -140,12 +127,12 @@ def _score_passes(
140
  """
141
 
142
  def _score_pass(
143
- full_text: str,
144
  n_ctx: int,
145
  n_resp: int,
146
  target_ids: torch.Tensor,
147
  ) -> torch.Tensor:
148
- with torch.no_grad(), model.trace(full_text, remote=remote):
149
  # logit at position i predicts token i+1, so response token j
150
  # (at full-text position n_ctx+j) uses logit at n_ctx+j-1.
151
  resp_logits = model.logits[0, n_ctx - 1 : n_ctx - 1 + n_resp].float()
@@ -163,8 +150,8 @@ def _score_passes(
163
  return out.detach().cpu()
164
 
165
  return {
166
- key: _score_pass(full_text, n_ctx, n_resp, target_ids)
167
- for key, full_text, n_ctx, n_resp, target_ids in specs
168
  }
169
 
170
 
@@ -176,11 +163,13 @@ def _specs_for_response(
176
  prefix: str,
177
  ) -> list[PassSpec]:
178
  """Build the (under_a, under_b) pass specs for a single response."""
179
- text_a, n_ctx_a, n_resp = _prepare_trace_text(tokenizer, context_a, response_ids)
180
- text_b, n_ctx_b, _ = _prepare_trace_text(tokenizer, context_b, response_ids)
 
 
181
  return [
182
- (f"{prefix}_under_a", text_a, n_ctx_a, n_resp, response_ids),
183
- (f"{prefix}_under_b", text_b, n_ctx_b, n_resp, response_ids),
184
  ]
185
 
186
 
 
11
  Near-zero (gray) → both personas would emit this token with similar likelihood.
12
  """
13
 
 
14
  from dataclasses import dataclass
15
  from html import escape
16
 
 
19
 
20
  from utils.chat import format_generation_prompt
21
 
 
 
22
 
23
  @dataclass
24
  class TokenContrast:
 
70
  return ids[keep], keep
71
 
72
 
73
+ def _prepare_trace_input_ids(
74
  tokenizer: object,
75
  context_messages: list[dict[str, str]],
76
  response_ids: torch.Tensor,
77
+ ) -> tuple[torch.Tensor, int, int]:
78
+ """Build exact trace input ids and return ``(input_ids, n_ctx, n_resp)``."""
79
  context_prompt, _ = format_generation_prompt(context_messages, tokenizer)
80
  context_ids = tokenizer(context_prompt, return_tensors="pt").input_ids[0]
81
+ input_ids = torch.cat([context_ids.cpu(), response_ids.detach().cpu()])
 
 
 
 
 
 
 
 
 
 
82
  n_ctx = len(context_ids)
83
  n_resp = len(response_ids)
84
+ return input_ids, n_ctx, n_resp
85
 
86
 
87
  def _build_contrast(
 
109
  return _decode_ids(tokenizer, [token_id])
110
 
111
 
112
+ # Each spec: (key, input_ids, n_ctx, n_resp, target_ids).
113
+ PassSpec = tuple[str, torch.Tensor, int, int, torch.Tensor]
114
 
115
 
116
  def _score_passes(
 
127
  """
128
 
129
  def _score_pass(
130
+ input_ids: torch.Tensor,
131
  n_ctx: int,
132
  n_resp: int,
133
  target_ids: torch.Tensor,
134
  ) -> torch.Tensor:
135
+ with torch.no_grad(), model.trace(input_ids, remote=remote):
136
  # logit at position i predicts token i+1, so response token j
137
  # (at full-text position n_ctx+j) uses logit at n_ctx+j-1.
138
  resp_logits = model.logits[0, n_ctx - 1 : n_ctx - 1 + n_resp].float()
 
150
  return out.detach().cpu()
151
 
152
  return {
153
+ key: _score_pass(input_ids, n_ctx, n_resp, target_ids)
154
+ for key, input_ids, n_ctx, n_resp, target_ids in specs
155
  }
156
 
157
 
 
163
  prefix: str,
164
  ) -> list[PassSpec]:
165
  """Build the (under_a, under_b) pass specs for a single response."""
166
+ input_a, n_ctx_a, n_resp = _prepare_trace_input_ids(
167
+ tokenizer, context_a, response_ids
168
+ )
169
+ input_b, n_ctx_b, _ = _prepare_trace_input_ids(tokenizer, context_b, response_ids)
170
  return [
171
+ (f"{prefix}_under_a", input_a, n_ctx_a, n_resp, response_ids),
172
+ (f"{prefix}_under_b", input_b, n_ctx_b, n_resp, response_ids),
173
  ]
174
 
175
 
utils/runtime.py CHANGED
@@ -1,8 +1,56 @@
 
1
  import logging
 
2
 
3
  import streamlit as st
4
 
5
  logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  @st.cache_data(show_spinner=False, ttl=30)
@@ -16,8 +64,6 @@ def list_remote_models() -> list[str]:
16
  the whole response. See nnsight 0.6.3 ``ndif.py::status``.
17
  """
18
 
19
- import json
20
-
21
  import nnsight
22
 
23
  try:
@@ -29,32 +75,11 @@ def list_remote_models() -> list[str]:
29
  model_names: list[str] = []
30
  bad_states: list[tuple[str, str]] = [] # (repo_id_or_key, application_state)
31
 
32
- for value in (raw or {}).get("deployments", {}).values():
33
- if not isinstance(value, dict):
34
- continue
35
- if (
36
- value.get("deployment_level") not in {"HOT", "WARM"}
37
- and "schedule" not in value
38
- ):
39
- continue
40
-
41
- model_key = value.get("model_key", "")
42
- model_class = model_key.split(":", 1)[0].split(".")[-1]
43
- try:
44
- repo_id = json.loads(model_key.split(":", 1)[-1]).get("repo_id")
45
- except Exception:
46
- repo_id = model_key
47
-
48
- state = value.get("application_state", "NOT DEPLOYED")
49
- if state not in {"RUNNING", "NOT DEPLOYED", "DEPLOYING", "DELETING"}:
50
- bad_states.append((repo_id or model_key, state))
51
-
52
- if model_class not in {"LanguageModel", "StandardizedTransformer"}:
53
- continue
54
- if state != "RUNNING":
55
- continue
56
- if isinstance(repo_id, str):
57
- model_names.append(repo_id)
58
 
59
  if bad_states:
60
  logger.warning(
@@ -67,27 +92,17 @@ def list_remote_models() -> list[str]:
67
 
68
 
69
  @st.cache_resource(show_spinner=False, max_entries=1)
70
- def _cached_model_by_name(model_name: str):
71
  """Load and cache a standardized nnterp model.
72
 
73
  Streamlit reruns this app on every interaction, so caching keeps one loaded
74
  model instance per model name instead of reloading weights on every widget
75
- change.
 
 
 
76
  """
77
 
78
  from nnterp import StandardizedTransformer
79
 
80
- # The remote constructor path is currently unstable for this model wrapper.
81
- # return StandardizedTransformer(model_name, remote=remote, check_renaming=False)
82
  return StandardizedTransformer(model_name)
83
-
84
-
85
- def cached_model(model_name: str, remote: bool):
86
- """Return the cached model for ``model_name``.
87
-
88
- ``remote`` still matters at generation/trace time, but the current
89
- ``StandardizedTransformer`` constructor ignores it. Keeping it out of the
90
- cache key avoids loading duplicate local model objects when toggling NDIF.
91
- """
92
-
93
- return _cached_model_by_name(model_name)
 
1
+ import json
2
  import logging
3
+ from collections.abc import Iterable
4
 
5
  import streamlit as st
6
 
7
  logger = logging.getLogger(__name__)
8
+ _LANGUAGE_MODEL_CLASSES = {"LanguageModel", "StandardizedTransformer"}
9
+ _EXPECTED_NDIF_STATES = {"RUNNING", "NOT DEPLOYED", "DEPLOYING", "DELETING"}
10
+
11
+
12
+ def _iter_deployments(raw: object) -> Iterable[dict]:
13
+ if not isinstance(raw, dict):
14
+ return ()
15
+ deployments = raw.get("deployments", {})
16
+ if not isinstance(deployments, dict):
17
+ return ()
18
+ return (value for value in deployments.values() if isinstance(value, dict))
19
+
20
+
21
+ def _is_visible_deployment(deployment: dict) -> bool:
22
+ return deployment.get("deployment_level") in {"HOT", "WARM"} or (
23
+ "schedule" in deployment
24
+ )
25
+
26
+
27
+ def _repo_id_from_model_key(model_key: str) -> str:
28
+ try:
29
+ repo_id = json.loads(model_key.split(":", 1)[-1]).get("repo_id")
30
+ except Exception:
31
+ return model_key
32
+ return repo_id if isinstance(repo_id, str) else model_key
33
+
34
+
35
+ def _running_language_model(deployment: dict) -> str | None:
36
+ if not _is_visible_deployment(deployment):
37
+ return None
38
+
39
+ model_key = deployment.get("model_key", "")
40
+ model_class = model_key.split(":", 1)[0].split(".")[-1]
41
+ if model_class not in _LANGUAGE_MODEL_CLASSES:
42
+ return None
43
+ if deployment.get("application_state", "NOT DEPLOYED") != "RUNNING":
44
+ return None
45
+ return _repo_id_from_model_key(model_key)
46
+
47
+
48
+ def _unexpected_state(deployment: dict) -> tuple[str, str] | None:
49
+ state = deployment.get("application_state", "NOT DEPLOYED")
50
+ if state in _EXPECTED_NDIF_STATES:
51
+ return None
52
+ model_key = deployment.get("model_key", "")
53
+ return _repo_id_from_model_key(model_key), state
54
 
55
 
56
  @st.cache_data(show_spinner=False, ttl=30)
 
64
  the whole response. See nnsight 0.6.3 ``ndif.py::status``.
65
  """
66
 
 
 
67
  import nnsight
68
 
69
  try:
 
75
  model_names: list[str] = []
76
  bad_states: list[tuple[str, str]] = [] # (repo_id_or_key, application_state)
77
 
78
+ for deployment in _iter_deployments(raw):
79
+ if bad_state := _unexpected_state(deployment):
80
+ bad_states.append(bad_state)
81
+ if model_name := _running_language_model(deployment):
82
+ model_names.append(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  if bad_states:
85
  logger.warning(
 
92
 
93
 
94
  @st.cache_resource(show_spinner=False, max_entries=1)
95
+ def cached_model(model_name: str):
96
  """Load and cache a standardized nnterp model.
97
 
98
  Streamlit reruns this app on every interaction, so caching keeps one loaded
99
  model instance per model name instead of reloading weights on every widget
100
+ change. ``remote`` is intentionally not part of the cache key: it matters
101
+ at generation/trace time, but the current ``StandardizedTransformer``
102
+ constructor ignores it, and excluding it avoids loading duplicate local
103
+ model objects when toggling NDIF.
104
  """
105
 
106
  from nnterp import StandardizedTransformer
107
 
 
 
108
  return StandardizedTransformer(model_name)
 
 
 
 
 
 
 
 
 
 
 
uv.lock CHANGED
The diff for this file is too large to render. See raw diff