Jac-Zac commited on
Commit ·
ae347c6
1
Parent(s): b279884
add session-scoped NDIF execution and improve cold-load UX
Browse files- bind remote NDIF backends to per-session API keys instead of
process-global state
- route chat, compare-chat, probe tracing, contrast scoring, and
extraction through explicit remote backends
- reuse upstream extraction via backend_factory with persona-vectors
- add Hugging Face cold-load notices for datasets and vector stores
- improve NDIF model discovery resilience and refresh sidebar key
handling
- update docs and env guidance for per-session NDIF keys and current
dependency setup
- Improved general things in the ui and added dataset loading
information
- Added ndif information when performing genereation
- .env.example +3 -3
- README.md +11 -9
- app.py +5 -7
- pyproject.toml +2 -1
- tabs/chat.py +2 -1
- tabs/chat_shared.py +2 -0
- tabs/compare_chat.py +5 -2
- tabs/extract.py +14 -2
- tabs/probe_ui.py +2 -1
- tests/test_analysis_sources.py +57 -0
- tests/test_datasets.py +17 -4
- tests/test_runtime_session_ndif.py +75 -0
- utils/analysis_sources.py +44 -13
- utils/chat.py +7 -32
- utils/contrast.py +12 -3
- utils/datasets.py +3 -1
- utils/probe_trace.py +8 -1
- utils/runtime.py +53 -5
- uv.lock +4 -4
.env.example
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
# Copy this file to .env and fill in the values.
|
| 2 |
|
| 3 |
-
# NDIF API key for remote nnsight execution
|
| 4 |
-
#
|
| 5 |
-
# Get
|
| 6 |
NDIF_API_KEY=your-ndif-api-key-here
|
| 7 |
|
| 8 |
# HuggingFace model cache directory
|
|
|
|
| 1 |
# Copy this file to .env and fill in the values.
|
| 2 |
|
| 3 |
+
# Optional app-level NDIF API key for remote nnsight execution.
|
| 4 |
+
# If omitted, users can enter their own per-session key in the sidebar.
|
| 5 |
+
# Get one at https://login.ndif.us
|
| 6 |
NDIF_API_KEY=your-ndif-api-key-here
|
| 7 |
|
| 8 |
# HuggingFace model cache directory
|
README.md
CHANGED
|
@@ -63,11 +63,9 @@ cp .env.example .env
|
|
| 63 |
|
| 64 |
## Local Development
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
`persona-data` can also be checked out next to this repo for local package work.
|
| 71 |
|
| 72 |
Example:
|
| 73 |
|
|
@@ -97,13 +95,15 @@ This app can be deployed to Hugging Face Spaces using Docker.
|
|
| 97 |
|
| 98 |
### Prerequisites
|
| 99 |
|
| 100 |
-
|
|
|
|
|
|
|
| 101 |
|
| 102 |
### Build Locally (Optional)
|
| 103 |
|
| 104 |
```bash
|
| 105 |
docker build -t persona-ui .
|
| 106 |
-
#
|
| 107 |
docker run --env-file .env --rm -p 8501:8501 persona-ui
|
| 108 |
```
|
| 109 |
|
|
@@ -112,7 +112,7 @@ docker run --env-file .env --rm -p 8501:8501 persona-ui
|
|
| 112 |
Copy `.env.example` to `.env` and fill in:
|
| 113 |
|
| 114 |
```bash
|
| 115 |
-
NDIF_API_KEY=... #
|
| 116 |
HF_HOME=... # Optional: HuggingFace cache directory
|
| 117 |
ARTIFACTS_DIR=... # Optional: where persona vectors are read from (default: ./artifacts)
|
| 118 |
PERSONA_VECTORS_HUB_REPO=... # Optional: default Analysis/Probing Hub dataset repo
|
|
@@ -122,7 +122,9 @@ PERSONA_UI_FIGURE_STATE_ENTRIES=2 # Optional: recent rendered Analysis figur
|
|
| 122 |
PERSONA_UI_PREPARED_STATE_ENTRIES=4 # Optional: recent projection-ready markers kept in-session
|
| 123 |
```
|
| 124 |
|
| 125 |
-
The app picks up this file automatically via `load_dotenv()` on startup.
|
|
|
|
|
|
|
| 126 |
|
| 127 |
## Persona Vectors
|
| 128 |
|
|
|
|
| 63 |
|
| 64 |
## Local Development
|
| 65 |
|
| 66 |
+
The checked-in dependency config uses published packages. For local package
|
| 67 |
+
work, uncomment the `tool.uv.sources` block in `pyproject.toml` and keep sibling
|
| 68 |
+
checkouts next to this repo.
|
|
|
|
|
|
|
| 69 |
|
| 70 |
Example:
|
| 71 |
|
|
|
|
| 95 |
|
| 96 |
### Prerequisites
|
| 97 |
|
| 98 |
+
Dependencies are published on PyPI, so deployment does not require sibling
|
| 99 |
+
checkouts. Remote NDIF execution still needs an API key, either configured as an
|
| 100 |
+
environment variable or entered by each user in the sidebar.
|
| 101 |
|
| 102 |
### Build Locally (Optional)
|
| 103 |
|
| 104 |
```bash
|
| 105 |
docker build -t persona-ui .
|
| 106 |
+
# Pass your local .env if you want the container to use the same configuration
|
| 107 |
docker run --env-file .env --rm -p 8501:8501 persona-ui
|
| 108 |
```
|
| 109 |
|
|
|
|
| 112 |
Copy `.env.example` to `.env` and fill in:
|
| 113 |
|
| 114 |
```bash
|
| 115 |
+
NDIF_API_KEY=... # Optional shared NDIF key; users can also enter one per session
|
| 116 |
HF_HOME=... # Optional: HuggingFace cache directory
|
| 117 |
ARTIFACTS_DIR=... # Optional: where persona vectors are read from (default: ./artifacts)
|
| 118 |
PERSONA_VECTORS_HUB_REPO=... # Optional: default Analysis/Probing Hub dataset repo
|
|
|
|
| 122 |
PERSONA_UI_PREPARED_STATE_ENTRIES=4 # Optional: recent projection-ready markers kept in-session
|
| 123 |
```
|
| 124 |
|
| 125 |
+
The app picks up this file automatically via `load_dotenv()` on startup. If
|
| 126 |
+
`NDIF_API_KEY` is unset, Chat and Extract users are prompted for a per-session
|
| 127 |
+
key when they need remote execution.
|
| 128 |
|
| 129 |
## Persona Vectors
|
| 130 |
|
app.py
CHANGED
|
@@ -7,7 +7,7 @@ from dotenv import load_dotenv
|
|
| 7 |
from utils.analysis_sources import DEFAULT_COMPARE_MODEL, DEFAULT_HUB_REPO, SOURCE_HUB
|
| 8 |
from utils.helpers import DATASET_SOURCES, session_key, widget_key
|
| 9 |
from utils.preload import preload_once
|
| 10 |
-
from utils.runtime import list_remote_models
|
| 11 |
from utils.theme import active_base, install_catppuccin_theme
|
| 12 |
|
| 13 |
load_dotenv()
|
|
@@ -181,10 +181,10 @@ def _remote_model_input(remote_models: list[str]) -> str:
|
|
| 181 |
|
| 182 |
|
| 183 |
def _ndif_api_key_input() -> None:
|
| 184 |
-
"""Prompt for
|
| 185 |
-
import nnsight
|
| 186 |
|
| 187 |
-
if
|
|
|
|
| 188 |
return
|
| 189 |
|
| 190 |
api_key = st.text_input(
|
|
@@ -193,9 +193,7 @@ def _ndif_api_key_input() -> None:
|
|
| 193 |
key=_SIDEBAR_NDIF_API_KEY,
|
| 194 |
help=f"Required for remote (NDIF) execution. Register at {NDIF_REGISTRATION_URL}",
|
| 195 |
)
|
| 196 |
-
if api_key:
|
| 197 |
-
nnsight.CONFIG.API.APIKEY = api_key
|
| 198 |
-
else:
|
| 199 |
st.caption(f"No NDIF API key found. [Get one]({NDIF_REGISTRATION_URL}).")
|
| 200 |
|
| 201 |
|
|
|
|
| 7 |
from utils.analysis_sources import DEFAULT_COMPARE_MODEL, DEFAULT_HUB_REPO, SOURCE_HUB
|
| 8 |
from utils.helpers import DATASET_SOURCES, session_key, widget_key
|
| 9 |
from utils.preload import preload_once
|
| 10 |
+
from utils.runtime import configured_ndif_api_key, list_remote_models
|
| 11 |
from utils.theme import active_base, install_catppuccin_theme
|
| 12 |
|
| 13 |
load_dotenv()
|
|
|
|
| 181 |
|
| 182 |
|
| 183 |
def _ndif_api_key_input() -> None:
|
| 184 |
+
"""Prompt for a per-session NDIF API key."""
|
|
|
|
| 185 |
|
| 186 |
+
if configured_ndif_api_key():
|
| 187 |
+
st.caption("Using NDIF API key from environment.")
|
| 188 |
return
|
| 189 |
|
| 190 |
api_key = st.text_input(
|
|
|
|
| 193 |
key=_SIDEBAR_NDIF_API_KEY,
|
| 194 |
help=f"Required for remote (NDIF) execution. Register at {NDIF_REGISTRATION_URL}",
|
| 195 |
)
|
| 196 |
+
if not api_key:
|
|
|
|
|
|
|
| 197 |
st.caption(f"No NDIF API key found. [Get one]({NDIF_REGISTRATION_URL}).")
|
| 198 |
|
| 199 |
|
pyproject.toml
CHANGED
|
@@ -5,7 +5,7 @@ description = "Streamlit UI for persona-vectors"
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
-
"persona-vectors>=0.8.
|
| 9 |
"datasets>=4.8.5",
|
| 10 |
"huggingface-hub>=1.14.0",
|
| 11 |
"streamlit>=1.44.0",
|
|
@@ -22,6 +22,7 @@ dev = [
|
|
| 22 |
|
| 23 |
[tool.pytest.ini_options]
|
| 24 |
testpaths = ["tests"]
|
|
|
|
| 25 |
|
| 26 |
# Local development:
|
| 27 |
# [tool.uv.sources]
|
|
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
+
"persona-vectors>=0.8.4",
|
| 9 |
"datasets>=4.8.5",
|
| 10 |
"huggingface-hub>=1.14.0",
|
| 11 |
"streamlit>=1.44.0",
|
|
|
|
| 22 |
|
| 23 |
[tool.pytest.ini_options]
|
| 24 |
testpaths = ["tests"]
|
| 25 |
+
pythonpath = ["."]
|
| 26 |
|
| 27 |
# Local development:
|
| 28 |
# [tool.uv.sources]
|
tabs/chat.py
CHANGED
|
@@ -28,7 +28,7 @@ from tabs.chat_ui import (
|
|
| 28 |
from utils.chat import build_chat_messages, resolve_system_prompt
|
| 29 |
from utils.chat_export import save_chat_export
|
| 30 |
from utils.helpers import format_ndif_status, session_key, widget_key
|
| 31 |
-
from utils.runtime import cached_model
|
| 32 |
|
| 33 |
if TYPE_CHECKING:
|
| 34 |
from persona_data.synth_persona import PersonaData
|
|
@@ -129,6 +129,7 @@ def _handle_single_chat_generation(
|
|
| 129 |
generation=generation,
|
| 130 |
on_status=_show_ndif_status if remote else None,
|
| 131 |
on_error=_show_error,
|
|
|
|
| 132 |
)
|
| 133 |
if error is not None:
|
| 134 |
status_box.empty()
|
|
|
|
| 28 |
from utils.chat import build_chat_messages, resolve_system_prompt
|
| 29 |
from utils.chat_export import save_chat_export
|
| 30 |
from utils.helpers import format_ndif_status, session_key, widget_key
|
| 31 |
+
from utils.runtime import cached_model, session_ndif_api_key
|
| 32 |
|
| 33 |
if TYPE_CHECKING:
|
| 34 |
from persona_data.synth_persona import PersonaData
|
|
|
|
| 129 |
generation=generation,
|
| 130 |
on_status=_show_ndif_status if remote else None,
|
| 131 |
on_error=_show_error,
|
| 132 |
+
ndif_api_key=session_ndif_api_key(),
|
| 133 |
)
|
| 134 |
if error is not None:
|
| 135 |
status_box.empty()
|
tabs/chat_shared.py
CHANGED
|
@@ -109,6 +109,7 @@ def generate_chat_reply_result(
|
|
| 109 |
generation: GenerationConfig,
|
| 110 |
on_status: Callable[[str, str, str], None] | None = None,
|
| 111 |
on_error: Callable[[Exception], None] | None = None,
|
|
|
|
| 112 |
) -> tuple[ChatReply | None, Exception | None]:
|
| 113 |
try:
|
| 114 |
return (
|
|
@@ -117,6 +118,7 @@ def generate_chat_reply_result(
|
|
| 117 |
messages=messages,
|
| 118 |
remote=remote,
|
| 119 |
on_status=on_status,
|
|
|
|
| 120 |
**generation.to_generate_kwargs(),
|
| 121 |
),
|
| 122 |
None,
|
|
|
|
| 109 |
generation: GenerationConfig,
|
| 110 |
on_status: Callable[[str, str, str], None] | None = None,
|
| 111 |
on_error: Callable[[Exception], None] | None = None,
|
| 112 |
+
ndif_api_key: str | None = None,
|
| 113 |
) -> tuple[ChatReply | None, Exception | None]:
|
| 114 |
try:
|
| 115 |
return (
|
|
|
|
| 118 |
messages=messages,
|
| 119 |
remote=remote,
|
| 120 |
on_status=on_status,
|
| 121 |
+
ndif_api_key=ndif_api_key,
|
| 122 |
**generation.to_generate_kwargs(),
|
| 123 |
),
|
| 124 |
None,
|
tabs/compare_chat.py
CHANGED
|
@@ -15,7 +15,7 @@ from utils.chat import ChatReply, build_chat_messages, resolve_system_prompt
|
|
| 15 |
from utils.chat_export import save_chat_export
|
| 16 |
from utils.contrast import compute_contrast, compute_contrast_pair
|
| 17 |
from utils.helpers import format_ndif_status, persona_label, session_key, widget_key
|
| 18 |
-
from utils.runtime import cached_model
|
| 19 |
|
| 20 |
from .chat_ui import (
|
| 21 |
GenerationConfig,
|
|
@@ -173,6 +173,7 @@ def _generate_panels(
|
|
| 173 |
remote=remote,
|
| 174 |
generation=generation,
|
| 175 |
on_status=_show_ndif_status if remote else None,
|
|
|
|
| 176 |
)
|
| 177 |
results.append(reply if error is None else error)
|
| 178 |
status_box.empty()
|
|
@@ -254,6 +255,7 @@ def _recompute_pending_contrast(
|
|
| 254 |
label_a=label_a,
|
| 255 |
label_b=label_b,
|
| 256 |
remote=remote,
|
|
|
|
| 257 |
)
|
| 258 |
if contrast is not None:
|
| 259 |
msg["_contrast"] = contrast
|
|
@@ -295,7 +297,7 @@ def _render_compare_footer(
|
|
| 295 |
|
| 296 |
footer = st.container()
|
| 297 |
with footer:
|
| 298 |
-
exp_col, rst_col, _spacer = st.columns([1, 1.25,
|
| 299 |
with exp_col:
|
| 300 |
if st.button(
|
| 301 |
"",
|
|
@@ -379,6 +381,7 @@ def _compute_new_reply_contrast(
|
|
| 379 |
label_a=persona_label(left.persona),
|
| 380 |
label_b=persona_label(right.persona),
|
| 381 |
remote=remote,
|
|
|
|
| 382 |
)
|
| 383 |
if left_contrast is not None:
|
| 384 |
left.state["messages"][-1]["_contrast"] = left_contrast
|
|
|
|
| 15 |
from utils.chat_export import save_chat_export
|
| 16 |
from utils.contrast import compute_contrast, compute_contrast_pair
|
| 17 |
from utils.helpers import format_ndif_status, persona_label, session_key, widget_key
|
| 18 |
+
from utils.runtime import cached_model, session_ndif_api_key
|
| 19 |
|
| 20 |
from .chat_ui import (
|
| 21 |
GenerationConfig,
|
|
|
|
| 173 |
remote=remote,
|
| 174 |
generation=generation,
|
| 175 |
on_status=_show_ndif_status if remote else None,
|
| 176 |
+
ndif_api_key=session_ndif_api_key(),
|
| 177 |
)
|
| 178 |
results.append(reply if error is None else error)
|
| 179 |
status_box.empty()
|
|
|
|
| 255 |
label_a=label_a,
|
| 256 |
label_b=label_b,
|
| 257 |
remote=remote,
|
| 258 |
+
ndif_api_key=session_ndif_api_key(),
|
| 259 |
)
|
| 260 |
if contrast is not None:
|
| 261 |
msg["_contrast"] = contrast
|
|
|
|
| 297 |
|
| 298 |
footer = st.container()
|
| 299 |
with footer:
|
| 300 |
+
exp_col, rst_col, _spacer = st.columns([1, 1.25, 20], gap="xsmall")
|
| 301 |
with exp_col:
|
| 302 |
if st.button(
|
| 303 |
"",
|
|
|
|
| 381 |
label_a=persona_label(left.persona),
|
| 382 |
label_b=persona_label(right.persona),
|
| 383 |
remote=remote,
|
| 384 |
+
ndif_api_key=session_ndif_api_key(),
|
| 385 |
)
|
| 386 |
if left_contrast is not None:
|
| 387 |
left.state["messages"][-1]["_contrast"] = left_contrast
|
tabs/extract.py
CHANGED
|
@@ -26,7 +26,7 @@ from utils.helpers import (
|
|
| 26 |
session_key,
|
| 27 |
widget_key,
|
| 28 |
)
|
| 29 |
-
from utils.runtime import cached_model
|
| 30 |
from utils.theme import active_base
|
| 31 |
|
| 32 |
_LAST_VARIANTS_KEY = "extract:last_variants"
|
|
@@ -366,16 +366,28 @@ def _run_extraction_plan(
|
|
| 366 |
step / total_steps if total_steps else 1.0,
|
| 367 |
text=f"{_row_label(persona, variant)} ({step + 1}/{total_steps})",
|
| 368 |
)
|
|
|
|
| 369 |
results.extend(
|
| 370 |
run_extraction(
|
| 371 |
model=model,
|
| 372 |
model_name=model_name,
|
| 373 |
-
qa_pairs=
|
| 374 |
variants=(variant,),
|
| 375 |
persona=persona,
|
| 376 |
mask_strategy=settings.mask_strategy,
|
| 377 |
remote=remote,
|
| 378 |
on_status=_on_ndif_status if remote else None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
)
|
| 380 |
)
|
| 381 |
|
|
|
|
| 26 |
session_key,
|
| 27 |
widget_key,
|
| 28 |
)
|
| 29 |
+
from utils.runtime import cached_model, remote_backend, session_ndif_api_key
|
| 30 |
from utils.theme import active_base
|
| 31 |
|
| 32 |
_LAST_VARIANTS_KEY = "extract:last_variants"
|
|
|
|
| 366 |
step / total_steps if total_steps else 1.0,
|
| 367 |
text=f"{_row_label(persona, variant)} ({step + 1}/{total_steps})",
|
| 368 |
)
|
| 369 |
+
selected_qa = qa_pairs[: settings.max_questions]
|
| 370 |
results.extend(
|
| 371 |
run_extraction(
|
| 372 |
model=model,
|
| 373 |
model_name=model_name,
|
| 374 |
+
qa_pairs=selected_qa,
|
| 375 |
variants=(variant,),
|
| 376 |
persona=persona,
|
| 377 |
mask_strategy=settings.mask_strategy,
|
| 378 |
remote=remote,
|
| 379 |
on_status=_on_ndif_status if remote else None,
|
| 380 |
+
backend_factory=(
|
| 381 |
+
(
|
| 382 |
+
lambda: remote_backend(
|
| 383 |
+
model,
|
| 384 |
+
session_ndif_api_key(),
|
| 385 |
+
on_status=_on_ndif_status,
|
| 386 |
+
)
|
| 387 |
+
)
|
| 388 |
+
if remote
|
| 389 |
+
else None
|
| 390 |
+
),
|
| 391 |
)
|
| 392 |
)
|
| 393 |
|
tabs/probe_ui.py
CHANGED
|
@@ -28,7 +28,7 @@ from utils.probes import (
|
|
| 28 |
load_probe,
|
| 29 |
load_probe_from_bytes,
|
| 30 |
)
|
| 31 |
-
from utils.runtime import cached_model
|
| 32 |
from utils.selection_controls import remembered_segmented_control
|
| 33 |
|
| 34 |
_LAST_SOURCE_KEY = session_key("probe", "last_source")
|
|
@@ -428,6 +428,7 @@ def render_probe_inspector(
|
|
| 428 |
layer=layer,
|
| 429 |
location=location,
|
| 430 |
remote=remote,
|
|
|
|
| 431 |
)
|
| 432 |
except Exception as exc:
|
| 433 |
_reset()
|
|
|
|
| 28 |
load_probe,
|
| 29 |
load_probe_from_bytes,
|
| 30 |
)
|
| 31 |
+
from utils.runtime import cached_model, session_ndif_api_key
|
| 32 |
from utils.selection_controls import remembered_segmented_control
|
| 33 |
|
| 34 |
_LAST_SOURCE_KEY = session_key("probe", "last_source")
|
|
|
|
| 428 |
layer=layer,
|
| 429 |
location=location,
|
| 430 |
remote=remote,
|
| 431 |
+
ndif_api_key=session_ndif_api_key(),
|
| 432 |
)
|
| 433 |
except Exception as exc:
|
| 434 |
_reset()
|
tests/test_analysis_sources.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from utils import analysis_sources
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class _Notice:
|
| 7 |
+
def __init__(self) -> None:
|
| 8 |
+
self.messages: list[str] = []
|
| 9 |
+
self.empty_calls = 0
|
| 10 |
+
|
| 11 |
+
def warning(self, message: str) -> None:
|
| 12 |
+
self.messages.append(message)
|
| 13 |
+
|
| 14 |
+
def empty(self) -> None:
|
| 15 |
+
self.empty_calls += 1
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_hub_vector_notice_is_transient_for_unopened_variants(monkeypatch):
|
| 19 |
+
notice = _Notice()
|
| 20 |
+
|
| 21 |
+
class DummyHubStore:
|
| 22 |
+
_datasets = {"templated": object()}
|
| 23 |
+
|
| 24 |
+
monkeypatch.setattr(
|
| 25 |
+
analysis_sources,
|
| 26 |
+
"HFPersonaVectorStore",
|
| 27 |
+
DummyHubStore,
|
| 28 |
+
)
|
| 29 |
+
monkeypatch.setattr(analysis_sources.st, "empty", lambda: notice)
|
| 30 |
+
|
| 31 |
+
with analysis_sources._hub_vector_notice(
|
| 32 |
+
DummyHubStore(), ("templated", "biography")
|
| 33 |
+
):
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
assert notice.messages
|
| 37 |
+
assert "persona vectors from Hugging Face" in notice.messages[0]
|
| 38 |
+
assert notice.empty_calls == 1
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_hub_vector_notice_stays_quiet_when_variants_are_open(monkeypatch):
|
| 42 |
+
class DummyHubStore:
|
| 43 |
+
_datasets = {"templated": object()}
|
| 44 |
+
|
| 45 |
+
monkeypatch.setattr(
|
| 46 |
+
analysis_sources,
|
| 47 |
+
"HFPersonaVectorStore",
|
| 48 |
+
DummyHubStore,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
called = []
|
| 52 |
+
monkeypatch.setattr(analysis_sources.st, "empty", lambda: called.append(True))
|
| 53 |
+
|
| 54 |
+
with analysis_sources._hub_vector_notice(DummyHubStore(), ("templated",)):
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
assert called == []
|
tests/test_datasets.py
CHANGED
|
@@ -11,8 +11,20 @@ class _Progress:
|
|
| 11 |
self.updates.append((value, text))
|
| 12 |
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def test_download_missing_startup_files_only_fetches_uncached_files(monkeypatch):
|
| 15 |
-
|
| 16 |
progress = _Progress()
|
| 17 |
downloads: list[tuple[str, str, str]] = []
|
| 18 |
|
|
@@ -21,7 +33,7 @@ def test_download_missing_startup_files_only_fetches_uncached_files(monkeypatch)
|
|
| 21 |
"_is_cached",
|
| 22 |
lambda _repo, filename: filename == "already.jsonl",
|
| 23 |
)
|
| 24 |
-
monkeypatch.setattr(datasets.st, "
|
| 25 |
monkeypatch.setattr(
|
| 26 |
datasets.st,
|
| 27 |
"progress",
|
|
@@ -41,7 +53,8 @@ def test_download_missing_startup_files_only_fetches_uncached_files(monkeypatch)
|
|
| 41 |
"Example",
|
| 42 |
)
|
| 43 |
|
| 44 |
-
assert
|
|
|
|
| 45 |
assert downloads == [("org/repo", "missing.jsonl", "dataset")]
|
| 46 |
assert progress.updates[-1] == (1.0, "Downloaded missing.jsonl (1/1)")
|
| 47 |
|
|
@@ -52,7 +65,7 @@ def test_download_missing_startup_files_stays_quiet_when_cached(monkeypatch):
|
|
| 52 |
def unexpected(*_args, **_kwargs):
|
| 53 |
raise AssertionError("cold-download UI should not render for warm cache")
|
| 54 |
|
| 55 |
-
monkeypatch.setattr(datasets.st, "
|
| 56 |
monkeypatch.setattr(datasets.st, "progress", unexpected)
|
| 57 |
monkeypatch.setattr(datasets, "hf_hub_download", unexpected)
|
| 58 |
|
|
|
|
| 11 |
self.updates.append((value, text))
|
| 12 |
|
| 13 |
|
| 14 |
+
class _Notice:
|
| 15 |
+
def __init__(self) -> None:
|
| 16 |
+
self.messages: list[str] = []
|
| 17 |
+
self.empty_calls = 0
|
| 18 |
+
|
| 19 |
+
def warning(self, message: str) -> None:
|
| 20 |
+
self.messages.append(message)
|
| 21 |
+
|
| 22 |
+
def empty(self) -> None:
|
| 23 |
+
self.empty_calls += 1
|
| 24 |
+
|
| 25 |
+
|
| 26 |
def test_download_missing_startup_files_only_fetches_uncached_files(monkeypatch):
|
| 27 |
+
notice = _Notice()
|
| 28 |
progress = _Progress()
|
| 29 |
downloads: list[tuple[str, str, str]] = []
|
| 30 |
|
|
|
|
| 33 |
"_is_cached",
|
| 34 |
lambda _repo, filename: filename == "already.jsonl",
|
| 35 |
)
|
| 36 |
+
monkeypatch.setattr(datasets.st, "empty", lambda: notice)
|
| 37 |
monkeypatch.setattr(
|
| 38 |
datasets.st,
|
| 39 |
"progress",
|
|
|
|
| 53 |
"Example",
|
| 54 |
)
|
| 55 |
|
| 56 |
+
assert notice.messages and "First-time setup for Example" in notice.messages[0]
|
| 57 |
+
assert notice.empty_calls == 1
|
| 58 |
assert downloads == [("org/repo", "missing.jsonl", "dataset")]
|
| 59 |
assert progress.updates[-1] == (1.0, "Downloaded missing.jsonl (1/1)")
|
| 60 |
|
|
|
|
| 65 |
def unexpected(*_args, **_kwargs):
|
| 66 |
raise AssertionError("cold-download UI should not render for warm cache")
|
| 67 |
|
| 68 |
+
monkeypatch.setattr(datasets.st, "empty", unexpected)
|
| 69 |
monkeypatch.setattr(datasets.st, "progress", unexpected)
|
| 70 |
monkeypatch.setattr(datasets, "hf_hub_download", unexpected)
|
| 71 |
|
tests/test_runtime_session_ndif.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from utils import runtime
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_session_ndif_api_key_is_read_from_current_session(monkeypatch) -> None:
|
| 7 |
+
monkeypatch.setattr(
|
| 8 |
+
runtime.st,
|
| 9 |
+
"session_state",
|
| 10 |
+
{"sidebar:ndif_api_key": "user-a-key"},
|
| 11 |
+
)
|
| 12 |
+
assert runtime.session_ndif_api_key() == "user-a-key"
|
| 13 |
+
|
| 14 |
+
monkeypatch.setattr(
|
| 15 |
+
runtime.st,
|
| 16 |
+
"session_state",
|
| 17 |
+
{"sidebar:ndif_api_key": "user-b-key"},
|
| 18 |
+
)
|
| 19 |
+
assert runtime.session_ndif_api_key() == "user-b-key"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def test_configured_ndif_api_key_reads_environment(monkeypatch) -> None:
|
| 23 |
+
monkeypatch.setenv("NDIF_API_KEY", "env-key")
|
| 24 |
+
assert runtime.configured_ndif_api_key() == "env-key"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_remote_backend_binds_explicit_session_key(monkeypatch) -> None:
|
| 28 |
+
from nnsight.intervention.backends import remote
|
| 29 |
+
|
| 30 |
+
seen: list[str | None] = []
|
| 31 |
+
|
| 32 |
+
class FakeBackend:
|
| 33 |
+
def __init__(self, model_key: str, api_key: str | None = None) -> None:
|
| 34 |
+
self.model_key = model_key
|
| 35 |
+
self.api_key = api_key
|
| 36 |
+
self.verbose = False
|
| 37 |
+
seen.append(api_key)
|
| 38 |
+
|
| 39 |
+
class FakeModel:
|
| 40 |
+
def to_model_key(self) -> str:
|
| 41 |
+
return "model-key"
|
| 42 |
+
|
| 43 |
+
monkeypatch.setattr(remote, "RemoteBackend", FakeBackend)
|
| 44 |
+
monkeypatch.setattr(
|
| 45 |
+
runtime.st,
|
| 46 |
+
"session_state",
|
| 47 |
+
{"sidebar:ndif_api_key": "ambient-session-key"},
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
backend = runtime.remote_backend(FakeModel(), "explicit-user-key")
|
| 51 |
+
|
| 52 |
+
assert backend.api_key == "explicit-user-key"
|
| 53 |
+
assert seen == ["explicit-user-key"]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def test_remote_backend_falls_back_to_environment_key(monkeypatch) -> None:
|
| 57 |
+
from nnsight.intervention.backends import remote
|
| 58 |
+
|
| 59 |
+
class FakeBackend:
|
| 60 |
+
def __init__(self, model_key: str, api_key: str | None = None) -> None:
|
| 61 |
+
self.model_key = model_key
|
| 62 |
+
self.api_key = api_key
|
| 63 |
+
self.verbose = False
|
| 64 |
+
|
| 65 |
+
class FakeModel:
|
| 66 |
+
def to_model_key(self) -> str:
|
| 67 |
+
return "model-key"
|
| 68 |
+
|
| 69 |
+
monkeypatch.setattr(remote, "RemoteBackend", FakeBackend)
|
| 70 |
+
monkeypatch.setattr(runtime.st, "session_state", {})
|
| 71 |
+
monkeypatch.setenv("NDIF_API_KEY", "env-key")
|
| 72 |
+
|
| 73 |
+
backend = runtime.remote_backend(FakeModel())
|
| 74 |
+
|
| 75 |
+
assert backend.api_key == "env-key"
|
utils/analysis_sources.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
from persona_vectors.analysis import (
|
|
@@ -39,6 +40,34 @@ _VECTOR_CACHE_ENTRIES = env_int("PERSONA_UI_VECTOR_CACHE_ENTRIES", 4)
|
|
| 39 |
_PREPARED_CACHE_ENTRIES = env_int("PERSONA_UI_PREPARED_CACHE_ENTRIES", 8)
|
| 40 |
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
@st.cache_resource(show_spinner=False, max_entries=_STORE_CACHE_ENTRIES)
|
| 43 |
def activation_store_cached(
|
| 44 |
source: str,
|
|
@@ -74,9 +103,9 @@ def personas_cached(
|
|
| 74 |
*,
|
| 75 |
include_baseline: bool = False,
|
| 76 |
) -> list[str]:
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
|
| 81 |
|
| 82 |
@st.cache_data(show_spinner=False)
|
|
@@ -89,7 +118,8 @@ def persona_names_cached(
|
|
| 89 |
persona_ids: tuple[str, ...],
|
| 90 |
) -> dict[str, str]:
|
| 91 |
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 92 |
-
|
|
|
|
| 93 |
# Preserve input order, fall back to the id when the row has no display name.
|
| 94 |
return {pid: names.get(pid, pid) for pid in persona_ids}
|
| 95 |
|
|
@@ -103,9 +133,9 @@ def store_layers_cached(
|
|
| 103 |
variants: tuple[str, ...],
|
| 104 |
persona_ids: tuple[str, ...],
|
| 105 |
) -> list[int]:
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
|
| 110 |
|
| 111 |
@st.cache_data(show_spinner=False)
|
|
@@ -156,12 +186,13 @@ def load_analysis_dataset_cached(
|
|
| 156 |
persona_ids: tuple[str, ...],
|
| 157 |
) -> AnalysisDataset:
|
| 158 |
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
| 165 |
|
| 166 |
|
| 167 |
def load_persona_vectors_cached(
|
|
|
|
| 1 |
import os
|
| 2 |
+
from contextlib import contextmanager
|
| 3 |
|
| 4 |
import streamlit as st
|
| 5 |
from persona_vectors.analysis import (
|
|
|
|
| 40 |
_PREPARED_CACHE_ENTRIES = env_int("PERSONA_UI_PREPARED_CACHE_ENTRIES", 8)
|
| 41 |
|
| 42 |
|
| 43 |
+
def _hub_variants_pending(store: Store, variants: tuple[str, ...]) -> tuple[str, ...]:
|
| 44 |
+
"""Return Hub variants that have not yet been opened by this store instance."""
|
| 45 |
+
|
| 46 |
+
if not isinstance(store, HFPersonaVectorStore):
|
| 47 |
+
return ()
|
| 48 |
+
return tuple(variant for variant in variants if variant not in store._datasets)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@contextmanager
|
| 52 |
+
def _hub_vector_notice(store: Store, variants: tuple[str, ...]):
|
| 53 |
+
"""Show a transient, honest cold-load note for Hub-backed vector data."""
|
| 54 |
+
|
| 55 |
+
pending = _hub_variants_pending(store, variants)
|
| 56 |
+
if not pending:
|
| 57 |
+
yield
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
notice = st.empty()
|
| 61 |
+
notice.warning(
|
| 62 |
+
"Loading persona vectors from Hugging Face. "
|
| 63 |
+
"On a cold cache, this may download Hub dataset files."
|
| 64 |
+
)
|
| 65 |
+
try:
|
| 66 |
+
yield
|
| 67 |
+
finally:
|
| 68 |
+
notice.empty()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
@st.cache_resource(show_spinner=False, max_entries=_STORE_CACHE_ENTRIES)
|
| 72 |
def activation_store_cached(
|
| 73 |
source: str,
|
|
|
|
| 103 |
*,
|
| 104 |
include_baseline: bool = False,
|
| 105 |
) -> list[str]:
|
| 106 |
+
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 107 |
+
with _hub_vector_notice(store, variants):
|
| 108 |
+
return store.list_personas(list(variants), include_baseline=include_baseline)
|
| 109 |
|
| 110 |
|
| 111 |
@st.cache_data(show_spinner=False)
|
|
|
|
| 118 |
persona_ids: tuple[str, ...],
|
| 119 |
) -> dict[str, str]:
|
| 120 |
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 121 |
+
with _hub_vector_notice(store, variants):
|
| 122 |
+
names = store.persona_names(list(persona_ids), variants=list(variants))
|
| 123 |
# Preserve input order, fall back to the id when the row has no display name.
|
| 124 |
return {pid: names.get(pid, pid) for pid in persona_ids}
|
| 125 |
|
|
|
|
| 133 |
variants: tuple[str, ...],
|
| 134 |
persona_ids: tuple[str, ...],
|
| 135 |
) -> list[int]:
|
| 136 |
+
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 137 |
+
with _hub_vector_notice(store, variants):
|
| 138 |
+
return store.list_layers(list(variants), list(persona_ids))
|
| 139 |
|
| 140 |
|
| 141 |
@st.cache_data(show_spinner=False)
|
|
|
|
| 186 |
persona_ids: tuple[str, ...],
|
| 187 |
) -> AnalysisDataset:
|
| 188 |
store = activation_store_cached(source, location, model_name, mask_strategy_value)
|
| 189 |
+
with _hub_vector_notice(store, variants):
|
| 190 |
+
return load_analysis_dataset(
|
| 191 |
+
store,
|
| 192 |
+
variants,
|
| 193 |
+
mask_strategy=MaskStrategy(mask_strategy_value),
|
| 194 |
+
persona_ids=persona_ids,
|
| 195 |
+
)
|
| 196 |
|
| 197 |
|
| 198 |
def load_persona_vectors_cached(
|
utils/chat.py
CHANGED
|
@@ -187,6 +187,7 @@ def generate_chat_reply(
|
|
| 187 |
repetition_penalty: float = 1.0,
|
| 188 |
seed: int | None = None,
|
| 189 |
on_status: Callable[[str, str, str], None] | None = None,
|
|
|
|
| 190 |
) -> ChatReply:
|
| 191 |
"""Generate one assistant reply from a full chat history.
|
| 192 |
|
|
@@ -230,7 +231,12 @@ def generate_chat_reply(
|
|
| 230 |
generation_kwargs["repetition_penalty"] = repetition_penalty
|
| 231 |
# `remote` is captured by nnsight's RemoteableMixin.trace() and is NOT
|
| 232 |
# forwarded to the underlying model's generate
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
with (
|
| 236 |
_seeded_rng(seed if do_sample and not remote else None),
|
|
@@ -256,34 +262,3 @@ def generate_chat_reply(
|
|
| 256 |
text=text,
|
| 257 |
generated_ids=generated_ids.detach().cpu(),
|
| 258 |
)
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
def _build_remote_backend(
|
| 262 |
-
model: StandardizedTransformer,
|
| 263 |
-
on_status: Callable[[str, str, str], None] | None,
|
| 264 |
-
):
|
| 265 |
-
"""Build an NDIF backend that can surface lifecycle updates to callers."""
|
| 266 |
-
|
| 267 |
-
if on_status is None:
|
| 268 |
-
return None
|
| 269 |
-
|
| 270 |
-
from nnsight.intervention.backends.remote import JobStatusDisplay, RemoteBackend
|
| 271 |
-
|
| 272 |
-
class _CallbackJobStatusDisplay(JobStatusDisplay):
|
| 273 |
-
def update(
|
| 274 |
-
self,
|
| 275 |
-
job_id: str = "",
|
| 276 |
-
status_name: str = "",
|
| 277 |
-
description: str = "",
|
| 278 |
-
):
|
| 279 |
-
super().update(job_id, status_name, description)
|
| 280 |
-
if status_name:
|
| 281 |
-
on_status(job_id, status_name, description)
|
| 282 |
-
|
| 283 |
-
backend = RemoteBackend(model.to_model_key())
|
| 284 |
-
backend.CONNECT_TIMEOUT = 300.0
|
| 285 |
-
backend.status_display = _CallbackJobStatusDisplay(
|
| 286 |
-
enabled=True,
|
| 287 |
-
verbose=backend.verbose,
|
| 288 |
-
)
|
| 289 |
-
return backend
|
|
|
|
| 187 |
repetition_penalty: float = 1.0,
|
| 188 |
seed: int | None = None,
|
| 189 |
on_status: Callable[[str, str, str], None] | None = None,
|
| 190 |
+
ndif_api_key: str | None = None,
|
| 191 |
) -> ChatReply:
|
| 192 |
"""Generate one assistant reply from a full chat history.
|
| 193 |
|
|
|
|
| 231 |
generation_kwargs["repetition_penalty"] = repetition_penalty
|
| 232 |
# `remote` is captured by nnsight's RemoteableMixin.trace() and is NOT
|
| 233 |
# forwarded to the underlying model's generate
|
| 234 |
+
if remote:
|
| 235 |
+
from utils.runtime import remote_backend
|
| 236 |
+
|
| 237 |
+
backend = remote_backend(model, ndif_api_key, on_status=on_status)
|
| 238 |
+
else:
|
| 239 |
+
backend = None
|
| 240 |
|
| 241 |
with (
|
| 242 |
_seeded_rng(seed if do_sample and not remote else None),
|
|
|
|
| 262 |
text=text,
|
| 263 |
generated_ids=generated_ids.detach().cpu(),
|
| 264 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/contrast.py
CHANGED
|
@@ -100,6 +100,7 @@ def _score_passes(
|
|
| 100 |
model: StandardizedTransformer,
|
| 101 |
specs: list[PassSpec],
|
| 102 |
remote: bool,
|
|
|
|
| 103 |
) -> dict[str, torch.Tensor]:
|
| 104 |
"""
|
| 105 |
Run one forward pass per spec and return reduced per-token logprobs.
|
|
@@ -115,7 +116,13 @@ def _score_passes(
|
|
| 115 |
n_resp: int,
|
| 116 |
target_ids: torch.Tensor,
|
| 117 |
) -> torch.Tensor:
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
# logit at position i predicts token i+1, so response token j
|
| 120 |
# (at full-text position n_ctx+j) uses logit at n_ctx+j-1.
|
| 121 |
resp_logits = model.logits[0, n_ctx - 1 : n_ctx - 1 + n_resp].float()
|
|
@@ -157,6 +164,7 @@ def compute_contrast(
|
|
| 157 |
label_a: str,
|
| 158 |
label_b: str,
|
| 159 |
remote: bool = False,
|
|
|
|
| 160 |
) -> "TokenContrast | None":
|
| 161 |
"""Compute per-token contrast weights for a single response (2 forward passes)."""
|
| 162 |
tokenizer = model.tokenizer
|
|
@@ -164,7 +172,7 @@ def compute_contrast(
|
|
| 164 |
return None
|
| 165 |
|
| 166 |
specs = _specs_for_response(tokenizer, response_ids, context_a, context_b, "r")
|
| 167 |
-
out = _score_passes(model, specs, remote)
|
| 168 |
return _build_contrast(
|
| 169 |
tokenizer, response_ids, out["r_under_a"], out["r_under_b"], label_a, label_b
|
| 170 |
)
|
|
@@ -179,6 +187,7 @@ def compute_contrast_pair(
|
|
| 179 |
label_a: str,
|
| 180 |
label_b: str,
|
| 181 |
remote: bool = False,
|
|
|
|
| 182 |
) -> tuple["TokenContrast | None", "TokenContrast | None"]:
|
| 183 |
"""
|
| 184 |
Compute contrast weights for both panel responses (up to 4 remote passes).
|
|
@@ -197,7 +206,7 @@ def compute_contrast_pair(
|
|
| 197 |
tokenizer, response_ids_b, context_a, context_b, "b"
|
| 198 |
)
|
| 199 |
|
| 200 |
-
out = _score_passes(model, specs, remote)
|
| 201 |
|
| 202 |
def _build(resp_ids: torch.Tensor, prefix: str) -> "TokenContrast | None":
|
| 203 |
k_a, k_b = f"{prefix}_under_a", f"{prefix}_under_b"
|
|
|
|
| 100 |
model: StandardizedTransformer,
|
| 101 |
specs: list[PassSpec],
|
| 102 |
remote: bool,
|
| 103 |
+
ndif_api_key: str | None = None,
|
| 104 |
) -> dict[str, torch.Tensor]:
|
| 105 |
"""
|
| 106 |
Run one forward pass per spec and return reduced per-token logprobs.
|
|
|
|
| 116 |
n_resp: int,
|
| 117 |
target_ids: torch.Tensor,
|
| 118 |
) -> torch.Tensor:
|
| 119 |
+
if remote:
|
| 120 |
+
from utils.runtime import remote_backend
|
| 121 |
+
|
| 122 |
+
backend = remote_backend(model, ndif_api_key)
|
| 123 |
+
else:
|
| 124 |
+
backend = None
|
| 125 |
+
with torch.no_grad(), model.trace(input_ids, remote=remote, backend=backend):
|
| 126 |
# logit at position i predicts token i+1, so response token j
|
| 127 |
# (at full-text position n_ctx+j) uses logit at n_ctx+j-1.
|
| 128 |
resp_logits = model.logits[0, n_ctx - 1 : n_ctx - 1 + n_resp].float()
|
|
|
|
| 164 |
label_a: str,
|
| 165 |
label_b: str,
|
| 166 |
remote: bool = False,
|
| 167 |
+
ndif_api_key: str | None = None,
|
| 168 |
) -> "TokenContrast | None":
|
| 169 |
"""Compute per-token contrast weights for a single response (2 forward passes)."""
|
| 170 |
tokenizer = model.tokenizer
|
|
|
|
| 172 |
return None
|
| 173 |
|
| 174 |
specs = _specs_for_response(tokenizer, response_ids, context_a, context_b, "r")
|
| 175 |
+
out = _score_passes(model, specs, remote, ndif_api_key)
|
| 176 |
return _build_contrast(
|
| 177 |
tokenizer, response_ids, out["r_under_a"], out["r_under_b"], label_a, label_b
|
| 178 |
)
|
|
|
|
| 187 |
label_a: str,
|
| 188 |
label_b: str,
|
| 189 |
remote: bool = False,
|
| 190 |
+
ndif_api_key: str | None = None,
|
| 191 |
) -> tuple["TokenContrast | None", "TokenContrast | None"]:
|
| 192 |
"""
|
| 193 |
Compute contrast weights for both panel responses (up to 4 remote passes).
|
|
|
|
| 206 |
tokenizer, response_ids_b, context_a, context_b, "b"
|
| 207 |
)
|
| 208 |
|
| 209 |
+
out = _score_passes(model, specs, remote, ndif_api_key)
|
| 210 |
|
| 211 |
def _build(resp_ids: torch.Tensor, prefix: str) -> "TokenContrast | None":
|
| 212 |
k_a, k_b = f"{prefix}_under_a", f"{prefix}_under_b"
|
utils/datasets.py
CHANGED
|
@@ -183,7 +183,8 @@ def _download_missing_startup_files_if_needed(
|
|
| 183 |
if not missing:
|
| 184 |
return
|
| 185 |
|
| 186 |
-
st.
|
|
|
|
| 187 |
f"First-time setup for {label}: downloading dataset files from Hugging Face. "
|
| 188 |
"Later loads should use the local cache."
|
| 189 |
)
|
|
@@ -199,6 +200,7 @@ def _download_missing_startup_files_if_needed(
|
|
| 199 |
index / total,
|
| 200 |
text=f"Downloaded {filename} ({index}/{total})",
|
| 201 |
)
|
|
|
|
| 202 |
|
| 203 |
|
| 204 |
def _prepare_nemotron_startup_download(dataset_source: str, label: str) -> None:
|
|
|
|
| 183 |
if not missing:
|
| 184 |
return
|
| 185 |
|
| 186 |
+
notice = st.empty()
|
| 187 |
+
notice.warning(
|
| 188 |
f"First-time setup for {label}: downloading dataset files from Hugging Face. "
|
| 189 |
"Later loads should use the local cache."
|
| 190 |
)
|
|
|
|
| 200 |
index / total,
|
| 201 |
text=f"Downloaded {filename} ({index}/{total})",
|
| 202 |
)
|
| 203 |
+
notice.empty()
|
| 204 |
|
| 205 |
|
| 206 |
def _prepare_nemotron_startup_download(dataset_source: str, label: str) -> None:
|
utils/probe_trace.py
CHANGED
|
@@ -51,6 +51,7 @@ def trace_conversation(
|
|
| 51 |
layer: int,
|
| 52 |
location: str,
|
| 53 |
remote: bool,
|
|
|
|
| 54 |
) -> ConversationTrace:
|
| 55 |
prompt_text, _ = format_generation_prompt(
|
| 56 |
messages,
|
|
@@ -71,7 +72,13 @@ def trace_conversation(
|
|
| 71 |
return cached
|
| 72 |
|
| 73 |
accessor = _select_accessor(model, location)
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
saved_ids = model.input_ids[0].detach().cpu().save()
|
| 76 |
saved_acts = accessor[layer][0].detach().float().cpu().save()
|
| 77 |
|
|
|
|
| 51 |
layer: int,
|
| 52 |
location: str,
|
| 53 |
remote: bool,
|
| 54 |
+
ndif_api_key: str | None = None,
|
| 55 |
) -> ConversationTrace:
|
| 56 |
prompt_text, _ = format_generation_prompt(
|
| 57 |
messages,
|
|
|
|
| 72 |
return cached
|
| 73 |
|
| 74 |
accessor = _select_accessor(model, location)
|
| 75 |
+
if remote:
|
| 76 |
+
from utils.runtime import remote_backend
|
| 77 |
+
|
| 78 |
+
backend = remote_backend(model, ndif_api_key)
|
| 79 |
+
else:
|
| 80 |
+
backend = None
|
| 81 |
+
with torch.no_grad(), model.trace(prompt_text, remote=remote, backend=backend):
|
| 82 |
saved_ids = model.input_ids[0].detach().cpu().save()
|
| 83 |
saved_acts = accessor[layer][0].detach().float().cpu().save()
|
| 84 |
|
utils/runtime.py
CHANGED
|
@@ -1,15 +1,17 @@
|
|
| 1 |
import json
|
| 2 |
import logging
|
|
|
|
| 3 |
from collections.abc import Iterable
|
| 4 |
|
| 5 |
import streamlit as st
|
| 6 |
|
| 7 |
-
from utils.helpers import env_int
|
| 8 |
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
_LANGUAGE_MODEL_CLASSES = {"LanguageModel", "StandardizedTransformer"}
|
| 11 |
_EXPECTED_NDIF_STATES = {"RUNNING", "NOT DEPLOYED", "DEPLOYING", "DELETING"}
|
| 12 |
_MODEL_CACHE_ENTRIES = env_int("PERSONA_UI_MODEL_CACHE_ENTRIES", 1)
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
def _iter_deployments(raw: object) -> Iterable[dict]:
|
|
@@ -60,17 +62,17 @@ def _unexpected_state(deployment: dict) -> tuple[str, str] | None:
|
|
| 60 |
def list_remote_models() -> list[str]:
|
| 61 |
"""Return the NDIF language models that are currently running.
|
| 62 |
|
| 63 |
-
Parses the raw NDIF response directly instead of going through
|
| 64 |
-
``nnsight.
|
| 65 |
any deployment with an ``application_state`` that isn't in nnsight's
|
| 66 |
``ModelStatus`` enum (e.g. ``UNHEALTHY``) — one bad deployment poisons
|
| 67 |
the whole response. See nnsight 0.6.3 ``ndif.py::status``.
|
| 68 |
"""
|
| 69 |
|
| 70 |
-
|
| 71 |
|
| 72 |
try:
|
| 73 |
-
raw =
|
| 74 |
except Exception:
|
| 75 |
logger.warning("Failed to fetch NDIF status", exc_info=True)
|
| 76 |
return []
|
|
@@ -94,6 +96,52 @@ def list_remote_models() -> list[str]:
|
|
| 94 |
return sorted(set(model_names))
|
| 95 |
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
@st.cache_resource(show_spinner=False, max_entries=_MODEL_CACHE_ENTRIES)
|
| 98 |
def cached_model(model_name: str):
|
| 99 |
"""Load and cache a standardized nnterp model.
|
|
|
|
| 1 |
import json
|
| 2 |
import logging
|
| 3 |
+
import os
|
| 4 |
from collections.abc import Iterable
|
| 5 |
|
| 6 |
import streamlit as st
|
| 7 |
|
| 8 |
+
from utils.helpers import env_int, session_key
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
_LANGUAGE_MODEL_CLASSES = {"LanguageModel", "StandardizedTransformer"}
|
| 12 |
_EXPECTED_NDIF_STATES = {"RUNNING", "NOT DEPLOYED", "DEPLOYING", "DELETING"}
|
| 13 |
_MODEL_CACHE_ENTRIES = env_int("PERSONA_UI_MODEL_CACHE_ENTRIES", 1)
|
| 14 |
+
_SESSION_NDIF_API_KEY = session_key("sidebar", "ndif_api_key")
|
| 15 |
|
| 16 |
|
| 17 |
def _iter_deployments(raw: object) -> Iterable[dict]:
|
|
|
|
| 62 |
def list_remote_models() -> list[str]:
|
| 63 |
"""Return the NDIF language models that are currently running.
|
| 64 |
|
| 65 |
+
Parses the raw NDIF response directly instead of going through the formatted
|
| 66 |
+
``nnsight.ndif.status()`` response because formatting crashes whenever NDIF reports
|
| 67 |
any deployment with an ``application_state`` that isn't in nnsight's
|
| 68 |
``ModelStatus`` enum (e.g. ``UNHEALTHY``) — one bad deployment poisons
|
| 69 |
the whole response. See nnsight 0.6.3 ``ndif.py::status``.
|
| 70 |
"""
|
| 71 |
|
| 72 |
+
from nnsight.ndif import status
|
| 73 |
|
| 74 |
try:
|
| 75 |
+
raw = status(raw=True)
|
| 76 |
except Exception:
|
| 77 |
logger.warning("Failed to fetch NDIF status", exc_info=True)
|
| 78 |
return []
|
|
|
|
| 96 |
return sorted(set(model_names))
|
| 97 |
|
| 98 |
|
| 99 |
+
def session_ndif_api_key() -> str | None:
|
| 100 |
+
"""Return this visitor's NDIF key without touching process globals."""
|
| 101 |
+
|
| 102 |
+
value = st.session_state.get(_SESSION_NDIF_API_KEY)
|
| 103 |
+
return value if isinstance(value, str) and value else None
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def configured_ndif_api_key() -> str | None:
|
| 107 |
+
"""Return an app-level NDIF key configured through the environment, if any."""
|
| 108 |
+
|
| 109 |
+
value = os.environ.get("NDIF_API_KEY")
|
| 110 |
+
return value if value else None
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def remote_backend(model: object, api_key: str | None = None, *, on_status=None):
|
| 114 |
+
"""Build an NDIF backend with credentials bound to one browser session."""
|
| 115 |
+
|
| 116 |
+
from nnsight.intervention.backends.remote import JobStatusDisplay, RemoteBackend
|
| 117 |
+
|
| 118 |
+
active_key = api_key or session_ndif_api_key() or configured_ndif_api_key()
|
| 119 |
+
if not active_key:
|
| 120 |
+
raise RuntimeError("Enter your NDIF API key before using remote execution.")
|
| 121 |
+
|
| 122 |
+
backend = RemoteBackend(model.to_model_key(), api_key=active_key)
|
| 123 |
+
backend.CONNECT_TIMEOUT = 300.0
|
| 124 |
+
if on_status is None:
|
| 125 |
+
return backend
|
| 126 |
+
|
| 127 |
+
class _CallbackJobStatusDisplay(JobStatusDisplay):
|
| 128 |
+
def update(
|
| 129 |
+
self,
|
| 130 |
+
job_id: str = "",
|
| 131 |
+
status_name: str = "",
|
| 132 |
+
description: str = "",
|
| 133 |
+
):
|
| 134 |
+
super().update(job_id, status_name, description)
|
| 135 |
+
if status_name:
|
| 136 |
+
on_status(job_id, status_name, description)
|
| 137 |
+
|
| 138 |
+
backend.status_display = _CallbackJobStatusDisplay(
|
| 139 |
+
enabled=True,
|
| 140 |
+
verbose=backend.verbose,
|
| 141 |
+
)
|
| 142 |
+
return backend
|
| 143 |
+
|
| 144 |
+
|
| 145 |
@st.cache_resource(show_spinner=False, max_entries=_MODEL_CACHE_ENTRIES)
|
| 146 |
def cached_model(model_name: str):
|
| 147 |
"""Load and cache a standardized nnterp model.
|
uv.lock
CHANGED
|
@@ -1608,7 +1608,7 @@ requires-dist = [
|
|
| 1608 |
{ name = "catppuccin", specifier = ">=2.5.0" },
|
| 1609 |
{ name = "datasets", specifier = ">=4.8.5" },
|
| 1610 |
{ name = "huggingface-hub", specifier = ">=1.14.0" },
|
| 1611 |
-
{ name = "persona-vectors", specifier = ">=0.8.
|
| 1612 |
{ name = "plotly", specifier = ">=6.6.0" },
|
| 1613 |
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
| 1614 |
{ name = "safetensors", specifier = ">=0.7.0" },
|
|
@@ -1620,7 +1620,7 @@ dev = [{ name = "pytest", specifier = ">=9.0.3" }]
|
|
| 1620 |
|
| 1621 |
[[package]]
|
| 1622 |
name = "persona-vectors"
|
| 1623 |
-
version = "0.8.
|
| 1624 |
source = { registry = "https://pypi.org/simple" }
|
| 1625 |
dependencies = [
|
| 1626 |
{ name = "datasets" },
|
|
@@ -1639,9 +1639,9 @@ dependencies = [
|
|
| 1639 |
{ name = "transformers" },
|
| 1640 |
{ name = "umap-learn" },
|
| 1641 |
]
|
| 1642 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 1643 |
wheels = [
|
| 1644 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 1645 |
]
|
| 1646 |
|
| 1647 |
[[package]]
|
|
|
|
| 1608 |
{ name = "catppuccin", specifier = ">=2.5.0" },
|
| 1609 |
{ name = "datasets", specifier = ">=4.8.5" },
|
| 1610 |
{ name = "huggingface-hub", specifier = ">=1.14.0" },
|
| 1611 |
+
{ name = "persona-vectors", specifier = ">=0.8.4" },
|
| 1612 |
{ name = "plotly", specifier = ">=6.6.0" },
|
| 1613 |
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
| 1614 |
{ name = "safetensors", specifier = ">=0.7.0" },
|
|
|
|
| 1620 |
|
| 1621 |
[[package]]
|
| 1622 |
name = "persona-vectors"
|
| 1623 |
+
version = "0.8.4"
|
| 1624 |
source = { registry = "https://pypi.org/simple" }
|
| 1625 |
dependencies = [
|
| 1626 |
{ name = "datasets" },
|
|
|
|
| 1639 |
{ name = "transformers" },
|
| 1640 |
{ name = "umap-learn" },
|
| 1641 |
]
|
| 1642 |
+
sdist = { url = "https://files.pythonhosted.org/packages/65/e4/9f7d9e082d3719e7b0e808b853c74795a902c2c433a9bf5cab1bfe712385/persona_vectors-0.8.4.tar.gz", hash = "sha256:46a941c6f6c4029c0ac32c103c9f8c9574fdb3a288fb07b9477c13e08b6941e8", size = 43333, upload-time = "2026-05-18T17:28:07.812Z" }
|
| 1643 |
wheels = [
|
| 1644 |
+
{ url = "https://files.pythonhosted.org/packages/4e/6f/25f63c81c0ac7f5daafe8a18a23a11b351be982109f8e12d615f9bb97080/persona_vectors-0.8.4-py3-none-any.whl", hash = "sha256:4f3de83a4527c432e8974e509bfc0e92dfc53a199ee52421a217bfc2edfbe0d0", size = 53324, upload-time = "2026-05-18T17:28:06.862Z" },
|
| 1645 |
]
|
| 1646 |
|
| 1647 |
[[package]]
|