Jac-Zac commited on
Commit
eb41f91
·
1 Parent(s): f4259c0

Updated to latest version

Browse files

Updated message

Fix bug

Fix bugs

Files changed (13) hide show
  1. .env.example +1 -0
  2. README.md +11 -2
  3. app.py +39 -5
  4. pyproject.toml +7 -6
  5. state.py +6 -4
  6. tabs/chat.py +206 -61
  7. tabs/compare.py +71 -57
  8. tabs/extract.py +58 -13
  9. utils/chat.py +6 -12
  10. utils/chat_export.py +1 -1
  11. utils/datasets.py +2 -2
  12. utils/helpers.py +0 -4
  13. uv.lock +32 -57
.env.example CHANGED
@@ -9,6 +9,7 @@ NDIF_API_KEY=your-ndif-api-key-here
9
  # Defaults to ~/.cache/huggingface if unset
10
  # Useful when working on a cluster with a shared cache or limited home quota
11
  HF_HOME=/path/to/your/hf/cache
 
12
 
13
  # Root directory for all generated artifacts (activations, plots, etc.)
14
  # Defaults to artifacts if unset
 
9
  # Defaults to ~/.cache/huggingface if unset
10
  # Useful when working on a cluster with a shared cache or limited home quota
11
  HF_HOME=/path/to/your/hf/cache
12
+ HF_TOKEN=your-token
13
 
14
  # Root directory for all generated artifacts (activations, plots, etc.)
15
  # Defaults to artifacts if unset
README.md CHANGED
@@ -42,9 +42,15 @@ uv sync
42
  cp .env.example .env
43
  ```
44
 
 
 
 
 
 
 
45
  ## Local Setup Note
46
 
47
- For now, `persona-data` and `persona-vectors` need to be checked out in the parent directory of `persona-ui`.
48
 
49
  Example:
50
 
@@ -80,6 +86,9 @@ ARTIFACTS_DIR=... # Optional: where activations are read from (default: ./a
80
 
81
  The app picks up this file automatically via `load_dotenv()` on startup.
82
 
 
 
 
83
  ## Saved Artifacts
84
 
85
  The Compare and Extract tabs read from / write to:
@@ -88,7 +97,7 @@ The Compare and Extract tabs read from / write to:
88
  artifacts/
89
  ├── activations/<model_dir>/<prompt_variant>/<persona_id>/
90
  │ ├── activations.safetensors
91
- │ └── metadata.json
92
  └── chats/<model_dir>/<prompt_variant>/
93
  └── <export>.json
94
  ```
 
42
  cp .env.example .env
43
  ```
44
 
45
+ ## Local Development
46
+
47
+ The committed dependency graph uses git sources so `persona-ui` can install cleanly in a Hugging Face Space or any isolated environment.
48
+
49
+ For local sibling checkouts, uncomment the `path` sources in `persona-ui/pyproject.toml` and `persona-vectors/pyproject.toml`, then comment out the git sources.
50
+
51
  ## Local Setup Note
52
 
53
+ For local development, `persona-data` and `persona-vectors` can still be checked out in the parent directory of `persona-ui`.
54
 
55
  Example:
56
 
 
86
 
87
  The app picks up this file automatically via `load_dotenv()` on startup.
88
 
89
+ You can also override the active NDIF or Hugging Face token from the sidebar
90
+ `API Keys` section. Those inputs only apply for the current session.
91
+
92
  ## Saved Artifacts
93
 
94
  The Compare and Extract tabs read from / write to:
 
97
  artifacts/
98
  ├── activations/<model_dir>/<prompt_variant>/<persona_id>/
99
  │ ├── activations.safetensors
100
+ │ └── metadata.json # used for persona names and layer counts
101
  └── chats/<model_dir>/<prompt_variant>/
102
  └── <export>.json
103
  ```
app.py CHANGED
@@ -8,6 +8,42 @@ from utils.helpers import DATASET_SOURCES
8
  load_dotenv()
9
  DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
10
  REMOTE_DEFAULT_MODEL = os.environ.get("REMOTE_DEFAULT_MODEL", "google/gemma-2-9b-it")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  def _sidebar_controls() -> tuple[bool, str, str, str]:
@@ -18,7 +54,7 @@ def _sidebar_controls() -> tuple[bool, str, str, str]:
18
  st.caption("Chat, extract, and compare persona runs.")
19
 
20
  if "sidebar__active_tab" not in st.session_state:
21
- st.session_state["sidebar__active_tab"] = _TABS[0]
22
 
23
  active_tab = st.session_state["sidebar__active_tab"]
24
  for tab_name, icon in zip(_TABS, _TAB_ICONS, strict=True):
@@ -71,11 +107,9 @@ def _sidebar_controls() -> tuple[bool, str, str, str]:
71
  help="Dataset for Chat and Extract.",
72
  )
73
 
74
- return remote, model_name, dataset_source, active_tab
75
-
76
 
77
- _TABS = ["Chat", "Compare", "Extract"]
78
- _TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
79
 
80
 
81
  def main() -> None:
 
8
  load_dotenv()
9
  DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
10
  REMOTE_DEFAULT_MODEL = os.environ.get("REMOTE_DEFAULT_MODEL", "google/gemma-2-9b-it")
11
+ NDIF_API_KEY = os.environ.get("NDIF_API_KEY", "")
12
+ HF_TOKEN = os.environ.get("HF_TOKEN", os.environ.get("HUGGING_FACE_HUB_TOKEN", ""))
13
+
14
+
15
+ _TABS = ["Chat", "Compare", "Extract"]
16
+ _TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
17
+
18
+
19
+ def _sync_sidebar_api_key(env_var: str, value: str) -> None:
20
+ if value:
21
+ os.environ[env_var] = value
22
+
23
+
24
+ def _sidebar_api_keys() -> None:
25
+ with st.sidebar:
26
+ st.divider()
27
+ st.caption("API Keys")
28
+
29
+ ndif_api_key = st.text_input(
30
+ "NDIF API key",
31
+ value=NDIF_API_KEY,
32
+ type="password",
33
+ key="sidebar__ndif_api_key",
34
+ help="Overrides NDIF_API_KEY for this session.",
35
+ )
36
+ _sync_sidebar_api_key("NDIF_API_KEY", ndif_api_key)
37
+
38
+ hf_token = st.text_input(
39
+ "Hugging Face token",
40
+ value=HF_TOKEN,
41
+ type="password",
42
+ key="sidebar__hf_token",
43
+ help="Overrides HF_TOKEN and HUGGING_FACE_HUB_TOKEN for this session.",
44
+ )
45
+ _sync_sidebar_api_key("HF_TOKEN", hf_token)
46
+ _sync_sidebar_api_key("HUGGING_FACE_HUB_TOKEN", hf_token)
47
 
48
 
49
  def _sidebar_controls() -> tuple[bool, str, str, str]:
 
54
  st.caption("Chat, extract, and compare persona runs.")
55
 
56
  if "sidebar__active_tab" not in st.session_state:
57
+ st.session_state["sidebar__active_tab"] = "Chat"
58
 
59
  active_tab = st.session_state["sidebar__active_tab"]
60
  for tab_name, icon in zip(_TABS, _TAB_ICONS, strict=True):
 
107
  help="Dataset for Chat and Extract.",
108
  )
109
 
110
+ _sidebar_api_keys()
 
111
 
112
+ return remote, model_name, dataset_source, active_tab
 
113
 
114
 
115
  def main() -> None:
pyproject.toml CHANGED
@@ -5,18 +5,19 @@ description = "Streamlit UI for persona-vectors"
5
  readme = "README.md"
6
  requires-python = ">=3.10"
7
  dependencies = [
8
- "persona-vectors",
9
- "persona-data",
10
  "streamlit>=1.44.0",
11
  "plotly>=6.6.0",
12
  "python-dotenv>=1.2.2",
13
  ]
14
 
15
  [tool.uv.sources]
16
- persona-vectors = { path = "../persona-vectors", editable = true }
17
- persona-data = { path = "../persona-data", editable = true }
18
- # persona-vectors = { git = "ssh://git@github.com/implicit-personalization/persona-vectors.git" } # use for release
19
- # persona-data = { git = "ssh://git@github.com/implicit-personalization/persona-data.git" } # use for release
 
20
 
21
  # [build-system]
22
  # requires = ["uv_build>=0.11.3,<0.12"]
 
5
  readme = "README.md"
6
  requires-python = ">=3.10"
7
  dependencies = [
8
+ "persona-vectors>=0.1.0",
9
+ "persona-data>=0.1.0",
10
  "streamlit>=1.44.0",
11
  "plotly>=6.6.0",
12
  "python-dotenv>=1.2.2",
13
  ]
14
 
15
  [tool.uv.sources]
16
+ # Local development:
17
+ # persona-vectors = { path = "../persona-vectors", editable = true }
18
+ # persona-data = { path = "../persona-data", editable = true }
19
+ persona-vectors = { git = "ssh://git@github.com/implicit-personalization/persona-vectors.git" }
20
+ persona-data = { git = "ssh://git@github.com/implicit-personalization/persona-data.git" }
21
 
22
  # [build-system]
23
  # requires = ["uv_build>=0.11.3,<0.12"]
