Jac-Zac commited on
Commit ·
2bf3d21
1
Parent(s): 12cdb17
App new feel and look revamp
Browse files- app.py +35 -8
- pyproject.toml +3 -1
- tabs/compare.py +339 -137
- tabs/extract.py +3 -19
- utils/compare_sources.py +186 -0
- utils/controls.py +29 -0
- uv.lock +22 -18
app.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
from dotenv import load_dotenv
|
|
@@ -16,6 +17,14 @@ _TABS = ["Chat", "Compare", "Extract"]
|
|
| 16 |
_TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def _remote_model_input(remote_models: list[str]) -> str:
|
| 20 |
"""Return the active remote model id, picking from running NDIF deployments or a custom value."""
|
| 21 |
|
|
@@ -74,7 +83,7 @@ def _remote_model_input(remote_models: list[str]) -> str:
|
|
| 74 |
return model_name
|
| 75 |
|
| 76 |
|
| 77 |
-
def _sidebar_controls() ->
|
| 78 |
from utils.runtime import list_remote_models
|
| 79 |
|
| 80 |
with st.sidebar:
|
|
@@ -96,6 +105,19 @@ def _sidebar_controls() -> tuple[bool, str, str, str]:
|
|
| 96 |
st.session_state["sidebar__active_tab"] = tab_name
|
| 97 |
st.rerun()
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
st.divider()
|
| 100 |
st.caption("Runtime")
|
| 101 |
remote = st.toggle("Remote (NDIF)", value=False, key="sidebar__remote")
|
|
@@ -119,7 +141,12 @@ def _sidebar_controls() -> tuple[bool, str, str, str]:
|
|
| 119 |
help="Dataset for Chat and Extract.",
|
| 120 |
)
|
| 121 |
|
| 122 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
|
| 125 |
def main() -> None:
|
|
@@ -136,20 +163,20 @@ def main() -> None:
|
|
| 136 |
|
| 137 |
torch.set_grad_enabled(False)
|
| 138 |
|
| 139 |
-
|
| 140 |
|
| 141 |
-
if active_tab == "Extract":
|
| 142 |
from tabs.extract import render_extract_tab
|
| 143 |
|
| 144 |
-
render_extract_tab(remote, model_name, dataset_source)
|
| 145 |
-
elif active_tab == "Compare":
|
| 146 |
from tabs.compare import render_compare_tab
|
| 147 |
|
| 148 |
-
render_compare_tab(
|
| 149 |
else:
|
| 150 |
from tabs.chat import render_chat_tab
|
| 151 |
|
| 152 |
-
render_chat_tab(remote, model_name, dataset_source)
|
| 153 |
|
| 154 |
|
| 155 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import os
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
|
| 4 |
import streamlit as st
|
| 5 |
from dotenv import load_dotenv
|
|
|
|
| 17 |
_TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
|
| 18 |
|
| 19 |
|
| 20 |
+
@dataclass(frozen=True)
|
| 21 |
+
class SidebarState:
|
| 22 |
+
remote: bool
|
| 23 |
+
model_name: str
|
| 24 |
+
dataset_source: str
|
| 25 |
+
active_tab: str
|
| 26 |
+
|
| 27 |
+
|
| 28 |
def _remote_model_input(remote_models: list[str]) -> str:
|
| 29 |
"""Return the active remote model id, picking from running NDIF deployments or a custom value."""
|
| 30 |
|
|
|
|
| 83 |
return model_name
|
| 84 |
|
| 85 |
|
| 86 |
+
def _sidebar_controls() -> SidebarState:
|
| 87 |
from utils.runtime import list_remote_models
|
| 88 |
|
| 89 |
with st.sidebar:
|
|
|
|
| 105 |
st.session_state["sidebar__active_tab"] = tab_name
|
| 106 |
st.rerun()
|
| 107 |
|
| 108 |
+
if active_tab == "Compare":
|
| 109 |
+
model_name = st.session_state.get(_LAST_LOCAL_MODEL_KEY, DEFAULT_MODEL)
|
| 110 |
+
dataset_source = st.session_state.get(
|
| 111 |
+
"sidebar__dataset_source",
|
| 112 |
+
DATASET_SOURCES[0],
|
| 113 |
+
)
|
| 114 |
+
return SidebarState(
|
| 115 |
+
remote=False,
|
| 116 |
+
model_name=model_name,
|
| 117 |
+
dataset_source=dataset_source,
|
| 118 |
+
active_tab=active_tab,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
st.divider()
|
| 122 |
st.caption("Runtime")
|
| 123 |
remote = st.toggle("Remote (NDIF)", value=False, key="sidebar__remote")
|
|
|
|
| 141 |
help="Dataset for Chat and Extract.",
|
| 142 |
)
|
| 143 |
|
| 144 |
+
return SidebarState(
|
| 145 |
+
remote=remote,
|
| 146 |
+
model_name=model_name,
|
| 147 |
+
dataset_source=dataset_source,
|
| 148 |
+
active_tab=active_tab,
|
| 149 |
+
)
|
| 150 |
|
| 151 |
|
| 152 |
def main() -> None:
|
|
|
|
| 163 |
|
| 164 |
torch.set_grad_enabled(False)
|
| 165 |
|
| 166 |
+
sidebar = _sidebar_controls()
|
| 167 |
|
| 168 |
+
if sidebar.active_tab == "Extract":
|
| 169 |
from tabs.extract import render_extract_tab
|
| 170 |
|
| 171 |
+
render_extract_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)
|
| 172 |
+
elif sidebar.active_tab == "Compare":
|
| 173 |
from tabs.compare import render_compare_tab
|
| 174 |
|
| 175 |
+
render_compare_tab()
|
| 176 |
else:
|
| 177 |
from tabs.chat import render_chat_tab
|
| 178 |
|
| 179 |
+
render_chat_tab(sidebar.remote, sidebar.model_name, sidebar.dataset_source)
|
| 180 |
|
| 181 |
|
| 182 |
if __name__ == "__main__":
|
pyproject.toml
CHANGED
|
@@ -1,12 +1,14 @@
|
|
| 1 |
[project]
|
| 2 |
name = "persona-ui"
|
| 3 |
-
version = "0.
|
| 4 |
description = "Streamlit UI for persona-vectors"
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
"persona-vectors>=0.6.4",
|
| 9 |
"persona-data>=0.4.2",
|
|
|
|
|
|
|
| 10 |
"streamlit>=1.44.0",
|
| 11 |
"plotly>=6.6.0",
|
| 12 |
"python-dotenv>=1.2.2",
|
|
|
|
| 1 |
[project]
|
| 2 |
name = "persona-ui"
|
| 3 |
+
version = "0.4.0"
|
| 4 |
description = "Streamlit UI for persona-vectors"
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
"persona-vectors>=0.6.4",
|
| 9 |
"persona-data>=0.4.2",
|
| 10 |
+
"datasets>=4.8.5",
|
| 11 |
+
"huggingface-hub>=1.14.0",
|
| 12 |
"streamlit>=1.44.0",
|
| 13 |
"plotly>=6.6.0",
|
| 14 |
"python-dotenv>=1.2.2",
|
tabs/compare.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
-
import os
|
| 2 |
from collections.abc import Callable
|
| 3 |
from dataclasses import dataclass
|
| 4 |
from itertools import combinations
|
|
|
|
| 5 |
|
| 6 |
import streamlit as st
|
| 7 |
from persona_data.environment import get_artifacts_dir
|
|
|
|
| 8 |
from persona_vectors.analysis import load_persona_vectors, load_variant_vectors
|
| 9 |
-
from persona_vectors.artifacts import
|
| 10 |
-
from persona_vectors.artifacts import list_layers as list_local_layers
|
| 11 |
from persona_vectors.extraction import MaskStrategy
|
| 12 |
from persona_vectors.plots import (
|
| 13 |
build_layered_figure,
|
|
@@ -16,50 +16,39 @@ from persona_vectors.plots import (
|
|
| 16 |
save_plot_html,
|
| 17 |
)
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
from utils.helpers import (
|
| 20 |
ANALYSIS_HELP_TEXT,
|
| 21 |
ANALYSIS_MODES,
|
| 22 |
-
persona_display_label,
|
| 23 |
prompt_variant_label,
|
| 24 |
slugify,
|
| 25 |
widget_key,
|
| 26 |
)
|
| 27 |
|
| 28 |
-
Store = ActivationStore | HFActivationStore
|
| 29 |
-
|
| 30 |
-
DEFAULT_HUB_REPO = os.environ.get(
|
| 31 |
-
"PERSONA_VECTORS_HUB_REPO",
|
| 32 |
-
"implicit-personalization/synth-persona-vectors",
|
| 33 |
-
)
|
| 34 |
-
SOURCE_HUB = "Hugging Face Hub"
|
| 35 |
-
SOURCE_LOCAL = "Local activations"
|
| 36 |
-
SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
|
| 37 |
-
|
| 38 |
|
| 39 |
def _filename(*parts: str) -> str:
|
| 40 |
return "__".join(slugify(part) for part in parts if part)
|
| 41 |
|
| 42 |
|
| 43 |
-
_list_layers_cached = st.cache_data(show_spinner=False)(list_local_layers)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
@st.cache_data(show_spinner=False)
|
| 47 |
-
def _hub_layers_cached(
|
| 48 |
-
repo_id: str,
|
| 49 |
-
model_name: str,
|
| 50 |
-
mask_strategy_value: str,
|
| 51 |
-
variant: str,
|
| 52 |
-
persona_id: str,
|
| 53 |
-
) -> list[int]:
|
| 54 |
-
store = HFActivationStore(
|
| 55 |
-
repo_id,
|
| 56 |
-
model_name,
|
| 57 |
-
mask_strategy=MaskStrategy(mask_strategy_value),
|
| 58 |
-
)
|
| 59 |
-
sample = store.load(variant, persona_id)
|
| 60 |
-
return list(range(int(sample.shape[0])))
|
| 61 |
-
|
| 62 |
-
|
| 63 |
# Keep compare-tab selection state separate so projection defaults do not
|
| 64 |
# overwrite cosine similarity defaults.
|
| 65 |
_LAST_COSINE_PERSONAS_KEY = "compare:last_personas:cosine"
|
|
@@ -68,6 +57,15 @@ _LAST_MASK_STRATEGY_KEY = "compare:last_mask_strategy"
|
|
| 68 |
_LAST_SOURCE_KEY = "compare:last_source"
|
| 69 |
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
@dataclass(frozen=True)
|
| 72 |
class CosineSelection:
|
| 73 |
variants: list[str]
|
|
@@ -77,11 +75,10 @@ class CosineSelection:
|
|
| 77 |
persona_key: str
|
| 78 |
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
return f"local:{store.root_dir}"
|
| 85 |
|
| 86 |
|
| 87 |
def _layers_for_variant(
|
|
@@ -93,14 +90,14 @@ def _layers_for_variant(
|
|
| 93 |
if isinstance(store, HFActivationStore):
|
| 94 |
if not persona_ids:
|
| 95 |
return []
|
| 96 |
-
return
|
| 97 |
store.repo_id,
|
| 98 |
store.model_name,
|
| 99 |
mask_strategy.value,
|
| 100 |
variant,
|
| 101 |
persona_ids[0],
|
| 102 |
)
|
| 103 |
-
return
|
| 104 |
str(store.root_dir),
|
| 105 |
store.model_name,
|
| 106 |
[variant],
|
|
@@ -109,59 +106,188 @@ def _layers_for_variant(
|
|
| 109 |
)
|
| 110 |
|
| 111 |
|
| 112 |
-
def
|
| 113 |
store: Store,
|
| 114 |
variants: list[str],
|
| 115 |
mask_strategy: MaskStrategy,
|
| 116 |
*,
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
remember_key: str,
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
-
last_personas: list[str] = st.session_state.get(remember_key, [])
|
| 134 |
-
default_personas = [p for p in last_personas if p in persona_options]
|
| 135 |
-
if not default_personas:
|
| 136 |
-
default_personas = persona_options if default_all else persona_options[:1]
|
| 137 |
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
"load",
|
| 140 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
widget_scope,
|
| 142 |
store.model_name,
|
| 143 |
mask_strategy.value,
|
| 144 |
*variants,
|
| 145 |
)
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
)
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
|
| 167 |
def _render_save_buttons(
|
|
@@ -179,35 +305,18 @@ def _render_save_buttons(
|
|
| 179 |
|
| 180 |
|
| 181 |
def _render_mask_strategy_select(scope: str) -> MaskStrategy:
|
| 182 |
-
|
| 183 |
-
_LAST_MASK_STRATEGY_KEY,
|
| 184 |
-
MaskStrategy.ANSWER_MEAN.value,
|
| 185 |
-
)
|
| 186 |
-
strategies = list(MaskStrategy)
|
| 187 |
-
selected = st.selectbox(
|
| 188 |
-
"Mask strategy",
|
| 189 |
-
options=strategies,
|
| 190 |
-
index=next(
|
| 191 |
-
(
|
| 192 |
-
idx
|
| 193 |
-
for idx, strategy in enumerate(strategies)
|
| 194 |
-
if strategy.value == last_strategy
|
| 195 |
-
),
|
| 196 |
-
0,
|
| 197 |
-
),
|
| 198 |
-
format_func=lambda strategy: strategy.value.replace("_", " ").title(),
|
| 199 |
key=widget_key("load", "mask_strategy", scope),
|
|
|
|
| 200 |
help="Which extracted activation set to load.",
|
| 201 |
)
|
| 202 |
-
st.session_state[_LAST_MASK_STRATEGY_KEY] = selected.value
|
| 203 |
-
return selected
|
| 204 |
|
| 205 |
|
| 206 |
def _render_cosine_selection(
|
| 207 |
store: Store,
|
| 208 |
mask_strategy: MaskStrategy,
|
| 209 |
) -> CosineSelection | None:
|
| 210 |
-
variants =
|
| 211 |
if len(variants) < 2:
|
| 212 |
st.info("Need at least two variants with saved vectors for cosine comparison.")
|
| 213 |
return None
|
|
@@ -220,7 +329,7 @@ def _render_cosine_selection(
|
|
| 220 |
options=variants,
|
| 221 |
index=0,
|
| 222 |
format_func=prompt_variant_label,
|
| 223 |
-
key=widget_key("load", "variant_a",
|
| 224 |
)
|
| 225 |
with col2:
|
| 226 |
variant_b = st.selectbox(
|
|
@@ -228,18 +337,18 @@ def _render_cosine_selection(
|
|
| 228 |
options=variants,
|
| 229 |
index=min(1, len(variants) - 1),
|
| 230 |
format_func=prompt_variant_label,
|
| 231 |
-
key=widget_key("load", "variant_b",
|
| 232 |
)
|
| 233 |
|
| 234 |
if variant_a == variant_b:
|
| 235 |
st.warning("Choose two different variants to compare.")
|
| 236 |
return None
|
| 237 |
|
| 238 |
-
persona_ids
|
| 239 |
store,
|
| 240 |
[variant_a, variant_b],
|
| 241 |
mask_strategy,
|
| 242 |
-
widget_scope=f"cosine:{
|
| 243 |
remember_key=_LAST_COSINE_PERSONAS_KEY,
|
| 244 |
)
|
| 245 |
if not persona_ids:
|
|
@@ -334,7 +443,7 @@ def _render_cosine_similarity(
|
|
| 334 |
cosine_fig_key = widget_key(
|
| 335 |
"load",
|
| 336 |
"cosine_fig_state",
|
| 337 |
-
|
| 338 |
store.model_name,
|
| 339 |
mask_strategy.value,
|
| 340 |
selection.variant_a,
|
|
@@ -363,7 +472,7 @@ def _render_cosine_similarity(
|
|
| 363 |
key=widget_key(
|
| 364 |
"load",
|
| 365 |
"compare_vectors",
|
| 366 |
-
|
| 367 |
store.model_name,
|
| 368 |
mask_strategy.value,
|
| 369 |
selection.variant_a,
|
|
@@ -398,7 +507,7 @@ def _select_single_variant_samples(
|
|
| 398 |
mask_strategy: MaskStrategy,
|
| 399 |
scope: str,
|
| 400 |
) -> tuple[str, list[str], str, list[int]] | None:
|
| 401 |
-
variants =
|
| 402 |
if not variants:
|
| 403 |
st.info("No variants with saved vectors for this model.")
|
| 404 |
return None
|
|
@@ -407,13 +516,13 @@ def _select_single_variant_samples(
|
|
| 407 |
options=variants,
|
| 408 |
index=variants.index("biography") if "biography" in variants else 0,
|
| 409 |
format_func=prompt_variant_label,
|
| 410 |
-
key=widget_key("load", "variant", scope,
|
| 411 |
)
|
| 412 |
-
persona_ids
|
| 413 |
store,
|
| 414 |
[variant],
|
| 415 |
mask_strategy,
|
| 416 |
-
widget_scope=f"{scope}:{
|
| 417 |
remember_key=_LAST_PROJECTION_PERSONAS_KEY,
|
| 418 |
default_all=True,
|
| 419 |
)
|
|
@@ -426,26 +535,8 @@ def _select_single_variant_samples(
|
|
| 426 |
st.info("No shared layers are available for the selected personas.")
|
| 427 |
return None
|
| 428 |
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
options=layer_options,
|
| 432 |
-
default=layer_options,
|
| 433 |
-
key=widget_key(
|
| 434 |
-
"load",
|
| 435 |
-
"layers",
|
| 436 |
-
scope,
|
| 437 |
-
_store_id(store),
|
| 438 |
-
store.model_name,
|
| 439 |
-
mask_strategy.value,
|
| 440 |
-
variant,
|
| 441 |
-
persona_key,
|
| 442 |
-
),
|
| 443 |
-
)
|
| 444 |
-
if not selected_layers:
|
| 445 |
-
st.info("Select at least one layer.")
|
| 446 |
-
return None
|
| 447 |
-
|
| 448 |
-
return variant, persona_ids, persona_key, selected_layers
|
| 449 |
|
| 450 |
|
| 451 |
def _render_layered_figure_analysis(
|
|
@@ -472,7 +563,7 @@ def _render_layered_figure_analysis(
|
|
| 472 |
fig_key = widget_key(
|
| 473 |
"load",
|
| 474 |
f"{scope}_fig_state",
|
| 475 |
-
|
| 476 |
store.model_name,
|
| 477 |
mask_strategy.value,
|
| 478 |
figure_kind,
|
|
@@ -481,7 +572,7 @@ def _render_layered_figure_analysis(
|
|
| 481 |
"persona_vector",
|
| 482 |
persona_key,
|
| 483 |
)
|
| 484 |
-
filename = scope
|
| 485 |
|
| 486 |
if st.button(button_label, type="primary"):
|
| 487 |
try:
|
|
@@ -549,7 +640,105 @@ def _render_source_select() -> str:
|
|
| 549 |
return source
|
| 550 |
|
| 551 |
|
| 552 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
if source == SOURCE_HUB:
|
| 554 |
repo = st.text_input(
|
| 555 |
"Hub repo",
|
|
@@ -557,16 +746,29 @@ def _build_store(source: str, model_name: str, mask_strategy: MaskStrategy) -> S
|
|
| 557 |
key="compare:hub_repo",
|
| 558 |
help="Hugging Face dataset published by `scripts/push_to_hf.py`.",
|
| 559 |
)
|
| 560 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
artifacts_root = st.text_input(
|
| 562 |
"Artifacts root",
|
| 563 |
value=str(get_artifacts_dir() / "activations"),
|
| 564 |
key="compare:artifacts_root",
|
| 565 |
)
|
| 566 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
|
| 568 |
|
| 569 |
-
def render_compare_tab(
|
| 570 |
"""Render the compare tab."""
|
| 571 |
|
| 572 |
st.title("Compare")
|
|
@@ -585,9 +787,9 @@ def render_compare_tab(model_name: str) -> None:
|
|
| 585 |
analysis_mode = ANALYSIS_MODES[0]
|
| 586 |
st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
|
| 587 |
|
| 588 |
-
with st.expander("Source settings", expanded=
|
| 589 |
mask_strategy = _render_mask_strategy_select(analysis_mode)
|
| 590 |
-
store = _build_store(source,
|
| 591 |
|
| 592 |
if analysis_mode == "Cosine similarity":
|
| 593 |
_render_cosine_similarity(store, mask_strategy)
|
|
|
|
|
|
|
| 1 |
from collections.abc import Callable
|
| 2 |
from dataclasses import dataclass
|
| 3 |
from itertools import combinations
|
| 4 |
+
from pathlib import Path
|
| 5 |
|
| 6 |
import streamlit as st
|
| 7 |
from persona_data.environment import get_artifacts_dir
|
| 8 |
+
from persona_data.synth_persona import BASELINE_PERSONA_ID
|
| 9 |
from persona_vectors.analysis import load_persona_vectors, load_variant_vectors
|
| 10 |
+
from persona_vectors.artifacts import HFActivationStore
|
|
|
|
| 11 |
from persona_vectors.extraction import MaskStrategy
|
| 12 |
from persona_vectors.plots import (
|
| 13 |
build_layered_figure,
|
|
|
|
| 16 |
save_plot_html,
|
| 17 |
)
|
| 18 |
|
| 19 |
+
from utils.compare_sources import (
|
| 20 |
+
DEFAULT_COMPARE_MODEL,
|
| 21 |
+
DEFAULT_HUB_REPO,
|
| 22 |
+
SOURCE_HUB,
|
| 23 |
+
SOURCE_LOCAL,
|
| 24 |
+
SOURCES,
|
| 25 |
+
Store,
|
| 26 |
+
activation_store_cached,
|
| 27 |
+
available_variants,
|
| 28 |
+
hub_layers_cached,
|
| 29 |
+
hub_models_by_mask_strategy,
|
| 30 |
+
list_layers_cached,
|
| 31 |
+
local_model_matches,
|
| 32 |
+
local_model_options_cached,
|
| 33 |
+
persona_names_cached,
|
| 34 |
+
personas_cached,
|
| 35 |
+
store_cache_parts,
|
| 36 |
+
store_id,
|
| 37 |
+
)
|
| 38 |
+
from utils.controls import render_mask_strategy_select
|
| 39 |
from utils.helpers import (
|
| 40 |
ANALYSIS_HELP_TEXT,
|
| 41 |
ANALYSIS_MODES,
|
|
|
|
| 42 |
prompt_variant_label,
|
| 43 |
slugify,
|
| 44 |
widget_key,
|
| 45 |
)
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
def _filename(*parts: str) -> str:
|
| 49 |
return "__".join(slugify(part) for part in parts if part)
|
| 50 |
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
# Keep compare-tab selection state separate so projection defaults do not
|
| 53 |
# overwrite cosine similarity defaults.
|
| 54 |
_LAST_COSINE_PERSONAS_KEY = "compare:last_personas:cosine"
|
|
|
|
| 57 |
_LAST_SOURCE_KEY = "compare:last_source"
|
| 58 |
|
| 59 |
|
| 60 |
+
def _is_assistant_persona(persona_id: str, persona_name: str | None = None) -> bool:
|
| 61 |
+
persona_id_normalized = persona_id.strip().lower()
|
| 62 |
+
persona_name_normalized = (persona_name or "").strip().lower()
|
| 63 |
+
return (
|
| 64 |
+
persona_id_normalized in {"assistant", BASELINE_PERSONA_ID.lower()}
|
| 65 |
+
or persona_name_normalized == "assistant"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
@dataclass(frozen=True)
|
| 70 |
class CosineSelection:
|
| 71 |
variants: list[str]
|
|
|
|
| 75 |
persona_key: str
|
| 76 |
|
| 77 |
|
| 78 |
+
@dataclass(frozen=True)
|
| 79 |
+
class PersonaOptions:
|
| 80 |
+
regular_ids: list[str]
|
| 81 |
+
assistant_id: str | None
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
def _layers_for_variant(
|
|
|
|
| 90 |
if isinstance(store, HFActivationStore):
|
| 91 |
if not persona_ids:
|
| 92 |
return []
|
| 93 |
+
return hub_layers_cached(
|
| 94 |
store.repo_id,
|
| 95 |
store.model_name,
|
| 96 |
mask_strategy.value,
|
| 97 |
variant,
|
| 98 |
persona_ids[0],
|
| 99 |
)
|
| 100 |
+
return list_layers_cached(
|
| 101 |
str(store.root_dir),
|
| 102 |
store.model_name,
|
| 103 |
[variant],
|
|
|
|
| 106 |
)
|
| 107 |
|
| 108 |
|
| 109 |
+
def _load_persona_options(
|
| 110 |
store: Store,
|
| 111 |
variants: list[str],
|
| 112 |
mask_strategy: MaskStrategy,
|
| 113 |
*,
|
| 114 |
+
empty_message: str,
|
| 115 |
+
) -> PersonaOptions | None:
|
| 116 |
+
source, location, model_name = store_cache_parts(store)
|
| 117 |
+
variant_key = tuple(variants)
|
| 118 |
+
persona_ids = personas_cached(
|
| 119 |
+
source,
|
| 120 |
+
location,
|
| 121 |
+
model_name,
|
| 122 |
+
mask_strategy.value,
|
| 123 |
+
variant_key,
|
| 124 |
+
)
|
| 125 |
+
if not persona_ids:
|
| 126 |
+
st.info(empty_message)
|
| 127 |
+
return None
|
| 128 |
+
|
| 129 |
+
persona_names = persona_names_cached(
|
| 130 |
+
source,
|
| 131 |
+
location,
|
| 132 |
+
model_name,
|
| 133 |
+
mask_strategy.value,
|
| 134 |
+
variant_key,
|
| 135 |
+
tuple(persona_ids),
|
| 136 |
+
)
|
| 137 |
+
assistant_ids = [
|
| 138 |
+
persona_id
|
| 139 |
+
for persona_id in persona_ids
|
| 140 |
+
if _is_assistant_persona(persona_id, persona_names.get(persona_id))
|
| 141 |
+
]
|
| 142 |
+
assistant_id = next(
|
| 143 |
+
(
|
| 144 |
+
persona_id
|
| 145 |
+
for persona_id in assistant_ids
|
| 146 |
+
if persona_id == BASELINE_PERSONA_ID
|
| 147 |
+
),
|
| 148 |
+
assistant_ids[0] if assistant_ids else None,
|
| 149 |
+
)
|
| 150 |
+
regular_ids = [persona_id for persona_id in persona_ids if persona_id not in assistant_ids]
|
| 151 |
+
if not regular_ids and assistant_id is None:
|
| 152 |
+
st.info("No personas found for this model and variant.")
|
| 153 |
+
return None
|
| 154 |
+
return PersonaOptions(regular_ids=regular_ids, assistant_id=assistant_id)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _seed_persona_memory(
|
| 158 |
remember_key: str,
|
| 159 |
+
options: PersonaOptions,
|
| 160 |
+
*,
|
| 161 |
+
default_all: bool,
|
| 162 |
+
) -> tuple[int, bool]:
|
| 163 |
+
remembered_count_key = f"{remember_key}:count"
|
| 164 |
+
remembered_assistant_key = f"{remember_key}:include_assistant"
|
| 165 |
+
legacy_ids = st.session_state.get(remember_key, [])
|
| 166 |
+
if isinstance(legacy_ids, list) and legacy_ids:
|
| 167 |
+
st.session_state.setdefault(
|
| 168 |
+
remembered_count_key,
|
| 169 |
+
sum(persona_id in options.regular_ids for persona_id in legacy_ids),
|
| 170 |
+
)
|
| 171 |
+
st.session_state.setdefault(
|
| 172 |
+
remembered_assistant_key,
|
| 173 |
+
options.assistant_id in legacy_ids,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
default_count = len(options.regular_ids) if default_all else min(1, len(options.regular_ids))
|
| 177 |
+
remembered_count = int(st.session_state.get(remembered_count_key, default_count))
|
| 178 |
+
persona_count = min(max(remembered_count, 0), len(options.regular_ids))
|
| 179 |
+
include_assistant = bool(
|
| 180 |
+
st.session_state.get(remembered_assistant_key, options.assistant_id is not None)
|
| 181 |
+
)
|
| 182 |
+
return persona_count, include_assistant
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
+
def _render_persona_count_controls(
|
| 186 |
+
store: Store,
|
| 187 |
+
variants: list[str],
|
| 188 |
+
mask_strategy: MaskStrategy,
|
| 189 |
+
widget_scope: str,
|
| 190 |
+
options: PersonaOptions,
|
| 191 |
+
*,
|
| 192 |
+
default_count: int,
|
| 193 |
+
include_assistant_default: bool,
|
| 194 |
+
) -> tuple[int, bool]:
|
| 195 |
+
count_key = widget_key(
|
| 196 |
"load",
|
| 197 |
+
"persona_count",
|
| 198 |
+
widget_scope,
|
| 199 |
+
store.model_name,
|
| 200 |
+
mask_strategy.value,
|
| 201 |
+
*variants,
|
| 202 |
+
)
|
| 203 |
+
assistant_key = widget_key(
|
| 204 |
+
"load",
|
| 205 |
+
"include_assistant",
|
| 206 |
widget_scope,
|
| 207 |
store.model_name,
|
| 208 |
mask_strategy.value,
|
| 209 |
*variants,
|
| 210 |
)
|
| 211 |
|
| 212 |
+
if options.regular_ids:
|
| 213 |
+
persona_count = st.slider(
|
| 214 |
+
"Personas",
|
| 215 |
+
min_value=0 if options.assistant_id is not None else 1,
|
| 216 |
+
max_value=len(options.regular_ids),
|
| 217 |
+
value=default_count,
|
| 218 |
+
key=count_key,
|
| 219 |
+
help="Use the first N available non-assistant personas.",
|
| 220 |
+
)
|
| 221 |
+
else:
|
| 222 |
+
persona_count = 0
|
| 223 |
+
st.caption("No non-assistant personas are available for this selection.")
|
| 224 |
+
include_assistant = False
|
| 225 |
+
if options.assistant_id is not None:
|
| 226 |
+
include_assistant = st.checkbox(
|
| 227 |
+
"Include Assistant persona",
|
| 228 |
+
value=include_assistant_default,
|
| 229 |
+
key=assistant_key,
|
| 230 |
+
)
|
| 231 |
+
return persona_count, include_assistant
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def _select_artifact_personas(
|
| 235 |
+
store: Store,
|
| 236 |
+
variants: list[str],
|
| 237 |
+
mask_strategy: MaskStrategy,
|
| 238 |
+
*,
|
| 239 |
+
widget_scope: str,
|
| 240 |
+
remember_key: str,
|
| 241 |
+
default_all: bool = False,
|
| 242 |
+
) -> list[str]:
|
| 243 |
+
empty_message = (
|
| 244 |
+
"No personas have vectors for all selected variants. "
|
| 245 |
+
"Pick a single variant or change the source."
|
| 246 |
+
if len(variants) > 1
|
| 247 |
+
else "No personas found for this model and variant."
|
| 248 |
)
|
| 249 |
+
options = _load_persona_options(
|
| 250 |
+
store,
|
| 251 |
+
variants,
|
| 252 |
+
mask_strategy,
|
| 253 |
+
empty_message=empty_message,
|
| 254 |
+
)
|
| 255 |
+
if options is None:
|
| 256 |
+
return []
|
| 257 |
+
|
| 258 |
+
default_count, include_assistant_default = _seed_persona_memory(
|
| 259 |
+
remember_key,
|
| 260 |
+
options,
|
| 261 |
+
default_all=default_all,
|
| 262 |
+
)
|
| 263 |
+
persona_count, include_assistant = _render_persona_count_controls(
|
| 264 |
+
store,
|
| 265 |
+
variants,
|
| 266 |
+
mask_strategy,
|
| 267 |
+
widget_scope,
|
| 268 |
+
options,
|
| 269 |
+
default_count=default_count,
|
| 270 |
+
include_assistant_default=include_assistant_default,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
persona_ids = options.regular_ids[:persona_count]
|
| 274 |
+
if include_assistant and options.assistant_id is not None:
|
| 275 |
+
persona_ids.append(options.assistant_id)
|
| 276 |
+
|
| 277 |
+
remembered_count_key = f"{remember_key}:count"
|
| 278 |
+
remembered_assistant_key = f"{remember_key}:include_assistant"
|
| 279 |
+
st.session_state[remembered_count_key] = persona_count
|
| 280 |
+
st.session_state[remembered_assistant_key] = include_assistant
|
| 281 |
+
st.session_state[remember_key] = persona_ids
|
| 282 |
+
|
| 283 |
+
if not persona_ids:
|
| 284 |
+
st.info("Select at least one persona or include the Assistant persona.")
|
| 285 |
+
return []
|
| 286 |
+
|
| 287 |
+
regular_label = f"{persona_count} persona{'s' if persona_count != 1 else ''}"
|
| 288 |
+
assistant_label = " plus Assistant" if include_assistant and options.assistant_id else ""
|
| 289 |
+
st.caption(f"Using {regular_label}{assistant_label}.")
|
| 290 |
+
return persona_ids
|
| 291 |
|
| 292 |
|
| 293 |
def _render_save_buttons(
|
|
|
|
| 305 |
|
| 306 |
|
| 307 |
def _render_mask_strategy_select(scope: str) -> MaskStrategy:
|
| 308 |
+
return render_mask_strategy_select(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
key=widget_key("load", "mask_strategy", scope),
|
| 310 |
+
last_key=_LAST_MASK_STRATEGY_KEY,
|
| 311 |
help="Which extracted activation set to load.",
|
| 312 |
)
|
|
|
|
|
|
|
| 313 |
|
| 314 |
|
| 315 |
def _render_cosine_selection(
|
| 316 |
store: Store,
|
| 317 |
mask_strategy: MaskStrategy,
|
| 318 |
) -> CosineSelection | None:
|
| 319 |
+
variants = available_variants(store, mask_strategy)
|
| 320 |
if len(variants) < 2:
|
| 321 |
st.info("Need at least two variants with saved vectors for cosine comparison.")
|
| 322 |
return None
|
|
|
|
| 329 |
options=variants,
|
| 330 |
index=0,
|
| 331 |
format_func=prompt_variant_label,
|
| 332 |
+
key=widget_key("load", "variant_a", store_id(store)),
|
| 333 |
)
|
| 334 |
with col2:
|
| 335 |
variant_b = st.selectbox(
|
|
|
|
| 337 |
options=variants,
|
| 338 |
index=min(1, len(variants) - 1),
|
| 339 |
format_func=prompt_variant_label,
|
| 340 |
+
key=widget_key("load", "variant_b", store_id(store)),
|
| 341 |
)
|
| 342 |
|
| 343 |
if variant_a == variant_b:
|
| 344 |
st.warning("Choose two different variants to compare.")
|
| 345 |
return None
|
| 346 |
|
| 347 |
+
persona_ids = _select_artifact_personas(
|
| 348 |
store,
|
| 349 |
[variant_a, variant_b],
|
| 350 |
mask_strategy,
|
| 351 |
+
widget_scope=f"cosine:{store_id(store)}",
|
| 352 |
remember_key=_LAST_COSINE_PERSONAS_KEY,
|
| 353 |
)
|
| 354 |
if not persona_ids:
|
|
|
|
| 443 |
cosine_fig_key = widget_key(
|
| 444 |
"load",
|
| 445 |
"cosine_fig_state",
|
| 446 |
+
store_id(store),
|
| 447 |
store.model_name,
|
| 448 |
mask_strategy.value,
|
| 449 |
selection.variant_a,
|
|
|
|
| 472 |
key=widget_key(
|
| 473 |
"load",
|
| 474 |
"compare_vectors",
|
| 475 |
+
store_id(store),
|
| 476 |
store.model_name,
|
| 477 |
mask_strategy.value,
|
| 478 |
selection.variant_a,
|
|
|
|
| 507 |
mask_strategy: MaskStrategy,
|
| 508 |
scope: str,
|
| 509 |
) -> tuple[str, list[str], str, list[int]] | None:
|
| 510 |
+
variants = available_variants(store, mask_strategy)
|
| 511 |
if not variants:
|
| 512 |
st.info("No variants with saved vectors for this model.")
|
| 513 |
return None
|
|
|
|
| 516 |
options=variants,
|
| 517 |
index=variants.index("biography") if "biography" in variants else 0,
|
| 518 |
format_func=prompt_variant_label,
|
| 519 |
+
key=widget_key("load", "variant", scope, store_id(store)),
|
| 520 |
)
|
| 521 |
+
persona_ids = _select_artifact_personas(
|
| 522 |
store,
|
| 523 |
[variant],
|
| 524 |
mask_strategy,
|
| 525 |
+
widget_scope=f"{scope}:{store_id(store)}",
|
| 526 |
remember_key=_LAST_PROJECTION_PERSONAS_KEY,
|
| 527 |
default_all=True,
|
| 528 |
)
|
|
|
|
| 535 |
st.info("No shared layers are available for the selected personas.")
|
| 536 |
return None
|
| 537 |
|
| 538 |
+
st.caption(f"Using all {len(layer_options)} available layer(s).")
|
| 539 |
+
return variant, persona_ids, persona_key, layer_options
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 540 |
|
| 541 |
|
| 542 |
def _render_layered_figure_analysis(
|
|
|
|
| 563 |
fig_key = widget_key(
|
| 564 |
"load",
|
| 565 |
f"{scope}_fig_state",
|
| 566 |
+
store_id(store),
|
| 567 |
store.model_name,
|
| 568 |
mask_strategy.value,
|
| 569 |
figure_kind,
|
|
|
|
| 572 |
"persona_vector",
|
| 573 |
persona_key,
|
| 574 |
)
|
| 575 |
+
filename = scope
|
| 576 |
|
| 577 |
if st.button(button_label, type="primary"):
|
| 578 |
try:
|
|
|
|
| 640 |
return source
|
| 641 |
|
| 642 |
|
| 643 |
+
def _render_hub_model_select(
|
| 644 |
+
repo_id: str,
|
| 645 |
+
mask_strategy: MaskStrategy,
|
| 646 |
+
) -> str:
|
| 647 |
+
fallback_model = st.session_state.get(
|
| 648 |
+
"compare:hub_model_fallback",
|
| 649 |
+
DEFAULT_COMPARE_MODEL,
|
| 650 |
+
)
|
| 651 |
+
try:
|
| 652 |
+
models_by_strategy = hub_models_by_mask_strategy(repo_id)
|
| 653 |
+
except Exception as exc:
|
| 654 |
+
st.warning(f"Could not load Hub configs for `{repo_id}`: {exc}")
|
| 655 |
+
return st.text_input(
|
| 656 |
+
"Hub model",
|
| 657 |
+
value=fallback_model,
|
| 658 |
+
key="compare:hub_model_fallback",
|
| 659 |
+
help="Compare-only model id to use if Hub config discovery is unavailable.",
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
model_options = models_by_strategy.get(mask_strategy, [])
|
| 663 |
+
if not model_options:
|
| 664 |
+
st.warning(
|
| 665 |
+
f"No Hub vector configs found for `{mask_strategy.value}` in `{repo_id}`."
|
| 666 |
+
)
|
| 667 |
+
return st.text_input(
|
| 668 |
+
"Hub model",
|
| 669 |
+
value=fallback_model,
|
| 670 |
+
key="compare:hub_model_fallback",
|
| 671 |
+
help="Compare-only model id to use for this Hub repo.",
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
previous_model = st.session_state.get(
|
| 675 |
+
widget_key("load", "hub_model", repo_id, mask_strategy.value),
|
| 676 |
+
fallback_model,
|
| 677 |
+
)
|
| 678 |
+
default_model = (
|
| 679 |
+
previous_model if previous_model in model_options else model_options[0]
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
return st.selectbox(
|
| 683 |
+
"Hub model",
|
| 684 |
+
options=model_options,
|
| 685 |
+
index=model_options.index(default_model),
|
| 686 |
+
key=widget_key("load", "hub_model", repo_id, mask_strategy.value),
|
| 687 |
+
help="Models with vectors in the selected Hub repo and mask strategy.",
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def _render_local_model_select(
|
| 692 |
+
artifacts_root: str,
|
| 693 |
+
mask_strategy: MaskStrategy,
|
| 694 |
+
) -> str:
|
| 695 |
+
fallback_model = st.session_state.get("compare:local_model", DEFAULT_COMPARE_MODEL)
|
| 696 |
+
model_options = local_model_options_cached(artifacts_root, mask_strategy.value)
|
| 697 |
+
if not model_options:
|
| 698 |
+
return st.text_input(
|
| 699 |
+
"Local model",
|
| 700 |
+
value=fallback_model,
|
| 701 |
+
key="compare:local_model",
|
| 702 |
+
help="Compare-only local model id or path.",
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
custom = st.toggle(
|
| 706 |
+
"Custom local model",
|
| 707 |
+
value=False,
|
| 708 |
+
key="compare:local_model_custom_enabled",
|
| 709 |
+
help="Enter a model id/path manually instead of choosing from activation directories.",
|
| 710 |
+
)
|
| 711 |
+
if custom:
|
| 712 |
+
return st.text_input(
|
| 713 |
+
"Local model",
|
| 714 |
+
value=fallback_model,
|
| 715 |
+
key="compare:local_model",
|
| 716 |
+
help="Compare-only local model id or path.",
|
| 717 |
+
)
|
| 718 |
+
|
| 719 |
+
previous_model = st.session_state.get("compare:local_model_select", fallback_model)
|
| 720 |
+
if not any(local_model_matches(previous_model, option) for option in model_options):
|
| 721 |
+
previous_model = fallback_model
|
| 722 |
+
default_model = next(
|
| 723 |
+
(
|
| 724 |
+
option
|
| 725 |
+
for option in model_options
|
| 726 |
+
if local_model_matches(option, previous_model)
|
| 727 |
+
),
|
| 728 |
+
model_options[0],
|
| 729 |
+
)
|
| 730 |
+
selected = st.selectbox(
|
| 731 |
+
"Local model",
|
| 732 |
+
options=model_options,
|
| 733 |
+
index=model_options.index(default_model),
|
| 734 |
+
key="compare:local_model_select",
|
| 735 |
+
help="Models discovered under the selected artifacts root.",
|
| 736 |
+
)
|
| 737 |
+
st.session_state["compare:local_model"] = selected
|
| 738 |
+
return selected
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
def _build_store(source: str, mask_strategy: MaskStrategy) -> Store:
|
| 742 |
if source == SOURCE_HUB:
|
| 743 |
repo = st.text_input(
|
| 744 |
"Hub repo",
|
|
|
|
| 746 |
key="compare:hub_repo",
|
| 747 |
help="Hugging Face dataset published by `scripts/push_to_hf.py`.",
|
| 748 |
)
|
| 749 |
+
hub_model_name = _render_hub_model_select(repo, mask_strategy)
|
| 750 |
+
return activation_store_cached(
|
| 751 |
+
SOURCE_HUB,
|
| 752 |
+
repo,
|
| 753 |
+
hub_model_name,
|
| 754 |
+
mask_strategy.value,
|
| 755 |
+
)
|
| 756 |
artifacts_root = st.text_input(
|
| 757 |
"Artifacts root",
|
| 758 |
value=str(get_artifacts_dir() / "activations"),
|
| 759 |
key="compare:artifacts_root",
|
| 760 |
)
|
| 761 |
+
artifacts_root = str(Path(artifacts_root).expanduser())
|
| 762 |
+
local_model_name = _render_local_model_select(artifacts_root, mask_strategy)
|
| 763 |
+
return activation_store_cached(
|
| 764 |
+
SOURCE_LOCAL,
|
| 765 |
+
artifacts_root,
|
| 766 |
+
local_model_name,
|
| 767 |
+
mask_strategy.value,
|
| 768 |
+
)
|
| 769 |
|
| 770 |
|
| 771 |
+
def render_compare_tab() -> None:
|
| 772 |
"""Render the compare tab."""
|
| 773 |
|
| 774 |
st.title("Compare")
|
|
|
|
| 787 |
analysis_mode = ANALYSIS_MODES[0]
|
| 788 |
st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
|
| 789 |
|
| 790 |
+
with st.expander("Source settings", expanded=True):
|
| 791 |
mask_strategy = _render_mask_strategy_select(analysis_mode)
|
| 792 |
+
store = _build_store(source, mask_strategy)
|
| 793 |
|
| 794 |
if analysis_mode == "Cosine similarity":
|
| 795 |
_render_cosine_similarity(store, mask_strategy)
|
tabs/extract.py
CHANGED
|
@@ -13,6 +13,7 @@ from persona_vectors.extraction import (
|
|
| 13 |
from persona_vectors.preview import TokenSegment, preview_token_segments
|
| 14 |
|
| 15 |
from utils.datasets import load_dataset, load_persona_list
|
|
|
|
| 16 |
from utils.helpers import (
|
| 17 |
NDIF_STATUS_ICONS,
|
| 18 |
persona_label,
|
|
@@ -211,28 +212,11 @@ def _render_mask_strategy_select(
|
|
| 211 |
remote: bool,
|
| 212 |
dataset_source: str,
|
| 213 |
) -> MaskStrategy:
|
| 214 |
-
|
| 215 |
-
_LAST_MASK_STRATEGY_KEY,
|
| 216 |
-
MaskStrategy.ANSWER_MEAN.value,
|
| 217 |
-
)
|
| 218 |
-
strategy_options = list(MaskStrategy)
|
| 219 |
-
mask_strategy = st.selectbox(
|
| 220 |
-
"Mask strategy",
|
| 221 |
-
options=strategy_options,
|
| 222 |
-
index=next(
|
| 223 |
-
(
|
| 224 |
-
idx
|
| 225 |
-
for idx, strategy in enumerate(strategy_options)
|
| 226 |
-
if strategy.value == last_strategy
|
| 227 |
-
),
|
| 228 |
-
0,
|
| 229 |
-
),
|
| 230 |
-
format_func=lambda s: s.value.replace("_", " ").title(),
|
| 231 |
key=_extract_widget_key(model_name, remote, dataset_source, "mask_strategy"),
|
|
|
|
| 232 |
help="Which tokens contribute to the averaged hidden state.",
|
| 233 |
)
|
| 234 |
-
st.session_state[_LAST_MASK_STRATEGY_KEY] = mask_strategy.value
|
| 235 |
-
return mask_strategy
|
| 236 |
|
| 237 |
|
| 238 |
def _collect_runs(
|
|
|
|
| 13 |
from persona_vectors.preview import TokenSegment, preview_token_segments
|
| 14 |
|
| 15 |
from utils.datasets import load_dataset, load_persona_list
|
| 16 |
+
from utils.controls import render_mask_strategy_select
|
| 17 |
from utils.helpers import (
|
| 18 |
NDIF_STATUS_ICONS,
|
| 19 |
persona_label,
|
|
|
|
| 212 |
remote: bool,
|
| 213 |
dataset_source: str,
|
| 214 |
) -> MaskStrategy:
|
| 215 |
+
return render_mask_strategy_select(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
key=_extract_widget_key(model_name, remote, dataset_source, "mask_strategy"),
|
| 217 |
+
last_key=_LAST_MASK_STRATEGY_KEY,
|
| 218 |
help="Which tokens contribute to the averaged hidden state.",
|
| 219 |
)
|
|
|
|
|
|
|
| 220 |
|
| 221 |
|
| 222 |
def _collect_runs(
|
utils/compare_sources.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import streamlit as st
|
| 5 |
+
from persona_vectors.artifacts import ActivationStore, HFActivationStore
|
| 6 |
+
from persona_vectors.artifacts import list_layers as list_local_layers
|
| 7 |
+
from persona_vectors.artifacts import model_dir_name
|
| 8 |
+
from persona_vectors.extraction import MaskStrategy
|
| 9 |
+
|
| 10 |
+
Store = ActivationStore | HFActivationStore
|
| 11 |
+
|
| 12 |
+
DEFAULT_HUB_REPO = os.environ.get(
|
| 13 |
+
"PERSONA_VECTORS_HUB_REPO",
|
| 14 |
+
"implicit-personalization/synth-persona-vectors",
|
| 15 |
+
)
|
| 16 |
+
DEFAULT_COMPARE_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
|
| 17 |
+
SOURCE_HUB = "Hugging Face Hub"
|
| 18 |
+
SOURCE_LOCAL = "Local activations"
|
| 19 |
+
SOURCES = (SOURCE_HUB, SOURCE_LOCAL)
|
| 20 |
+
|
| 21 |
+
list_layers_cached = st.cache_data(show_spinner=False)(list_local_layers)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@st.cache_resource(show_spinner=False)
|
| 25 |
+
def activation_store_cached(
|
| 26 |
+
source: str,
|
| 27 |
+
location: str,
|
| 28 |
+
model_name: str,
|
| 29 |
+
mask_strategy_value: str,
|
| 30 |
+
) -> Store:
|
| 31 |
+
mask_strategy = MaskStrategy(mask_strategy_value)
|
| 32 |
+
if source == SOURCE_HUB:
|
| 33 |
+
return HFActivationStore(location, model_name, mask_strategy=mask_strategy)
|
| 34 |
+
return ActivationStore(model_name, location, mask_strategy=mask_strategy)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@st.cache_data(show_spinner=False, ttl=10)
|
| 38 |
+
def available_variants_cached(
|
| 39 |
+
source: str,
|
| 40 |
+
location: str,
|
| 41 |
+
model_name: str,
|
| 42 |
+
mask_strategy_value: str,
|
| 43 |
+
) -> list[str]:
|
| 44 |
+
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 45 |
+
return store.available_variants()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@st.cache_data(show_spinner=False, ttl=10)
|
| 49 |
+
def personas_cached(
|
| 50 |
+
source: str,
|
| 51 |
+
location: str,
|
| 52 |
+
model_name: str,
|
| 53 |
+
mask_strategy_value: str,
|
| 54 |
+
variants: tuple[str, ...],
|
| 55 |
+
) -> list[str]:
|
| 56 |
+
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 57 |
+
return store.list_personas(
|
| 58 |
+
list(variants),
|
| 59 |
+
mask_strategy=MaskStrategy(mask_strategy_value),
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@st.cache_data(show_spinner=False, ttl=10)
|
| 64 |
+
def persona_names_cached(
|
| 65 |
+
source: str,
|
| 66 |
+
location: str,
|
| 67 |
+
model_name: str,
|
| 68 |
+
mask_strategy_value: str,
|
| 69 |
+
variants: tuple[str, ...],
|
| 70 |
+
persona_ids: tuple[str, ...],
|
| 71 |
+
) -> dict[str, str]:
|
| 72 |
+
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 73 |
+
return store.persona_names(
|
| 74 |
+
list(persona_ids),
|
| 75 |
+
variants=list(variants),
|
| 76 |
+
mask_strategy=MaskStrategy(mask_strategy_value),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@st.cache_data(show_spinner=False, ttl=10)
|
| 81 |
+
def local_model_options_cached(
|
| 82 |
+
artifacts_root: str, mask_strategy_value: str
|
| 83 |
+
) -> list[str]:
|
| 84 |
+
root = Path(artifacts_root).expanduser()
|
| 85 |
+
if not root.exists() or not root.is_dir():
|
| 86 |
+
return []
|
| 87 |
+
|
| 88 |
+
options = []
|
| 89 |
+
try:
|
| 90 |
+
model_roots = sorted(path for path in root.iterdir() if path.is_dir())
|
| 91 |
+
except OSError:
|
| 92 |
+
return []
|
| 93 |
+
|
| 94 |
+
for model_root in model_roots:
|
| 95 |
+
strategy_root = model_root / mask_strategy_value
|
| 96 |
+
if not strategy_root.is_dir():
|
| 97 |
+
continue
|
| 98 |
+
variant_roots = (
|
| 99 |
+
variant_root
|
| 100 |
+
for variant_root in strategy_root.iterdir()
|
| 101 |
+
if variant_root.is_dir()
|
| 102 |
+
)
|
| 103 |
+
if any(
|
| 104 |
+
(variant_root / "manifest.json").exists() for variant_root in variant_roots
|
| 105 |
+
):
|
| 106 |
+
options.append(model_root.name.replace("__", "/"))
|
| 107 |
+
return options
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@st.cache_data(show_spinner=False)
|
| 111 |
+
def hub_config_names_cached(repo_id: str) -> list[str]:
|
| 112 |
+
try:
|
| 113 |
+
from huggingface_hub import get_dataset_config_names
|
| 114 |
+
except ImportError:
|
| 115 |
+
from datasets import get_dataset_config_names
|
| 116 |
+
|
| 117 |
+
return sorted(get_dataset_config_names(repo_id))
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@st.cache_data(show_spinner=False)
|
| 121 |
+
def hub_layers_cached(
|
| 122 |
+
repo_id: str,
|
| 123 |
+
model_name: str,
|
| 124 |
+
mask_strategy_value: str,
|
| 125 |
+
variant: str,
|
| 126 |
+
persona_id: str,
|
| 127 |
+
) -> list[int]:
|
| 128 |
+
store = HFActivationStore(
|
| 129 |
+
repo_id,
|
| 130 |
+
model_name,
|
| 131 |
+
mask_strategy=MaskStrategy(mask_strategy_value),
|
| 132 |
+
)
|
| 133 |
+
sample = store.load(variant, persona_id)
|
| 134 |
+
return list(range(int(sample.shape[0])))
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def parse_hub_config_name(config_name: str) -> tuple[str, MaskStrategy] | None:
|
| 138 |
+
for strategy in MaskStrategy:
|
| 139 |
+
suffix = f"__{strategy.value}"
|
| 140 |
+
if config_name.endswith(suffix):
|
| 141 |
+
model_key = config_name[: -len(suffix)]
|
| 142 |
+
return model_key.replace("__", "/"), strategy
|
| 143 |
+
return None
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def hub_models_by_mask_strategy(repo_id: str) -> dict[MaskStrategy, list[str]]:
|
| 147 |
+
models_by_strategy: dict[MaskStrategy, set[str]] = {
|
| 148 |
+
strategy: set() for strategy in MaskStrategy
|
| 149 |
+
}
|
| 150 |
+
for config_name in hub_config_names_cached(repo_id):
|
| 151 |
+
parsed = parse_hub_config_name(config_name)
|
| 152 |
+
if parsed is None:
|
| 153 |
+
continue
|
| 154 |
+
model_name, strategy = parsed
|
| 155 |
+
models_by_strategy[strategy].add(model_name)
|
| 156 |
+
return {
|
| 157 |
+
strategy: sorted(models)
|
| 158 |
+
for strategy, models in models_by_strategy.items()
|
| 159 |
+
if models
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def store_cache_parts(store: Store) -> tuple[str, str, str]:
|
| 164 |
+
if isinstance(store, HFActivationStore):
|
| 165 |
+
return SOURCE_HUB, store.repo_id, store.model_name
|
| 166 |
+
return SOURCE_LOCAL, str(store.root_dir), store.model_name
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def store_id(store: Store) -> str:
|
| 170 |
+
if isinstance(store, HFActivationStore):
|
| 171 |
+
return f"hub:{store.repo_id}"
|
| 172 |
+
return f"local:{store.root_dir}"
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def available_variants(store: Store, mask_strategy: MaskStrategy) -> list[str]:
|
| 176 |
+
source, location, model_name = store_cache_parts(store)
|
| 177 |
+
return available_variants_cached(
|
| 178 |
+
source,
|
| 179 |
+
location,
|
| 180 |
+
model_name,
|
| 181 |
+
mask_strategy.value,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def local_model_matches(left: str, right: str) -> bool:
|
| 186 |
+
return model_dir_name(left) == model_dir_name(right)
|
utils/controls.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from persona_vectors.extraction import MaskStrategy
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def render_mask_strategy_select(
|
| 6 |
+
*,
|
| 7 |
+
key: str,
|
| 8 |
+
last_key: str,
|
| 9 |
+
help_text: str,
|
| 10 |
+
) -> MaskStrategy:
|
| 11 |
+
last_strategy = st.session_state.get(last_key, MaskStrategy.ANSWER_MEAN.value)
|
| 12 |
+
strategies = list(MaskStrategy)
|
| 13 |
+
selected = st.selectbox(
|
| 14 |
+
"Mask strategy",
|
| 15 |
+
options=strategies,
|
| 16 |
+
index=next(
|
| 17 |
+
(
|
| 18 |
+
idx
|
| 19 |
+
for idx, strategy in enumerate(strategies)
|
| 20 |
+
if strategy.value == last_strategy
|
| 21 |
+
),
|
| 22 |
+
0,
|
| 23 |
+
),
|
| 24 |
+
format_func=lambda strategy: strategy.value.replace("_", " ").title(),
|
| 25 |
+
key=key,
|
| 26 |
+
help=help_text,
|
| 27 |
+
)
|
| 28 |
+
st.session_state[last_key] = selected.value
|
| 29 |
+
return selected
|
uv.lock
CHANGED
|
@@ -376,7 +376,7 @@ name = "cuda-bindings"
|
|
| 376 |
version = "13.2.0"
|
| 377 |
source = { registry = "https://pypi.org/simple" }
|
| 378 |
dependencies = [
|
| 379 |
-
{ name = "cuda-pathfinder" },
|
| 380 |
]
|
| 381 |
wheels = [
|
| 382 |
{ url = "https://files.pythonhosted.org/packages/52/c8/b2589d68acf7e3d63e2be330b84bc25712e97ed799affbca7edd7eae25d6/cuda_bindings-13.2.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e865447abfb83d6a98ad5130ed3c70b1fc295ae3eeee39fd07b4ddb0671b6788", size = 5722404, upload-time = "2026-03-11T00:12:44.041Z" },
|
|
@@ -407,37 +407,37 @@ wheels = [
|
|
| 407 |
|
| 408 |
[package.optional-dependencies]
|
| 409 |
cublas = [
|
| 410 |
-
{ name = "nvidia-cublas", marker = "sys_platform == 'linux'
|
| 411 |
]
|
| 412 |
cudart = [
|
| 413 |
-
{ name = "nvidia-cuda-runtime", marker = "sys_platform == 'linux'
|
| 414 |
]
|
| 415 |
cufft = [
|
| 416 |
-
{ name = "nvidia-cufft", marker = "sys_platform == 'linux'
|
| 417 |
]
|
| 418 |
cufile = [
|
| 419 |
{ name = "nvidia-cufile", marker = "sys_platform == 'linux'" },
|
| 420 |
]
|
| 421 |
cupti = [
|
| 422 |
-
{ name = "nvidia-cuda-cupti", marker = "sys_platform == 'linux'
|
| 423 |
]
|
| 424 |
curand = [
|
| 425 |
-
{ name = "nvidia-curand", marker = "sys_platform == 'linux'
|
| 426 |
]
|
| 427 |
cusolver = [
|
| 428 |
-
{ name = "nvidia-cusolver", marker = "sys_platform == 'linux'
|
| 429 |
]
|
| 430 |
cusparse = [
|
| 431 |
-
{ name = "nvidia-cusparse", marker = "sys_platform == 'linux'
|
| 432 |
]
|
| 433 |
nvjitlink = [
|
| 434 |
-
{ name = "nvidia-nvjitlink", marker = "sys_platform == 'linux'
|
| 435 |
]
|
| 436 |
nvrtc = [
|
| 437 |
-
{ name = "nvidia-cuda-nvrtc", marker = "sys_platform == 'linux'
|
| 438 |
]
|
| 439 |
nvtx = [
|
| 440 |
-
{ name = "nvidia-nvtx", marker = "sys_platform == 'linux'
|
| 441 |
]
|
| 442 |
|
| 443 |
[[package]]
|
|
@@ -1326,7 +1326,7 @@ name = "nvidia-cudnn-cu13"
|
|
| 1326 |
version = "9.19.0.56"
|
| 1327 |
source = { registry = "https://pypi.org/simple" }
|
| 1328 |
dependencies = [
|
| 1329 |
-
{ name = "nvidia-cublas" },
|
| 1330 |
]
|
| 1331 |
wheels = [
|
| 1332 |
{ url = "https://files.pythonhosted.org/packages/f1/84/26025437c1e6b61a707442184fa0c03d083b661adf3a3eecfd6d21677740/nvidia_cudnn_cu13-9.19.0.56-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:6ed29ffaee1176c612daf442e4dd6cfeb6a0caa43ddcbeb59da94953030b1be4", size = 433781201, upload-time = "2026-02-03T20:40:53.805Z" },
|
|
@@ -1338,7 +1338,7 @@ name = "nvidia-cufft"
|
|
| 1338 |
version = "12.0.0.61"
|
| 1339 |
source = { registry = "https://pypi.org/simple" }
|
| 1340 |
dependencies = [
|
| 1341 |
-
{ name = "nvidia-nvjitlink" },
|
| 1342 |
]
|
| 1343 |
wheels = [
|
| 1344 |
{ url = "https://files.pythonhosted.org/packages/8b/ae/f417a75c0259e85c1d2f83ca4e960289a5f814ed0cea74d18c353d3e989d/nvidia_cufft-12.0.0.61-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5", size = 214053554, upload-time = "2025-09-04T08:31:38.196Z" },
|
|
@@ -1368,9 +1368,9 @@ name = "nvidia-cusolver"
|
|
| 1368 |
version = "12.0.4.66"
|
| 1369 |
source = { registry = "https://pypi.org/simple" }
|
| 1370 |
dependencies = [
|
| 1371 |
-
{ name = "nvidia-cublas" },
|
| 1372 |
-
{ name = "nvidia-cusparse" },
|
| 1373 |
-
{ name = "nvidia-nvjitlink" },
|
| 1374 |
]
|
| 1375 |
wheels = [
|
| 1376 |
{ url = "https://files.pythonhosted.org/packages/c8/c3/b30c9e935fc01e3da443ec0116ed1b2a009bb867f5324d3f2d7e533e776b/nvidia_cusolver-12.0.4.66-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2", size = 223467760, upload-time = "2025-09-04T08:33:04.222Z" },
|
|
@@ -1382,7 +1382,7 @@ name = "nvidia-cusparse"
|
|
| 1382 |
version = "12.6.3.3"
|
| 1383 |
source = { registry = "https://pypi.org/simple" }
|
| 1384 |
dependencies = [
|
| 1385 |
-
{ name = "nvidia-nvjitlink" },
|
| 1386 |
]
|
| 1387 |
wheels = [
|
| 1388 |
{ url = "https://files.pythonhosted.org/packages/f8/94/5c26f33738ae35276672f12615a64bd008ed5be6d1ebcb23579285d960a9/nvidia_cusparse-12.6.3.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c", size = 162155568, upload-time = "2025-09-04T08:33:42.864Z" },
|
|
@@ -1575,10 +1575,12 @@ wheels = [
|
|
| 1575 |
|
| 1576 |
[[package]]
|
| 1577 |
name = "persona-ui"
|
| 1578 |
-
version = "0.
|
| 1579 |
source = { virtual = "." }
|
| 1580 |
dependencies = [
|
| 1581 |
{ name = "catppuccin" },
|
|
|
|
|
|
|
| 1582 |
{ name = "persona-data" },
|
| 1583 |
{ name = "persona-vectors" },
|
| 1584 |
{ name = "plotly" },
|
|
@@ -1589,6 +1591,8 @@ dependencies = [
|
|
| 1589 |
[package.metadata]
|
| 1590 |
requires-dist = [
|
| 1591 |
{ name = "catppuccin", specifier = ">=2.5.0" },
|
|
|
|
|
|
|
| 1592 |
{ name = "persona-data", specifier = ">=0.4.2" },
|
| 1593 |
{ name = "persona-vectors", specifier = ">=0.6.4" },
|
| 1594 |
{ name = "plotly", specifier = ">=6.6.0" },
|
|
|
|
| 376 |
version = "13.2.0"
|
| 377 |
source = { registry = "https://pypi.org/simple" }
|
| 378 |
dependencies = [
|
| 379 |
+
{ name = "cuda-pathfinder", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" },
|
| 380 |
]
|
| 381 |
wheels = [
|
| 382 |
{ url = "https://files.pythonhosted.org/packages/52/c8/b2589d68acf7e3d63e2be330b84bc25712e97ed799affbca7edd7eae25d6/cuda_bindings-13.2.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e865447abfb83d6a98ad5130ed3c70b1fc295ae3eeee39fd07b4ddb0671b6788", size = 5722404, upload-time = "2026-03-11T00:12:44.041Z" },
|
|
|
|
| 407 |
|
| 408 |
[package.optional-dependencies]
|
| 409 |
cublas = [
|
| 410 |
+
{ name = "nvidia-cublas", marker = "sys_platform == 'linux'" },
|
| 411 |
]
|
| 412 |
cudart = [
|
| 413 |
+
{ name = "nvidia-cuda-runtime", marker = "sys_platform == 'linux'" },
|
| 414 |
]
|
| 415 |
cufft = [
|
| 416 |
+
{ name = "nvidia-cufft", marker = "sys_platform == 'linux'" },
|
| 417 |
]
|
| 418 |
cufile = [
|
| 419 |
{ name = "nvidia-cufile", marker = "sys_platform == 'linux'" },
|
| 420 |
]
|
| 421 |
cupti = [
|
| 422 |
+
{ name = "nvidia-cuda-cupti", marker = "sys_platform == 'linux'" },
|
| 423 |
]
|
| 424 |
curand = [
|
| 425 |
+
{ name = "nvidia-curand", marker = "sys_platform == 'linux'" },
|
| 426 |
]
|
| 427 |
cusolver = [
|
| 428 |
+
{ name = "nvidia-cusolver", marker = "sys_platform == 'linux'" },
|
| 429 |
]
|
| 430 |
cusparse = [
|
| 431 |
+
{ name = "nvidia-cusparse", marker = "sys_platform == 'linux'" },
|
| 432 |
]
|
| 433 |
nvjitlink = [
|
| 434 |
+
{ name = "nvidia-nvjitlink", marker = "sys_platform == 'linux'" },
|
| 435 |
]
|
| 436 |
nvrtc = [
|
| 437 |
+
{ name = "nvidia-cuda-nvrtc", marker = "sys_platform == 'linux'" },
|
| 438 |
]
|
| 439 |
nvtx = [
|
| 440 |
+
{ name = "nvidia-nvtx", marker = "sys_platform == 'linux'" },
|
| 441 |
]
|
| 442 |
|
| 443 |
[[package]]
|
|
|
|
| 1326 |
version = "9.19.0.56"
|
| 1327 |
source = { registry = "https://pypi.org/simple" }
|
| 1328 |
dependencies = [
|
| 1329 |
+
{ name = "nvidia-cublas", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" },
|
| 1330 |
]
|
| 1331 |
wheels = [
|
| 1332 |
{ url = "https://files.pythonhosted.org/packages/f1/84/26025437c1e6b61a707442184fa0c03d083b661adf3a3eecfd6d21677740/nvidia_cudnn_cu13-9.19.0.56-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:6ed29ffaee1176c612daf442e4dd6cfeb6a0caa43ddcbeb59da94953030b1be4", size = 433781201, upload-time = "2026-02-03T20:40:53.805Z" },
|
|
|
|
| 1338 |
version = "12.0.0.61"
|
| 1339 |
source = { registry = "https://pypi.org/simple" }
|
| 1340 |
dependencies = [
|
| 1341 |
+
{ name = "nvidia-nvjitlink", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" },
|
| 1342 |
]
|
| 1343 |
wheels = [
|
| 1344 |
{ url = "https://files.pythonhosted.org/packages/8b/ae/f417a75c0259e85c1d2f83ca4e960289a5f814ed0cea74d18c353d3e989d/nvidia_cufft-12.0.0.61-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2708c852ef8cd89d1d2068bdbece0aa188813a0c934db3779b9b1faa8442e5f5", size = 214053554, upload-time = "2025-09-04T08:31:38.196Z" },
|
|
|
|
| 1368 |
version = "12.0.4.66"
|
| 1369 |
source = { registry = "https://pypi.org/simple" }
|
| 1370 |
dependencies = [
|
| 1371 |
+
{ name = "nvidia-cublas", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" },
|
| 1372 |
+
{ name = "nvidia-cusparse", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" },
|
| 1373 |
+
{ name = "nvidia-nvjitlink", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" },
|
| 1374 |
]
|
| 1375 |
wheels = [
|
| 1376 |
{ url = "https://files.pythonhosted.org/packages/c8/c3/b30c9e935fc01e3da443ec0116ed1b2a009bb867f5324d3f2d7e533e776b/nvidia_cusolver-12.0.4.66-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:02c2457eaa9e39de20f880f4bd8820e6a1cfb9f9a34f820eb12a155aa5bc92d2", size = 223467760, upload-time = "2025-09-04T08:33:04.222Z" },
|
|
|
|
| 1382 |
version = "12.6.3.3"
|
| 1383 |
source = { registry = "https://pypi.org/simple" }
|
| 1384 |
dependencies = [
|
| 1385 |
+
{ name = "nvidia-nvjitlink", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" },
|
| 1386 |
]
|
| 1387 |
wheels = [
|
| 1388 |
{ url = "https://files.pythonhosted.org/packages/f8/94/5c26f33738ae35276672f12615a64bd008ed5be6d1ebcb23579285d960a9/nvidia_cusparse-12.6.3.3-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:80bcc4662f23f1054ee334a15c72b8940402975e0eab63178fc7e670aa59472c", size = 162155568, upload-time = "2025-09-04T08:33:42.864Z" },
|
|
|
|
| 1575 |
|
| 1576 |
[[package]]
|
| 1577 |
name = "persona-ui"
|
| 1578 |
+
version = "0.4.0"
|
| 1579 |
source = { virtual = "." }
|
| 1580 |
dependencies = [
|
| 1581 |
{ name = "catppuccin" },
|
| 1582 |
+
{ name = "datasets" },
|
| 1583 |
+
{ name = "huggingface-hub" },
|
| 1584 |
{ name = "persona-data" },
|
| 1585 |
{ name = "persona-vectors" },
|
| 1586 |
{ name = "plotly" },
|
|
|
|
| 1591 |
[package.metadata]
|
| 1592 |
requires-dist = [
|
| 1593 |
{ name = "catppuccin", specifier = ">=2.5.0" },
|
| 1594 |
+
{ name = "datasets", specifier = ">=4.8.5" },
|
| 1595 |
+
{ name = "huggingface-hub", specifier = ">=1.14.0" },
|
| 1596 |
{ name = "persona-data", specifier = ">=0.4.2" },
|
| 1597 |
{ name = "persona-vectors", specifier = ">=0.6.4" },
|
| 1598 |
{ name = "plotly", specifier = ">=6.6.0" },
|