Jac-Zac commited on
Commit
c30bbc5
·
1 Parent(s): 330d092

Updated cleaned up code

Browse files
Files changed (14) hide show
  1. app.py +59 -62
  2. pyproject.toml +2 -2
  3. state.py +3 -4
  4. tabs/chat.py +0 -2
  5. tabs/chat_ui.py +23 -53
  6. tabs/compare.py +1 -4
  7. tabs/extract.py +15 -23
  8. utils/chat.py +21 -3
  9. utils/contrast.py +3 -27
  10. utils/datasets.py +6 -20
  11. utils/helpers.py +1 -3
  12. utils/probe_trace.py +3 -21
  13. utils/probes.py +22 -22
  14. uv.lock +14 -14
app.py CHANGED
@@ -16,6 +16,64 @@ _TABS = ["Chat", "Compare", "Extract"]
16
  _TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def _sidebar_controls() -> tuple[bool, str, str, str]:
20
  from utils.runtime import list_remote_models
21
 
@@ -44,68 +102,7 @@ def _sidebar_controls() -> tuple[bool, str, str, str]:
44
  remote = st.toggle("Remote (NDIF)", value=False, key="sidebar__remote")
45
 
46
  if remote:
47
- remote_models = list_remote_models()
48
- custom_remote_key = "sidebar__remote_model_custom_enabled"
49
- custom_remote_model = st.toggle(
50
- "Custom remote model",
51
- value=False,
52
- key=custom_remote_key,
53
- help="Enter any NDIF-loadable model id, even if it is not currently running.",
54
- )
55
- if remote_models:
56
- if custom_remote_model:
57
- model_name = st.text_input(
58
- "Model",
59
- value=st.session_state.get(
60
- "sidebar__remote_model_custom_value",
61
- st.session_state.get(
62
- _LAST_REMOTE_MODEL_KEY, REMOTE_DEFAULT_MODEL
63
- ),
64
- ),
65
- key="sidebar__remote_model_custom_value",
66
- help="NDIF model id. Example: openai/gpt-oss-20b",
67
- )
68
- st.caption(
69
- f"{len(remote_models)} running NDIF model(s) detected. Custom model ids can cold-load if your NDIF account allows it."
70
- )
71
- else:
72
- default_model = st.session_state.get(
73
- "sidebar__remote_model",
74
- st.session_state.get(_LAST_REMOTE_MODEL_KEY),
75
- )
76
- if default_model not in remote_models:
77
- default_model = (
78
- REMOTE_DEFAULT_MODEL
79
- if REMOTE_DEFAULT_MODEL in remote_models
80
- else remote_models[0]
81
- )
82
- if (
83
- st.session_state.get("sidebar__remote_model")
84
- not in remote_models
85
- ):
86
- st.session_state["sidebar__remote_model"] = default_model
87
- selected_remote_model = st.selectbox(
88
- "Model",
89
- options=remote_models,
90
- index=remote_models.index(default_model),
91
- key="sidebar__remote_model",
92
- help="Running NDIF model.",
93
- )
94
- model_name = selected_remote_model
95
- else:
96
- st.warning("No running NDIF models found.")
97
- model_name = st.text_input(
98
- "Model",
99
- value=st.session_state.get(
100
- "sidebar__remote_model_custom_value",
101
- st.session_state.get(
102
- _LAST_REMOTE_MODEL_KEY, REMOTE_DEFAULT_MODEL
103
- ),
104
- ),
105
- key="sidebar__remote_model_custom_value",
106
- help="NDIF model id. Use this to cold-load a remote model.",
107
- )
108
- st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name
109
  else:
