Jac-Zac commited on
Commit ·
9ba2da4
1
Parent(s): 77c2d62
Updated code supporting latest version of persona-vector and data
Browse files- pyproject.toml +2 -3
- tabs/chat.py +136 -93
- tabs/chat_ui.py +60 -6
- tabs/compare.py +204 -253
- tabs/compare_chat.py +326 -243
- tabs/extract.py +390 -258
- tabs/probe_ui.py +164 -125
- utils/contrast.py +17 -28
- utils/runtime.py +58 -43
- uv.lock +0 -0
pyproject.toml
CHANGED
|
@@ -5,12 +5,11 @@ description = "Streamlit UI for persona-vectors"
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
-
"persona-vectors>=0.
|
| 9 |
-
"persona-data>=0.
|
| 10 |
"streamlit>=1.44.0",
|
| 11 |
"plotly>=6.6.0",
|
| 12 |
"python-dotenv>=1.2.2",
|
| 13 |
-
"transformers>=5.5.0",
|
| 14 |
]
|
| 15 |
|
| 16 |
# Local development:
|
|
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
+
"persona-vectors>=0.5.1",
|
| 9 |
+
"persona-data>=0.4.0",
|
| 10 |
"streamlit>=1.44.0",
|
| 11 |
"plotly>=6.6.0",
|
| 12 |
"python-dotenv>=1.2.2",
|
|
|
|
| 13 |
]
|
| 14 |
|
| 15 |
# Local development:
|
tabs/chat.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
| 1 |
import streamlit as st
|
|
|
|
| 2 |
|
| 3 |
-
from state import chat_session_key, get_chat_state, reset_chat_context_state
|
| 4 |
from tabs.chat_ui import (
|
| 5 |
GenerationConfig,
|
| 6 |
-
|
| 7 |
render_chat_window,
|
| 8 |
-
render_generation_settings,
|
| 9 |
render_persona_prompt_controls,
|
| 10 |
render_system_prompt,
|
| 11 |
)
|
|
@@ -27,21 +27,7 @@ _LAST_PROMPT_MODE_KEY = "chat:last_prompt_mode"
|
|
| 27 |
_LAST_COMPARE_MODE_KEY = "chat:last_compare_mode"
|
| 28 |
|
| 29 |
|
| 30 |
-
def
|
| 31 |
-
"""Render the chat tab."""
|
| 32 |
-
|
| 33 |
-
st.title("Chat")
|
| 34 |
-
|
| 35 |
-
context_key = chat_session_key(model_name, dataset_source)
|
| 36 |
-
chat_state = get_chat_state(model_name, remote, dataset_source)
|
| 37 |
-
|
| 38 |
-
# Carry over persona / prompt selections across model or remote switches.
|
| 39 |
-
if chat_state["persona_id"] is None:
|
| 40 |
-
chat_state["persona_id"] = st.session_state.get(_LAST_PERSONA_ID_KEY)
|
| 41 |
-
chat_state["prompt_mode"] = st.session_state.get(
|
| 42 |
-
_LAST_PROMPT_MODE_KEY, "templated"
|
| 43 |
-
)
|
| 44 |
-
|
| 45 |
try:
|
| 46 |
dataset, dataset_status = load_dataset(
|
| 47 |
dataset_source,
|
|
@@ -52,36 +38,124 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 52 |
except Exception as exc:
|
| 53 |
st.error(f"Could not load data: {exc}")
|
| 54 |
st.info("Check the selected dataset source or upload both JSONL files.")
|
| 55 |
-
return
|
| 56 |
|
| 57 |
personas = list(dataset)
|
| 58 |
if not personas:
|
| 59 |
st.warning("No personas found in the selected dataset.")
|
| 60 |
st.info("Try a different dataset source or upload a non-empty personas file.")
|
| 61 |
-
return
|
|
|
|
| 62 |
|
| 63 |
-
generation: GenerationConfig = render_generation_settings(context_key, remote)
|
| 64 |
-
probe_enabled = st.toggle(
|
| 65 |
-
"Probe tools",
|
| 66 |
-
value=False,
|
| 67 |
-
key=widget_key(context_key, "probe_enabled"),
|
| 68 |
-
help="Trace chat activations and run compatible `.pt` probes on tapped tokens.",
|
| 69 |
-
)
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
)
|
| 77 |
-
compare_mode = st.toggle(
|
| 78 |
-
"Compare mode",
|
| 79 |
-
key=compare_key,
|
| 80 |
-
help="Side-by-side: send one message to two independent persona/prompt configurations.",
|
| 81 |
-
)
|
| 82 |
-
st.session_state[_LAST_COMPARE_MODE_KEY] = compare_mode
|
| 83 |
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
render_compare_mode(
|
| 86 |
remote,
|
| 87 |
model_name,
|
|
@@ -89,6 +163,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 89 |
dataset_source,
|
| 90 |
personas,
|
| 91 |
generation,
|
|
|
|
| 92 |
)
|
| 93 |
return
|
| 94 |
|
|
@@ -150,7 +225,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 150 |
remote=remote,
|
| 151 |
active_system_prompt=active_system_prompt,
|
| 152 |
chat_state=chat_state,
|
| 153 |
-
enabled=probe_enabled,
|
| 154 |
)
|
| 155 |
|
| 156 |
render_chat_window(
|
|
@@ -161,36 +236,18 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 161 |
pending_key=pending_key,
|
| 162 |
)
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
dataset_source=dataset_source,
|
| 177 |
-
persona_id=selected_persona.id,
|
| 178 |
-
persona_name=getattr(selected_persona, "name", None),
|
| 179 |
-
prompt_mode=prompt_mode,
|
| 180 |
-
system_prompt=active_system_prompt,
|
| 181 |
-
messages=chat_state["messages"],
|
| 182 |
-
generation=generation_dict(generation),
|
| 183 |
-
)
|
| 184 |
-
st.toast("Exported", icon=":material/check:")
|
| 185 |
-
with rst_col:
|
| 186 |
-
if st.button(
|
| 187 |
-
"",
|
| 188 |
-
icon=":material/delete_sweep:",
|
| 189 |
-
key=reset_key,
|
| 190 |
-
help="Reset chat",
|
| 191 |
-
):
|
| 192 |
-
_reset_active_chat_context()
|
| 193 |
-
st.rerun()
|
| 194 |
|
| 195 |
user_prompt = st.chat_input("Ask something...", key=chat_input_key)
|
| 196 |
|
|
@@ -205,26 +262,12 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 205 |
if not pending_action:
|
| 206 |
return
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
past_key_values=chat_state["past_key_values"],
|
| 218 |
-
**generation.to_generate_kwargs(),
|
| 219 |
-
)
|
| 220 |
-
except Exception as exc:
|
| 221 |
-
with chat_log:
|
| 222 |
-
st.error(f"Could not generate a reply: {exc}")
|
| 223 |
-
st.info("Try a shorter prompt, reset the chat, or switch personas.")
|
| 224 |
-
if pending_action == "new_user_prompt" and chat_state["messages"]:
|
| 225 |
-
chat_state["messages"].pop()
|
| 226 |
-
return
|
| 227 |
-
|
| 228 |
-
chat_state["messages"].append({"role": "assistant", "content": reply.text})
|
| 229 |
-
chat_state["past_key_values"] = reply.past_key_values if not remote else None
|
| 230 |
-
st.rerun()
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
+
from persona_data.synth_persona import PersonaData
|
| 3 |
|
| 4 |
+
from state import ChatState, chat_session_key, get_chat_state, reset_chat_context_state
|
| 5 |
from tabs.chat_ui import (
|
| 6 |
GenerationConfig,
|
| 7 |
+
render_advanced_settings,
|
| 8 |
render_chat_window,
|
|
|
|
| 9 |
render_persona_prompt_controls,
|
| 10 |
render_system_prompt,
|
| 11 |
)
|
|
|
|
| 27 |
_LAST_COMPARE_MODE_KEY = "chat:last_compare_mode"
|
| 28 |
|
| 29 |
|
| 30 |
+
def _load_personas(dataset_source: str) -> list[PersonaData] | None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
try:
|
| 32 |
dataset, dataset_status = load_dataset(
|
| 33 |
dataset_source,
|
|
|
|
| 38 |
except Exception as exc:
|
| 39 |
st.error(f"Could not load data: {exc}")
|
| 40 |
st.info("Check the selected dataset source or upload both JSONL files.")
|
| 41 |
+
return None
|
| 42 |
|
| 43 |
personas = list(dataset)
|
| 44 |
if not personas:
|
| 45 |
st.warning("No personas found in the selected dataset.")
|
| 46 |
st.info("Try a different dataset source or upload a non-empty personas file.")
|
| 47 |
+
return None
|
| 48 |
+
return personas
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
+
def _render_single_chat_footer(
|
| 52 |
+
*,
|
| 53 |
+
model_name: str,
|
| 54 |
+
dataset_source: str,
|
| 55 |
+
persona: PersonaData,
|
| 56 |
+
prompt_mode: str,
|
| 57 |
+
system_prompt: str | None,
|
| 58 |
+
chat_state: ChatState,
|
| 59 |
+
generation: GenerationConfig,
|
| 60 |
+
export_key: str,
|
| 61 |
+
reset_key: str,
|
| 62 |
+
on_reset,
|
| 63 |
+
) -> None:
|
| 64 |
+
footer = st.container()
|
| 65 |
+
with footer:
|
| 66 |
+
exp_col, rst_col, _spacer = st.columns([0.5, 0.5, 10], gap="xsmall")
|
| 67 |
+
with exp_col:
|
| 68 |
+
if st.button(
|
| 69 |
+
"",
|
| 70 |
+
icon=":material/download:",
|
| 71 |
+
key=export_key,
|
| 72 |
+
help="Export chat",
|
| 73 |
+
):
|
| 74 |
+
save_chat_export(
|
| 75 |
+
model_name=model_name,
|
| 76 |
+
dataset_source=dataset_source,
|
| 77 |
+
persona_id=persona.id,
|
| 78 |
+
persona_name=getattr(persona, "name", None),
|
| 79 |
+
prompt_mode=prompt_mode,
|
| 80 |
+
system_prompt=system_prompt,
|
| 81 |
+
messages=chat_state["messages"],
|
| 82 |
+
generation=generation.to_export_dict(),
|
| 83 |
+
)
|
| 84 |
+
st.toast("Exported", icon=":material/check:")
|
| 85 |
+
with rst_col:
|
| 86 |
+
if st.button(
|
| 87 |
+
"",
|
| 88 |
+
icon=":material/delete_sweep:",
|
| 89 |
+
key=reset_key,
|
| 90 |
+
help="Reset chat",
|
| 91 |
+
):
|
| 92 |
+
on_reset()
|
| 93 |
+
st.rerun()
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _handle_single_chat_generation(
|
| 97 |
+
*,
|
| 98 |
+
remote: bool,
|
| 99 |
+
model_name: str,
|
| 100 |
+
chat_state: ChatState,
|
| 101 |
+
active_system_prompt: str | None,
|
| 102 |
+
generation: GenerationConfig,
|
| 103 |
+
pending_action: object,
|
| 104 |
+
chat_log,
|
| 105 |
+
) -> None:
|
| 106 |
+
messages = build_chat_messages(active_system_prompt, chat_state["messages"])
|
| 107 |
+
|
| 108 |
+
with st.spinner("Generating reply..."):
|
| 109 |
+
model = cached_model(model_name=model_name)
|
| 110 |
+
try:
|
| 111 |
+
reply: ChatReply = generate_chat_reply(
|
| 112 |
+
model=model,
|
| 113 |
+
messages=messages,
|
| 114 |
+
remote=remote,
|
| 115 |
+
past_key_values=chat_state["past_key_values"],
|
| 116 |
+
**generation.to_generate_kwargs(),
|
| 117 |
+
)
|
| 118 |
+
except Exception as exc:
|
| 119 |
+
with chat_log:
|
| 120 |
+
st.error(f"Could not generate a reply: {exc}")
|
| 121 |
+
st.info("Try a shorter prompt, reset the chat, or switch personas.")
|
| 122 |
+
if pending_action == "new_user_prompt" and chat_state["messages"]:
|
| 123 |
+
chat_state["messages"].pop()
|
| 124 |
+
return
|
| 125 |
+
|
| 126 |
+
chat_state["messages"].append({"role": "assistant", "content": reply.text})
|
| 127 |
+
chat_state["past_key_values"] = reply.past_key_values if not remote else None
|
| 128 |
+
st.rerun()
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
| 134 |
+
"""Render the chat tab."""
|
| 135 |
+
|
| 136 |
+
st.title("Chat")
|
| 137 |
+
st.caption("Chat with a persona, optionally side-by-side or with token contrast.")
|
| 138 |
+
|
| 139 |
+
context_key = chat_session_key(model_name, dataset_source)
|
| 140 |
+
chat_state = get_chat_state(model_name, remote, dataset_source)
|
| 141 |
+
|
| 142 |
+
# Carry over persona / prompt selections across model or remote switches.
|
| 143 |
+
if chat_state["persona_id"] is None:
|
| 144 |
+
chat_state["persona_id"] = st.session_state.get(_LAST_PERSONA_ID_KEY)
|
| 145 |
+
chat_state["prompt_mode"] = st.session_state.get(
|
| 146 |
+
_LAST_PROMPT_MODE_KEY, "templated"
|
| 147 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
personas = _load_personas(dataset_source)
|
| 150 |
+
if personas is None:
|
| 151 |
+
return
|
| 152 |
+
|
| 153 |
+
generation, tools = render_advanced_settings(
|
| 154 |
+
context_key,
|
| 155 |
+
remote,
|
| 156 |
+
last_compare_mode_key=_LAST_COMPARE_MODE_KEY,
|
| 157 |
+
)
|
| 158 |
+
if tools.compare_mode:
|
| 159 |
render_compare_mode(
|
| 160 |
remote,
|
| 161 |
model_name,
|
|
|
|
| 163 |
dataset_source,
|
| 164 |
personas,
|
| 165 |
generation,
|
| 166 |
+
contrast_enabled=tools.token_contrast,
|
| 167 |
)
|
| 168 |
return
|
| 169 |
|
|
|
|
| 225 |
remote=remote,
|
| 226 |
active_system_prompt=active_system_prompt,
|
| 227 |
chat_state=chat_state,
|
| 228 |
+
enabled=tools.probe_enabled,
|
| 229 |
)
|
| 230 |
|
| 231 |
render_chat_window(
|
|
|
|
| 236 |
pending_key=pending_key,
|
| 237 |
)
|
| 238 |
|
| 239 |
+
_render_single_chat_footer(
|
| 240 |
+
model_name=model_name,
|
| 241 |
+
dataset_source=dataset_source,
|
| 242 |
+
persona=selected_persona,
|
| 243 |
+
prompt_mode=prompt_mode,
|
| 244 |
+
system_prompt=active_system_prompt,
|
| 245 |
+
chat_state=chat_state,
|
| 246 |
+
generation=generation,
|
| 247 |
+
export_key=export_key,
|
| 248 |
+
reset_key=reset_key,
|
| 249 |
+
on_reset=_reset_active_chat_context,
|
| 250 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
user_prompt = st.chat_input("Ask something...", key=chat_input_key)
|
| 253 |
|
|
|
|
| 262 |
if not pending_action:
|
| 263 |
return
|
| 264 |
|
| 265 |
+
_handle_single_chat_generation(
|
| 266 |
+
remote=remote,
|
| 267 |
+
model_name=model_name,
|
| 268 |
+
chat_state=chat_state,
|
| 269 |
+
active_system_prompt=active_system_prompt,
|
| 270 |
+
generation=generation,
|
| 271 |
+
pending_action=pending_action,
|
| 272 |
+
chat_log=chat_log,
|
| 273 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tabs/chat_ui.py
CHANGED
|
@@ -48,6 +48,13 @@ class GenerationConfig:
|
|
| 48 |
}
|
| 49 |
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
@st.dialog("Edit", width="medium")
|
| 52 |
def _open_edit_dialog(
|
| 53 |
*,
|
|
@@ -108,13 +115,54 @@ def _open_system_prompt_dialog(*, prompt_key: str, current_value: str) -> None:
|
|
| 108 |
st.rerun()
|
| 109 |
|
| 110 |
|
| 111 |
-
def
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
-
def render_generation_settings(context_key: str, remote: bool) -> GenerationConfig:
|
| 116 |
-
"""Render the Advanced generation settings expander."""
|
| 117 |
-
with st.expander("Advanced", expanded=False):
|
| 118 |
config_col1, config_col2 = st.columns([2, 1])
|
| 119 |
with config_col1:
|
| 120 |
max_new_tokens = st.slider(
|
|
@@ -199,7 +247,7 @@ def render_generation_settings(context_key: str, remote: bool) -> GenerationConf
|
|
| 199 |
st.caption("Seed is local-only and disabled for remote runs.")
|
| 200 |
|
| 201 |
do_sample = bool(use_sampling)
|
| 202 |
-
|
| 203 |
max_new_tokens=int(max_new_tokens),
|
| 204 |
do_sample=do_sample,
|
| 205 |
temperature=float(temperature),
|
|
@@ -208,6 +256,12 @@ def render_generation_settings(context_key: str, remote: bool) -> GenerationConf
|
|
| 208 |
repetition_penalty=float(repetition_penalty),
|
| 209 |
seed=seed if do_sample and seed is not None and not remote else None,
|
| 210 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
|
| 213 |
def render_chat_message(
|
|
|
|
| 48 |
}
|
| 49 |
|
| 50 |
|
| 51 |
+
@dataclass(frozen=True)
|
| 52 |
+
class ChatTools:
|
| 53 |
+
probe_enabled: bool
|
| 54 |
+
compare_mode: bool
|
| 55 |
+
token_contrast: bool
|
| 56 |
+
|
| 57 |
+
|
| 58 |
@st.dialog("Edit", width="medium")
|
| 59 |
def _open_edit_dialog(
|
| 60 |
*,
|
|
|
|
| 115 |
st.rerun()
|
| 116 |
|
| 117 |
|
| 118 |
+
def render_advanced_settings(
|
| 119 |
+
context_key: str,
|
| 120 |
+
remote: bool,
|
| 121 |
+
*,
|
| 122 |
+
last_compare_mode_key: str,
|
| 123 |
+
) -> tuple[GenerationConfig, ChatTools]:
|
| 124 |
+
"""Render the Advanced expander: tool toggles + generation settings."""
|
| 125 |
+
with st.expander("Advanced", expanded=False):
|
| 126 |
+
st.caption("Tools")
|
| 127 |
|
| 128 |
+
compare_key = widget_key(context_key, "compare_mode")
|
| 129 |
+
if compare_key not in st.session_state:
|
| 130 |
+
st.session_state[compare_key] = st.session_state.get(
|
| 131 |
+
last_compare_mode_key, False
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
tools_col1, tools_col2, tools_col3 = st.columns(3)
|
| 135 |
+
with tools_col1:
|
| 136 |
+
probe_enabled = st.toggle(
|
| 137 |
+
"Probe tools",
|
| 138 |
+
value=False,
|
| 139 |
+
key=widget_key(context_key, "probe_enabled"),
|
| 140 |
+
help="Trace chat activations and run compatible `.pt` probes on tapped tokens.",
|
| 141 |
+
)
|
| 142 |
+
with tools_col2:
|
| 143 |
+
compare_mode = st.toggle(
|
| 144 |
+
"Compare mode",
|
| 145 |
+
key=compare_key,
|
| 146 |
+
help="Side-by-side: send one message to two independent persona/prompt configurations.",
|
| 147 |
+
)
|
| 148 |
+
with tools_col3:
|
| 149 |
+
token_contrast = st.toggle(
|
| 150 |
+
"Token contrast",
|
| 151 |
+
value=False,
|
| 152 |
+
key=widget_key(context_key, "token_contrast"),
|
| 153 |
+
disabled=not compare_mode,
|
| 154 |
+
help=(
|
| 155 |
+
"Color each generated token by how characteristic it is of each persona. "
|
| 156 |
+
"Red = more likely under the left persona, blue = more likely under the "
|
| 157 |
+
"right. Requires up to four extra scoring passes after each turn. "
|
| 158 |
+
"Available only in Compare mode."
|
| 159 |
+
),
|
| 160 |
+
)
|
| 161 |
+
st.session_state[last_compare_mode_key] = compare_mode
|
| 162 |
+
|
| 163 |
+
st.divider()
|
| 164 |
+
st.caption("Generation")
|
| 165 |
|
|
|
|
|
|
|
|
|
|
| 166 |
config_col1, config_col2 = st.columns([2, 1])
|
| 167 |
with config_col1:
|
| 168 |
max_new_tokens = st.slider(
|
|
|
|
| 247 |
st.caption("Seed is local-only and disabled for remote runs.")
|
| 248 |
|
| 249 |
do_sample = bool(use_sampling)
|
| 250 |
+
generation = GenerationConfig(
|
| 251 |
max_new_tokens=int(max_new_tokens),
|
| 252 |
do_sample=do_sample,
|
| 253 |
temperature=float(temperature),
|
|
|
|
| 256 |
repetition_penalty=float(repetition_penalty),
|
| 257 |
seed=seed if do_sample and seed is not None and not remote else None,
|
| 258 |
)
|
| 259 |
+
tools = ChatTools(
|
| 260 |
+
probe_enabled=probe_enabled,
|
| 261 |
+
compare_mode=compare_mode,
|
| 262 |
+
token_contrast=token_contrast and compare_mode,
|
| 263 |
+
)
|
| 264 |
+
return generation, tools
|
| 265 |
|
| 266 |
|
| 267 |
def render_chat_message(
|
tabs/compare.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
|
|
| 1 |
from itertools import combinations
|
|
|
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
from persona_data.environment import get_artifacts_dir
|
| 5 |
-
from persona_data.prompts import BASELINE_PERSONA_ID
|
| 6 |
from persona_vectors.analysis import (
|
| 7 |
load_persona_mean_samples,
|
| 8 |
load_variant_mean_samples,
|
|
@@ -41,6 +42,15 @@ _LAST_PROJECTION_PERSONAS_KEY = "compare:last_personas:projection"
|
|
| 41 |
_LAST_MASK_STRATEGY_KEY = "compare:last_mask_strategy"
|
| 42 |
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
def _select_artifact_personas(
|
| 45 |
store: ActivationStore,
|
| 46 |
variants: list[str],
|
|
@@ -143,71 +153,157 @@ def _render_mask_strategy_select(scope: str) -> MaskStrategy:
|
|
| 143 |
return selected
|
| 144 |
|
| 145 |
|
| 146 |
-
def
|
| 147 |
store: ActivationStore,
|
| 148 |
mask_strategy: MaskStrategy,
|
| 149 |
-
) -> None:
|
| 150 |
variants = list(store.variants)
|
| 151 |
if len(variants) < 2:
|
| 152 |
st.info("Need at least two non-baseline variants for cosine comparison.")
|
| 153 |
-
return
|
| 154 |
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
)
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
)
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
)
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
return
|
| 186 |
-
persona_key = "_".join(sorted(persona_ids))
|
| 187 |
|
| 188 |
cosine_fig_key = widget_key(
|
| 189 |
"load",
|
| 190 |
"cosine_fig_state",
|
| 191 |
store.model_name,
|
| 192 |
mask_strategy.value,
|
| 193 |
-
variant_a,
|
| 194 |
-
variant_b,
|
| 195 |
-
persona_key,
|
| 196 |
)
|
| 197 |
filename = _filename(
|
| 198 |
"compare",
|
| 199 |
"cosine",
|
| 200 |
store.model_name,
|
| 201 |
mask_strategy.value,
|
| 202 |
-
variant_a,
|
| 203 |
-
variant_b,
|
| 204 |
)
|
| 205 |
pairs_filename = _filename(
|
| 206 |
"compare",
|
| 207 |
"cosine_pairs",
|
| 208 |
store.model_name,
|
| 209 |
mask_strategy.value,
|
| 210 |
-
"_".join(variants),
|
| 211 |
)
|
| 212 |
|
| 213 |
if st.button(
|
|
@@ -218,79 +314,16 @@ def _render_cosine_similarity(
|
|
| 218 |
"compare_vectors",
|
| 219 |
store.model_name,
|
| 220 |
mask_strategy.value,
|
| 221 |
-
variant_a,
|
| 222 |
-
variant_b,
|
| 223 |
-
persona_key,
|
| 224 |
),
|
| 225 |
):
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
store,
|
| 229 |
-
[variant_a, variant_b],
|
| 230 |
-
persona_ids=persona_ids,
|
| 231 |
-
)
|
| 232 |
-
except Exception as exc:
|
| 233 |
-
st.error(f"Could not load vectors: {exc}")
|
| 234 |
st.session_state.pop(cosine_fig_key, None)
|
| 235 |
return
|
| 236 |
-
|
| 237 |
-
labels = variant_samples[variant_a].labels
|
| 238 |
-
display_traces = [
|
| 239 |
-
(
|
| 240 |
-
label,
|
| 241 |
-
variant_samples[variant_a].vectors[index],
|
| 242 |
-
variant_samples[variant_b].vectors[index],
|
| 243 |
-
)
|
| 244 |
-
for index, label in enumerate(labels)
|
| 245 |
-
]
|
| 246 |
-
fig = plot_layer_similarity(
|
| 247 |
-
display_traces,
|
| 248 |
-
title=f"{prompt_variant_label(variant_a)} vs {prompt_variant_label(variant_b)}",
|
| 249 |
-
show=False,
|
| 250 |
-
)
|
| 251 |
-
|
| 252 |
-
pair_traces = []
|
| 253 |
-
pair_errors = []
|
| 254 |
-
for left, right in combinations(variants, 2):
|
| 255 |
-
try:
|
| 256 |
-
pair_samples = (
|
| 257 |
-
variant_samples
|
| 258 |
-
if {left, right} == {variant_a, variant_b}
|
| 259 |
-
else load_variant_mean_samples(
|
| 260 |
-
store,
|
| 261 |
-
[left, right],
|
| 262 |
-
persona_ids=persona_ids,
|
| 263 |
-
)
|
| 264 |
-
)
|
| 265 |
-
except Exception as exc:
|
| 266 |
-
pair_errors.append(f"{left} vs {right}: {exc}")
|
| 267 |
-
continue
|
| 268 |
-
pair_traces.append(
|
| 269 |
-
(
|
| 270 |
-
f"{prompt_variant_label(left)} vs {prompt_variant_label(right)}",
|
| 271 |
-
pair_samples[left].vectors.mean(dim=0),
|
| 272 |
-
pair_samples[right].vectors.mean(dim=0),
|
| 273 |
-
)
|
| 274 |
-
)
|
| 275 |
-
|
| 276 |
-
if pair_errors:
|
| 277 |
-
for err in pair_errors:
|
| 278 |
-
st.warning(f"Skipped pair trace: `{err}`")
|
| 279 |
-
pair_fig = (
|
| 280 |
-
plot_layer_similarity(
|
| 281 |
-
pair_traces,
|
| 282 |
-
title="Variant-pair cosine similarity averaged over selected personas",
|
| 283 |
-
show=False,
|
| 284 |
-
)
|
| 285 |
-
if pair_traces
|
| 286 |
-
else None
|
| 287 |
-
)
|
| 288 |
-
st.session_state[cosine_fig_key] = (
|
| 289 |
-
fig,
|
| 290 |
-
pair_fig,
|
| 291 |
-
len(display_traces),
|
| 292 |
-
len(pair_traces),
|
| 293 |
-
)
|
| 294 |
|
| 295 |
if cosine_fig_key in st.session_state:
|
| 296 |
fig, pair_fig, n_traces, n_pair_traces = st.session_state[cosine_fig_key]
|
|
@@ -369,190 +402,89 @@ def _select_single_variant_samples(
|
|
| 369 |
return variant, persona_ids, persona_key, selected_layers
|
| 370 |
|
| 371 |
|
| 372 |
-
def
|
| 373 |
-
store: ActivationStore,
|
| 374 |
-
) -> bool:
|
| 375 |
-
return BASELINE_PERSONA_ID in store.list_personas(
|
| 376 |
-
[BASELINE_PERSONA_ID],
|
| 377 |
-
warn_missing=False,
|
| 378 |
-
)
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
def _render_baseline_reference_toggle(
|
| 382 |
store: ActivationStore,
|
| 383 |
mask_strategy: MaskStrategy,
|
|
|
|
| 384 |
scope: str,
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
value=available,
|
| 390 |
-
disabled=not available,
|
| 391 |
-
key=widget_key("load", "include_baseline", scope, mask_strategy.value),
|
| 392 |
-
help=(
|
| 393 |
-
"Adds the single saved baseline artifact as one reference sample."
|
| 394 |
-
if available
|
| 395 |
-
else "Run Assistant baseline extraction first."
|
| 396 |
-
),
|
| 397 |
-
)
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
def _render_similarity_matrix(
|
| 401 |
-
store: ActivationStore,
|
| 402 |
-
mask_strategy: MaskStrategy,
|
| 403 |
) -> None:
|
| 404 |
-
|
| 405 |
-
store,
|
| 406 |
-
mask_strategy,
|
| 407 |
-
"similarity_matrix",
|
| 408 |
-
)
|
| 409 |
-
if selected is None:
|
| 410 |
-
return
|
| 411 |
-
variant, persona_ids, persona_key, selected_layers = selected
|
| 412 |
-
include_baseline = _render_baseline_reference_toggle(
|
| 413 |
-
store,
|
| 414 |
-
mask_strategy,
|
| 415 |
-
"similarity_matrix",
|
| 416 |
-
)
|
| 417 |
-
|
| 418 |
-
fig_key = widget_key(
|
| 419 |
-
"load",
|
| 420 |
-
"similarity_matrix_fig_state",
|
| 421 |
-
store.model_name,
|
| 422 |
-
mask_strategy.value,
|
| 423 |
-
variant,
|
| 424 |
-
"persona_mean",
|
| 425 |
-
persona_key,
|
| 426 |
-
BASELINE_PERSONA_ID if include_baseline else "no_baseline",
|
| 427 |
-
)
|
| 428 |
-
filename = _filename(
|
| 429 |
-
"compare",
|
| 430 |
-
"similarity_matrix",
|
| 431 |
-
store.model_name,
|
| 432 |
-
mask_strategy.value,
|
| 433 |
-
variant,
|
| 434 |
-
"persona_mean",
|
| 435 |
-
persona_key,
|
| 436 |
-
BASELINE_PERSONA_ID if include_baseline else "",
|
| 437 |
-
)
|
| 438 |
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
variant,
|
| 444 |
-
mask_strategy=mask_strategy,
|
| 445 |
-
persona_ids=persona_ids,
|
| 446 |
-
include_baseline=include_baseline,
|
| 447 |
-
)
|
| 448 |
-
matrix_fig = build_layered_figure(
|
| 449 |
-
samples,
|
| 450 |
-
"similarity",
|
| 451 |
-
layers=selected_layers,
|
| 452 |
-
title=(
|
| 453 |
-
"Centered similarity - "
|
| 454 |
-
f"{prompt_variant_label(variant)} - personas averaged over questions"
|
| 455 |
-
),
|
| 456 |
-
)
|
| 457 |
-
trajectory_fig = build_pair_similarity_figure(
|
| 458 |
-
samples,
|
| 459 |
-
layers=selected_layers,
|
| 460 |
-
title=(
|
| 461 |
-
"Pair similarity trajectories - "
|
| 462 |
-
f"{prompt_variant_label(variant)} - personas averaged over questions"
|
| 463 |
-
),
|
| 464 |
-
)
|
| 465 |
-
st.session_state[fig_key] = (
|
| 466 |
-
matrix_fig,
|
| 467 |
-
trajectory_fig,
|
| 468 |
-
samples.vectors.shape[0],
|
| 469 |
-
)
|
| 470 |
-
except Exception as exc:
|
| 471 |
-
st.error(f"Could not build similarity matrix: {exc}")
|
| 472 |
-
st.session_state.pop(fig_key, None)
|
| 473 |
-
|
| 474 |
-
if fig_key in st.session_state:
|
| 475 |
-
matrix_fig, trajectory_fig, n_samples = st.session_state[fig_key]
|
| 476 |
-
st.plotly_chart(matrix_fig, width="stretch")
|
| 477 |
-
st.subheader("Pair trajectories")
|
| 478 |
-
st.plotly_chart(trajectory_fig, width="stretch")
|
| 479 |
-
_render_save_buttons(
|
| 480 |
-
[matrix_fig, trajectory_fig],
|
| 481 |
-
[filename, f"{filename}__pair_trajectories"],
|
| 482 |
-
"similarity_matrix",
|
| 483 |
-
)
|
| 484 |
-
st.success(f"Loaded {n_samples} samples.")
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
def _render_embedding_analysis(
|
| 488 |
-
store: ActivationStore,
|
| 489 |
-
analysis_mode: str,
|
| 490 |
-
mask_strategy: MaskStrategy,
|
| 491 |
-
) -> None:
|
| 492 |
-
selected = _select_single_variant_samples(
|
| 493 |
-
store,
|
| 494 |
-
mask_strategy,
|
| 495 |
-
analysis_mode.lower(),
|
| 496 |
-
)
|
| 497 |
if selected is None:
|
| 498 |
return
|
| 499 |
variant, persona_ids, persona_key, selected_layers = selected
|
| 500 |
|
| 501 |
-
figure_kind = analysis_mode.lower()
|
| 502 |
-
include_baseline = _render_baseline_reference_toggle(
|
| 503 |
-
store,
|
| 504 |
-
mask_strategy,
|
| 505 |
-
analysis_mode.lower(),
|
| 506 |
-
)
|
| 507 |
-
|
| 508 |
fig_key = widget_key(
|
| 509 |
"load",
|
| 510 |
-
"
|
| 511 |
store.model_name,
|
| 512 |
mask_strategy.value,
|
| 513 |
figure_kind,
|
| 514 |
variant,
|
| 515 |
"persona_mean",
|
| 516 |
persona_key,
|
| 517 |
-
BASELINE_PERSONA_ID if include_baseline else "no_baseline",
|
| 518 |
)
|
| 519 |
filename = _filename(
|
| 520 |
"compare",
|
| 521 |
-
|
| 522 |
store.model_name,
|
| 523 |
mask_strategy.value,
|
| 524 |
variant,
|
| 525 |
"persona_mean",
|
| 526 |
persona_key,
|
| 527 |
-
BASELINE_PERSONA_ID if include_baseline else "",
|
| 528 |
)
|
| 529 |
|
| 530 |
-
if st.button(
|
| 531 |
try:
|
| 532 |
samples = load_persona_mean_samples(
|
| 533 |
store,
|
| 534 |
variant,
|
| 535 |
mask_strategy=mask_strategy,
|
| 536 |
persona_ids=persona_ids,
|
| 537 |
-
include_baseline=include_baseline,
|
| 538 |
)
|
| 539 |
-
|
| 540 |
samples,
|
| 541 |
figure_kind,
|
| 542 |
layers=selected_layers,
|
| 543 |
-
title=(
|
| 544 |
-
|
| 545 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
)
|
| 547 |
-
st.session_state[fig_key] = (
|
| 548 |
except Exception as exc:
|
| 549 |
-
st.error(f"Could not build
|
| 550 |
st.session_state.pop(fig_key, None)
|
| 551 |
|
| 552 |
if fig_key in st.session_state:
|
| 553 |
-
|
| 554 |
-
st.plotly_chart(
|
| 555 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
st.success(f"Loaded {n_samples} samples.")
|
| 557 |
|
| 558 |
|
|
@@ -562,9 +494,7 @@ def render_compare_tab(model_name: str) -> None:
|
|
| 562 |
st.title("Compare")
|
| 563 |
st.caption("Compare saved activations by cosine similarity, PCA, or UMAP.")
|
| 564 |
|
| 565 |
-
st.
|
| 566 |
-
|
| 567 |
-
with st.expander("Advanced", expanded=False):
|
| 568 |
artifacts_root = st.text_input(
|
| 569 |
"Artifacts root",
|
| 570 |
value=str(get_artifacts_dir() / "activations"),
|
|
@@ -580,14 +510,35 @@ def render_compare_tab(model_name: str) -> None:
|
|
| 580 |
if analysis_mode is None:
|
| 581 |
analysis_mode = ANALYSIS_MODES[0]
|
| 582 |
st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
|
| 583 |
-
|
|
|
|
| 584 |
store = ActivationStore(model_name, artifacts_root, mask_strategy=mask_strategy)
|
| 585 |
|
| 586 |
if analysis_mode == "Cosine similarity":
|
| 587 |
_render_cosine_similarity(store, mask_strategy)
|
| 588 |
return
|
| 589 |
if analysis_mode == "Similarity matrix":
|
| 590 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 591 |
return
|
| 592 |
|
| 593 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Callable
|
| 2 |
from itertools import combinations
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
|
| 5 |
import streamlit as st
|
| 6 |
from persona_data.environment import get_artifacts_dir
|
|
|
|
| 7 |
from persona_vectors.analysis import (
|
| 8 |
load_persona_mean_samples,
|
| 9 |
load_variant_mean_samples,
|
|
|
|
| 42 |
_LAST_MASK_STRATEGY_KEY = "compare:last_mask_strategy"
|
| 43 |
|
| 44 |
|
| 45 |
+
@dataclass(frozen=True)
|
| 46 |
+
class CosineSelection:
|
| 47 |
+
variants: list[str]
|
| 48 |
+
variant_a: str
|
| 49 |
+
variant_b: str
|
| 50 |
+
persona_ids: list[str]
|
| 51 |
+
persona_key: str
|
| 52 |
+
|
| 53 |
+
|
| 54 |
def _select_artifact_personas(
|
| 55 |
store: ActivationStore,
|
| 56 |
variants: list[str],
|
|
|
|
| 153 |
return selected
|
| 154 |
|
| 155 |
|
| 156 |
+
def _render_cosine_selection(
|
| 157 |
store: ActivationStore,
|
| 158 |
mask_strategy: MaskStrategy,
|
| 159 |
+
) -> CosineSelection | None:
|
| 160 |
variants = list(store.variants)
|
| 161 |
if len(variants) < 2:
|
| 162 |
st.info("Need at least two non-baseline variants for cosine comparison.")
|
| 163 |
+
return None
|
| 164 |
|
| 165 |
+
with st.expander("Vector selection", expanded=True):
|
| 166 |
+
col1, col2 = st.columns(2)
|
| 167 |
+
with col1:
|
| 168 |
+
variant_a = st.selectbox(
|
| 169 |
+
"Variant A",
|
| 170 |
+
options=variants,
|
| 171 |
+
index=0,
|
| 172 |
+
format_func=prompt_variant_label,
|
| 173 |
+
key=widget_key("load", "variant_a"),
|
| 174 |
+
)
|
| 175 |
+
with col2:
|
| 176 |
+
variant_b = st.selectbox(
|
| 177 |
+
"Variant B",
|
| 178 |
+
options=variants,
|
| 179 |
+
index=min(1, len(variants) - 1),
|
| 180 |
+
format_func=prompt_variant_label,
|
| 181 |
+
key=widget_key("load", "variant_b"),
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
if variant_a == variant_b:
|
| 185 |
+
st.warning("Choose two different variants to compare.")
|
| 186 |
+
return None
|
| 187 |
+
|
| 188 |
+
persona_ids, _ = _select_artifact_personas(
|
| 189 |
+
store,
|
| 190 |
+
[variant_a, variant_b],
|
| 191 |
+
mask_strategy,
|
| 192 |
+
widget_scope="cosine",
|
| 193 |
+
remember_key=_LAST_COSINE_PERSONAS_KEY,
|
| 194 |
)
|
| 195 |
+
if not persona_ids:
|
| 196 |
+
return None
|
| 197 |
+
return CosineSelection(
|
| 198 |
+
variants=variants,
|
| 199 |
+
variant_a=variant_a,
|
| 200 |
+
variant_b=variant_b,
|
| 201 |
+
persona_ids=persona_ids,
|
| 202 |
+
persona_key="_".join(sorted(persona_ids)),
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def _build_cosine_figures(
|
| 207 |
+
store: ActivationStore,
|
| 208 |
+
selection: CosineSelection,
|
| 209 |
+
) -> tuple[object, object | None, int, int] | None:
|
| 210 |
+
try:
|
| 211 |
+
variant_samples = load_variant_mean_samples(
|
| 212 |
+
store,
|
| 213 |
+
[selection.variant_a, selection.variant_b],
|
| 214 |
+
persona_ids=selection.persona_ids,
|
| 215 |
)
|
| 216 |
+
except Exception as exc:
|
| 217 |
+
st.error(f"Could not load vectors: {exc}")
|
| 218 |
+
return None
|
| 219 |
|
| 220 |
+
labels = variant_samples[selection.variant_a].labels
|
| 221 |
+
display_traces = [
|
| 222 |
+
(
|
| 223 |
+
label,
|
| 224 |
+
variant_samples[selection.variant_a].vectors[index],
|
| 225 |
+
variant_samples[selection.variant_b].vectors[index],
|
| 226 |
+
)
|
| 227 |
+
for index, label in enumerate(labels)
|
| 228 |
+
]
|
| 229 |
+
fig = plot_layer_similarity(
|
| 230 |
+
display_traces,
|
| 231 |
+
title=(
|
| 232 |
+
f"{prompt_variant_label(selection.variant_a)} vs "
|
| 233 |
+
f"{prompt_variant_label(selection.variant_b)}"
|
| 234 |
+
),
|
| 235 |
+
show=False,
|
| 236 |
+
)
|
| 237 |
|
| 238 |
+
pair_traces = []
|
| 239 |
+
pair_errors = []
|
| 240 |
+
for left, right in combinations(selection.variants, 2):
|
| 241 |
+
try:
|
| 242 |
+
pair_samples = (
|
| 243 |
+
variant_samples
|
| 244 |
+
if {left, right} == {selection.variant_a, selection.variant_b}
|
| 245 |
+
else load_variant_mean_samples(
|
| 246 |
+
store,
|
| 247 |
+
[left, right],
|
| 248 |
+
persona_ids=selection.persona_ids,
|
| 249 |
+
)
|
| 250 |
+
)
|
| 251 |
+
pair_traces.append(
|
| 252 |
+
(
|
| 253 |
+
f"{prompt_variant_label(left)} vs {prompt_variant_label(right)}",
|
| 254 |
+
pair_samples[left].vectors.mean(dim=0),
|
| 255 |
+
pair_samples[right].vectors.mean(dim=0),
|
| 256 |
+
)
|
| 257 |
+
)
|
| 258 |
+
except Exception as exc:
|
| 259 |
+
pair_errors.append(f"{left} vs {right}: {exc}")
|
| 260 |
+
continue
|
| 261 |
+
|
| 262 |
+
for err in pair_errors:
|
| 263 |
+
st.warning(f"Skipped pair trace: `{err}`")
|
| 264 |
+
pair_fig = (
|
| 265 |
+
plot_layer_similarity(
|
| 266 |
+
pair_traces,
|
| 267 |
+
title="Variant-pair cosine similarity averaged over selected personas",
|
| 268 |
+
show=False,
|
| 269 |
+
)
|
| 270 |
+
if pair_traces
|
| 271 |
+
else None
|
| 272 |
)
|
| 273 |
+
return fig, pair_fig, len(display_traces), len(pair_traces)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def _render_cosine_similarity(
|
| 277 |
+
store: ActivationStore,
|
| 278 |
+
mask_strategy: MaskStrategy,
|
| 279 |
+
) -> None:
|
| 280 |
+
selection = _render_cosine_selection(store, mask_strategy)
|
| 281 |
+
if selection is None:
|
| 282 |
return
|
|
|
|
| 283 |
|
| 284 |
cosine_fig_key = widget_key(
|
| 285 |
"load",
|
| 286 |
"cosine_fig_state",
|
| 287 |
store.model_name,
|
| 288 |
mask_strategy.value,
|
| 289 |
+
selection.variant_a,
|
| 290 |
+
selection.variant_b,
|
| 291 |
+
selection.persona_key,
|
| 292 |
)
|
| 293 |
filename = _filename(
|
| 294 |
"compare",
|
| 295 |
"cosine",
|
| 296 |
store.model_name,
|
| 297 |
mask_strategy.value,
|
| 298 |
+
selection.variant_a,
|
| 299 |
+
selection.variant_b,
|
| 300 |
)
|
| 301 |
pairs_filename = _filename(
|
| 302 |
"compare",
|
| 303 |
"cosine_pairs",
|
| 304 |
store.model_name,
|
| 305 |
mask_strategy.value,
|
| 306 |
+
"_".join(selection.variants),
|
| 307 |
)
|
| 308 |
|
| 309 |
if st.button(
|
|
|
|
| 314 |
"compare_vectors",
|
| 315 |
store.model_name,
|
| 316 |
mask_strategy.value,
|
| 317 |
+
selection.variant_a,
|
| 318 |
+
selection.variant_b,
|
| 319 |
+
selection.persona_key,
|
| 320 |
),
|
| 321 |
):
|
| 322 |
+
figures = _build_cosine_figures(store, selection)
|
| 323 |
+
if figures is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
st.session_state.pop(cosine_fig_key, None)
|
| 325 |
return
|
| 326 |
+
st.session_state[cosine_fig_key] = figures
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
if cosine_fig_key in st.session_state:
|
| 329 |
fig, pair_fig, n_traces, n_pair_traces = st.session_state[cosine_fig_key]
|
|
|
|
| 402 |
return variant, persona_ids, persona_key, selected_layers
|
| 403 |
|
| 404 |
|
| 405 |
+
def _render_layered_figure_analysis(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
store: ActivationStore,
|
| 407 |
mask_strategy: MaskStrategy,
|
| 408 |
+
*,
|
| 409 |
scope: str,
|
| 410 |
+
figure_kind: str,
|
| 411 |
+
button_label: str,
|
| 412 |
+
title_fn: Callable[[str], str],
|
| 413 |
+
include_pair_trajectories: bool = False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
) -> None:
|
| 415 |
+
"""Render a single-variant layered analysis: select → button → figure(s).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
|
| 417 |
+
Used for similarity matrix, PCA, and UMAP. Set ``include_pair_trajectories``
|
| 418 |
+
to add the pair-similarity-trajectory figure (similarity matrix only).
|
| 419 |
+
"""
|
| 420 |
+
selected = _select_single_variant_samples(store, mask_strategy, scope)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
if selected is None:
|
| 422 |
return
|
| 423 |
variant, persona_ids, persona_key, selected_layers = selected
|
| 424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
fig_key = widget_key(
|
| 426 |
"load",
|
| 427 |
+
f"{scope}_fig_state",
|
| 428 |
store.model_name,
|
| 429 |
mask_strategy.value,
|
| 430 |
figure_kind,
|
| 431 |
variant,
|
| 432 |
"persona_mean",
|
| 433 |
persona_key,
|
|
|
|
| 434 |
)
|
| 435 |
filename = _filename(
|
| 436 |
"compare",
|
| 437 |
+
scope,
|
| 438 |
store.model_name,
|
| 439 |
mask_strategy.value,
|
| 440 |
variant,
|
| 441 |
"persona_mean",
|
| 442 |
persona_key,
|
|
|
|
| 443 |
)
|
| 444 |
|
| 445 |
+
if st.button(button_label, type="primary"):
|
| 446 |
try:
|
| 447 |
samples = load_persona_mean_samples(
|
| 448 |
store,
|
| 449 |
variant,
|
| 450 |
mask_strategy=mask_strategy,
|
| 451 |
persona_ids=persona_ids,
|
|
|
|
| 452 |
)
|
| 453 |
+
main_fig = build_layered_figure(
|
| 454 |
samples,
|
| 455 |
figure_kind,
|
| 456 |
layers=selected_layers,
|
| 457 |
+
title=title_fn(variant),
|
| 458 |
+
)
|
| 459 |
+
extra_fig = (
|
| 460 |
+
build_pair_similarity_figure(
|
| 461 |
+
samples,
|
| 462 |
+
layers=selected_layers,
|
| 463 |
+
title=(
|
| 464 |
+
"Pair similarity trajectories - "
|
| 465 |
+
f"{prompt_variant_label(variant)} - "
|
| 466 |
+
"persona mean activations"
|
| 467 |
+
),
|
| 468 |
+
)
|
| 469 |
+
if include_pair_trajectories
|
| 470 |
+
else None
|
| 471 |
)
|
| 472 |
+
st.session_state[fig_key] = (main_fig, extra_fig, samples.vectors.shape[0])
|
| 473 |
except Exception as exc:
|
| 474 |
+
st.error(f"Could not build figure: {exc}")
|
| 475 |
st.session_state.pop(fig_key, None)
|
| 476 |
|
| 477 |
if fig_key in st.session_state:
|
| 478 |
+
main_fig, extra_fig, n_samples = st.session_state[fig_key]
|
| 479 |
+
st.plotly_chart(main_fig, width="stretch")
|
| 480 |
+
figs = [main_fig]
|
| 481 |
+
filenames = [filename]
|
| 482 |
+
if extra_fig is not None:
|
| 483 |
+
st.subheader("Pair trajectories")
|
| 484 |
+
st.plotly_chart(extra_fig, width="stretch")
|
| 485 |
+
figs.append(extra_fig)
|
| 486 |
+
filenames.append(f"{filename}__pair_trajectories")
|
| 487 |
+
_render_save_buttons(figs, filenames, scope)
|
| 488 |
st.success(f"Loaded {n_samples} samples.")
|
| 489 |
|
| 490 |
|
|
|
|
| 494 |
st.title("Compare")
|
| 495 |
st.caption("Compare saved activations by cosine similarity, PCA, or UMAP.")
|
| 496 |
|
| 497 |
+
with st.expander("Artifact settings", expanded=False):
|
|
|
|
|
|
|
| 498 |
artifacts_root = st.text_input(
|
| 499 |
"Artifacts root",
|
| 500 |
value=str(get_artifacts_dir() / "activations"),
|
|
|
|
| 510 |
if analysis_mode is None:
|
| 511 |
analysis_mode = ANALYSIS_MODES[0]
|
| 512 |
st.caption(ANALYSIS_HELP_TEXT[analysis_mode])
|
| 513 |
+
with st.expander("Activation settings", expanded=False):
|
| 514 |
+
mask_strategy = _render_mask_strategy_select(analysis_mode)
|
| 515 |
store = ActivationStore(model_name, artifacts_root, mask_strategy=mask_strategy)
|
| 516 |
|
| 517 |
if analysis_mode == "Cosine similarity":
|
| 518 |
_render_cosine_similarity(store, mask_strategy)
|
| 519 |
return
|
| 520 |
if analysis_mode == "Similarity matrix":
|
| 521 |
+
_render_layered_figure_analysis(
|
| 522 |
+
store,
|
| 523 |
+
mask_strategy,
|
| 524 |
+
scope="similarity_matrix",
|
| 525 |
+
figure_kind="similarity",
|
| 526 |
+
button_label="Generate similarity matrix",
|
| 527 |
+
title_fn=lambda v: (
|
| 528 |
+
"Centered similarity - "
|
| 529 |
+
f"{prompt_variant_label(v)} - persona mean activations"
|
| 530 |
+
),
|
| 531 |
+
include_pair_trajectories=True,
|
| 532 |
+
)
|
| 533 |
return
|
| 534 |
|
| 535 |
+
_render_layered_figure_analysis(
|
| 536 |
+
store,
|
| 537 |
+
mask_strategy,
|
| 538 |
+
scope=analysis_mode.lower(),
|
| 539 |
+
figure_kind=analysis_mode.lower(),
|
| 540 |
+
button_label=f"Generate {analysis_mode} projection",
|
| 541 |
+
title_fn=lambda v: (
|
| 542 |
+
f"{analysis_mode} - {prompt_variant_label(v)} - Persona means"
|
| 543 |
+
),
|
| 544 |
+
)
|
tabs/compare_chat.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
-
from
|
|
|
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
from nnterp import StandardizedTransformer
|
| 5 |
from persona_data.synth_persona import PersonaData
|
| 6 |
|
| 7 |
-
from state import default_chat_state, reset_chat_context_state
|
| 8 |
from utils.chat import (
|
| 9 |
ChatReply,
|
| 10 |
build_chat_messages,
|
|
@@ -18,7 +19,6 @@ from utils.runtime import cached_model
|
|
| 18 |
|
| 19 |
from .chat_ui import (
|
| 20 |
GenerationConfig,
|
| 21 |
-
generation_dict,
|
| 22 |
render_chat_message,
|
| 23 |
render_chat_window,
|
| 24 |
render_persona_prompt_controls,
|
|
@@ -26,9 +26,10 @@ from .chat_ui import (
|
|
| 26 |
)
|
| 27 |
|
| 28 |
|
| 29 |
-
|
|
|
|
| 30 |
side: str
|
| 31 |
-
state:
|
| 32 |
log: Any
|
| 33 |
prompt: str | None
|
| 34 |
persona: PersonaData
|
|
@@ -37,6 +38,13 @@ class ComparePanel(NamedTuple):
|
|
| 37 |
pending_key: str
|
| 38 |
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
def _reset_compare_panel(panel: ComparePanel) -> None:
|
| 41 |
reset_chat_context_state(
|
| 42 |
panel.state,
|
|
@@ -48,195 +56,188 @@ def _reset_compare_panel(panel: ComparePanel) -> None:
|
|
| 48 |
st.session_state.pop(panel.edit_key, None)
|
| 49 |
|
| 50 |
|
| 51 |
-
def
|
| 52 |
*,
|
| 53 |
-
model: StandardizedTransformer,
|
| 54 |
-
remote: bool,
|
| 55 |
-
panel_state: dict[str, object],
|
| 56 |
-
panel_prompt: str | None,
|
| 57 |
-
generation: GenerationConfig,
|
| 58 |
-
) -> ChatReply:
|
| 59 |
-
return generate_chat_reply(
|
| 60 |
-
model=model,
|
| 61 |
-
messages=build_chat_messages(panel_prompt, panel_state["messages"]),
|
| 62 |
-
remote=remote,
|
| 63 |
-
past_key_values=panel_state["past_key_values"],
|
| 64 |
-
**generation.to_generate_kwargs(),
|
| 65 |
-
)
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def render_compare_mode(
|
| 69 |
-
remote: bool,
|
| 70 |
-
model_name: str,
|
| 71 |
context_key: str,
|
| 72 |
-
|
| 73 |
personas: list[PersonaData],
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
"Requires up to four extra scoring passes after each turn."
|
| 94 |
-
),
|
| 95 |
)
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
# Carry over persona / prompt selections across model or remote switches.
|
| 108 |
-
persist_persona_key = f"chat:last_cmp_{side}_persona"
|
| 109 |
-
persist_prompt_key = f"chat:last_cmp_{side}_prompt"
|
| 110 |
-
if state["persona_id"] is None:
|
| 111 |
-
state["persona_id"] = st.session_state.get(persist_persona_key)
|
| 112 |
-
state["prompt_mode"] = st.session_state.get(persist_prompt_key, "templated")
|
| 113 |
-
|
| 114 |
-
selected_persona, prompt_mode, changed = render_persona_prompt_controls(
|
| 115 |
-
personas,
|
| 116 |
-
state["persona_id"],
|
| 117 |
-
state["prompt_mode"],
|
| 118 |
-
widget_key(panel_key, "persona"),
|
| 119 |
-
widget_key(panel_key, "prompt_mode"),
|
| 120 |
)
|
| 121 |
-
st.session_state
|
| 122 |
-
st.session_state[persist_prompt_key] = prompt_mode
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
st.session_state.pop(edit_key, None)
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
)
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
persona=selected_persona,
|
| 145 |
-
prompt_key=prompt_key,
|
| 146 |
-
edit_key=edit_key,
|
| 147 |
-
pending_key=pending_key,
|
| 148 |
-
)
|
| 149 |
|
| 150 |
-
left_col, right_col = st.columns(2)
|
| 151 |
-
with left_col:
|
| 152 |
-
left = render_panel("left")
|
| 153 |
-
with right_col:
|
| 154 |
-
right = render_panel("right")
|
| 155 |
-
panels: list[ComparePanel] = [left, right]
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
)
|
| 175 |
-
|
| 176 |
-
|
|
|
|
|
|
|
| 177 |
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
continue
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
| 185 |
)
|
| 186 |
-
|
| 187 |
-
|
|
|
|
| 188 |
)
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
|
| 191 |
-
# Recompute contrast for assistant messages that were edited in place.
|
| 192 |
-
if contrast_enabled:
|
| 193 |
-
pending_edits: list[tuple[int, int]] = [
|
| 194 |
-
(panel_idx, msg_idx)
|
| 195 |
-
for panel_idx, panel in enumerate(panels)
|
| 196 |
-
for msg_idx, msg in enumerate(panel.state["messages"])
|
| 197 |
-
if msg.get("_needs_contrast") and msg.get("role") == "assistant"
|
| 198 |
-
]
|
| 199 |
-
if pending_edits:
|
| 200 |
-
model = _get_model()
|
| 201 |
-
label_a = persona_label(left.persona)
|
| 202 |
-
label_b = persona_label(right.persona)
|
| 203 |
-
with st.spinner("Recomputing token contrast…"):
|
| 204 |
-
for panel_idx, msg_idx in pending_edits:
|
| 205 |
-
panel = panels[panel_idx]
|
| 206 |
-
msg = panel.state["messages"][msg_idx]
|
| 207 |
-
if msg_idx >= len(left.state["messages"]) or msg_idx >= len(
|
| 208 |
-
right.state["messages"]
|
| 209 |
-
):
|
| 210 |
-
msg.pop("_needs_contrast", None)
|
| 211 |
-
continue
|
| 212 |
-
context_a = build_chat_messages(
|
| 213 |
-
left.prompt, left.state["messages"][:msg_idx]
|
| 214 |
-
)
|
| 215 |
-
context_b = build_chat_messages(
|
| 216 |
-
right.prompt, right.state["messages"][:msg_idx]
|
| 217 |
-
)
|
| 218 |
-
try:
|
| 219 |
-
response_ids = model.tokenizer(
|
| 220 |
-
msg["content"],
|
| 221 |
-
add_special_tokens=False,
|
| 222 |
-
return_tensors="pt",
|
| 223 |
-
).input_ids[0]
|
| 224 |
-
tc = compute_contrast(
|
| 225 |
-
model=model,
|
| 226 |
-
context_a=context_a,
|
| 227 |
-
context_b=context_b,
|
| 228 |
-
response_ids=response_ids,
|
| 229 |
-
label_a=label_a,
|
| 230 |
-
label_b=label_b,
|
| 231 |
-
remote=remote,
|
| 232 |
-
)
|
| 233 |
-
if tc is not None:
|
| 234 |
-
msg["_contrast"] = tc
|
| 235 |
-
except Exception as exc:
|
| 236 |
-
st.warning(f"Token contrast recompute failed: {exc}")
|
| 237 |
-
msg.pop("_needs_contrast", None)
|
| 238 |
-
st.rerun()
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
for panel in panels:
|
| 241 |
render_chat_window(
|
| 242 |
chat_log=panel.log,
|
|
@@ -248,10 +249,23 @@ def render_compare_mode(
|
|
| 248 |
edit_column_ratio=(10, 1),
|
| 249 |
)
|
| 250 |
|
| 251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
reset_menu_nonce_key = widget_key(context_key, "cmp_reset_menu_nonce")
|
| 253 |
if reset_menu_nonce_key not in st.session_state:
|
| 254 |
st.session_state[reset_menu_nonce_key] = 0
|
|
|
|
|
|
|
| 255 |
with footer:
|
| 256 |
exp_col, rst_col, _spacer = st.columns([0.5, 0.5, 10], gap="xsmall")
|
| 257 |
with exp_col:
|
|
@@ -270,7 +284,7 @@ def render_compare_mode(
|
|
| 270 |
prompt_mode=panel.state["prompt_mode"],
|
| 271 |
system_prompt=panel.prompt,
|
| 272 |
messages=panel.state["messages"],
|
| 273 |
-
generation=
|
| 274 |
panel_label=panel.side,
|
| 275 |
)
|
| 276 |
st.toast("Exported", icon=":material/check:")
|
|
@@ -304,81 +318,150 @@ def render_compare_mode(
|
|
| 304 |
st.session_state[reset_menu_nonce_key] += 1
|
| 305 |
st.rerun()
|
| 306 |
|
| 307 |
-
user_prompt = st.chat_input(
|
| 308 |
-
"Ask both...",
|
| 309 |
-
key=widget_key(context_key, "cmp_input"),
|
| 310 |
-
)
|
| 311 |
-
|
| 312 |
-
if not user_prompt:
|
| 313 |
-
return
|
| 314 |
-
|
| 315 |
-
model = cached_model(model_name=model_name, remote=remote)
|
| 316 |
|
|
|
|
| 317 |
for panel in panels:
|
| 318 |
panel.state["messages"].append({"role": "user", "content": user_prompt})
|
| 319 |
with panel.log:
|
| 320 |
render_chat_message({"role": "user", "content": user_prompt})
|
| 321 |
|
| 322 |
-
# Snapshot contexts before the new assistant turn is appended (needed for contrast).
|
| 323 |
-
pre_gen_contexts = [
|
| 324 |
-
build_chat_messages(panel.prompt, panel.state["messages"]) for panel in panels
|
| 325 |
-
]
|
| 326 |
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
)
|
| 340 |
-
)
|
| 341 |
-
except Exception as exc:
|
| 342 |
-
results.append(exc)
|
| 343 |
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
|
| 353 |
-
panel.state["messages"].append({"role": "assistant", "content": result.text})
|
| 354 |
-
panel.state["past_key_values"] = result.past_key_values if not remote else None
|
| 355 |
-
valid_results.append(result)
|
| 356 |
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
):
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
-
# Rerun so the newly appended turns are redrawn through the editable history
|
| 383 |
-
# renderer instead of only appearing in the one-off generation pass.
|
| 384 |
st.rerun()
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Any
|
| 3 |
|
| 4 |
import streamlit as st
|
| 5 |
from nnterp import StandardizedTransformer
|
| 6 |
from persona_data.synth_persona import PersonaData
|
| 7 |
|
| 8 |
+
from state import ChatState, default_chat_state, reset_chat_context_state
|
| 9 |
from utils.chat import (
|
| 10 |
ChatReply,
|
| 11 |
build_chat_messages,
|
|
|
|
| 19 |
|
| 20 |
from .chat_ui import (
|
| 21 |
GenerationConfig,
|
|
|
|
| 22 |
render_chat_message,
|
| 23 |
render_chat_window,
|
| 24 |
render_persona_prompt_controls,
|
|
|
|
| 26 |
)
|
| 27 |
|
| 28 |
|
| 29 |
+
@dataclass(frozen=True)
|
| 30 |
+
class ComparePanel:
|
| 31 |
side: str
|
| 32 |
+
state: ChatState
|
| 33 |
log: Any
|
| 34 |
prompt: str | None
|
| 35 |
persona: PersonaData
|
|
|
|
| 38 |
pending_key: str
|
| 39 |
|
| 40 |
|
| 41 |
+
def _get_compare_state(context_key: str, side: str) -> tuple[str, ChatState]:
|
| 42 |
+
panel_key = widget_key(context_key, f"cmp_{side}")
|
| 43 |
+
if panel_key not in st.session_state:
|
| 44 |
+
st.session_state[panel_key] = default_chat_state()
|
| 45 |
+
return panel_key, st.session_state[panel_key]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
def _reset_compare_panel(panel: ComparePanel) -> None:
|
| 49 |
reset_chat_context_state(
|
| 50 |
panel.state,
|
|
|
|
| 56 |
st.session_state.pop(panel.edit_key, None)
|
| 57 |
|
| 58 |
|
| 59 |
+
def _render_compare_panel(
|
| 60 |
*,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
context_key: str,
|
| 62 |
+
side: str,
|
| 63 |
personas: list[PersonaData],
|
| 64 |
+
) -> ComparePanel:
|
| 65 |
+
panel_key, state = _get_compare_state(context_key, side)
|
| 66 |
+
|
| 67 |
+
prompt_key = widget_key(panel_key, "custom_prompt")
|
| 68 |
+
edit_key = widget_key(panel_key, "edit_idx")
|
| 69 |
+
pending_key = widget_key(panel_key, "pending_regen")
|
| 70 |
+
|
| 71 |
+
persist_persona_key = f"chat:last_cmp_{side}_persona"
|
| 72 |
+
persist_prompt_key = f"chat:last_cmp_{side}_prompt"
|
| 73 |
+
if state["persona_id"] is None:
|
| 74 |
+
state["persona_id"] = st.session_state.get(persist_persona_key)
|
| 75 |
+
state["prompt_mode"] = st.session_state.get(persist_prompt_key, "templated")
|
| 76 |
+
|
| 77 |
+
selected_persona, prompt_mode, changed = render_persona_prompt_controls(
|
| 78 |
+
personas,
|
| 79 |
+
state["persona_id"],
|
| 80 |
+
state["prompt_mode"],
|
| 81 |
+
widget_key(panel_key, "persona"),
|
| 82 |
+
widget_key(panel_key, "prompt_mode"),
|
|
|
|
|
|
|
| 83 |
)
|
| 84 |
+
st.session_state[persist_persona_key] = selected_persona.id
|
| 85 |
+
st.session_state[persist_prompt_key] = prompt_mode
|
| 86 |
+
|
| 87 |
+
if changed:
|
| 88 |
+
reset_chat_context_state(
|
| 89 |
+
state,
|
| 90 |
+
selected_persona.id,
|
| 91 |
+
prompt_mode,
|
| 92 |
+
prompt_key,
|
| 93 |
+
pending_key,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
)
|
| 95 |
+
st.session_state.pop(edit_key, None)
|
|
|
|
| 96 |
|
| 97 |
+
active_system_prompt = resolve_system_prompt(
|
| 98 |
+
persona=selected_persona,
|
| 99 |
+
mode=prompt_mode,
|
| 100 |
+
)
|
|
|
|
| 101 |
|
| 102 |
+
chat_log = st.container()
|
| 103 |
+
with chat_log:
|
| 104 |
+
active_system_prompt = render_system_prompt(
|
| 105 |
+
prompt_key,
|
| 106 |
+
prompt_mode,
|
| 107 |
+
active_system_prompt,
|
| 108 |
)
|
| 109 |
|
| 110 |
+
return ComparePanel(
|
| 111 |
+
side=side,
|
| 112 |
+
state=state,
|
| 113 |
+
log=chat_log,
|
| 114 |
+
prompt=active_system_prompt,
|
| 115 |
+
persona=selected_persona,
|
| 116 |
+
prompt_key=prompt_key,
|
| 117 |
+
edit_key=edit_key,
|
| 118 |
+
pending_key=pending_key,
|
| 119 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
| 122 |
+
def _generate_panels(
|
| 123 |
+
*,
|
| 124 |
+
model: StandardizedTransformer,
|
| 125 |
+
remote: bool,
|
| 126 |
+
panels: list[ComparePanel],
|
| 127 |
+
generation: GenerationConfig,
|
| 128 |
+
spinner_label: str,
|
| 129 |
+
) -> list[ChatReply | Exception]:
|
| 130 |
+
results: list[ChatReply | Exception] = []
|
| 131 |
+
with st.spinner(spinner_label):
|
| 132 |
+
for panel in panels:
|
| 133 |
+
try:
|
| 134 |
+
results.append(
|
| 135 |
+
generate_chat_reply(
|
| 136 |
+
model=model,
|
| 137 |
+
messages=build_chat_messages(
|
| 138 |
+
panel.prompt, panel.state["messages"]
|
| 139 |
+
),
|
| 140 |
+
remote=remote,
|
| 141 |
+
past_key_values=panel.state["past_key_values"],
|
| 142 |
+
**generation.to_generate_kwargs(),
|
| 143 |
)
|
| 144 |
+
)
|
| 145 |
+
except Exception as exc:
|
| 146 |
+
results.append(exc)
|
| 147 |
+
return results
|
| 148 |
|
| 149 |
+
|
| 150 |
+
def _apply_panel_results(
|
| 151 |
+
*,
|
| 152 |
+
panels: list[ComparePanel],
|
| 153 |
+
results: list[ChatReply | Exception],
|
| 154 |
+
remote: bool,
|
| 155 |
+
rollback_user_on_error: bool,
|
| 156 |
+
) -> list[ChatReply | None]:
|
| 157 |
+
valid_results: list[ChatReply | None] = []
|
| 158 |
+
for panel, result in zip(panels, results, strict=True):
|
| 159 |
+
if isinstance(result, Exception):
|
| 160 |
+
with panel.log:
|
| 161 |
+
st.error(f"Generation failed: {result}")
|
| 162 |
+
if rollback_user_on_error and panel.state["messages"]:
|
| 163 |
+
panel.state["messages"].pop()
|
| 164 |
+
valid_results.append(None)
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
panel.state["messages"].append({"role": "assistant", "content": result.text})
|
| 168 |
+
panel.state["past_key_values"] = result.past_key_values if not remote else None
|
| 169 |
+
valid_results.append(result)
|
| 170 |
+
return valid_results
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _pending_contrast_edits(panels: list[ComparePanel]) -> list[tuple[int, int]]:
|
| 174 |
+
return [
|
| 175 |
+
(panel_idx, msg_idx)
|
| 176 |
+
for panel_idx, panel in enumerate(panels)
|
| 177 |
+
for msg_idx, msg in enumerate(panel.state["messages"])
|
| 178 |
+
if msg.get("_needs_contrast") and msg.get("role") == "assistant"
|
| 179 |
+
]
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _recompute_pending_contrast(
|
| 183 |
+
*,
|
| 184 |
+
model: StandardizedTransformer,
|
| 185 |
+
remote: bool,
|
| 186 |
+
panels: list[ComparePanel],
|
| 187 |
+
) -> bool:
|
| 188 |
+
pending_edits = _pending_contrast_edits(panels)
|
| 189 |
+
if not pending_edits:
|
| 190 |
+
return False
|
| 191 |
+
|
| 192 |
+
left, right = panels
|
| 193 |
+
label_a = persona_label(left.persona)
|
| 194 |
+
label_b = persona_label(right.persona)
|
| 195 |
+
with st.spinner("Recomputing token contrast..."):
|
| 196 |
+
for panel_idx, msg_idx in pending_edits:
|
| 197 |
+
panel = panels[panel_idx]
|
| 198 |
+
msg = panel.state["messages"][msg_idx]
|
| 199 |
+
if msg_idx >= len(left.state["messages"]) or msg_idx >= len(
|
| 200 |
+
right.state["messages"]
|
| 201 |
+
):
|
| 202 |
+
msg.pop("_needs_contrast", None)
|
| 203 |
continue
|
| 204 |
+
|
| 205 |
+
context_a = build_chat_messages(
|
| 206 |
+
left.prompt,
|
| 207 |
+
left.state["messages"][:msg_idx],
|
| 208 |
)
|
| 209 |
+
context_b = build_chat_messages(
|
| 210 |
+
right.prompt,
|
| 211 |
+
right.state["messages"][:msg_idx],
|
| 212 |
)
|
| 213 |
+
try:
|
| 214 |
+
response_ids = model.tokenizer(
|
| 215 |
+
msg["content"],
|
| 216 |
+
add_special_tokens=False,
|
| 217 |
+
return_tensors="pt",
|
| 218 |
+
).input_ids[0]
|
| 219 |
+
contrast = compute_contrast(
|
| 220 |
+
model=model,
|
| 221 |
+
context_a=context_a,
|
| 222 |
+
context_b=context_b,
|
| 223 |
+
response_ids=response_ids,
|
| 224 |
+
label_a=label_a,
|
| 225 |
+
label_b=label_b,
|
| 226 |
+
remote=remote,
|
| 227 |
+
)
|
| 228 |
+
if contrast is not None:
|
| 229 |
+
msg["_contrast"] = contrast
|
| 230 |
+
except Exception as exc:
|
| 231 |
+
st.warning(f"Token contrast recompute failed: {exc}")
|
| 232 |
+
msg.pop("_needs_contrast", None)
|
| 233 |
+
return True
|
| 234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
+
def _render_compare_history(
|
| 237 |
+
*,
|
| 238 |
+
panels: list[ComparePanel],
|
| 239 |
+
contrast_enabled: bool,
|
| 240 |
+
) -> None:
|
| 241 |
for panel in panels:
|
| 242 |
render_chat_window(
|
| 243 |
chat_log=panel.log,
|
|
|
|
| 249 |
edit_column_ratio=(10, 1),
|
| 250 |
)
|
| 251 |
|
| 252 |
+
|
| 253 |
+
def _render_compare_footer(
|
| 254 |
+
*,
|
| 255 |
+
context_key: str,
|
| 256 |
+
model_name: str,
|
| 257 |
+
dataset_source: str,
|
| 258 |
+
panels: list[ComparePanel],
|
| 259 |
+
generation: GenerationConfig,
|
| 260 |
+
) -> None:
|
| 261 |
+
# Bumping this nonce after a reset gives the popover a fresh widget key,
|
| 262 |
+
# which forces Streamlit to re-mount it closed (popovers don't auto-close
|
| 263 |
+
# on click).
|
| 264 |
reset_menu_nonce_key = widget_key(context_key, "cmp_reset_menu_nonce")
|
| 265 |
if reset_menu_nonce_key not in st.session_state:
|
| 266 |
st.session_state[reset_menu_nonce_key] = 0
|
| 267 |
+
|
| 268 |
+
footer = st.container()
|
| 269 |
with footer:
|
| 270 |
exp_col, rst_col, _spacer = st.columns([0.5, 0.5, 10], gap="xsmall")
|
| 271 |
with exp_col:
|
|
|
|
| 284 |
prompt_mode=panel.state["prompt_mode"],
|
| 285 |
system_prompt=panel.prompt,
|
| 286 |
messages=panel.state["messages"],
|
| 287 |
+
generation=generation.to_export_dict(),
|
| 288 |
panel_label=panel.side,
|
| 289 |
)
|
| 290 |
st.toast("Exported", icon=":material/check:")
|
|
|
|
| 318 |
st.session_state[reset_menu_nonce_key] += 1
|
| 319 |
st.rerun()
|
| 320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
|
| 322 |
+
def _append_user_prompt(panels: list[ComparePanel], user_prompt: str) -> None:
|
| 323 |
for panel in panels:
|
| 324 |
panel.state["messages"].append({"role": "user", "content": user_prompt})
|
| 325 |
with panel.log:
|
| 326 |
render_chat_message({"role": "user", "content": user_prompt})
|
| 327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
+
def _compute_new_reply_contrast(
|
| 330 |
+
*,
|
| 331 |
+
model: StandardizedTransformer,
|
| 332 |
+
remote: bool,
|
| 333 |
+
panels: list[ComparePanel],
|
| 334 |
+
pre_gen_contexts: list[list[dict[str, str]]],
|
| 335 |
+
results: list[ChatReply | None],
|
| 336 |
+
) -> None:
|
| 337 |
+
if len(results) != 2 or any(
|
| 338 |
+
result is None or result.generated_ids is None for result in results
|
| 339 |
+
):
|
| 340 |
+
return
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
|
| 342 |
+
left, right = panels
|
| 343 |
+
with st.spinner("Computing token contrast..."):
|
| 344 |
+
try:
|
| 345 |
+
left_contrast, right_contrast = compute_contrast_pair(
|
| 346 |
+
model=model,
|
| 347 |
+
context_a=pre_gen_contexts[0],
|
| 348 |
+
context_b=pre_gen_contexts[1],
|
| 349 |
+
response_ids_a=results[0].generated_ids,
|
| 350 |
+
response_ids_b=results[1].generated_ids,
|
| 351 |
+
label_a=persona_label(left.persona),
|
| 352 |
+
label_b=persona_label(right.persona),
|
| 353 |
+
remote=remote,
|
| 354 |
+
)
|
| 355 |
+
if left_contrast is not None:
|
| 356 |
+
left.state["messages"][-1]["_contrast"] = left_contrast
|
| 357 |
+
if right_contrast is not None:
|
| 358 |
+
right.state["messages"][-1]["_contrast"] = right_contrast
|
| 359 |
+
except Exception as exc:
|
| 360 |
+
st.warning(f"Token contrast failed: {exc}")
|
| 361 |
|
|
|
|
|
|
|
|
|
|
| 362 |
|
| 363 |
+
def _render_compare_panels(
|
| 364 |
+
*,
|
| 365 |
+
context_key: str,
|
| 366 |
+
personas: list[PersonaData],
|
| 367 |
+
) -> list[ComparePanel]:
|
| 368 |
+
left_col, right_col = st.columns(2)
|
| 369 |
+
with left_col:
|
| 370 |
+
left = _render_compare_panel(
|
| 371 |
+
context_key=context_key,
|
| 372 |
+
side="left",
|
| 373 |
+
personas=personas,
|
| 374 |
+
)
|
| 375 |
+
with right_col:
|
| 376 |
+
right = _render_compare_panel(
|
| 377 |
+
context_key=context_key,
|
| 378 |
+
side="right",
|
| 379 |
+
personas=personas,
|
| 380 |
+
)
|
| 381 |
+
return [left, right]
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def render_compare_mode(
|
| 385 |
+
remote: bool,
|
| 386 |
+
model_name: str,
|
| 387 |
+
context_key: str,
|
| 388 |
+
dataset_source: str,
|
| 389 |
+
personas: list[PersonaData],
|
| 390 |
+
generation: GenerationConfig,
|
| 391 |
+
*,
|
| 392 |
+
contrast_enabled: bool,
|
| 393 |
+
) -> None:
|
| 394 |
+
"""Render the full side-by-side comparison UI."""
|
| 395 |
+
|
| 396 |
+
panels = _render_compare_panels(context_key=context_key, personas=personas)
|
| 397 |
+
|
| 398 |
+
regen_panels = [
|
| 399 |
+
panel for panel in panels if st.session_state.pop(panel.pending_key, False)
|
| 400 |
+
]
|
| 401 |
+
if regen_panels:
|
| 402 |
+
results = _generate_panels(
|
| 403 |
+
model=cached_model(model_name=model_name),
|
| 404 |
+
remote=remote,
|
| 405 |
+
panels=regen_panels,
|
| 406 |
+
generation=generation,
|
| 407 |
+
spinner_label="Regenerating...",
|
| 408 |
+
)
|
| 409 |
+
_apply_panel_results(
|
| 410 |
+
panels=regen_panels,
|
| 411 |
+
results=results,
|
| 412 |
+
remote=remote,
|
| 413 |
+
rollback_user_on_error=False,
|
| 414 |
+
)
|
| 415 |
+
st.rerun()
|
| 416 |
+
|
| 417 |
+
if contrast_enabled and _recompute_pending_contrast(
|
| 418 |
+
model=cached_model(model_name=model_name),
|
| 419 |
+
remote=remote,
|
| 420 |
+
panels=panels,
|
| 421 |
):
|
| 422 |
+
st.rerun()
|
| 423 |
+
|
| 424 |
+
_render_compare_history(panels=panels, contrast_enabled=contrast_enabled)
|
| 425 |
+
_render_compare_footer(
|
| 426 |
+
context_key=context_key,
|
| 427 |
+
model_name=model_name,
|
| 428 |
+
dataset_source=dataset_source,
|
| 429 |
+
panels=panels,
|
| 430 |
+
generation=generation,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
user_prompt = st.chat_input(
|
| 434 |
+
"Ask both...",
|
| 435 |
+
key=widget_key(context_key, "cmp_input"),
|
| 436 |
+
)
|
| 437 |
+
if not user_prompt:
|
| 438 |
+
return
|
| 439 |
+
|
| 440 |
+
_append_user_prompt(panels, user_prompt)
|
| 441 |
+
pre_gen_contexts = [
|
| 442 |
+
build_chat_messages(panel.prompt, panel.state["messages"]) for panel in panels
|
| 443 |
+
]
|
| 444 |
+
model = cached_model(model_name=model_name)
|
| 445 |
+
results = _generate_panels(
|
| 446 |
+
model=model,
|
| 447 |
+
remote=remote,
|
| 448 |
+
panels=panels,
|
| 449 |
+
generation=generation,
|
| 450 |
+
spinner_label="Generating...",
|
| 451 |
+
)
|
| 452 |
+
valid_results = _apply_panel_results(
|
| 453 |
+
panels=panels,
|
| 454 |
+
results=results,
|
| 455 |
+
remote=remote,
|
| 456 |
+
rollback_user_on_error=True,
|
| 457 |
+
)
|
| 458 |
+
if contrast_enabled:
|
| 459 |
+
_compute_new_reply_contrast(
|
| 460 |
+
model=model,
|
| 461 |
+
remote=remote,
|
| 462 |
+
panels=panels,
|
| 463 |
+
pre_gen_contexts=pre_gen_contexts,
|
| 464 |
+
results=valid_results,
|
| 465 |
+
)
|
| 466 |
|
|
|
|
|
|
|
| 467 |
st.rerun()
|
tabs/extract.py
CHANGED
|
@@ -1,12 +1,9 @@
|
|
| 1 |
import html
|
|
|
|
| 2 |
from typing import Literal, cast
|
| 3 |
|
| 4 |
import streamlit as st
|
| 5 |
-
from persona_data.prompts import
|
| 6 |
-
BASELINE_PERSONA_ID,
|
| 7 |
-
BASELINE_PERSONA_NAME,
|
| 8 |
-
format_prompt,
|
| 9 |
-
)
|
| 10 |
from persona_data.synth_persona import PersonaData, QAPair
|
| 11 |
from persona_vectors.artifacts import PERSONA_VARIANTS
|
| 12 |
from persona_vectors.extraction import (
|
|
@@ -42,8 +39,11 @@ _ITEM_TYPE_OPTIONS = ["all", "mcq", "frq"]
|
|
| 42 |
_SCOPE_OPTIONS = ["all", "individual", "shared"]
|
| 43 |
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
def _remembered_select(
|
|
@@ -53,10 +53,11 @@ def _remembered_select(
|
|
| 53 |
key: str,
|
| 54 |
default: str = "all",
|
| 55 |
) -> str:
|
|
|
|
| 56 |
selected = st.selectbox(
|
| 57 |
label,
|
| 58 |
options=options,
|
| 59 |
-
index=
|
| 60 |
key=key,
|
| 61 |
)
|
| 62 |
st.session_state[state_key] = selected
|
|
@@ -66,21 +67,13 @@ def _remembered_select(
|
|
| 66 |
def _build_run_plan(
|
| 67 |
selected_variants: list[str],
|
| 68 |
runs: list[tuple[PersonaData, list[QAPair]]],
|
| 69 |
-
) -> list[tuple[PersonaData
|
| 70 |
-
"""
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
for variant in selected_variants:
|
| 77 |
-
if variant == BASELINE_PERSONA_ID:
|
| 78 |
-
_, qa_pairs = runs[0]
|
| 79 |
-
plan.append((None, qa_pairs, variant))
|
| 80 |
-
else:
|
| 81 |
-
for persona, qa_pairs in runs:
|
| 82 |
-
plan.append((persona, qa_pairs, variant))
|
| 83 |
-
return plan
|
| 84 |
|
| 85 |
|
| 86 |
def _extract_widget_key(
|
|
@@ -89,103 +82,55 @@ def _extract_widget_key(
|
|
| 89 |
return widget_key("extract", str(remote), model_name, dataset_source, suffix)
|
| 90 |
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
"response": "color:#22d3ee",
|
| 108 |
-
"question": "color:#fde047",
|
| 109 |
-
}.get(segment.role, "color:#9ca3af")
|
| 110 |
-
|
| 111 |
-
if segment.is_special:
|
| 112 |
-
style = "color:#d946ef;font-weight:bold"
|
| 113 |
-
if segment.is_masked:
|
| 114 |
-
style = f"{style};background:#86efac;border-radius:2px;padding:0 1px"
|
| 115 |
-
return style
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
def _render_sample_tokens_html(p, tokenizer, *, max_tokens: int = 200) -> str:
|
| 119 |
-
spans: list[str] = []
|
| 120 |
-
for segment in preview_token_segments(p, tokenizer, max_tokens=max_tokens):
|
| 121 |
-
spans.append(
|
| 122 |
-
f'<span style="{_token_style(segment)}">{html.escape(segment.text)}</span>'
|
| 123 |
)
|
| 124 |
|
| 125 |
-
return (
|
| 126 |
-
'<pre style="white-space:pre-wrap;font-size:0.82em;line-height:1.5;'
|
| 127 |
-
"background:#0e1117;padding:8px 10px;border-radius:6px;"
|
| 128 |
-
'border:1px solid #333;margin:0">'
|
| 129 |
-
f"{''.join(spans)}</pre>"
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
| 134 |
-
"""Render the extraction tab."""
|
| 135 |
-
|
| 136 |
-
st.title("Extract")
|
| 137 |
-
|
| 138 |
-
if dataset_source == "Local JSONL upload":
|
| 139 |
-
with st.expander("Local dataset upload", expanded=True):
|
| 140 |
-
st.file_uploader(
|
| 141 |
-
"personas.jsonl",
|
| 142 |
-
type=["jsonl"],
|
| 143 |
-
key="extract__personas_file",
|
| 144 |
-
help="Expected fields: id, persona, templated_view, biography_view",
|
| 145 |
-
)
|
| 146 |
-
st.file_uploader(
|
| 147 |
-
"qa.jsonl",
|
| 148 |
-
type=["jsonl"],
|
| 149 |
-
key="extract__qa_file",
|
| 150 |
-
help="Expected fields: id, qid, type, item_type, scope, question, answer",
|
| 151 |
-
)
|
| 152 |
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
| 160 |
"Persona variants",
|
| 161 |
options=PERSONA_VARIANTS,
|
| 162 |
-
default=
|
|
|
|
| 163 |
format_func=prompt_variant_label,
|
| 164 |
key=_extract_widget_key(model_name, remote, dataset_source, "persona_variants"),
|
| 165 |
help="Extract these variants for each selected persona.",
|
| 166 |
)
|
| 167 |
include_baseline = st.checkbox(
|
| 168 |
"Extract Assistant baseline",
|
| 169 |
-
value=st.session_state.get(
|
| 170 |
-
_LAST_BASELINE_KEY,
|
| 171 |
-
BASELINE_PERSONA_ID in last_variants,
|
| 172 |
-
),
|
| 173 |
key=_extract_widget_key(model_name, remote, dataset_source, "baseline"),
|
| 174 |
-
help=
|
| 175 |
-
"Extracts the persona-less Assistant prompt once using the first "
|
| 176 |
-
"selected persona's QA set."
|
| 177 |
-
),
|
| 178 |
)
|
| 179 |
-
selected_variants = [
|
| 180 |
-
*selected_persona_variants,
|
| 181 |
-
*([BASELINE_PERSONA_ID] if include_baseline else []),
|
| 182 |
-
]
|
| 183 |
st.session_state[_LAST_VARIANTS_KEY] = selected_variants
|
| 184 |
st.session_state[_LAST_BASELINE_KEY] = include_baseline
|
| 185 |
if not selected_variants:
|
| 186 |
-
st.info("Select at least one persona variant
|
| 187 |
-
return
|
|
|
|
|
|
|
| 188 |
|
|
|
|
| 189 |
try:
|
| 190 |
dataset, dataset_status = load_dataset(
|
| 191 |
dataset_source,
|
|
@@ -198,11 +143,11 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 198 |
st.info(
|
| 199 |
"Upload both JSONL files or switch to the built-in SynthPersona source."
|
| 200 |
)
|
| 201 |
-
return
|
| 202 |
|
| 203 |
if not getattr(dataset, "supports_qa", True):
|
| 204 |
st.info("This dataset is persona-only for now. Use Chat to browse personas.")
|
| 205 |
-
return
|
| 206 |
|
| 207 |
personas = list(dataset)
|
| 208 |
if not personas:
|
|
@@ -210,8 +155,17 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 210 |
st.info(
|
| 211 |
"Try another dataset source or check that the personas file is not empty."
|
| 212 |
)
|
| 213 |
-
return
|
|
|
|
|
|
|
| 214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
last_persona_ids: set[str] = set(st.session_state.get(_LAST_PERSONA_IDS_KEY, []))
|
| 216 |
default_personas = [p for p in personas if p.id in last_persona_ids] or [
|
| 217 |
personas[0]
|
|
@@ -227,185 +181,295 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 227 |
|
| 228 |
if not selected_personas:
|
| 229 |
st.info("Select at least one persona.")
|
| 230 |
-
return
|
|
|
|
| 231 |
|
| 232 |
-
with st.expander("Advanced", expanded=False):
|
| 233 |
-
st.caption("Filters")
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
)
|
| 245 |
-
qa_filter_type: Literal["explicit", "implicit"] | None = (
|
| 246 |
-
cast(Literal["explicit", "implicit"], qa_type_select)
|
| 247 |
-
if qa_type_select in ("explicit", "implicit")
|
| 248 |
-
else None
|
| 249 |
-
)
|
| 250 |
-
with col2:
|
| 251 |
-
item_type_select = _remembered_select(
|
| 252 |
-
"Item type",
|
| 253 |
-
_ITEM_TYPE_OPTIONS,
|
| 254 |
-
_LAST_ITEM_TYPE_KEY,
|
| 255 |
-
key=_extract_widget_key(
|
| 256 |
-
model_name, remote, dataset_source, "item_type_select"
|
| 257 |
-
),
|
| 258 |
-
)
|
| 259 |
-
qa_filter_item_type: Literal["mcq", "frq"] | None = (
|
| 260 |
-
cast(Literal["mcq", "frq"], item_type_select)
|
| 261 |
-
if item_type_select in ("mcq", "frq")
|
| 262 |
-
else None
|
| 263 |
-
)
|
| 264 |
-
with col3:
|
| 265 |
-
scope_select = _remembered_select(
|
| 266 |
-
"Scope",
|
| 267 |
-
_SCOPE_OPTIONS,
|
| 268 |
-
_LAST_SCOPE_KEY,
|
| 269 |
-
key=_extract_widget_key(
|
| 270 |
-
model_name,
|
| 271 |
-
remote,
|
| 272 |
-
dataset_source,
|
| 273 |
-
"scope_select",
|
| 274 |
-
),
|
| 275 |
-
)
|
| 276 |
-
qa_filter_scope: Literal["individual", "shared"] | None = (
|
| 277 |
-
cast(Literal["individual", "shared"], scope_select)
|
| 278 |
-
if scope_select in ("individual", "shared")
|
| 279 |
-
else None
|
| 280 |
-
)
|
| 281 |
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
)
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
key=_extract_widget_key(
|
| 301 |
model_name,
|
| 302 |
remote,
|
| 303 |
dataset_source,
|
| 304 |
-
"
|
| 305 |
),
|
| 306 |
-
help="Which tokens contribute to the averaged hidden state.",
|
| 307 |
)
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
if
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
)
|
| 345 |
-
st.session_state[_LAST_MAX_QUESTIONS_KEY] = max_questions
|
| 346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
run_col, preview_col, _spacer = st.columns([1, 1, 4], gap="small")
|
| 348 |
with run_col:
|
| 349 |
run_clicked = st.button(
|
| 350 |
-
"Run extraction",
|
|
|
|
|
|
|
| 351 |
)
|
| 352 |
with preview_col:
|
| 353 |
preview_clicked = st.button("Preview tokens", use_container_width=True)
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
)
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
if not run_clicked:
|
| 396 |
-
return
|
| 397 |
-
|
| 398 |
status_box = st.empty()
|
| 399 |
status_box.info("Extraction in progress...")
|
| 400 |
progress = st.progress(0, text="Preparing extraction...")
|
| 401 |
-
ndif_status_box = st.empty()
|
| 402 |
|
| 403 |
def _on_ndif_status(job_id: str, status_name: str, description: str) -> None:
|
| 404 |
icon = NDIF_STATUS_ICONS.get(status_name, "•")
|
| 405 |
ndif_status_box.caption(f"{icon} `{job_id}` **{status_name}** — {description}")
|
| 406 |
|
| 407 |
with st.spinner("Loading model..."):
|
| 408 |
-
model = cached_model(model_name=model_name
|
| 409 |
|
| 410 |
try:
|
| 411 |
total_steps = len(run_plan)
|
|
@@ -419,10 +483,10 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 419 |
run_extraction(
|
| 420 |
model=model,
|
| 421 |
model_name=model_name,
|
| 422 |
-
qa_pairs=qa_pairs[:max_questions],
|
| 423 |
variants=(variant,),
|
| 424 |
persona=persona,
|
| 425 |
-
mask_strategy=mask_strategy,
|
| 426 |
remote=remote,
|
| 427 |
on_status=_on_ndif_status if remote else None,
|
| 428 |
)
|
|
@@ -444,3 +508,71 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 444 |
f"- **{result.persona_name}** · {prompt_variant_label(result.variant)}: "
|
| 445 |
f"{result.n_questions} questions"
|
| 446 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import html
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
from typing import Literal, cast
|
| 4 |
|
| 5 |
import streamlit as st
|
| 6 |
+
from persona_data.prompts import format_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from persona_data.synth_persona import PersonaData, QAPair
|
| 8 |
from persona_vectors.artifacts import PERSONA_VARIANTS
|
| 9 |
from persona_vectors.extraction import (
|
|
|
|
| 39 |
_SCOPE_OPTIONS = ["all", "individual", "shared"]
|
| 40 |
|
| 41 |
|
| 42 |
+
@dataclass(frozen=True)
|
| 43 |
+
class ExtractSettings:
|
| 44 |
+
runs: list[tuple[PersonaData, list[QAPair]]]
|
| 45 |
+
mask_strategy: MaskStrategy
|
| 46 |
+
max_questions: int
|
| 47 |
|
| 48 |
|
| 49 |
def _remembered_select(
|
|
|
|
| 53 |
key: str,
|
| 54 |
default: str = "all",
|
| 55 |
) -> str:
|
| 56 |
+
current = st.session_state.get(state_key, default)
|
| 57 |
selected = st.selectbox(
|
| 58 |
label,
|
| 59 |
options=options,
|
| 60 |
+
index=options.index(current) if current in options else 0,
|
| 61 |
key=key,
|
| 62 |
)
|
| 63 |
st.session_state[state_key] = selected
|
|
|
|
| 67 |
def _build_run_plan(
|
| 68 |
selected_variants: list[str],
|
| 69 |
runs: list[tuple[PersonaData, list[QAPair]]],
|
| 70 |
+
) -> list[tuple[PersonaData, list[QAPair], str]]:
|
| 71 |
+
"""Cartesian product of personas × variants."""
|
| 72 |
+
return [(p, qa, v) for v in selected_variants for p, qa in runs]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _row_label(persona: PersonaData, variant: str) -> str:
|
| 76 |
+
return f"{persona.name} · {prompt_variant_label(variant)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
|
| 79 |
def _extract_widget_key(
|
|
|
|
| 82 |
return widget_key("extract", str(remote), model_name, dataset_source, suffix)
|
| 83 |
|
| 84 |
|
| 85 |
+
def _render_local_dataset_upload(dataset_source: str) -> None:
|
| 86 |
+
if dataset_source != "Local JSONL upload":
|
| 87 |
+
return
|
| 88 |
+
with st.expander("Local dataset upload", expanded=True):
|
| 89 |
+
st.file_uploader(
|
| 90 |
+
"personas.jsonl",
|
| 91 |
+
type=["jsonl"],
|
| 92 |
+
key="extract__personas_file",
|
| 93 |
+
help="Expected fields: id, persona, templated_view, biography_view",
|
| 94 |
+
)
|
| 95 |
+
st.file_uploader(
|
| 96 |
+
"qa.jsonl",
|
| 97 |
+
type=["jsonl"],
|
| 98 |
+
key="extract__qa_file",
|
| 99 |
+
help="Expected fields: id, qid, type, item_type, scope, question, answer",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
)
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
+
def _render_variant_controls(
|
| 104 |
+
*,
|
| 105 |
+
model_name: str,
|
| 106 |
+
remote: bool,
|
| 107 |
+
dataset_source: str,
|
| 108 |
+
) -> tuple[list[str], bool] | None:
|
| 109 |
+
default_variants = st.session_state.get(_LAST_VARIANTS_KEY, list(PERSONA_VARIANTS))
|
| 110 |
+
selected_variants = st.multiselect(
|
| 111 |
"Persona variants",
|
| 112 |
options=PERSONA_VARIANTS,
|
| 113 |
+
default=[v for v in default_variants if v in PERSONA_VARIANTS]
|
| 114 |
+
or list(PERSONA_VARIANTS),
|
| 115 |
format_func=prompt_variant_label,
|
| 116 |
key=_extract_widget_key(model_name, remote, dataset_source, "persona_variants"),
|
| 117 |
help="Extract these variants for each selected persona.",
|
| 118 |
)
|
| 119 |
include_baseline = st.checkbox(
|
| 120 |
"Extract Assistant baseline",
|
| 121 |
+
value=st.session_state.get(_LAST_BASELINE_KEY, True),
|
|
|
|
|
|
|
|
|
|
| 122 |
key=_extract_widget_key(model_name, remote, dataset_source, "baseline"),
|
| 123 |
+
help="Also extract the Assistant baseline persona using the first persona's QA set.",
|
|
|
|
|
|
|
|
|
|
| 124 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
st.session_state[_LAST_VARIANTS_KEY] = selected_variants
|
| 126 |
st.session_state[_LAST_BASELINE_KEY] = include_baseline
|
| 127 |
if not selected_variants:
|
| 128 |
+
st.info("Select at least one persona variant.")
|
| 129 |
+
return None
|
| 130 |
+
return selected_variants, include_baseline
|
| 131 |
+
|
| 132 |
|
| 133 |
+
def _load_qa_dataset_personas(dataset_source: str) -> tuple[object, list[PersonaData]] | None:
|
| 134 |
try:
|
| 135 |
dataset, dataset_status = load_dataset(
|
| 136 |
dataset_source,
|
|
|
|
| 143 |
st.info(
|
| 144 |
"Upload both JSONL files or switch to the built-in SynthPersona source."
|
| 145 |
)
|
| 146 |
+
return None
|
| 147 |
|
| 148 |
if not getattr(dataset, "supports_qa", True):
|
| 149 |
st.info("This dataset is persona-only for now. Use Chat to browse personas.")
|
| 150 |
+
return None
|
| 151 |
|
| 152 |
personas = list(dataset)
|
| 153 |
if not personas:
|
|
|
|
| 155 |
st.info(
|
| 156 |
"Try another dataset source or check that the personas file is not empty."
|
| 157 |
)
|
| 158 |
+
return None
|
| 159 |
+
return dataset, personas
|
| 160 |
+
|
| 161 |
|
| 162 |
+
def _render_persona_select(
|
| 163 |
+
*,
|
| 164 |
+
personas: list[PersonaData],
|
| 165 |
+
model_name: str,
|
| 166 |
+
remote: bool,
|
| 167 |
+
dataset_source: str,
|
| 168 |
+
) -> list[PersonaData] | None:
|
| 169 |
last_persona_ids: set[str] = set(st.session_state.get(_LAST_PERSONA_IDS_KEY, []))
|
| 170 |
default_personas = [p for p in personas if p.id in last_persona_ids] or [
|
| 171 |
personas[0]
|
|
|
|
| 181 |
|
| 182 |
if not selected_personas:
|
| 183 |
st.info("Select at least one persona.")
|
| 184 |
+
return None
|
| 185 |
+
return selected_personas
|
| 186 |
|
|
|
|
|
|
|
| 187 |
|
| 188 |
+
_TOKEN_LEGEND = (
|
| 189 |
+
'<div style="display:flex;gap:12px;flex-wrap:wrap;font-size:0.8em;margin-bottom:8px">'
|
| 190 |
+
'<span style="background:#86efac;color:black;padding:1px 6px;border-radius:3px">masked</span>'
|
| 191 |
+
'<span style="color:#fde047;padding:1px 6px">question</span>'
|
| 192 |
+
'<span style="color:#22d3ee;padding:1px 6px">response</span>'
|
| 193 |
+
'<span style="color:#d946ef;font-weight:bold;padding:1px 6px">special</span>'
|
| 194 |
+
'<span style="color:#9ca3af;padding:1px 6px">template</span>'
|
| 195 |
+
"</div>"
|
| 196 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
+
_MAX_PREVIEW_SAMPLES = 3
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _token_style(segment: TokenSegment) -> str:
|
| 202 |
+
style = {
|
| 203 |
+
"response": "color:#22d3ee",
|
| 204 |
+
"question": "color:#fde047",
|
| 205 |
+
}.get(segment.role, "color:#9ca3af")
|
| 206 |
+
|
| 207 |
+
if segment.is_special:
|
| 208 |
+
style = "color:#d946ef;font-weight:bold"
|
| 209 |
+
if segment.is_masked:
|
| 210 |
+
style = f"{style};background:#86efac;border-radius:2px;padding:0 1px"
|
| 211 |
+
return style
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def _render_sample_tokens_html(p, tokenizer, *, max_tokens: int = 200) -> str:
|
| 215 |
+
spans: list[str] = []
|
| 216 |
+
for segment in preview_token_segments(p, tokenizer, max_tokens=max_tokens):
|
| 217 |
+
spans.append(
|
| 218 |
+
f'<span style="{_token_style(segment)}">{html.escape(segment.text)}</span>'
|
| 219 |
)
|
| 220 |
+
|
| 221 |
+
return (
|
| 222 |
+
'<pre style="white-space:pre-wrap;font-size:0.82em;line-height:1.5;'
|
| 223 |
+
"background:var(--secondary-background-color,rgba(127,127,127,0.08));"
|
| 224 |
+
"padding:8px 10px;border-radius:6px;"
|
| 225 |
+
'border:1px solid rgba(127,127,127,0.25);margin:0">'
|
| 226 |
+
f"{''.join(spans)}</pre>"
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _render_filter_controls(
|
| 231 |
+
*,
|
| 232 |
+
model_name: str,
|
| 233 |
+
remote: bool,
|
| 234 |
+
dataset_source: str,
|
| 235 |
+
) -> tuple[
|
| 236 |
+
Literal["explicit", "implicit"] | None,
|
| 237 |
+
Literal["mcq", "frq"] | None,
|
| 238 |
+
Literal["individual", "shared"] | None,
|
| 239 |
+
]:
|
| 240 |
+
col1, col2, col3 = st.columns(3)
|
| 241 |
+
with col1:
|
| 242 |
+
qa_type_select = _remembered_select(
|
| 243 |
+
"QA type",
|
| 244 |
+
_QA_TYPE_OPTIONS,
|
| 245 |
+
_LAST_QA_TYPE_KEY,
|
| 246 |
+
key=_extract_widget_key(model_name, remote, dataset_source, "qa_type_select"),
|
| 247 |
+
)
|
| 248 |
+
with col2:
|
| 249 |
+
item_type_select = _remembered_select(
|
| 250 |
+
"Item type",
|
| 251 |
+
_ITEM_TYPE_OPTIONS,
|
| 252 |
+
_LAST_ITEM_TYPE_KEY,
|
| 253 |
key=_extract_widget_key(
|
| 254 |
model_name,
|
| 255 |
remote,
|
| 256 |
dataset_source,
|
| 257 |
+
"item_type_select",
|
| 258 |
),
|
|
|
|
| 259 |
)
|
| 260 |
+
with col3:
|
| 261 |
+
scope_select = _remembered_select(
|
| 262 |
+
"Scope",
|
| 263 |
+
_SCOPE_OPTIONS,
|
| 264 |
+
_LAST_SCOPE_KEY,
|
| 265 |
+
key=_extract_widget_key(model_name, remote, dataset_source, "scope_select"),
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
return (
|
| 269 |
+
cast(Literal["explicit", "implicit"], qa_type_select)
|
| 270 |
+
if qa_type_select in ("explicit", "implicit")
|
| 271 |
+
else None,
|
| 272 |
+
cast(Literal["mcq", "frq"], item_type_select)
|
| 273 |
+
if item_type_select in ("mcq", "frq")
|
| 274 |
+
else None,
|
| 275 |
+
cast(Literal["individual", "shared"], scope_select)
|
| 276 |
+
if scope_select in ("individual", "shared")
|
| 277 |
+
else None,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
def _render_mask_strategy_select(
|
| 282 |
+
*,
|
| 283 |
+
model_name: str,
|
| 284 |
+
remote: bool,
|
| 285 |
+
dataset_source: str,
|
| 286 |
+
) -> MaskStrategy:
|
| 287 |
+
last_strategy = st.session_state.get(
|
| 288 |
+
_LAST_MASK_STRATEGY_KEY,
|
| 289 |
+
MaskStrategy.ANSWER_MEAN.value,
|
| 290 |
+
)
|
| 291 |
+
strategy_options = list(MaskStrategy)
|
| 292 |
+
mask_strategy = st.selectbox(
|
| 293 |
+
"Mask strategy",
|
| 294 |
+
options=strategy_options,
|
| 295 |
+
index=next(
|
| 296 |
+
(
|
| 297 |
+
idx
|
| 298 |
+
for idx, strategy in enumerate(strategy_options)
|
| 299 |
+
if strategy.value == last_strategy
|
| 300 |
),
|
| 301 |
+
0,
|
| 302 |
+
),
|
| 303 |
+
format_func=lambda s: s.value.replace("_", " ").title(),
|
| 304 |
+
key=_extract_widget_key(model_name, remote, dataset_source, "mask_strategy"),
|
| 305 |
+
help="Which tokens contribute to the averaged hidden state.",
|
| 306 |
+
)
|
| 307 |
+
st.session_state[_LAST_MASK_STRATEGY_KEY] = mask_strategy.value
|
| 308 |
+
return mask_strategy
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def _collect_runs(
|
| 312 |
+
*,
|
| 313 |
+
dataset,
|
| 314 |
+
selected_personas: list[PersonaData],
|
| 315 |
+
qa_filter_type: Literal["explicit", "implicit"] | None,
|
| 316 |
+
qa_filter_item_type: Literal["mcq", "frq"] | None,
|
| 317 |
+
qa_filter_scope: Literal["individual", "shared"] | None,
|
| 318 |
+
) -> list[tuple[PersonaData, list[QAPair]]] | None:
|
| 319 |
+
runs, skipped = [], []
|
| 320 |
+
for persona in selected_personas:
|
| 321 |
+
qa = list(
|
| 322 |
+
dataset.get_qa(
|
| 323 |
+
persona.id,
|
| 324 |
+
type=qa_filter_type,
|
| 325 |
+
item_type=qa_filter_item_type,
|
| 326 |
+
scope=qa_filter_scope,
|
| 327 |
+
)
|
| 328 |
+
)
|
| 329 |
+
if qa:
|
| 330 |
+
runs.append((persona, qa))
|
| 331 |
+
else:
|
| 332 |
+
skipped.append(persona)
|
| 333 |
+
if skipped:
|
| 334 |
+
names = ", ".join(p.name for p in skipped)
|
| 335 |
+
st.warning(f"No QA pairs match filters for: {names}. They will be skipped.")
|
| 336 |
+
|
| 337 |
+
if not runs:
|
| 338 |
+
st.info("No personas have matching QA pairs. Widen the filters.")
|
| 339 |
+
return None
|
| 340 |
+
return runs
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def _render_max_questions(
|
| 344 |
+
*,
|
| 345 |
+
model_name: str,
|
| 346 |
+
remote: bool,
|
| 347 |
+
dataset_source: str,
|
| 348 |
+
runs: list[tuple[PersonaData, list[QAPair]]],
|
| 349 |
+
) -> int:
|
| 350 |
+
max_q = min(len(qa_pairs) for _, qa_pairs in runs)
|
| 351 |
+
max_questions = st.slider(
|
| 352 |
+
"Max questions",
|
| 353 |
+
min_value=1,
|
| 354 |
+
max_value=max_q,
|
| 355 |
+
value=min(max(st.session_state.get(_LAST_MAX_QUESTIONS_KEY, max_q), 1), max_q),
|
| 356 |
+
key=_extract_widget_key(model_name, remote, dataset_source, "max_questions"),
|
| 357 |
+
)
|
| 358 |
+
st.session_state[_LAST_MAX_QUESTIONS_KEY] = max_questions
|
| 359 |
+
return max_questions
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
def _render_advanced_settings(
|
| 363 |
+
*,
|
| 364 |
+
dataset,
|
| 365 |
+
selected_personas: list[PersonaData],
|
| 366 |
+
model_name: str,
|
| 367 |
+
remote: bool,
|
| 368 |
+
dataset_source: str,
|
| 369 |
+
) -> ExtractSettings | None:
|
| 370 |
+
with st.expander("Advanced", expanded=False):
|
| 371 |
+
st.caption("Filters")
|
| 372 |
+
qa_filter_type, qa_filter_item_type, qa_filter_scope = _render_filter_controls(
|
| 373 |
+
model_name=model_name,
|
| 374 |
+
remote=remote,
|
| 375 |
+
dataset_source=dataset_source,
|
| 376 |
)
|
|
|
|
| 377 |
|
| 378 |
+
st.caption("Extraction settings")
|
| 379 |
+
mask_strategy = _render_mask_strategy_select(
|
| 380 |
+
model_name=model_name,
|
| 381 |
+
remote=remote,
|
| 382 |
+
dataset_source=dataset_source,
|
| 383 |
+
)
|
| 384 |
+
runs = _collect_runs(
|
| 385 |
+
dataset=dataset,
|
| 386 |
+
selected_personas=selected_personas,
|
| 387 |
+
qa_filter_type=qa_filter_type,
|
| 388 |
+
qa_filter_item_type=qa_filter_item_type,
|
| 389 |
+
qa_filter_scope=qa_filter_scope,
|
| 390 |
+
)
|
| 391 |
+
if runs is None:
|
| 392 |
+
return None
|
| 393 |
+
|
| 394 |
+
max_questions = _render_max_questions(
|
| 395 |
+
model_name=model_name,
|
| 396 |
+
remote=remote,
|
| 397 |
+
dataset_source=dataset_source,
|
| 398 |
+
runs=runs,
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
return ExtractSettings(
|
| 402 |
+
runs=runs,
|
| 403 |
+
mask_strategy=mask_strategy,
|
| 404 |
+
max_questions=max_questions,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def _render_extract_actions() -> tuple[bool, bool]:
|
| 409 |
run_col, preview_col, _spacer = st.columns([1, 1, 4], gap="small")
|
| 410 |
with run_col:
|
| 411 |
run_clicked = st.button(
|
| 412 |
+
"Run extraction",
|
| 413 |
+
type="primary",
|
| 414 |
+
use_container_width=True,
|
| 415 |
)
|
| 416 |
with preview_col:
|
| 417 |
preview_clicked = st.button("Preview tokens", use_container_width=True)
|
| 418 |
+
return run_clicked, preview_clicked
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def _render_token_preview(
|
| 422 |
+
*,
|
| 423 |
+
remote: bool,
|
| 424 |
+
model_name: str,
|
| 425 |
+
run_plan: list[tuple[PersonaData, list[QAPair], str]],
|
| 426 |
+
settings: ExtractSettings,
|
| 427 |
+
) -> None:
|
| 428 |
+
with st.spinner("Loading tokenizer..."):
|
| 429 |
+
model = cached_model(model_name=model_name)
|
| 430 |
+
st.markdown(_TOKEN_LEGEND, unsafe_allow_html=True)
|
| 431 |
+
for persona, qa_pairs, variant in run_plan:
|
| 432 |
+
system_prompt = format_prompt(persona, variant) # type: ignore[arg-type]
|
| 433 |
+
prepared = prepare_inputs_for_strategy(
|
| 434 |
+
tokenizer=model.tokenizer,
|
| 435 |
+
system_prompt=system_prompt,
|
| 436 |
+
qa_pairs=qa_pairs[: settings.max_questions],
|
| 437 |
+
mask_strategy=settings.mask_strategy,
|
| 438 |
+
)
|
| 439 |
+
st.caption(_row_label(persona, variant))
|
| 440 |
+
for i, p in enumerate(prepared[:_MAX_PREVIEW_SAMPLES]):
|
| 441 |
+
question = p.question if len(p.question) <= 60 else p.question[:57] + "..."
|
| 442 |
+
seq_len = int(p.input_ids.shape[0])
|
| 443 |
+
masked = int(p.token_mask.sum())
|
| 444 |
+
label = f"sample {i} — {question} (len={seq_len}, masked={masked})"
|
| 445 |
+
with st.expander(label):
|
| 446 |
+
st.markdown(
|
| 447 |
+
_render_sample_tokens_html(p, model.tokenizer),
|
| 448 |
+
unsafe_allow_html=True,
|
| 449 |
)
|
| 450 |
+
if len(prepared) > _MAX_PREVIEW_SAMPLES:
|
| 451 |
+
remaining = len(prepared) - _MAX_PREVIEW_SAMPLES
|
| 452 |
+
st.caption(f"… and {remaining} more sample(s) not shown.")
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def _run_extraction_plan(
|
| 456 |
+
*,
|
| 457 |
+
remote: bool,
|
| 458 |
+
model_name: str,
|
| 459 |
+
run_plan: list[tuple[PersonaData, list[QAPair], str]],
|
| 460 |
+
settings: ExtractSettings,
|
| 461 |
+
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
status_box = st.empty()
|
| 463 |
status_box.info("Extraction in progress...")
|
| 464 |
progress = st.progress(0, text="Preparing extraction...")
|
| 465 |
+
ndif_status_box = st.empty()
|
| 466 |
|
| 467 |
def _on_ndif_status(job_id: str, status_name: str, description: str) -> None:
|
| 468 |
icon = NDIF_STATUS_ICONS.get(status_name, "•")
|
| 469 |
ndif_status_box.caption(f"{icon} `{job_id}` **{status_name}** — {description}")
|
| 470 |
|
| 471 |
with st.spinner("Loading model..."):
|
| 472 |
+
model = cached_model(model_name=model_name)
|
| 473 |
|
| 474 |
try:
|
| 475 |
total_steps = len(run_plan)
|
|
|
|
| 483 |
run_extraction(
|
| 484 |
model=model,
|
| 485 |
model_name=model_name,
|
| 486 |
+
qa_pairs=qa_pairs[: settings.max_questions],
|
| 487 |
variants=(variant,),
|
| 488 |
persona=persona,
|
| 489 |
+
mask_strategy=settings.mask_strategy,
|
| 490 |
remote=remote,
|
| 491 |
on_status=_on_ndif_status if remote else None,
|
| 492 |
)
|
|
|
|
| 508 |
f"- **{result.persona_name}** · {prompt_variant_label(result.variant)}: "
|
| 509 |
f"{result.n_questions} questions"
|
| 510 |
)
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
| 514 |
+
"""Render the extraction tab."""
|
| 515 |
+
|
| 516 |
+
st.title("Extract")
|
| 517 |
+
st.caption("Extract per-persona activation vectors from QA pairs.")
|
| 518 |
+
|
| 519 |
+
_render_local_dataset_upload(dataset_source)
|
| 520 |
+
variant_choice = _render_variant_controls(
|
| 521 |
+
model_name=model_name,
|
| 522 |
+
remote=remote,
|
| 523 |
+
dataset_source=dataset_source,
|
| 524 |
+
)
|
| 525 |
+
if variant_choice is None:
|
| 526 |
+
return
|
| 527 |
+
selected_variants, include_baseline = variant_choice
|
| 528 |
+
|
| 529 |
+
loaded = _load_qa_dataset_personas(dataset_source)
|
| 530 |
+
if loaded is None:
|
| 531 |
+
return
|
| 532 |
+
dataset, personas = loaded
|
| 533 |
+
|
| 534 |
+
selected_personas = _render_persona_select(
|
| 535 |
+
personas=personas,
|
| 536 |
+
model_name=model_name,
|
| 537 |
+
remote=remote,
|
| 538 |
+
dataset_source=dataset_source,
|
| 539 |
+
)
|
| 540 |
+
if selected_personas is None:
|
| 541 |
+
return
|
| 542 |
+
|
| 543 |
+
settings = _render_advanced_settings(
|
| 544 |
+
dataset=dataset,
|
| 545 |
+
selected_personas=selected_personas,
|
| 546 |
+
model_name=model_name,
|
| 547 |
+
remote=remote,
|
| 548 |
+
dataset_source=dataset_source,
|
| 549 |
+
)
|
| 550 |
+
if settings is None:
|
| 551 |
+
return
|
| 552 |
+
|
| 553 |
+
runs = list(settings.runs)
|
| 554 |
+
baseline = getattr(dataset, "baseline", None)
|
| 555 |
+
if include_baseline and baseline is not None and runs:
|
| 556 |
+
runs.append((baseline, runs[0][1]))
|
| 557 |
+
|
| 558 |
+
run_clicked, preview_clicked = _render_extract_actions()
|
| 559 |
+
run_plan = _build_run_plan(selected_variants, runs)
|
| 560 |
+
|
| 561 |
+
if preview_clicked:
|
| 562 |
+
_render_token_preview(
|
| 563 |
+
remote=remote,
|
| 564 |
+
model_name=model_name,
|
| 565 |
+
run_plan=run_plan,
|
| 566 |
+
settings=settings,
|
| 567 |
+
)
|
| 568 |
+
return
|
| 569 |
+
|
| 570 |
+
if not run_clicked:
|
| 571 |
+
return
|
| 572 |
+
|
| 573 |
+
_run_extraction_plan(
|
| 574 |
+
remote=remote,
|
| 575 |
+
model_name=model_name,
|
| 576 |
+
run_plan=run_plan,
|
| 577 |
+
settings=settings,
|
| 578 |
+
)
|
tabs/probe_ui.py
CHANGED
|
@@ -16,13 +16,6 @@ from utils.probes import (
|
|
| 16 |
from utils.runtime import cached_model
|
| 17 |
|
| 18 |
|
| 19 |
-
def _token_button_label(index: int, token: str) -> str:
|
| 20 |
-
display = token.encode("unicode_escape").decode("ascii") or "<empty>"
|
| 21 |
-
if len(display) > 18:
|
| 22 |
-
display = display[:15] + "..."
|
| 23 |
-
return f"{index}: {display}"
|
| 24 |
-
|
| 25 |
-
|
| 26 |
def _render_probe_results(result: ProbeRunResult, probe: LoadedProbe) -> None:
|
| 27 |
top_k = min(5, int(result.probabilities.numel()))
|
| 28 |
if top_k == 0:
|
|
@@ -95,55 +88,27 @@ def _load_probe_from_controls(context_key: str) -> LoadedProbe | None:
|
|
| 95 |
return load_probe(repo_id.strip(), selected_file)
|
| 96 |
|
| 97 |
|
| 98 |
-
def
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
"probe_selected_token",
|
| 102 |
-
trace.prompt_hash[:12],
|
| 103 |
-
)
|
| 104 |
-
selected = int(st.session_state.get(selected_key, trace.n_tokens - 1))
|
| 105 |
-
selected = max(0, min(selected, trace.n_tokens - 1))
|
| 106 |
-
|
| 107 |
-
window_size = st.slider(
|
| 108 |
-
"Token window",
|
| 109 |
-
min_value=8,
|
| 110 |
-
max_value=min(96, max(8, trace.n_tokens)),
|
| 111 |
-
value=min(32, max(8, trace.n_tokens)),
|
| 112 |
-
step=8,
|
| 113 |
-
key=widget_key(context_key, "probe_token_window", trace.prompt_hash[:12]),
|
| 114 |
-
)
|
| 115 |
-
center = st.slider(
|
| 116 |
-
"Window center",
|
| 117 |
min_value=0,
|
| 118 |
max_value=trace.n_tokens - 1,
|
| 119 |
-
value=
|
| 120 |
-
key=widget_key(context_key, "
|
| 121 |
)
|
| 122 |
-
start = max(0, center - window_size // 2)
|
| 123 |
-
end = min(trace.n_tokens, start + window_size)
|
| 124 |
-
start = max(0, end - window_size)
|
| 125 |
-
|
| 126 |
-
cols = st.columns(8)
|
| 127 |
-
for offset, token_index in enumerate(range(start, end)):
|
| 128 |
-
col = cols[offset % len(cols)]
|
| 129 |
-
token = trace.tokens[token_index]
|
| 130 |
-
if col.button(
|
| 131 |
-
_token_button_label(token_index, token),
|
| 132 |
-
key=widget_key(
|
| 133 |
-
context_key,
|
| 134 |
-
"probe_token",
|
| 135 |
-
trace.prompt_hash[:12],
|
| 136 |
-
str(token_index),
|
| 137 |
-
),
|
| 138 |
-
type="primary" if token_index == selected else "secondary",
|
| 139 |
-
help=token.encode("unicode_escape").decode("ascii"),
|
| 140 |
-
):
|
| 141 |
-
selected = token_index
|
| 142 |
-
st.session_state[selected_key] = token_index
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
)
|
| 148 |
return selected
|
| 149 |
|
|
@@ -163,6 +128,128 @@ def _model_dimensions(model: object) -> tuple[int, int]:
|
|
| 163 |
return int(hidden_size), int(num_layers)
|
| 164 |
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
def render_probe_inspector(
|
| 167 |
*,
|
| 168 |
context_key: str,
|
|
@@ -188,88 +275,40 @@ def render_probe_inspector(
|
|
| 188 |
if probe is None:
|
| 189 |
return
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
try:
|
| 194 |
-
hidden_size, num_layers = _model_dimensions(model)
|
| 195 |
-
except Exception as exc:
|
| 196 |
-
st.error(str(exc))
|
| 197 |
return
|
|
|
|
| 198 |
|
| 199 |
-
layer =
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
min_value=0,
|
| 205 |
-
max_value=max(0, num_layers - 1),
|
| 206 |
-
value=min(15, max(0, num_layers - 1)),
|
| 207 |
-
step=1,
|
| 208 |
-
key=widget_key(context_key, "probe_layer"),
|
| 209 |
-
)
|
| 210 |
-
)
|
| 211 |
-
|
| 212 |
-
location = probe.location
|
| 213 |
-
if location is None:
|
| 214 |
-
location = st.selectbox(
|
| 215 |
-
"Activation location",
|
| 216 |
-
options=("post_reasoning", "pre_reasoning"),
|
| 217 |
-
key=widget_key(context_key, "probe_location"),
|
| 218 |
-
)
|
| 219 |
-
|
| 220 |
st.caption(
|
| 221 |
f"Probe layer {layer}; {location}; input dim {probe.input_dim}; "
|
| 222 |
f"model hidden size {hidden_size}"
|
| 223 |
)
|
| 224 |
-
if not
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
"This probe input dim does not match a single-token activation "
|
| 230 |
-
"for the active model."
|
| 231 |
-
)
|
| 232 |
-
return
|
| 233 |
-
|
| 234 |
-
trace_key = widget_key(context_key, "probe_trace_enabled")
|
| 235 |
-
if st.button(
|
| 236 |
-
"Trace conversation",
|
| 237 |
-
key=widget_key(context_key, "probe_trace"),
|
| 238 |
-
use_container_width=True,
|
| 239 |
):
|
| 240 |
-
st.session_state[trace_key] = True
|
| 241 |
-
if not st.session_state.get(trace_key, False):
|
| 242 |
-
return
|
| 243 |
-
|
| 244 |
-
messages = build_chat_messages(active_system_prompt, chat_state["messages"])
|
| 245 |
-
with st.spinner("Tracing conversation..."):
|
| 246 |
-
trace = trace_conversation(
|
| 247 |
-
model=model,
|
| 248 |
-
model_name=model_name,
|
| 249 |
-
messages=messages,
|
| 250 |
-
layer=layer,
|
| 251 |
-
location=location,
|
| 252 |
-
remote=remote,
|
| 253 |
-
)
|
| 254 |
-
|
| 255 |
-
st.caption(
|
| 256 |
-
f"Cached {trace.n_tokens} tokens from layer {trace.layer}; "
|
| 257 |
-
f"prompt hash `{trace.prompt_hash[:10]}`"
|
| 258 |
-
)
|
| 259 |
-
if trace.n_tokens == 0:
|
| 260 |
-
st.warning("The traced conversation produced no tokens.")
|
| 261 |
return
|
| 262 |
|
| 263 |
-
|
| 264 |
-
try:
|
| 265 |
-
vector = vectorize_token(trace, token_index=selected_token)
|
| 266 |
-
result = probe.run(vector.vector)
|
| 267 |
-
except Exception as exc:
|
| 268 |
-
st.error(f"Probe execution failed: {exc}")
|
| 269 |
return
|
| 270 |
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
)
|
| 275 |
-
|
|
|
|
|
|
|
|
|
| 16 |
from utils.runtime import cached_model
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def _render_probe_results(result: ProbeRunResult, probe: LoadedProbe) -> None:
|
| 20 |
top_k = min(5, int(result.probabilities.numel()))
|
| 21 |
if top_k == 0:
|
|
|
|
| 88 |
return load_probe(repo_id.strip(), selected_file)
|
| 89 |
|
| 90 |
|
| 91 |
+
def _render_token_picker(trace: ConversationTrace, context_key: str) -> int:
|
| 92 |
+
selected = st.slider(
|
| 93 |
+
"Token index",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
min_value=0,
|
| 95 |
max_value=trace.n_tokens - 1,
|
| 96 |
+
value=trace.n_tokens - 1,
|
| 97 |
+
key=widget_key(context_key, "probe_selected_token", trace.prompt_hash[:12]),
|
| 98 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
+
window = 8
|
| 101 |
+
start = max(0, selected - window)
|
| 102 |
+
end = min(trace.n_tokens, selected + window + 1)
|
| 103 |
+
parts: list[str] = []
|
| 104 |
+
for i in range(start, end):
|
| 105 |
+
token_repr = trace.tokens[i].encode("unicode_escape").decode("ascii") or "·"
|
| 106 |
+
parts.append(f"**[{token_repr}]**" if i == selected else token_repr)
|
| 107 |
+
st.markdown(
|
| 108 |
+
f"<div style='font-family:ui-monospace,monospace;font-size:0.85em;"
|
| 109 |
+
f"line-height:1.6;background:rgba(127,127,127,0.08);padding:6px 10px;"
|
| 110 |
+
f"border-radius:4px;'>{' '.join(parts)}</div>",
|
| 111 |
+
unsafe_allow_html=True,
|
| 112 |
)
|
| 113 |
return selected
|
| 114 |
|
|
|
|
| 128 |
return int(hidden_size), int(num_layers)
|
| 129 |
|
| 130 |
|
| 131 |
+
def _load_model_with_dimensions(model_name: str) -> tuple[object, int, int] | None:
|
| 132 |
+
with st.spinner("Loading model metadata..."):
|
| 133 |
+
model = cached_model(model_name=model_name)
|
| 134 |
+
try:
|
| 135 |
+
hidden_size, num_layers = _model_dimensions(model)
|
| 136 |
+
except Exception as exc:
|
| 137 |
+
st.error(str(exc))
|
| 138 |
+
return None
|
| 139 |
+
return model, hidden_size, num_layers
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _select_probe_target(
|
| 143 |
+
*,
|
| 144 |
+
probe: LoadedProbe,
|
| 145 |
+
context_key: str,
|
| 146 |
+
num_layers: int,
|
| 147 |
+
) -> tuple[int, str]:
|
| 148 |
+
layer = probe.layer
|
| 149 |
+
if layer is None:
|
| 150 |
+
layer = int(
|
| 151 |
+
st.number_input(
|
| 152 |
+
"Layer",
|
| 153 |
+
min_value=0,
|
| 154 |
+
max_value=max(0, num_layers - 1),
|
| 155 |
+
value=min(15, max(0, num_layers - 1)),
|
| 156 |
+
step=1,
|
| 157 |
+
key=widget_key(context_key, "probe_layer"),
|
| 158 |
+
)
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
location = probe.location
|
| 162 |
+
if location is None:
|
| 163 |
+
location = st.selectbox(
|
| 164 |
+
"Activation location",
|
| 165 |
+
options=("post_reasoning", "pre_reasoning"),
|
| 166 |
+
key=widget_key(context_key, "probe_location"),
|
| 167 |
+
)
|
| 168 |
+
return layer, location
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _probe_target_is_valid(
|
| 172 |
+
*,
|
| 173 |
+
probe: LoadedProbe,
|
| 174 |
+
layer: int,
|
| 175 |
+
num_layers: int,
|
| 176 |
+
hidden_size: int,
|
| 177 |
+
) -> bool:
|
| 178 |
+
if not 0 <= layer < num_layers:
|
| 179 |
+
st.error(f"Probe layer {layer} is outside the model's {num_layers} layers.")
|
| 180 |
+
return False
|
| 181 |
+
if probe.input_dim != hidden_size:
|
| 182 |
+
st.warning(
|
| 183 |
+
"This probe input dim does not match a single-token activation "
|
| 184 |
+
"for the active model."
|
| 185 |
+
)
|
| 186 |
+
return False
|
| 187 |
+
return True
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def _trace_requested(context_key: str) -> bool:
|
| 191 |
+
trace_key = widget_key(context_key, "probe_trace_enabled")
|
| 192 |
+
if st.button(
|
| 193 |
+
"Trace conversation",
|
| 194 |
+
key=widget_key(context_key, "probe_trace"),
|
| 195 |
+
use_container_width=True,
|
| 196 |
+
):
|
| 197 |
+
st.session_state[trace_key] = True
|
| 198 |
+
return bool(st.session_state.get(trace_key, False))
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _trace_active_conversation(
|
| 202 |
+
*,
|
| 203 |
+
model: object,
|
| 204 |
+
model_name: str,
|
| 205 |
+
remote: bool,
|
| 206 |
+
active_system_prompt: str | None,
|
| 207 |
+
chat_state: dict[str, object],
|
| 208 |
+
layer: int,
|
| 209 |
+
location: str,
|
| 210 |
+
) -> ConversationTrace | None:
|
| 211 |
+
messages = build_chat_messages(active_system_prompt, chat_state["messages"])
|
| 212 |
+
with st.spinner("Tracing conversation..."):
|
| 213 |
+
trace = trace_conversation(
|
| 214 |
+
model=model,
|
| 215 |
+
model_name=model_name,
|
| 216 |
+
messages=messages,
|
| 217 |
+
layer=layer,
|
| 218 |
+
location=location,
|
| 219 |
+
remote=remote,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
st.caption(
|
| 223 |
+
f"Cached {trace.n_tokens} tokens from layer {trace.layer}; "
|
| 224 |
+
f"prompt hash `{trace.prompt_hash[:10]}`"
|
| 225 |
+
)
|
| 226 |
+
if trace.n_tokens == 0:
|
| 227 |
+
st.warning("The traced conversation produced no tokens.")
|
| 228 |
+
return None
|
| 229 |
+
return trace
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _run_probe_on_selected_token(
|
| 233 |
+
*,
|
| 234 |
+
trace: ConversationTrace,
|
| 235 |
+
context_key: str,
|
| 236 |
+
probe: LoadedProbe,
|
| 237 |
+
) -> None:
|
| 238 |
+
selected_token = _render_token_picker(trace, context_key)
|
| 239 |
+
try:
|
| 240 |
+
vector = vectorize_token(trace, token_index=selected_token)
|
| 241 |
+
result = probe.run(vector.vector)
|
| 242 |
+
except Exception as exc:
|
| 243 |
+
st.error(f"Probe execution failed: {exc}")
|
| 244 |
+
return
|
| 245 |
+
|
| 246 |
+
st.caption(
|
| 247 |
+
f"Vectorization {vector.mode}; token {vector.token_index}; "
|
| 248 |
+
f"vector dim {int(vector.vector.shape[0])}"
|
| 249 |
+
)
|
| 250 |
+
_render_probe_results(result, probe)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
def render_probe_inspector(
|
| 254 |
*,
|
| 255 |
context_key: str,
|
|
|
|
| 275 |
if probe is None:
|
| 276 |
return
|
| 277 |
|
| 278 |
+
loaded = _load_model_with_dimensions(model_name)
|
| 279 |
+
if loaded is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
return
|
| 281 |
+
model, hidden_size, num_layers = loaded
|
| 282 |
|
| 283 |
+
layer, location = _select_probe_target(
|
| 284 |
+
probe=probe,
|
| 285 |
+
context_key=context_key,
|
| 286 |
+
num_layers=num_layers,
|
| 287 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
st.caption(
|
| 289 |
f"Probe layer {layer}; {location}; input dim {probe.input_dim}; "
|
| 290 |
f"model hidden size {hidden_size}"
|
| 291 |
)
|
| 292 |
+
if not _probe_target_is_valid(
|
| 293 |
+
probe=probe,
|
| 294 |
+
layer=layer,
|
| 295 |
+
num_layers=num_layers,
|
| 296 |
+
hidden_size=hidden_size,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
return
|
| 299 |
|
| 300 |
+
if not _trace_requested(context_key):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
return
|
| 302 |
|
| 303 |
+
trace = _trace_active_conversation(
|
| 304 |
+
model=model,
|
| 305 |
+
model_name=model_name,
|
| 306 |
+
remote=remote,
|
| 307 |
+
active_system_prompt=active_system_prompt,
|
| 308 |
+
chat_state=chat_state,
|
| 309 |
+
layer=layer,
|
| 310 |
+
location=location,
|
| 311 |
)
|
| 312 |
+
if trace is None:
|
| 313 |
+
return
|
| 314 |
+
_run_probe_on_selected_token(trace=trace, context_key=context_key, probe=probe)
|
utils/contrast.py
CHANGED
|
@@ -11,7 +11,6 @@ Negative (blue) → token is more characteristic of persona B.
|
|
| 11 |
Near-zero (gray) → both personas would emit this token with similar likelihood.
|
| 12 |
"""
|
| 13 |
|
| 14 |
-
import logging
|
| 15 |
from dataclasses import dataclass
|
| 16 |
from html import escape
|
| 17 |
|
|
@@ -20,8 +19,6 @@ from nnterp import StandardizedTransformer
|
|
| 20 |
|
| 21 |
from utils.chat import format_generation_prompt
|
| 22 |
|
| 23 |
-
logger = logging.getLogger(__name__)
|
| 24 |
-
|
| 25 |
|
| 26 |
@dataclass
|
| 27 |
class TokenContrast:
|
|
@@ -73,28 +70,18 @@ def _strip_special_ids(
|
|
| 73 |
return ids[keep], keep
|
| 74 |
|
| 75 |
|
| 76 |
-
def
|
| 77 |
tokenizer: object,
|
| 78 |
context_messages: list[dict[str, str]],
|
| 79 |
response_ids: torch.Tensor,
|
| 80 |
-
) -> tuple[
|
| 81 |
-
"""Build
|
| 82 |
context_prompt, _ = format_generation_prompt(context_messages, tokenizer)
|
| 83 |
context_ids = tokenizer(context_prompt, return_tensors="pt").input_ids[0]
|
| 84 |
-
|
| 85 |
-
full_text = context_prompt + response_text
|
| 86 |
-
full_ids = tokenizer(full_text, return_tensors="pt").input_ids[0]
|
| 87 |
-
expected_ids = torch.cat([context_ids, response_ids.cpu()])
|
| 88 |
-
if full_ids.tolist() != expected_ids.tolist():
|
| 89 |
-
logger.warning(
|
| 90 |
-
"contrast trace text did not round-trip to the expected token ids "
|
| 91 |
-
"(expected %d tokens, got %d); contrast scores may be slightly misaligned",
|
| 92 |
-
len(expected_ids),
|
| 93 |
-
len(full_ids),
|
| 94 |
-
)
|
| 95 |
n_ctx = len(context_ids)
|
| 96 |
n_resp = len(response_ids)
|
| 97 |
-
return
|
| 98 |
|
| 99 |
|
| 100 |
def _build_contrast(
|
|
@@ -122,8 +109,8 @@ def _token_display(tokenizer: object, token_id: int) -> str:
|
|
| 122 |
return _decode_ids(tokenizer, [token_id])
|
| 123 |
|
| 124 |
|
| 125 |
-
# Each spec: (key,
|
| 126 |
-
PassSpec = tuple[str,
|
| 127 |
|
| 128 |
|
| 129 |
def _score_passes(
|
|
@@ -140,12 +127,12 @@ def _score_passes(
|
|
| 140 |
"""
|
| 141 |
|
| 142 |
def _score_pass(
|
| 143 |
-
|
| 144 |
n_ctx: int,
|
| 145 |
n_resp: int,
|
| 146 |
target_ids: torch.Tensor,
|
| 147 |
) -> torch.Tensor:
|
| 148 |
-
with torch.no_grad(), model.trace(
|
| 149 |
# logit at position i predicts token i+1, so response token j
|
| 150 |
# (at full-text position n_ctx+j) uses logit at n_ctx+j-1.
|
| 151 |
resp_logits = model.logits[0, n_ctx - 1 : n_ctx - 1 + n_resp].float()
|
|
@@ -163,8 +150,8 @@ def _score_passes(
|
|
| 163 |
return out.detach().cpu()
|
| 164 |
|
| 165 |
return {
|
| 166 |
-
key: _score_pass(
|
| 167 |
-
for key,
|
| 168 |
}
|
| 169 |
|
| 170 |
|
|
@@ -176,11 +163,13 @@ def _specs_for_response(
|
|
| 176 |
prefix: str,
|
| 177 |
) -> list[PassSpec]:
|
| 178 |
"""Build the (under_a, under_b) pass specs for a single response."""
|
| 179 |
-
|
| 180 |
-
|
|
|
|
|
|
|
| 181 |
return [
|
| 182 |
-
(f"{prefix}_under_a",
|
| 183 |
-
(f"{prefix}_under_b",
|
| 184 |
]
|
| 185 |
|
| 186 |
|
|
|
|
| 11 |
Near-zero (gray) → both personas would emit this token with similar likelihood.
|
| 12 |
"""
|
| 13 |
|
|
|
|
| 14 |
from dataclasses import dataclass
|
| 15 |
from html import escape
|
| 16 |
|
|
|
|
| 19 |
|
| 20 |
from utils.chat import format_generation_prompt
|
| 21 |
|
|
|
|
|
|
|
| 22 |
|
| 23 |
@dataclass
|
| 24 |
class TokenContrast:
|
|
|
|
| 70 |
return ids[keep], keep
|
| 71 |
|
| 72 |
|
| 73 |
+
def _prepare_trace_input_ids(
|
| 74 |
tokenizer: object,
|
| 75 |
context_messages: list[dict[str, str]],
|
| 76 |
response_ids: torch.Tensor,
|
| 77 |
+
) -> tuple[torch.Tensor, int, int]:
|
| 78 |
+
"""Build exact trace input ids and return ``(input_ids, n_ctx, n_resp)``."""
|
| 79 |
context_prompt, _ = format_generation_prompt(context_messages, tokenizer)
|
| 80 |
context_ids = tokenizer(context_prompt, return_tensors="pt").input_ids[0]
|
| 81 |
+
input_ids = torch.cat([context_ids.cpu(), response_ids.detach().cpu()])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
n_ctx = len(context_ids)
|
| 83 |
n_resp = len(response_ids)
|
| 84 |
+
return input_ids, n_ctx, n_resp
|
| 85 |
|
| 86 |
|
| 87 |
def _build_contrast(
|
|
|
|
| 109 |
return _decode_ids(tokenizer, [token_id])
|
| 110 |
|
| 111 |
|
| 112 |
+
# Each spec: (key, input_ids, n_ctx, n_resp, target_ids).
|
| 113 |
+
PassSpec = tuple[str, torch.Tensor, int, int, torch.Tensor]
|
| 114 |
|
| 115 |
|
| 116 |
def _score_passes(
|
|
|
|
| 127 |
"""
|
| 128 |
|
| 129 |
def _score_pass(
|
| 130 |
+
input_ids: torch.Tensor,
|
| 131 |
n_ctx: int,
|
| 132 |
n_resp: int,
|
| 133 |
target_ids: torch.Tensor,
|
| 134 |
) -> torch.Tensor:
|
| 135 |
+
with torch.no_grad(), model.trace(input_ids, remote=remote):
|
| 136 |
# logit at position i predicts token i+1, so response token j
|
| 137 |
# (at full-text position n_ctx+j) uses logit at n_ctx+j-1.
|
| 138 |
resp_logits = model.logits[0, n_ctx - 1 : n_ctx - 1 + n_resp].float()
|
|
|
|
| 150 |
return out.detach().cpu()
|
| 151 |
|
| 152 |
return {
|
| 153 |
+
key: _score_pass(input_ids, n_ctx, n_resp, target_ids)
|
| 154 |
+
for key, input_ids, n_ctx, n_resp, target_ids in specs
|
| 155 |
}
|
| 156 |
|
| 157 |
|
|
|
|
| 163 |
prefix: str,
|
| 164 |
) -> list[PassSpec]:
|
| 165 |
"""Build the (under_a, under_b) pass specs for a single response."""
|
| 166 |
+
input_a, n_ctx_a, n_resp = _prepare_trace_input_ids(
|
| 167 |
+
tokenizer, context_a, response_ids
|
| 168 |
+
)
|
| 169 |
+
input_b, n_ctx_b, _ = _prepare_trace_input_ids(tokenizer, context_b, response_ids)
|
| 170 |
return [
|
| 171 |
+
(f"{prefix}_under_a", input_a, n_ctx_a, n_resp, response_ids),
|
| 172 |
+
(f"{prefix}_under_b", input_b, n_ctx_b, n_resp, response_ids),
|
| 173 |
]
|
| 174 |
|
| 175 |
|
utils/runtime.py
CHANGED
|
@@ -1,8 +1,56 @@
|
|
|
|
|
| 1 |
import logging
|
|
|
|
| 2 |
|
| 3 |
import streamlit as st
|
| 4 |
|
| 5 |
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
@st.cache_data(show_spinner=False, ttl=30)
|
|
@@ -16,8 +64,6 @@ def list_remote_models() -> list[str]:
|
|
| 16 |
the whole response. See nnsight 0.6.3 ``ndif.py::status``.
|
| 17 |
"""
|
| 18 |
|
| 19 |
-
import json
|
| 20 |
-
|
| 21 |
import nnsight
|
| 22 |
|
| 23 |
try:
|
|
@@ -29,32 +75,11 @@ def list_remote_models() -> list[str]:
|
|
| 29 |
model_names: list[str] = []
|
| 30 |
bad_states: list[tuple[str, str]] = [] # (repo_id_or_key, application_state)
|
| 31 |
|
| 32 |
-
for
|
| 33 |
-
if
|
| 34 |
-
|
| 35 |
-
if (
|
| 36 |
-
|
| 37 |
-
and "schedule" not in value
|
| 38 |
-
):
|
| 39 |
-
continue
|
| 40 |
-
|
| 41 |
-
model_key = value.get("model_key", "")
|
| 42 |
-
model_class = model_key.split(":", 1)[0].split(".")[-1]
|
| 43 |
-
try:
|
| 44 |
-
repo_id = json.loads(model_key.split(":", 1)[-1]).get("repo_id")
|
| 45 |
-
except Exception:
|
| 46 |
-
repo_id = model_key
|
| 47 |
-
|
| 48 |
-
state = value.get("application_state", "NOT DEPLOYED")
|
| 49 |
-
if state not in {"RUNNING", "NOT DEPLOYED", "DEPLOYING", "DELETING"}:
|
| 50 |
-
bad_states.append((repo_id or model_key, state))
|
| 51 |
-
|
| 52 |
-
if model_class not in {"LanguageModel", "StandardizedTransformer"}:
|
| 53 |
-
continue
|
| 54 |
-
if state != "RUNNING":
|
| 55 |
-
continue
|
| 56 |
-
if isinstance(repo_id, str):
|
| 57 |
-
model_names.append(repo_id)
|
| 58 |
|
| 59 |
if bad_states:
|
| 60 |
logger.warning(
|
|
@@ -67,27 +92,17 @@ def list_remote_models() -> list[str]:
|
|
| 67 |
|
| 68 |
|
| 69 |
@st.cache_resource(show_spinner=False, max_entries=1)
|
| 70 |
-
def
|
| 71 |
"""Load and cache a standardized nnterp model.
|
| 72 |
|
| 73 |
Streamlit reruns this app on every interaction, so caching keeps one loaded
|
| 74 |
model instance per model name instead of reloading weights on every widget
|
| 75 |
-
change.
|
|
|
|
|
|
|
|
|
|
| 76 |
"""
|
| 77 |
|
| 78 |
from nnterp import StandardizedTransformer
|
| 79 |
|
| 80 |
-
# The remote constructor path is currently unstable for this model wrapper.
|
| 81 |
-
# return StandardizedTransformer(model_name, remote=remote, check_renaming=False)
|
| 82 |
return StandardizedTransformer(model_name)
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
def cached_model(model_name: str, remote: bool):
|
| 86 |
-
"""Return the cached model for ``model_name``.
|
| 87 |
-
|
| 88 |
-
``remote`` still matters at generation/trace time, but the current
|
| 89 |
-
``StandardizedTransformer`` constructor ignores it. Keeping it out of the
|
| 90 |
-
cache key avoids loading duplicate local model objects when toggling NDIF.
|
| 91 |
-
"""
|
| 92 |
-
|
| 93 |
-
return _cached_model_by_name(model_name)
|
|
|
|
| 1 |
+
import json
|
| 2 |
import logging
|
| 3 |
+
from collections.abc import Iterable
|
| 4 |
|
| 5 |
import streamlit as st
|
| 6 |
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
+
_LANGUAGE_MODEL_CLASSES = {"LanguageModel", "StandardizedTransformer"}
|
| 9 |
+
_EXPECTED_NDIF_STATES = {"RUNNING", "NOT DEPLOYED", "DEPLOYING", "DELETING"}
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _iter_deployments(raw: object) -> Iterable[dict]:
|
| 13 |
+
if not isinstance(raw, dict):
|
| 14 |
+
return ()
|
| 15 |
+
deployments = raw.get("deployments", {})
|
| 16 |
+
if not isinstance(deployments, dict):
|
| 17 |
+
return ()
|
| 18 |
+
return (value for value in deployments.values() if isinstance(value, dict))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _is_visible_deployment(deployment: dict) -> bool:
|
| 22 |
+
return deployment.get("deployment_level") in {"HOT", "WARM"} or (
|
| 23 |
+
"schedule" in deployment
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _repo_id_from_model_key(model_key: str) -> str:
|
| 28 |
+
try:
|
| 29 |
+
repo_id = json.loads(model_key.split(":", 1)[-1]).get("repo_id")
|
| 30 |
+
except Exception:
|
| 31 |
+
return model_key
|
| 32 |
+
return repo_id if isinstance(repo_id, str) else model_key
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _running_language_model(deployment: dict) -> str | None:
|
| 36 |
+
if not _is_visible_deployment(deployment):
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
model_key = deployment.get("model_key", "")
|
| 40 |
+
model_class = model_key.split(":", 1)[0].split(".")[-1]
|
| 41 |
+
if model_class not in _LANGUAGE_MODEL_CLASSES:
|
| 42 |
+
return None
|
| 43 |
+
if deployment.get("application_state", "NOT DEPLOYED") != "RUNNING":
|
| 44 |
+
return None
|
| 45 |
+
return _repo_id_from_model_key(model_key)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _unexpected_state(deployment: dict) -> tuple[str, str] | None:
|
| 49 |
+
state = deployment.get("application_state", "NOT DEPLOYED")
|
| 50 |
+
if state in _EXPECTED_NDIF_STATES:
|
| 51 |
+
return None
|
| 52 |
+
model_key = deployment.get("model_key", "")
|
| 53 |
+
return _repo_id_from_model_key(model_key), state
|
| 54 |
|
| 55 |
|
| 56 |
@st.cache_data(show_spinner=False, ttl=30)
|
|
|
|
| 64 |
the whole response. See nnsight 0.6.3 ``ndif.py::status``.
|
| 65 |
"""
|
| 66 |
|
|
|
|
|
|
|
| 67 |
import nnsight
|
| 68 |
|
| 69 |
try:
|
|
|
|
| 75 |
model_names: list[str] = []
|
| 76 |
bad_states: list[tuple[str, str]] = [] # (repo_id_or_key, application_state)
|
| 77 |
|
| 78 |
+
for deployment in _iter_deployments(raw):
|
| 79 |
+
if bad_state := _unexpected_state(deployment):
|
| 80 |
+
bad_states.append(bad_state)
|
| 81 |
+
if model_name := _running_language_model(deployment):
|
| 82 |
+
model_names.append(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
if bad_states:
|
| 85 |
logger.warning(
|
|
|
|
| 92 |
|
| 93 |
|
| 94 |
@st.cache_resource(show_spinner=False, max_entries=1)
|
| 95 |
+
def cached_model(model_name: str):
|
| 96 |
"""Load and cache a standardized nnterp model.
|
| 97 |
|
| 98 |
Streamlit reruns this app on every interaction, so caching keeps one loaded
|
| 99 |
model instance per model name instead of reloading weights on every widget
|
| 100 |
+
change. ``remote`` is intentionally not part of the cache key: it matters
|
| 101 |
+
at generation/trace time, but the current ``StandardizedTransformer``
|
| 102 |
+
constructor ignores it, and excluding it avoids loading duplicate local
|
| 103 |
+
model objects when toggling NDIF.
|
| 104 |
"""
|
| 105 |
|
| 106 |
from nnterp import StandardizedTransformer
|
| 107 |
|
|
|
|
|
|
|
| 108 |
return StandardizedTransformer(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|