Jac-Zac commited on
Commit ·
eb41f91
1
Parent(s): f4259c0
Updated to latest version
Browse filesUpdated message
Fix bug
Fix bugs
- .env.example +1 -0
- README.md +11 -2
- app.py +39 -5
- pyproject.toml +7 -6
- state.py +6 -4
- tabs/chat.py +206 -61
- tabs/compare.py +71 -57
- tabs/extract.py +58 -13
- utils/chat.py +6 -12
- utils/chat_export.py +1 -1
- utils/datasets.py +2 -2
- utils/helpers.py +0 -4
- uv.lock +32 -57
.env.example
CHANGED
|
@@ -9,6 +9,7 @@ NDIF_API_KEY=your-ndif-api-key-here
|
|
| 9 |
# Defaults to ~/.cache/huggingface if unset
|
| 10 |
# Useful when working on a cluster with a shared cache or limited home quota
|
| 11 |
HF_HOME=/path/to/your/hf/cache
|
|
|
|
| 12 |
|
| 13 |
# Root directory for all generated artifacts (activations, plots, etc.)
|
| 14 |
# Defaults to artifacts if unset
|
|
|
|
| 9 |
# Defaults to ~/.cache/huggingface if unset
|
| 10 |
# Useful when working on a cluster with a shared cache or limited home quota
|
| 11 |
HF_HOME=/path/to/your/hf/cache
|
| 12 |
+
HF_TOKEN=your-token
|
| 13 |
|
| 14 |
# Root directory for all generated artifacts (activations, plots, etc.)
|
| 15 |
# Defaults to artifacts if unset
|
README.md
CHANGED
|
@@ -42,9 +42,15 @@ uv sync
|
|
| 42 |
cp .env.example .env
|
| 43 |
```
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
## Local Setup Note
|
| 46 |
|
| 47 |
-
For
|
| 48 |
|
| 49 |
Example:
|
| 50 |
|
|
@@ -80,6 +86,9 @@ ARTIFACTS_DIR=... # Optional: where activations are read from (default: ./a
|
|
| 80 |
|
| 81 |
The app picks up this file automatically via `load_dotenv()` on startup.
|
| 82 |
|
|
|
|
|
|
|
|
|
|
| 83 |
## Saved Artifacts
|
| 84 |
|
| 85 |
The Compare and Extract tabs read from / write to:
|
|
@@ -88,7 +97,7 @@ The Compare and Extract tabs read from / write to:
|
|
| 88 |
artifacts/
|
| 89 |
├── activations/<model_dir>/<prompt_variant>/<persona_id>/
|
| 90 |
│ ├── activations.safetensors
|
| 91 |
-
│ └── metadata.json
|
| 92 |
└── chats/<model_dir>/<prompt_variant>/
|
| 93 |
└── <export>.json
|
| 94 |
```
|
|
|
|
| 42 |
cp .env.example .env
|
| 43 |
```
|
| 44 |
|
| 45 |
+
## Local Development
|
| 46 |
+
|
| 47 |
+
The committed dependency graph uses git sources so `persona-ui` can install cleanly in a Hugging Face Space or any isolated environment.
|
| 48 |
+
|
| 49 |
+
For local sibling checkouts, uncomment the `path` sources in `persona-ui/pyproject.toml` and `persona-vectors/pyproject.toml`, then comment out the git sources.
|
| 50 |
+
|
| 51 |
## Local Setup Note
|
| 52 |
|
| 53 |
+
For local development, `persona-data` and `persona-vectors` can still be checked out in the parent directory of `persona-ui`.
|
| 54 |
|
| 55 |
Example:
|
| 56 |
|
|
|
|
| 86 |
|
| 87 |
The app picks up this file automatically via `load_dotenv()` on startup.
|
| 88 |
|
| 89 |
+
You can also override the active NDIF or Hugging Face token from the sidebar
|
| 90 |
+
`API Keys` section. Those inputs only apply for the current session.
|
| 91 |
+
|
| 92 |
## Saved Artifacts
|
| 93 |
|
| 94 |
The Compare and Extract tabs read from / write to:
|
|
|
|
| 97 |
artifacts/
|
| 98 |
├── activations/<model_dir>/<prompt_variant>/<persona_id>/
|
| 99 |
│ ├── activations.safetensors
|
| 100 |
+
│ └── metadata.json # used for persona names and layer counts
|
| 101 |
└── chats/<model_dir>/<prompt_variant>/
|
| 102 |
└── <export>.json
|
| 103 |
```
|
app.py
CHANGED
|
@@ -8,6 +8,42 @@ from utils.helpers import DATASET_SOURCES
|
|
| 8 |
load_dotenv()
|
| 9 |
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
|
| 10 |
REMOTE_DEFAULT_MODEL = os.environ.get("REMOTE_DEFAULT_MODEL", "google/gemma-2-9b-it")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
def _sidebar_controls() -> tuple[bool, str, str, str]:
|
|
@@ -18,7 +54,7 @@ def _sidebar_controls() -> tuple[bool, str, str, str]:
|
|
| 18 |
st.caption("Chat, extract, and compare persona runs.")
|
| 19 |
|
| 20 |
if "sidebar__active_tab" not in st.session_state:
|
| 21 |
-
st.session_state["sidebar__active_tab"] =
|
| 22 |
|
| 23 |
active_tab = st.session_state["sidebar__active_tab"]
|
| 24 |
for tab_name, icon in zip(_TABS, _TAB_ICONS, strict=True):
|
|
@@ -71,11 +107,9 @@ def _sidebar_controls() -> tuple[bool, str, str, str]:
|
|
| 71 |
help="Dataset for Chat and Extract.",
|
| 72 |
)
|
| 73 |
|
| 74 |
-
|
| 75 |
-
|
| 76 |
|
| 77 |
-
|
| 78 |
-
_TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
|
| 79 |
|
| 80 |
|
| 81 |
def main() -> None:
|
|
|
|
| 8 |
load_dotenv()
|
| 9 |
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
|
| 10 |
REMOTE_DEFAULT_MODEL = os.environ.get("REMOTE_DEFAULT_MODEL", "google/gemma-2-9b-it")
|
| 11 |
+
NDIF_API_KEY = os.environ.get("NDIF_API_KEY", "")
|
| 12 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", os.environ.get("HUGGING_FACE_HUB_TOKEN", ""))
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
_TABS = ["Chat", "Compare", "Extract"]
|
| 16 |
+
_TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _sync_sidebar_api_key(env_var: str, value: str) -> None:
|
| 20 |
+
if value:
|
| 21 |
+
os.environ[env_var] = value
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _sidebar_api_keys() -> None:
|
| 25 |
+
with st.sidebar:
|
| 26 |
+
st.divider()
|
| 27 |
+
st.caption("API Keys")
|
| 28 |
+
|
| 29 |
+
ndif_api_key = st.text_input(
|
| 30 |
+
"NDIF API key",
|
| 31 |
+
value=NDIF_API_KEY,
|
| 32 |
+
type="password",
|
| 33 |
+
key="sidebar__ndif_api_key",
|
| 34 |
+
help="Overrides NDIF_API_KEY for this session.",
|
| 35 |
+
)
|
| 36 |
+
_sync_sidebar_api_key("NDIF_API_KEY", ndif_api_key)
|
| 37 |
+
|
| 38 |
+
hf_token = st.text_input(
|
| 39 |
+
"Hugging Face token",
|
| 40 |
+
value=HF_TOKEN,
|
| 41 |
+
type="password",
|
| 42 |
+
key="sidebar__hf_token",
|
| 43 |
+
help="Overrides HF_TOKEN and HUGGING_FACE_HUB_TOKEN for this session.",
|
| 44 |
+
)
|
| 45 |
+
_sync_sidebar_api_key("HF_TOKEN", hf_token)
|
| 46 |
+
_sync_sidebar_api_key("HUGGING_FACE_HUB_TOKEN", hf_token)
|
| 47 |
|
| 48 |
|
| 49 |
def _sidebar_controls() -> tuple[bool, str, str, str]:
|
|
|
|
| 54 |
st.caption("Chat, extract, and compare persona runs.")
|
| 55 |
|
| 56 |
if "sidebar__active_tab" not in st.session_state:
|
| 57 |
+
st.session_state["sidebar__active_tab"] = "Chat"
|
| 58 |
|
| 59 |
active_tab = st.session_state["sidebar__active_tab"]
|
| 60 |
for tab_name, icon in zip(_TABS, _TAB_ICONS, strict=True):
|
|
|
|
| 107 |
help="Dataset for Chat and Extract.",
|
| 108 |
)
|
| 109 |
|
| 110 |
+
_sidebar_api_keys()
|
|
|
|
| 111 |
|
| 112 |
+
return remote, model_name, dataset_source, active_tab
|
|
|
|
| 113 |
|
| 114 |
|
| 115 |
def main() -> None:
|
pyproject.toml
CHANGED
|
@@ -5,18 +5,19 @@ description = "Streamlit UI for persona-vectors"
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.10"
|
| 7 |
dependencies = [
|
| 8 |
-
"persona-vectors",
|
| 9 |
-
"persona-data",
|
| 10 |
"streamlit>=1.44.0",
|
| 11 |
"plotly>=6.6.0",
|
| 12 |
"python-dotenv>=1.2.2",
|
| 13 |
]
|
| 14 |
|
| 15 |
[tool.uv.sources]
|
| 16 |
-
|
| 17 |
-
persona-
|
| 18 |
-
# persona-
|
| 19 |
-
|
|
|
|
| 20 |
|
| 21 |
# [build-system]
|
| 22 |
# requires = ["uv_build>=0.11.3,<0.12"]
|
|
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.10"
|
| 7 |
dependencies = [
|
| 8 |
+
"persona-vectors>=0.1.0",
|
| 9 |
+
"persona-data>=0.1.0",
|
| 10 |
"streamlit>=1.44.0",
|
| 11 |
"plotly>=6.6.0",
|
| 12 |
"python-dotenv>=1.2.2",
|
| 13 |
]
|
| 14 |
|
| 15 |
[tool.uv.sources]
|
| 16 |
+
# Local development:
|
| 17 |
+
# persona-vectors = { path = "../persona-vectors", editable = true }
|
| 18 |
+
# persona-data = { path = "../persona-data", editable = true }
|
| 19 |
+
persona-vectors = { git = "ssh://git@github.com/implicit-personalization/persona-vectors.git" }
|
| 20 |
+
persona-data = { git = "ssh://git@github.com/implicit-personalization/persona-data.git" }
|
| 21 |
|
| 22 |
# [build-system]
|
| 23 |
# requires = ["uv_build>=0.11.3,<0.12"]
|
state.py
CHANGED
|
@@ -51,9 +51,11 @@ def get_chat_state(
|
|
| 51 |
return state
|
| 52 |
|
| 53 |
|
| 54 |
-
def reset_chat_state(model_name: str,
|
| 55 |
"""Reset chat history and cache for the active context."""
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
| 51 |
return state
|
| 52 |
|
| 53 |
|
| 54 |
+
def reset_chat_state(model_name: str, dataset_source: str) -> None:
|
| 55 |
"""Reset chat history and cache for the active context."""
|
| 56 |
|
| 57 |
+
key = chat_session_key(model_name, dataset_source)
|
| 58 |
+
if key in st.session_state:
|
| 59 |
+
state = st.session_state[key]
|
| 60 |
+
state["messages"] = []
|
| 61 |
+
state["past_key_values"] = None
|
tabs/chat.py
CHANGED
|
@@ -23,12 +23,118 @@ from utils.helpers import (
|
|
| 23 |
)
|
| 24 |
from utils.runtime import cached_model
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
def _render_chat_message(message: dict[str, str]) -> None:
|
| 28 |
if not message.get("content"):
|
| 29 |
return
|
| 30 |
with st.chat_message(message["role"]):
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
def _clear_chat_ui_state(*keys: str) -> None:
|
|
@@ -38,14 +144,13 @@ def _clear_chat_ui_state(*keys: str) -> None:
|
|
| 38 |
|
| 39 |
def _reset_single_chat_context(
|
| 40 |
model_name: str,
|
| 41 |
-
remote: bool,
|
| 42 |
dataset_source: str,
|
| 43 |
chat_state: dict[str, object],
|
| 44 |
persona_id: str,
|
| 45 |
prompt_mode: str,
|
| 46 |
*ui_keys: str,
|
| 47 |
) -> None:
|
| 48 |
-
reset_chat_state(model_name,
|
| 49 |
chat_state["persona_id"] = persona_id
|
| 50 |
chat_state["prompt_mode"] = prompt_mode
|
| 51 |
_clear_chat_ui_state(*ui_keys)
|
|
@@ -101,35 +206,6 @@ def _render_persona_prompt_controls(
|
|
| 101 |
return selected_persona, prompt_mode, changed
|
| 102 |
|
| 103 |
|
| 104 |
-
def _render_system_prompt_editor(
|
| 105 |
-
prompt_key: str,
|
| 106 |
-
prompt_mode: str,
|
| 107 |
-
active_system_prompt: str | None,
|
| 108 |
-
*,
|
| 109 |
-
height: int,
|
| 110 |
-
label: str = "Prompt",
|
| 111 |
-
) -> str | None:
|
| 112 |
-
"""Render the editable system prompt area for a chat panel."""
|
| 113 |
-
|
| 114 |
-
if prompt_mode == "empty":
|
| 115 |
-
return active_system_prompt
|
| 116 |
-
|
| 117 |
-
if prompt_key not in st.session_state:
|
| 118 |
-
st.session_state[prompt_key] = active_system_prompt or ""
|
| 119 |
-
|
| 120 |
-
with st.expander("Edit prompt", expanded=False):
|
| 121 |
-
edited_prompt = (
|
| 122 |
-
st.text_area(
|
| 123 |
-
label,
|
| 124 |
-
key=prompt_key,
|
| 125 |
-
height=height,
|
| 126 |
-
label_visibility="collapsed",
|
| 127 |
-
)
|
| 128 |
-
or None
|
| 129 |
-
)
|
| 130 |
-
return edited_prompt
|
| 131 |
-
|
| 132 |
-
|
| 133 |
def _render_chat_window(
|
| 134 |
*,
|
| 135 |
chat_log: Any,
|
|
@@ -137,6 +213,9 @@ def _render_chat_window(
|
|
| 137 |
show_all_key: str,
|
| 138 |
show_all_btn_key: str,
|
| 139 |
show_earlier_label: str,
|
|
|
|
|
|
|
|
|
|
| 140 |
) -> Any:
|
| 141 |
"""Render the visible chat history inside one container."""
|
| 142 |
|
|
@@ -152,11 +231,19 @@ def _render_chat_window(
|
|
| 152 |
st.session_state[show_all_key] = True
|
| 153 |
st.rerun()
|
| 154 |
visible_messages = messages[-VISIBLE_MESSAGE_COUNT:]
|
|
|
|
| 155 |
else:
|
| 156 |
visible_messages = messages
|
|
|
|
| 157 |
|
| 158 |
-
for message in visible_messages:
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
return chat_log
|
| 162 |
|
|
@@ -218,7 +305,9 @@ def _render_compare_mode(
|
|
| 218 |
"""Render the full side-by-side comparison UI."""
|
| 219 |
left_col, right_col = st.columns(2)
|
| 220 |
|
| 221 |
-
def render_panel(
|
|
|
|
|
|
|
| 222 |
panel_key = widget_key(context_key, f"cmp_{side}")
|
| 223 |
state = st.session_state.get(panel_key)
|
| 224 |
if state is None:
|
|
@@ -226,6 +315,8 @@ def _render_compare_mode(
|
|
| 226 |
st.session_state[panel_key] = state
|
| 227 |
prompt_key = widget_key(panel_key, "custom_prompt")
|
| 228 |
show_all_key = widget_key(panel_key, "show_all")
|
|
|
|
|
|
|
| 229 |
|
| 230 |
selected_persona, prompt_mode, changed = _render_persona_prompt_controls(
|
| 231 |
personas,
|
|
@@ -240,16 +331,11 @@ def _render_compare_mode(
|
|
| 240 |
state["persona_id"] = selected_persona.id
|
| 241 |
state["prompt_mode"] = prompt_mode
|
| 242 |
_clear_chat_ui_state(prompt_key, show_all_key)
|
|
|
|
| 243 |
|
| 244 |
active_system_prompt = resolve_system_prompt(
|
| 245 |
persona=selected_persona, mode=prompt_mode
|
| 246 |
)
|
| 247 |
-
active_system_prompt = _render_system_prompt_editor(
|
| 248 |
-
prompt_key,
|
| 249 |
-
prompt_mode,
|
| 250 |
-
active_system_prompt,
|
| 251 |
-
height=150,
|
| 252 |
-
)
|
| 253 |
|
| 254 |
btn_col1, btn_col2 = st.columns(2)
|
| 255 |
with btn_col1:
|
|
@@ -279,22 +365,73 @@ def _render_compare_mode(
|
|
| 279 |
state["messages"] = []
|
| 280 |
state["past_key_values"] = None
|
| 281 |
_clear_chat_ui_state(prompt_key, show_all_key)
|
|
|
|
| 282 |
st.rerun()
|
| 283 |
|
| 284 |
chat_log = st.container()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
_render_chat_window(
|
| 286 |
chat_log=chat_log,
|
| 287 |
messages=state["messages"],
|
| 288 |
show_all_key=show_all_key,
|
| 289 |
show_all_btn_key=widget_key(panel_key, "show_all_btn"),
|
| 290 |
show_earlier_label="Show earlier",
|
|
|
|
|
|
|
|
|
|
| 291 |
)
|
| 292 |
-
return state, chat_log, active_system_prompt
|
| 293 |
|
| 294 |
with left_col:
|
| 295 |
-
left_state, left_log, left_prompt = render_panel("left", left_col)
|
| 296 |
with right_col:
|
| 297 |
-
right_state, right_log, right_prompt = render_panel(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
|
| 299 |
user_prompt = st.chat_input(
|
| 300 |
"Ask both...",
|
|
@@ -304,12 +441,8 @@ def _render_compare_mode(
|
|
| 304 |
return
|
| 305 |
|
| 306 |
model = cached_model(model_name=model_name, remote=remote)
|
| 307 |
-
panels = [
|
| 308 |
-
(left_state, left_log, left_prompt),
|
| 309 |
-
(right_state, right_log, right_prompt),
|
| 310 |
-
]
|
| 311 |
|
| 312 |
-
for panel_state, panel_log, _panel_prompt in panels:
|
| 313 |
panel_state["messages"].append({"role": "user", "content": user_prompt})
|
| 314 |
with panel_log:
|
| 315 |
_render_chat_message({"role": "user", "content": user_prompt})
|
|
@@ -331,7 +464,7 @@ def _render_compare_mode(
|
|
| 331 |
past_key_values=panel_state["past_key_values"],
|
| 332 |
**gen_kwargs,
|
| 333 |
)
|
| 334 |
-
for panel_state, _panel_log, panel_prompt in panels
|
| 335 |
]
|
| 336 |
results: list[ChatReply | Exception] = []
|
| 337 |
for future in futures:
|
|
@@ -341,7 +474,7 @@ def _render_compare_mode(
|
|
| 341 |
results.append(exc)
|
| 342 |
else:
|
| 343 |
results = []
|
| 344 |
-
for panel_state, _panel_log, panel_prompt in panels:
|
| 345 |
try:
|
| 346 |
results.append(
|
| 347 |
generate_chat_reply(
|
|
@@ -360,7 +493,9 @@ def _render_compare_mode(
|
|
| 360 |
except Exception as exc:
|
| 361 |
results.append(exc)
|
| 362 |
|
| 363 |
-
for (panel_state, panel_log, _panel_prompt), result in zip(
|
|
|
|
|
|
|
| 364 |
if isinstance(result, Exception):
|
| 365 |
with panel_log:
|
| 366 |
st.error(f"Generation failed: {result}")
|
|
@@ -384,7 +519,11 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 384 |
context_key = chat_session_key(model_name, dataset_source)
|
| 385 |
chat_state = get_chat_state(model_name, remote, dataset_source)
|
| 386 |
try:
|
| 387 |
-
dataset, dataset_status = load_dataset(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
st.caption(dataset_status)
|
| 389 |
except Exception as exc:
|
| 390 |
st.error(f"Could not load data: {exc}")
|
|
@@ -534,6 +673,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 534 |
pending_key = widget_key(context_key, "pending_prompt")
|
| 535 |
export_key = widget_key(context_key, "export_chat")
|
| 536 |
reset_key = widget_key(context_key, "reset")
|
|
|
|
| 537 |
|
| 538 |
col1, col2 = st.columns([2, 1])
|
| 539 |
with col1:
|
|
@@ -571,7 +711,6 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 571 |
had_history = bool(chat_state["messages"])
|
| 572 |
_reset_single_chat_context(
|
| 573 |
model_name,
|
| 574 |
-
remote,
|
| 575 |
dataset_source,
|
| 576 |
chat_state,
|
| 577 |
selected_persona.id,
|
|
@@ -581,17 +720,20 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 581 |
prompt_key,
|
| 582 |
pending_key,
|
| 583 |
)
|
|
|
|
| 584 |
if had_history:
|
| 585 |
st.info("Chat history reset because the persona or system prompt changed.")
|
| 586 |
|
| 587 |
chat_log = st.container()
|
| 588 |
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
|
|
|
|
|
|
| 595 |
|
| 596 |
action_col1, action_col2 = st.columns(2)
|
| 597 |
with action_col1:
|
|
@@ -612,7 +754,6 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 612 |
if st.button("Reset chat", key=reset_key, width="stretch", type="secondary"):
|
| 613 |
_reset_single_chat_context(
|
| 614 |
model_name,
|
| 615 |
-
remote,
|
| 616 |
dataset_source,
|
| 617 |
chat_state,
|
| 618 |
selected_persona.id,
|
|
@@ -622,6 +763,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 622 |
prompt_key,
|
| 623 |
pending_key,
|
| 624 |
)
|
|
|
|
| 625 |
st.rerun()
|
| 626 |
|
| 627 |
_render_chat_window(
|
|
@@ -630,6 +772,9 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 630 |
show_all_key=show_all_key,
|
| 631 |
show_all_btn_key=widget_key(context_key, "show_all_btn"),
|
| 632 |
show_earlier_label="Show earlier messages",
|
|
|
|
|
|
|
|
|
|
| 633 |
)
|
| 634 |
|
| 635 |
user_prompt = st.chat_input(
|
|
|
|
| 23 |
)
|
| 24 |
from utils.runtime import cached_model
|
| 25 |
|
| 26 |
+
COLLAPSED_MESSAGE_CHAR_LIMIT = 500
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _render_collapsible_markdown(content: str) -> None:
|
| 30 |
+
if len(content) <= COLLAPSED_MESSAGE_CHAR_LIMIT:
|
| 31 |
+
st.markdown(content)
|
| 32 |
+
return
|
| 33 |
+
|
| 34 |
+
with st.expander(f"Show full text ({len(content)} chars)", expanded=False):
|
| 35 |
+
st.markdown(content)
|
| 36 |
+
|
| 37 |
|
| 38 |
def _render_chat_message(message: dict[str, str]) -> None:
|
| 39 |
if not message.get("content"):
|
| 40 |
return
|
| 41 |
with st.chat_message(message["role"]):
|
| 42 |
+
_render_collapsible_markdown(message["content"])
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _render_inline_system_prompt(
|
| 46 |
+
prompt_key: str,
|
| 47 |
+
prompt_mode: str,
|
| 48 |
+
active_system_prompt: str | None,
|
| 49 |
+
edit_key: str,
|
| 50 |
+
height: int = 200,
|
| 51 |
+
) -> str | None:
|
| 52 |
+
"""Render the system prompt as an inline editable item at the top of the chat."""
|
| 53 |
+
if prompt_mode == "empty":
|
| 54 |
+
return active_system_prompt
|
| 55 |
+
|
| 56 |
+
if prompt_key not in st.session_state:
|
| 57 |
+
st.session_state[prompt_key] = active_system_prompt or ""
|
| 58 |
+
|
| 59 |
+
current_prompt = st.session_state[prompt_key] or None
|
| 60 |
+
is_editing = st.session_state.get(edit_key) == -1
|
| 61 |
+
|
| 62 |
+
with st.container(border=True):
|
| 63 |
+
st.caption("System prompt")
|
| 64 |
+
if is_editing:
|
| 65 |
+
new_val = st.text_area(
|
| 66 |
+
"system_prompt_edit",
|
| 67 |
+
value=current_prompt or "",
|
| 68 |
+
height=height,
|
| 69 |
+
label_visibility="collapsed",
|
| 70 |
+
key=f"{prompt_key}_inline_edit",
|
| 71 |
+
)
|
| 72 |
+
c1, c2 = st.columns(2)
|
| 73 |
+
with c1:
|
| 74 |
+
if st.button("Save", key=f"{edit_key}_sys_save", type="primary"):
|
| 75 |
+
st.session_state[prompt_key] = new_val
|
| 76 |
+
st.session_state[edit_key] = None
|
| 77 |
+
st.rerun()
|
| 78 |
+
with c2:
|
| 79 |
+
if st.button("Cancel", key=f"{edit_key}_sys_cancel"):
|
| 80 |
+
st.session_state[edit_key] = None
|
| 81 |
+
st.rerun()
|
| 82 |
+
else:
|
| 83 |
+
if current_prompt:
|
| 84 |
+
_render_collapsible_markdown(current_prompt)
|
| 85 |
+
else:
|
| 86 |
+
st.markdown("*(empty)*")
|
| 87 |
+
if st.button("Edit", key=f"{edit_key}_sys_edit"):
|
| 88 |
+
st.session_state[edit_key] = -1
|
| 89 |
+
st.rerun()
|
| 90 |
+
|
| 91 |
+
return st.session_state.get(prompt_key) or None
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _render_editable_message(
|
| 95 |
+
message: dict[str, str],
|
| 96 |
+
msg_index: int,
|
| 97 |
+
messages: list[dict[str, str]],
|
| 98 |
+
chat_state: dict[str, object],
|
| 99 |
+
edit_key: str,
|
| 100 |
+
pending_key: str,
|
| 101 |
+
) -> None:
|
| 102 |
+
"""Render a single message with an inline edit button."""
|
| 103 |
+
if not message.get("content"):
|
| 104 |
+
return
|
| 105 |
+
|
| 106 |
+
is_editing = st.session_state.get(edit_key) == msg_index
|
| 107 |
+
|
| 108 |
+
with st.chat_message(message["role"]):
|
| 109 |
+
if is_editing:
|
| 110 |
+
new_content = st.text_area(
|
| 111 |
+
"Edit",
|
| 112 |
+
value=message["content"],
|
| 113 |
+
height=100,
|
| 114 |
+
label_visibility="collapsed",
|
| 115 |
+
key=f"{edit_key}_msg_{msg_index}",
|
| 116 |
+
)
|
| 117 |
+
c1, c2 = st.columns(2)
|
| 118 |
+
with c1:
|
| 119 |
+
if st.button(
|
| 120 |
+
"Save", key=f"{edit_key}_msg_save_{msg_index}", type="primary"
|
| 121 |
+
):
|
| 122 |
+
messages[msg_index]["content"] = new_content
|
| 123 |
+
del messages[msg_index + 1 :]
|
| 124 |
+
chat_state["past_key_values"] = None
|
| 125 |
+
st.session_state[edit_key] = None
|
| 126 |
+
if message["role"] == "user":
|
| 127 |
+
st.session_state[pending_key] = True
|
| 128 |
+
st.rerun()
|
| 129 |
+
with c2:
|
| 130 |
+
if st.button("Cancel", key=f"{edit_key}_msg_cancel_{msg_index}"):
|
| 131 |
+
st.session_state[edit_key] = None
|
| 132 |
+
st.rerun()
|
| 133 |
+
else:
|
| 134 |
+
st.markdown(message["content"])
|
| 135 |
+
if st.button("Edit", key=f"{edit_key}_msg_edit_{msg_index}"):
|
| 136 |
+
st.session_state[edit_key] = msg_index
|
| 137 |
+
st.rerun()
|
| 138 |
|
| 139 |
|
| 140 |
def _clear_chat_ui_state(*keys: str) -> None:
|
|
|
|
| 144 |
|
| 145 |
def _reset_single_chat_context(
|
| 146 |
model_name: str,
|
|
|
|
| 147 |
dataset_source: str,
|
| 148 |
chat_state: dict[str, object],
|
| 149 |
persona_id: str,
|
| 150 |
prompt_mode: str,
|
| 151 |
*ui_keys: str,
|
| 152 |
) -> None:
|
| 153 |
+
reset_chat_state(model_name, dataset_source)
|
| 154 |
chat_state["persona_id"] = persona_id
|
| 155 |
chat_state["prompt_mode"] = prompt_mode
|
| 156 |
_clear_chat_ui_state(*ui_keys)
|
|
|
|
| 206 |
return selected_persona, prompt_mode, changed
|
| 207 |
|
| 208 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
def _render_chat_window(
|
| 210 |
*,
|
| 211 |
chat_log: Any,
|
|
|
|
| 213 |
show_all_key: str,
|
| 214 |
show_all_btn_key: str,
|
| 215 |
show_earlier_label: str,
|
| 216 |
+
chat_state: dict[str, object] | None = None,
|
| 217 |
+
edit_key: str | None = None,
|
| 218 |
+
pending_key: str | None = None,
|
| 219 |
) -> Any:
|
| 220 |
"""Render the visible chat history inside one container."""
|
| 221 |
|
|
|
|
| 231 |
st.session_state[show_all_key] = True
|
| 232 |
st.rerun()
|
| 233 |
visible_messages = messages[-VISIBLE_MESSAGE_COUNT:]
|
| 234 |
+
index_offset = len(messages) - VISIBLE_MESSAGE_COUNT
|
| 235 |
else:
|
| 236 |
visible_messages = messages
|
| 237 |
+
index_offset = 0
|
| 238 |
|
| 239 |
+
for i, message in enumerate(visible_messages):
|
| 240 |
+
actual_index = index_offset + i
|
| 241 |
+
if edit_key and pending_key:
|
| 242 |
+
_render_editable_message(
|
| 243 |
+
message, actual_index, messages, chat_state, edit_key, pending_key
|
| 244 |
+
)
|
| 245 |
+
else:
|
| 246 |
+
_render_chat_message(message)
|
| 247 |
|
| 248 |
return chat_log
|
| 249 |
|
|
|
|
| 305 |
"""Render the full side-by-side comparison UI."""
|
| 306 |
left_col, right_col = st.columns(2)
|
| 307 |
|
| 308 |
+
def render_panel(
|
| 309 |
+
side: str, column
|
| 310 |
+
) -> tuple[dict[str, object], Any, str | None, str]:
|
| 311 |
panel_key = widget_key(context_key, f"cmp_{side}")
|
| 312 |
state = st.session_state.get(panel_key)
|
| 313 |
if state is None:
|
|
|
|
| 315 |
st.session_state[panel_key] = state
|
| 316 |
prompt_key = widget_key(panel_key, "custom_prompt")
|
| 317 |
show_all_key = widget_key(panel_key, "show_all")
|
| 318 |
+
edit_key = widget_key(panel_key, "edit_idx")
|
| 319 |
+
pending_regen_key = widget_key(panel_key, "pending_regen")
|
| 320 |
|
| 321 |
selected_persona, prompt_mode, changed = _render_persona_prompt_controls(
|
| 322 |
personas,
|
|
|
|
| 331 |
state["persona_id"] = selected_persona.id
|
| 332 |
state["prompt_mode"] = prompt_mode
|
| 333 |
_clear_chat_ui_state(prompt_key, show_all_key)
|
| 334 |
+
st.session_state.pop(edit_key, None)
|
| 335 |
|
| 336 |
active_system_prompt = resolve_system_prompt(
|
| 337 |
persona=selected_persona, mode=prompt_mode
|
| 338 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
btn_col1, btn_col2 = st.columns(2)
|
| 341 |
with btn_col1:
|
|
|
|
| 365 |
state["messages"] = []
|
| 366 |
state["past_key_values"] = None
|
| 367 |
_clear_chat_ui_state(prompt_key, show_all_key)
|
| 368 |
+
st.session_state.pop(edit_key, None)
|
| 369 |
st.rerun()
|
| 370 |
|
| 371 |
chat_log = st.container()
|
| 372 |
+
with chat_log:
|
| 373 |
+
active_system_prompt = _render_inline_system_prompt(
|
| 374 |
+
prompt_key,
|
| 375 |
+
prompt_mode,
|
| 376 |
+
active_system_prompt,
|
| 377 |
+
edit_key,
|
| 378 |
+
height=150,
|
| 379 |
+
)
|
| 380 |
_render_chat_window(
|
| 381 |
chat_log=chat_log,
|
| 382 |
messages=state["messages"],
|
| 383 |
show_all_key=show_all_key,
|
| 384 |
show_all_btn_key=widget_key(panel_key, "show_all_btn"),
|
| 385 |
show_earlier_label="Show earlier",
|
| 386 |
+
chat_state=state,
|
| 387 |
+
edit_key=edit_key,
|
| 388 |
+
pending_key=pending_regen_key,
|
| 389 |
)
|
| 390 |
+
return state, chat_log, active_system_prompt, pending_regen_key
|
| 391 |
|
| 392 |
with left_col:
|
| 393 |
+
left_state, left_log, left_prompt, left_pending = render_panel("left", left_col)
|
| 394 |
with right_col:
|
| 395 |
+
right_state, right_log, right_prompt, right_pending = render_panel(
|
| 396 |
+
"right", right_col
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
panels = [
|
| 400 |
+
(left_state, left_log, left_prompt, left_pending),
|
| 401 |
+
(right_state, right_log, right_prompt, right_pending),
|
| 402 |
+
]
|
| 403 |
+
|
| 404 |
+
# Handle per-panel regeneration triggered by message edits
|
| 405 |
+
any_regen = any(st.session_state.get(p_pending) for _, _, _, p_pending in panels)
|
| 406 |
+
if any_regen:
|
| 407 |
+
model = cached_model(model_name=model_name, remote=remote)
|
| 408 |
+
for panel_state, panel_log, panel_prompt, p_pending in panels:
|
| 409 |
+
if not st.session_state.pop(p_pending, False):
|
| 410 |
+
continue
|
| 411 |
+
regen_messages = _build_chat_messages(panel_prompt, panel_state["messages"])
|
| 412 |
+
with st.spinner("Regenerating..."):
|
| 413 |
+
try:
|
| 414 |
+
result = generate_chat_reply(
|
| 415 |
+
model=model,
|
| 416 |
+
messages=regen_messages,
|
| 417 |
+
remote=remote,
|
| 418 |
+
past_key_values=panel_state["past_key_values"],
|
| 419 |
+
**gen_kwargs,
|
| 420 |
+
)
|
| 421 |
+
except Exception as exc:
|
| 422 |
+
with panel_log:
|
| 423 |
+
st.error(f"Generation failed: {exc}")
|
| 424 |
+
panel_state["messages"].pop()
|
| 425 |
+
continue
|
| 426 |
+
panel_state["messages"].append(
|
| 427 |
+
{"role": "assistant", "content": result.text}
|
| 428 |
+
)
|
| 429 |
+
panel_state["past_key_values"] = (
|
| 430 |
+
result.past_key_values if not remote else None
|
| 431 |
+
)
|
| 432 |
+
with panel_log:
|
| 433 |
+
_render_chat_message({"role": "assistant", "content": result.text})
|
| 434 |
+
st.rerun()
|
| 435 |
|
| 436 |
user_prompt = st.chat_input(
|
| 437 |
"Ask both...",
|
|
|
|
| 441 |
return
|
| 442 |
|
| 443 |
model = cached_model(model_name=model_name, remote=remote)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
|
| 445 |
+
for panel_state, panel_log, _panel_prompt, _p_pending in panels:
|
| 446 |
panel_state["messages"].append({"role": "user", "content": user_prompt})
|
| 447 |
with panel_log:
|
| 448 |
_render_chat_message({"role": "user", "content": user_prompt})
|
|
|
|
| 464 |
past_key_values=panel_state["past_key_values"],
|
| 465 |
**gen_kwargs,
|
| 466 |
)
|
| 467 |
+
for panel_state, _panel_log, panel_prompt, _p_pending in panels
|
| 468 |
]
|
| 469 |
results: list[ChatReply | Exception] = []
|
| 470 |
for future in futures:
|
|
|
|
| 474 |
results.append(exc)
|
| 475 |
else:
|
| 476 |
results = []
|
| 477 |
+
for panel_state, _panel_log, panel_prompt, _p_pending in panels:
|
| 478 |
try:
|
| 479 |
results.append(
|
| 480 |
generate_chat_reply(
|
|
|
|
| 493 |
except Exception as exc:
|
| 494 |
results.append(exc)
|
| 495 |
|
| 496 |
+
for (panel_state, panel_log, _panel_prompt, _p_pending), result in zip(
|
| 497 |
+
panels, results
|
| 498 |
+
):
|
| 499 |
if isinstance(result, Exception):
|
| 500 |
with panel_log:
|
| 501 |
st.error(f"Generation failed: {result}")
|
|
|
|
| 519 |
context_key = chat_session_key(model_name, dataset_source)
|
| 520 |
chat_state = get_chat_state(model_name, remote, dataset_source)
|
| 521 |
try:
|
| 522 |
+
dataset, dataset_status = load_dataset(
|
| 523 |
+
dataset_source,
|
| 524 |
+
personas_file=st.session_state.get("extract__personas_file"),
|
| 525 |
+
qa_file=st.session_state.get("extract__qa_file"),
|
| 526 |
+
)
|
| 527 |
st.caption(dataset_status)
|
| 528 |
except Exception as exc:
|
| 529 |
st.error(f"Could not load data: {exc}")
|
|
|
|
| 673 |
pending_key = widget_key(context_key, "pending_prompt")
|
| 674 |
export_key = widget_key(context_key, "export_chat")
|
| 675 |
reset_key = widget_key(context_key, "reset")
|
| 676 |
+
edit_key = widget_key(context_key, "edit_idx")
|
| 677 |
|
| 678 |
col1, col2 = st.columns([2, 1])
|
| 679 |
with col1:
|
|
|
|
| 711 |
had_history = bool(chat_state["messages"])
|
| 712 |
_reset_single_chat_context(
|
| 713 |
model_name,
|
|
|
|
| 714 |
dataset_source,
|
| 715 |
chat_state,
|
| 716 |
selected_persona.id,
|
|
|
|
| 720 |
prompt_key,
|
| 721 |
pending_key,
|
| 722 |
)
|
| 723 |
+
st.session_state.pop(edit_key, None)
|
| 724 |
if had_history:
|
| 725 |
st.info("Chat history reset because the persona or system prompt changed.")
|
| 726 |
|
| 727 |
chat_log = st.container()
|
| 728 |
|
| 729 |
+
with chat_log:
|
| 730 |
+
active_system_prompt = _render_inline_system_prompt(
|
| 731 |
+
prompt_key,
|
| 732 |
+
prompt_mode,
|
| 733 |
+
active_system_prompt,
|
| 734 |
+
edit_key,
|
| 735 |
+
height=200,
|
| 736 |
+
)
|
| 737 |
|
| 738 |
action_col1, action_col2 = st.columns(2)
|
| 739 |
with action_col1:
|
|
|
|
| 754 |
if st.button("Reset chat", key=reset_key, width="stretch", type="secondary"):
|
| 755 |
_reset_single_chat_context(
|
| 756 |
model_name,
|
|
|
|
| 757 |
dataset_source,
|
| 758 |
chat_state,
|
| 759 |
selected_persona.id,
|
|
|
|
| 763 |
prompt_key,
|
| 764 |
pending_key,
|
| 765 |
)
|
| 766 |
+
st.session_state.pop(edit_key, None)
|
| 767 |
st.rerun()
|
| 768 |
|
| 769 |
_render_chat_window(
|
|
|
|
| 772 |
show_all_key=show_all_key,
|
| 773 |
show_all_btn_key=widget_key(context_key, "show_all_btn"),
|
| 774 |
show_earlier_label="Show earlier messages",
|
| 775 |
+
chat_state=chat_state,
|
| 776 |
+
edit_key=edit_key,
|
| 777 |
+
pending_key=pending_key,
|
| 778 |
)
|
| 779 |
|
| 780 |
user_prompt = st.chat_input(
|
tabs/compare.py
CHANGED
|
@@ -5,7 +5,7 @@ import streamlit as st
|
|
| 5 |
import torch
|
| 6 |
from persona_data.environment import get_artifacts_dir
|
| 7 |
from persona_vectors.analysis import build_embedding_figure, project_pca, project_umap
|
| 8 |
-
from persona_vectors.artifacts import ActivationStore
|
| 9 |
from persona_vectors.artifacts import list_layers as list_available_layers
|
| 10 |
from persona_vectors.artifacts import list_personas as list_available_personas
|
| 11 |
from persona_vectors.artifacts import load_mean_activations, load_persona_names
|
|
@@ -14,7 +14,6 @@ from persona_vectors.plots import plot_layer_similarity, save_plot_html, save_pl
|
|
| 14 |
from utils.helpers import (
|
| 15 |
ANALYSIS_HELP_TEXT,
|
| 16 |
ANALYSIS_MODES,
|
| 17 |
-
PROMPT_VARIANTS,
|
| 18 |
persona_display_label,
|
| 19 |
prompt_variant_label,
|
| 20 |
slugify,
|
|
@@ -34,20 +33,27 @@ class ProjectionConfig:
|
|
| 34 |
project_fn: Callable[[torch.Tensor], torch.Tensor]
|
| 35 |
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
_PROJECTION_CONFIGS: dict[str, ProjectionConfig] = {
|
| 38 |
"PCA": ProjectionConfig("PCA", "PC1", "PC2", project_pca),
|
| 39 |
"UMAP": ProjectionConfig("UMAP", "UMAP 1", "UMAP 2", project_umap),
|
| 40 |
}
|
| 41 |
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
persona_ids: list[str],
|
| 49 |
-
) -> list[int]:
|
| 50 |
-
return list_available_layers(root_dir, model_name, variants, persona_ids)
|
| 51 |
|
| 52 |
|
| 53 |
def _load_embedding_samples(
|
|
@@ -86,9 +92,9 @@ def _load_embedding_samples(
|
|
| 86 |
continue
|
| 87 |
|
| 88 |
layer_vectors = vectors[:, layer_idx, :]
|
| 89 |
-
samples.append(layer_vectors)
|
| 90 |
-
labels.extend([persona_id] * layer_vectors.shape[0])
|
| 91 |
display_name = persona_names.get(persona_id) or persona_id
|
|
|
|
|
|
|
| 92 |
hover_text.extend(
|
| 93 |
[f"<b>{display_name}</b><br>{variant}"] * layer_vectors.shape[0]
|
| 94 |
)
|
|
@@ -114,28 +120,8 @@ def _load_embedding_samples(
|
|
| 114 |
return plots, errors
|
| 115 |
|
| 116 |
|
| 117 |
-
def _build_embedding_figures(
|
| 118 |
-
plots: list[tuple[int, torch.Tensor, list[str], list[str]]],
|
| 119 |
-
config: ProjectionConfig,
|
| 120 |
-
) -> list[tuple[int, object]]:
|
| 121 |
-
return [
|
| 122 |
-
(
|
| 123 |
-
layer_idx,
|
| 124 |
-
build_embedding_figure(
|
| 125 |
-
coords=coords,
|
| 126 |
-
labels=labels,
|
| 127 |
-
title=f"{config.title_prefix}, layer {layer_idx}",
|
| 128 |
-
x_label=config.x_label,
|
| 129 |
-
y_label=config.y_label,
|
| 130 |
-
hover_text=hover_text,
|
| 131 |
-
),
|
| 132 |
-
)
|
| 133 |
-
for layer_idx, coords, labels, hover_text in plots
|
| 134 |
-
]
|
| 135 |
-
|
| 136 |
-
|
| 137 |
def _render_embedding_results(
|
| 138 |
-
|
| 139 |
analysis_mode: str,
|
| 140 |
rendered_figures: list[tuple[int, object]],
|
| 141 |
saved_variant: str,
|
|
@@ -152,7 +138,7 @@ def _render_embedding_results(
|
|
| 152 |
_filename(
|
| 153 |
"compare",
|
| 154 |
analysis_mode,
|
| 155 |
-
|
| 156 |
saved_variant,
|
| 157 |
saved_persona_key,
|
| 158 |
str(layer_idx),
|
|
@@ -181,15 +167,20 @@ def _select_artifact_personas(
|
|
| 181 |
st.info("No personas found for this model yet. Run extraction first.")
|
| 182 |
return [], persona_names
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
persona_ids = st.multiselect(
|
| 185 |
"Personas",
|
| 186 |
options=persona_options,
|
| 187 |
-
default=
|
| 188 |
format_func=lambda persona_id: persona_display_label(
|
| 189 |
persona_id, persona_names.get(persona_id)
|
| 190 |
),
|
| 191 |
key=widget_key("load", "personas", store.model_name, *variants),
|
| 192 |
)
|
|
|
|
| 193 |
return persona_ids, persona_names
|
| 194 |
|
| 195 |
|
|
@@ -215,11 +206,11 @@ def _render_save_buttons(
|
|
| 215 |
|
| 216 |
def _select_embedding_config(
|
| 217 |
store: ActivationStore,
|
| 218 |
-
) ->
|
| 219 |
"""Render variant / persona / layer selectors and return the selection, or None on early exit."""
|
| 220 |
selected_variant = st.selectbox(
|
| 221 |
"Variant",
|
| 222 |
-
options=
|
| 223 |
format_func=prompt_variant_label,
|
| 224 |
key=widget_key("load", "variant"),
|
| 225 |
)
|
|
@@ -228,7 +219,8 @@ def _select_embedding_config(
|
|
| 228 |
if not persona_ids:
|
| 229 |
return None
|
| 230 |
|
| 231 |
-
|
|
|
|
| 232 |
str(store.root_dir),
|
| 233 |
store.model_name,
|
| 234 |
[selected_variant],
|
|
@@ -240,14 +232,14 @@ def _select_embedding_config(
|
|
| 240 |
)
|
| 241 |
return None
|
| 242 |
|
| 243 |
-
persona_key = "_".join(sorted(persona_ids))
|
| 244 |
layer_key = widget_key(
|
| 245 |
"load", "layers", store.model_name, selected_variant, persona_key
|
| 246 |
)
|
|
|
|
|
|
|
|
|
|
| 247 |
default_layers = [
|
| 248 |
-
layer
|
| 249 |
-
for layer in st.session_state.get(layer_key, layer_options[:3])
|
| 250 |
-
if layer in layer_options
|
| 251 |
] or layer_options[:3]
|
| 252 |
selected_layers = st.multiselect(
|
| 253 |
"Layers",
|
|
@@ -259,7 +251,15 @@ def _select_embedding_config(
|
|
| 259 |
st.info("Select at least one layer.")
|
| 260 |
return None
|
| 261 |
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
|
| 265 |
def _render_cosine_similarity(store: ActivationStore) -> None:
|
|
@@ -267,7 +267,7 @@ def _render_cosine_similarity(store: ActivationStore) -> None:
|
|
| 267 |
with col1:
|
| 268 |
variant_a = st.selectbox(
|
| 269 |
"Variant A",
|
| 270 |
-
options=
|
| 271 |
index=0,
|
| 272 |
format_func=prompt_variant_label,
|
| 273 |
key=widget_key("load", "variant_a"),
|
|
@@ -275,8 +275,8 @@ def _render_cosine_similarity(store: ActivationStore) -> None:
|
|
| 275 |
with col2:
|
| 276 |
variant_b = st.selectbox(
|
| 277 |
"Variant B",
|
| 278 |
-
options=
|
| 279 |
-
index=min(1, len(
|
| 280 |
format_func=prompt_variant_label,
|
| 281 |
key=widget_key("load", "variant_b"),
|
| 282 |
)
|
|
@@ -289,7 +289,9 @@ def _render_cosine_similarity(store: ActivationStore) -> None:
|
|
| 289 |
if not persona_ids:
|
| 290 |
return
|
| 291 |
|
| 292 |
-
cosine_fig_key = widget_key(
|
|
|
|
|
|
|
| 293 |
filename = _filename("compare", "cosine", store.model_name, variant_a, variant_b)
|
| 294 |
|
| 295 |
if st.button("Compare vectors", type="primary"):
|
|
@@ -334,8 +336,7 @@ def _render_embedding_analysis(store: ActivationStore, analysis_mode: str) -> No
|
|
| 334 |
config = _select_embedding_config(store)
|
| 335 |
if config is None:
|
| 336 |
return
|
| 337 |
-
|
| 338 |
-
persona_key = "_".join(sorted(persona_ids))
|
| 339 |
projection_config = _PROJECTION_CONFIGS.get(analysis_mode)
|
| 340 |
if projection_config is None:
|
| 341 |
st.error(f"Unsupported analysis mode: {analysis_mode}")
|
|
@@ -358,11 +359,11 @@ def _render_embedding_analysis(store: ActivationStore, analysis_mode: str) -> No
|
|
| 358 |
try:
|
| 359 |
plots, errors = _load_embedding_samples(
|
| 360 |
store,
|
| 361 |
-
persona_ids,
|
| 362 |
-
|
| 363 |
-
selected_layers,
|
| 364 |
projection_config.project_fn,
|
| 365 |
-
persona_names,
|
| 366 |
progress_fn=update_progress,
|
| 367 |
)
|
| 368 |
|
|
@@ -382,12 +383,25 @@ def _render_embedding_analysis(store: ActivationStore, analysis_mode: str) -> No
|
|
| 382 |
st.info("Try fewer personas, fewer layers, or a different variant.")
|
| 383 |
st.session_state.pop(embedding_fig_key, None)
|
| 384 |
else:
|
| 385 |
-
rendered_figures =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
total_samples = sum(coords.shape[0] for _, coords, _, _ in plots)
|
| 387 |
st.session_state[embedding_fig_key] = (
|
| 388 |
rendered_figures,
|
| 389 |
-
persona_key,
|
| 390 |
-
|
| 391 |
total_samples,
|
| 392 |
)
|
| 393 |
finally:
|
|
@@ -398,7 +412,7 @@ def _render_embedding_analysis(store: ActivationStore, analysis_mode: str) -> No
|
|
| 398 |
st.session_state[embedding_fig_key]
|
| 399 |
)
|
| 400 |
_render_embedding_results(
|
| 401 |
-
store,
|
| 402 |
analysis_mode,
|
| 403 |
rendered_figures,
|
| 404 |
saved_variant,
|
|
|
|
| 5 |
import torch
|
| 6 |
from persona_data.environment import get_artifacts_dir
|
| 7 |
from persona_vectors.analysis import build_embedding_figure, project_pca, project_umap
|
| 8 |
+
from persona_vectors.artifacts import SUPPORTED_VARIANTS, ActivationStore
|
| 9 |
from persona_vectors.artifacts import list_layers as list_available_layers
|
| 10 |
from persona_vectors.artifacts import list_personas as list_available_personas
|
| 11 |
from persona_vectors.artifacts import load_mean_activations, load_persona_names
|
|
|
|
| 14 |
from utils.helpers import (
|
| 15 |
ANALYSIS_HELP_TEXT,
|
| 16 |
ANALYSIS_MODES,
|
|
|
|
| 17 |
persona_display_label,
|
| 18 |
prompt_variant_label,
|
| 19 |
slugify,
|
|
|
|
| 33 |
project_fn: Callable[[torch.Tensor], torch.Tensor]
|
| 34 |
|
| 35 |
|
| 36 |
+
@dataclass(frozen=True)
|
| 37 |
+
class _EmbeddingConfig:
|
| 38 |
+
variant: str
|
| 39 |
+
persona_ids: list[str]
|
| 40 |
+
persona_names: dict[str, str]
|
| 41 |
+
selected_layers: list[int]
|
| 42 |
+
persona_key: str
|
| 43 |
+
|
| 44 |
+
|
| 45 |
_PROJECTION_CONFIGS: dict[str, ProjectionConfig] = {
|
| 46 |
"PCA": ProjectionConfig("PCA", "PC1", "PC2", project_pca),
|
| 47 |
"UMAP": ProjectionConfig("UMAP", "UMAP 1", "UMAP 2", project_umap),
|
| 48 |
}
|
| 49 |
|
| 50 |
+
_list_layers_cached = st.cache_data(show_spinner=False)(list_available_layers)
|
| 51 |
|
| 52 |
+
# Cross-model/NDIF-switch persistence keys — written on every render so that
|
| 53 |
+
# when the model changes (and widget keys change) the last selection is reused
|
| 54 |
+
# as the default, filtered to whatever is available for the new model.
|
| 55 |
+
_LAST_PERSONAS_KEY = "compare:last_personas"
|
| 56 |
+
_LAST_LAYERS_KEY = "compare:last_layers"
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
|
| 59 |
def _load_embedding_samples(
|
|
|
|
| 92 |
continue
|
| 93 |
|
| 94 |
layer_vectors = vectors[:, layer_idx, :]
|
|
|
|
|
|
|
| 95 |
display_name = persona_names.get(persona_id) or persona_id
|
| 96 |
+
samples.append(layer_vectors)
|
| 97 |
+
labels.extend([display_name] * layer_vectors.shape[0])
|
| 98 |
hover_text.extend(
|
| 99 |
[f"<b>{display_name}</b><br>{variant}"] * layer_vectors.shape[0]
|
| 100 |
)
|
|
|
|
| 120 |
return plots, errors
|
| 121 |
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
def _render_embedding_results(
|
| 124 |
+
model_name: str,
|
| 125 |
analysis_mode: str,
|
| 126 |
rendered_figures: list[tuple[int, object]],
|
| 127 |
saved_variant: str,
|
|
|
|
| 138 |
_filename(
|
| 139 |
"compare",
|
| 140 |
analysis_mode,
|
| 141 |
+
model_name,
|
| 142 |
saved_variant,
|
| 143 |
saved_persona_key,
|
| 144 |
str(layer_idx),
|
|
|
|
| 167 |
st.info("No personas found for this model yet. Run extraction first.")
|
| 168 |
return [], persona_names
|
| 169 |
|
| 170 |
+
last_personas: list[str] = st.session_state.get(_LAST_PERSONAS_KEY, [])
|
| 171 |
+
default_personas = [
|
| 172 |
+
p for p in last_personas if p in persona_options
|
| 173 |
+
] or persona_options[:1]
|
| 174 |
persona_ids = st.multiselect(
|
| 175 |
"Personas",
|
| 176 |
options=persona_options,
|
| 177 |
+
default=default_personas,
|
| 178 |
format_func=lambda persona_id: persona_display_label(
|
| 179 |
persona_id, persona_names.get(persona_id)
|
| 180 |
),
|
| 181 |
key=widget_key("load", "personas", store.model_name, *variants),
|
| 182 |
)
|
| 183 |
+
st.session_state[_LAST_PERSONAS_KEY] = persona_ids
|
| 184 |
return persona_ids, persona_names
|
| 185 |
|
| 186 |
|
|
|
|
| 206 |
|
| 207 |
def _select_embedding_config(
|
| 208 |
store: ActivationStore,
|
| 209 |
+
) -> _EmbeddingConfig | None:
|
| 210 |
"""Render variant / persona / layer selectors and return the selection, or None on early exit."""
|
| 211 |
selected_variant = st.selectbox(
|
| 212 |
"Variant",
|
| 213 |
+
options=SUPPORTED_VARIANTS,
|
| 214 |
format_func=prompt_variant_label,
|
| 215 |
key=widget_key("load", "variant"),
|
| 216 |
)
|
|
|
|
| 219 |
if not persona_ids:
|
| 220 |
return None
|
| 221 |
|
| 222 |
+
persona_key = "_".join(sorted(persona_ids))
|
| 223 |
+
layer_options = _list_layers_cached(
|
| 224 |
str(store.root_dir),
|
| 225 |
store.model_name,
|
| 226 |
[selected_variant],
|
|
|
|
| 232 |
)
|
| 233 |
return None
|
| 234 |
|
|
|
|
| 235 |
layer_key = widget_key(
|
| 236 |
"load", "layers", store.model_name, selected_variant, persona_key
|
| 237 |
)
|
| 238 |
+
last_layers: list[int] = st.session_state.get(
|
| 239 |
+
layer_key, st.session_state.get(_LAST_LAYERS_KEY, layer_options[:3])
|
| 240 |
+
)
|
| 241 |
default_layers = [
|
| 242 |
+
layer for layer in last_layers if layer in layer_options
|
|
|
|
|
|
|
| 243 |
] or layer_options[:3]
|
| 244 |
selected_layers = st.multiselect(
|
| 245 |
"Layers",
|
|
|
|
| 251 |
st.info("Select at least one layer.")
|
| 252 |
return None
|
| 253 |
|
| 254 |
+
st.session_state[_LAST_LAYERS_KEY] = selected_layers
|
| 255 |
+
|
| 256 |
+
return _EmbeddingConfig(
|
| 257 |
+
variant=selected_variant,
|
| 258 |
+
persona_ids=persona_ids,
|
| 259 |
+
persona_names=persona_names,
|
| 260 |
+
selected_layers=selected_layers,
|
| 261 |
+
persona_key=persona_key,
|
| 262 |
+
)
|
| 263 |
|
| 264 |
|
| 265 |
def _render_cosine_similarity(store: ActivationStore) -> None:
|
|
|
|
| 267 |
with col1:
|
| 268 |
variant_a = st.selectbox(
|
| 269 |
"Variant A",
|
| 270 |
+
options=SUPPORTED_VARIANTS,
|
| 271 |
index=0,
|
| 272 |
format_func=prompt_variant_label,
|
| 273 |
key=widget_key("load", "variant_a"),
|
|
|
|
| 275 |
with col2:
|
| 276 |
variant_b = st.selectbox(
|
| 277 |
"Variant B",
|
| 278 |
+
options=SUPPORTED_VARIANTS,
|
| 279 |
+
index=min(1, len(SUPPORTED_VARIANTS) - 1),
|
| 280 |
format_func=prompt_variant_label,
|
| 281 |
key=widget_key("load", "variant_b"),
|
| 282 |
)
|
|
|
|
| 289 |
if not persona_ids:
|
| 290 |
return
|
| 291 |
|
| 292 |
+
cosine_fig_key = widget_key(
|
| 293 |
+
"load", "cosine_fig_state", store.model_name, variant_a, variant_b
|
| 294 |
+
)
|
| 295 |
filename = _filename("compare", "cosine", store.model_name, variant_a, variant_b)
|
| 296 |
|
| 297 |
if st.button("Compare vectors", type="primary"):
|
|
|
|
| 336 |
config = _select_embedding_config(store)
|
| 337 |
if config is None:
|
| 338 |
return
|
| 339 |
+
|
|
|
|
| 340 |
projection_config = _PROJECTION_CONFIGS.get(analysis_mode)
|
| 341 |
if projection_config is None:
|
| 342 |
st.error(f"Unsupported analysis mode: {analysis_mode}")
|
|
|
|
| 359 |
try:
|
| 360 |
plots, errors = _load_embedding_samples(
|
| 361 |
store,
|
| 362 |
+
config.persona_ids,
|
| 363 |
+
config.variant,
|
| 364 |
+
config.selected_layers,
|
| 365 |
projection_config.project_fn,
|
| 366 |
+
config.persona_names,
|
| 367 |
progress_fn=update_progress,
|
| 368 |
)
|
| 369 |
|
|
|
|
| 383 |
st.info("Try fewer personas, fewer layers, or a different variant.")
|
| 384 |
st.session_state.pop(embedding_fig_key, None)
|
| 385 |
else:
|
| 386 |
+
rendered_figures = [
|
| 387 |
+
(
|
| 388 |
+
layer_idx,
|
| 389 |
+
build_embedding_figure(
|
| 390 |
+
coords=coords,
|
| 391 |
+
labels=labels,
|
| 392 |
+
title=f"{projection_config.title_prefix}, layer {layer_idx}",
|
| 393 |
+
x_label=projection_config.x_label,
|
| 394 |
+
y_label=projection_config.y_label,
|
| 395 |
+
hover_text=hover_text,
|
| 396 |
+
),
|
| 397 |
+
)
|
| 398 |
+
for layer_idx, coords, labels, hover_text in plots
|
| 399 |
+
]
|
| 400 |
total_samples = sum(coords.shape[0] for _, coords, _, _ in plots)
|
| 401 |
st.session_state[embedding_fig_key] = (
|
| 402 |
rendered_figures,
|
| 403 |
+
config.persona_key,
|
| 404 |
+
config.variant,
|
| 405 |
total_samples,
|
| 406 |
)
|
| 407 |
finally:
|
|
|
|
| 412 |
st.session_state[embedding_fig_key]
|
| 413 |
)
|
| 414 |
_render_embedding_results(
|
| 415 |
+
store.model_name,
|
| 416 |
analysis_mode,
|
| 417 |
rendered_figures,
|
| 418 |
saved_variant,
|
tabs/extract.py
CHANGED
|
@@ -1,16 +1,28 @@
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
|
|
|
| 2 |
from persona_vectors.extraction import run_extraction
|
| 3 |
|
| 4 |
from utils.datasets import load_dataset
|
| 5 |
from utils.helpers import (
|
| 6 |
NDIF_STATUS_ICONS,
|
| 7 |
-
PROMPT_VARIANTS,
|
| 8 |
persona_label,
|
| 9 |
prompt_variant_label,
|
| 10 |
widget_key,
|
| 11 |
)
|
| 12 |
from utils.runtime import cached_model
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
def _extract_widget_key(
|
| 16 |
model_name: str, remote: bool, dataset_source: str, suffix: str
|
|
@@ -26,7 +38,7 @@ def _render_local_dataset_uploads() -> None:
|
|
| 26 |
"personas.jsonl",
|
| 27 |
type=["jsonl"],
|
| 28 |
key="extract__personas_file",
|
| 29 |
-
help="Expected fields: id, persona,
|
| 30 |
)
|
| 31 |
st.file_uploader(
|
| 32 |
"qa.jsonl",
|
|
@@ -44,19 +56,28 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 44 |
if dataset_source == "Local JSONL upload":
|
| 45 |
_render_local_dataset_uploads()
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
selected_variants = st.multiselect(
|
| 48 |
"Prompt variants",
|
| 49 |
-
options=
|
| 50 |
-
default=
|
| 51 |
format_func=prompt_variant_label,
|
| 52 |
key=_extract_widget_key(model_name, remote, dataset_source, "prompt_variants"),
|
| 53 |
)
|
|
|
|
| 54 |
if not selected_variants:
|
| 55 |
st.info("Select at least one prompt variant.")
|
| 56 |
return
|
| 57 |
|
| 58 |
try:
|
| 59 |
-
dataset, dataset_status = load_dataset(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
st.caption(dataset_status)
|
| 61 |
except Exception as exc:
|
| 62 |
st.error(f"Could not load data: {exc}")
|
|
@@ -73,13 +94,18 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 73 |
)
|
| 74 |
return
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
selected_personas = st.multiselect(
|
| 77 |
"Personas",
|
| 78 |
options=personas,
|
| 79 |
-
default=
|
| 80 |
format_func=persona_label,
|
| 81 |
key=_extract_widget_key(model_name, remote, dataset_source, "persona_select"),
|
| 82 |
)
|
|
|
|
| 83 |
|
| 84 |
if not selected_personas:
|
| 85 |
st.info("Select at least one persona.")
|
|
@@ -93,26 +119,42 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 93 |
|
| 94 |
col1, col2, col3 = st.columns([2, 2, 1])
|
| 95 |
with col1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
qa_type_select = st.selectbox(
|
| 97 |
"QA type",
|
| 98 |
-
options=
|
| 99 |
-
index=
|
| 100 |
key=_extract_widget_key(
|
| 101 |
model_name, remote, dataset_source, "qa_type_select"
|
| 102 |
),
|
| 103 |
)
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
| 106 |
)
|
| 107 |
with col2:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
difficulty_values = st.multiselect(
|
| 109 |
"Difficulty",
|
| 110 |
options=[1, 2, 3],
|
| 111 |
-
default=
|
| 112 |
key=_extract_widget_key(
|
| 113 |
model_name, remote, dataset_source, "difficulty_select"
|
| 114 |
),
|
| 115 |
)
|
|
|
|
| 116 |
qa_filter_difficulty = difficulty_values if difficulty_values else None
|
| 117 |
|
| 118 |
runs, skipped = [], []
|
|
@@ -135,15 +177,18 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 135 |
return
|
| 136 |
|
| 137 |
max_q = min(len(qa_pairs) for _, qa_pairs in runs)
|
|
|
|
|
|
|
| 138 |
max_questions = st.slider(
|
| 139 |
"Max questions",
|
| 140 |
min_value=1,
|
| 141 |
max_value=max_q,
|
| 142 |
-
value=
|
| 143 |
key=_extract_widget_key(
|
| 144 |
model_name, remote, dataset_source, "max_questions"
|
| 145 |
),
|
| 146 |
)
|
|
|
|
| 147 |
|
| 148 |
if runs is None:
|
| 149 |
return
|
|
@@ -180,7 +225,7 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 180 |
model_name=model_name,
|
| 181 |
persona=persona,
|
| 182 |
qa_pairs=qa_pairs[:max_questions],
|
| 183 |
-
variants=
|
| 184 |
remote=remote,
|
| 185 |
on_status=_on_ndif_status if remote else None,
|
| 186 |
)
|
|
|
|
| 1 |
+
from typing import Literal, cast
|
| 2 |
+
|
| 3 |
import streamlit as st
|
| 4 |
+
from persona_vectors.artifacts import SUPPORTED_VARIANTS
|
| 5 |
from persona_vectors.extraction import run_extraction
|
| 6 |
|
| 7 |
from utils.datasets import load_dataset
|
| 8 |
from utils.helpers import (
|
| 9 |
NDIF_STATUS_ICONS,
|
|
|
|
| 10 |
persona_label,
|
| 11 |
prompt_variant_label,
|
| 12 |
widget_key,
|
| 13 |
)
|
| 14 |
from utils.runtime import cached_model
|
| 15 |
|
| 16 |
+
# Cross-model / remote-switch persistence — same pattern as compare.py.
|
| 17 |
+
# Written on every render so selections survive model or NDIF toggles.
|
| 18 |
+
_LAST_VARIANTS_KEY = "extract:last_variants"
|
| 19 |
+
_LAST_PERSONA_IDS_KEY = "extract:last_persona_ids"
|
| 20 |
+
_LAST_QA_TYPE_KEY = "extract:last_qa_type"
|
| 21 |
+
_LAST_DIFFICULTY_KEY = "extract:last_difficulty"
|
| 22 |
+
_LAST_MAX_QUESTIONS_KEY = "extract:last_max_questions"
|
| 23 |
+
|
| 24 |
+
_QA_TYPE_OPTIONS = ["all", "explicit", "implicit"]
|
| 25 |
+
|
| 26 |
|
| 27 |
def _extract_widget_key(
|
| 28 |
model_name: str, remote: bool, dataset_source: str, suffix: str
|
|
|
|
| 38 |
"personas.jsonl",
|
| 39 |
type=["jsonl"],
|
| 40 |
key="extract__personas_file",
|
| 41 |
+
help="Expected fields: id, persona, templated_view, biography_view",
|
| 42 |
)
|
| 43 |
st.file_uploader(
|
| 44 |
"qa.jsonl",
|
|
|
|
| 56 |
if dataset_source == "Local JSONL upload":
|
| 57 |
_render_local_dataset_uploads()
|
| 58 |
|
| 59 |
+
last_variants = st.session_state.get(_LAST_VARIANTS_KEY, list(SUPPORTED_VARIANTS))
|
| 60 |
+
default_variants = [v for v in last_variants if v in SUPPORTED_VARIANTS] or list(
|
| 61 |
+
SUPPORTED_VARIANTS
|
| 62 |
+
)
|
| 63 |
selected_variants = st.multiselect(
|
| 64 |
"Prompt variants",
|
| 65 |
+
options=SUPPORTED_VARIANTS,
|
| 66 |
+
default=default_variants,
|
| 67 |
format_func=prompt_variant_label,
|
| 68 |
key=_extract_widget_key(model_name, remote, dataset_source, "prompt_variants"),
|
| 69 |
)
|
| 70 |
+
st.session_state[_LAST_VARIANTS_KEY] = selected_variants
|
| 71 |
if not selected_variants:
|
| 72 |
st.info("Select at least one prompt variant.")
|
| 73 |
return
|
| 74 |
|
| 75 |
try:
|
| 76 |
+
dataset, dataset_status = load_dataset(
|
| 77 |
+
dataset_source,
|
| 78 |
+
personas_file=st.session_state.get("extract__personas_file"),
|
| 79 |
+
qa_file=st.session_state.get("extract__qa_file"),
|
| 80 |
+
)
|
| 81 |
st.caption(dataset_status)
|
| 82 |
except Exception as exc:
|
| 83 |
st.error(f"Could not load data: {exc}")
|
|
|
|
| 94 |
)
|
| 95 |
return
|
| 96 |
|
| 97 |
+
last_persona_ids: set[str] = set(st.session_state.get(_LAST_PERSONA_IDS_KEY, []))
|
| 98 |
+
default_personas = [p for p in personas if p.id in last_persona_ids] or [
|
| 99 |
+
personas[0]
|
| 100 |
+
]
|
| 101 |
selected_personas = st.multiselect(
|
| 102 |
"Personas",
|
| 103 |
options=personas,
|
| 104 |
+
default=default_personas,
|
| 105 |
format_func=persona_label,
|
| 106 |
key=_extract_widget_key(model_name, remote, dataset_source, "persona_select"),
|
| 107 |
)
|
| 108 |
+
st.session_state[_LAST_PERSONA_IDS_KEY] = [p.id for p in selected_personas]
|
| 109 |
|
| 110 |
if not selected_personas:
|
| 111 |
st.info("Select at least one persona.")
|
|
|
|
| 119 |
|
| 120 |
col1, col2, col3 = st.columns([2, 2, 1])
|
| 121 |
with col1:
|
| 122 |
+
last_qa_type = st.session_state.get(_LAST_QA_TYPE_KEY, "all")
|
| 123 |
+
qa_type_index = (
|
| 124 |
+
_QA_TYPE_OPTIONS.index(last_qa_type)
|
| 125 |
+
if last_qa_type in _QA_TYPE_OPTIONS
|
| 126 |
+
else 0
|
| 127 |
+
)
|
| 128 |
qa_type_select = st.selectbox(
|
| 129 |
"QA type",
|
| 130 |
+
options=_QA_TYPE_OPTIONS,
|
| 131 |
+
index=qa_type_index,
|
| 132 |
key=_extract_widget_key(
|
| 133 |
model_name, remote, dataset_source, "qa_type_select"
|
| 134 |
),
|
| 135 |
)
|
| 136 |
+
st.session_state[_LAST_QA_TYPE_KEY] = qa_type_select
|
| 137 |
+
qa_filter_type: Literal["explicit", "implicit"] | None = (
|
| 138 |
+
cast(Literal["explicit", "implicit"], qa_type_select)
|
| 139 |
+
if qa_type_select in ("explicit", "implicit")
|
| 140 |
+
else None
|
| 141 |
)
|
| 142 |
with col2:
|
| 143 |
+
last_difficulty = st.session_state.get(_LAST_DIFFICULTY_KEY, [1, 2, 3])
|
| 144 |
+
default_difficulty = [d for d in last_difficulty if d in (1, 2, 3)] or [
|
| 145 |
+
1,
|
| 146 |
+
2,
|
| 147 |
+
3,
|
| 148 |
+
]
|
| 149 |
difficulty_values = st.multiselect(
|
| 150 |
"Difficulty",
|
| 151 |
options=[1, 2, 3],
|
| 152 |
+
default=default_difficulty,
|
| 153 |
key=_extract_widget_key(
|
| 154 |
model_name, remote, dataset_source, "difficulty_select"
|
| 155 |
),
|
| 156 |
)
|
| 157 |
+
st.session_state[_LAST_DIFFICULTY_KEY] = difficulty_values
|
| 158 |
qa_filter_difficulty = difficulty_values if difficulty_values else None
|
| 159 |
|
| 160 |
runs, skipped = [], []
|
|
|
|
| 177 |
return
|
| 178 |
|
| 179 |
max_q = min(len(qa_pairs) for _, qa_pairs in runs)
|
| 180 |
+
last_max = st.session_state.get(_LAST_MAX_QUESTIONS_KEY, max_q)
|
| 181 |
+
default_max = min(max(last_max, 1), max_q)
|
| 182 |
max_questions = st.slider(
|
| 183 |
"Max questions",
|
| 184 |
min_value=1,
|
| 185 |
max_value=max_q,
|
| 186 |
+
value=default_max,
|
| 187 |
key=_extract_widget_key(
|
| 188 |
model_name, remote, dataset_source, "max_questions"
|
| 189 |
),
|
| 190 |
)
|
| 191 |
+
st.session_state[_LAST_MAX_QUESTIONS_KEY] = max_questions
|
| 192 |
|
| 193 |
if runs is None:
|
| 194 |
return
|
|
|
|
| 225 |
model_name=model_name,
|
| 226 |
persona=persona,
|
| 227 |
qa_pairs=qa_pairs[:max_questions],
|
| 228 |
+
variants=(variant,),
|
| 229 |
remote=remote,
|
| 230 |
on_status=_on_ndif_status if remote else None,
|
| 231 |
)
|
utils/chat.py
CHANGED
|
@@ -5,17 +5,10 @@ from typing import Literal
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
from nnterp import StandardizedTransformer
|
| 8 |
-
|
| 9 |
-
logger = logging.getLogger(__name__)
|
| 10 |
-
|
| 11 |
-
from persona_data.prompts import (
|
| 12 |
-
format_biography_prompt,
|
| 13 |
-
format_empty_persona_prompt,
|
| 14 |
-
format_templated_prompt,
|
| 15 |
-
normalize_messages,
|
| 16 |
-
)
|
| 17 |
from persona_data.synth_persona import PersonaData
|
| 18 |
|
|
|
|
| 19 |
SystemPromptMode = Literal["empty", "templated", "biography", "custom"]
|
| 20 |
|
| 21 |
|
|
@@ -47,11 +40,12 @@ def resolve_system_prompt(
|
|
| 47 |
if mode == "empty":
|
| 48 |
return ""
|
| 49 |
if mode == "templated":
|
| 50 |
-
return
|
| 51 |
if mode == "biography":
|
| 52 |
-
return
|
| 53 |
if mode == "custom":
|
| 54 |
-
return
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
def _format_plain_messages(
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
from nnterp import StandardizedTransformer
|
| 8 |
+
from persona_data.prompts import format_roleplay_prompt, normalize_messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from persona_data.synth_persona import PersonaData
|
| 10 |
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
SystemPromptMode = Literal["empty", "templated", "biography", "custom"]
|
| 13 |
|
| 14 |
|
|
|
|
| 40 |
if mode == "empty":
|
| 41 |
return ""
|
| 42 |
if mode == "templated":
|
| 43 |
+
return format_roleplay_prompt(persona.templated_view, mode="conversational")
|
| 44 |
if mode == "biography":
|
| 45 |
+
return format_roleplay_prompt(persona.biography_view, mode="conversational")
|
| 46 |
if mode == "custom":
|
| 47 |
+
return format_roleplay_prompt(mode="conversational")
|
| 48 |
+
raise ValueError(f"Unsupported system prompt mode: {mode}")
|
| 49 |
|
| 50 |
|
| 51 |
def _format_plain_messages(
|
utils/chat_export.py
CHANGED
|
@@ -54,7 +54,7 @@ def save_chat_export(
|
|
| 54 |
export_dir = (
|
| 55 |
get_artifacts_dir()
|
| 56 |
/ "chats"
|
| 57 |
-
/ model_name.
|
| 58 |
/ slugify(dataset_source)
|
| 59 |
/ slugify(persona_id)
|
| 60 |
)
|
|
|
|
| 54 |
export_dir = (
|
| 55 |
get_artifacts_dir()
|
| 56 |
/ "chats"
|
| 57 |
+
/ "__".join(slugify(part) for part in model_name.split("/"))
|
| 58 |
/ slugify(dataset_source)
|
| 59 |
/ slugify(persona_id)
|
| 60 |
)
|
utils/datasets.py
CHANGED
|
@@ -44,14 +44,14 @@ def _uploaded_file_to_temp_path(uploaded_file: Any, stem: str) -> Path:
|
|
| 44 |
|
| 45 |
def load_dataset(
|
| 46 |
dataset_source: str,
|
|
|
|
|
|
|
| 47 |
) -> tuple[SynthPersonaDataset | LocalPersonaDataset, str]:
|
| 48 |
"""Load the selected dataset source for the UI."""
|
| 49 |
|
| 50 |
if dataset_source == DATASET_SOURCES[0]:
|
| 51 |
return cached_hf_dataset(), "SynthPersona"
|
| 52 |
|
| 53 |
-
personas_file = st.session_state.get("extract__personas_file")
|
| 54 |
-
qa_file = st.session_state.get("extract__qa_file")
|
| 55 |
if personas_file is None or qa_file is None:
|
| 56 |
raise ValueError("Upload both personas.jsonl and qa.jsonl files")
|
| 57 |
|
|
|
|
| 44 |
|
| 45 |
def load_dataset(
|
| 46 |
dataset_source: str,
|
| 47 |
+
personas_file: Any = None,
|
| 48 |
+
qa_file: Any = None,
|
| 49 |
) -> tuple[SynthPersonaDataset | LocalPersonaDataset, str]:
|
| 50 |
"""Load the selected dataset source for the UI."""
|
| 51 |
|
| 52 |
if dataset_source == DATASET_SOURCES[0]:
|
| 53 |
return cached_hf_dataset(), "SynthPersona"
|
| 54 |
|
|
|
|
|
|
|
| 55 |
if personas_file is None or qa_file is None:
|
| 56 |
raise ValueError("Upload both personas.jsonl and qa.jsonl files")
|
| 57 |
|
utils/helpers.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import re
|
| 2 |
|
| 3 |
from persona_data.synth_persona import PersonaData
|
| 4 |
-
from persona_vectors.artifacts import SUPPORTED_VARIANTS
|
| 5 |
|
| 6 |
# Variant key -> human-readable label mapping
|
| 7 |
VARIANT_LABELS = {
|
|
@@ -11,9 +10,6 @@ VARIANT_LABELS = {
|
|
| 11 |
"custom": "Custom",
|
| 12 |
}
|
| 13 |
|
| 14 |
-
# Variants that correspond to actual system prompts (excludes "empty")
|
| 15 |
-
PROMPT_VARIANTS = list(SUPPORTED_VARIANTS)
|
| 16 |
-
|
| 17 |
# For selectbox options: list of labels in definition order
|
| 18 |
MODE_LABELS = list(VARIANT_LABELS.values())
|
| 19 |
|
|
|
|
| 1 |
import re
|
| 2 |
|
| 3 |
from persona_data.synth_persona import PersonaData
|
|
|
|
| 4 |
|
| 5 |
# Variant key -> human-readable label mapping
|
| 6 |
VARIANT_LABELS = {
|
|
|
|
| 10 |
"custom": "Custom",
|
| 11 |
}
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
# For selectbox options: list of labels in definition order
|
| 14 |
MODE_LABELS = list(VARIANT_LABELS.values())
|
| 15 |
|
uv.lock
CHANGED
|
@@ -297,7 +297,7 @@ name = "cuda-bindings"
|
|
| 297 |
version = "13.2.0"
|
| 298 |
source = { registry = "https://pypi.org/simple" }
|
| 299 |
dependencies = [
|
| 300 |
-
{ name = "cuda-pathfinder"
|
| 301 |
]
|
| 302 |
wheels = [
|
| 303 |
{ url = "https://files.pythonhosted.org/packages/1a/fe/7351d7e586a8b4c9f89731bfe4cf0148223e8f9903ff09571f78b3fb0682/cuda_bindings-13.2.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:08b395f79cb89ce0cd8effff07c4a1e20101b873c256a1aeb286e8fd7bd0f556", size = 5744254, upload-time = "2026-03-11T00:12:29.798Z" },
|
|
@@ -316,10 +316,10 @@ wheels = [
|
|
| 316 |
|
| 317 |
[[package]]
|
| 318 |
name = "cuda-pathfinder"
|
| 319 |
-
version = "1.5.
|
| 320 |
source = { registry = "https://pypi.org/simple" }
|
| 321 |
wheels = [
|
| 322 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 323 |
]
|
| 324 |
|
| 325 |
[[package]]
|
|
@@ -332,37 +332,37 @@ wheels = [
|
|
| 332 |
|
| 333 |
[package.optional-dependencies]
|
| 334 |
cublas = [
|
| 335 |
-
{ name = "nvidia-cublas", marker = "
|
| 336 |
]
|
| 337 |
cudart = [
|
| 338 |
-
{ name = "nvidia-cuda-runtime", marker = "
|
| 339 |
]
|
| 340 |
cufft = [
|
| 341 |
-
{ name = "nvidia-cufft", marker = "
|
| 342 |
]
|
| 343 |
cufile = [
|
| 344 |
{ name = "nvidia-cufile", marker = "sys_platform == 'linux'" },
|
| 345 |
]
|
| 346 |
cupti = [
|
| 347 |
-
{ name = "nvidia-cuda-cupti", marker = "
|
| 348 |
]
|
| 349 |
curand = [
|
| 350 |
-
{ name = "nvidia-curand", marker = "
|
| 351 |
]
|
| 352 |
cusolver = [
|
| 353 |
-
{ name = "nvidia-cusolver", marker = "
|
| 354 |
]
|
| 355 |
cusparse = [
|
| 356 |
-
{ name = "nvidia-cusparse", marker = "
|
| 357 |
]
|
| 358 |
nvjitlink = [
|
| 359 |
-
{ name = "nvidia-nvjitlink", marker = "
|
| 360 |
]
|
| 361 |
nvrtc = [
|
| 362 |
-
{ name = "nvidia-cuda-nvrtc", marker = "
|
| 363 |
]
|
| 364 |
nvtx = [
|
| 365 |
-
{ name = "nvidia-nvtx", marker = "
|
| 366 |
]
|
| 367 |
|
| 368 |
[[package]]
|
|
@@ -508,7 +508,7 @@ wheels = [
|
|
| 508 |
|
| 509 |
[[package]]
|
| 510 |
name = "huggingface-hub"
|
| 511 |
-
version = "1.9.
|
| 512 |
source = { registry = "https://pypi.org/simple" }
|
| 513 |
dependencies = [
|
| 514 |
{ name = "filelock" },
|
|
@@ -521,9 +521,9 @@ dependencies = [
|
|
| 521 |
{ name = "typer" },
|
| 522 |
{ name = "typing-extensions" },
|
| 523 |
]
|
| 524 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 525 |
wheels = [
|
| 526 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 527 |
]
|
| 528 |
|
| 529 |
[[package]]
|
|
@@ -883,11 +883,11 @@ wheels = [
|
|
| 883 |
|
| 884 |
[[package]]
|
| 885 |
name = "narwhals"
|
| 886 |
-
version = "2.
|
| 887 |
source = { registry = "https://pypi.org/simple" }
|
| 888 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 889 |
wheels = [
|
| 890 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 891 |
]
|
| 892 |
|
| 893 |
[[package]]
|
|
@@ -1216,7 +1216,7 @@ name = "nvidia-cudnn-cu13"
|
|
| 1216 |
version = "9.19.0.56"
|
| 1217 |
source = { registry = "https://pypi.org/simple" }
|
| 1218 |
dependencies = [
|
| 1219 |
-
{ name = "nvidia-cublas"
|
| 1220 |
]
|
| 1221 |
wheels = [
|
| 1222 |
{ url = "https://files.pythonhosted.org/packages/f1/84/26025437c1e6b61a707442184fa0c03d083b661adf3a3eecfd6d21677740/nvidia_cudnn_cu13-9.19.0.56-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:6ed29ffaee1176c612daf442e4dd6cfeb6a0caa43ddcbeb59da94953030b1be4", size = 433781201, upload-time = "2026-02-03T20:40:53.805Z" },
|
|
@@ -1228,7 +1228,7 @@ name = "nvidia-cufft"
|
|
| 1228 |
version = "12.0.0.61"
|
| 1229 |
source = { registry = "https://pypi.org/simple" }
|
| 1230 |
dependencies = [
|
| 1231 |
-
{ name = "nvidia-nvjitlink"
|
| 1232 |
]
|
| 1233 |
wheels = [
|
| 1234 |
{ url = "https://files.pythonhosted.org/packages/8b/ae/f417a75c0259e85c1d2f83ca4e960289a5f814ed0cea74d18c353d3e989d/nvidia_cufft-12.0.0.61-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5", size = 214053554, upload-time = "2025-09-04T08:31:38.196Z" },
|
|
@@ -1258,9 +1258,9 @@ name = "nvidia-cusolver"
|
|
| 1258 |
version = "12.0.4.66"
|
| 1259 |
source = { registry = "https://pypi.org/simple" }
|
| 1260 |
dependencies = [
|
| 1261 |
-
{ name = "nvidia-cublas"
|
| 1262 |
-
{ name = "nvidia-cusparse"
|
| 1263 |
-
{ name = "nvidia-nvjitlink"
|
| 1264 |
]
|
| 1265 |
wheels = [
|
| 1266 |
{ url = "https://files.pythonhosted.org/packages/c8/c3/b30c9e935fc01e3da443ec0116ed1b2a009bb867f5324d3f2d7e533e776b/nvidia_cusolver-12.0.4.66-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2", size = 223467760, upload-time = "2025-09-04T08:33:04.222Z" },
|
|
@@ -1272,7 +1272,7 @@ name = "nvidia-cusparse"
|
|
| 1272 |
version = "12.6.3.3"
|
| 1273 |
source = { registry = "https://pypi.org/simple" }
|
| 1274 |
dependencies = [
|
| 1275 |
-
{ name = "nvidia-nvjitlink"
|
| 1276 |
]
|
| 1277 |
wheels = [
|
| 1278 |
{ url = "https://files.pythonhosted.org/packages/f8/94/5c26f33738ae35276672f12615a64bd008ed5be6d1ebcb23579285d960a9/nvidia_cusparse-12.6.3.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c", size = 162155568, upload-time = "2025-09-04T08:33:42.864Z" },
|
|
@@ -1561,7 +1561,7 @@ wheels = [
|
|
| 1561 |
[[package]]
|
| 1562 |
name = "persona-data"
|
| 1563 |
version = "0.1.0"
|
| 1564 |
-
source = {
|
| 1565 |
dependencies = [
|
| 1566 |
{ name = "huggingface-hub" },
|
| 1567 |
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
|
@@ -1570,14 +1570,6 @@ dependencies = [
|
|
| 1570 |
{ name = "torch" },
|
| 1571 |
]
|
| 1572 |
|
| 1573 |
-
[package.metadata]
|
| 1574 |
-
requires-dist = [
|
| 1575 |
-
{ name = "huggingface-hub", specifier = ">=0.30.0" },
|
| 1576 |
-
{ name = "numpy", specifier = ">=1.24.0" },
|
| 1577 |
-
{ name = "python-dotenv", specifier = ">=1.0.0" },
|
| 1578 |
-
{ name = "torch", specifier = ">=2.0.0" },
|
| 1579 |
-
]
|
| 1580 |
-
|
| 1581 |
[[package]]
|
| 1582 |
name = "persona-ui"
|
| 1583 |
version = "0.1.0"
|
|
@@ -1592,8 +1584,8 @@ dependencies = [
|
|
| 1592 |
|
| 1593 |
[package.metadata]
|
| 1594 |
requires-dist = [
|
| 1595 |
-
{ name = "persona-data",
|
| 1596 |
-
{ name = "persona-vectors",
|
| 1597 |
{ name = "plotly", specifier = ">=6.6.0" },
|
| 1598 |
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
| 1599 |
{ name = "streamlit", specifier = ">=1.44.0" },
|
|
@@ -1602,7 +1594,7 @@ requires-dist = [
|
|
| 1602 |
[[package]]
|
| 1603 |
name = "persona-vectors"
|
| 1604 |
version = "0.1.0"
|
| 1605 |
-
source = {
|
| 1606 |
dependencies = [
|
| 1607 |
{ name = "kaleido" },
|
| 1608 |
{ name = "nnsight" },
|
|
@@ -1620,23 +1612,6 @@ dependencies = [
|
|
| 1620 |
{ name = "umap-learn" },
|
| 1621 |
]
|
| 1622 |
|
| 1623 |
-
[package.metadata]
|
| 1624 |
-
requires-dist = [
|
| 1625 |
-
{ name = "kaleido", specifier = ">=1.0.0" },
|
| 1626 |
-
{ name = "nnsight", specifier = ">=0.6.1" },
|
| 1627 |
-
{ name = "nnterp", specifier = ">=1.3.0" },
|
| 1628 |
-
{ name = "persona-data", editable = "../persona-data" },
|
| 1629 |
-
{ name = "plotly", specifier = ">=6.6.0" },
|
| 1630 |
-
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
| 1631 |
-
{ name = "safetensors", specifier = ">=0.7.0" },
|
| 1632 |
-
{ name = "scikit-learn", specifier = ">=1.6.0" },
|
| 1633 |
-
{ name = "torch", specifier = ">=2.10.0" },
|
| 1634 |
-
{ name = "torchvision", specifier = ">=0.26.0" },
|
| 1635 |
-
{ name = "tqdm", specifier = ">=4.67.3" },
|
| 1636 |
-
{ name = "transformers", specifier = ">=5.2.0" },
|
| 1637 |
-
{ name = "umap-learn", specifier = ">=0.5.7" },
|
| 1638 |
-
]
|
| 1639 |
-
|
| 1640 |
[[package]]
|
| 1641 |
name = "pexpect"
|
| 1642 |
version = "4.9.0"
|
|
@@ -2075,7 +2050,7 @@ wheels = [
|
|
| 2075 |
|
| 2076 |
[[package]]
|
| 2077 |
name = "pytest"
|
| 2078 |
-
version = "9.0.
|
| 2079 |
source = { registry = "https://pypi.org/simple" }
|
| 2080 |
dependencies = [
|
| 2081 |
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
|
@@ -2086,9 +2061,9 @@ dependencies = [
|
|
| 2086 |
{ name = "pygments" },
|
| 2087 |
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
| 2088 |
]
|
| 2089 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 2090 |
wheels = [
|
| 2091 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 2092 |
]
|
| 2093 |
|
| 2094 |
[[package]]
|
|
|
|
| 297 |
version = "13.2.0"
|
| 298 |
source = { registry = "https://pypi.org/simple" }
|
| 299 |
dependencies = [
|
| 300 |
+
{ name = "cuda-pathfinder" },
|
| 301 |
]
|
| 302 |
wheels = [
|
| 303 |
{ url = "https://files.pythonhosted.org/packages/1a/fe/7351d7e586a8b4c9f89731bfe4cf0148223e8f9903ff09571f78b3fb0682/cuda_bindings-13.2.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:08b395f79cb89ce0cd8effff07c4a1e20101b873c256a1aeb286e8fd7bd0f556", size = 5744254, upload-time = "2026-03-11T00:12:29.798Z" },
|
|
|
|
| 316 |
|
| 317 |
[[package]]
|
| 318 |
name = "cuda-pathfinder"
|
| 319 |
+
version = "1.5.2"
|
| 320 |
source = { registry = "https://pypi.org/simple" }
|
| 321 |
wheels = [
|
| 322 |
+
{ url = "https://files.pythonhosted.org/packages/f2/f9/1b9b60a30fc463c14cdea7a77228131a0ccc89572e8df9cb86c9648271ab/cuda_pathfinder-1.5.2-py3-none-any.whl", hash = "sha256:0c5f160a7756c5b072723cbbd6d861e38917ef956c68150b02f0b6e9271c71fa", size = 49988, upload-time = "2026-04-06T23:01:05.17Z" },
|
| 323 |
]
|
| 324 |
|
| 325 |
[[package]]
|
|
|
|
| 332 |
|
| 333 |
[package.optional-dependencies]
|
| 334 |
cublas = [
|
| 335 |
+
{ name = "nvidia-cublas", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
| 336 |
]
|
| 337 |
cudart = [
|
| 338 |
+
{ name = "nvidia-cuda-runtime", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
| 339 |
]
|
| 340 |
cufft = [
|
| 341 |
+
{ name = "nvidia-cufft", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
| 342 |
]
|
| 343 |
cufile = [
|
| 344 |
{ name = "nvidia-cufile", marker = "sys_platform == 'linux'" },
|
| 345 |
]
|
| 346 |
cupti = [
|
| 347 |
+
{ name = "nvidia-cuda-cupti", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
| 348 |
]
|
| 349 |
curand = [
|
| 350 |
+
{ name = "nvidia-curand", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
| 351 |
]
|
| 352 |
cusolver = [
|
| 353 |
+
{ name = "nvidia-cusolver", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
| 354 |
]
|
| 355 |
cusparse = [
|
| 356 |
+
{ name = "nvidia-cusparse", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
| 357 |
]
|
| 358 |
nvjitlink = [
|
| 359 |
+
{ name = "nvidia-nvjitlink", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
| 360 |
]
|
| 361 |
nvrtc = [
|
| 362 |
+
{ name = "nvidia-cuda-nvrtc", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
| 363 |
]
|
| 364 |
nvtx = [
|
| 365 |
+
{ name = "nvidia-nvtx", marker = "sys_platform == 'linux' or sys_platform == 'win32'" },
|
| 366 |
]
|
| 367 |
|
| 368 |
[[package]]
|
|
|
|
| 508 |
|
| 509 |
[[package]]
|
| 510 |
name = "huggingface-hub"
|
| 511 |
+
version = "1.9.2"
|
| 512 |
source = { registry = "https://pypi.org/simple" }
|
| 513 |
dependencies = [
|
| 514 |
{ name = "filelock" },
|
|
|
|
| 521 |
{ name = "typer" },
|
| 522 |
{ name = "typing-extensions" },
|
| 523 |
]
|
| 524 |
+
sdist = { url = "https://files.pythonhosted.org/packages/cf/65/fb800d327bf25bf31b798dd08935d326d064ecb9b359059fecd91b3a98e8/huggingface_hub-1.9.2.tar.gz", hash = "sha256:8d09d080a186bd950a361bfc04b862dfb04d6a2b41d48e9ba1b37507cfd3f1e1", size = 750284, upload-time = "2026-04-08T08:43:11.127Z" }
|
| 525 |
wheels = [
|
| 526 |
+
{ url = "https://files.pythonhosted.org/packages/57/d4/e33bf0b362810a9b96c5923e38908950d58ecb512db42e3730320c7f4a3a/huggingface_hub-1.9.2-py3-none-any.whl", hash = "sha256:e1e62ce237d4fbeca9f970aeb15176fbd503e04c25577bfd22f44aa7aa2b5243", size = 637349, upload-time = "2026-04-08T08:43:09.114Z" },
|
| 527 |
]
|
| 528 |
|
| 529 |
[[package]]
|
|
|
|
| 883 |
|
| 884 |
[[package]]
|
| 885 |
name = "narwhals"
|
| 886 |
+
version = "2.19.0"
|
| 887 |
source = { registry = "https://pypi.org/simple" }
|
| 888 |
+
sdist = { url = "https://files.pythonhosted.org/packages/4e/1a/bd3317c0bdbcd9ffb710ddf5250b32898f8f2c240be99494fe137feb77a7/narwhals-2.19.0.tar.gz", hash = "sha256:14fd7040b5ff211d415a82e4827b9d04c354e213e72a6d0730205ffd72e3b7ff", size = 623698, upload-time = "2026-04-06T15:50:58.786Z" }
|
| 889 |
wheels = [
|
| 890 |
+
{ url = "https://files.pythonhosted.org/packages/37/72/e61e3091e0e00fae9d3a8ef85ece9d2cd4b5966058e1f2901ce42679eebf/narwhals-2.19.0-py3-none-any.whl", hash = "sha256:1f8dfa4a33a6dbff878c3e9be4c3b455dfcaf2a9322f1357db00e4e92e95b84b", size = 446991, upload-time = "2026-04-06T15:50:57.046Z" },
|
| 891 |
]
|
| 892 |
|
| 893 |
[[package]]
|
|
|
|
| 1216 |
version = "9.19.0.56"
|
| 1217 |
source = { registry = "https://pypi.org/simple" }
|
| 1218 |
dependencies = [
|
| 1219 |
+
{ name = "nvidia-cublas" },
|
| 1220 |
]
|
| 1221 |
wheels = [
|
| 1222 |
{ url = "https://files.pythonhosted.org/packages/f1/84/26025437c1e6b61a707442184fa0c03d083b661adf3a3eecfd6d21677740/nvidia_cudnn_cu13-9.19.0.56-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:6ed29ffaee1176c612daf442e4dd6cfeb6a0caa43ddcbeb59da94953030b1be4", size = 433781201, upload-time = "2026-02-03T20:40:53.805Z" },
|
|
|
|
| 1228 |
version = "12.0.0.61"
|
| 1229 |
source = { registry = "https://pypi.org/simple" }
|
| 1230 |
dependencies = [
|
| 1231 |
+
{ name = "nvidia-nvjitlink" },
|
| 1232 |
]
|
| 1233 |
wheels = [
|
| 1234 |
{ url = "https://files.pythonhosted.org/packages/8b/ae/f417a75c0259e85c1d2f83ca4e960289a5f814ed0cea74d18c353d3e989d/nvidia_cufft-12.0.0.61-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5", size = 214053554, upload-time = "2025-09-04T08:31:38.196Z" },
|
|
|
|
| 1258 |
version = "12.0.4.66"
|
| 1259 |
source = { registry = "https://pypi.org/simple" }
|
| 1260 |
dependencies = [
|
| 1261 |
+
{ name = "nvidia-cublas" },
|
| 1262 |
+
{ name = "nvidia-cusparse" },
|
| 1263 |
+
{ name = "nvidia-nvjitlink" },
|
| 1264 |
]
|
| 1265 |
wheels = [
|
| 1266 |
{ url = "https://files.pythonhosted.org/packages/c8/c3/b30c9e935fc01e3da443ec0116ed1b2a009bb867f5324d3f2d7e533e776b/nvidia_cusolver-12.0.4.66-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2", size = 223467760, upload-time = "2025-09-04T08:33:04.222Z" },
|
|
|
|
| 1272 |
version = "12.6.3.3"
|
| 1273 |
source = { registry = "https://pypi.org/simple" }
|
| 1274 |
dependencies = [
|
| 1275 |
+
{ name = "nvidia-nvjitlink" },
|
| 1276 |
]
|
| 1277 |
wheels = [
|
| 1278 |
{ url = "https://files.pythonhosted.org/packages/f8/94/5c26f33738ae35276672f12615a64bd008ed5be6d1ebcb23579285d960a9/nvidia_cusparse-12.6.3.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c", size = 162155568, upload-time = "2025-09-04T08:33:42.864Z" },
|
|
|
|
| 1561 |
[[package]]
|
| 1562 |
name = "persona-data"
|
| 1563 |
version = "0.1.0"
|
| 1564 |
+
source = { git = "ssh://git@github.com/implicit-personalization/persona-data.git#3763bd6e42472b589b4e32acd3e47b711a0af1f5" }
|
| 1565 |
dependencies = [
|
| 1566 |
{ name = "huggingface-hub" },
|
| 1567 |
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
|
|
|
| 1570 |
{ name = "torch" },
|
| 1571 |
]
|
| 1572 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1573 |
[[package]]
|
| 1574 |
name = "persona-ui"
|
| 1575 |
version = "0.1.0"
|
|
|
|
| 1584 |
|
| 1585 |
[package.metadata]
|
| 1586 |
requires-dist = [
|
| 1587 |
+
{ name = "persona-data", git = "ssh://git@github.com/implicit-personalization/persona-data.git" },
|
| 1588 |
+
{ name = "persona-vectors", git = "ssh://git@github.com/implicit-personalization/persona-vectors.git" },
|
| 1589 |
{ name = "plotly", specifier = ">=6.6.0" },
|
| 1590 |
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
| 1591 |
{ name = "streamlit", specifier = ">=1.44.0" },
|
|
|
|
| 1594 |
[[package]]
|
| 1595 |
name = "persona-vectors"
|
| 1596 |
version = "0.1.0"
|
| 1597 |
+
source = { git = "ssh://git@github.com/implicit-personalization/persona-vectors.git#fa6b4b61eaaba9ce64ee8614766bf75879148bbb" }
|
| 1598 |
dependencies = [
|
| 1599 |
{ name = "kaleido" },
|
| 1600 |
{ name = "nnsight" },
|
|
|
|
| 1612 |
{ name = "umap-learn" },
|
| 1613 |
]
|
| 1614 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1615 |
[[package]]
|
| 1616 |
name = "pexpect"
|
| 1617 |
version = "4.9.0"
|
|
|
|
| 2050 |
|
| 2051 |
[[package]]
|
| 2052 |
name = "pytest"
|
| 2053 |
+
version = "9.0.3"
|
| 2054 |
source = { registry = "https://pypi.org/simple" }
|
| 2055 |
dependencies = [
|
| 2056 |
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
|
|
|
| 2061 |
{ name = "pygments" },
|
| 2062 |
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
| 2063 |
]
|
| 2064 |
+
sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" }
|
| 2065 |
wheels = [
|
| 2066 |
+
{ url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" },
|
| 2067 |
]
|
| 2068 |
|
| 2069 |
[[package]]
|