Jac-Zac commited on
Commit ·
c30bbc5
1
Parent(s): 330d092
Updated cleaned up code
Browse files- app.py +59 -62
- pyproject.toml +2 -2
- state.py +3 -4
- tabs/chat.py +0 -2
- tabs/chat_ui.py +23 -53
- tabs/compare.py +1 -4
- tabs/extract.py +15 -23
- utils/chat.py +21 -3
- utils/contrast.py +3 -27
- utils/datasets.py +6 -20
- utils/helpers.py +1 -3
- utils/probe_trace.py +3 -21
- utils/probes.py +22 -22
- uv.lock +14 -14
app.py
CHANGED
|
@@ -16,6 +16,64 @@ _TABS = ["Chat", "Compare", "Extract"]
|
|
| 16 |
_TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def _sidebar_controls() -> tuple[bool, str, str, str]:
|
| 20 |
from utils.runtime import list_remote_models
|
| 21 |
|
|
@@ -44,68 +102,7 @@ def _sidebar_controls() -> tuple[bool, str, str, str]:
|
|
| 44 |
remote = st.toggle("Remote (NDIF)", value=False, key="sidebar__remote")
|
| 45 |
|
| 46 |
if remote:
|
| 47 |
-
|
| 48 |
-
custom_remote_key = "sidebar__remote_model_custom_enabled"
|
| 49 |
-
custom_remote_model = st.toggle(
|
| 50 |
-
"Custom remote model",
|
| 51 |
-
value=False,
|
| 52 |
-
key=custom_remote_key,
|
| 53 |
-
help="Enter any NDIF-loadable model id, even if it is not currently running.",
|
| 54 |
-
)
|
| 55 |
-
if remote_models:
|
| 56 |
-
if custom_remote_model:
|
| 57 |
-
model_name = st.text_input(
|
| 58 |
-
"Model",
|
| 59 |
-
value=st.session_state.get(
|
| 60 |
-
"sidebar__remote_model_custom_value",
|
| 61 |
-
st.session_state.get(
|
| 62 |
-
_LAST_REMOTE_MODEL_KEY, REMOTE_DEFAULT_MODEL
|
| 63 |
-
),
|
| 64 |
-
),
|
| 65 |
-
key="sidebar__remote_model_custom_value",
|
| 66 |
-
help="NDIF model id. Example: openai/gpt-oss-20b",
|
| 67 |
-
)
|
| 68 |
-
st.caption(
|
| 69 |
-
f"{len(remote_models)} running NDIF model(s) detected. Custom model ids can cold-load if your NDIF account allows it."
|
| 70 |
-
)
|
| 71 |
-
else:
|
| 72 |
-
default_model = st.session_state.get(
|
| 73 |
-
"sidebar__remote_model",
|
| 74 |
-
st.session_state.get(_LAST_REMOTE_MODEL_KEY),
|
| 75 |
-
)
|
| 76 |
-
if default_model not in remote_models:
|
| 77 |
-
default_model = (
|
| 78 |
-
REMOTE_DEFAULT_MODEL
|
| 79 |
-
if REMOTE_DEFAULT_MODEL in remote_models
|
| 80 |
-
else remote_models[0]
|
| 81 |
-
)
|
| 82 |
-
if (
|
| 83 |
-
st.session_state.get("sidebar__remote_model")
|
| 84 |
-
not in remote_models
|
| 85 |
-
):
|
| 86 |
-
st.session_state["sidebar__remote_model"] = default_model
|
| 87 |
-
selected_remote_model = st.selectbox(
|
| 88 |
-
"Model",
|
| 89 |
-
options=remote_models,
|
| 90 |
-
index=remote_models.index(default_model),
|
| 91 |
-
key="sidebar__remote_model",
|
| 92 |
-
help="Running NDIF model.",
|
| 93 |
-
)
|
| 94 |
-
model_name = selected_remote_model
|
| 95 |
-
else:
|
| 96 |
-
st.warning("No running NDIF models found.")
|
| 97 |
-
model_name = st.text_input(
|
| 98 |
-
"Model",
|
| 99 |
-
value=st.session_state.get(
|
| 100 |
-
"sidebar__remote_model_custom_value",
|
| 101 |
-
st.session_state.get(
|
| 102 |
-
_LAST_REMOTE_MODEL_KEY, REMOTE_DEFAULT_MODEL
|
| 103 |
-
),
|
| 104 |
-
),
|
| 105 |
-
key="sidebar__remote_model_custom_value",
|
| 106 |
-
help="NDIF model id. Use this to cold-load a remote model.",
|
| 107 |
-
)
|
| 108 |
-
st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name
|
| 109 |
else:
|
| 110 |
model_name = st.text_input(
|
| 111 |
"Model",
|
|
|
|
| 16 |
_TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
|
| 17 |
|
| 18 |
|
| 19 |
+
def _remote_model_input(remote_models: list[str]) -> str:
|
| 20 |
+
"""Return the active remote model id, picking from running NDIF deployments or a custom value."""
|
| 21 |
+
|
| 22 |
+
last_remote = st.session_state.get(_LAST_REMOTE_MODEL_KEY, REMOTE_DEFAULT_MODEL)
|
| 23 |
+
|
| 24 |
+
if not remote_models:
|
| 25 |
+
st.warning("No running NDIF models found.")
|
| 26 |
+
model_name = st.text_input(
|
| 27 |
+
"Model",
|
| 28 |
+
value=st.session_state.get(
|
| 29 |
+
"sidebar__remote_model_custom_value", last_remote
|
| 30 |
+
),
|
| 31 |
+
key="sidebar__remote_model_custom_value",
|
| 32 |
+
help="NDIF model id. Use this to cold-load a remote model.",
|
| 33 |
+
)
|
| 34 |
+
st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name
|
| 35 |
+
return model_name
|
| 36 |
+
|
| 37 |
+
custom = st.toggle(
|
| 38 |
+
"Custom remote model",
|
| 39 |
+
value=False,
|
| 40 |
+
key="sidebar__remote_model_custom_enabled",
|
| 41 |
+
help="Enter any NDIF-loadable model id, even if it is not currently running.",
|
| 42 |
+
)
|
| 43 |
+
if custom:
|
| 44 |
+
model_name = st.text_input(
|
| 45 |
+
"Model",
|
| 46 |
+
value=st.session_state.get(
|
| 47 |
+
"sidebar__remote_model_custom_value", last_remote
|
| 48 |
+
),
|
| 49 |
+
key="sidebar__remote_model_custom_value",
|
| 50 |
+
help="NDIF model id. Example: openai/gpt-oss-20b",
|
| 51 |
+
)
|
| 52 |
+
st.caption(
|
| 53 |
+
f"{len(remote_models)} running NDIF model(s) detected. "
|
| 54 |
+
"Custom model ids can cold-load if your NDIF account allows it."
|
| 55 |
+
)
|
| 56 |
+
else:
|
| 57 |
+
default_model = st.session_state.get("sidebar__remote_model", last_remote)
|
| 58 |
+
if default_model not in remote_models:
|
| 59 |
+
default_model = (
|
| 60 |
+
REMOTE_DEFAULT_MODEL
|
| 61 |
+
if REMOTE_DEFAULT_MODEL in remote_models
|
| 62 |
+
else remote_models[0]
|
| 63 |
+
)
|
| 64 |
+
if st.session_state.get("sidebar__remote_model") not in remote_models:
|
| 65 |
+
st.session_state["sidebar__remote_model"] = default_model
|
| 66 |
+
model_name = st.selectbox(
|
| 67 |
+
"Model",
|
| 68 |
+
options=remote_models,
|
| 69 |
+
index=remote_models.index(default_model),
|
| 70 |
+
key="sidebar__remote_model",
|
| 71 |
+
help="Running NDIF model.",
|
| 72 |
+
)
|
| 73 |
+
st.session_state[_LAST_REMOTE_MODEL_KEY] = model_name
|
| 74 |
+
return model_name
|
| 75 |
+
|
| 76 |
+
|
| 77 |
def _sidebar_controls() -> tuple[bool, str, str, str]:
|
| 78 |
from utils.runtime import list_remote_models
|
| 79 |
|
|
|
|
| 102 |
remote = st.toggle("Remote (NDIF)", value=False, key="sidebar__remote")
|
| 103 |
|
| 104 |
if remote:
|
| 105 |
+
model_name = _remote_model_input(list_remote_models())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
else:
|
| 107 |
model_name = st.text_input(
|
| 108 |
"Model",
|
pyproject.toml
CHANGED
|
@@ -5,8 +5,8 @@ description = "Streamlit UI for persona-vectors"
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
-
"persona-vectors>=0.6.
|
| 9 |
-
"persona-data>=0.4.
|
| 10 |
"streamlit>=1.44.0",
|
| 11 |
"plotly>=6.6.0",
|
| 12 |
"python-dotenv>=1.2.2",
|
|
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
+
"persona-vectors>=0.6.3",
|
| 9 |
+
"persona-data>=0.4.2",
|
| 10 |
"streamlit>=1.44.0",
|
| 11 |
"plotly>=6.6.0",
|
| 12 |
"python-dotenv>=1.2.2",
|
state.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
from typing import Literal, NotRequired, TypedDict
|
| 3 |
|
|
|
|
|
|
|
| 4 |
_CHAT_STATE_PREFIX = "chat_state::"
|
| 5 |
PendingChatAction = Literal["new_user_prompt", "regenerate_after_edit"]
|
| 6 |
|
|
@@ -50,9 +51,7 @@ def reset_chat_context_state(
|
|
| 50 |
st.session_state.pop(key, None)
|
| 51 |
|
| 52 |
|
| 53 |
-
def get_chat_state(
|
| 54 |
-
model_name: str, remote: bool, dataset_source: str
|
| 55 |
-
) -> ChatState:
|
| 56 |
"""Return the mutable chat state for the active context."""
|
| 57 |
|
| 58 |
key = chat_session_key(model_name, dataset_source)
|
|
|
|
|
|
|
| 1 |
from typing import Literal, NotRequired, TypedDict
|
| 2 |
|
| 3 |
+
import streamlit as st
|
| 4 |
+
|
| 5 |
_CHAT_STATE_PREFIX = "chat_state::"
|
| 6 |
PendingChatAction = Literal["new_user_prompt", "regenerate_after_edit"]
|
| 7 |
|
|
|
|
| 51 |
st.session_state.pop(key, None)
|
| 52 |
|
| 53 |
|
| 54 |
+
def get_chat_state(model_name: str, remote: bool, dataset_source: str) -> ChatState:
|
|
|
|
|
|
|
| 55 |
"""Return the mutable chat state for the active context."""
|
| 56 |
|
| 57 |
key = chat_session_key(model_name, dataset_source)
|
tabs/chat.py
CHANGED
|
@@ -128,8 +128,6 @@ def _handle_single_chat_generation(
|
|
| 128 |
st.rerun()
|
| 129 |
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
| 134 |
"""Render the chat tab."""
|
| 135 |
|
|
|
|
| 128 |
st.rerun()
|
| 129 |
|
| 130 |
|
|
|
|
|
|
|
| 131 |
def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
| 132 |
"""Render the chat tab."""
|
| 133 |
|
tabs/chat_ui.py
CHANGED
|
@@ -269,76 +269,46 @@ def render_chat_message(
|
|
| 269 |
) -> None:
|
| 270 |
if not message.get("content"):
|
| 271 |
return
|
| 272 |
-
role = message["role"]
|
| 273 |
contrast: TokenContrast | None = message.get("_contrast") if show_contrast else None
|
| 274 |
-
with st.chat_message(role):
|
| 275 |
if contrast is not None:
|
| 276 |
st.html(render_contrast_html(contrast))
|
| 277 |
else:
|
| 278 |
st.markdown(message["content"])
|
| 279 |
|
| 280 |
|
| 281 |
-
def _render_editable_message(
|
| 282 |
-
message: dict[str, str],
|
| 283 |
-
msg_index: int,
|
| 284 |
-
messages: list[dict[str, str]],
|
| 285 |
-
chat_state: dict[str, object],
|
| 286 |
-
edit_key: str,
|
| 287 |
-
pending_key: str,
|
| 288 |
-
show_contrast: bool = False,
|
| 289 |
-
column_ratio: tuple[int, int] = (25, 1),
|
| 290 |
-
) -> None:
|
| 291 |
-
if not message.get("content"):
|
| 292 |
-
return
|
| 293 |
-
role = message["role"]
|
| 294 |
-
contrast: TokenContrast | None = message.get("_contrast") if show_contrast else None
|
| 295 |
-
msg_col, edit_col = st.columns(
|
| 296 |
-
list(column_ratio), gap="xsmall", vertical_alignment="center"
|
| 297 |
-
)
|
| 298 |
-
|
| 299 |
-
with msg_col:
|
| 300 |
-
with st.chat_message(role):
|
| 301 |
-
if contrast is not None:
|
| 302 |
-
st.html(render_contrast_html(contrast))
|
| 303 |
-
else:
|
| 304 |
-
st.markdown(message["content"])
|
| 305 |
-
with edit_col:
|
| 306 |
-
if st.button(
|
| 307 |
-
"", icon=":material/edit:", key=f"{edit_key}_edit_{msg_index}", help="Edit"
|
| 308 |
-
):
|
| 309 |
-
_open_edit_dialog(
|
| 310 |
-
msg_index=msg_index,
|
| 311 |
-
messages=messages,
|
| 312 |
-
chat_state=chat_state,
|
| 313 |
-
pending_key=pending_key,
|
| 314 |
-
)
|
| 315 |
-
|
| 316 |
-
|
| 317 |
def render_chat_window(
|
| 318 |
*,
|
| 319 |
chat_log: Any,
|
| 320 |
messages: list[dict[str, str]],
|
| 321 |
-
chat_state: dict[str, object]
|
| 322 |
-
edit_key: str
|
| 323 |
-
pending_key: str
|
| 324 |
show_contrast: bool = False,
|
| 325 |
edit_column_ratio: tuple[int, int] = (25, 1),
|
| 326 |
) -> None:
|
| 327 |
with chat_log:
|
| 328 |
for i, message in enumerate(messages):
|
| 329 |
-
if
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
edit_key,
|
| 336 |
-
pending_key,
|
| 337 |
-
show_contrast=show_contrast,
|
| 338 |
-
column_ratio=edit_column_ratio,
|
| 339 |
-
)
|
| 340 |
-
else:
|
| 341 |
render_chat_message(message, show_contrast=show_contrast)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
|
| 343 |
|
| 344 |
def _assistant_first(personas: list[PersonaData]) -> list[PersonaData]:
|
|
|
|
| 269 |
) -> None:
|
| 270 |
if not message.get("content"):
|
| 271 |
return
|
|
|
|
| 272 |
contrast: TokenContrast | None = message.get("_contrast") if show_contrast else None
|
| 273 |
+
with st.chat_message(message["role"]):
|
| 274 |
if contrast is not None:
|
| 275 |
st.html(render_contrast_html(contrast))
|
| 276 |
else:
|
| 277 |
st.markdown(message["content"])
|
| 278 |
|
| 279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
def render_chat_window(
|
| 281 |
*,
|
| 282 |
chat_log: Any,
|
| 283 |
messages: list[dict[str, str]],
|
| 284 |
+
chat_state: dict[str, object],
|
| 285 |
+
edit_key: str,
|
| 286 |
+
pending_key: str,
|
| 287 |
show_contrast: bool = False,
|
| 288 |
edit_column_ratio: tuple[int, int] = (25, 1),
|
| 289 |
) -> None:
|
| 290 |
with chat_log:
|
| 291 |
for i, message in enumerate(messages):
|
| 292 |
+
if not message.get("content"):
|
| 293 |
+
continue
|
| 294 |
+
msg_col, edit_col = st.columns(
|
| 295 |
+
list(edit_column_ratio), gap="xsmall", vertical_alignment="center"
|
| 296 |
+
)
|
| 297 |
+
with msg_col:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
render_chat_message(message, show_contrast=show_contrast)
|
| 299 |
+
with edit_col:
|
| 300 |
+
if st.button(
|
| 301 |
+
"",
|
| 302 |
+
icon=":material/edit:",
|
| 303 |
+
key=f"{edit_key}_edit_{i}",
|
| 304 |
+
help="Edit",
|
| 305 |
+
):
|
| 306 |
+
_open_edit_dialog(
|
| 307 |
+
msg_index=i,
|
| 308 |
+
messages=messages,
|
| 309 |
+
chat_state=chat_state,
|
| 310 |
+
pending_key=pending_key,
|
| 311 |
+
)
|
| 312 |
|
| 313 |
|
| 314 |
def _assistant_first(personas: list[PersonaData]) -> list[PersonaData]:
|
tabs/compare.py
CHANGED
|
@@ -5,10 +5,7 @@ from itertools import combinations
|
|
| 5 |
|
| 6 |
import streamlit as st
|
| 7 |
from persona_data.environment import get_artifacts_dir
|
| 8 |
-
from persona_vectors.analysis import
|
| 9 |
-
load_persona_vectors,
|
| 10 |
-
load_variant_vectors,
|
| 11 |
-
)
|
| 12 |
from persona_vectors.artifacts import ActivationStore, HFActivationStore
|
| 13 |
from persona_vectors.artifacts import list_layers as list_local_layers
|
| 14 |
from persona_vectors.extraction import MaskStrategy
|
|
|
|
| 5 |
|
| 6 |
import streamlit as st
|
| 7 |
from persona_data.environment import get_artifacts_dir
|
| 8 |
+
from persona_vectors.analysis import load_persona_vectors, load_variant_vectors
|
|
|
|
|
|
|
|
|
|
| 9 |
from persona_vectors.artifacts import ActivationStore, HFActivationStore
|
| 10 |
from persona_vectors.artifacts import list_layers as list_local_layers
|
| 11 |
from persona_vectors.extraction import MaskStrategy
|
tabs/extract.py
CHANGED
|
@@ -102,7 +102,9 @@ def _render_variant_controls(
|
|
| 102 |
return selected_variants, include_baseline
|
| 103 |
|
| 104 |
|
| 105 |
-
def _load_qa_dataset_personas(
|
|
|
|
|
|
|
| 106 |
try:
|
| 107 |
dataset, dataset_status = load_dataset(
|
| 108 |
dataset_source,
|
|
@@ -237,7 +239,9 @@ def _collect_runs(
|
|
| 237 |
runs, skipped = [], []
|
| 238 |
for persona in selected_personas:
|
| 239 |
if persona.id == BASELINE_PERSONA_ID:
|
| 240 |
-
qa = list(
|
|
|
|
|
|
|
| 241 |
elif hasattr(dataset, "train_test_split"):
|
| 242 |
qa, _ = dataset.train_test_split(persona.id)
|
| 243 |
else:
|
|
@@ -268,28 +272,15 @@ def _render_max_questions(
|
|
| 268 |
"Max questions (train split)",
|
| 269 |
min_value=1,
|
| 270 |
max_value=max_q,
|
| 271 |
-
value=min(
|
|
|
|
|
|
|
| 272 |
key=_extract_widget_key(model_name, remote, dataset_source, "max_questions"),
|
| 273 |
)
|
| 274 |
st.session_state[_LAST_MAX_QUESTIONS_KEY] = max_questions
|
| 275 |
return max_questions
|
| 276 |
|
| 277 |
|
| 278 |
-
def _render_advanced_settings(
|
| 279 |
-
*,
|
| 280 |
-
model_name: str,
|
| 281 |
-
remote: bool,
|
| 282 |
-
dataset_source: str,
|
| 283 |
-
) -> MaskStrategy:
|
| 284 |
-
with st.expander("Advanced", expanded=False):
|
| 285 |
-
mask_strategy = _render_mask_strategy_select(
|
| 286 |
-
model_name=model_name,
|
| 287 |
-
remote=remote,
|
| 288 |
-
dataset_source=dataset_source,
|
| 289 |
-
)
|
| 290 |
-
return mask_strategy
|
| 291 |
-
|
| 292 |
-
|
| 293 |
def _render_extract_actions() -> tuple[bool, bool]:
|
| 294 |
run_col, preview_col, _spacer = st.columns([1, 1, 4], gap="small")
|
| 295 |
with run_col:
|
|
@@ -439,11 +430,12 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 439 |
dataset_source=dataset_source,
|
| 440 |
runs=runs,
|
| 441 |
)
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
|
|
|
| 447 |
settings = ExtractSettings(
|
| 448 |
mask_strategy=mask_strategy,
|
| 449 |
max_questions=max_questions,
|
|
|
|
| 102 |
return selected_variants, include_baseline
|
| 103 |
|
| 104 |
|
| 105 |
+
def _load_qa_dataset_personas(
|
| 106 |
+
dataset_source: str,
|
| 107 |
+
) -> tuple[object, list[PersonaData]] | None:
|
| 108 |
try:
|
| 109 |
dataset, dataset_status = load_dataset(
|
| 110 |
dataset_source,
|
|
|
|
| 239 |
runs, skipped = [], []
|
| 240 |
for persona in selected_personas:
|
| 241 |
if persona.id == BASELINE_PERSONA_ID:
|
| 242 |
+
qa = list(
|
| 243 |
+
dataset.get_qa(BASELINE_PERSONA_ID, item_type="mcq", scope="shared")
|
| 244 |
+
)
|
| 245 |
elif hasattr(dataset, "train_test_split"):
|
| 246 |
qa, _ = dataset.train_test_split(persona.id)
|
| 247 |
else:
|
|
|
|
| 272 |
"Max questions (train split)",
|
| 273 |
min_value=1,
|
| 274 |
max_value=max_q,
|
| 275 |
+
value=min(
|
| 276 |
+
max(st.session_state.get(_LAST_MAX_QUESTIONS_KEY, default), 1), max_q
|
| 277 |
+
),
|
| 278 |
key=_extract_widget_key(model_name, remote, dataset_source, "max_questions"),
|
| 279 |
)
|
| 280 |
st.session_state[_LAST_MAX_QUESTIONS_KEY] = max_questions
|
| 281 |
return max_questions
|
| 282 |
|
| 283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
def _render_extract_actions() -> tuple[bool, bool]:
|
| 285 |
run_col, preview_col, _spacer = st.columns([1, 1, 4], gap="small")
|
| 286 |
with run_col:
|
|
|
|
| 430 |
dataset_source=dataset_source,
|
| 431 |
runs=runs,
|
| 432 |
)
|
| 433 |
+
with st.expander("Advanced", expanded=False):
|
| 434 |
+
mask_strategy = _render_mask_strategy_select(
|
| 435 |
+
model_name=model_name,
|
| 436 |
+
remote=remote,
|
| 437 |
+
dataset_source=dataset_source,
|
| 438 |
+
)
|
| 439 |
settings = ExtractSettings(
|
| 440 |
mask_strategy=mask_strategy,
|
| 441 |
max_questions=max_questions,
|
utils/chat.py
CHANGED
|
@@ -74,9 +74,7 @@ def _format_plain_messages(
|
|
| 74 |
else:
|
| 75 |
lines.append(f"{role.title()}: {content}")
|
| 76 |
|
| 77 |
-
if add_generation_prompt and (
|
| 78 |
-
not lines or not lines[-1].startswith("Assistant:")
|
| 79 |
-
):
|
| 80 |
lines.append("Assistant:")
|
| 81 |
|
| 82 |
return "\n\n".join(lines)
|
|
@@ -130,6 +128,26 @@ def format_generation_prompt(
|
|
| 130 |
return prompt, prompt_token_count
|
| 131 |
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
@contextmanager
|
| 134 |
def _seeded_rng(seed: int | None):
|
| 135 |
"""Context manager that forks the RNG state and sets a deterministic seed."""
|
|
|
|
| 74 |
else:
|
| 75 |
lines.append(f"{role.title()}: {content}")
|
| 76 |
|
| 77 |
+
if add_generation_prompt and (not lines or not lines[-1].startswith("Assistant:")):
|
|
|
|
|
|
|
| 78 |
lines.append("Assistant:")
|
| 79 |
|
| 80 |
return "\n\n".join(lines)
|
|
|
|
| 128 |
return prompt, prompt_token_count
|
| 129 |
|
| 130 |
|
| 131 |
+
def resolve_saved_tensor(value: object) -> torch.Tensor:
|
| 132 |
+
"""Resolve an nnsight ``.save()`` proxy (or raw tensor) to a CPU tensor."""
|
| 133 |
+
resolved = value.value if getattr(value, "value", None) is not None else value
|
| 134 |
+
if not isinstance(resolved, torch.Tensor):
|
| 135 |
+
raise TypeError(f"Trace result did not resolve to a tensor: {type(resolved)!r}")
|
| 136 |
+
return resolved.detach().cpu()
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def decode_token(tokenizer: object, token_id: int) -> str:
|
| 140 |
+
"""Decode a single token id, falling back when ``clean_up_tokenization_spaces`` is unsupported."""
|
| 141 |
+
try:
|
| 142 |
+
return tokenizer.decode(
|
| 143 |
+
[token_id],
|
| 144 |
+
skip_special_tokens=False,
|
| 145 |
+
clean_up_tokenization_spaces=False,
|
| 146 |
+
)
|
| 147 |
+
except TypeError:
|
| 148 |
+
return tokenizer.decode([token_id], skip_special_tokens=False)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
@contextmanager
|
| 152 |
def _seeded_rng(seed: int | None):
|
| 153 |
"""Context manager that forks the RNG state and sets a deterministic seed."""
|
utils/contrast.py
CHANGED
|
@@ -17,7 +17,7 @@ from html import escape
|
|
| 17 |
import torch
|
| 18 |
from nnterp import StandardizedTransformer
|
| 19 |
|
| 20 |
-
from utils.chat import format_generation_prompt
|
| 21 |
|
| 22 |
|
| 23 |
@dataclass
|
|
@@ -43,18 +43,6 @@ def _normalise_diffs(diffs: torch.Tensor) -> list[float]:
|
|
| 43 |
return (diffs.float().clamp(-clip_val, clip_val) / clip_val).tolist()
|
| 44 |
|
| 45 |
|
| 46 |
-
def _decode_ids(tokenizer: object, ids: list[int]) -> str:
|
| 47 |
-
"""Decode token IDs, falling back when clean_up_tokenization_spaces is unsupported."""
|
| 48 |
-
try:
|
| 49 |
-
return tokenizer.decode(
|
| 50 |
-
ids,
|
| 51 |
-
skip_special_tokens=False,
|
| 52 |
-
clean_up_tokenization_spaces=False,
|
| 53 |
-
)
|
| 54 |
-
except TypeError:
|
| 55 |
-
return tokenizer.decode(ids, skip_special_tokens=False)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
def _strip_special_ids(
|
| 59 |
ids: torch.Tensor,
|
| 60 |
tokenizer: object,
|
|
@@ -96,7 +84,7 @@ def _build_contrast(
|
|
| 96 |
display_ids, keep_mask = _strip_special_ids(response_ids, tokenizer)
|
| 97 |
display_diffs = diffs[keep_mask]
|
| 98 |
return TokenContrast(
|
| 99 |
-
tokens=[
|
| 100 |
weights=_normalise_diffs(display_diffs),
|
| 101 |
raw_diffs=display_diffs.float().tolist(),
|
| 102 |
label_a=label_a,
|
|
@@ -104,11 +92,6 @@ def _build_contrast(
|
|
| 104 |
)
|
| 105 |
|
| 106 |
|
| 107 |
-
def _token_display(tokenizer: object, token_id: int) -> str:
|
| 108 |
-
"""Render a single token id as normal decoded text."""
|
| 109 |
-
return _decode_ids(tokenizer, [token_id])
|
| 110 |
-
|
| 111 |
-
|
| 112 |
# Each spec: (key, input_ids, n_ctx, n_resp, target_ids).
|
| 113 |
PassSpec = tuple[str, torch.Tensor, int, int, torch.Tensor]
|
| 114 |
|
|
@@ -140,14 +123,7 @@ def _score_passes(
|
|
| 140 |
targets = target_ids.to(log_probs.device).view(-1, 1)
|
| 141 |
picked = log_probs.gather(1, targets).view(-1)
|
| 142 |
out = picked.detach().cpu().save()
|
| 143 |
-
|
| 144 |
-
if getattr(out, "value", None) is not None:
|
| 145 |
-
out = out.value
|
| 146 |
-
if not isinstance(out, torch.Tensor):
|
| 147 |
-
raise TypeError(
|
| 148 |
-
f"contrast score did not resolve to a tensor: {type(out)!r}"
|
| 149 |
-
)
|
| 150 |
-
return out.detach().cpu()
|
| 151 |
|
| 152 |
return {
|
| 153 |
key: _score_pass(input_ids, n_ctx, n_resp, target_ids)
|
|
|
|
| 17 |
import torch
|
| 18 |
from nnterp import StandardizedTransformer
|
| 19 |
|
| 20 |
+
from utils.chat import decode_token, format_generation_prompt, resolve_saved_tensor
|
| 21 |
|
| 22 |
|
| 23 |
@dataclass
|
|
|
|
| 43 |
return (diffs.float().clamp(-clip_val, clip_val) / clip_val).tolist()
|
| 44 |
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
def _strip_special_ids(
|
| 47 |
ids: torch.Tensor,
|
| 48 |
tokenizer: object,
|
|
|
|
| 84 |
display_ids, keep_mask = _strip_special_ids(response_ids, tokenizer)
|
| 85 |
display_diffs = diffs[keep_mask]
|
| 86 |
return TokenContrast(
|
| 87 |
+
tokens=[decode_token(tokenizer, tid.item()) for tid in display_ids],
|
| 88 |
weights=_normalise_diffs(display_diffs),
|
| 89 |
raw_diffs=display_diffs.float().tolist(),
|
| 90 |
label_a=label_a,
|
|
|
|
| 92 |
)
|
| 93 |
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
# Each spec: (key, input_ids, n_ctx, n_resp, target_ids).
|
| 96 |
PassSpec = tuple[str, torch.Tensor, int, int, torch.Tensor]
|
| 97 |
|
|
|
|
| 123 |
targets = target_ids.to(log_probs.device).view(-1, 1)
|
| 124 |
picked = log_probs.gather(1, targets).view(-1)
|
| 125 |
out = picked.detach().cpu().save()
|
| 126 |
+
return resolve_saved_tensor(out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
return {
|
| 129 |
key: _score_pass(input_ids, n_ctx, n_resp, target_ids)
|
utils/datasets.py
CHANGED
|
@@ -17,24 +17,10 @@ from .helpers import DATASET_SOURCES
|
|
| 17 |
|
| 18 |
|
| 19 |
@st.cache_resource(show_spinner=False)
|
| 20 |
-
def
|
| 21 |
-
"""
|
| 22 |
|
| 23 |
-
return
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
@st.cache_resource(show_spinner=False)
|
| 27 |
-
def cached_nemotron_dataset() -> NemotronPersonasFranceDataset:
|
| 28 |
-
"""Load the Nemotron France HuggingFace dataset once."""
|
| 29 |
-
|
| 30 |
-
return NemotronPersonasFranceDataset()
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
@st.cache_resource(show_spinner=False)
|
| 34 |
-
def cached_nemotron_usa_dataset() -> NemotronPersonasUSADataset:
|
| 35 |
-
"""Load the Nemotron USA HuggingFace dataset once."""
|
| 36 |
-
|
| 37 |
-
return NemotronPersonasUSADataset()
|
| 38 |
|
| 39 |
|
| 40 |
def _upload_cache_dir() -> Path:
|
|
@@ -74,13 +60,13 @@ def load_dataset(
|
|
| 74 |
"""Load the selected dataset source for the UI."""
|
| 75 |
|
| 76 |
if dataset_source == DATASET_SOURCES[0]:
|
| 77 |
-
return
|
| 78 |
|
| 79 |
if dataset_source == DATASET_SOURCES[1]:
|
| 80 |
-
return
|
| 81 |
|
| 82 |
if dataset_source == DATASET_SOURCES[2]:
|
| 83 |
-
return
|
| 84 |
|
| 85 |
if personas_file is None or qa_file is None:
|
| 86 |
raise ValueError("Upload both personas.jsonl and qa.jsonl files")
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
@st.cache_resource(show_spinner=False)
|
| 20 |
+
def _cached_dataset(cls: type) -> Any:
|
| 21 |
+
"""Instantiate and cache a HuggingFace dataset class once per session."""
|
| 22 |
|
| 23 |
+
return cls()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
def _upload_cache_dir() -> Path:
|
|
|
|
| 60 |
"""Load the selected dataset source for the UI."""
|
| 61 |
|
| 62 |
if dataset_source == DATASET_SOURCES[0]:
|
| 63 |
+
return _cached_dataset(SynthPersonaDataset), "SynthPersona"
|
| 64 |
|
| 65 |
if dataset_source == DATASET_SOURCES[1]:
|
| 66 |
+
return _cached_dataset(NemotronPersonasFranceDataset), "Nemotron France"
|
| 67 |
|
| 68 |
if dataset_source == DATASET_SOURCES[2]:
|
| 69 |
+
return _cached_dataset(NemotronPersonasUSADataset), "Nemotron USA"
|
| 70 |
|
| 71 |
if personas_file is None or qa_file is None:
|
| 72 |
raise ValueError("Upload both personas.jsonl and qa.jsonl files")
|
utils/helpers.py
CHANGED
|
@@ -13,9 +13,7 @@ VARIANT_LABELS = {
|
|
| 13 |
|
| 14 |
CHAT_PROMPT_MODES = ("empty", "templated", "biography", "custom")
|
| 15 |
CHAT_PROMPT_MODE_LABELS = [VARIANT_LABELS[key] for key in CHAT_PROMPT_MODES]
|
| 16 |
-
CHAT_PROMPT_MODE_LABEL_TO_KEY = {
|
| 17 |
-
VARIANT_LABELS[key]: key for key in CHAT_PROMPT_MODES
|
| 18 |
-
}
|
| 19 |
|
| 20 |
|
| 21 |
DATASET_SOURCES = [
|
|
|
|
| 13 |
|
| 14 |
CHAT_PROMPT_MODES = ("empty", "templated", "biography", "custom")
|
| 15 |
CHAT_PROMPT_MODE_LABELS = [VARIANT_LABELS[key] for key in CHAT_PROMPT_MODES]
|
| 16 |
+
CHAT_PROMPT_MODE_LABEL_TO_KEY = {VARIANT_LABELS[key]: key for key in CHAT_PROMPT_MODES}
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
DATASET_SOURCES = [
|
utils/probe_trace.py
CHANGED
|
@@ -7,7 +7,7 @@ import streamlit as st
|
|
| 7 |
import torch
|
| 8 |
from nnterp import StandardizedTransformer
|
| 9 |
|
| 10 |
-
from utils.chat import format_generation_prompt
|
| 11 |
|
| 12 |
_TRACE_CACHE_KEY = "probe:trace_cache"
|
| 13 |
_MAX_CACHED_TRACES = 3
|
|
@@ -74,8 +74,8 @@ def trace_conversation(
|
|
| 74 |
saved_ids = model.input_ids[0].detach().cpu().save()
|
| 75 |
saved_acts = accessor[layer][0].detach().float().cpu().save()
|
| 76 |
|
| 77 |
-
input_ids =
|
| 78 |
-
activations =
|
| 79 |
if input_ids.ndim != 1:
|
| 80 |
raise ValueError(
|
| 81 |
f"Expected traced input ids to be [seq], got {tuple(input_ids.shape)}"
|
|
@@ -125,17 +125,6 @@ def vectorize_token(
|
|
| 125 |
)
|
| 126 |
|
| 127 |
|
| 128 |
-
def decode_token(tokenizer: object, token_id: int) -> str:
|
| 129 |
-
try:
|
| 130 |
-
return tokenizer.decode(
|
| 131 |
-
[token_id],
|
| 132 |
-
skip_special_tokens=False,
|
| 133 |
-
clean_up_tokenization_spaces=False,
|
| 134 |
-
)
|
| 135 |
-
except TypeError:
|
| 136 |
-
return tokenizer.decode([token_id], skip_special_tokens=False)
|
| 137 |
-
|
| 138 |
-
|
| 139 |
def _select_accessor(model: StandardizedTransformer, location: str):
|
| 140 |
normalized = location.lower()
|
| 141 |
if normalized in {"pre_reasoning", "pre", "input", "layers_input"}:
|
|
@@ -145,13 +134,6 @@ def _select_accessor(model: StandardizedTransformer, location: str):
|
|
| 145 |
raise ValueError(f"Unsupported trace location: {location!r}")
|
| 146 |
|
| 147 |
|
| 148 |
-
def _resolve_saved_tensor(value) -> torch.Tensor:
|
| 149 |
-
resolved = value.value if getattr(value, "value", None) is not None else value
|
| 150 |
-
if not isinstance(resolved, torch.Tensor):
|
| 151 |
-
raise TypeError(f"Trace result did not resolve to a tensor: {type(resolved)!r}")
|
| 152 |
-
return resolved.detach().cpu()
|
| 153 |
-
|
| 154 |
-
|
| 155 |
def _trace_cache_key(
|
| 156 |
*,
|
| 157 |
model_name: str,
|
|
|
|
| 7 |
import torch
|
| 8 |
from nnterp import StandardizedTransformer
|
| 9 |
|
| 10 |
+
from utils.chat import decode_token, format_generation_prompt, resolve_saved_tensor
|
| 11 |
|
| 12 |
_TRACE_CACHE_KEY = "probe:trace_cache"
|
| 13 |
_MAX_CACHED_TRACES = 3
|
|
|
|
| 74 |
saved_ids = model.input_ids[0].detach().cpu().save()
|
| 75 |
saved_acts = accessor[layer][0].detach().float().cpu().save()
|
| 76 |
|
| 77 |
+
input_ids = resolve_saved_tensor(saved_ids)
|
| 78 |
+
activations = resolve_saved_tensor(saved_acts)
|
| 79 |
if input_ids.ndim != 1:
|
| 80 |
raise ValueError(
|
| 81 |
f"Expected traced input ids to be [seq], got {tuple(input_ids.shape)}"
|
|
|
|
| 125 |
)
|
| 126 |
|
| 127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
def _select_accessor(model: StandardizedTransformer, location: str):
|
| 129 |
normalized = location.lower()
|
| 130 |
if normalized in {"pre_reasoning", "pre", "input", "layers_input"}:
|
|
|
|
| 134 |
raise ValueError(f"Unsupported trace location: {location!r}")
|
| 135 |
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
def _trace_cache_key(
|
| 138 |
*,
|
| 139 |
model_name: str,
|
utils/probes.py
CHANGED
|
@@ -225,15 +225,28 @@ def _load_probe_payload(
|
|
| 225 |
num_classes=num_classes,
|
| 226 |
)
|
| 227 |
labels = _normalize_labels(payload.get("idx_to_label"), num_classes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
return LoadedProbe(
|
| 229 |
model=model,
|
| 230 |
input_dim=input_dim,
|
| 231 |
labels=labels,
|
| 232 |
model_type=str(payload.get("model_type") or metadata.model_type),
|
| 233 |
-
layer=
|
| 234 |
-
location=
|
| 235 |
-
scaler_mean=
|
| 236 |
-
scaler_std=
|
| 237 |
)
|
| 238 |
|
| 239 |
|
|
@@ -296,7 +309,9 @@ def _coerce_probe_dim(
|
|
| 296 |
weights = [
|
| 297 |
tensor
|
| 298 |
for key, tensor in state_dict.items()
|
| 299 |
-
if key.endswith("weight")
|
|
|
|
|
|
|
| 300 |
]
|
| 301 |
if not weights:
|
| 302 |
raise ValueError(f"Cannot infer probe {dim} dimension from state dict")
|
|
@@ -349,27 +364,12 @@ def _coerce_hidden_dims(value: Any) -> list[int]:
|
|
| 349 |
raise TypeError(f"Unsupported hidden_dims value: {type(value)!r}")
|
| 350 |
|
| 351 |
|
| 352 |
-
def
|
| 353 |
-
if
|
| 354 |
return None
|
| 355 |
return value.detach().cpu()
|
| 356 |
|
| 357 |
|
| 358 |
-
def _coerce_optional_int(value: Any, fallback: int | None) -> int | None:
|
| 359 |
-
if value is None:
|
| 360 |
-
return fallback
|
| 361 |
-
try:
|
| 362 |
-
return int(value)
|
| 363 |
-
except (TypeError, ValueError):
|
| 364 |
-
return fallback
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
def _coerce_location(value: Any, fallback: str | None) -> str | None:
|
| 368 |
-
if isinstance(value, str) and value:
|
| 369 |
-
return value
|
| 370 |
-
return fallback
|
| 371 |
-
|
| 372 |
-
|
| 373 |
def _normalize_labels(raw_labels: Any, num_classes: int) -> list[str | None]:
|
| 374 |
if isinstance(raw_labels, (list, tuple)):
|
| 375 |
labels = [str(label) for label in raw_labels[:num_classes]]
|
|
|
|
| 225 |
num_classes=num_classes,
|
| 226 |
)
|
| 227 |
labels = _normalize_labels(payload.get("idx_to_label"), num_classes)
|
| 228 |
+
|
| 229 |
+
raw_layer = payload.get("layer")
|
| 230 |
+
try:
|
| 231 |
+
layer = int(raw_layer) if raw_layer is not None else metadata.layer
|
| 232 |
+
except (TypeError, ValueError):
|
| 233 |
+
layer = metadata.layer
|
| 234 |
+
raw_location = payload.get("location")
|
| 235 |
+
location = (
|
| 236 |
+
raw_location
|
| 237 |
+
if isinstance(raw_location, str) and raw_location
|
| 238 |
+
else metadata.location
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
return LoadedProbe(
|
| 242 |
model=model,
|
| 243 |
input_dim=input_dim,
|
| 244 |
labels=labels,
|
| 245 |
model_type=str(payload.get("model_type") or metadata.model_type),
|
| 246 |
+
layer=layer,
|
| 247 |
+
location=location,
|
| 248 |
+
scaler_mean=_as_cpu_tensor(payload.get("scaler_mean")),
|
| 249 |
+
scaler_std=_as_cpu_tensor(payload.get("scaler_std")),
|
| 250 |
)
|
| 251 |
|
| 252 |
|
|
|
|
| 309 |
weights = [
|
| 310 |
tensor
|
| 311 |
for key, tensor in state_dict.items()
|
| 312 |
+
if key.endswith("weight")
|
| 313 |
+
and isinstance(tensor, torch.Tensor)
|
| 314 |
+
and tensor.ndim == 2
|
| 315 |
]
|
| 316 |
if not weights:
|
| 317 |
raise ValueError(f"Cannot infer probe {dim} dimension from state dict")
|
|
|
|
| 364 |
raise TypeError(f"Unsupported hidden_dims value: {type(value)!r}")
|
| 365 |
|
| 366 |
|
| 367 |
+
def _as_cpu_tensor(value: Any) -> torch.Tensor | None:
|
| 368 |
+
if not isinstance(value, torch.Tensor):
|
| 369 |
return None
|
| 370 |
return value.detach().cpu()
|
| 371 |
|
| 372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
def _normalize_labels(raw_labels: Any, num_classes: int) -> list[str | None]:
|
| 374 |
if isinstance(raw_labels, (list, tuple)):
|
| 375 |
labels = [str(label) for label in raw_labels[:num_classes]]
|
uv.lock
CHANGED
|
@@ -1120,11 +1120,11 @@ wheels = [
|
|
| 1120 |
|
| 1121 |
[[package]]
|
| 1122 |
name = "narwhals"
|
| 1123 |
-
version = "2.
|
| 1124 |
source = { registry = "https://pypi.org/simple" }
|
| 1125 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 1126 |
wheels = [
|
| 1127 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 1128 |
]
|
| 1129 |
|
| 1130 |
[[package]]
|
|
@@ -1550,7 +1550,7 @@ wheels = [
|
|
| 1550 |
|
| 1551 |
[[package]]
|
| 1552 |
name = "persona-data"
|
| 1553 |
-
version = "0.4.
|
| 1554 |
source = { registry = "https://pypi.org/simple" }
|
| 1555 |
dependencies = [
|
| 1556 |
{ name = "huggingface-hub" },
|
|
@@ -1559,9 +1559,9 @@ dependencies = [
|
|
| 1559 |
{ name = "python-dotenv" },
|
| 1560 |
{ name = "torch" },
|
| 1561 |
]
|
| 1562 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 1563 |
wheels = [
|
| 1564 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 1565 |
]
|
| 1566 |
|
| 1567 |
[[package]]
|
|
@@ -1578,8 +1578,8 @@ dependencies = [
|
|
| 1578 |
|
| 1579 |
[package.metadata]
|
| 1580 |
requires-dist = [
|
| 1581 |
-
{ name = "persona-data", specifier = ">=0.4.
|
| 1582 |
-
{ name = "persona-vectors", specifier = ">=0.6.
|
| 1583 |
{ name = "plotly", specifier = ">=6.6.0" },
|
| 1584 |
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
| 1585 |
{ name = "streamlit", specifier = ">=1.44.0" },
|
|
@@ -1587,7 +1587,7 @@ requires-dist = [
|
|
| 1587 |
|
| 1588 |
[[package]]
|
| 1589 |
name = "persona-vectors"
|
| 1590 |
-
version = "0.6.
|
| 1591 |
source = { registry = "https://pypi.org/simple" }
|
| 1592 |
dependencies = [
|
| 1593 |
{ name = "datasets" },
|
|
@@ -1606,9 +1606,9 @@ dependencies = [
|
|
| 1606 |
{ name = "transformers" },
|
| 1607 |
{ name = "umap-learn" },
|
| 1608 |
]
|
| 1609 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 1610 |
wheels = [
|
| 1611 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 1612 |
]
|
| 1613 |
|
| 1614 |
[[package]]
|
|
@@ -2912,11 +2912,11 @@ wheels = [
|
|
| 2912 |
|
| 2913 |
[[package]]
|
| 2914 |
name = "urllib3"
|
| 2915 |
-
version = "2.
|
| 2916 |
source = { registry = "https://pypi.org/simple" }
|
| 2917 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 2918 |
wheels = [
|
| 2919 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 2920 |
]
|
| 2921 |
|
| 2922 |
[[package]]
|
|
|
|
| 1120 |
|
| 1121 |
[[package]]
|
| 1122 |
name = "narwhals"
|
| 1123 |
+
version = "2.21.0"
|
| 1124 |
source = { registry = "https://pypi.org/simple" }
|
| 1125 |
+
sdist = { url = "https://files.pythonhosted.org/packages/2d/0e/3ad61eb87088cc4932e0d851531fa82f845a6230b68b091a0e298cc7e537/narwhals-2.21.0.tar.gz", hash = "sha256:7c6e7f50528e62b7a967dd864d7e117d2955d38d4f730653ce46a9861358e2dc", size = 633083, upload-time = "2026-05-08T12:29:02.587Z" }
|
| 1126 |
wheels = [
|
| 1127 |
+
{ url = "https://files.pythonhosted.org/packages/c7/e1/68c2256b69a314eba133673377ba9118c356f6342a0c02b61de449cf2bf2/narwhals-2.21.0-py3-none-any.whl", hash = "sha256:1e6617d0fca68ae1fda29e5397c4eaacd3ffc9fffe6bcd6ded0c690475e853be", size = 451943, upload-time = "2026-05-08T12:29:01.058Z" },
|
| 1128 |
]
|
| 1129 |
|
| 1130 |
[[package]]
|
|
|
|
| 1550 |
|
| 1551 |
[[package]]
|
| 1552 |
name = "persona-data"
|
| 1553 |
+
version = "0.4.2"
|
| 1554 |
source = { registry = "https://pypi.org/simple" }
|
| 1555 |
dependencies = [
|
| 1556 |
{ name = "huggingface-hub" },
|
|
|
|
| 1559 |
{ name = "python-dotenv" },
|
| 1560 |
{ name = "torch" },
|
| 1561 |
]
|
| 1562 |
+
sdist = { url = "https://files.pythonhosted.org/packages/a4/2f/099a74e54846172a20b697b46b285eb2f0004e1db530308d6b4ff1f19079/persona_data-0.4.2.tar.gz", hash = "sha256:7870292a79b3943a77c31595140de3b2243b783222590248d09891de70e7fe1b", size = 9276, upload-time = "2026-05-08T13:59:27.58Z" }
|
| 1563 |
wheels = [
|
| 1564 |
+
{ url = "https://files.pythonhosted.org/packages/57/03/e76a48b41ee00684a4430269007e217e70f59e2597d7c862d93cfc5ac78b/persona_data-0.4.2-py3-none-any.whl", hash = "sha256:c881d6fb71af87a6fa773284076e4cb55794db6dc447a7eb0047eee2b389c855", size = 11914, upload-time = "2026-05-08T13:59:28.198Z" },
|
| 1565 |
]
|
| 1566 |
|
| 1567 |
[[package]]
|
|
|
|
| 1578 |
|
| 1579 |
[package.metadata]
|
| 1580 |
requires-dist = [
|
| 1581 |
+
{ name = "persona-data", specifier = ">=0.4.2" },
|
| 1582 |
+
{ name = "persona-vectors", specifier = ">=0.6.3" },
|
| 1583 |
{ name = "plotly", specifier = ">=6.6.0" },
|
| 1584 |
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
| 1585 |
{ name = "streamlit", specifier = ">=1.44.0" },
|
|
|
|
| 1587 |
|
| 1588 |
[[package]]
|
| 1589 |
name = "persona-vectors"
|
| 1590 |
+
version = "0.6.3"
|
| 1591 |
source = { registry = "https://pypi.org/simple" }
|
| 1592 |
dependencies = [
|
| 1593 |
{ name = "datasets" },
|
|
|
|
| 1606 |
{ name = "transformers" },
|
| 1607 |
{ name = "umap-learn" },
|
| 1608 |
]
|
| 1609 |
+
sdist = { url = "https://files.pythonhosted.org/packages/42/f5/57836026dc1b8c716ff6e443ba3cc8fafef108078e52f872c101f66ab61c/persona_vectors-0.6.3.tar.gz", hash = "sha256:2389aaa4ab5e83c4541556a000e0268ad3f1f2d5e741ade9830cb3da972332c5", size = 24509, upload-time = "2026-05-08T14:10:37.09Z" }
|
| 1610 |
wheels = [
|
| 1611 |
+
{ url = "https://files.pythonhosted.org/packages/3c/92/912d2a6998bcc103631597125bad5b5644c981b52e62fff229aee64139ae/persona_vectors-0.6.3-py3-none-any.whl", hash = "sha256:9a7f275c7e58990e1228a0d35ca2a8898eb8330fd4a9a627fb28fc574883d260", size = 29366, upload-time = "2026-05-08T14:10:38.184Z" },
|
| 1612 |
]
|
| 1613 |
|
| 1614 |
[[package]]
|
|
|
|
| 2912 |
|
| 2913 |
[[package]]
|
| 2914 |
name = "urllib3"
|
| 2915 |
+
version = "2.7.0"
|
| 2916 |
source = { registry = "https://pypi.org/simple" }
|
| 2917 |
+
sdist = { url = "https://files.pythonhosted.org/packages/53/0c/06f8b233b8fd13b9e5ee11424ef85419ba0d8ba0b3138bf360be2ff56953/urllib3-2.7.0.tar.gz", hash = "sha256:231e0ec3b63ceb14667c67be60f2f2c40a518cb38b03af60abc813da26505f4c", size = 433602, upload-time = "2026-05-07T16:13:18.596Z" }
|
| 2918 |
wheels = [
|
| 2919 |
+
{ url = "https://files.pythonhosted.org/packages/7f/3e/5db95bcf282c52709639744ca2a8b149baccf648e39c8cc87553df9eae0c/urllib3-2.7.0-py3-none-any.whl", hash = "sha256:9fb4c81ebbb1ce9531cce37674bbc6f1360472bc18ca9a553ede278ef7276897", size = 131087, upload-time = "2026-05-07T16:13:17.151Z" },
|
| 2920 |
]
|
| 2921 |
|
| 2922 |
[[package]]
|