Jac-Zac commited on
Commit ·
a89a7f1
0
Parent(s):
First commit
Browse files- .env.example +19 -0
- .gitignore +64 -0
- README.md +79 -0
- WARNING.md +3 -0
- app.py +111 -0
- pyproject.toml +28 -0
- state.py +59 -0
- tabs/__init__.py +0 -0
- tabs/chat.py +636 -0
- tabs/compare.py +354 -0
- tabs/extract.py +195 -0
- utils/__init__.py +1 -0
- utils/artifacts.py +249 -0
- utils/chat.py +226 -0
- utils/chat_export.py +117 -0
- utils/datasets.py +59 -0
- utils/extraction.py +151 -0
- utils/helpers.py +66 -0
- utils/local_dataset.py +72 -0
- utils/runtime.py +53 -0
- uv.lock +0 -0
.env.example
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copy this file to .env and fill in the values.
|
| 2 |
+
|
| 3 |
+
# NDIF API key for remote nnsight execution
|
| 4 |
+
# Required only when REMOTE=True in notebook.py
|
| 5 |
+
# Get yours at https://login.ndif.us
|
| 6 |
+
NDIF_API_KEY=your-ndif-api-key-here
|
| 7 |
+
|
| 8 |
+
# HuggingFace model cache directory
|
| 9 |
+
# Defaults to ~/.cache/huggingface if unset
|
| 10 |
+
# Useful when working on a cluster with a shared cache or limited home quota
|
| 11 |
+
HF_HOME=/path/to/your/hf/cache
|
| 12 |
+
|
| 13 |
+
# Root directory for all generated artifacts (activations, plots, etc.)
|
| 14 |
+
# Defaults to artifacts if unset
|
| 15 |
+
ARTIFACTS_DIR=artifacts
|
| 16 |
+
|
| 17 |
+
# Default model IDs shown in the sidebar (optional — change to override the built-in defaults)
|
| 18 |
+
# DEFAULT_MODEL=google/gemma-2-2b-it
|
| 19 |
+
# REMOTE_DEFAULT_MODEL=google/gemma-2-9b-it
|
.gitignore
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual environments
|
| 24 |
+
.venv/
|
| 25 |
+
venv/
|
| 26 |
+
ENV/
|
| 27 |
+
env/
|
| 28 |
+
|
| 29 |
+
# Environment variables — .env.example is intentionally tracked
|
| 30 |
+
.env
|
| 31 |
+
.env.*
|
| 32 |
+
!.env.example
|
| 33 |
+
|
| 34 |
+
# IDE
|
| 35 |
+
.idea/
|
| 36 |
+
.vscode/
|
| 37 |
+
*.swp
|
| 38 |
+
*.swo
|
| 39 |
+
*~
|
| 40 |
+
|
| 41 |
+
# Jupyter
|
| 42 |
+
.ipynb_checkpoints/
|
| 43 |
+
|
| 44 |
+
# Testing
|
| 45 |
+
.pytest_cache/
|
| 46 |
+
.coverage
|
| 47 |
+
htmlcov/
|
| 48 |
+
|
| 49 |
+
# OS
|
| 50 |
+
.DS_Store
|
| 51 |
+
Thumbs.db
|
| 52 |
+
|
| 53 |
+
# Project specific
|
| 54 |
+
results/
|
| 55 |
+
outputs/
|
| 56 |
+
artifacts/
|
| 57 |
+
*.json.bak
|
| 58 |
+
*.jsonl
|
| 59 |
+
*.jsonl.bak
|
| 60 |
+
|
| 61 |
+
# Tmp to avoid pushing things I'm testing
|
| 62 |
+
__marimo__/
|
| 63 |
+
AGENTS.md
|
| 64 |
+
# notebook_marimo.py
|
README.md
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Persona UI
|
| 2 |
+
|
| 3 |
+
Streamlit interface for persona vector extraction, analysis, and chat.
|
| 4 |
+
|
| 5 |
+
> [!WARNING]
|
| 6 |
+
> This is a proof-of-concept UI, mostly vibe-coded. It will likely be replaced by a proper frontend/backend in the future.
|
| 7 |
+
|
| 8 |
+
## Overview
|
| 9 |
+
|
| 10 |
+
A web app built on top of [persona-vectors](../persona-vectors) that provides three tabs:
|
| 11 |
+
|
| 12 |
+
- **Chat** — interactive conversations with a model using persona-based system prompts (templated or biography)
|
| 13 |
+
- **Compare** — load saved activations and explore layer-wise cosine similarity, PCA, and UMAP projections
|
| 14 |
+
- **Extract** — run activation extraction from HuggingFace or a local JSONL dataset directly from the browser
|
| 15 |
+
|
| 16 |
+
## Repository Layout
|
| 17 |
+
|
| 18 |
+
```
|
| 19 |
+
persona-ui/
|
| 20 |
+
├── app.py # Main entry point (Streamlit)
|
| 21 |
+
├── state.py # Session state management (chat history, KV cache)
|
| 22 |
+
├── tabs/
|
| 23 |
+
│ ├── chat.py # Chat tab
|
| 24 |
+
│ ├── compare.py # Activation comparison tab
|
| 25 |
+
│ └── extract.py # Extraction tab
|
| 26 |
+
└── utils/
|
| 27 |
+
├── artifacts.py # Load saved activations metadata
|
| 28 |
+
├── chat.py # Chat generation logic
|
| 29 |
+
├── chat_export.py # Export chat logs to JSON
|
| 30 |
+
├── datasets.py # Dataset loader wrapper
|
| 31 |
+
├── extraction.py # Extraction orchestration
|
| 32 |
+
├── helpers.py # UI labels and slug helpers
|
| 33 |
+
├── local_dataset.py # Local JSONL dataset parsing
|
| 34 |
+
└── runtime.py # Model caching and NDIF queries
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
Dataset loading and environment helpers are provided by the sibling
|
| 38 |
+
[persona-data](../persona-data) package. Core extraction, analysis, and
|
| 39 |
+
steering logic comes from [persona-vectors](../persona-vectors).
|
| 40 |
+
|
| 41 |
+
## Installation
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
uv sync
|
| 45 |
+
cp .env.example .env
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
## Quickstart
|
| 49 |
+
|
| 50 |
+
```bash
|
| 51 |
+
streamlit run app.py
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## Configuration
|
| 55 |
+
|
| 56 |
+
Copy `.env.example` to `.env` and fill in:
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
NDIF_API_KEY=... # Required for remote (NDIF) model execution
|
| 60 |
+
HF_HOME=... # Optional: HuggingFace cache directory
|
| 61 |
+
ARTIFACTS_DIR=... # Optional: where activations are read from (default: ./artifacts)
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
The app picks up this file automatically via `load_env()` on startup.
|
| 65 |
+
|
| 66 |
+
## Saved Artifacts
|
| 67 |
+
|
| 68 |
+
The Compare and Extract tabs read from / write to:
|
| 69 |
+
|
| 70 |
+
```
|
| 71 |
+
artifacts/
|
| 72 |
+
├── activations/<model_dir>/<prompt_variant>/<persona_id>/
|
| 73 |
+
│ ├── activations.safetensors
|
| 74 |
+
│ └── metadata.json
|
| 75 |
+
└── chats/<model_dir>/<prompt_variant>/
|
| 76 |
+
└── <export>.json
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
`<model_dir>` is the model name with `/` replaced by `__` (e.g. `google__gemma-2-9b-it`).
|
WARNING.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# WARNING 🚨
|
| 2 |
+
|
| 3 |
+
This part of the project is majorly vibe-coded. Mostly becuase it will probably be changed in the future to support an actual interace backhand / frontand without streamlit. And is as of now mostly a proof of concept and an easy development part of the project.
|
app.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import streamlit as st
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
|
| 7 |
+
# Load .env early so DEFAULT_MODEL / REMOTE_DEFAULT_MODEL can be overridden via env
|
| 8 |
+
load_dotenv(Path(__file__).parent / ".env")
|
| 9 |
+
|
| 10 |
+
from utils.helpers import DATASET_SOURCES
|
| 11 |
+
|
| 12 |
+
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "google/gemma-2-2b-it")
|
| 13 |
+
REMOTE_DEFAULT_MODEL = os.environ.get("REMOTE_DEFAULT_MODEL", "google/gemma-2-9b-it")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _sidebar_controls() -> tuple[bool, str, str, str]:
|
| 17 |
+
from utils.runtime import list_remote_models
|
| 18 |
+
|
| 19 |
+
with st.sidebar:
|
| 20 |
+
st.markdown("# Persona UI")
|
| 21 |
+
st.caption("Chat, extract, and compare persona runs.")
|
| 22 |
+
|
| 23 |
+
if "sidebar__active_tab" not in st.session_state:
|
| 24 |
+
st.session_state["sidebar__active_tab"] = _TABS[0]
|
| 25 |
+
|
| 26 |
+
active_tab = st.session_state["sidebar__active_tab"]
|
| 27 |
+
for tab_name, icon in zip(_TABS, _TAB_ICONS, strict=True):
|
| 28 |
+
is_selected = tab_name == active_tab
|
| 29 |
+
if st.button(
|
| 30 |
+
tab_name,
|
| 31 |
+
key=f"sidebar__tab__{tab_name.lower()}",
|
| 32 |
+
use_container_width=True,
|
| 33 |
+
type="primary" if is_selected else "secondary",
|
| 34 |
+
icon=icon,
|
| 35 |
+
):
|
| 36 |
+
st.session_state["sidebar__active_tab"] = tab_name
|
| 37 |
+
st.rerun()
|
| 38 |
+
|
| 39 |
+
st.divider()
|
| 40 |
+
st.caption("Runtime")
|
| 41 |
+
remote = st.toggle("Remote (NDIF)", value=False, key="sidebar__remote")
|
| 42 |
+
|
| 43 |
+
if remote:
|
| 44 |
+
remote_models = list_remote_models()
|
| 45 |
+
if remote_models:
|
| 46 |
+
default_model = (
|
| 47 |
+
REMOTE_DEFAULT_MODEL
|
| 48 |
+
if REMOTE_DEFAULT_MODEL in remote_models
|
| 49 |
+
else remote_models[0]
|
| 50 |
+
)
|
| 51 |
+
model_name = st.selectbox(
|
| 52 |
+
"Model",
|
| 53 |
+
options=remote_models,
|
| 54 |
+
index=remote_models.index(default_model),
|
| 55 |
+
key="sidebar__remote_model",
|
| 56 |
+
help="Running NDIF model.",
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
st.error("No running NDIF models found.")
|
| 60 |
+
model_name = REMOTE_DEFAULT_MODEL
|
| 61 |
+
else:
|
| 62 |
+
model_name = st.text_input(
|
| 63 |
+
"Model",
|
| 64 |
+
value=DEFAULT_MODEL,
|
| 65 |
+
key="sidebar__local_model",
|
| 66 |
+
help="Local model id or path.",
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
st.caption("Data")
|
| 70 |
+
dataset_source = st.selectbox(
|
| 71 |
+
"Source",
|
| 72 |
+
DATASET_SOURCES,
|
| 73 |
+
key="sidebar__dataset_source",
|
| 74 |
+
help="Dataset for Chat and Extract.",
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return remote, model_name, dataset_source, active_tab
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
_TABS = ["Chat", "Compare", "Extract"]
|
| 81 |
+
_TAB_ICONS = [":material/chat:", ":material/search:", ":material/tune:"]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def main() -> None:
|
| 85 |
+
"""Run the Streamlit app."""
|
| 86 |
+
|
| 87 |
+
# Deferred: importing torch is slow; keep it after dotenv load (done at
|
| 88 |
+
# module level above) so the Streamlit page config renders immediately.
|
| 89 |
+
import torch
|
| 90 |
+
|
| 91 |
+
torch.set_grad_enabled(False)
|
| 92 |
+
|
| 93 |
+
st.set_page_config(page_title="Persona UI", layout="wide")
|
| 94 |
+
remote, model_name, dataset_source, active_tab = _sidebar_controls()
|
| 95 |
+
|
| 96 |
+
if active_tab == "Extract":
|
| 97 |
+
from tabs.extract import render_extract_tab
|
| 98 |
+
|
| 99 |
+
render_extract_tab(remote, model_name, dataset_source)
|
| 100 |
+
elif active_tab == "Compare":
|
| 101 |
+
from tabs.compare import render_compare_tab
|
| 102 |
+
|
| 103 |
+
render_compare_tab(model_name)
|
| 104 |
+
else:
|
| 105 |
+
from tabs.chat import render_chat_tab
|
| 106 |
+
|
| 107 |
+
render_chat_tab(remote, model_name, dataset_source)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
if __name__ == "__main__":
|
| 111 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "persona-ui"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Streamlit UI for persona-vectors"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"persona-vectors",
|
| 9 |
+
"persona-data",
|
| 10 |
+
"nnterp>=1.3.0",
|
| 11 |
+
"streamlit>=1.44.0",
|
| 12 |
+
"plotly>=6.6.0",
|
| 13 |
+
"kaleido>=1.0.0",
|
| 14 |
+
"python-dotenv>=1.2.2",
|
| 15 |
+
"torch>=2.10.0",
|
| 16 |
+
"transformers>=5.2.0",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
[tool.uv.sources]
|
| 20 |
+
# NOTE: Switch to git sources after pushing the new package structure
|
| 21 |
+
persona-vectors = { git = "ssh://git@github.com/implicit-personalization/persona-vectors.git" }
|
| 22 |
+
persona-data = { git = "ssh://git@github.com/implicit-personalization/persona-data.git" }
|
| 23 |
+
# persona-vectors = { path = "../persona-vectors", editable = true }
|
| 24 |
+
# persona-data = { path = "../persona-data", editable = true }
|
| 25 |
+
|
| 26 |
+
# [build-system]
|
| 27 |
+
# requires = ["uv_build>=0.11.3,<0.12"]
|
| 28 |
+
# build-backend = "uv_build"
|
state.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
_CHAT_STATE_PREFIX = "chat_state::"
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def chat_session_key(model_name: str, dataset_source: str) -> str:
|
| 7 |
+
"""Build the session-state key for a chat context."""
|
| 8 |
+
|
| 9 |
+
return f"{_CHAT_STATE_PREFIX}{model_name}::{dataset_source}"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _default_chat_state() -> dict[str, object]:
|
| 13 |
+
return {
|
| 14 |
+
"messages": [],
|
| 15 |
+
"persona_id": None,
|
| 16 |
+
"prompt_mode": "templated",
|
| 17 |
+
"past_key_values": None,
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _evict_inactive_kv_caches(active_key: str) -> None:
|
| 22 |
+
"""Drop past_key_values from every chat context except the active one."""
|
| 23 |
+
|
| 24 |
+
for key in st.session_state:
|
| 25 |
+
if (
|
| 26 |
+
isinstance(key, str)
|
| 27 |
+
and key.startswith(_CHAT_STATE_PREFIX)
|
| 28 |
+
and key != active_key
|
| 29 |
+
):
|
| 30 |
+
state = st.session_state[key]
|
| 31 |
+
if isinstance(state, dict) and state.get("past_key_values") is not None:
|
| 32 |
+
state["past_key_values"] = None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_chat_state(
|
| 36 |
+
model_name: str, remote: bool, dataset_source: str
|
| 37 |
+
) -> dict[str, object]:
|
| 38 |
+
"""Return the mutable chat state for the active context."""
|
| 39 |
+
|
| 40 |
+
key = chat_session_key(model_name, dataset_source)
|
| 41 |
+
state = st.session_state.get(key)
|
| 42 |
+
if state is None:
|
| 43 |
+
state = _default_chat_state()
|
| 44 |
+
st.session_state[key] = state
|
| 45 |
+
else:
|
| 46 |
+
for default_key, default_value in _default_chat_state().items():
|
| 47 |
+
state.setdefault(default_key, default_value)
|
| 48 |
+
_evict_inactive_kv_caches(key)
|
| 49 |
+
if remote and state.get("past_key_values") is not None:
|
| 50 |
+
state["past_key_values"] = None
|
| 51 |
+
return state
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def reset_chat_state(model_name: str, remote: bool, dataset_source: str) -> None:
|
| 55 |
+
"""Reset chat history and cache for the active context."""
|
| 56 |
+
|
| 57 |
+
state = get_chat_state(model_name, remote, dataset_source)
|
| 58 |
+
state["messages"] = []
|
| 59 |
+
state["past_key_values"] = None
|
tabs/__init__.py
ADDED
|
File without changes
|
tabs/chat.py
ADDED
|
@@ -0,0 +1,636 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import threading
|
| 2 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 3 |
+
from contextlib import nullcontext
|
| 4 |
+
|
| 5 |
+
import streamlit as st
|
| 6 |
+
|
| 7 |
+
from state import chat_session_key, get_chat_state, reset_chat_state
|
| 8 |
+
from utils.chat import ChatReply, generate_chat_reply, resolve_system_prompt
|
| 9 |
+
from utils.chat_export import save_chat_export
|
| 10 |
+
from utils.datasets import load_dataset
|
| 11 |
+
from utils.helpers import (
|
| 12 |
+
MODE_LABEL_TO_KEY,
|
| 13 |
+
MODE_LABELS,
|
| 14 |
+
VARIANT_LABELS,
|
| 15 |
+
persona_label,
|
| 16 |
+
widget_key,
|
| 17 |
+
)
|
| 18 |
+
from utils.runtime import cached_model
|
| 19 |
+
|
| 20 |
+
_VISIBLE_MESSAGE_COUNT = 5
|
| 21 |
+
_model_lock = threading.Lock()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _render_chat_message(message: dict[str, str]) -> None:
|
| 25 |
+
if not message.get("content"):
|
| 26 |
+
return
|
| 27 |
+
with st.chat_message(message["role"]):
|
| 28 |
+
st.markdown(message["content"])
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _clear_chat_ui_state(*keys: str) -> None:
|
| 32 |
+
for key in keys:
|
| 33 |
+
st.session_state.pop(key, None)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _generation_dict(gen_kwargs: dict, advanced_generation: bool) -> dict[str, object]:
|
| 37 |
+
return {
|
| 38 |
+
"max_new_tokens": int(gen_kwargs["max_new_tokens"]),
|
| 39 |
+
"advanced_generation": bool(advanced_generation),
|
| 40 |
+
"use_sampling": bool(gen_kwargs["do_sample"]),
|
| 41 |
+
"temperature": float(gen_kwargs["temperature"]),
|
| 42 |
+
"top_p": float(gen_kwargs["top_p"]),
|
| 43 |
+
"top_k": int(gen_kwargs["top_k"]),
|
| 44 |
+
"repetition_penalty": float(gen_kwargs["repetition_penalty"]),
|
| 45 |
+
"seed": gen_kwargs["seed"],
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ── Compare mode helpers ───────────────────────────────────────────────────────
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _panel_state(panel_key: str) -> dict:
|
| 53 |
+
"""Get or initialise compare-panel chat state stored in session_state."""
|
| 54 |
+
if panel_key not in st.session_state:
|
| 55 |
+
st.session_state[panel_key] = {
|
| 56 |
+
"messages": [],
|
| 57 |
+
"persona_id": None,
|
| 58 |
+
"prompt_mode": "templated",
|
| 59 |
+
"past_key_values": None,
|
| 60 |
+
}
|
| 61 |
+
return st.session_state[panel_key]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _render_compare_panel(
|
| 65 |
+
side: str,
|
| 66 |
+
context_key: str,
|
| 67 |
+
personas: list,
|
| 68 |
+
remote: bool,
|
| 69 |
+
model_name: str,
|
| 70 |
+
dataset_source: str,
|
| 71 |
+
gen_kwargs: dict,
|
| 72 |
+
advanced_generation: bool,
|
| 73 |
+
) -> dict:
|
| 74 |
+
"""Render persona/prompt controls + chat log for one compare panel.
|
| 75 |
+
|
| 76 |
+
Returns a dict with keys needed by the generation step:
|
| 77 |
+
panel_key, state, active_system_prompt, selected_persona, chat_log
|
| 78 |
+
"""
|
| 79 |
+
panel_key = widget_key(context_key, f"cmp_{side}")
|
| 80 |
+
state = _panel_state(panel_key)
|
| 81 |
+
|
| 82 |
+
# ── Per-panel selectors ──────────────────────────────────────────────────
|
| 83 |
+
p_col, m_col = st.columns([3, 2])
|
| 84 |
+
with p_col:
|
| 85 |
+
selected_index = next(
|
| 86 |
+
(i for i, p in enumerate(personas) if p.id == state["persona_id"]), 0
|
| 87 |
+
)
|
| 88 |
+
selected_persona = st.selectbox(
|
| 89 |
+
"Persona",
|
| 90 |
+
options=personas,
|
| 91 |
+
index=selected_index,
|
| 92 |
+
format_func=persona_label,
|
| 93 |
+
key=widget_key(panel_key, "persona"),
|
| 94 |
+
)
|
| 95 |
+
with m_col:
|
| 96 |
+
current_label = VARIANT_LABELS.get(state["prompt_mode"], "None")
|
| 97 |
+
prompt_mode_label = st.selectbox(
|
| 98 |
+
"Prompt",
|
| 99 |
+
options=MODE_LABELS,
|
| 100 |
+
index=MODE_LABELS.index(current_label),
|
| 101 |
+
key=widget_key(panel_key, "prompt_mode"),
|
| 102 |
+
)
|
| 103 |
+
prompt_mode = MODE_LABEL_TO_KEY[prompt_mode_label]
|
| 104 |
+
|
| 105 |
+
# Reset state when persona or mode changes.
|
| 106 |
+
changed = (
|
| 107 |
+
state["persona_id"] != selected_persona.id
|
| 108 |
+
or state["prompt_mode"] != prompt_mode
|
| 109 |
+
)
|
| 110 |
+
if changed:
|
| 111 |
+
state["messages"] = []
|
| 112 |
+
state["past_key_values"] = None
|
| 113 |
+
state["persona_id"] = selected_persona.id
|
| 114 |
+
state["prompt_mode"] = prompt_mode
|
| 115 |
+
_clear_chat_ui_state(
|
| 116 |
+
widget_key(panel_key, "custom_prompt"),
|
| 117 |
+
widget_key(panel_key, "show_all"),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# ── System prompt ────────────────────────────────────────────────────────
|
| 121 |
+
active_system_prompt = resolve_system_prompt(
|
| 122 |
+
persona=selected_persona, mode=prompt_mode
|
| 123 |
+
)
|
| 124 |
+
custom_prompt_key = widget_key(panel_key, "custom_prompt")
|
| 125 |
+
if prompt_mode != "empty":
|
| 126 |
+
if custom_prompt_key not in st.session_state:
|
| 127 |
+
st.session_state[custom_prompt_key] = active_system_prompt
|
| 128 |
+
with st.expander("Edit prompt", expanded=False):
|
| 129 |
+
active_system_prompt = (
|
| 130 |
+
st.text_area(
|
| 131 |
+
"prompt",
|
| 132 |
+
key=custom_prompt_key,
|
| 133 |
+
height=150,
|
| 134 |
+
label_visibility="collapsed",
|
| 135 |
+
)
|
| 136 |
+
or None
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
export_success_message: str | None = None
|
| 140 |
+
action_col1, action_col2 = st.columns(2)
|
| 141 |
+
with action_col1:
|
| 142 |
+
if st.button(
|
| 143 |
+
"Export chat",
|
| 144 |
+
key=widget_key(panel_key, "export_chat"),
|
| 145 |
+
use_container_width=True,
|
| 146 |
+
):
|
| 147 |
+
export_path = save_chat_export(
|
| 148 |
+
model_name=model_name,
|
| 149 |
+
dataset_source=dataset_source,
|
| 150 |
+
persona_id=selected_persona.id,
|
| 151 |
+
persona_name=getattr(selected_persona, "name", None),
|
| 152 |
+
panel_label=side,
|
| 153 |
+
prompt_mode=prompt_mode,
|
| 154 |
+
system_prompt=active_system_prompt,
|
| 155 |
+
messages=state["messages"],
|
| 156 |
+
generation=_generation_dict(gen_kwargs, advanced_generation),
|
| 157 |
+
)
|
| 158 |
+
export_success_message = f"Saved chat export to {export_path}"
|
| 159 |
+
with action_col2:
|
| 160 |
+
if st.button(
|
| 161 |
+
"Reset chat",
|
| 162 |
+
key=widget_key(panel_key, "reset"),
|
| 163 |
+
use_container_width=True,
|
| 164 |
+
type="secondary",
|
| 165 |
+
):
|
| 166 |
+
state["messages"] = []
|
| 167 |
+
state["past_key_values"] = None
|
| 168 |
+
_clear_chat_ui_state(
|
| 169 |
+
widget_key(panel_key, "custom_prompt"),
|
| 170 |
+
widget_key(panel_key, "show_all"),
|
| 171 |
+
)
|
| 172 |
+
st.rerun()
|
| 173 |
+
|
| 174 |
+
if export_success_message:
|
| 175 |
+
st.success(export_success_message)
|
| 176 |
+
|
| 177 |
+
# ── Message history ──────────────────────────────────────────────────────
|
| 178 |
+
show_all_key = widget_key(panel_key, "show_all")
|
| 179 |
+
messages = state["messages"]
|
| 180 |
+
if len(messages) > _VISIBLE_MESSAGE_COUNT and not st.session_state.get(
|
| 181 |
+
show_all_key, False
|
| 182 |
+
):
|
| 183 |
+
hidden_count = len(messages) - _VISIBLE_MESSAGE_COUNT
|
| 184 |
+
if st.button(
|
| 185 |
+
f"Show earlier ({hidden_count} hidden)",
|
| 186 |
+
key=widget_key(panel_key, "show_all_btn"),
|
| 187 |
+
):
|
| 188 |
+
st.session_state[show_all_key] = True
|
| 189 |
+
st.rerun()
|
| 190 |
+
visible = messages[-_VISIBLE_MESSAGE_COUNT:]
|
| 191 |
+
else:
|
| 192 |
+
visible = messages
|
| 193 |
+
|
| 194 |
+
chat_log = st.container()
|
| 195 |
+
with chat_log:
|
| 196 |
+
for msg in visible:
|
| 197 |
+
_render_chat_message(msg)
|
| 198 |
+
|
| 199 |
+
return {
|
| 200 |
+
"panel_key": panel_key,
|
| 201 |
+
"state": state,
|
| 202 |
+
"active_system_prompt": active_system_prompt,
|
| 203 |
+
"selected_persona": selected_persona,
|
| 204 |
+
"chat_log": chat_log,
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def _generate_for_panel(
|
| 209 |
+
panel: dict,
|
| 210 |
+
model,
|
| 211 |
+
remote: bool,
|
| 212 |
+
gen_kwargs: dict,
|
| 213 |
+
) -> ChatReply:
|
| 214 |
+
"""Run generate_chat_reply for one compare panel. Thread-safe."""
|
| 215 |
+
messages = []
|
| 216 |
+
if panel["active_system_prompt"]:
|
| 217 |
+
messages.append({"role": "system", "content": panel["active_system_prompt"]})
|
| 218 |
+
messages.extend(panel["state"]["messages"])
|
| 219 |
+
|
| 220 |
+
ctx = nullcontext() if remote else _model_lock
|
| 221 |
+
with ctx:
|
| 222 |
+
return generate_chat_reply(
|
| 223 |
+
model=model,
|
| 224 |
+
messages=messages,
|
| 225 |
+
remote=remote,
|
| 226 |
+
past_key_values=panel["state"]["past_key_values"],
|
| 227 |
+
**gen_kwargs,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def _render_compare_mode(
|
| 232 |
+
remote: bool,
|
| 233 |
+
model_name: str,
|
| 234 |
+
context_key: str,
|
| 235 |
+
dataset_source: str,
|
| 236 |
+
personas: list,
|
| 237 |
+
gen_kwargs: dict,
|
| 238 |
+
advanced_generation: bool,
|
| 239 |
+
) -> None:
|
| 240 |
+
"""Render the full side-by-side comparison UI."""
|
| 241 |
+
left_col, right_col = st.columns(2)
|
| 242 |
+
|
| 243 |
+
with left_col:
|
| 244 |
+
left = _render_compare_panel(
|
| 245 |
+
"left",
|
| 246 |
+
context_key,
|
| 247 |
+
personas,
|
| 248 |
+
remote,
|
| 249 |
+
model_name,
|
| 250 |
+
dataset_source,
|
| 251 |
+
gen_kwargs,
|
| 252 |
+
advanced_generation,
|
| 253 |
+
)
|
| 254 |
+
with right_col:
|
| 255 |
+
right = _render_compare_panel(
|
| 256 |
+
"right",
|
| 257 |
+
context_key,
|
| 258 |
+
personas,
|
| 259 |
+
remote,
|
| 260 |
+
model_name,
|
| 261 |
+
dataset_source,
|
| 262 |
+
gen_kwargs,
|
| 263 |
+
advanced_generation,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
user_prompt = st.chat_input(
|
| 267 |
+
"Ask both...",
|
| 268 |
+
key=widget_key(context_key, "cmp_input"),
|
| 269 |
+
)
|
| 270 |
+
if not user_prompt:
|
| 271 |
+
return
|
| 272 |
+
|
| 273 |
+
model = cached_model(model_name=model_name, remote=remote)
|
| 274 |
+
panels = [(left, left_col), (right, right_col)]
|
| 275 |
+
|
| 276 |
+
for panel, col in panels:
|
| 277 |
+
panel["state"]["messages"].append({"role": "user", "content": user_prompt})
|
| 278 |
+
with col:
|
| 279 |
+
with panel["chat_log"]:
|
| 280 |
+
_render_chat_message({"role": "user", "content": user_prompt})
|
| 281 |
+
|
| 282 |
+
# Generate both responses in parallel (remote: truly concurrent; local: serialised via lock).
|
| 283 |
+
with st.spinner("Generating..."):
|
| 284 |
+
with ThreadPoolExecutor(max_workers=2) as executor:
|
| 285 |
+
futures = [
|
| 286 |
+
executor.submit(_generate_for_panel, panel, model, remote, gen_kwargs)
|
| 287 |
+
for panel, col in panels
|
| 288 |
+
]
|
| 289 |
+
results = []
|
| 290 |
+
for future in futures:
|
| 291 |
+
try:
|
| 292 |
+
results.append(future.result())
|
| 293 |
+
except Exception as exc:
|
| 294 |
+
results.append(exc)
|
| 295 |
+
|
| 296 |
+
for (panel, col), result in zip(panels, results):
|
| 297 |
+
if isinstance(result, Exception):
|
| 298 |
+
with col:
|
| 299 |
+
with panel["chat_log"]:
|
| 300 |
+
st.error(f"Generation failed: {result}")
|
| 301 |
+
panel["state"]["messages"].pop()
|
| 302 |
+
continue
|
| 303 |
+
|
| 304 |
+
panel["state"]["messages"].append({"role": "assistant", "content": result.text})
|
| 305 |
+
panel["state"]["past_key_values"] = (
|
| 306 |
+
result.past_key_values if not remote else None
|
| 307 |
+
)
|
| 308 |
+
with col:
|
| 309 |
+
with panel["chat_log"]:
|
| 310 |
+
_render_chat_message({"role": "assistant", "content": result.text})
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# ── Main tab entry point ───────────────────────────────────────────────────────
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
| 317 |
+
"""Render the chat tab."""
|
| 318 |
+
|
| 319 |
+
st.title("Chat")
|
| 320 |
+
|
| 321 |
+
context_key = chat_session_key(model_name, dataset_source)
|
| 322 |
+
chat_state = get_chat_state(model_name, remote, dataset_source)
|
| 323 |
+
try:
|
| 324 |
+
dataset, dataset_status = load_dataset(dataset_source)
|
| 325 |
+
st.caption(dataset_status)
|
| 326 |
+
except Exception as exc:
|
| 327 |
+
st.error(f"Could not load data: {exc}")
|
| 328 |
+
st.info("Check the selected dataset source or upload both JSONL files.")
|
| 329 |
+
return
|
| 330 |
+
|
| 331 |
+
personas = list(dataset)
|
| 332 |
+
if not personas:
|
| 333 |
+
st.warning("No personas found in the selected dataset.")
|
| 334 |
+
st.info("Try a different dataset source or upload a non-empty personas file.")
|
| 335 |
+
return
|
| 336 |
+
|
| 337 |
+
# ── Generation settings ───────────────────────────────────────────────────
|
| 338 |
+
with st.expander("Advanced", expanded=False):
|
| 339 |
+
config_col1, config_col2 = st.columns([2, 1])
|
| 340 |
+
with config_col1:
|
| 341 |
+
max_new_tokens = st.slider(
|
| 342 |
+
"Max new tokens",
|
| 343 |
+
min_value=16,
|
| 344 |
+
max_value=512,
|
| 345 |
+
value=256,
|
| 346 |
+
step=16,
|
| 347 |
+
key=widget_key(context_key, "max_new_tokens"),
|
| 348 |
+
)
|
| 349 |
+
with config_col2:
|
| 350 |
+
repetition_penalty = st.slider(
|
| 351 |
+
"Repetition penalty",
|
| 352 |
+
min_value=0.5,
|
| 353 |
+
max_value=2.0,
|
| 354 |
+
value=1.0,
|
| 355 |
+
step=0.05,
|
| 356 |
+
key=widget_key(context_key, "repetition_penalty"),
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
use_sampling = st.checkbox(
|
| 360 |
+
"Random sampling",
|
| 361 |
+
value=False,
|
| 362 |
+
key=widget_key(context_key, "use_sampling"),
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
sampling_disabled = not use_sampling
|
| 366 |
+
sampling_col1, sampling_col2, sampling_col3 = st.columns(3)
|
| 367 |
+
with sampling_col1:
|
| 368 |
+
temperature = st.slider(
|
| 369 |
+
"Temperature",
|
| 370 |
+
min_value=0.01,
|
| 371 |
+
max_value=2.0,
|
| 372 |
+
value=1.0,
|
| 373 |
+
step=0.01,
|
| 374 |
+
disabled=sampling_disabled,
|
| 375 |
+
key=widget_key(context_key, "temperature"),
|
| 376 |
+
)
|
| 377 |
+
with sampling_col2:
|
| 378 |
+
top_p = st.slider(
|
| 379 |
+
"Top-p",
|
| 380 |
+
min_value=0.01,
|
| 381 |
+
max_value=1.0,
|
| 382 |
+
value=1.0,
|
| 383 |
+
step=0.01,
|
| 384 |
+
disabled=sampling_disabled,
|
| 385 |
+
key=widget_key(context_key, "top_p"),
|
| 386 |
+
)
|
| 387 |
+
with sampling_col3:
|
| 388 |
+
top_k = st.slider(
|
| 389 |
+
"Top-k (0 = off)",
|
| 390 |
+
min_value=0,
|
| 391 |
+
max_value=100,
|
| 392 |
+
value=50,
|
| 393 |
+
step=1,
|
| 394 |
+
disabled=sampling_disabled,
|
| 395 |
+
key=widget_key(context_key, "top_k"),
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
seed_disabled = sampling_disabled or remote
|
| 399 |
+
seed_enabled = st.checkbox(
|
| 400 |
+
"Fix seed",
|
| 401 |
+
value=False,
|
| 402 |
+
disabled=seed_disabled,
|
| 403 |
+
key=widget_key(context_key, "seed_enabled"),
|
| 404 |
+
)
|
| 405 |
+
if seed_enabled:
|
| 406 |
+
seed = int(
|
| 407 |
+
st.number_input(
|
| 408 |
+
"Seed",
|
| 409 |
+
min_value=0,
|
| 410 |
+
max_value=2_147_483_647,
|
| 411 |
+
value=0,
|
| 412 |
+
step=1,
|
| 413 |
+
disabled=seed_disabled,
|
| 414 |
+
key=widget_key(context_key, "seed"),
|
| 415 |
+
)
|
| 416 |
+
)
|
| 417 |
+
else:
|
| 418 |
+
seed = None
|
| 419 |
+
|
| 420 |
+
if remote:
|
| 421 |
+
st.caption("Seed is local-only and disabled for remote runs.")
|
| 422 |
+
|
| 423 |
+
advanced_generation = (
|
| 424 |
+
max_new_tokens != 256
|
| 425 |
+
or use_sampling
|
| 426 |
+
or temperature != 1.0
|
| 427 |
+
or top_p != 1.0
|
| 428 |
+
or top_k != 50
|
| 429 |
+
or repetition_penalty != 1.0
|
| 430 |
+
or seed is not None
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
do_sample = bool(use_sampling)
|
| 434 |
+
generation_seed = seed if do_sample and seed is not None and not remote else None
|
| 435 |
+
gen_kwargs = dict(
|
| 436 |
+
max_new_tokens=int(max_new_tokens),
|
| 437 |
+
do_sample=do_sample,
|
| 438 |
+
temperature=temperature,
|
| 439 |
+
top_p=top_p,
|
| 440 |
+
top_k=top_k,
|
| 441 |
+
repetition_penalty=repetition_penalty,
|
| 442 |
+
seed=generation_seed,
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# ── Mode toggle ───────────────────────────────────────────────────────────
|
| 446 |
+
compare_mode = st.toggle(
|
| 447 |
+
"Compare mode",
|
| 448 |
+
value=False,
|
| 449 |
+
key=widget_key(context_key, "compare_mode"),
|
| 450 |
+
help="Side-by-side: send one message to two independent persona/prompt configurations.",
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
if compare_mode:
|
| 454 |
+
_render_compare_mode(
|
| 455 |
+
remote,
|
| 456 |
+
model_name,
|
| 457 |
+
context_key,
|
| 458 |
+
dataset_source,
|
| 459 |
+
personas,
|
| 460 |
+
gen_kwargs,
|
| 461 |
+
advanced_generation,
|
| 462 |
+
)
|
| 463 |
+
return
|
| 464 |
+
|
| 465 |
+
# ── Single-chat mode ──────────────────────────────────────────────────────
|
| 466 |
+
persona_select_key = widget_key(context_key, "persona_select")
|
| 467 |
+
prompt_mode_select_key = widget_key(context_key, "system_prompt_select")
|
| 468 |
+
|
| 469 |
+
col1, col2 = st.columns([2, 1])
|
| 470 |
+
with col1:
|
| 471 |
+
selected_index = next(
|
| 472 |
+
(i for i, p in enumerate(personas) if p.id == chat_state["persona_id"]),
|
| 473 |
+
0,
|
| 474 |
+
)
|
| 475 |
+
selected_persona = st.selectbox(
|
| 476 |
+
"Persona",
|
| 477 |
+
options=personas,
|
| 478 |
+
index=selected_index,
|
| 479 |
+
format_func=persona_label,
|
| 480 |
+
key=persona_select_key,
|
| 481 |
+
)
|
| 482 |
+
with col2:
|
| 483 |
+
current_mode_label = VARIANT_LABELS.get(chat_state["prompt_mode"], "None")
|
| 484 |
+
prompt_mode_label = st.selectbox(
|
| 485 |
+
"Prompt",
|
| 486 |
+
options=MODE_LABELS,
|
| 487 |
+
index=MODE_LABELS.index(current_mode_label),
|
| 488 |
+
key=prompt_mode_select_key,
|
| 489 |
+
)
|
| 490 |
+
prompt_mode = MODE_LABEL_TO_KEY[prompt_mode_label]
|
| 491 |
+
|
| 492 |
+
active_system_prompt = resolve_system_prompt(
|
| 493 |
+
persona=selected_persona,
|
| 494 |
+
mode=prompt_mode,
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
chat_input_key = widget_key(context_key, "chat_input")
|
| 498 |
+
show_all_key = widget_key(context_key, "show_all_messages")
|
| 499 |
+
custom_prompt_key = widget_key(context_key, "custom_system_prompt")
|
| 500 |
+
pending_key = widget_key(context_key, "pending_prompt")
|
| 501 |
+
export_success_message: str | None = None
|
| 502 |
+
|
| 503 |
+
action_col1, action_col2 = st.columns(2)
|
| 504 |
+
with action_col1:
|
| 505 |
+
if st.button("Reset chat", use_container_width=True, type="secondary"):
|
| 506 |
+
reset_chat_state(model_name, remote, dataset_source)
|
| 507 |
+
_clear_chat_ui_state(
|
| 508 |
+
chat_input_key,
|
| 509 |
+
show_all_key,
|
| 510 |
+
custom_prompt_key,
|
| 511 |
+
pending_key,
|
| 512 |
+
)
|
| 513 |
+
st.rerun()
|
| 514 |
+
with action_col2:
|
| 515 |
+
if st.button("Export chat", use_container_width=True):
|
| 516 |
+
export_path = save_chat_export(
|
| 517 |
+
model_name=model_name,
|
| 518 |
+
dataset_source=dataset_source,
|
| 519 |
+
persona_id=selected_persona.id,
|
| 520 |
+
persona_name=getattr(selected_persona, "name", None),
|
| 521 |
+
prompt_mode=prompt_mode,
|
| 522 |
+
system_prompt=active_system_prompt,
|
| 523 |
+
messages=chat_state["messages"],
|
| 524 |
+
generation=_generation_dict(gen_kwargs, advanced_generation),
|
| 525 |
+
)
|
| 526 |
+
export_success_message = f"Saved chat export to {export_path}"
|
| 527 |
+
|
| 528 |
+
if export_success_message:
|
| 529 |
+
st.success(export_success_message)
|
| 530 |
+
|
| 531 |
+
changed_context = (
|
| 532 |
+
chat_state["persona_id"] != selected_persona.id
|
| 533 |
+
or chat_state["prompt_mode"] != prompt_mode
|
| 534 |
+
)
|
| 535 |
+
if changed_context:
|
| 536 |
+
had_history = bool(chat_state["messages"])
|
| 537 |
+
chat_state["persona_id"] = selected_persona.id
|
| 538 |
+
chat_state["prompt_mode"] = prompt_mode
|
| 539 |
+
reset_chat_state(model_name, remote, dataset_source)
|
| 540 |
+
_clear_chat_ui_state(
|
| 541 |
+
chat_input_key,
|
| 542 |
+
show_all_key,
|
| 543 |
+
custom_prompt_key,
|
| 544 |
+
pending_key,
|
| 545 |
+
)
|
| 546 |
+
if had_history:
|
| 547 |
+
st.info("Chat history reset because the persona or system prompt changed.")
|
| 548 |
+
|
| 549 |
+
chat_log = st.container()
|
| 550 |
+
|
| 551 |
+
with chat_log:
|
| 552 |
+
# System prompt as first item in conversation — collapsed by default, editable.
|
| 553 |
+
if prompt_mode != "empty":
|
| 554 |
+
if custom_prompt_key not in st.session_state:
|
| 555 |
+
st.session_state[custom_prompt_key] = active_system_prompt
|
| 556 |
+
with st.expander("Edit prompt", expanded=False):
|
| 557 |
+
active_system_prompt = (
|
| 558 |
+
st.text_area(
|
| 559 |
+
"Prompt",
|
| 560 |
+
key=custom_prompt_key,
|
| 561 |
+
height=200,
|
| 562 |
+
label_visibility="collapsed",
|
| 563 |
+
)
|
| 564 |
+
or None
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
# Collapse older messages, show only the most recent ones.
|
| 568 |
+
messages = chat_state["messages"]
|
| 569 |
+
if len(messages) > _VISIBLE_MESSAGE_COUNT and not st.session_state.get(
|
| 570 |
+
show_all_key, False
|
| 571 |
+
):
|
| 572 |
+
hidden_count = len(messages) - _VISIBLE_MESSAGE_COUNT
|
| 573 |
+
if st.button(
|
| 574 |
+
f"Show earlier messages ({hidden_count} hidden)",
|
| 575 |
+
key=widget_key(context_key, "show_all_btn"),
|
| 576 |
+
):
|
| 577 |
+
st.session_state[show_all_key] = True
|
| 578 |
+
st.rerun()
|
| 579 |
+
visible_messages = messages[-_VISIBLE_MESSAGE_COUNT:]
|
| 580 |
+
else:
|
| 581 |
+
visible_messages = messages
|
| 582 |
+
|
| 583 |
+
for message in visible_messages:
|
| 584 |
+
_render_chat_message(message)
|
| 585 |
+
|
| 586 |
+
user_prompt = st.chat_input(
|
| 587 |
+
"Ask something...",
|
| 588 |
+
key=chat_input_key,
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
# Pass 1: user submitted — append message and rerun so it renders before generation.
|
| 592 |
+
if user_prompt:
|
| 593 |
+
chat_state["messages"].append({"role": "user", "content": user_prompt})
|
| 594 |
+
st.session_state[pending_key] = True
|
| 595 |
+
st.rerun()
|
| 596 |
+
|
| 597 |
+
# Pass 2: message is already rendered above; now run generation.
|
| 598 |
+
if not st.session_state.pop(pending_key, False):
|
| 599 |
+
return
|
| 600 |
+
|
| 601 |
+
messages = []
|
| 602 |
+
if active_system_prompt:
|
| 603 |
+
messages.append({"role": "system", "content": active_system_prompt})
|
| 604 |
+
messages.extend(chat_state["messages"])
|
| 605 |
+
|
| 606 |
+
with st.spinner("Generating reply..."):
|
| 607 |
+
model = cached_model(model_name=model_name, remote=remote)
|
| 608 |
+
try:
|
| 609 |
+
reply: ChatReply = generate_chat_reply(
|
| 610 |
+
model=model,
|
| 611 |
+
messages=messages,
|
| 612 |
+
remote=remote,
|
| 613 |
+
past_key_values=chat_state["past_key_values"],
|
| 614 |
+
**gen_kwargs,
|
| 615 |
+
)
|
| 616 |
+
except Exception as exc:
|
| 617 |
+
with chat_log:
|
| 618 |
+
st.error(f"Could not generate a reply: {exc}")
|
| 619 |
+
st.info("Try a shorter prompt, reset the chat, or switch personas.")
|
| 620 |
+
chat_state["messages"].pop()
|
| 621 |
+
return
|
| 622 |
+
|
| 623 |
+
chat_state["messages"].append({"role": "assistant", "content": reply.text})
|
| 624 |
+
chat_state["past_key_values"] = reply.past_key_values if not remote else None
|
| 625 |
+
|
| 626 |
+
save_chat_export(
|
| 627 |
+
model_name=model_name,
|
| 628 |
+
dataset_source=dataset_source,
|
| 629 |
+
persona_id=selected_persona.id,
|
| 630 |
+
persona_name=getattr(selected_persona, "name", None),
|
| 631 |
+
prompt_mode=prompt_mode,
|
| 632 |
+
system_prompt=active_system_prompt,
|
| 633 |
+
messages=chat_state["messages"],
|
| 634 |
+
generation=_generation_dict(gen_kwargs, advanced_generation),
|
| 635 |
+
)
|
| 636 |
+
st.rerun()
|
tabs/compare.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
from persona_data.environment import get_artifacts_dir
|
| 3 |
+
from persona_vectors.analysis import build_embedding_figure, project_pca, project_umap
|
| 4 |
+
from persona_vectors.plots import (
|
| 5 |
+
plot_multiple_layer_similarities,
|
| 6 |
+
save_plot_html,
|
| 7 |
+
save_plot_png,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
from utils.artifacts import (
|
| 11 |
+
artifact_persona_options,
|
| 12 |
+
list_available_layers,
|
| 13 |
+
load_cosine_traces,
|
| 14 |
+
load_embedding_samples,
|
| 15 |
+
)
|
| 16 |
+
from utils.helpers import (
|
| 17 |
+
ANALYSIS_HELP_TEXT,
|
| 18 |
+
ANALYSIS_LABELS,
|
| 19 |
+
ANALYSIS_MODES,
|
| 20 |
+
PROMPT_VARIANTS,
|
| 21 |
+
persona_display_label,
|
| 22 |
+
prompt_variant_label,
|
| 23 |
+
slugify,
|
| 24 |
+
widget_key,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _filename(*parts: str) -> str:
|
| 29 |
+
return "__".join(slugify(part) for part in parts if part)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _select_artifact_personas(
|
| 33 |
+
artifacts_root: str,
|
| 34 |
+
model_name: str,
|
| 35 |
+
variants: list[str],
|
| 36 |
+
) -> tuple[list[str], dict[str, str]]:
|
| 37 |
+
persona_options, persona_names = artifact_persona_options(
|
| 38 |
+
artifacts_root,
|
| 39 |
+
model_name,
|
| 40 |
+
variants,
|
| 41 |
+
)
|
| 42 |
+
if not persona_options:
|
| 43 |
+
if len(variants) > 1:
|
| 44 |
+
st.info(
|
| 45 |
+
"No personas have saved activations for all selected variants. Run extraction for both variants first."
|
| 46 |
+
)
|
| 47 |
+
else:
|
| 48 |
+
st.info("No personas found for this model yet. Run extraction first.")
|
| 49 |
+
return [], persona_names
|
| 50 |
+
|
| 51 |
+
persona_ids = st.multiselect(
|
| 52 |
+
"Personas",
|
| 53 |
+
options=persona_options,
|
| 54 |
+
default=persona_options[:1] if len(persona_options) > 1 else persona_options,
|
| 55 |
+
format_func=lambda persona_id: persona_display_label(
|
| 56 |
+
persona_id, persona_names.get(persona_id)
|
| 57 |
+
),
|
| 58 |
+
key=widget_key("load", "personas", model_name, *variants),
|
| 59 |
+
)
|
| 60 |
+
return persona_ids, persona_names
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _render_cosine_similarity(
|
| 64 |
+
artifacts_root: str,
|
| 65 |
+
model_name: str,
|
| 66 |
+
) -> None:
|
| 67 |
+
col1, col2 = st.columns(2)
|
| 68 |
+
with col1:
|
| 69 |
+
variant_a = st.selectbox(
|
| 70 |
+
"Variant A",
|
| 71 |
+
options=PROMPT_VARIANTS,
|
| 72 |
+
index=0,
|
| 73 |
+
format_func=prompt_variant_label,
|
| 74 |
+
key=widget_key("load", "variant_a"),
|
| 75 |
+
)
|
| 76 |
+
with col2:
|
| 77 |
+
variant_b = st.selectbox(
|
| 78 |
+
"Variant B",
|
| 79 |
+
options=PROMPT_VARIANTS,
|
| 80 |
+
index=min(1, len(PROMPT_VARIANTS) - 1),
|
| 81 |
+
format_func=prompt_variant_label,
|
| 82 |
+
key=widget_key("load", "variant_b"),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
if variant_a == variant_b:
|
| 86 |
+
st.warning("Choose two different variants to compare.")
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
persona_ids, _ = _select_artifact_personas(
|
| 90 |
+
artifacts_root,
|
| 91 |
+
model_name,
|
| 92 |
+
[variant_a, variant_b],
|
| 93 |
+
)
|
| 94 |
+
if not persona_ids:
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
cosine_fig_key = widget_key("load", "cosine_fig_state", model_name)
|
| 98 |
+
filename = _filename("compare", "cosine", model_name, variant_a, variant_b)
|
| 99 |
+
|
| 100 |
+
if st.button("Compare vectors", type="primary"):
|
| 101 |
+
traces, loaded_names, errors = load_cosine_traces(
|
| 102 |
+
artifacts_root,
|
| 103 |
+
model_name,
|
| 104 |
+
persona_ids,
|
| 105 |
+
variant_a,
|
| 106 |
+
variant_b,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if errors:
|
| 110 |
+
for err in errors:
|
| 111 |
+
st.error(f"Failed to load vectors: `{err}`")
|
| 112 |
+
if not traces:
|
| 113 |
+
st.error("No personas loaded successfully.")
|
| 114 |
+
st.info(
|
| 115 |
+
"Check that extraction has been run for both variants and selected personas."
|
| 116 |
+
)
|
| 117 |
+
st.session_state.pop(cosine_fig_key, None)
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
display_traces = [
|
| 121 |
+
(
|
| 122 |
+
persona_display_label(persona_id, loaded_names.get(persona_id)),
|
| 123 |
+
short,
|
| 124 |
+
long,
|
| 125 |
+
)
|
| 126 |
+
for persona_id, short, long in traces
|
| 127 |
+
]
|
| 128 |
+
fig = plot_multiple_layer_similarities(
|
| 129 |
+
display_traces,
|
| 130 |
+
title=f"{prompt_variant_label(variant_a)} vs {prompt_variant_label(variant_b)}",
|
| 131 |
+
show=False,
|
| 132 |
+
)
|
| 133 |
+
st.session_state[cosine_fig_key] = (fig, len(traces))
|
| 134 |
+
|
| 135 |
+
if cosine_fig_key in st.session_state:
|
| 136 |
+
fig, n_traces = st.session_state[cosine_fig_key]
|
| 137 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 138 |
+
save_col1, save_col2 = st.columns(2)
|
| 139 |
+
with save_col1:
|
| 140 |
+
if st.button("Save HTML", key=widget_key("load", "save_cosine_html")):
|
| 141 |
+
output_path = save_plot_html(fig, filename)
|
| 142 |
+
st.success(f"Saved HTML to `{output_path}`")
|
| 143 |
+
with save_col2:
|
| 144 |
+
if st.button("Save PNG", key=widget_key("load", "save_cosine_png")):
|
| 145 |
+
try:
|
| 146 |
+
output_path = save_plot_png(fig, filename)
|
| 147 |
+
st.success(f"Saved PNG to `{output_path}`")
|
| 148 |
+
except Exception as exc:
|
| 149 |
+
st.error(f"Could not save PNG: {exc}")
|
| 150 |
+
st.success(f"Loaded {n_traces} personas for cosine comparison.")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _render_embedding_analysis(
|
| 154 |
+
artifacts_root: str,
|
| 155 |
+
model_name: str,
|
| 156 |
+
analysis_mode: str,
|
| 157 |
+
) -> None:
|
| 158 |
+
selected_variant = st.selectbox(
|
| 159 |
+
"Variant",
|
| 160 |
+
options=PROMPT_VARIANTS,
|
| 161 |
+
format_func=prompt_variant_label,
|
| 162 |
+
key=widget_key("load", "variant"),
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
persona_ids, persona_names = _select_artifact_personas(
|
| 166 |
+
artifacts_root,
|
| 167 |
+
model_name,
|
| 168 |
+
[selected_variant],
|
| 169 |
+
)
|
| 170 |
+
if not persona_ids:
|
| 171 |
+
return
|
| 172 |
+
|
| 173 |
+
layer_options = list_available_layers(
|
| 174 |
+
artifacts_root,
|
| 175 |
+
model_name,
|
| 176 |
+
[selected_variant],
|
| 177 |
+
persona_ids,
|
| 178 |
+
)
|
| 179 |
+
if not layer_options:
|
| 180 |
+
st.info(
|
| 181 |
+
"No shared layers are available for the selected personas. Try fewer personas or a different variant."
|
| 182 |
+
)
|
| 183 |
+
return
|
| 184 |
+
|
| 185 |
+
persona_key = "_".join(sorted(persona_ids))
|
| 186 |
+
layer_key = widget_key("load", "layers", model_name, selected_variant, persona_key)
|
| 187 |
+
default_layers = [
|
| 188 |
+
layer
|
| 189 |
+
for layer in st.session_state.get(layer_key, layer_options[:3])
|
| 190 |
+
if layer in layer_options
|
| 191 |
+
] or layer_options[:3]
|
| 192 |
+
selected_layers = st.multiselect(
|
| 193 |
+
"Layers",
|
| 194 |
+
options=layer_options,
|
| 195 |
+
default=default_layers,
|
| 196 |
+
key=layer_key,
|
| 197 |
+
)
|
| 198 |
+
if not selected_layers:
|
| 199 |
+
st.info("Select at least one layer.")
|
| 200 |
+
return
|
| 201 |
+
|
| 202 |
+
button_label = (
|
| 203 |
+
"Generate PCA projection"
|
| 204 |
+
if analysis_mode == "PCA"
|
| 205 |
+
else "Generate UMAP projection"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
embedding_fig_key = widget_key(
|
| 209 |
+
"load", "embedding_fig_state", model_name, analysis_mode
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
if st.button(button_label, type="primary"):
|
| 213 |
+
progress = st.progress(0, text="Preparing projections...")
|
| 214 |
+
|
| 215 |
+
def update_progress(current: int, total: int, loaded: int) -> None:
|
| 216 |
+
fraction = current / total if total else 1.0
|
| 217 |
+
progress.progress(
|
| 218 |
+
fraction,
|
| 219 |
+
text=f"Processing layer {current}/{total} ({loaded} plot(s) ready)",
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
project_fn = project_pca if analysis_mode == "PCA" else project_umap
|
| 223 |
+
try:
|
| 224 |
+
plots, errors = load_embedding_samples(
|
| 225 |
+
artifacts_root,
|
| 226 |
+
model_name,
|
| 227 |
+
persona_ids,
|
| 228 |
+
selected_variant,
|
| 229 |
+
selected_layers,
|
| 230 |
+
project_fn,
|
| 231 |
+
persona_names,
|
| 232 |
+
progress_fn=update_progress,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
if errors:
|
| 236 |
+
for err in errors:
|
| 237 |
+
if (
|
| 238 |
+
"missing layer" in err
|
| 239 |
+
or "no selected personas have this layer" in err
|
| 240 |
+
):
|
| 241 |
+
st.warning(f"Skipping unavailable data: `{err}`")
|
| 242 |
+
else:
|
| 243 |
+
st.error(f"Failed to load vectors: `{err}`")
|
| 244 |
+
if not plots:
|
| 245 |
+
st.warning(
|
| 246 |
+
"No projections could be built for the current persona/layer selection."
|
| 247 |
+
)
|
| 248 |
+
st.info("Try fewer personas, fewer layers, or a different variant.")
|
| 249 |
+
st.session_state.pop(embedding_fig_key, None)
|
| 250 |
+
else:
|
| 251 |
+
title_prefix, x_label, y_label = ANALYSIS_LABELS[analysis_mode]
|
| 252 |
+
rendered_figures: list[tuple[int, object]] = []
|
| 253 |
+
for layer_idx, coords, labels, hover_text in plots:
|
| 254 |
+
fig = build_embedding_figure(
|
| 255 |
+
coords=coords,
|
| 256 |
+
labels=labels,
|
| 257 |
+
title=f"{title_prefix}, layer {layer_idx}",
|
| 258 |
+
x_label=x_label,
|
| 259 |
+
y_label=y_label,
|
| 260 |
+
hover_text=hover_text,
|
| 261 |
+
)
|
| 262 |
+
rendered_figures.append((layer_idx, fig))
|
| 263 |
+
total_samples = sum(coords.shape[0] for _, coords, _, _ in plots)
|
| 264 |
+
st.session_state[embedding_fig_key] = (
|
| 265 |
+
rendered_figures,
|
| 266 |
+
persona_key,
|
| 267 |
+
selected_variant,
|
| 268 |
+
total_samples,
|
| 269 |
+
)
|
| 270 |
+
finally:
|
| 271 |
+
progress.empty()
|
| 272 |
+
|
| 273 |
+
if embedding_fig_key in st.session_state:
|
| 274 |
+
rendered_figures, saved_persona_key, saved_variant, total_samples = (
|
| 275 |
+
st.session_state[embedding_fig_key]
|
| 276 |
+
)
|
| 277 |
+
cols = st.columns(2)
|
| 278 |
+
for idx, (layer_idx, fig) in enumerate(rendered_figures):
|
| 279 |
+
with cols[idx % 2]:
|
| 280 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 281 |
+
st.success(
|
| 282 |
+
f"Loaded {total_samples} samples across {len(rendered_figures)} layers."
|
| 283 |
+
)
|
| 284 |
+
filenames = [
|
| 285 |
+
_filename(
|
| 286 |
+
"compare",
|
| 287 |
+
analysis_mode,
|
| 288 |
+
model_name,
|
| 289 |
+
saved_variant,
|
| 290 |
+
saved_persona_key,
|
| 291 |
+
str(layer_idx),
|
| 292 |
+
)
|
| 293 |
+
for layer_idx, _ in rendered_figures
|
| 294 |
+
]
|
| 295 |
+
save_col1, save_col2 = st.columns(2)
|
| 296 |
+
with save_col1:
|
| 297 |
+
if st.button(
|
| 298 |
+
"Save HTML",
|
| 299 |
+
key=widget_key("load", "save_embedding_html", analysis_mode),
|
| 300 |
+
):
|
| 301 |
+
saved_paths = [
|
| 302 |
+
save_plot_html(fig, fn)
|
| 303 |
+
for (_, fig), fn in zip(rendered_figures, filenames)
|
| 304 |
+
]
|
| 305 |
+
st.success(
|
| 306 |
+
f"Saved {len(saved_paths)} HTML plot(s) to `artifacts/plots`."
|
| 307 |
+
)
|
| 308 |
+
with save_col2:
|
| 309 |
+
if st.button(
|
| 310 |
+
"Save PNG",
|
| 311 |
+
key=widget_key("load", "save_embedding_png", analysis_mode),
|
| 312 |
+
):
|
| 313 |
+
try:
|
| 314 |
+
saved_paths = [
|
| 315 |
+
save_plot_png(fig, fn)
|
| 316 |
+
for (_, fig), fn in zip(rendered_figures, filenames)
|
| 317 |
+
]
|
| 318 |
+
st.success(
|
| 319 |
+
f"Saved {len(saved_paths)} PNG plot(s) to `artifacts/plots`."
|
| 320 |
+
)
|
| 321 |
+
except Exception as exc:
|
| 322 |
+
st.error(f"Could not save PNGs: {exc}")
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def render_compare_tab(model_name: str) -> None:
|
| 326 |
+
"""Render the compare tab."""
|
| 327 |
+
|
| 328 |
+
st.title("Compare")
|
| 329 |
+
st.caption("Compare saved activations by cosine similarity, PCA, or UMAP.")
|
| 330 |
+
|
| 331 |
+
st.subheader("Analysis")
|
| 332 |
+
|
| 333 |
+
with st.expander("Advanced", expanded=False):
|
| 334 |
+
artifacts_root = st.text_input(
|
| 335 |
+
"Artifacts root",
|
| 336 |
+
value=str(get_artifacts_dir() / "activations"),
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
analysis_mode = st.segmented_control(
|
| 340 |
+
"Analysis mode",
|
| 341 |
+
options=ANALYSIS_MODES,
|
| 342 |
+
default=ANALYSIS_MODES[0],
|
| 343 |
+
key=widget_key("load", "analysis_mode"),
|
| 344 |
+
label_visibility="collapsed",
|
| 345 |
+
)
|
| 346 |
+
if analysis_mode is None:
|
| 347 |
+
analysis_mode = ANALYSIS_MODES[0]
|
| 348 |
+
st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
|
| 349 |
+
|
| 350 |
+
if analysis_mode == "Cosine similarity":
|
| 351 |
+
_render_cosine_similarity(artifacts_root, model_name)
|
| 352 |
+
return
|
| 353 |
+
|
| 354 |
+
_render_embedding_analysis(artifacts_root, model_name, analysis_mode)
|
tabs/extract.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
from utils.datasets import load_dataset
|
| 4 |
+
from utils.extraction import run_extraction
|
| 5 |
+
from utils.helpers import (
|
| 6 |
+
PROMPT_VARIANTS,
|
| 7 |
+
persona_label,
|
| 8 |
+
prompt_variant_label,
|
| 9 |
+
widget_key,
|
| 10 |
+
)
|
| 11 |
+
from utils.runtime import cached_model
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _extract_widget_key(
|
| 15 |
+
model_name: str, remote: bool, dataset_source: str, suffix: str
|
| 16 |
+
) -> str:
|
| 17 |
+
return widget_key("extract", str(remote), model_name, dataset_source, suffix)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _render_local_dataset_uploads() -> None:
|
| 21 |
+
"""Render file inputs for local dataset uploads."""
|
| 22 |
+
|
| 23 |
+
with st.expander("Local dataset upload", expanded=True):
|
| 24 |
+
st.file_uploader(
|
| 25 |
+
"personas.jsonl",
|
| 26 |
+
type=["jsonl"],
|
| 27 |
+
key="extract__personas_file",
|
| 28 |
+
help="Expected fields: id, persona, templated_prompt, biography_md",
|
| 29 |
+
)
|
| 30 |
+
st.file_uploader(
|
| 31 |
+
"qa.jsonl",
|
| 32 |
+
type=["jsonl"],
|
| 33 |
+
key="extract__qa_file",
|
| 34 |
+
help="Expected fields: id, qid, type, question, answer, difficulty",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
| 39 |
+
"""Render the extraction tab."""
|
| 40 |
+
|
| 41 |
+
st.title("Extract")
|
| 42 |
+
|
| 43 |
+
if dataset_source == "Local JSONL upload":
|
| 44 |
+
_render_local_dataset_uploads()
|
| 45 |
+
|
| 46 |
+
selected_variants = st.multiselect(
|
| 47 |
+
"Prompt variants",
|
| 48 |
+
options=PROMPT_VARIANTS,
|
| 49 |
+
default=PROMPT_VARIANTS,
|
| 50 |
+
format_func=prompt_variant_label,
|
| 51 |
+
key=_extract_widget_key(model_name, remote, dataset_source, "prompt_variants"),
|
| 52 |
+
)
|
| 53 |
+
if not selected_variants:
|
| 54 |
+
st.info("Select at least one prompt variant.")
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
dataset, dataset_status = load_dataset(dataset_source)
|
| 59 |
+
st.caption(dataset_status)
|
| 60 |
+
except Exception as exc:
|
| 61 |
+
st.error(f"Could not load data: {exc}")
|
| 62 |
+
st.info(
|
| 63 |
+
"Upload both JSONL files or switch to the built-in SynthPersona source."
|
| 64 |
+
)
|
| 65 |
+
return
|
| 66 |
+
|
| 67 |
+
personas = list(dataset)
|
| 68 |
+
if not personas:
|
| 69 |
+
st.warning("No personas found in the selected dataset.")
|
| 70 |
+
st.info(
|
| 71 |
+
"Try another dataset source or check that the personas file is not empty."
|
| 72 |
+
)
|
| 73 |
+
return
|
| 74 |
+
|
| 75 |
+
selected_personas = st.multiselect(
|
| 76 |
+
"Personas",
|
| 77 |
+
options=personas,
|
| 78 |
+
default=[personas[0]] if personas else [],
|
| 79 |
+
format_func=persona_label,
|
| 80 |
+
key=_extract_widget_key(model_name, remote, dataset_source, "persona_select"),
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if not selected_personas:
|
| 84 |
+
st.info("Select at least one persona.")
|
| 85 |
+
return
|
| 86 |
+
|
| 87 |
+
qa_filter_type: str | None
|
| 88 |
+
qa_filter_difficulty: list[int] | None
|
| 89 |
+
|
| 90 |
+
with st.expander("Advanced", expanded=False):
|
| 91 |
+
st.caption("Filters")
|
| 92 |
+
|
| 93 |
+
col1, col2, col3 = st.columns([2, 2, 1])
|
| 94 |
+
with col1:
|
| 95 |
+
qa_type_select = st.selectbox(
|
| 96 |
+
"QA type",
|
| 97 |
+
options=["all", "explicit", "implicit"],
|
| 98 |
+
index=0,
|
| 99 |
+
key=_extract_widget_key(
|
| 100 |
+
model_name, remote, dataset_source, "qa_type_select"
|
| 101 |
+
),
|
| 102 |
+
)
|
| 103 |
+
qa_filter_type = (
|
| 104 |
+
qa_type_select if qa_type_select in ("explicit", "implicit") else None
|
| 105 |
+
)
|
| 106 |
+
with col2:
|
| 107 |
+
difficulty_values = st.multiselect(
|
| 108 |
+
"Difficulty",
|
| 109 |
+
options=[1, 2, 3],
|
| 110 |
+
default=[1, 2, 3],
|
| 111 |
+
key=_extract_widget_key(
|
| 112 |
+
model_name, remote, dataset_source, "difficulty_select"
|
| 113 |
+
),
|
| 114 |
+
)
|
| 115 |
+
qa_filter_difficulty = difficulty_values if difficulty_values else None
|
| 116 |
+
|
| 117 |
+
# Pre-load QA pairs for all selected personas to validate filters and set slider range.
|
| 118 |
+
qa_by_persona = {
|
| 119 |
+
p.id: dataset.get_qa(
|
| 120 |
+
p.id, type=qa_filter_type, difficulty=qa_filter_difficulty
|
| 121 |
+
)
|
| 122 |
+
for p in selected_personas
|
| 123 |
+
}
|
| 124 |
+
personas_without_qa = [p for p in selected_personas if not qa_by_persona[p.id]]
|
| 125 |
+
if personas_without_qa:
|
| 126 |
+
names = ", ".join(p.name for p in personas_without_qa)
|
| 127 |
+
st.warning(f"No QA pairs match filters for: {names}. They will be skipped.")
|
| 128 |
+
|
| 129 |
+
personas_to_run = [p for p in selected_personas if qa_by_persona[p.id]]
|
| 130 |
+
if not personas_to_run:
|
| 131 |
+
st.info("No personas have matching QA pairs. Widen the filters.")
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
min_qa_count = min(len(qa_by_persona[p.id]) for p in personas_to_run)
|
| 135 |
+
|
| 136 |
+
with col3:
|
| 137 |
+
max_questions = st.slider(
|
| 138 |
+
"Max questions",
|
| 139 |
+
min_value=1,
|
| 140 |
+
max_value=min_qa_count,
|
| 141 |
+
value=min_qa_count,
|
| 142 |
+
key=_extract_widget_key(
|
| 143 |
+
model_name, remote, dataset_source, "max_questions"
|
| 144 |
+
),
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
run_clicked = st.button("Run extraction", type="primary")
|
| 148 |
+
if not run_clicked:
|
| 149 |
+
return
|
| 150 |
+
|
| 151 |
+
status_box = st.empty()
|
| 152 |
+
status_box.info("Extraction in progress...")
|
| 153 |
+
progress = st.progress(0, text="Preparing extraction...")
|
| 154 |
+
|
| 155 |
+
with st.spinner("Loading model..."):
|
| 156 |
+
model = cached_model(model_name=model_name, remote=remote)
|
| 157 |
+
|
| 158 |
+
try:
|
| 159 |
+
total_steps = len(personas_to_run) * len(selected_variants)
|
| 160 |
+
step = 0
|
| 161 |
+
results = []
|
| 162 |
+
|
| 163 |
+
for persona in personas_to_run:
|
| 164 |
+
qa_pairs = qa_by_persona[persona.id][:max_questions]
|
| 165 |
+
for variant in selected_variants:
|
| 166 |
+
progress.progress(
|
| 167 |
+
step / total_steps if total_steps else 1.0,
|
| 168 |
+
text=f"{persona.name} · {prompt_variant_label(variant)} ({step + 1}/{total_steps})",
|
| 169 |
+
)
|
| 170 |
+
variant_results = run_extraction(
|
| 171 |
+
model=model,
|
| 172 |
+
model_name=model_name,
|
| 173 |
+
persona=persona,
|
| 174 |
+
qa_pairs=qa_pairs,
|
| 175 |
+
variants=[variant],
|
| 176 |
+
remote=remote,
|
| 177 |
+
)
|
| 178 |
+
results.extend(variant_results)
|
| 179 |
+
step += 1
|
| 180 |
+
|
| 181 |
+
progress.progress(1.0, text="Extraction complete")
|
| 182 |
+
except Exception as exc:
|
| 183 |
+
st.error(f"Extraction failed: {exc}")
|
| 184 |
+
return
|
| 185 |
+
finally:
|
| 186 |
+
progress.empty()
|
| 187 |
+
|
| 188 |
+
status_box.success("Extraction complete")
|
| 189 |
+
st.success(f"Saved {len(results)} artifact set(s)")
|
| 190 |
+
|
| 191 |
+
for result in results:
|
| 192 |
+
st.markdown(
|
| 193 |
+
f"- **{result.persona_name}** · {prompt_variant_label(result.variant)}: "
|
| 194 |
+
f"{result.n_questions} questions, {result.n_layers} layers, {result.d_model} hidden size"
|
| 195 |
+
)
|
utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Utility helpers for the Streamlit UI."""
|
utils/artifacts.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from collections.abc import Callable
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import streamlit as st
|
| 6 |
+
import torch
|
| 7 |
+
from persona_vectors.activation_io import (
|
| 8 |
+
load_activation_metadata,
|
| 9 |
+
load_per_question_vectors,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def model_dir_name(model_name: str) -> str:
|
| 16 |
+
"""Encode a model name for use in artifact paths."""
|
| 17 |
+
|
| 18 |
+
return model_name.replace("/", "__")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def list_available_personas(
|
| 22 |
+
artifacts_root: str | Path,
|
| 23 |
+
model_name: str,
|
| 24 |
+
variants: list[str],
|
| 25 |
+
) -> list[str]:
|
| 26 |
+
"""List persona ids available for every requested variant."""
|
| 27 |
+
|
| 28 |
+
shared_personas: set[str] | None = None
|
| 29 |
+
root = Path(artifacts_root)
|
| 30 |
+
for variant in variants:
|
| 31 |
+
model_dir = root / model_dir_name(model_name) / variant
|
| 32 |
+
if not model_dir.exists():
|
| 33 |
+
return []
|
| 34 |
+
|
| 35 |
+
variant_personas = {d.name for d in model_dir.iterdir() if d.is_dir()}
|
| 36 |
+
if shared_personas is None:
|
| 37 |
+
shared_personas = variant_personas
|
| 38 |
+
else:
|
| 39 |
+
shared_personas &= variant_personas
|
| 40 |
+
|
| 41 |
+
if not shared_personas:
|
| 42 |
+
return []
|
| 43 |
+
|
| 44 |
+
return sorted(shared_personas or set())
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def load_persona_names(
|
| 48 |
+
artifacts_root: str | Path,
|
| 49 |
+
model_name: str,
|
| 50 |
+
variants: list[str],
|
| 51 |
+
persona_ids: list[str],
|
| 52 |
+
) -> dict[str, str]:
|
| 53 |
+
"""Load display names from saved activation metadata."""
|
| 54 |
+
|
| 55 |
+
names: dict[str, str] = {}
|
| 56 |
+
for persona_id in persona_ids:
|
| 57 |
+
for variant in variants:
|
| 58 |
+
try:
|
| 59 |
+
metadata = load_activation_metadata(
|
| 60 |
+
root_dir=artifacts_root,
|
| 61 |
+
model_name=model_name,
|
| 62 |
+
prompt_variant=variant,
|
| 63 |
+
persona_id=persona_id,
|
| 64 |
+
)
|
| 65 |
+
except Exception:
|
| 66 |
+
logger.debug(
|
| 67 |
+
"Failed to load metadata for persona %s variant %s",
|
| 68 |
+
persona_id,
|
| 69 |
+
variant,
|
| 70 |
+
exc_info=True,
|
| 71 |
+
)
|
| 72 |
+
continue
|
| 73 |
+
|
| 74 |
+
persona_name = metadata.get("persona_name")
|
| 75 |
+
if isinstance(persona_name, str) and persona_name:
|
| 76 |
+
names[persona_id] = persona_name
|
| 77 |
+
break
|
| 78 |
+
|
| 79 |
+
return names
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def artifact_persona_options(
|
| 83 |
+
artifacts_root: str | Path,
|
| 84 |
+
model_name: str,
|
| 85 |
+
variants: list[str],
|
| 86 |
+
) -> tuple[list[str], dict[str, str]]:
|
| 87 |
+
"""Return persona ids and names for the selected artifacts."""
|
| 88 |
+
|
| 89 |
+
persona_options = list_available_personas(artifacts_root, model_name, variants)
|
| 90 |
+
persona_names = load_persona_names(
|
| 91 |
+
artifacts_root,
|
| 92 |
+
model_name,
|
| 93 |
+
variants,
|
| 94 |
+
persona_options,
|
| 95 |
+
)
|
| 96 |
+
return persona_options, persona_names
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@st.cache_data(show_spinner=False)
|
| 100 |
+
def list_available_layers(
|
| 101 |
+
artifacts_root: str,
|
| 102 |
+
model_name: str,
|
| 103 |
+
variants: list[str],
|
| 104 |
+
persona_ids: list[str],
|
| 105 |
+
) -> list[int]:
|
| 106 |
+
"""List layer indices shared by all matching saved activation files."""
|
| 107 |
+
|
| 108 |
+
shared_layers: set[int] | None = None
|
| 109 |
+
for variant in variants:
|
| 110 |
+
for persona_id in persona_ids:
|
| 111 |
+
try:
|
| 112 |
+
vectors, _ = load_per_question_vectors(
|
| 113 |
+
root_dir=artifacts_root,
|
| 114 |
+
model_name=model_name,
|
| 115 |
+
prompt_variant=variant,
|
| 116 |
+
persona_id=persona_id,
|
| 117 |
+
)
|
| 118 |
+
except Exception:
|
| 119 |
+
logger.debug(
|
| 120 |
+
"Failed to load vectors for persona %s variant %s",
|
| 121 |
+
persona_id,
|
| 122 |
+
variant,
|
| 123 |
+
exc_info=True,
|
| 124 |
+
)
|
| 125 |
+
continue
|
| 126 |
+
|
| 127 |
+
layers = set(range(vectors.shape[1]))
|
| 128 |
+
if shared_layers is None:
|
| 129 |
+
shared_layers = layers
|
| 130 |
+
else:
|
| 131 |
+
shared_layers &= layers
|
| 132 |
+
|
| 133 |
+
return sorted(shared_layers or set())
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def load_cosine_traces(
|
| 137 |
+
artifacts_root: str | Path,
|
| 138 |
+
model_name: str,
|
| 139 |
+
persona_ids: list[str],
|
| 140 |
+
variant_a: str,
|
| 141 |
+
variant_b: str,
|
| 142 |
+
) -> tuple[list[tuple[str, torch.Tensor, torch.Tensor]], dict[str, str], list[str]]:
|
| 143 |
+
"""Load mean activation traces for pairwise cosine-similarity plots."""
|
| 144 |
+
|
| 145 |
+
persona_names = load_persona_names(
|
| 146 |
+
artifacts_root,
|
| 147 |
+
model_name,
|
| 148 |
+
[variant_a, variant_b],
|
| 149 |
+
persona_ids,
|
| 150 |
+
)
|
| 151 |
+
traces: list[tuple[str, torch.Tensor, torch.Tensor]] = []
|
| 152 |
+
errors: list[str] = []
|
| 153 |
+
|
| 154 |
+
for persona_id in persona_ids:
|
| 155 |
+
try:
|
| 156 |
+
vectors_a, _ = load_per_question_vectors(
|
| 157 |
+
root_dir=artifacts_root,
|
| 158 |
+
model_name=model_name,
|
| 159 |
+
prompt_variant=variant_a,
|
| 160 |
+
persona_id=persona_id,
|
| 161 |
+
)
|
| 162 |
+
vectors_b, _ = load_per_question_vectors(
|
| 163 |
+
root_dir=artifacts_root,
|
| 164 |
+
model_name=model_name,
|
| 165 |
+
prompt_variant=variant_b,
|
| 166 |
+
persona_id=persona_id,
|
| 167 |
+
)
|
| 168 |
+
except Exception as exc:
|
| 169 |
+
errors.append(f"{persona_id}: {exc}")
|
| 170 |
+
continue
|
| 171 |
+
|
| 172 |
+
traces.append(
|
| 173 |
+
(persona_id, vectors_a.float().mean(dim=0), vectors_b.float().mean(dim=0))
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
return traces, persona_names, errors
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def load_embedding_samples(
|
| 180 |
+
artifacts_root: str | Path,
|
| 181 |
+
model_name: str,
|
| 182 |
+
persona_ids: list[str],
|
| 183 |
+
variant: str,
|
| 184 |
+
selected_layers: list[int],
|
| 185 |
+
project_fn: Callable[[torch.Tensor], torch.Tensor],
|
| 186 |
+
persona_names: dict[str, str],
|
| 187 |
+
progress_fn: Callable[[int, int, int], None] | None = None,
|
| 188 |
+
) -> tuple[list[tuple[int, torch.Tensor, list[str], list[str]]], list[str]]:
|
| 189 |
+
"""Load samples for 2D projections without re-reading each layer from disk."""
|
| 190 |
+
|
| 191 |
+
plots: list[tuple[int, torch.Tensor, list[str], list[str]]] = []
|
| 192 |
+
errors: list[str] = []
|
| 193 |
+
vectors_by_persona: dict[str, torch.Tensor] = {}
|
| 194 |
+
|
| 195 |
+
for persona_id in persona_ids:
|
| 196 |
+
try:
|
| 197 |
+
vectors, _ = load_per_question_vectors(
|
| 198 |
+
root_dir=artifacts_root,
|
| 199 |
+
model_name=model_name,
|
| 200 |
+
prompt_variant=variant,
|
| 201 |
+
persona_id=persona_id,
|
| 202 |
+
)
|
| 203 |
+
except Exception as exc:
|
| 204 |
+
errors.append(f"{persona_id} / {variant}: {exc}")
|
| 205 |
+
continue
|
| 206 |
+
|
| 207 |
+
vectors_by_persona[persona_id] = vectors
|
| 208 |
+
|
| 209 |
+
total_layers = len(selected_layers)
|
| 210 |
+
for idx, layer_idx in enumerate(selected_layers, start=1):
|
| 211 |
+
samples: list[torch.Tensor] = []
|
| 212 |
+
labels: list[str] = []
|
| 213 |
+
hover_text: list[str] = []
|
| 214 |
+
|
| 215 |
+
for persona_id, vectors in vectors_by_persona.items():
|
| 216 |
+
if layer_idx >= vectors.shape[1]:
|
| 217 |
+
errors.append(f"{persona_id} / {variant}: missing layer {layer_idx}")
|
| 218 |
+
continue
|
| 219 |
+
|
| 220 |
+
layer_vectors = vectors[:, layer_idx, :]
|
| 221 |
+
samples.append(layer_vectors)
|
| 222 |
+
labels.extend([persona_id] * layer_vectors.shape[0])
|
| 223 |
+
display_name = persona_names.get(persona_id) or persona_id
|
| 224 |
+
hover_text.extend(
|
| 225 |
+
[
|
| 226 |
+
f"<b>{display_name}</b><br>{variant}",
|
| 227 |
+
]
|
| 228 |
+
* layer_vectors.shape[0]
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if not samples:
|
| 232 |
+
errors.append(f"Layer {layer_idx}: no selected personas have this layer")
|
| 233 |
+
else:
|
| 234 |
+
all_samples = torch.cat(samples, dim=0)
|
| 235 |
+
if all_samples.shape[0] < 2:
|
| 236 |
+
errors.append(
|
| 237 |
+
f"Layer {layer_idx}: need at least 2 samples after filtering selected personas"
|
| 238 |
+
)
|
| 239 |
+
else:
|
| 240 |
+
try:
|
| 241 |
+
coords = project_fn(all_samples)
|
| 242 |
+
plots.append((layer_idx, coords, labels, hover_text))
|
| 243 |
+
except Exception as exc:
|
| 244 |
+
errors.append(f"Layer {layer_idx}: {exc}")
|
| 245 |
+
|
| 246 |
+
if progress_fn is not None:
|
| 247 |
+
progress_fn(idx, total_layers, len(plots))
|
| 248 |
+
|
| 249 |
+
return plots, errors
|
utils/chat.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from contextlib import contextmanager, nullcontext
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Literal
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from nnterp import StandardizedTransformer
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
from persona_data.synth_persona import PersonaData
|
| 12 |
+
from persona_data.prompts import (
|
| 13 |
+
format_biography_prompt,
|
| 14 |
+
format_templated_prompt,
|
| 15 |
+
normalize_messages,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
SystemPromptMode = Literal["empty", "templated", "biography", "custom"]
|
| 19 |
+
|
| 20 |
+
_CUSTOM_PROMPT_DEFAULT = "You are a helpful assistant."
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class ChatReply:
|
| 25 |
+
text: str
|
| 26 |
+
prompt_tokens: int
|
| 27 |
+
output_tokens: int
|
| 28 |
+
past_key_values: object | None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def resolve_system_prompt(
|
| 32 |
+
persona: PersonaData | None,
|
| 33 |
+
mode: SystemPromptMode,
|
| 34 |
+
) -> str:
|
| 35 |
+
"""Resolve the active system prompt for chat.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
persona: Selected persona, if any.
|
| 39 |
+
mode: Prompt mode selected in the UI.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
The rendered system prompt string.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
if persona is None:
|
| 46 |
+
return ""
|
| 47 |
+
|
| 48 |
+
if mode == "templated":
|
| 49 |
+
return format_templated_prompt(persona.templated_prompt)
|
| 50 |
+
if mode == "biography":
|
| 51 |
+
return format_biography_prompt(persona.biography_md)
|
| 52 |
+
if mode == "custom":
|
| 53 |
+
return _CUSTOM_PROMPT_DEFAULT
|
| 54 |
+
return ""
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _format_plain_messages(
|
| 58 |
+
messages: list[dict[str, str]], add_generation_prompt: bool
|
| 59 |
+
) -> str:
|
| 60 |
+
"""Format messages as plain ``Role: content`` text, used as a last-resort fallback."""
|
| 61 |
+
lines: list[str] = []
|
| 62 |
+
|
| 63 |
+
for message in messages:
|
| 64 |
+
role = message["role"]
|
| 65 |
+
content = message["content"]
|
| 66 |
+
|
| 67 |
+
if role == "system":
|
| 68 |
+
if content:
|
| 69 |
+
lines.append(f"System: {content}")
|
| 70 |
+
elif role == "user":
|
| 71 |
+
lines.append(f"User: {content}")
|
| 72 |
+
elif role == "assistant":
|
| 73 |
+
lines.append(f"Assistant: {content}")
|
| 74 |
+
else:
|
| 75 |
+
lines.append(f"{role.title()}: {content}")
|
| 76 |
+
|
| 77 |
+
if add_generation_prompt and (not lines or not lines[-1].startswith("Assistant:")):
|
| 78 |
+
lines.append("Assistant:")
|
| 79 |
+
|
| 80 |
+
return "\n\n".join(lines)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _format_generation_prompt(
|
| 84 |
+
messages: list[dict[str, str]], tokenizer: object
|
| 85 |
+
) -> tuple[str, int]:
|
| 86 |
+
"""Render messages into a single prompt string and count prompt tokens.
|
| 87 |
+
|
| 88 |
+
Tries the tokenizer's chat template first, falls back to normalized messages,
|
| 89 |
+
then to a plain-text format if both template attempts fail.
|
| 90 |
+
"""
|
| 91 |
+
normalized_messages = messages
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
prompt = tokenizer.apply_chat_template(
|
| 95 |
+
normalized_messages,
|
| 96 |
+
tokenize=False,
|
| 97 |
+
add_generation_prompt=True,
|
| 98 |
+
)
|
| 99 |
+
except Exception:
|
| 100 |
+
logger.debug(
|
| 101 |
+
"Chat template failed on raw messages, trying normalized", exc_info=True
|
| 102 |
+
)
|
| 103 |
+
normalized_messages = normalize_messages(messages)
|
| 104 |
+
|
| 105 |
+
try:
|
| 106 |
+
prompt = tokenizer.apply_chat_template(
|
| 107 |
+
normalized_messages,
|
| 108 |
+
tokenize=False,
|
| 109 |
+
add_generation_prompt=True,
|
| 110 |
+
)
|
| 111 |
+
except Exception:
|
| 112 |
+
logger.debug(
|
| 113 |
+
"Chat template failed on normalized messages, falling back to plain format",
|
| 114 |
+
exc_info=True,
|
| 115 |
+
)
|
| 116 |
+
prompt = _format_plain_messages(
|
| 117 |
+
normalized_messages,
|
| 118 |
+
add_generation_prompt=True,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
prompt_token_count = tokenizer(prompt, return_tensors="pt").input_ids.shape[1]
|
| 122 |
+
return prompt, prompt_token_count
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@contextmanager
|
| 126 |
+
def _seeded_rng(seed: int | None):
|
| 127 |
+
"""Context manager that forks the RNG state and sets a deterministic seed."""
|
| 128 |
+
if seed is None:
|
| 129 |
+
yield
|
| 130 |
+
return
|
| 131 |
+
|
| 132 |
+
cuda_ctx = torch.random.fork_rng(devices=range(torch.cuda.device_count()))
|
| 133 |
+
mps_ctx = (
|
| 134 |
+
torch.random.fork_rng(devices=range(1), device_type="mps")
|
| 135 |
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
| 136 |
+
else nullcontext()
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
with cuda_ctx, mps_ctx:
|
| 140 |
+
torch.manual_seed(seed)
|
| 141 |
+
yield
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def generate_chat_reply(
|
| 145 |
+
model: StandardizedTransformer,
|
| 146 |
+
messages: list[dict[str, str]],
|
| 147 |
+
remote: bool,
|
| 148 |
+
past_key_values: object | None = None,
|
| 149 |
+
max_new_tokens: int = 256,
|
| 150 |
+
do_sample: bool = False,
|
| 151 |
+
temperature: float = 1.0,
|
| 152 |
+
top_p: float = 1.0,
|
| 153 |
+
top_k: int = 50,
|
| 154 |
+
repetition_penalty: float = 1.0,
|
| 155 |
+
seed: int | None = None,
|
| 156 |
+
) -> ChatReply:
|
| 157 |
+
"""Generate one assistant reply from a full chat history.
|
| 158 |
+
|
| 159 |
+
The helper uses ``model.generate`` so it works with both local and NDIF-backed
|
| 160 |
+
nnsight models. The full conversation is re-rendered each turn and the cache from
|
| 161 |
+
the previous turn is reused when available.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
model: Loaded standardized nnterp model.
|
| 165 |
+
messages: Full chat history, including any system prompt as the first message.
|
| 166 |
+
remote: Whether to execute the generation on NDIF.
|
| 167 |
+
past_key_values: Cache returned by the previous generation step.
|
| 168 |
+
max_new_tokens: Maximum number of assistant tokens to generate.
|
| 169 |
+
do_sample: Whether to sample from the model distribution.
|
| 170 |
+
temperature: Sampling temperature, used only when sampling is enabled.
|
| 171 |
+
top_p: Nucleus sampling threshold, used only when sampling is enabled.
|
| 172 |
+
top_k: Top-k cutoff, used only when sampling is enabled.
|
| 173 |
+
repetition_penalty: Repetition penalty applied during decoding.
|
| 174 |
+
seed: Optional local RNG seed for sampled generation.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
ChatReply with generated text and the updated cache.
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
tokenizer = model.tokenizer
|
| 181 |
+
prompt, prompt_token_count = _format_generation_prompt(messages, tokenizer)
|
| 182 |
+
|
| 183 |
+
generation_kwargs: dict[str, object] = {
|
| 184 |
+
"max_new_tokens": max_new_tokens,
|
| 185 |
+
"return_dict_in_generate": True,
|
| 186 |
+
"use_cache": True,
|
| 187 |
+
}
|
| 188 |
+
if do_sample:
|
| 189 |
+
generation_kwargs["do_sample"] = True
|
| 190 |
+
generation_kwargs["temperature"] = temperature
|
| 191 |
+
generation_kwargs["top_p"] = top_p
|
| 192 |
+
generation_kwargs["top_k"] = top_k
|
| 193 |
+
if repetition_penalty != 1.0:
|
| 194 |
+
generation_kwargs["repetition_penalty"] = repetition_penalty
|
| 195 |
+
if past_key_values is not None and not remote:
|
| 196 |
+
generation_kwargs["past_key_values"] = past_key_values
|
| 197 |
+
if remote:
|
| 198 |
+
generation_kwargs["remote"] = True
|
| 199 |
+
# WARNING: NDIF returns caches on CPU, so cross-turn cache reuse is not stable.
|
| 200 |
+
|
| 201 |
+
with _seeded_rng(seed if do_sample and not remote else None):
|
| 202 |
+
with model.generate(prompt, **generation_kwargs) as tracer:
|
| 203 |
+
generated = tracer.result.save()
|
| 204 |
+
|
| 205 |
+
if hasattr(generated, "value") and getattr(generated, "value") is not None:
|
| 206 |
+
generated = generated.value
|
| 207 |
+
|
| 208 |
+
if not hasattr(generated, "sequences"):
|
| 209 |
+
raise ValueError("Generation did not return token sequences")
|
| 210 |
+
|
| 211 |
+
sequences = generated.sequences
|
| 212 |
+
if not isinstance(sequences, torch.Tensor):
|
| 213 |
+
raise TypeError("Generated sequences must be a tensor")
|
| 214 |
+
|
| 215 |
+
generated_ids = sequences[0, prompt_token_count:]
|
| 216 |
+
text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 217 |
+
output_tokens = int(sequences.shape[1] - prompt_token_count)
|
| 218 |
+
|
| 219 |
+
return ChatReply(
|
| 220 |
+
text=text,
|
| 221 |
+
prompt_tokens=prompt_token_count,
|
| 222 |
+
output_tokens=max(0, output_tokens),
|
| 223 |
+
past_key_values=(
|
| 224 |
+
getattr(generated, "past_key_values", None) if not remote else None
|
| 225 |
+
),
|
| 226 |
+
)
|
utils/chat_export.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from datetime import datetime, timezone
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from persona_data.environment import get_artifacts_dir
|
| 6 |
+
|
| 7 |
+
from utils.artifacts import model_dir_name
|
| 8 |
+
from utils.helpers import slugify
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def build_chat_export_payload(
|
| 12 |
+
*,
|
| 13 |
+
model_name: str,
|
| 14 |
+
dataset_source: str,
|
| 15 |
+
persona_id: str,
|
| 16 |
+
persona_name: str | None,
|
| 17 |
+
panel_label: str | None,
|
| 18 |
+
prompt_mode: str,
|
| 19 |
+
system_prompt: str | None,
|
| 20 |
+
messages: list[dict[str, str]],
|
| 21 |
+
generation: dict[str, object],
|
| 22 |
+
) -> dict[str, object]:
|
| 23 |
+
"""Build a JSON-serializable snapshot of the current chat session.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
model_name: Model identifier used for the chat.
|
| 27 |
+
dataset_source: Human-readable dataset source label.
|
| 28 |
+
persona_id: Selected persona id.
|
| 29 |
+
persona_name: Selected persona display name, if available.
|
| 30 |
+
prompt_mode: Active system prompt mode.
|
| 31 |
+
messages: Conversation messages without the system prompt.
|
| 32 |
+
generation: Generation settings used for the chat.
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
A JSON-serializable dictionary.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
return {
|
| 39 |
+
"model_name": model_name,
|
| 40 |
+
"dataset_source": dataset_source,
|
| 41 |
+
"persona": {
|
| 42 |
+
"id": persona_id,
|
| 43 |
+
"name": persona_name,
|
| 44 |
+
},
|
| 45 |
+
"panel_label": panel_label,
|
| 46 |
+
"prompt_mode": prompt_mode,
|
| 47 |
+
"generation": generation,
|
| 48 |
+
"messages": (
|
| 49 |
+
[{"role": "system", "content": system_prompt}] if system_prompt else []
|
| 50 |
+
)
|
| 51 |
+
+ messages,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def save_chat_export(
|
| 56 |
+
*,
|
| 57 |
+
model_name: str,
|
| 58 |
+
dataset_source: str,
|
| 59 |
+
persona_id: str,
|
| 60 |
+
persona_name: str | None,
|
| 61 |
+
prompt_mode: str,
|
| 62 |
+
system_prompt: str | None,
|
| 63 |
+
messages: list[dict[str, str]],
|
| 64 |
+
generation: dict[str, object],
|
| 65 |
+
panel_label: str | None = None,
|
| 66 |
+
) -> Path:
|
| 67 |
+
"""Save the current chat session to ``artifacts/chats`` as JSON.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
model_name: Model identifier used for the chat.
|
| 71 |
+
dataset_source: Human-readable dataset source label.
|
| 72 |
+
persona_id: Selected persona id.
|
| 73 |
+
persona_name: Selected persona display name, if available.
|
| 74 |
+
prompt_mode: Active system prompt mode.
|
| 75 |
+
system_prompt: Current system prompt text, if any.
|
| 76 |
+
messages: Conversation messages without the system prompt.
|
| 77 |
+
generation: Generation settings used for the chat.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
The path the export was written to.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
payload = build_chat_export_payload(
|
| 84 |
+
model_name=model_name,
|
| 85 |
+
dataset_source=dataset_source,
|
| 86 |
+
persona_id=persona_id,
|
| 87 |
+
persona_name=persona_name,
|
| 88 |
+
panel_label=panel_label,
|
| 89 |
+
prompt_mode=prompt_mode,
|
| 90 |
+
system_prompt=system_prompt,
|
| 91 |
+
messages=messages,
|
| 92 |
+
generation=generation,
|
| 93 |
+
)
|
| 94 |
+
export_dir = (
|
| 95 |
+
get_artifacts_dir()
|
| 96 |
+
/ "chats"
|
| 97 |
+
/ model_dir_name(model_name)
|
| 98 |
+
/ slugify(dataset_source)
|
| 99 |
+
/ slugify(persona_id)
|
| 100 |
+
)
|
| 101 |
+
export_dir.mkdir(parents=True, exist_ok=True)
|
| 102 |
+
|
| 103 |
+
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
| 104 |
+
filename_parts = [
|
| 105 |
+
timestamp,
|
| 106 |
+
slugify(persona_name or persona_id),
|
| 107 |
+
slugify(prompt_mode),
|
| 108 |
+
]
|
| 109 |
+
if panel_label:
|
| 110 |
+
filename_parts.append(slugify(panel_label))
|
| 111 |
+
export_path = export_dir / f"{'__'.join(filename_parts)}.json"
|
| 112 |
+
export_path.write_text(
|
| 113 |
+
f"{json.dumps(payload, indent=2, ensure_ascii=False)}\n",
|
| 114 |
+
encoding="utf-8",
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
return export_path
|
utils/datasets.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import atexit
|
| 2 |
+
import shutil
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from tempfile import mkdtemp
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import streamlit as st
|
| 8 |
+
from persona_data.synth_persona import SynthPersonaDataset
|
| 9 |
+
|
| 10 |
+
from .helpers import DATASET_SOURCES
|
| 11 |
+
from .local_dataset import LocalPersonaDataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@st.cache_resource(show_spinner=False)
|
| 15 |
+
def cached_hf_dataset() -> SynthPersonaDataset:
|
| 16 |
+
"""Load the default SynthPersona HuggingFace dataset once."""
|
| 17 |
+
|
| 18 |
+
return SynthPersonaDataset()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _upload_cache_dir() -> Path:
|
| 22 |
+
cache_dir = st.session_state.get("_upload_cache_dir")
|
| 23 |
+
if cache_dir is None:
|
| 24 |
+
cache_dir = mkdtemp(prefix="persona_vectors_uploads_")
|
| 25 |
+
st.session_state["_upload_cache_dir"] = cache_dir
|
| 26 |
+
# Register cleanup so the temp dir is removed when the server process exits.
|
| 27 |
+
atexit.register(shutil.rmtree, cache_dir, ignore_errors=True)
|
| 28 |
+
return Path(cache_dir)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _uploaded_file_to_temp_path(uploaded_file: Any, stem: str) -> Path:
|
| 32 |
+
suffix = Path(uploaded_file.name).suffix or ".jsonl"
|
| 33 |
+
temp_path = _upload_cache_dir() / f"{stem}{suffix}"
|
| 34 |
+
data = uploaded_file.getvalue()
|
| 35 |
+
if temp_path.exists() and temp_path.stat().st_size == len(data):
|
| 36 |
+
return temp_path
|
| 37 |
+
temp_path.write_bytes(data)
|
| 38 |
+
return temp_path
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def load_dataset(
|
| 42 |
+
dataset_source: str,
|
| 43 |
+
) -> tuple[SynthPersonaDataset | LocalPersonaDataset, str]:
|
| 44 |
+
"""Load the selected dataset source for the UI."""
|
| 45 |
+
|
| 46 |
+
if dataset_source == DATASET_SOURCES[0]:
|
| 47 |
+
return cached_hf_dataset(), "SynthPersona"
|
| 48 |
+
|
| 49 |
+
personas_file = st.session_state.get("extract__personas_file")
|
| 50 |
+
qa_file = st.session_state.get("extract__qa_file")
|
| 51 |
+
if personas_file is None or qa_file is None:
|
| 52 |
+
raise ValueError("Upload both personas.jsonl and qa.jsonl files")
|
| 53 |
+
|
| 54 |
+
personas_path = _uploaded_file_to_temp_path(personas_file, stem="personas")
|
| 55 |
+
qa_path = _uploaded_file_to_temp_path(qa_file, stem="qa")
|
| 56 |
+
return (
|
| 57 |
+
LocalPersonaDataset(personas_path=personas_path, qa_path=qa_path),
|
| 58 |
+
"Local upload",
|
| 59 |
+
)
|
utils/extraction.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import logging
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from nnterp import StandardizedTransformer
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
from persona_data.environment import get_artifacts_dir
|
| 11 |
+
from persona_data.synth_persona import PersonaData, QAPair
|
| 12 |
+
from persona_vectors.activation_io import save_per_question_vectors
|
| 13 |
+
from persona_vectors.activations import extract_activations
|
| 14 |
+
from persona_data.prompts import (
|
| 15 |
+
format_biography_prompt,
|
| 16 |
+
format_messages,
|
| 17 |
+
format_templated_prompt,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class VariantExtractionResult:
|
| 23 |
+
variant: str
|
| 24 |
+
output_dir: str
|
| 25 |
+
n_questions: int
|
| 26 |
+
n_layers: int
|
| 27 |
+
d_model: int
|
| 28 |
+
persona_name: str = ""
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _prepare_inputs(
|
| 32 |
+
tokenizer: object,
|
| 33 |
+
system_prompt: str,
|
| 34 |
+
qa_pairs: list[QAPair],
|
| 35 |
+
) -> tuple[list[str], list[torch.Tensor], list[str]]:
|
| 36 |
+
"""Format QA pairs into tokenized prompts with answer-token masks.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
tokenizer: HuggingFace-compatible tokenizer from the model.
|
| 40 |
+
system_prompt: System prompt to prepend to each conversation.
|
| 41 |
+
qa_pairs: List of question-answer pairs to format.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
A tuple of (full_texts, token_masks, questions) where full_texts are
|
| 45 |
+
the rendered prompt strings, token_masks are boolean tensors marking
|
| 46 |
+
answer tokens, and questions are the raw question strings.
|
| 47 |
+
"""
|
| 48 |
+
full_texts: list[str] = []
|
| 49 |
+
token_masks: list[torch.Tensor] = []
|
| 50 |
+
questions: list[str] = []
|
| 51 |
+
|
| 52 |
+
for qa in qa_pairs:
|
| 53 |
+
messages = [
|
| 54 |
+
{"role": "system", "content": system_prompt},
|
| 55 |
+
{"role": "user", "content": qa.question},
|
| 56 |
+
{"role": "assistant", "content": qa.answer},
|
| 57 |
+
]
|
| 58 |
+
full_prompt, answer_start = format_messages(messages, tokenizer)
|
| 59 |
+
seq_len = tokenizer(full_prompt, return_tensors="pt").input_ids.shape[1]
|
| 60 |
+
|
| 61 |
+
full_texts.append(full_prompt)
|
| 62 |
+
token_masks.append(torch.arange(seq_len) >= answer_start)
|
| 63 |
+
questions.append(qa.question)
|
| 64 |
+
|
| 65 |
+
return full_texts, token_masks, questions
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def run_extraction(
|
| 69 |
+
model: StandardizedTransformer,
|
| 70 |
+
model_name: str,
|
| 71 |
+
persona: PersonaData,
|
| 72 |
+
qa_pairs: list[QAPair],
|
| 73 |
+
variants: list[str],
|
| 74 |
+
remote: bool,
|
| 75 |
+
) -> list[VariantExtractionResult]:
|
| 76 |
+
"""Run activation extraction and save outputs for selected variants.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
model: Loaded standardized nnterp model.
|
| 80 |
+
model_name: HuggingFace model identifier used for artifact paths.
|
| 81 |
+
persona: The persona whose QA pairs are being extracted.
|
| 82 |
+
qa_pairs: Question-answer pairs to run extraction on.
|
| 83 |
+
variants: Prompt variants to extract (e.g. ``"templated"``, ``"biography"``).
|
| 84 |
+
remote: Whether to execute on NDIF.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
A list of extraction results, one per variant.
|
| 88 |
+
|
| 89 |
+
Raises:
|
| 90 |
+
ValueError: If ``qa_pairs`` is empty or an unsupported variant is given.
|
| 91 |
+
"""
|
| 92 |
+
if not qa_pairs:
|
| 93 |
+
raise ValueError("No QA pairs selected for extraction")
|
| 94 |
+
|
| 95 |
+
tokenizer = model.tokenizer
|
| 96 |
+
activations_dir = get_artifacts_dir() / "activations"
|
| 97 |
+
|
| 98 |
+
system_prompt_by_variant = {
|
| 99 |
+
"templated": format_templated_prompt(persona.templated_prompt),
|
| 100 |
+
"biography": format_biography_prompt(persona.biography_md),
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
results: list[VariantExtractionResult] = []
|
| 104 |
+
|
| 105 |
+
for variant in variants:
|
| 106 |
+
if variant not in system_prompt_by_variant:
|
| 107 |
+
raise ValueError(f"Unsupported variant: {variant}")
|
| 108 |
+
|
| 109 |
+
full_texts, token_masks, questions = _prepare_inputs(
|
| 110 |
+
tokenizer=tokenizer,
|
| 111 |
+
system_prompt=system_prompt_by_variant[variant],
|
| 112 |
+
qa_pairs=qa_pairs,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
per_question_vectors = extract_activations(
|
| 116 |
+
model=model,
|
| 117 |
+
full_texts=full_texts,
|
| 118 |
+
token_masks=token_masks,
|
| 119 |
+
remote=remote,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
artifact_dir = save_per_question_vectors(
|
| 123 |
+
root_dir=activations_dir,
|
| 124 |
+
model_name=model_name,
|
| 125 |
+
prompt_variant=variant,
|
| 126 |
+
persona_id=persona.id,
|
| 127 |
+
persona_name=persona.name,
|
| 128 |
+
per_question_vectors=per_question_vectors,
|
| 129 |
+
questions=questions,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
results.append(
|
| 133 |
+
VariantExtractionResult(
|
| 134 |
+
variant=variant,
|
| 135 |
+
output_dir=str(artifact_dir),
|
| 136 |
+
n_questions=per_question_vectors.shape[0],
|
| 137 |
+
n_layers=per_question_vectors.shape[1],
|
| 138 |
+
d_model=per_question_vectors.shape[2],
|
| 139 |
+
persona_name=persona.name,
|
| 140 |
+
)
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Free activation tensors between variants to keep memory bounded.
|
| 144 |
+
del per_question_vectors, full_texts, token_masks
|
| 145 |
+
gc.collect()
|
| 146 |
+
if torch.cuda.is_available():
|
| 147 |
+
torch.cuda.empty_cache()
|
| 148 |
+
if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
|
| 149 |
+
torch.mps.empty_cache()
|
| 150 |
+
|
| 151 |
+
return results
|
utils/helpers.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from persona_data.synth_persona import PersonaData
|
| 2 |
+
|
| 3 |
+
# Variant key -> human-readable label mapping
|
| 4 |
+
VARIANT_LABELS = {
|
| 5 |
+
"empty": "None",
|
| 6 |
+
"templated": "Template",
|
| 7 |
+
"biography": "Biography",
|
| 8 |
+
"custom": "Custom",
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
# Variants that correspond to actual system prompts (excludes "empty")
|
| 12 |
+
PROMPT_VARIANTS = ["templated", "biography"]
|
| 13 |
+
|
| 14 |
+
# For selectbox options: list of labels in definition order
|
| 15 |
+
MODE_LABELS = list(VARIANT_LABELS.values())
|
| 16 |
+
|
| 17 |
+
# Reverse lookup: label -> key
|
| 18 |
+
MODE_LABEL_TO_KEY = {v: k for k, v in VARIANT_LABELS.items()}
|
| 19 |
+
|
| 20 |
+
DATASET_SOURCES = ["HuggingFace: synth-persona", "Local JSONL upload"]
|
| 21 |
+
ANALYSIS_MODES = ["Cosine similarity", "PCA", "UMAP"]
|
| 22 |
+
|
| 23 |
+
ANALYSIS_LABELS = {
|
| 24 |
+
"PCA": ("PCA", "PC1", "PC2"),
|
| 25 |
+
"UMAP": ("UMAP", "UMAP 1", "UMAP 2"),
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
ANALYSIS_HELP_TEXT = {
|
| 29 |
+
"Cosine similarity": "Compare layer-wise alignment between variants.",
|
| 30 |
+
"PCA": "Project the selected layers into a global 2D view.",
|
| 31 |
+
"UMAP": "Project the selected layers into a local-neighborhood 2D view.",
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def slugify(value: str) -> str:
|
| 36 |
+
"""Convert a string to a slug safe for filenames and URLs."""
|
| 37 |
+
|
| 38 |
+
import re
|
| 39 |
+
|
| 40 |
+
return re.sub(r"[^a-z0-9]+", "_", value.lower()).strip("_") or "unknown"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def widget_key(*parts: str) -> str:
|
| 44 |
+
"""Generate a namespaced Streamlit widget key from parts."""
|
| 45 |
+
|
| 46 |
+
return "::".join(parts)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def prompt_variant_label(variant: str) -> str:
|
| 50 |
+
"""Return a human-friendly prompt-variant label."""
|
| 51 |
+
|
| 52 |
+
return VARIANT_LABELS.get(variant, variant.title())
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def persona_label(persona: PersonaData) -> str:
|
| 56 |
+
"""Format a persona for selection widgets."""
|
| 57 |
+
|
| 58 |
+
return f"{persona.name} ({persona.id})"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def persona_display_label(persona_id: str, persona_name: str | None) -> str:
|
| 62 |
+
"""Format a persona id with an optional display name."""
|
| 63 |
+
|
| 64 |
+
if persona_name:
|
| 65 |
+
return f"{persona_name} ({persona_id})"
|
| 66 |
+
return persona_id
|
utils/local_dataset.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Iterator, Literal
|
| 6 |
+
|
| 7 |
+
from persona_data.synth_persona import PersonaData, QAPair
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class LocalPersonaDataset:
|
| 12 |
+
"""Dataset loaded from local JSONL files."""
|
| 13 |
+
|
| 14 |
+
personas_path: Path
|
| 15 |
+
qa_path: Path
|
| 16 |
+
|
| 17 |
+
def __post_init__(self) -> None:
|
| 18 |
+
with self.personas_path.open() as f:
|
| 19 |
+
self._personas: list[PersonaData] = []
|
| 20 |
+
for line in f:
|
| 21 |
+
if not line.strip():
|
| 22 |
+
continue
|
| 23 |
+
data = json.loads(line)
|
| 24 |
+
self._personas.append(
|
| 25 |
+
PersonaData(
|
| 26 |
+
id=data["id"],
|
| 27 |
+
persona=data["persona"],
|
| 28 |
+
templated_prompt=data["templated_prompt"],
|
| 29 |
+
biography_md=data["biography_md"],
|
| 30 |
+
)
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
self._qa: dict[str, list[QAPair]] = defaultdict(list)
|
| 34 |
+
with self.qa_path.open() as f:
|
| 35 |
+
for line in f:
|
| 36 |
+
if not line.strip():
|
| 37 |
+
continue
|
| 38 |
+
data = json.loads(line)
|
| 39 |
+
self._qa[data["id"]].append(
|
| 40 |
+
QAPair(
|
| 41 |
+
qid=data["qid"],
|
| 42 |
+
type=data["type"],
|
| 43 |
+
question=data["question"],
|
| 44 |
+
answer=data["answer"],
|
| 45 |
+
difficulty=data["difficulty"],
|
| 46 |
+
)
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
def __len__(self) -> int:
|
| 50 |
+
return len(self._personas)
|
| 51 |
+
|
| 52 |
+
def __iter__(self) -> Iterator[PersonaData]:
|
| 53 |
+
return iter(self._personas)
|
| 54 |
+
|
| 55 |
+
def __getitem__(self, idx: int) -> PersonaData:
|
| 56 |
+
return self._personas[idx]
|
| 57 |
+
|
| 58 |
+
def get_qa(
|
| 59 |
+
self,
|
| 60 |
+
persona_id: str,
|
| 61 |
+
type: Literal["explicit", "implicit"] | None = None,
|
| 62 |
+
difficulty: int | list[int] | None = None,
|
| 63 |
+
) -> list[QAPair]:
|
| 64 |
+
pairs = self._qa.get(persona_id, [])
|
| 65 |
+
if type is not None:
|
| 66 |
+
pairs = [pair for pair in pairs if pair.type == type]
|
| 67 |
+
|
| 68 |
+
if difficulty is not None:
|
| 69 |
+
levels = {difficulty} if isinstance(difficulty, int) else set(difficulty)
|
| 70 |
+
pairs = [pair for pair in pairs if pair.difficulty in levels]
|
| 71 |
+
|
| 72 |
+
return pairs
|
utils/runtime.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import streamlit as st
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@st.cache_data(show_spinner=False, ttl=30)
|
| 9 |
+
def list_remote_models() -> list[str]:
|
| 10 |
+
"""Return the NDIF language models that are currently running."""
|
| 11 |
+
|
| 12 |
+
import nnsight
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
status = nnsight.ndif_status()
|
| 16 |
+
except Exception:
|
| 17 |
+
logger.warning("Failed to fetch NDIF status", exc_info=True)
|
| 18 |
+
return []
|
| 19 |
+
|
| 20 |
+
model_names: list[str] = []
|
| 21 |
+
|
| 22 |
+
for entry in status.values():
|
| 23 |
+
if not isinstance(entry, dict):
|
| 24 |
+
continue
|
| 25 |
+
if entry.get("model_class") not in {"LanguageModel", "StandardizedTransformer"}:
|
| 26 |
+
continue
|
| 27 |
+
|
| 28 |
+
state = entry.get("state")
|
| 29 |
+
state_name = getattr(state, "name", None) or getattr(state, "value", None)
|
| 30 |
+
if state_name != "RUNNING":
|
| 31 |
+
continue
|
| 32 |
+
|
| 33 |
+
repo_id = entry.get("repo_id")
|
| 34 |
+
if isinstance(repo_id, str):
|
| 35 |
+
model_names.append(repo_id)
|
| 36 |
+
|
| 37 |
+
return sorted(set(model_names))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@st.cache_resource(show_spinner=False, max_entries=1)
|
| 41 |
+
def cached_model(model_name: str, remote: bool):
|
| 42 |
+
"""Load and cache a standardized nnterp model.
|
| 43 |
+
|
| 44 |
+
Streamlit reruns this app on every interaction, so caching keeps one loaded
|
| 45 |
+
model instance per ``(model_name, remote)`` instead of reloading weights on
|
| 46 |
+
every widget change.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
from nnterp import StandardizedTransformer
|
| 50 |
+
|
| 51 |
+
# HACK: For now do it like this because of the bug.
|
| 52 |
+
# model = StandardizedTransformer(model_name, remote=True)
|
| 53 |
+
return StandardizedTransformer(model_name)
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|