110
  model_name = st.text_input(
111
  "Model",
 
16
  _TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
17
 
18
 
19
+ def _remote_model_input(remote_models: list[str]) -> str:
20
+ """Return the active remote model id, picking from running NDIF deployments or a custom value."""
21
+
22
+ last_remote = st.session_state.get(_LAST_REMOTE_MODEL_KEY, REMOTE_DEFAULT_MODEL)
23
+
24
+ if not remote_models:
25
+ st.warning("No running NDIF models found.")
26
+ model_name = st.text_input(
27
+ "Model",
28
+ value=st.session_state.get(
29
+ "sidebar__remote_model_custom_value", last_remote
30
+ ),
31
+ key="sidebar__remote_model_custom_value",
32
+ help="NDIF model id. Use this to cold-load a remote model.",
33
+ )
34
+ st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name
35
+ return model_name
36
+
37
+ custom = st.toggle(
38
+ "Custom remote model",
39
+ value=False,
40
+ key="sidebar__remote_model_custom_enabled",
41
+ help="Enter any NDIF-loadable model id, even if it is not currently running.",
42
+ )
43
+ if custom:
44
+ model_name = st.text_input(
45
+ "Model",
46
+ value=st.session_state.get(
47
+ "sidebar__remote_model_custom_value", last_remote
48
+ ),
49
+ key="sidebar__remote_model_custom_value",
50
+ help="NDIF model id. Example: openai/gpt-oss-20b",
51
+ )
52
+ st.caption(
53
+ f"{len(remote_models)} running NDIF model(s) detected. "
54
+ "Custom model ids can cold-load if your NDIF account allows it."
55
+ )
56
+ else:
57
+ default_model = st.session_state.get("sidebar__remote_model", last_remote)
58
+ if default_model not in remote_models:
59
+ default_model = (
60
+ REMOTE_DEFAULT_MODEL
61
+ if REMOTE_DEFAULT_MODEL in remote_models
62
+ else remote_models[0]
63
+ )
64
+ if st.session_state.get("sidebar__remote_model") not in remote_models:
65
+ st.session_state["sidebar__remote_model"] = default_model
66
+ model_name = st.selectbox(
67
+ "Model",
68
+ options=remote_models,
69
+ index=remote_models.index(default_model),
70
+ key="sidebar__remote_model",
71
+ help="Running NDIF model.",
72
+ )
73
+ st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name
74
+ return model_name
75
+
76
+
77
  def _sidebar_controls() -> tuple[bool, str, str, str]:
78
  from utils.runtime import list_remote_models
79
 
 
102
  remote = st.toggle("Remote (NDIF)", value=False, key="sidebar__remote")
103
 
104
  if remote:
105
+ model_name = _remote_model_input(list_remote_models())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  else:
107
  model_name = st.text_input(
108
  "Model",
pyproject.toml CHANGED
@@ -5,8 +5,8 @@ description = "Streamlit UI for persona-vectors"
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
- "persona-vectors>=0.6.1",
9
- "persona-data>=0.4.1",
10
  "streamlit>=1.44.0",
11
  "plotly>=6.6.0",
12
  "python-dotenv>=1.2.2",
 
5
  readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
+ "persona-vectors>=0.6.3",
9
+ "persona-data>=0.4.2",
10
  "streamlit>=1.44.0",
11
  "plotly>=6.6.0",
12
  "python-dotenv>=1.2.2",
state.py CHANGED
@@ -1,6 +1,7 @@
1
- import streamlit as st
2
  from typing import Literal, NotRequired, TypedDict
3
 
 
 
4
  _CHAT_STATE_PREFIX = "chat_state::"
5
  PendingChatAction = Literal["new_user_prompt", "regenerate_after_edit"]
6
 
@@ -50,9 +51,7 @@ def reset_chat_context_state(
50
  st.session_state.pop(key, None)
51
 
52
 
53
- def get_chat_state(
54
- model_name: str, remote: bool, dataset_source: str
55
- ) -> ChatState:
56
  """Return the mutable chat state for the active context."""
57
 
58
  key = chat_session_key(model_name, dataset_source)
 
 
1
  from typing import Literal, NotRequired, TypedDict
2
 
3
+ import streamlit as st
4
+
5
  _CHAT_STATE_PREFIX = "chat_state::"
6
  PendingChatAction = Literal["new_user_prompt", "regenerate_after_edit"]
7
 
 
51
  st.session_state.pop(key, None)
52
 
53
 
54
+ def get_chat_state(model_name: str, remote: bool, dataset_source: str) -> ChatState:
 
 
55
  """Return the mutable chat state for the active context."""
56
 
57
  key = chat_session_key(model_name, dataset_source)
tabs/chat.py CHANGED
@@ -128,8 +128,6 @@ def _handle_single_chat_generation(
128
  st.rerun()
129
 
130
 
131
-
132
-
133
  def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
134
  """Render the chat tab."""
135
 
 
128
  st.rerun()
129
 
130
 
 
 
131
  def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
132
  """Render the chat tab."""
133
 
tabs/chat_ui.py CHANGED
@@ -269,76 +269,46 @@ def render_chat_message(
269
  ) -> None:
270
  if not message.get("content"):
271
  return
272
- role = message["role"]
273
  contrast: TokenContrast | None = message.get("_contrast") if show_contrast else None
274
- with st.chat_message(role):
275
  if contrast is not None:
276
  st.html(render_contrast_html(contrast))
277
  else:
278
  st.markdown(message["content"])
279
 
280
 
281
- def _render_editable_message(
282
- message: dict[str, str],
283
- msg_index: int,
284
- messages: list[dict[str, str]],
285
- chat_state: dict[str, object],
286
- edit_key: str,
287
- pending_key: str,
288
- show_contrast: bool = False,
289
- column_ratio: tuple[int, int] = (25, 1),
290
- ) -> None:
291
- if not message.get("content"):
292
- return
293
- role = message["role"]
294
- contrast: TokenContrast | None = message.get("_contrast") if show_contrast else None
295
- msg_col, edit_col = st.columns(
296
- list(column_ratio), gap="xsmall", vertical_alignment="center"
297
- )
298
-
299
- with msg_col:
300
- with st.chat_message(role):
301
- if contrast is not None:
302
- st.html(render_contrast_html(contrast))
303
- else:
304
- st.markdown(message["content"])
305
- with edit_col:
306
- if st.button(
307
- "", icon=":material/edit:", key=f"{edit_key}_edit_{msg_index}", help="Edit"
308
- ):
309
- _open_edit_dialog(
310
- msg_index=msg_index,
311
- messages=messages,
312
- chat_state=chat_state,
313
- pending_key=pending_key,
314
- )
315
-
316
-
317
  def render_chat_window(
318
  *,
319
  chat_log: Any,
320
  messages: list[dict[str, str]],
321
- chat_state: dict[str, object] | None = None,
322
- edit_key: str | None = None,
323
- pending_key: str | None = None,
324
  show_contrast: bool = False,
325
  edit_column_ratio: tuple[int, int] = (25, 1),
326
  ) -> None:
327
  with chat_log:
328
  for i, message in enumerate(messages):
329
- if edit_key and pending_key and chat_state is not None:
330
- _render_editable_message(
331
- message,
332
- i,
333
- messages,
334
- chat_state,
335
- edit_key,
336
- pending_key,
337
- show_contrast=show_contrast,
338
- column_ratio=edit_column_ratio,
339
- )
340
- else:
341
  render_chat_message(message, show_contrast=show_contrast)
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
 
344
  def _assistant_first(personas: list[PersonaData]) -> list[PersonaData]:
 
269
  ) -> None:
270
  if not message.get("content"):
271
  return
 
272
  contrast: TokenContrast | None = message.get("_contrast") if show_contrast else None
273
+ with st.chat_message(message["role"]):
274
  if contrast is not None:
275
  st.html(render_contrast_html(contrast))
276
  else:
277
  st.markdown(message["content"])
278
 
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  def render_chat_window(
281
  *,
282
  chat_log: Any,
283
  messages: list[dict[str, str]],
284
+ chat_state: dict[str, object],
285
+ edit_key: str,
286
+ pending_key: str,
287
  show_contrast: bool = False,
288
  edit_column_ratio: tuple[int, int] = (25, 1),
289
  ) -> None:
290
  with chat_log:
291
  for i, message in enumerate(messages):
292
+ if not message.get("content"):
293
+ continue
294
+ msg_col, edit_col = st.columns(
295
+ list(edit_column_ratio), gap="xsmall", vertical_alignment="center"
296
+ )
297
+ with msg_col:
 
 
 
 
 
 
298
  render_chat_message(message, show_contrast=show_contrast)
299
+ with edit_col:
300
+ if st.button(
301
+ "",
302
+ icon=":material/edit:",
303
+ key=f"{edit_key}_edit_{i}",
304
+ help="Edit",
305
+ ):
306
+ _open_edit_dialog(
307
+ msg_index=i,
308
+ messages=messages,
309
+ chat_state=chat_state,
310
+ pending_key=pending_key,
311
+ )
312
 
313
 
314
  def _assistant_first(personas: list[PersonaData]) -> list[PersonaData]:
tabs/compare.py CHANGED
@@ -5,10 +5,7 @@ from itertools import combinations
5
 
6
  import streamlit as st
7
  from persona_data.environment import get_artifacts_dir
8
- from persona_vectors.analysis import (
9
- load_persona_vectors,
10
- load_variant_vectors,
11
- )
12
  from persona_vectors.artifacts import ActivationStore, HFActivationStore
13
  from persona_vectors.artifacts import list_layers as list_local_layers
14
  from persona_vectors.extraction import MaskStrategy
 
5
 
6
  import streamlit as st
7
  from persona_data.environment import get_artifacts_dir
8
+ from persona_vectors.analysis import load_persona_vectors, load_variant_vectors
 
 
 
9
  from persona_vectors.artifacts import ActivationStore, HFActivationStore
10
  from persona_vectors.artifacts import list_layers as list_local_layers
11
  from persona_vectors.extraction import MaskStrategy
tabs/extract.py CHANGED
@@ -102,7 +102,9 @@ def _render_variant_controls(
102
  return selected_variants, include_baseline
103
 
104
 
105
- def _load_qa_dataset_personas(dataset_source: str) -> tuple[object, list[PersonaData]] | None:
 
 
106
  try:
107
  dataset, dataset_status = load_dataset(
108
  dataset_source,
@@ -237,7 +239,9 @@ def _collect_runs(
237
  runs, skipped = [], []
238
  for persona in selected_personas:
239
  if persona.id == BASELINE_PERSONA_ID:
240
- qa = list(dataset.get_qa(BASELINE_PERSONA_ID, item_type="mcq", scope="shared"))
 
 
241
  elif hasattr(dataset, "train_test_split"):
242
  qa, _ = dataset.train_test_split(persona.id)
243
  else:
@@ -268,28 +272,15 @@ def _render_max_questions(
268
  "Max questions (train split)",
269
  min_value=1,
270
  max_value=max_q,
271
- value=min(max(st.session_state.get(_LAST_MAX_QUESTIONS_KEY, default), 1), max_q),
 
 
272
  key=_extract_widget_key(model_name, remote, dataset_source, "max_questions"),
273
  )
274
  st.session_state[_LAST_MAX_QUESTIONS_KEY] = max_questions
275
  return max_questions
276
 
277
 
278
- def _render_advanced_settings(
279
- *,
280
- model_name: str,
281
- remote: bool,
282
- dataset_source: str,
283
- ) -> MaskStrategy:
284
- with st.expander("Advanced", expanded=False):
285
- mask_strategy = _render_mask_strategy_select(
286
- model_name=model_name,
287
- remote=remote,
288
- dataset_source=dataset_source,
289
- )
290
- return mask_strategy
291
-
292
-
293
  def _render_extract_actions() -> tuple[bool, bool]:
294
  run_col, preview_col, _spacer = st.columns([1, 1, 4], gap="small")
295
  with run_col:
@@ -439,11 +430,12 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
439
  dataset_source=dataset_source,
440
  runs=runs,
441
  )
442
- mask_strategy = _render_advanced_settings(
443
- model_name=model_name,
444
- remote=remote,
445
- dataset_source=dataset_source,
446
- )
 
447
  settings = ExtractSettings(
448
  mask_strategy=mask_strategy,
449
  max_questions=max_questions,
 
102
  return selected_variants, include_baseline
103
 
104
 
105
+ def _load_qa_dataset_personas(
106
+ dataset_source: str,
107
+ ) -> tuple[object, list[PersonaData]] | None:
108
  try:
109
  dataset, dataset_status = load_dataset(
110
  dataset_source,
 
239
  runs, skipped = [], []
240
  for persona in selected_personas:
241
  if persona.id == BASELINE_PERSONA_ID:
242
+ qa = list(
243
+ dataset.get_qa(BASELINE_PERSONA_ID, item_type="mcq", scope="shared")
244
+ )
245
  elif hasattr(dataset, "train_test_split"):
246
  qa, _ = dataset.train_test_split(persona.id)
247
  else:
 
272
  "Max questions (train split)",
273
  min_value=1,
274
  max_value=max_q,
275
+ value=min(
276
+ max(st.session_state.get(_LAST_MAX_QUESTIONS_KEY, default), 1), max_q
277
+ ),
278
  key=_extract_widget_key(model_name, remote, dataset_source, "max_questions"),
279
  )
280
  st.session_state[_LAST_MAX_QUESTIONS_KEY] = max_questions
281
  return max_questions
282
 
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  def _render_extract_actions() -> tuple[bool, bool]:
285
  run_col, preview_col, _spacer = st.columns([1, 1, 4], gap="small")
286
  with run_col:
 
430
  dataset_source=dataset_source,
431
  runs=runs,
432
  )
433
+ with st.expander("Advanced", expanded=False):
434
+ mask_strategy = _render_mask_strategy_select(
435
+ model_name=model_name,
436
+ remote=remote,
437
+ dataset_source=dataset_source,
438
+ )
439
  settings = ExtractSettings(
440
  mask_strategy=mask_strategy,
441
  max_questions=max_questions,
utils/chat.py CHANGED
@@ -74,9 +74,7 @@ def _format_plain_messages(
74
  else:
75
  lines.append(f"{role.title()}: {content}")
76
 
77
- if add_generation_prompt and (
78
- not lines or not lines[-1].startswith("Assistant:")
79
- ):
80
  lines.append("Assistant:")
81
 
82
  return "\n\n".join(lines)
@@ -130,6 +128,26 @@ def format_generation_prompt(
130
  return prompt, prompt_token_count
131
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  @contextmanager
134
  def _seeded_rng(seed: int | None):
135
  """Context manager that forks the RNG state and sets a deterministic seed."""
 
74
  else:
75
  lines.append(f"{role.title()}: {content}")
76
 
77
+ if add_generation_prompt and (not lines or not lines[-1].startswith("Assistant:")):
 
 
78
  lines.append("Assistant:")
79
 
80
  return "\n\n".join(lines)
 
128
  return prompt, prompt_token_count
129
 
130
 
131
+ def resolve_saved_tensor(value: object) -> torch.Tensor:
132
+ """Resolve an nnsight ``.save()`` proxy (or raw tensor) to a CPU tensor."""
133
+ resolved = value.value if getattr(value, "value", None) is not None else value
134
+ if not isinstance(resolved, torch.Tensor):
135
+ raise TypeError(f"Trace result did not resolve to a tensor: {type(resolved)!r}")
136
+ return resolved.detach().cpu()
137
+
138
+
139
+ def decode_token(tokenizer: object, token_id: int) -> str:
140
+ """Decode a single token id, falling back when ``clean_up_tokenization_spaces`` is unsupported."""
141
+ try:
142
+ return tokenizer.decode(
143
+ [token_id],
144
+ skip_special_tokens=False,
145
+ clean_up_tokenization_spaces=False,
146
+ )
147
+ except TypeError:
148
+ return tokenizer.decode([token_id], skip_special_tokens=False)
149
+
150
+
151
  @contextmanager
152
  def _seeded_rng(seed: int | None):
153
  """Context manager that forks the RNG state and sets a deterministic seed."""
utils/contrast.py CHANGED
@@ -17,7 +17,7 @@ from html import escape
17
  import torch
18
  from nnterp import StandardizedTransformer
19
 
20
- from utils.chat import format_generation_prompt
21
 
22
 
23
  @dataclass
@@ -43,18 +43,6 @@ def _normalise_diffs(diffs: torch.Tensor) -> list[float]:
43
  return (diffs.float().clamp(-clip_val, clip_val) / clip_val).tolist()
44
 
45
 
46
- def _decode_ids(tokenizer: object, ids: list[int]) -> str:
47
- """Decode token IDs, falling back when clean_up_tokenization_spaces is unsupported."""
48
- try:
49
- return tokenizer.decode(
50
- ids,
51
- skip_special_tokens=False,
52
- clean_up_tokenization_spaces=False,
53
- )
54
- except TypeError:
55
- return tokenizer.decode(ids, skip_special_tokens=False)
56
-
57
-
58
  def _strip_special_ids(
59
  ids: torch.Tensor,
60
  tokenizer: object,
@@ -96,7 +84,7 @@ def _build_contrast(
96
  display_ids, keep_mask = _strip_special_ids(response_ids, tokenizer)
97
  display_diffs = diffs[keep_mask]
98
  return TokenContrast(
99
- tokens=[_token_display(tokenizer, tid.item()) for tid in display_ids],
100
  weights=_normalise_diffs(display_diffs),
101
  raw_diffs=display_diffs.float().tolist(),
102
  label_a=label_a,
@@ -104,11 +92,6 @@ def _build_contrast(
104
  )
105
 
106
 
107
- def _token_display(tokenizer: object, token_id: int) -> str:
108
- """Render a single token id as normal decoded text."""
109
- return _decode_ids(tokenizer, [token_id])
110
-
111
-
112
  # Each spec: (key, input_ids, n_ctx, n_resp, target_ids).
113
  PassSpec = tuple[str, torch.Tensor, int, int, torch.Tensor]
114
 
@@ -140,14 +123,7 @@ def _score_passes(
140
  targets = target_ids.to(log_probs.device).view(-1, 1)
141
  picked = log_probs.gather(1, targets).view(-1)
142
  out = picked.detach().cpu().save()
143
-
144
- if getattr(out, "value", None) is not None:
145
- out = out.value
146
- if not isinstance(out, torch.Tensor):
147
- raise TypeError(
148
- f"contrast score did not resolve to a tensor: {type(out)!r}"
149
- )
150
- return out.detach().cpu()
151
 
152
  return {
153
  key: _score_pass(input_ids, n_ctx, n_resp, target_ids)
 
17
  import torch
18
  from nnterp import StandardizedTransformer
19
 
20
+ from utils.chat import decode_token, format_generation_prompt, resolve_saved_tensor
21
 
22
 
23
  @dataclass
 
43
  return (diffs.float().clamp(-clip_val, clip_val) / clip_val).tolist()
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def _strip_special_ids(
47
  ids: torch.Tensor,
48
  tokenizer: object,
 
84
  display_ids, keep_mask = _strip_special_ids(response_ids, tokenizer)
85
  display_diffs = diffs[keep_mask]
86
  return TokenContrast(
87
+ tokens=[decode_token(tokenizer, tid.item()) for tid in display_ids],
88
  weights=_normalise_diffs(display_diffs),
89
  raw_diffs=display_diffs.float().tolist(),
90
  label_a=label_a,
 
92
  )
93
 
94
 
 
 
 
 
 
95
  # Each spec: (key, input_ids, n_ctx, n_resp, target_ids).
96
  PassSpec = tuple[str, torch.Tensor, int, int, torch.Tensor]
97
 
 
123
  targets = target_ids.to(log_probs.device).view(-1, 1)
124
  picked = log_probs.gather(1, targets).view(-1)
125
  out = picked.detach().cpu().save()
126
+ return resolve_saved_tensor(out)
 
 
 
 
 
 
 
127
 
128
  return {
129
  key: _score_pass(input_ids, n_ctx, n_resp, target_ids)
utils/datasets.py CHANGED
@@ -17,24 +17,10 @@ from .helpers import DATASET_SOURCES
17
 
18
 
19
  @st.cache_resource(show_spinner=False)
20
- def cached_hf_dataset() -> SynthPersonaDataset:
21
- """Load the default SynthPersona HuggingFace dataset once."""
22
 
23
- return SynthPersonaDataset()
24
-
25
-
26
- @st.cache_resource(show_spinner=False)
27
- def cached_nemotron_dataset() -> NemotronPersonasFranceDataset:
28
- """Load the Nemotron France HuggingFace dataset once."""
29
-
30
- return NemotronPersonasFranceDataset()
31
-
32
-
33
- @st.cache_resource(show_spinner=False)
34
- def cached_nemotron_usa_dataset() -> NemotronPersonasUSADataset:
35
- """Load the Nemotron USA HuggingFace dataset once."""
36
-
37
- return NemotronPersonasUSADataset()
38
 
39
 
40
  def _upload_cache_dir() -> Path:
@@ -74,13 +60,13 @@ def load_dataset(
74
  """Load the selected dataset source for the UI."""
75
 
76
  if dataset_source == DATASET_SOURCES[0]:
77
- return cached_hf_dataset(), "SynthPersona"
78
 
79
  if dataset_source == DATASET_SOURCES[1]:
80
- return cached_nemotron_dataset(), "Nemotron France"
81
 
82
  if dataset_source == DATASET_SOURCES[2]:
83
- return cached_nemotron_usa_dataset(), "Nemotron USA"
84
 
85
  if personas_file is None or qa_file is None:
86
  raise ValueError("Upload both personas.jsonl and qa.jsonl files")
 
17
 
18
 
19
  @st.cache_resource(show_spinner=False)
20
+ def _cached_dataset(cls: type) -> Any:
21
+ """Instantiate and cache a HuggingFace dataset class once per session."""
22
 
23
+ return cls()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  def _upload_cache_dir() -> Path:
 
60
  """Load the selected dataset source for the UI."""
61
 
62
  if dataset_source == DATASET_SOURCES[0]:
63
+ return _cached_dataset(SynthPersonaDataset), "SynthPersona"
64
 
65
  if dataset_source == DATASET_SOURCES[1]:
66
+ return _cached_dataset(NemotronPersonasFranceDataset), "Nemotron France"
67
 
68
  if dataset_source == DATASET_SOURCES[2]:
69
+ return _cached_dataset(NemotronPersonasUSADataset), "Nemotron USA"
70
 
71
  if personas_file is None or qa_file is None:
72
  raise ValueError("Upload both personas.jsonl and qa.jsonl files")
utils/helpers.py CHANGED
@@ -13,9 +13,7 @@ VARIANT_LABELS = {
13
 
14
  CHAT_PROMPT_MODES = ("empty", "templated", "biography", "custom")
15
  CHAT_PROMPT_MODE_LABELS = [VARIANT_LABELS[key] for key in CHAT_PROMPT_MODES]
16
- CHAT_PROMPT_MODE_LABEL_TO_KEY = {
17
- VARIANT_LABELS[key]: key for key in CHAT_PROMPT_MODES
18
- }
19
 
20
 
21
  DATASET_SOURCES = [
 
13
 
14
  CHAT_PROMPT_MODES = ("empty", "templated", "biography", "custom")
15
  CHAT_PROMPT_MODE_LABELS = [VARIANT_LABELS[key] for key in CHAT_PROMPT_MODES]
16
+ CHAT_PROMPT_MODE_LABEL_TO_KEY = {VARIANT_LABELS[key]: key for key in CHAT_PROMPT_MODES}
 
 
17
 
18
 
19
  DATASET_SOURCES = [
utils/probe_trace.py CHANGED
@@ -7,7 +7,7 @@ import streamlit as st
7
  import torch
8
  from nnterp import StandardizedTransformer
9
 
10
- from utils.chat import format_generation_prompt
11
 
12
  _TRACE_CACHE_KEY = "probe:trace_cache"
13
  _MAX_CACHED_TRACES = 3
@@ -74,8 +74,8 @@ def trace_conversation(
74
  saved_ids = model.input_ids[0].detach().cpu().save()
75
  saved_acts = accessor[layer][0].detach().float().cpu().save()
76
 
77
- input_ids = _resolve_saved_tensor(saved_ids)
78
- activations = _resolve_saved_tensor(saved_acts)
79
  if input_ids.ndim != 1:
80
  raise ValueError(
81
  f"Expected traced input ids to be [seq], got {tuple(input_ids.shape)}"
@@ -125,17 +125,6 @@ def vectorize_token(
125
  )
126
 
127
 
128
- def decode_token(tokenizer: object, token_id: int) -> str:
129
- try:
130
- return tokenizer.decode(
131
- [token_id],
132
- skip_special_tokens=False,
133
- clean_up_tokenization_spaces=False,
134
- )
135
- except TypeError:
136
- return tokenizer.decode([token_id], skip_special_tokens=False)
137
-
138
-
139
  def _select_accessor(model: StandardizedTransformer, location: str):
140
  normalized = location.lower()
141
  if normalized in {"pre_reasoning", "pre", "input", "layers_input"}:
@@ -145,13 +134,6 @@ def _select_accessor(model: StandardizedTransformer, location: str):
145
  raise ValueError(f"Unsupported trace location: {location!r}")
146
 
147
 
148
- def _resolve_saved_tensor(value) -> torch.Tensor:
149
- resolved = value.value if getattr(value, "value", None) is not None else value
150
- if not isinstance(resolved, torch.Tensor):
151
- raise TypeError(f"Trace result did not resolve to a tensor: {type(resolved)!r}")
152
- return resolved.detach().cpu()
153
-
154
-
155
  def _trace_cache_key(
156
  *,
157
  model_name: str,
 
7
  import torch
8
  from nnterp import StandardizedTransformer
9
 
10
+ from utils.chat import decode_token, format_generation_prompt, resolve_saved_tensor
11
 
12
  _TRACE_CACHE_KEY = "probe:trace_cache"
13
  _MAX_CACHED_TRACES = 3
 
74
  saved_ids = model.input_ids[0].detach().cpu().save()
75
  saved_acts = accessor[layer][0].detach().float().cpu().save()
76
 
77
+ input_ids = resolve_saved_tensor(saved_ids)
78
+ activations = resolve_saved_tensor(saved_acts)
79
  if input_ids.ndim != 1:
80
  raise ValueError(
81
  f"Expected traced input ids to be [seq], got {tuple(input_ids.shape)}"
 
125
  )
126
 
127
 
 
 
 
 
 
 
 
 
 
 
 
128
  def _select_accessor(model: StandardizedTransformer, location: str):
129
  normalized = location.lower()
130
  if normalized in {"pre_reasoning", "pre", "input", "layers_input"}:
 
134
  raise ValueError(f"Unsupported trace location: {location!r}")
135
 
136
 
 
 
 
 
 
 
 
137
  def _trace_cache_key(
138
  *,
139
  model_name: str,
utils/probes.py CHANGED
@@ -225,15 +225,28 @@ def _load_probe_payload(
225
  num_classes=num_classes,
226
  )
227
  labels = _normalize_labels(payload.get("idx_to_label"), num_classes)
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  return LoadedProbe(
229
  model=model,
230
  input_dim=input_dim,
231
  labels=labels,
232
  model_type=str(payload.get("model_type") or metadata.model_type),
233
- layer=_coerce_optional_int(payload.get("layer"), metadata.layer),
234
- location=_coerce_location(payload.get("location"), metadata.location),
235
- scaler_mean=_coerce_tensor(payload.get("scaler_mean")),
236
- scaler_std=_coerce_tensor(payload.get("scaler_std")),
237
  )
238
 
239
 
@@ -296,7 +309,9 @@ def _coerce_probe_dim(
296
  weights = [
297
  tensor
298
  for key, tensor in state_dict.items()
299
- if key.endswith("weight") and isinstance(tensor, torch.Tensor) and tensor.ndim == 2
 
 
300
  ]
301
  if not weights:
302
  raise ValueError(f"Cannot infer probe {dim} dimension from state dict")
@@ -349,27 +364,12 @@ def _coerce_hidden_dims(value: Any) -> list[int]:
349
  raise TypeError(f"Unsupported hidden_dims value: {type(value)!r}")
350
 
351
 
352
- def _coerce_tensor(value: Any) -> torch.Tensor | None:
353
- if value is None or not isinstance(value, torch.Tensor):
354
  return None
355
  return value.detach().cpu()
356
 
357
 
358
- def _coerce_optional_int(value: Any, fallback: int | None) -> int | None:
359
- if value is None:
360
- return fallback
361
- try:
362
- return int(value)
363
- except (TypeError, ValueError):
364
- return fallback
365
-
366
-
367
- def _coerce_location(value: Any, fallback: str | None) -> str | None:
368
- if isinstance(value, str) and value:
369
- return value
370
- return fallback
371
-
372
-
373
  def _normalize_labels(raw_labels: Any, num_classes: int) -> list[str | None]:
374
  if isinstance(raw_labels, (list, tuple)):
375
  labels = [str(label) for label in raw_labels[:num_classes]]
 
225
  num_classes=num_classes,
226
  )
227
  labels = _normalize_labels(payload.get("idx_to_label"), num_classes)
228
+
229
+ raw_layer = payload.get("layer")
230
+ try:
231
+ layer = int(raw_layer) if raw_layer is not None else metadata.layer
232
+ except (TypeError, ValueError):
233
+ layer = metadata.layer
234
+ raw_location = payload.get("location")
235
+ location = (
236
+ raw_location
237
+ if isinstance(raw_location, str) and raw_location
238
+ else metadata.location
239
+ )
240
+
241
  return LoadedProbe(
242
  model=model,
243
  input_dim=input_dim,
244
  labels=labels,
245
  model_type=str(payload.get("model_type") or metadata.model_type),
246
+ layer=layer,
247
+ location=location,
248
+ scaler_mean=_as_cpu_tensor(payload.get("scaler_mean")),
249
+ scaler_std=_as_cpu_tensor(payload.get("scaler_std")),
250
  )
251
 
252
 
 
309
  weights = [
310
  tensor
311
  for key, tensor in state_dict.items()
312
+ if key.endswith("weight")
313
+ and isinstance(tensor, torch.Tensor)
314
+ and tensor.ndim == 2
315
  ]
316
  if not weights:
317
  raise ValueError(f"Cannot infer probe {dim} dimension from state dict")
 
364
  raise TypeError(f"Unsupported hidden_dims value: {type(value)!r}")
365
 
366
 
367
+ def _as_cpu_tensor(value: Any) -> torch.Tensor | None:
368
+ if not isinstance(value, torch.Tensor):
369
  return None
370
  return value.detach().cpu()
371
 
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  def _normalize_labels(raw_labels: Any, num_classes: int) -> list[str | None]:
374
  if isinstance(raw_labels, (list, tuple)):
375
  labels = [str(label) for label in raw_labels[:num_classes]]
uv.lock CHANGED
@@ -1120,11 +1120,11 @@ wheels = [
1120
 
1121
  [[package]]
1122
  name = "narwhals"
1123
- version = "2.20.0"
1124
  source = { registry = "https://pypi.org/simple" }
1125
- sdist = { url = "https://files.pythonhosted.org/packages/e9/f3/257adc69a71011b4c8cda321b00f02c5bf1980ae38ffd05a58d9632d4de8/narwhals-2.20.0.tar.gz", hash = "sha256:c10994975fa7dc5a68c2cffcddbd5908fc8ebb2d463c5bab085309c0ee1f551e", size = 627848, upload-time = "2026-04-20T12:11:45.427Z" }
1126
  wheels = [
1127
- { url = "https://files.pythonhosted.org/packages/d0/69/f24d3d1c38ad69e256138b4ec2452a8c7cf66be49dc214771ae99dd4f0a0/narwhals-2.20.0-py3-none-any.whl", hash = "sha256:16e750ea5507d4ba6e8d03455b5f93a535e0405976561baea235bca5dc9f475d", size = 449373, upload-time = "2026-04-20T12:11:43.596Z" },
1128
  ]
1129
 
1130
  [[package]]
@@ -1550,7 +1550,7 @@ wheels = [
1550
 
1551
  [[package]]
1552
  name = "persona-data"
1553
- version = "0.4.1"
1554
  source = { registry = "https://pypi.org/simple" }
1555
  dependencies = [
1556
  { name = "huggingface-hub" },
@@ -1559,9 +1559,9 @@ dependencies = [
1559
  { name = "python-dotenv" },
1560
  { name = "torch" },
1561
  ]
1562
- sdist = { url = "https://files.pythonhosted.org/packages/d5/9b/b9bc22cf6393cbd38529dfdc1128963c2935060f96b6896b81349bd34050/persona_data-0.4.1.tar.gz", hash = "sha256:1e98a8999f498f95eeaaa4f46931818b2b1296b5ad500a89c7ad1e87b5aa405f", size = 9294, upload-time = "2026-05-07T10:27:00.746Z" }
1563
  wheels = [
1564
- { url = "https://files.pythonhosted.org/packages/8a/a9/10a586bfb4a585931dbc0f657c71a35946880cbb7b25316594042cc1a00a/persona_data-0.4.1-py3-none-any.whl", hash = "sha256:53780689988e487b68d826c0cd980dfe6bb13a340e01a10c042e4dc86f46e765", size = 11937, upload-time = "2026-05-07T10:26:59.865Z" },
1565
  ]
1566
 
1567
  [[package]]
@@ -1578,8 +1578,8 @@ dependencies = [
1578
 
1579
  [package.metadata]
1580
  requires-dist = [
1581
- { name = "persona-data", specifier = ">=0.4.1" },
1582
- { name = "persona-vectors", specifier = ">=0.6.1" },
1583
  { name = "plotly", specifier = ">=6.6.0" },
1584
  { name = "python-dotenv", specifier = ">=1.2.2" },
1585
  { name = "streamlit", specifier = ">=1.44.0" },
@@ -1587,7 +1587,7 @@ requires-dist = [
1587
 
1588
  [[package]]
1589
  name = "persona-vectors"
1590
- version = "0.6.1"
1591
  source = { registry = "https://pypi.org/simple" }
1592
  dependencies = [
1593
  { name = "datasets" },
@@ -1606,9 +1606,9 @@ dependencies = [
1606
  { name = "transformers" },
1607
  { name = "umap-learn" },
1608
  ]
1609
- sdist = { url = "https://files.pythonhosted.org/packages/69/f3/6da35af90c8ea5333db1763ece04a3230353ac5a76c0dc8fea705a6e86cf/persona_vectors-0.6.1.tar.gz", hash = "sha256:552ac9a0d739a453c5d9eb612cb0d0d2820a1b53ce84f490295a84105a71f7cc", size = 24311, upload-time = "2026-05-07T15:07:29.951Z" }
1610
  wheels = [
1611
- { url = "https://files.pythonhosted.org/packages/86/66/91df378258e2c0cbc7860652b07b5e65ee1949ba14be2efdb6c646a933f1/persona_vectors-0.6.1-py3-none-any.whl", hash = "sha256:593977ad19c9f23df7d86e302fe4bcf49159425da67d83281a11858026c5e85e", size = 28683, upload-time = "2026-05-07T15:07:30.791Z" },
1612
  ]
1613
 
1614
  [[package]]
@@ -2912,11 +2912,11 @@ wheels = [
2912
 
2913
  [[package]]
2914
  name = "urllib3"
2915
- version = "2.6.3"
2916
  source = { registry = "https://pypi.org/simple" }
2917
- sdist = { url = "https://files.pythonhosted.org/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556, upload-time = "2026-01-07T16:24:43.925Z" }
2918
  wheels = [
2919
- { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" },
2920
  ]
2921
 
2922
  [[package]]
 
1120
 
1121
  [[package]]
1122
  name = "narwhals"
1123
+ version = "2.21.0"
1124
  source = { registry = "https://pypi.org/simple" }
1125
+ sdist = { url = "https://files.pythonhosted.org/packages/2d/0e/3ad61eb87088cc4932e0d851531fa82f845a6230b68b091a0e298cc7e537/narwhals-2.21.0.tar.gz", hash = "sha256:7c6e7f50528e62b7a967dd864d7e117d2955d38d4f730653ce46a9861358e2dc", size = 633083, upload-time = "2026-05-08T12:29:02.587Z" }
1126
  wheels = [
1127
+ { url = "https://files.pythonhosted.org/packages/c7/e1/68c2256b69a314eba133673377ba9118c356f6342a0c02b61de449cf2bf2/narwhals-2.21.0-py3-none-any.whl", hash = "sha256:1e6617d0fca68ae1fda29e5397c4eaacd3ffc9fffe6bcd6ded0c690475e853be", size = 451943, upload-time = "2026-05-08T12:29:01.058Z" },
1128
  ]
1129
 
1130
  [[package]]
 
1550
 
1551
  [[package]]
1552
  name = "persona-data"
1553
+ version = "0.4.2"
1554
  source = { registry = "https://pypi.org/simple" }
1555
  dependencies = [
1556
  { name = "huggingface-hub" },
 
1559
  { name = "python-dotenv" },
1560
  { name = "torch" },
1561
  ]
1562
+ sdist = { url = "https://files.pythonhosted.org/packages/a4/2f/099a74e54846172a20b697b46b285eb2f0004e1db530308d6b4ff1f19079/persona_data-0.4.2.tar.gz", hash = "sha256:7870292a79b3943a77c31595140de3b2243b783222590248d09891de70e7fe1b", size = 9276, upload-time = "2026-05-08T13:59:27.58Z" }
1563
  wheels = [
1564
+ { url = "https://files.pythonhosted.org/packages/57/03/e76a48b41ee00684a4430269007e217e70f59e2597d7c862d93cfc5ac78b/persona_data-0.4.2-py3-none-any.whl", hash = "sha256:c881d6fb71af87a6fa773284076e4cb55794db6dc447a7eb0047eee2b389c855", size = 11914, upload-time = "2026-05-08T13:59:28.198Z" },
1565
  ]
1566
 
1567
  [[package]]
 
1578
 
1579
  [package.metadata]
1580
  requires-dist = [
1581
+ { name = "persona-data", specifier = ">=0.4.2" },
1582
+ { name = "persona-vectors", specifier = ">=0.6.3" },
1583
  { name = "plotly", specifier = ">=6.6.0" },
1584
  { name = "python-dotenv", specifier = ">=1.2.2" },
1585
  { name = "streamlit", specifier = ">=1.44.0" },
 
1587
 
1588
  [[package]]
1589
  name = "persona-vectors"
1590
+ version = "0.6.3"
1591
  source = { registry = "https://pypi.org/simple" }
1592
  dependencies = [
1593
  { name = "datasets" },
 
1606
  { name = "transformers" },
1607
  { name = "umap-learn" },
1608
  ]
1609
+ sdist = { url = "https://files.pythonhosted.org/packages/42/f5/57836026dc1b8c716ff6e443ba3cc8fafef108078e52f872c101f66ab61c/persona_vectors-0.6.3.tar.gz", hash = "sha256:2389aaa4ab5e83c4541556a000e0268ad3f1f2d5e741ade9830cb3da972332c5", size = 24509, upload-time = "2026-05-08T14:10:37.09Z" }
1610
  wheels = [
1611
+ { url = "https://files.pythonhosted.org/packages/3c/92/912d2a6998bcc103631597125bad5b5644c981b52e62fff229aee64139ae/persona_vectors-0.6.3-py3-none-any.whl", hash = "sha256:9a7f275c7e58990e1228a0d35ca2a8898eb8330fd4a9a627fb28fc574883d260", size = 29366, upload-time = "2026-05-08T14:10:38.184Z" },
1612
  ]
1613
 
1614
  [[package]]
 
2912
 
2913
  [[package]]
2914
  name = "urllib3"
2915
+ version = "2.7.0"
2916
  source = { registry = "https://pypi.org/simple" }
2917
+ sdist = { url = "https://files.pythonhosted.org/packages/53/0c/06f8b233b8fd13b9e5ee11424ef85419ba0d8ba0b3138bf360be2ff56953/urllib3-2.7.0.tar.gz", hash = "sha256:231e0ec3b63ceb14667c67be60f2f2c40a518cb38b03af60abc813da26505f4c", size = 433602, upload-time = "2026-05-07T16:13:18.596Z" }
2918
  wheels = [
2919
+ { url = "https://files.pythonhosted.org/packages/7f/3e/5db95bcf282c52709639744ca2a8b149baccf648e39c8cc87553df9eae0c/urllib3-2.7.0-py3-none-any.whl", hash = "sha256:9fb4c81ebbb1ce9531cce37674bbc6f1360472bc18ca9a553ede278ef7276897", size = 131087, upload-time = "2026-05-07T16:13:17.151Z" },
2920
  ]
2921
 
2922
  [[package]]