state.py CHANGED
@@ -51,9 +51,11 @@ def get_chat_state(
51
  return state
52
 
53
 
54
- def reset_chat_state(model_name: str, remote: bool, dataset_source: str) -> None:
55
  """Reset chat history and cache for the active context."""
56
 
57
- state = get_chat_state(model_name, remote, dataset_source)
58
- state["messages"] = []
59
- state["past_key_values"] = None
 
 
 
51
  return state
52
 
53
 
54
+ def reset_chat_state(model_name: str, dataset_source: str) -> None:
55
  """Reset chat history and cache for the active context."""
56
 
57
+ key = chat_session_key(model_name, dataset_source)
58
+ if key in st.session_state:
59
+ state = st.session_state[key]
60
+ state["messages"] = []
61
+ state["past_key_values"] = None
tabs/chat.py CHANGED
@@ -23,12 +23,118 @@ from utils.helpers import (
23
  )
24
  from utils.runtime import cached_model
25
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def _render_chat_message(message: dict[str, str]) -> None:
28
  if not message.get("content"):
29
  return
30
  with st.chat_message(message["role"]):
31
- st.markdown(message["content"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
 
34
  def _clear_chat_ui_state(*keys: str) -> None:
@@ -38,14 +144,13 @@ def _clear_chat_ui_state(*keys: str) -> None:
38
 
39
  def _reset_single_chat_context(
40
  model_name: str,
41
- remote: bool,
42
  dataset_source: str,
43
  chat_state: dict[str, object],
44
  persona_id: str,
45
  prompt_mode: str,
46
  *ui_keys: str,
47
  ) -> None:
48
- reset_chat_state(model_name, remote, dataset_source)
49
  chat_state["persona_id"] = persona_id
50
  chat_state["prompt_mode"] = prompt_mode
51
  _clear_chat_ui_state(*ui_keys)
@@ -101,35 +206,6 @@ def _render_persona_prompt_controls(
101
  return selected_persona, prompt_mode, changed
102
 
103
 
104
- def _render_system_prompt_editor(
105
- prompt_key: str,
106
- prompt_mode: str,
107
- active_system_prompt: str | None,
108
- *,
109
- height: int,
110
- label: str = "Prompt",
111
- ) -> str | None:
112
- """Render the editable system prompt area for a chat panel."""
113
-
114
- if prompt_mode == "empty":
115
- return active_system_prompt
116
-
117
- if prompt_key not in st.session_state:
118
- st.session_state[prompt_key] = active_system_prompt or ""
119
-
120
- with st.expander("Edit prompt", expanded=False):
121
- edited_prompt = (
122
- st.text_area(
123
- label,
124
- key=prompt_key,
125
- height=height,
126
- label_visibility="collapsed",
127
- )
128
- or None
129
- )
130
- return edited_prompt
131
-
132
-
133
  def _render_chat_window(
134
  *,
135
  chat_log: Any,
@@ -137,6 +213,9 @@ def _render_chat_window(
137
  show_all_key: str,
138
  show_all_btn_key: str,
139
  show_earlier_label: str,
 
 
 
140
  ) -> Any:
141
  """Render the visible chat history inside one container."""
142
 
@@ -152,11 +231,19 @@ def _render_chat_window(
152
  st.session_state[show_all_key] = True
153
  st.rerun()
154
  visible_messages = messages[-VISIBLE_MESSAGE_COUNT:]
 
155
  else:
156
  visible_messages = messages
 
157
 
158
- for message in visible_messages:
159
- _render_chat_message(message)
 
 
 
 
 
 
160
 
161
  return chat_log
162
 
@@ -218,7 +305,9 @@ def _render_compare_mode(
218
  """Render the full side-by-side comparison UI."""
219
  left_col, right_col = st.columns(2)
220
 
221
- def render_panel(side: str, column) -> tuple[dict[str, object], Any, str | None]:
 
 
222
  panel_key = widget_key(context_key, f"cmp_{side}")
223
  state = st.session_state.get(panel_key)
224
  if state is None:
@@ -226,6 +315,8 @@ def _render_compare_mode(
226
  st.session_state[panel_key] = state
227
  prompt_key = widget_key(panel_key, "custom_prompt")
228
  show_all_key = widget_key(panel_key, "show_all")
 
 
229
 
230
  selected_persona, prompt_mode, changed = _render_persona_prompt_controls(
231
  personas,
@@ -240,16 +331,11 @@ def _render_compare_mode(
240
  state["persona_id"] = selected_persona.id
241
  state["prompt_mode"] = prompt_mode
242
  _clear_chat_ui_state(prompt_key, show_all_key)
 
243
 
244
  active_system_prompt = resolve_system_prompt(
245
  persona=selected_persona, mode=prompt_mode
246
  )
247
- active_system_prompt = _render_system_prompt_editor(
248
- prompt_key,
249
- prompt_mode,
250
- active_system_prompt,
251
- height=150,
252
- )
253
 
254
  btn_col1, btn_col2 = st.columns(2)
255
  with btn_col1:
@@ -279,22 +365,73 @@ def _render_compare_mode(
279
  state["messages"] = []
280
  state["past_key_values"] = None
281
  _clear_chat_ui_state(prompt_key, show_all_key)
 
282
  st.rerun()
283
 
284
  chat_log = st.container()
 
 
 
 
 
 
 
 
285
  _render_chat_window(
286
  chat_log=chat_log,
287
  messages=state["messages"],
288
  show_all_key=show_all_key,
289
  show_all_btn_key=widget_key(panel_key, "show_all_btn"),
290
  show_earlier_label="Show earlier",
 
 
 
291
  )
292
- return state, chat_log, active_system_prompt
293
 
294
  with left_col:
295
- left_state, left_log, left_prompt = render_panel("left", left_col)
296
  with right_col:
297
- right_state, right_log, right_prompt = render_panel("right", right_col)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
  user_prompt = st.chat_input(
300
  "Ask both...",
@@ -304,12 +441,8 @@ def _render_compare_mode(
304
  return
305
 
306
  model = cached_model(model_name=model_name, remote=remote)
307
- panels = [
308
- (left_state, left_log, left_prompt),
309
- (right_state, right_log, right_prompt),
310
- ]
311
 
312
- for panel_state, panel_log, _panel_prompt in panels:
313
  panel_state["messages"].append({"role": "user", "content": user_prompt})
314
  with panel_log:
315
  _render_chat_message({"role": "user", "content": user_prompt})
@@ -331,7 +464,7 @@ def _render_compare_mode(
331
  past_key_values=panel_state["past_key_values"],
332
  **gen_kwargs,
333
  )
334
- for panel_state, _panel_log, panel_prompt in panels
335
  ]
336
  results: list[ChatReply | Exception] = []
337
  for future in futures:
@@ -341,7 +474,7 @@ def _render_compare_mode(
341
  results.append(exc)
342
  else:
343
  results = []
344
- for panel_state, _panel_log, panel_prompt in panels:
345
  try:
346
  results.append(
347
  generate_chat_reply(
@@ -360,7 +493,9 @@ def _render_compare_mode(
360
  except Exception as exc:
361
  results.append(exc)
362
 
363
- for (panel_state, panel_log, _panel_prompt), result in zip(panels, results):
 
 
364
  if isinstance(result, Exception):
365
  with panel_log:
366
  st.error(f"Generation failed: {result}")
@@ -384,7 +519,11 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
384
  context_key = chat_session_key(model_name, dataset_source)
385
  chat_state = get_chat_state(model_name, remote, dataset_source)
386
  try:
387
- dataset, dataset_status = load_dataset(dataset_source)
 
 
 
 
388
  st.caption(dataset_status)
389
  except Exception as exc:
390
  st.error(f"Could not load data: {exc}")
@@ -534,6 +673,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
534
  pending_key = widget_key(context_key, "pending_prompt")
535
  export_key = widget_key(context_key, "export_chat")
536
  reset_key = widget_key(context_key, "reset")
 
537
 
538
  col1, col2 = st.columns([2, 1])
539
  with col1:
@@ -571,7 +711,6 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
571
  had_history = bool(chat_state["messages"])
572
  _reset_single_chat_context(
573
  model_name,
574
- remote,
575
  dataset_source,
576
  chat_state,
577
  selected_persona.id,
@@ -581,17 +720,20 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
581
  prompt_key,
582
  pending_key,
583
  )
 
584
  if had_history:
585
  st.info("Chat history reset because the persona or system prompt changed.")
586
 
587
  chat_log = st.container()
588
 
589
- active_system_prompt = _render_system_prompt_editor(
590
- prompt_key,
591
- prompt_mode,
592
- active_system_prompt,
593
- height=200,
594
- )
 
 
595
 
596
  action_col1, action_col2 = st.columns(2)
597
  with action_col1:
@@ -612,7 +754,6 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
612
  if st.button("Reset chat", key=reset_key, width="stretch", type="secondary"):
613
  _reset_single_chat_context(
614
  model_name,
615
- remote,
616
  dataset_source,
617
  chat_state,
618
  selected_persona.id,
@@ -622,6 +763,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
622
  prompt_key,
623
  pending_key,
624
  )
 
625
  st.rerun()
626
 
627
  _render_chat_window(
@@ -630,6 +772,9 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
630
  show_all_key=show_all_key,
631
  show_all_btn_key=widget_key(context_key, "show_all_btn"),
632
  show_earlier_label="Show earlier messages",
 
 
 
633
  )
634
 
635
  user_prompt = st.chat_input(
 
23
  )
24
  from utils.runtime import cached_model
25
 
26
+ COLLAPSED_MESSAGE_CHAR_LIMIT = 500
27
+
28
+
29
+ def _render_collapsible_markdown(content: str) -> None:
30
+ if len(content) <= COLLAPSED_MESSAGE_CHAR_LIMIT:
31
+ st.markdown(content)
32
+ return
33
+
34
+ with st.expander(f"Show full text ({len(content)} chars)", expanded=False):
35
+ st.markdown(content)
36
+
37
 
38
  def _render_chat_message(message: dict[str, str]) -> None:
39
  if not message.get("content"):
40
  return
41
  with st.chat_message(message["role"]):
42
+ _render_collapsible_markdown(message["content"])
43
+
44
+
45
+ def _render_inline_system_prompt(
46
+ prompt_key: str,
47
+ prompt_mode: str,
48
+ active_system_prompt: str | None,
49
+ edit_key: str,
50
+ height: int = 200,
51
+ ) -> str | None:
52
+ """Render the system prompt as an inline editable item at the top of the chat."""
53
+ if prompt_mode == "empty":
54
+ return active_system_prompt
55
+
56
+ if prompt_key not in st.session_state:
57
+ st.session_state[prompt_key] = active_system_prompt or ""
58
+
59
+ current_prompt = st.session_state[prompt_key] or None
60
+ is_editing = st.session_state.get(edit_key) == -1
61
+
62
+ with st.container(border=True):
63
+ st.caption("System prompt")
64
+ if is_editing:
65
+ new_val = st.text_area(
66
+ "system_prompt_edit",
67
+ value=current_prompt or "",
68
+ height=height,
69
+ label_visibility="collapsed",
70
+ key=f"{prompt_key}_inline_edit",
71
+ )
72
+ c1, c2 = st.columns(2)
73
+ with c1:
74
+ if st.button("Save", key=f"{edit_key}_sys_save", type="primary"):
75
+ st.session_state[prompt_key] = new_val
76
+ st.session_state[edit_key] = None
77
+ st.rerun()
78
+ with c2:
79
+ if st.button("Cancel", key=f"{edit_key}_sys_cancel"):
80
+ st.session_state[edit_key] = None
81
+ st.rerun()
82
+ else:
83
+ if current_prompt:
84
+ _render_collapsible_markdown(current_prompt)
85
+ else:
86
+ st.markdown("*(empty)*")
87
+ if st.button("Edit", key=f"{edit_key}_sys_edit"):
88
+ st.session_state[edit_key] = -1
89
+ st.rerun()
90
+
91
+ return st.session_state.get(prompt_key) or None
92
+
93
+
94
+ def _render_editable_message(
95
+ message: dict[str, str],
96
+ msg_index: int,
97
+ messages: list[dict[str, str]],
98
+ chat_state: dict[str, object],
99
+ edit_key: str,
100
+ pending_key: str,
101
+ ) -> None:
102
+ """Render a single message with an inline edit button."""
103
+ if not message.get("content"):
104
+ return
105
+
106
+ is_editing = st.session_state.get(edit_key) == msg_index
107
+
108
+ with st.chat_message(message["role"]):
109
+ if is_editing:
110
+ new_content = st.text_area(
111
+ "Edit",
112
+ value=message["content"],
113
+ height=100,
114
+ label_visibility="collapsed",
115
+ key=f"{edit_key}_msg_{msg_index}",
116
+ )
117
+ c1, c2 = st.columns(2)
118
+ with c1:
119
+ if st.button(
120
+ "Save", key=f"{edit_key}_msg_save_{msg_index}", type="primary"
121
+ ):
122
+ messages[msg_index]["content"] = new_content
123
+ del messages[msg_index + 1 :]
124
+ chat_state["past_key_values"] = None
125
+ st.session_state[edit_key] = None
126
+ if message["role"] == "user":
127
+ st.session_state[pending_key] = True
128
+ st.rerun()
129
+ with c2:
130
+ if st.button("Cancel", key=f"{edit_key}_msg_cancel_{msg_index}"):
131
+ st.session_state[edit_key] = None
132
+ st.rerun()
133
+ else:
134
+ st.markdown(message["content"])
135
+ if st.button("Edit", key=f"{edit_key}_msg_edit_{msg_index}"):
136
+ st.session_state[edit_key] = msg_index
137
+ st.rerun()
138
 
139
 
140
  def _clear_chat_ui_state(*keys: str) -> None:
 
144
 
145
  def _reset_single_chat_context(
146
  model_name: str,
 
147
  dataset_source: str,
148
  chat_state: dict[str, object],
149
  persona_id: str,
150
  prompt_mode: str,
151
  *ui_keys: str,
152
  ) -> None:
153
+ reset_chat_state(model_name, dataset_source)
154
  chat_state["persona_id"] = persona_id
155
  chat_state["prompt_mode"] = prompt_mode
156
  _clear_chat_ui_state(*ui_keys)
 
206
  return selected_persona, prompt_mode, changed
207
 
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  def _render_chat_window(
210
  *,
211
  chat_log: Any,
 
213
  show_all_key: str,
214
  show_all_btn_key: str,
215
  show_earlier_label: str,
216
+ chat_state: dict[str, object] | None = None,
217
+ edit_key: str | None = None,
218
+ pending_key: str | None = None,
219
  ) -> Any:
220
  """Render the visible chat history inside one container."""
221
 
 
231
  st.session_state[show_all_key] = True
232
  st.rerun()
233
  visible_messages = messages[-VISIBLE_MESSAGE_COUNT:]
234
+ index_offset = len(messages) - VISIBLE_MESSAGE_COUNT
235
  else:
236
  visible_messages = messages
237
+ index_offset = 0
238
 
239
+ for i, message in enumerate(visible_messages):
240
+ actual_index = index_offset + i
241
+ if edit_key and pending_key:
242
+ _render_editable_message(
243
+ message, actual_index, messages, chat_state, edit_key, pending_key
244
+ )
245
+ else:
246
+ _render_chat_message(message)
247
 
248
  return chat_log
249
 
 
305
  """Render the full side-by-side comparison UI."""
306
  left_col, right_col = st.columns(2)
307
 
308
+ def render_panel(
309
+ side: str, column
310
+ ) -> tuple[dict[str, object], Any, str | None, str]:
311
  panel_key = widget_key(context_key, f"cmp_{side}")
312
  state = st.session_state.get(panel_key)
313
  if state is None:
 
315
  st.session_state[panel_key] = state
316
  prompt_key = widget_key(panel_key, "custom_prompt")
317
  show_all_key = widget_key(panel_key, "show_all")
318
+ edit_key = widget_key(panel_key, "edit_idx")
319
+ pending_regen_key = widget_key(panel_key, "pending_regen")
320
 
321
  selected_persona, prompt_mode, changed = _render_persona_prompt_controls(
322
  personas,
 
331
  state["persona_id"] = selected_persona.id
332
  state["prompt_mode"] = prompt_mode
333
  _clear_chat_ui_state(prompt_key, show_all_key)
334
+ st.session_state.pop(edit_key, None)
335
 
336
  active_system_prompt = resolve_system_prompt(
337
  persona=selected_persona, mode=prompt_mode
338
  )
 
 
 
 
 
 
339
 
340
  btn_col1, btn_col2 = st.columns(2)
341
  with btn_col1:
 
365
  state["messages"] = []
366
  state["past_key_values"] = None
367
  _clear_chat_ui_state(prompt_key, show_all_key)
368
+ st.session_state.pop(edit_key, None)
369
  st.rerun()
370
 
371
  chat_log = st.container()
372
+ with chat_log:
373
+ active_system_prompt = _render_inline_system_prompt(
374
+ prompt_key,
375
+ prompt_mode,
376
+ active_system_prompt,
377
+ edit_key,
378
+ height=150,
379
+ )
380
  _render_chat_window(
381
  chat_log=chat_log,
382
  messages=state["messages"],
383
  show_all_key=show_all_key,
384
  show_all_btn_key=widget_key(panel_key, "show_all_btn"),
385
  show_earlier_label="Show earlier",
386
+ chat_state=state,
387
+ edit_key=edit_key,
388
+ pending_key=pending_regen_key,
389
  )
390
+ return state, chat_log, active_system_prompt, pending_regen_key
391
 
392
  with left_col:
393
+ left_state, left_log, left_prompt, left_pending = render_panel("left", left_col)
394
  with right_col:
395
+ right_state, right_log, right_prompt, right_pending = render_panel(
396
+ "right", right_col
397
+ )
398
+
399
+ panels = [
400
+ (left_state, left_log, left_prompt, left_pending),
401
+ (right_state, right_log, right_prompt, right_pending),
402
+ ]
403
+
404
+ # Handle per-panel regeneration triggered by message edits
405
+ any_regen = any(st.session_state.get(p_pending) for _, _, _, p_pending in panels)
406
+ if any_regen:
407
+ model = cached_model(model_name=model_name, remote=remote)
408
+ for panel_state, panel_log, panel_prompt, p_pending in panels:
409
+ if not st.session_state.pop(p_pending, False):
410
+ continue
411
+ regen_messages = _build_chat_messages(panel_prompt, panel_state["messages"])
412
+ with st.spinner("Regenerating..."):
413
+ try:
414
+ result = generate_chat_reply(
415
+ model=model,
416
+ messages=regen_messages,
417
+ remote=remote,
418
+ past_key_values=panel_state["past_key_values"],
419
+ **gen_kwargs,
420
+ )
421
+ except Exception as exc:
422
+ with panel_log:
423
+ st.error(f"Generation failed: {exc}")
424
+ panel_state["messages"].pop()
425
+ continue
426
+ panel_state["messages"].append(
427
+ {"role": "assistant", "content": result.text}
428
+ )
429
+ panel_state["past_key_values"] = (
430
+ result.past_key_values if not remote else None
431
+ )
432
+ with panel_log:
433
+ _render_chat_message({"role": "assistant", "content": result.text})
434
+ st.rerun()
435
 
436
  user_prompt = st.chat_input(
437
  "Ask both...",
 
441
  return
442
 
443
  model = cached_model(model_name=model_name, remote=remote)
 
 
 
 
444
 
445
+ for panel_state, panel_log, _panel_prompt, _p_pending in panels:
446
  panel_state["messages"].append({"role": "user", "content": user_prompt})
447
  with panel_log:
448
  _render_chat_message({"role": "user", "content": user_prompt})
 
464
  past_key_values=panel_state["past_key_values"],
465
  **gen_kwargs,
466
  )
467
+ for panel_state, _panel_log, panel_prompt, _p_pending in panels
468
  ]
469
  results: list[ChatReply | Exception] = []
470
  for future in futures:
 
474
  results.append(exc)
475
  else:
476
  results = []
477
+ for panel_state, _panel_log, panel_prompt, _p_pending in panels:
478
  try:
479
  results.append(
480
  generate_chat_reply(
 
493
  except Exception as exc:
494
  results.append(exc)
495
 
496
+ for (panel_state, panel_log, _panel_prompt, _p_pending), result in zip(
497
+ panels, results
498
+ ):
499
  if isinstance(result, Exception):
500
  with panel_log:
501
  st.error(f"Generation failed: {result}")
 
519
  context_key = chat_session_key(model_name, dataset_source)
520
  chat_state = get_chat_state(model_name, remote, dataset_source)
521
  try:
522
+ dataset, dataset_status = load_dataset(
523
+ dataset_source,
524
+ personas_file=st.session_state.get("extract__personas_file"),
525
+ qa_file=st.session_state.get("extract__qa_file"),
526
+ )
527
  st.caption(dataset_status)
528
  except Exception as exc:
529
  st.error(f"Could not load data: {exc}")
 
673
  pending_key = widget_key(context_key, "pending_prompt")
674
  export_key = widget_key(context_key, "export_chat")
675
  reset_key = widget_key(context_key, "reset")
676
+ edit_key = widget_key(context_key, "edit_idx")
677
 
678
  col1, col2 = st.columns([2, 1])
679
  with col1:
 
711
  had_history = bool(chat_state["messages"])
712
  _reset_single_chat_context(
713
  model_name,
 
714
  dataset_source,
715
  chat_state,
716
  selected_persona.id,
 
720
  prompt_key,
721
  pending_key,
722
  )
723
+ st.session_state.pop(edit_key, None)
724
  if had_history:
725
  st.info("Chat history reset because the persona or system prompt changed.")
726
 
727
  chat_log = st.container()
728
 
729
+ with chat_log:
730
+ active_system_prompt = _render_inline_system_prompt(
731
+ prompt_key,
732
+ prompt_mode,
733
+ active_system_prompt,
734
+ edit_key,
735
+ height=200,
736
+ )
737
 
738
  action_col1, action_col2 = st.columns(2)
739
  with action_col1:
 
754
  if st.button("Reset chat", key=reset_key, width="stretch", type="secondary"):
755
  _reset_single_chat_context(
756
  model_name,
 
757
  dataset_source,
758
  chat_state,
759
  selected_persona.id,
 
763
  prompt_key,
764
  pending_key,
765
  )
766
+ st.session_state.pop(edit_key, None)
767
  st.rerun()
768
 
769
  _render_chat_window(
 
772
  show_all_key=show_all_key,
773
  show_all_btn_key=widget_key(context_key, "show_all_btn"),
774
  show_earlier_label="Show earlier messages",
775
+ chat_state=chat_state,
776
+ edit_key=edit_key,
777
+ pending_key=pending_key,
778
  )
779
 
780
  user_prompt = st.chat_input(
tabs/compare.py CHANGED
@@ -5,7 +5,7 @@ import streamlit as st
5
  import torch
6
  from persona_data.environment import get_artifacts_dir
7
  from persona_vectors.analysis import build_embedding_figure, project_pca, project_umap
8
- from persona_vectors.artifacts import ActivationStore
9
  from persona_vectors.artifacts import list_layers as list_available_layers
10
  from persona_vectors.artifacts import list_personas as list_available_personas
11
  from persona_vectors.artifacts import load_mean_activations, load_persona_names
@@ -14,7 +14,6 @@ from persona_vectors.plots import plot_layer_similarity, save_plot_html, save_pl
14
  from utils.helpers import (
15
  ANALYSIS_HELP_TEXT,
16
  ANALYSIS_MODES,
17
- PROMPT_VARIANTS,
18
  persona_display_label,
19
  prompt_variant_label,
20
  slugify,
@@ -34,20 +33,27 @@ class ProjectionConfig:
34
  project_fn: Callable[[torch.Tensor], torch.Tensor]
35
 
36
 
 
 
 
 
 
 
 
 
 
37
  _PROJECTION_CONFIGS: dict[str, ProjectionConfig] = {
38
  "PCA": ProjectionConfig("PCA", "PC1", "PC2", project_pca),
39
  "UMAP": ProjectionConfig("UMAP", "UMAP 1", "UMAP 2", project_umap),
40
  }
41
 
 
42
 
43
- @st.cache_data(show_spinner=False)
44
- def _list_layers(
45
- root_dir: str,
46
- model_name: str,
47
- variants: list[str],
48
- persona_ids: list[str],
49
- ) -> list[int]:
50
- return list_available_layers(root_dir, model_name, variants, persona_ids)
51
 
52
 
53
  def _load_embedding_samples(
@@ -86,9 +92,9 @@ def _load_embedding_samples(
86
  continue
87
 
88
  layer_vectors = vectors[:, layer_idx, :]
89
- samples.append(layer_vectors)
90
- labels.extend([persona_id] * layer_vectors.shape[0])
91
  display_name = persona_names.get(persona_id) or persona_id
 
 
92
  hover_text.extend(
93
  [f"<b>{display_name}</b><br>{variant}"] * layer_vectors.shape[0]
94
  )
@@ -114,28 +120,8 @@ def _load_embedding_samples(
114
  return plots, errors
115
 
116
 
117
- def _build_embedding_figures(
118
- plots: list[tuple[int, torch.Tensor, list[str], list[str]]],
119
- config: ProjectionConfig,
120
- ) -> list[tuple[int, object]]:
121
- return [
122
- (
123
- layer_idx,
124
- build_embedding_figure(
125
- coords=coords,
126
- labels=labels,
127
- title=f"{config.title_prefix}, layer {layer_idx}",
128
- x_label=config.x_label,
129
- y_label=config.y_label,
130
- hover_text=hover_text,
131
- ),
132
- )
133
- for layer_idx, coords, labels, hover_text in plots
134
- ]
135
-
136
-
137
  def _render_embedding_results(
138
- store: ActivationStore,
139
  analysis_mode: str,
140
  rendered_figures: list[tuple[int, object]],
141
  saved_variant: str,
@@ -152,7 +138,7 @@ def _render_embedding_results(
152
  _filename(
153
  "compare",
154
  analysis_mode,
155
- store.model_name,
156
  saved_variant,
157
  saved_persona_key,
158
  str(layer_idx),
@@ -181,15 +167,20 @@ def _select_artifact_personas(
181
  st.info("No personas found for this model yet. Run extraction first.")
182
  return [], persona_names
183
 
 
 
 
 
184
  persona_ids = st.multiselect(
185
  "Personas",
186
  options=persona_options,
187
- default=persona_options[:1] if len(persona_options) > 1 else persona_options,
188
  format_func=lambda persona_id: persona_display_label(
189
  persona_id, persona_names.get(persona_id)
190
  ),
191
  key=widget_key("load", "personas", store.model_name, *variants),
192
  )
 
193
  return persona_ids, persona_names
194
 
195
 
@@ -215,11 +206,11 @@ def _render_save_buttons(
215
 
216
  def _select_embedding_config(
217
  store: ActivationStore,
218
- ) -> tuple[str, list[str], dict[str, str], list[int]] | None:
219
  """Render variant / persona / layer selectors and return the selection, or None on early exit."""
220
  selected_variant = st.selectbox(
221
  "Variant",
222
- options=PROMPT_VARIANTS,
223
  format_func=prompt_variant_label,
224
  key=widget_key("load", "variant"),
225
  )
@@ -228,7 +219,8 @@ def _select_embedding_config(
228
  if not persona_ids:
229
  return None
230
 
231
- layer_options = _list_layers(
 
232
  str(store.root_dir),
233
  store.model_name,
234
  [selected_variant],
@@ -240,14 +232,14 @@ def _select_embedding_config(
240
  )
241
  return None
242
 
243
- persona_key = "_".join(sorted(persona_ids))
244
  layer_key = widget_key(
245
  "load", "layers", store.model_name, selected_variant, persona_key
246
  )
 
 
 
247
  default_layers = [
248
- layer
249
- for layer in st.session_state.get(layer_key, layer_options[:3])
250
- if layer in layer_options
251
  ] or layer_options[:3]
252
  selected_layers = st.multiselect(
253
  "Layers",
@@ -259,7 +251,15 @@ def _select_embedding_config(
259
  st.info("Select at least one layer.")
260
  return None
261
 
262
- return selected_variant, persona_ids, persona_names, selected_layers
 
 
 
 
 
 
 
 
263
 
264
 
265
  def _render_cosine_similarity(store: ActivationStore) -> None:
@@ -267,7 +267,7 @@ def _render_cosine_similarity(store: ActivationStore) -> None:
267
  with col1:
268
  variant_a = st.selectbox(
269
  "Variant A",
270
- options=PROMPT_VARIANTS,
271
  index=0,
272
  format_func=prompt_variant_label,
273
  key=widget_key("load", "variant_a"),
@@ -275,8 +275,8 @@ def _render_cosine_similarity(store: ActivationStore) -> None:
275
  with col2:
276
  variant_b = st.selectbox(
277
  "Variant B",
278
- options=PROMPT_VARIANTS,
279
- index=min(1, len(PROMPT_VARIANTS) - 1),
280
  format_func=prompt_variant_label,
281
  key=widget_key("load", "variant_b"),
282
  )
@@ -289,7 +289,9 @@ def _render_cosine_similarity(store: ActivationStore) -> None:
289
  if not persona_ids:
290
  return
291
 
292
- cosine_fig_key = widget_key("load", "cosine_fig_state", store.model_name)
 
 
293
  filename = _filename("compare", "cosine", store.model_name, variant_a, variant_b)
294
 
295
  if st.button("Compare vectors", type="primary"):
@@ -334,8 +336,7 @@ def _render_embedding_analysis(store: ActivationStore, analysis_mode: str) -> No
334
  config = _select_embedding_config(store)
335
  if config is None:
336
  return
337
- selected_variant, persona_ids, persona_names, selected_layers = config
338
- persona_key = "_".join(sorted(persona_ids))
339
  projection_config = _PROJECTION_CONFIGS.get(analysis_mode)
340
  if projection_config is None:
341
  st.error(f"Unsupported analysis mode: {analysis_mode}")
@@ -358,11 +359,11 @@ def _render_embedding_analysis(store: ActivationStore, analysis_mode: str) -> No
358
  try:
359
  plots, errors = _load_embedding_samples(
360
  store,
361
- persona_ids,
362
- selected_variant,
363
- selected_layers,
364
  projection_config.project_fn,
365
- persona_names,
366
  progress_fn=update_progress,
367
  )
368
 
@@ -382,12 +383,25 @@ def _render_embedding_analysis(store: ActivationStore, analysis_mode: str) -> No
382
  st.info("Try fewer personas, fewer layers, or a different variant.")
383
  st.session_state.pop(embedding_fig_key, None)
384
  else:
385
- rendered_figures = _build_embedding_figures(plots, projection_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  total_samples = sum(coords.shape[0] for _, coords, _, _ in plots)
387
  st.session_state[embedding_fig_key] = (
388
  rendered_figures,
389
- persona_key,
390
- selected_variant,
391
  total_samples,
392
  )
393
  finally:
@@ -398,7 +412,7 @@ def _render_embedding_analysis(store: ActivationStore, analysis_mode: str) -> No
398
  st.session_state[embedding_fig_key]
399
  )
400
  _render_embedding_results(
401
- store,
402
  analysis_mode,
403
  rendered_figures,
404
  saved_variant,
 
5
  import torch
6
  from persona_data.environment import get_artifacts_dir
7
  from persona_vectors.analysis import build_embedding_figure, project_pca, project_umap
8
+ from persona_vectors.artifacts import SUPPORTED_VARIANTS, ActivationStore
9
  from persona_vectors.artifacts import list_layers as list_available_layers
10
  from persona_vectors.artifacts import list_personas as list_available_personas
11
  from persona_vectors.artifacts import load_mean_activations, load_persona_names
 
14
  from utils.helpers import (
15
  ANALYSIS_HELP_TEXT,
16
  ANALYSIS_MODES,
 
17
  persona_display_label,
18
  prompt_variant_label,
19
  slugify,
 
33
  project_fn: Callable[[torch.Tensor], torch.Tensor]
34
 
35
 
36
+ @dataclass(frozen=True)
37
+ class _EmbeddingConfig:
38
+ variant: str
39
+ persona_ids: list[str]
40
+ persona_names: dict[str, str]
41
+ selected_layers: list[int]
42
+ persona_key: str
43
+
44
+
45
  _PROJECTION_CONFIGS: dict[str, ProjectionConfig] = {
46
  "PCA": ProjectionConfig("PCA", "PC1", "PC2", project_pca),
47
  "UMAP": ProjectionConfig("UMAP", "UMAP 1", "UMAP 2", project_umap),
48
  }
49
 
50
+ _list_layers_cached = st.cache_data(show_spinner=False)(list_available_layers)
51
 
52
+ # Cross-model/NDIF-switch persistence keys — written on every render so that
53
+ # when the model changes (and widget keys change) the last selection is reused
54
+ # as the default, filtered to whatever is available for the new model.
55
+ _LAST_PERSONAS_KEY = "compare:last_personas"
56
+ _LAST_LAYERS_KEY = "compare:last_layers"
 
 
 
57
 
58
 
59
  def _load_embedding_samples(
 
92
  continue
93
 
94
  layer_vectors = vectors[:, layer_idx, :]
 
 
95
  display_name = persona_names.get(persona_id) or persona_id
96
+ samples.append(layer_vectors)
97
+ labels.extend([display_name] * layer_vectors.shape[0])
98
  hover_text.extend(
99
  [f"<b>{display_name}</b><br>{variant}"] * layer_vectors.shape[0]
100
  )
 
120
  return plots, errors
121
 
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  def _render_embedding_results(
124
+ model_name: str,
125
  analysis_mode: str,
126
  rendered_figures: list[tuple[int, object]],
127
  saved_variant: str,
 
138
  _filename(
139
  "compare",
140
  analysis_mode,
141
+ model_name,
142
  saved_variant,
143
  saved_persona_key,
144
  str(layer_idx),
 
167
  st.info("No personas found for this model yet. Run extraction first.")
168
  return [], persona_names
169
 
170
+ last_personas: list[str] = st.session_state.get(_LAST_PERSONAS_KEY, [])
171
+ default_personas = [
172
+ p for p in last_personas if p in persona_options
173
+ ] or persona_options[:1]
174
  persona_ids = st.multiselect(
175
  "Personas",
176
  options=persona_options,
177
+ default=default_personas,
178
  format_func=lambda persona_id: persona_display_label(
179
  persona_id, persona_names.get(persona_id)
180
  ),
181
  key=widget_key("load", "personas", store.model_name, *variants),
182
  )
183
+ st.session_state[_LAST_PERSONAS_KEY] = persona_ids
184
  return persona_ids, persona_names
185
 
186
 
 
206
 
207
  def _select_embedding_config(
208
  store: ActivationStore,
209
+ ) -> _EmbeddingConfig | None:
210
  """Render variant / persona / layer selectors and return the selection, or None on early exit."""
211
  selected_variant = st.selectbox(
212
  "Variant",
213
+ options=SUPPORTED_VARIANTS,
214
  format_func=prompt_variant_label,
215
  key=widget_key("load", "variant"),
216
  )
 
219
  if not persona_ids:
220
  return None
221
 
222
+ persona_key = "_".join(sorted(persona_ids))
223
+ layer_options = _list_layers_cached(
224
  str(store.root_dir),
225
  store.model_name,
226
  [selected_variant],
 
232
  )
233
  return None
234
 
 
235
  layer_key = widget_key(
236
  "load", "layers", store.model_name, selected_variant, persona_key
237
  )
238
+ last_layers: list[int] = st.session_state.get(
239
+ layer_key, st.session_state.get(_LAST_LAYERS_KEY, layer_options[:3])
240
+ )
241
  default_layers = [
242
+ layer for layer in last_layers if layer in layer_options
 
 
243
  ] or layer_options[:3]
244
  selected_layers = st.multiselect(
245
  "Layers",
 
251
  st.info("Select at least one layer.")
252
  return None
253
 
254
+ st.session_state[_LAST_LAYERS_KEY] = selected_layers
255
+
256
+ return _EmbeddingConfig(
257
+ variant=selected_variant,
258
+ persona_ids=persona_ids,
259
+ persona_names=persona_names,
260
+ selected_layers=selected_layers,
261
+ persona_key=persona_key,
262
+ )
263
 
264
 
265
  def _render_cosine_similarity(store: ActivationStore) -> None:
 
267
  with col1:
268
  variant_a = st.selectbox(
269
  "Variant A",
270
+ options=SUPPORTED_VARIANTS,
271
  index=0,
272
  format_func=prompt_variant_label,
273
  key=widget_key("load", "variant_a"),
 
275
  with col2:
276
  variant_b = st.selectbox(
277
  "Variant B",
278
+ options=SUPPORTED_VARIANTS,
279
+ index=min(1, len(SUPPORTED_VARIANTS) - 1),
280
  format_func=prompt_variant_label,
281
  key=widget_key("load", "variant_b"),
282
  )
 
289
  if not persona_ids:
290
  return
291
 
292
+ cosine_fig_key = widget_key(
293
+ "load", "cosine_fig_state", store.model_name, variant_a, variant_b
294
+ )
295
  filename = _filename("compare", "cosine", store.model_name, variant_a, variant_b)
296
 
297
  if st.button("Compare vectors", type="primary"):
 
336
  config = _select_embedding_config(store)
337
  if config is None:
338
  return
339
+
 
340
  projection_config = _PROJECTION_CONFIGS.get(analysis_mode)
341
  if projection_config is None:
342
  st.error(f"Unsupported analysis mode: {analysis_mode}")
 
359
  try:
360
  plots, errors = _load_embedding_samples(
361
  store,
362
+ config.persona_ids,
363
+ config.variant,
364
+ config.selected_layers,
365
  projection_config.project_fn,
366
+ config.persona_names,
367
  progress_fn=update_progress,
368
  )
369
 
 
383
  st.info("Try fewer personas, fewer layers, or a different variant.")
384
  st.session_state.pop(embedding_fig_key, None)
385
  else:
386
+ rendered_figures = [
387
+ (
388
+ layer_idx,
389
+ build_embedding_figure(
390
+ coords=coords,
391
+ labels=labels,
392
+ title=f"{projection_config.title_prefix}, layer {layer_idx}",
393
+ x_label=projection_config.x_label,
394
+ y_label=projection_config.y_label,
395
+ hover_text=hover_text,
396
+ ),
397
+ )
398
+ for layer_idx, coords, labels, hover_text in plots
399
+ ]
400
  total_samples = sum(coords.shape[0] for _, coords, _, _ in plots)
401
  st.session_state[embedding_fig_key] = (
402
  rendered_figures,
403
+ config.persona_key,
404
+ config.variant,
405
  total_samples,
406
  )
407
  finally:
 
412
  st.session_state[embedding_fig_key]
413
  )
414
  _render_embedding_results(
415
+ store.model_name,
416
  analysis_mode,
417
  rendered_figures,
418
  saved_variant,
tabs/extract.py CHANGED
@@ -1,16 +1,28 @@
 
 
1
  import streamlit as st
 
2
  from persona_vectors.extraction import run_extraction
3
 
4
  from utils.datasets import load_dataset
5
  from utils.helpers import (
6
  NDIF_STATUS_ICONS,
7
- PROMPT_VARIANTS,
8
  persona_label,
9
  prompt_variant_label,
10
  widget_key,
11
  )
12
  from utils.runtime import cached_model
13
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def _extract_widget_key(
16
  model_name: str, remote: bool, dataset_source: str, suffix: str
@@ -26,7 +38,7 @@ def _render_local_dataset_uploads() -> None:
26
  "personas.jsonl",
27
  type=["jsonl"],
28
  key="extract__personas_file",
29
- help="Expected fields: id, persona, templated_prompt, biography_md",
30
  )
31
  st.file_uploader(
32
  "qa.jsonl",
@@ -44,19 +56,28 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
44
  if dataset_source == "Local JSONL upload":
45
  _render_local_dataset_uploads()
46
 
 
 
 
 
47
  selected_variants = st.multiselect(
48
  "Prompt variants",
49
- options=PROMPT_VARIANTS,
50
- default=PROMPT_VARIANTS,
51
  format_func=prompt_variant_label,
52
  key=_extract_widget_key(model_name, remote, dataset_source, "prompt_variants"),
53
  )
 
54
  if not selected_variants:
55
  st.info("Select at least one prompt variant.")
56
  return
57
 
58
  try:
59
- dataset, dataset_status = load_dataset(dataset_source)
 
 
 
 
60
  st.caption(dataset_status)
61
  except Exception as exc:
62
  st.error(f"Could not load data: {exc}")
@@ -73,13 +94,18 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
73
  )
74
  return
75
 
 
 
 
 
76
  selected_personas = st.multiselect(
77
  "Personas",
78
  options=personas,
79
- default=[personas[0]] if personas else [],
80
  format_func=persona_label,
81
  key=_extract_widget_key(model_name, remote, dataset_source, "persona_select"),
82
  )
 
83
 
84
  if not selected_personas:
85
  st.info("Select at least one persona.")
@@ -93,26 +119,42 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
93
 
94
  col1, col2, col3 = st.columns([2, 2, 1])
95
  with col1:
 
 
 
 
 
 
96
  qa_type_select = st.selectbox(
97
  "QA type",
98
- options=["all", "explicit", "implicit"],
99
- index=0,
100
  key=_extract_widget_key(
101
  model_name, remote, dataset_source, "qa_type_select"
102
  ),
103
  )
104
- qa_filter_type = (
105
- qa_type_select if qa_type_select in ("explicit", "implicit") else None
 
 
 
106
  )
107
  with col2:
 
 
 
 
 
 
108
  difficulty_values = st.multiselect(
109
  "Difficulty",
110
  options=[1, 2, 3],
111
- default=[1, 2, 3],
112
  key=_extract_widget_key(
113
  model_name, remote, dataset_source, "difficulty_select"
114
  ),
115
  )
 
116
  qa_filter_difficulty = difficulty_values if difficulty_values else None
117
 
118
  runs, skipped = [], []
@@ -135,15 +177,18 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
135
  return
136
 
137
  max_q = min(len(qa_pairs) for _, qa_pairs in runs)
 
 
138
  max_questions = st.slider(
139
  "Max questions",
140
  min_value=1,
141
  max_value=max_q,
142
- value=max_q,
143
  key=_extract_widget_key(
144
  model_name, remote, dataset_source, "max_questions"
145
  ),
146
  )
 
147
 
148
  if runs is None:
149
  return
@@ -180,7 +225,7 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
180
  model_name=model_name,
181
  persona=persona,
182
  qa_pairs=qa_pairs[:max_questions],
183
- variants=[variant],
184
  remote=remote,
185
  on_status=_on_ndif_status if remote else None,
186
  )
 
1
+ from typing import Literal, cast
2
+
3
  import streamlit as st
4
+ from persona_vectors.artifacts import SUPPORTED_VARIANTS
5
  from persona_vectors.extraction import run_extraction
6
 
7
  from utils.datasets import load_dataset
8
  from utils.helpers import (
9
  NDIF_STATUS_ICONS,
 
10
  persona_label,
11
  prompt_variant_label,
12
  widget_key,
13
  )
14
  from utils.runtime import cached_model
15
 
16
+ # Cross-model / remote-switch persistence — same pattern as compare.py.
17
+ # Written on every render so selections survive model or NDIF toggles.
18
+ _LAST_VARIANTS_KEY = "extract:last_variants"
19
+ _LAST_PERSONA_IDS_KEY = "extract:last_persona_ids"
20
+ _LAST_QA_TYPE_KEY = "extract:last_qa_type"
21
+ _LAST_DIFFICULTY_KEY = "extract:last_difficulty"
22
+ _LAST_MAX_QUESTIONS_KEY = "extract:last_max_questions"
23
+
24
+ _QA_TYPE_OPTIONS = ["all", "explicit", "implicit"]
25
+
26
 
27
  def _extract_widget_key(
28
  model_name: str, remote: bool, dataset_source: str, suffix: str
 
38
  "personas.jsonl",
39
  type=["jsonl"],
40
  key="extract__personas_file",
41
+ help="Expected fields: id, persona, templated_view, biography_view",
42
  )
43
  st.file_uploader(
44
  "qa.jsonl",
 
56
  if dataset_source == "Local JSONL upload":
57
  _render_local_dataset_uploads()
58
 
59
+ last_variants = st.session_state.get(_LAST_VARIANTS_KEY, list(SUPPORTED_VARIANTS))
60
+ default_variants = [v for v in last_variants if v in SUPPORTED_VARIANTS] or list(
61
+ SUPPORTED_VARIANTS
62
+ )
63
  selected_variants = st.multiselect(
64
  "Prompt variants",
65
+ options=SUPPORTED_VARIANTS,
66
+ default=default_variants,
67
  format_func=prompt_variant_label,
68
  key=_extract_widget_key(model_name, remote, dataset_source, "prompt_variants"),
69
  )
70
+ st.session_state[_LAST_VARIANTS_KEY] = selected_variants
71
  if not selected_variants:
72
  st.info("Select at least one prompt variant.")
73
  return
74
 
75
  try:
76
+ dataset, dataset_status = load_dataset(
77
+ dataset_source,
78
+ personas_file=st.session_state.get("extract__personas_file"),
79
+ qa_file=st.session_state.get("extract__qa_file"),
80
+ )
81
  st.caption(dataset_status)
82
  except Exception as exc:
83
  st.error(f"Could not load data: {exc}")
 
94
  )
95
  return
96
 
97
+ last_persona_ids: set[str] = set(st.session_state.get(_LAST_PERSONA_IDS_KEY, []))
98
+ default_personas = [p for p in personas if p.id in last_persona_ids] or [
99
+ personas[0]
100
+ ]
101
  selected_personas = st.multiselect(
102
  "Personas",
103
  options=personas,
104
+ default=default_personas,
105
  format_func=persona_label,
106
  key=_extract_widget_key(model_name, remote, dataset_source, "persona_select"),
107
  )
108
+ st.session_state[_LAST_PERSONA_IDS_KEY] = [p.id for p in selected_personas]
109
 
110
  if not selected_personas:
111
  st.info("Select at least one persona.")
 
119
 
120
  col1, col2, col3 = st.columns([2, 2, 1])
121
  with col1:
122
+ last_qa_type = st.session_state.get(_LAST_QA_TYPE_KEY, "all")
123
+ qa_type_index = (
124
+ _QA_TYPE_OPTIONS.index(last_qa_type)
125
+ if last_qa_type in _QA_TYPE_OPTIONS
126
+ else 0
127
+ )
128
  qa_type_select = st.selectbox(
129
  "QA type",
130
+ options=_QA_TYPE_OPTIONS,
131
+ index=qa_type_index,
132
  key=_extract_widget_key(
133
  model_name, remote, dataset_source, "qa_type_select"
134
  ),
135
  )
136
+ st.session_state[_LAST_QA_TYPE_KEY] = qa_type_select
137
+ qa_filter_type: Literal["explicit", "implicit"] | None = (
138
+ cast(Literal["explicit", "implicit"], qa_type_select)
139
+ if qa_type_select in ("explicit", "implicit")
140
+ else None
141
  )
142
  with col2:
143
+ last_difficulty = st.session_state.get(_LAST_DIFFICULTY_KEY, [1, 2, 3])
144
+ default_difficulty = [d for d in last_difficulty if d in (1, 2, 3)] or [
145
+ 1,
146
+ 2,
147
+ 3,
148
+ ]
149
  difficulty_values = st.multiselect(
150
  "Difficulty",
151
  options=[1, 2, 3],
152
+ default=default_difficulty,
153
  key=_extract_widget_key(
154
  model_name, remote, dataset_source, "difficulty_select"
155
  ),
156
  )
157
+ st.session_state[_LAST_DIFFICULTY_KEY] = difficulty_values
158
  qa_filter_difficulty = difficulty_values if difficulty_values else None
159
 
160
  runs, skipped = [], []
 
177
  return
178
 
179
  max_q = min(len(qa_pairs) for _, qa_pairs in runs)
180
+ last_max = st.session_state.get(_LAST_MAX_QUESTIONS_KEY, max_q)
181
+ default_max = min(max(last_max, 1), max_q)
182
  max_questions = st.slider(
183
  "Max questions",
184
  min_value=1,
185
  max_value=max_q,
186
+ value=default_max,
187
  key=_extract_widget_key(
188
  model_name, remote, dataset_source, "max_questions"
189
  ),
190
  )
191
+ st.session_state[_LAST_MAX_QUESTIONS_KEY] = max_questions
192
 
193
  if runs is None:
194
  return
 
225
  model_name=model_name,
226
  persona=persona,
227
  qa_pairs=qa_pairs[:max_questions],
228
+ variants=(variant,),
229
  remote=remote,
230
  on_status=_on_ndif_status if remote else None,
231
  )
utils/chat.py CHANGED
@@ -5,17 +5,10 @@ from typing import Literal
5
 
6
  import torch
7
  from nnterp import StandardizedTransformer
8
-
9
- logger = logging.getLogger(__name__)
10
-
11
- from persona_data.prompts import (
12
- format_biography_prompt,
13
- format_empty_persona_prompt,
14
- format_templated_prompt,
15
- normalize_messages,
16
- )
17
  from persona_data.synth_persona import PersonaData
18
 
 
19
  SystemPromptMode = Literal["empty", "templated", "biography", "custom"]
20
 
21
 
@@ -47,11 +40,12 @@ def resolve_system_prompt(
47
  if mode == "empty":
48
  return ""
49
  if mode == "templated":
50
- return format_templated_prompt(persona.templated_prompt)
51
  if mode == "biography":
52
- return format_biography_prompt(persona.biography_md)
53
  if mode == "custom":
54
- return format_empty_persona_prompt()
 
55
 
56
 
57
  def _format_plain_messages(
 
5
 
6
  import torch
7
  from nnterp import StandardizedTransformer
8
+ from persona_data.prompts import format_roleplay_prompt, normalize_messages
 
 
 
 
 
 
 
 
9
  from persona_data.synth_persona import PersonaData
10
 
11
+ logger = logging.getLogger(__name__)
12
  SystemPromptMode = Literal["empty", "templated", "biography", "custom"]
13
 
14
 
 
40
  if mode == "empty":
41
  return ""
42
  if mode == "templated":
43
+ return format_roleplay_prompt(persona.templated_view, mode="conversational")
44
  if mode == "biography":
45
+ return format_roleplay_prompt(persona.biography_view, mode="conversational")
46
  if mode == "custom":
47
+ return format_roleplay_prompt(mode="conversational")
48
+ raise ValueError(f"Unsupported system prompt mode: {mode}")
49
 
50
 
51
  def _format_plain_messages(
utils/chat_export.py CHANGED
@@ -54,7 +54,7 @@ def save_chat_export(
54
  export_dir = (
55
  get_artifacts_dir()
56
  / "chats"
57
- / model_name.replace("/", "__")
58
  / slugify(dataset_source)
59
  / slugify(persona_id)
60
  )
 
54
  export_dir = (
55
  get_artifacts_dir()
56
  / "chats"
57
+ / "__".join(slugify(part) for part in model_name.split("/"))
58
  / slugify(dataset_source)
59
  / slugify(persona_id)
60
  )
utils/datasets.py CHANGED
@@ -44,14 +44,14 @@ def _uploaded_file_to_temp_path(uploaded_file: Any, stem: str) -> Path:
44
 
45
  def load_dataset(
46
  dataset_source: str,
 
 
47
  ) -> tuple[SynthPersonaDataset | LocalPersonaDataset, str]:
48
  """Load the selected dataset source for the UI."""
49
 
50
  if dataset_source == DATASET_SOURCES[0]:
51
  return cached_hf_dataset(), "SynthPersona"
52
 
53
- personas_file = st.session_state.get("extract__personas_file")
54
- qa_file = st.session_state.get("extract__qa_file")
55
  if personas_file is None or qa_file is None:
56
  raise ValueError("Upload both personas.jsonl and qa.jsonl files")
57
 
 
44
 
45
  def load_dataset(
46
  dataset_source: str,
47
+ personas_file: Any = None,
48
+ qa_file: Any = None,
49
  ) -> tuple[SynthPersonaDataset | LocalPersonaDataset, str]:
50
  """Load the selected dataset source for the UI."""
51
 
52
  if dataset_source == DATASET_SOURCES[0]:
53
  return cached_hf_dataset(), "SynthPersona"
54
 
 
 
55
  if personas_file is None or qa_file is None:
56
  raise ValueError("Upload both personas.jsonl and qa.jsonl files")
57
 
utils/helpers.py CHANGED
@@ -1,7 +1,6 @@
1
  import re
2
 
3
  from persona_data.synth_persona import PersonaData
4
- from persona_vectors.artifacts import SUPPORTED_VARIANTS
5
 
6
  # Variant key -> human-readable label mapping
7
  VARIANT_LABELS = {
@@ -11,9 +10,6 @@ VARIANT_LABELS = {
11
  "custom": "Custom",
12
  }
13
 
14
- # Variants that correspond to actual system prompts (excludes "empty")
15
- PROMPT_VARIANTS = list(SUPPORTED_VARIANTS)
16
-
17
  # For selectbox options: list of labels in definition order
18
  MODE_LABELS = list(VARIANT_LABELS.values())
19
 
 
1
  import re
2
 
3
  from persona_data.synth_persona import PersonaData
 
4
 
5
  # Variant key -> human-readable label mapping
6
  VARIANT_LABELS = {
 
10
  "custom": "Custom",
11
  }
12
 
 
 
 
13
  # For selectbox options: list of labels in definition order
14
  MODE_LABELS = list(VARIANT_LABELS.values())
15
 
uv.lock CHANGED
@@ -297,7 +297,7 @@ name = "cuda-bindings"
297
  version = "13.2.0"
298
  source = { registry = "https://pypi.org/simple" }
299
  dependencies = [
300
- { name = "cuda-pathfinder", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
301
  ]
302
  wheels = [
303
  { url = "https://files.pythonhosted.org/packages/1a/fe/7351d7e586a8b4c9f89731bfe4cf0148223e8f9903ff09571f78b3fb0682/cuda_bindings-13.2.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:08b395f79cb89ce0cd8effff07c4a1e20101b873c256a1aeb286e8fd7bd0f556", size = 5744254, upload-time = "2026-03-11T00:12:29.798Z" },
@@ -316,10 +316,10 @@ wheels = [
316
 
317
  [[package]]
318
  name = "cuda-pathfinder"
319
- version = "1.5.1"
320
  source = { registry = "https://pypi.org/simple" }
321
  wheels = [
322
- { url = "https://files.pythonhosted.org/packages/c4/74/8c66861b873d8eed51fde56d3091baa4906a56f0d4390cae991f2d41dda5/cuda_pathfinder-1.5.1-py3-none-any.whl", hash = "sha256:b3718097fb57cf9e8a904dd072d806f2c9a27627e35c020b06ab9454bcec08c0", size = 49861, upload-time = "2026-04-03T16:41:22.203Z" },
323
  ]
324
 
325
  [[package]]
@@ -332,37 +332,37 @@ wheels = [
332
 
333
  [package.optional-dependencies]
334
  cublas = [
335
- { name = "nvidia-cublas", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" },
336
  ]
337
  cudart = [
338
- { name = "nvidia-cuda-runtime", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" },
339
  ]
340
  cufft = [
341
- { name = "nvidia-cufft", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" },
342
  ]
343
  cufile = [
344
  { name = "nvidia-cufile", marker = "sys_platform == 'linux'" },
345
  ]
346
  cupti = [
347
- { name = "nvidia-cuda-cupti", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" },
348
  ]
349
  curand = [
350
- { name = "nvidia-curand", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" },
351
  ]
352
  cusolver = [
353
- { name = "nvidia-cusolver", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" },
354
  ]
355
  cusparse = [
356
- { name = "nvidia-cusparse", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" },
357
  ]
358
  nvjitlink = [
359
- { name = "nvidia-nvjitlink", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" },
360
  ]
361
  nvrtc = [
362
- { name = "nvidia-cuda-nvrtc", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" },
363
  ]
364
  nvtx = [
365
- { name = "nvidia-nvtx", marker = "(python_full_version < '3.11' and sys_platform == 'win32') or sys_platform == 'linux'" },
366
  ]
367
 
368
  [[package]]
@@ -508,7 +508,7 @@ wheels = [
508
 
509
  [[package]]
510
  name = "huggingface-hub"
511
- version = "1.9.0"
512
  source = { registry = "https://pypi.org/simple" }
513
  dependencies = [
514
  { name = "filelock" },
@@ -521,9 +521,9 @@ dependencies = [
521
  { name = "typer" },
522
  { name = "typing-extensions" },
523
  ]
524
- sdist = { url = "https://files.pythonhosted.org/packages/88/bb/62c7aa86f63a05e2f9b96642fdef9b94526a23979820b09f5455deff4983/huggingface_hub-1.9.0.tar.gz", hash = "sha256:0ea5be7a56135c91797cae6ad726e38eaeb6eb4b77cefff5c9d38ba0ecf874f7", size = 750326, upload-time = "2026-04-03T08:35:55.888Z" }
525
  wheels = [
526
- { url = "https://files.pythonhosted.org/packages/73/37/0d15d16150e1829f3e90962c99f28257f6de9e526a680b4c6f5acdb54fd2/huggingface_hub-1.9.0-py3-none-any.whl", hash = "sha256:2999328c058d39fd19ab748dd09bd4da2fbaa4f4c1ddea823eab103051e14a1f", size = 637355, upload-time = "2026-04-03T08:35:53.897Z" },
527
  ]
528
 
529
  [[package]]
@@ -883,11 +883,11 @@ wheels = [
883
 
884
  [[package]]
885
  name = "narwhals"
886
- version = "2.18.1"
887
  source = { registry = "https://pypi.org/simple" }
888
- sdist = { url = "https://files.pythonhosted.org/packages/59/96/45218c2fdec4c9f22178f905086e85ef1a6d63862dcc3cd68eb60f1867f5/narwhals-2.18.1.tar.gz", hash = "sha256:652a1fcc9d432bbf114846688884c215f17eb118aa640b7419295d2f910d2a8b", size = 620578, upload-time = "2026-03-24T15:11:25.456Z" }
889
  wheels = [
890
- { url = "https://files.pythonhosted.org/packages/3f/c3/06490e98393dcb4d6ce2bf331a39335375c300afaef526897881fbeae6ab/narwhals-2.18.1-py3-none-any.whl", hash = "sha256:a0a8bb80205323851338888ba3a12b4f65d352362c8a94be591244faf36504ad", size = 444952, upload-time = "2026-03-24T15:11:23.801Z" },
891
  ]
892
 
893
  [[package]]
@@ -1216,7 +1216,7 @@ name = "nvidia-cudnn-cu13"
1216
  version = "9.19.0.56"
1217
  source = { registry = "https://pypi.org/simple" }
1218
  dependencies = [
1219
- { name = "nvidia-cublas", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
1220
  ]
1221
  wheels = [
1222
  { url = "https://files.pythonhosted.org/packages/f1/84/26025437c1e6b61a707442184fa0c03d083b661adf3a3eecfd6d21677740/nvidia_cudnn_cu13-9.19.0.56-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:6ed29ffaee1176c612daf442e4dd6cfeb6a0caa43ddcbeb59da94953030b1be4", size = 433781201, upload-time = "2026-02-03T20:40:53.805Z" },
@@ -1228,7 +1228,7 @@ name = "nvidia-cufft"
1228
  version = "12.0.0.61"
1229
  source = { registry = "https://pypi.org/simple" }
1230
  dependencies = [
1231
- { name = "nvidia-nvjitlink", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
1232
  ]
1233
  wheels = [
1234
  { url = "https://files.pythonhosted.org/packages/8b/ae/f417a75c0259e85c1d2f83ca4e960289a5f814ed0cea74d18c353d3e989d/nvidia_cufft-12.0.0.61-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5", size = 214053554, upload-time = "2025-09-04T08:31:38.196Z" },
@@ -1258,9 +1258,9 @@ name = "nvidia-cusolver"
1258
  version = "12.0.4.66"
1259
  source = { registry = "https://pypi.org/simple" }
1260
  dependencies = [
1261
- { name = "nvidia-cublas", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
1262
- { name = "nvidia-cusparse", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
1263
- { name = "nvidia-nvjitlink", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
1264
  ]
1265
  wheels = [
1266
  { url = "https://files.pythonhosted.org/packages/c8/c3/b30c9e935fc01e3da443ec0116ed1b2a009bb867f5324d3f2d7e533e776b/nvidia_cusolver-12.0.4.66-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2", size = 223467760, upload-time = "2025-09-04T08:33:04.222Z" },
@@ -1272,7 +1272,7 @@ name = "nvidia-cusparse"
1272
  version = "12.6.3.3"
1273
  source = { registry = "https://pypi.org/simple" }
1274
  dependencies = [
1275
- { name = "nvidia-nvjitlink", marker = "(python_full_version < '3.11' and sys_platform == 'emscripten') or (python_full_version < '3.11' and sys_platform == 'win32') or (sys_platform != 'emscripten' and sys_platform != 'win32')" },
1276
  ]
1277
  wheels = [
1278
  { url = "https://files.pythonhosted.org/packages/f8/94/5c26f33738ae35276672f12615a64bd008ed5be6d1ebcb23579285d960a9/nvidia_cusparse-12.6.3.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c", size = 162155568, upload-time = "2025-09-04T08:33:42.864Z" },
@@ -1561,7 +1561,7 @@ wheels = [
1561
  [[package]]
1562
  name = "persona-data"
1563
  version = "0.1.0"
1564
- source = { editable = "../persona-data" }
1565
  dependencies = [
1566
  { name = "huggingface-hub" },
1567
  { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
@@ -1570,14 +1570,6 @@ dependencies = [
1570
  { name = "torch" },
1571
  ]
1572
 
1573
- [package.metadata]
1574
- requires-dist = [
1575
- { name = "huggingface-hub", specifier = ">=0.30.0" },
1576
- { name = "numpy", specifier = ">=1.24.0" },
1577
- { name = "python-dotenv", specifier = ">=1.0.0" },
1578
- { name = "torch", specifier = ">=2.0.0" },
1579
- ]
1580
-
1581
  [[package]]
1582
  name = "persona-ui"
1583
  version = "0.1.0"
@@ -1592,8 +1584,8 @@ dependencies = [
1592
 
1593
  [package.metadata]
1594
  requires-dist = [
1595
- { name = "persona-data", editable = "../persona-data" },
1596
- { name = "persona-vectors", editable = "../persona-vectors" },
1597
  { name = "plotly", specifier = ">=6.6.0" },
1598
  { name = "python-dotenv", specifier = ">=1.2.2" },
1599
  { name = "streamlit", specifier = ">=1.44.0" },
@@ -1602,7 +1594,7 @@ requires-dist = [
1602
  [[package]]
1603
  name = "persona-vectors"
1604
  version = "0.1.0"
1605
- source = { editable = "../persona-vectors" }
1606
  dependencies = [
1607
  { name = "kaleido" },
1608
  { name = "nnsight" },
@@ -1620,23 +1612,6 @@ dependencies = [
1620
  { name = "umap-learn" },
1621
  ]
1622
 
1623
- [package.metadata]
1624
- requires-dist = [
1625
- { name = "kaleido", specifier = ">=1.0.0" },
1626
- { name = "nnsight", specifier = ">=0.6.1" },
1627
- { name = "nnterp", specifier = ">=1.3.0" },
1628
- { name = "persona-data", editable = "../persona-data" },
1629
- { name = "plotly", specifier = ">=6.6.0" },
1630
- { name = "python-dotenv", specifier = ">=1.2.2" },
1631
- { name = "safetensors", specifier = ">=0.7.0" },
1632
- { name = "scikit-learn", specifier = ">=1.6.0" },
1633
- { name = "torch", specifier = ">=2.10.0" },
1634
- { name = "torchvision", specifier = ">=0.26.0" },
1635
- { name = "tqdm", specifier = ">=4.67.3" },
1636
- { name = "transformers", specifier = ">=5.2.0" },
1637
- { name = "umap-learn", specifier = ">=0.5.7" },
1638
- ]
1639
-
1640
  [[package]]
1641
  name = "pexpect"
1642
  version = "4.9.0"
@@ -2075,7 +2050,7 @@ wheels = [
2075
 
2076
  [[package]]
2077
  name = "pytest"
2078
- version = "9.0.2"
2079
  source = { registry = "https://pypi.org/simple" }
2080
  dependencies = [
2081
  { name = "colorama", marker = "sys_platform == 'win32'" },
@@ -2086,9 +2061,9 @@ dependencies = [
2086
  { name = "pygments" },
2087
  { name = "tomli", marker = "python_full_version < '3.11'" },
2088
  ]
2089
- sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" }
2090
  wheels = [
2091
- { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
2092
  ]
2093
 
2094
  [[package]]
 
297
  version = "13.2.0"
298
  source = { registry = "https://pypi.org/simple" }
299
  dependencies = [
300
+ { name = "cuda-pathfinder" },
301
  ]
302
  wheels = [
303
  { url = "https://files.pythonhosted.org/packages/1a/fe/7351d7e586a8b4c9f89731bfe4cf0148223e8f9903ff09571f78b3fb0682/cuda_bindings-13.2.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:08b395f79cb89ce0cd8effff07c4a1e20101b873c256a1aeb286e8fd7bd0f556", size = 5744254, upload-time = "2026-03-11T00:12:29.798Z" },
 
316
 
317
  [[package]]
318
  name = "cuda-pathfinder"
319
+ version = "1.5.2"
320
  source = { registry = "https://pypi.org/simple" }
321
  wheels = [
322
+ { url = "https://files.pythonhosted.org/packages/f2/f9/1b9b60a30fc463c14cdea7a77228131a0ccc89572e8df9cb86c9648271ab/cuda_pathfinder-1.5.2-py3-none-any.whl", hash = "sha256:0c5f160a7756c5b072723cbbd6d861e38917ef956c68150b02f0b6e9271c71fa", size = 49988, upload-time = "2026-04-06T23:01:05.17Z" },
323
  ]
324
 
325
  [[package]]
 
332
 
333
  [package.optional-dependencies]
334
  cublas = [
335
+ { name = "nvidia-cublas", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
336
  ]
337
  cudart = [
338
+ { name = "nvidia-cuda-runtime", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
339
  ]
340
  cufft = [
341
+ { name = "nvidia-cufft", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
342
  ]
343
  cufile = [
344
  { name = "nvidia-cufile", marker = "sys_platform == 'linux'" },
345
  ]
346
  cupti = [
347
+ { name = "nvidia-cuda-cupti", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
348
  ]
349
  curand = [
350
+ { name = "nvidia-curand", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
351
  ]
352
  cusolver = [
353
+ { name = "nvidia-cusolver", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
354
  ]
355
  cusparse = [
356
+ { name = "nvidia-cusparse", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
357
  ]
358
  nvjitlink = [
359
+ { name = "nvidia-nvjitlink", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
360
  ]
361
  nvrtc = [
362
+ { name = "nvidia-cuda-nvrtc", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
363
  ]
364
  nvtx = [
365
+ { name = "nvidia-nvtx", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
366
  ]
367
 
368
  [[package]]
 
508
 
509
  [[package]]
510
  name = "huggingface-hub"
511
+ version = "1.9.2"
512
  source = { registry = "https://pypi.org/simple" }
513
  dependencies = [
514
  { name = "filelock" },
 
521
  { name = "typer" },
522
  { name = "typing-extensions" },
523
  ]
524
+ sdist = { url = "https://files.pythonhosted.org/packages/cf/65/fb800d327bf25bf31b798dd08935d326d064ecb9b359059fecd91b3a98e8/huggingface_hub-1.9.2.tar.gz", hash = "sha256:8d09d080a186bd950a361bfc04b862dfb04d6a2b41d48e9ba1b37507cfd3f1e1", size = 750284, upload-time = "2026-04-08T08:43:11.127Z" }
525
  wheels = [
526
+ { url = "https://files.pythonhosted.org/packages/57/d4/e33bf0b362810a9b96c5923e38908950d58ecb512db42e3730320c7f4a3a/huggingface_hub-1.9.2-py3-none-any.whl", hash = "sha256:e1e62ce237d4fbeca9f970aeb15176fbd503e04c25577bfd22f44aa7aa2b5243", size = 637349, upload-time = "2026-04-08T08:43:09.114Z" },
527
  ]
528
 
529
  [[package]]
 
883
 
884
  [[package]]
885
  name = "narwhals"
886
+ version = "2.19.0"
887
  source = { registry = "https://pypi.org/simple" }
888
+ sdist = { url = "https://files.pythonhosted.org/packages/4e/1a/bd3317c0bdbcd9ffb710ddf5250b32898f8f2c240be99494fe137feb77a7/narwhals-2.19.0.tar.gz", hash = "sha256:14fd7040b5ff211d415a82e4827b9d04c354e213e72a6d0730205ffd72e3b7ff", size = 623698, upload-time = "2026-04-06T15:50:58.786Z" }
889
  wheels = [
890
+ { url = "https://files.pythonhosted.org/packages/37/72/e61e3091e0e00fae9d3a8ef85ece9d2cd4b5966058e1f2901ce42679eebf/narwhals-2.19.0-py3-none-any.whl", hash = "sha256:1f8dfa4a33a6dbff878c3e9be4c3b455dfcaf2a9322f1357db00e4e92e95b84b", size = 446991, upload-time = "2026-04-06T15:50:57.046Z" },
891
  ]
892
 
893
  [[package]]
 
1216
  version = "9.19.0.56"
1217
  source = { registry = "https://pypi.org/simple" }
1218
  dependencies = [
1219
+ { name = "nvidia-cublas" },
1220
  ]
1221
  wheels = [
1222
  { url = "https://files.pythonhosted.org/packages/f1/84/26025437c1e6b61a707442184fa0c03d083b661adf3a3eecfd6d21677740/nvidia_cudnn_cu13-9.19.0.56-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:6ed29ffaee1176c612daf442e4dd6cfeb6a0caa43ddcbeb59da94953030b1be4", size = 433781201, upload-time = "2026-02-03T20:40:53.805Z" },
 
1228
  version = "12.0.0.61"
1229
  source = { registry = "https://pypi.org/simple" }
1230
  dependencies = [
1231
+ { name = "nvidia-nvjitlink" },
1232
  ]
1233
  wheels = [
1234
  { url = "https://files.pythonhosted.org/packages/8b/ae/f417a75c0259e85c1d2f83ca4e960289a5f814ed0cea74d18c353d3e989d/nvidia_cufft-12.0.0.61-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5", size = 214053554, upload-time = "2025-09-04T08:31:38.196Z" },
 
1258
  version = "12.0.4.66"
1259
  source = { registry = "https://pypi.org/simple" }
1260
  dependencies = [
1261
+ { name = "nvidia-cublas" },
1262
+ { name = "nvidia-cusparse" },
1263
+ { name = "nvidia-nvjitlink" },
1264
  ]
1265
  wheels = [
1266
  { url = "https://files.pythonhosted.org/packages/c8/c3/b30c9e935fc01e3da443ec0116ed1b2a009bb867f5324d3f2d7e533e776b/nvidia_cusolver-12.0.4.66-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2", size = 223467760, upload-time = "2025-09-04T08:33:04.222Z" },
 
1272
  version = "12.6.3.3"
1273
  source = { registry = "https://pypi.org/simple" }
1274
  dependencies = [
1275
+ { name = "nvidia-nvjitlink" },
1276
  ]
1277
  wheels = [
1278
  { url = "https://files.pythonhosted.org/packages/f8/94/5c26f33738ae35276672f12615a64bd008ed5be6d1ebcb23579285d960a9/nvidia_cusparse-12.6.3.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c", size = 162155568, upload-time = "2025-09-04T08:33:42.864Z" },
 
1561
  [[package]]
1562
  name = "persona-data"
1563
  version = "0.1.0"
1564
+ source = { git = "ssh://git@github.com/implicit-personalization/persona-data.git#3763bd6e42472b589b4e32acd3e47b711a0af1f5" }
1565
  dependencies = [
1566
  { name = "huggingface-hub" },
1567
  { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
 
1570
  { name = "torch" },
1571
  ]
1572
 
 
 
 
 
 
 
 
 
1573
  [[package]]
1574
  name = "persona-ui"
1575
  version = "0.1.0"
 
1584
 
1585
  [package.metadata]
1586
  requires-dist = [
1587
+ { name = "persona-data", git = "ssh://git@github.com/implicit-personalization/persona-data.git" },
1588
+ { name = "persona-vectors", git = "ssh://git@github.com/implicit-personalization/persona-vectors.git" },
1589
  { name = "plotly", specifier = ">=6.6.0" },
1590
  { name = "python-dotenv", specifier = ">=1.2.2" },
1591
  { name = "streamlit", specifier = ">=1.44.0" },
 
1594
  [[package]]
1595
  name = "persona-vectors"
1596
  version = "0.1.0"
1597
+ source = { git = "ssh://git@github.com/implicit-personalization/persona-vectors.git#fa6b4b61eaaba9ce64ee8614766bf75879148bbb" }
1598
  dependencies = [
1599
  { name = "kaleido" },
1600
  { name = "nnsight" },
 
1612
  { name = "umap-learn" },
1613
  ]
1614
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615
  [[package]]
1616
  name = "pexpect"
1617
  version = "4.9.0"
 
2050
 
2051
  [[package]]
2052
  name = "pytest"
2053
+ version = "9.0.3"
2054
  source = { registry = "https://pypi.org/simple" }
2055
  dependencies = [
2056
  { name = "colorama", marker = "sys_platform == 'win32'" },
 
2061
  { name = "pygments" },
2062
  { name = "tomli", marker = "python_full_version < '3.11'" },
2063
  ]
2064
+ sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" }
2065
  wheels = [
2066
+ { url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" },
2067
  ]
2068
 
2069
  [[package]]