Spaces:
Sleeping
Sleeping
Jac-Zac commited on
Commit ·
6b81e1e
1
Parent(s): 88f2164
Testing newer version implemetation
Browse files- pyproject.toml +2 -2
- state.py +1 -27
- tabs/chat.py +6 -23
- tabs/compare.py +8 -9
- tabs/compare_chat.py +105 -198
- tabs/extract.py +87 -56
- uv.lock +55 -89
pyproject.toml
CHANGED
|
@@ -6,7 +6,7 @@ readme = "README.md"
|
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
"persona-vectors>=0.4.3",
|
| 9 |
-
"persona-data>=0.
|
| 10 |
"streamlit>=1.44.0",
|
| 11 |
"plotly>=6.6.0",
|
| 12 |
"python-dotenv>=1.2.2",
|
|
@@ -15,7 +15,7 @@ dependencies = [
|
|
| 15 |
|
| 16 |
# Local development:
|
| 17 |
[tool.uv.sources]
|
| 18 |
-
|
| 19 |
# persona-data = { path = "../persona-data", editable = true }
|
| 20 |
|
| 21 |
# [build-system]
|
|
|
|
| 6 |
requires-python = ">=3.12"
|
| 7 |
dependencies = [
|
| 8 |
"persona-vectors>=0.4.3",
|
| 9 |
+
"persona-data>=0.3.4",
|
| 10 |
"streamlit>=1.44.0",
|
| 11 |
"plotly>=6.6.0",
|
| 12 |
"python-dotenv>=1.2.2",
|
|
|
|
| 15 |
|
| 16 |
# Local development:
|
| 17 |
[tool.uv.sources]
|
| 18 |
+
persona-vectors = { path = "../persona-vectors", editable = true }
|
| 19 |
# persona-data = { path = "../persona-data", editable = true }
|
| 20 |
|
| 21 |
# [build-system]
|
state.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
|
| 3 |
_CHAT_STATE_PREFIX = "chat_state::"
|
| 4 |
-
_CHAT_KEYS_REGISTRY = "chat_state::_registered_keys"
|
| 5 |
|
| 6 |
|
| 7 |
def chat_session_key(model_name: str, dataset_source: str) -> str:
|
|
@@ -35,38 +34,13 @@ def reset_chat_context_state(
|
|
| 35 |
st.session_state.pop(key, None)
|
| 36 |
|
| 37 |
|
| 38 |
-
def _evict_inactive_kv_caches(active_key: str) -> None:
|
| 39 |
-
"""Drop past_key_values from every chat context except the active one."""
|
| 40 |
-
|
| 41 |
-
for key in st.session_state.get(_CHAT_KEYS_REGISTRY, ()):
|
| 42 |
-
if key != active_key:
|
| 43 |
-
state = st.session_state.get(key)
|
| 44 |
-
if isinstance(state, dict) and state.get("past_key_values") is not None:
|
| 45 |
-
state["past_key_values"] = None
|
| 46 |
-
|
| 47 |
-
|
| 48 |
def get_chat_state(
|
| 49 |
model_name: str, remote: bool, dataset_source: str
|
| 50 |
) -> dict[str, object]:
|
| 51 |
"""Return the mutable chat state for the active context."""
|
| 52 |
|
| 53 |
key = chat_session_key(model_name, dataset_source)
|
| 54 |
-
|
| 55 |
-
if registry is None:
|
| 56 |
-
registry = set()
|
| 57 |
-
st.session_state[_CHAT_KEYS_REGISTRY] = registry
|
| 58 |
-
registry.add(key)
|
| 59 |
-
|
| 60 |
-
state = st.session_state.get(key)
|
| 61 |
-
if state is None:
|
| 62 |
-
state = default_chat_state()
|
| 63 |
-
st.session_state[key] = state
|
| 64 |
-
else:
|
| 65 |
-
state.setdefault("messages", [])
|
| 66 |
-
state.setdefault("persona_id", None)
|
| 67 |
-
state.setdefault("prompt_mode", "templated")
|
| 68 |
-
state.setdefault("past_key_values", None)
|
| 69 |
-
_evict_inactive_kv_caches(key)
|
| 70 |
if remote and state.get("past_key_values") is not None:
|
| 71 |
state["past_key_values"] = None
|
| 72 |
return state
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
|
| 3 |
_CHAT_STATE_PREFIX = "chat_state::"
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
def chat_session_key(model_name: str, dataset_source: str) -> str:
|
|
|
|
| 34 |
st.session_state.pop(key, None)
|
| 35 |
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def get_chat_state(
|
| 38 |
model_name: str, remote: bool, dataset_source: str
|
| 39 |
) -> dict[str, object]:
|
| 40 |
"""Return the mutable chat state for the active context."""
|
| 41 |
|
| 42 |
key = chat_session_key(model_name, dataset_source)
|
| 43 |
+
state = st.session_state.setdefault(key, default_chat_state())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
if remote and state.get("past_key_values") is not None:
|
| 45 |
state["past_key_values"] = None
|
| 46 |
return state
|
tabs/chat.py
CHANGED
|
@@ -169,10 +169,9 @@ def render_system_prompt(
|
|
| 169 |
return st.session_state.get(prompt_key) or None
|
| 170 |
|
| 171 |
|
| 172 |
-
def generation_dict(gen_kwargs: dict
|
| 173 |
return {
|
| 174 |
"max_new_tokens": int(gen_kwargs["max_new_tokens"]),
|
| 175 |
-
"advanced_generation": bool(advanced_generation),
|
| 176 |
"use_sampling": bool(gen_kwargs["do_sample"]),
|
| 177 |
"temperature": float(gen_kwargs["temperature"]),
|
| 178 |
"top_p": float(gen_kwargs["top_p"]),
|
|
@@ -258,12 +257,8 @@ def build_chat_messages(
|
|
| 258 |
# ── Main tab entry point ───────────────────────────────────────────────────────
|
| 259 |
|
| 260 |
|
| 261 |
-
def _render_generation_settings(context_key: str, remote: bool) ->
|
| 262 |
-
"""Render the Advanced generation settings expander.
|
| 263 |
-
|
| 264 |
-
Returns ``(gen_kwargs, advanced_generation)`` where ``advanced_generation``
|
| 265 |
-
is True when any generation setting differs from its default.
|
| 266 |
-
"""
|
| 267 |
with st.expander("Advanced", expanded=False):
|
| 268 |
config_col1, config_col2 = st.columns([2, 1])
|
| 269 |
with config_col1:
|
|
@@ -349,19 +344,9 @@ def _render_generation_settings(context_key: str, remote: bool) -> tuple[dict, b
|
|
| 349 |
if remote:
|
| 350 |
st.caption("Seed is local-only and disabled for remote runs.")
|
| 351 |
|
| 352 |
-
advanced_generation = (
|
| 353 |
-
max_new_tokens != _GEN_DEFAULTS["max_new_tokens"]
|
| 354 |
-
or use_sampling
|
| 355 |
-
or temperature != _GEN_DEFAULTS["temperature"]
|
| 356 |
-
or top_p != _GEN_DEFAULTS["top_p"]
|
| 357 |
-
or top_k != _GEN_DEFAULTS["top_k"]
|
| 358 |
-
or repetition_penalty != _GEN_DEFAULTS["repetition_penalty"]
|
| 359 |
-
or seed is not None
|
| 360 |
-
)
|
| 361 |
-
|
| 362 |
do_sample = bool(use_sampling)
|
| 363 |
generation_seed = seed if do_sample and seed is not None and not remote else None
|
| 364 |
-
|
| 365 |
"max_new_tokens": int(max_new_tokens),
|
| 366 |
"do_sample": do_sample,
|
| 367 |
"temperature": temperature,
|
|
@@ -370,7 +355,6 @@ def _render_generation_settings(context_key: str, remote: bool) -> tuple[dict, b
|
|
| 370 |
"repetition_penalty": repetition_penalty,
|
| 371 |
"seed": generation_seed,
|
| 372 |
}
|
| 373 |
-
return gen_kwargs, advanced_generation
|
| 374 |
|
| 375 |
|
| 376 |
def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
@@ -406,7 +390,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 406 |
st.info("Try a different dataset source or upload a non-empty personas file.")
|
| 407 |
return
|
| 408 |
|
| 409 |
-
gen_kwargs
|
| 410 |
|
| 411 |
# ── Mode toggle ───────────────────────────────────────────────────────────
|
| 412 |
compare_key = widget_key(context_key, "compare_mode")
|
|
@@ -431,7 +415,6 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 431 |
dataset_source,
|
| 432 |
personas,
|
| 433 |
gen_kwargs,
|
| 434 |
-
advanced_generation,
|
| 435 |
)
|
| 436 |
return
|
| 437 |
|
|
@@ -513,7 +496,7 @@ def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
| 513 |
prompt_mode=prompt_mode,
|
| 514 |
system_prompt=active_system_prompt,
|
| 515 |
messages=chat_state["messages"],
|
| 516 |
-
generation=generation_dict(gen_kwargs
|
| 517 |
)
|
| 518 |
st.toast("Exported", icon=":material/check:")
|
| 519 |
with rst_col:
|
|
|
|
| 169 |
return st.session_state.get(prompt_key) or None
|
| 170 |
|
| 171 |
|
| 172 |
+
def generation_dict(gen_kwargs: dict) -> dict[str, object]:
|
| 173 |
return {
|
| 174 |
"max_new_tokens": int(gen_kwargs["max_new_tokens"]),
|
|
|
|
| 175 |
"use_sampling": bool(gen_kwargs["do_sample"]),
|
| 176 |
"temperature": float(gen_kwargs["temperature"]),
|
| 177 |
"top_p": float(gen_kwargs["top_p"]),
|
|
|
|
| 257 |
# ── Main tab entry point ───────────────────────────────────────────────────────
|
| 258 |
|
| 259 |
|
| 260 |
+
def _render_generation_settings(context_key: str, remote: bool) -> dict:
|
| 261 |
+
"""Render the Advanced generation settings expander and return generation kwargs."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
with st.expander("Advanced", expanded=False):
|
| 263 |
config_col1, config_col2 = st.columns([2, 1])
|
| 264 |
with config_col1:
|
|
|
|
| 344 |
if remote:
|
| 345 |
st.caption("Seed is local-only and disabled for remote runs.")
|
| 346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
do_sample = bool(use_sampling)
|
| 348 |
generation_seed = seed if do_sample and seed is not None and not remote else None
|
| 349 |
+
return {
|
| 350 |
"max_new_tokens": int(max_new_tokens),
|
| 351 |
"do_sample": do_sample,
|
| 352 |
"temperature": temperature,
|
|
|
|
| 355 |
"repetition_penalty": repetition_penalty,
|
| 356 |
"seed": generation_seed,
|
| 357 |
}
|
|
|
|
| 358 |
|
| 359 |
|
| 360 |
def render_chat_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
|
|
|
| 390 |
st.info("Try a different dataset source or upload a non-empty personas file.")
|
| 391 |
return
|
| 392 |
|
| 393 |
+
gen_kwargs = _render_generation_settings(context_key, remote)
|
| 394 |
|
| 395 |
# ── Mode toggle ───────────────────────────────────────────────────────────
|
| 396 |
compare_key = widget_key(context_key, "compare_mode")
|
|
|
|
| 415 |
dataset_source,
|
| 416 |
personas,
|
| 417 |
gen_kwargs,
|
|
|
|
| 418 |
)
|
| 419 |
return
|
| 420 |
|
|
|
|
| 496 |
prompt_mode=prompt_mode,
|
| 497 |
system_prompt=active_system_prompt,
|
| 498 |
messages=chat_state["messages"],
|
| 499 |
+
generation=generation_dict(gen_kwargs),
|
| 500 |
)
|
| 501 |
st.toast("Exported", icon=":material/check:")
|
| 502 |
with rst_col:
|
tabs/compare.py
CHANGED
|
@@ -124,18 +124,17 @@ def _render_mask_strategy_select(scope: str) -> MaskStrategy:
|
|
| 124 |
MaskStrategy.ANSWER_MEAN.value,
|
| 125 |
)
|
| 126 |
strategies = list(MaskStrategy)
|
| 127 |
-
default_index = next(
|
| 128 |
-
(
|
| 129 |
-
idx
|
| 130 |
-
for idx, strategy in enumerate(strategies)
|
| 131 |
-
if strategy.value == last_strategy
|
| 132 |
-
),
|
| 133 |
-
0,
|
| 134 |
-
)
|
| 135 |
selected = st.selectbox(
|
| 136 |
"Mask strategy",
|
| 137 |
options=strategies,
|
| 138 |
-
index=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
format_func=lambda strategy: strategy.value.replace("_", " ").title(),
|
| 140 |
key=widget_key("load", "mask_strategy", scope),
|
| 141 |
help="Which extracted activation artifact set to load.",
|
|
|
|
| 124 |
MaskStrategy.ANSWER_MEAN.value,
|
| 125 |
)
|
| 126 |
strategies = list(MaskStrategy)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
selected = st.selectbox(
|
| 128 |
"Mask strategy",
|
| 129 |
options=strategies,
|
| 130 |
+
index=next(
|
| 131 |
+
(
|
| 132 |
+
idx
|
| 133 |
+
for idx, strategy in enumerate(strategies)
|
| 134 |
+
if strategy.value == last_strategy
|
| 135 |
+
),
|
| 136 |
+
0,
|
| 137 |
+
),
|
| 138 |
format_func=lambda strategy: strategy.value.replace("_", " ").title(),
|
| 139 |
key=widget_key("load", "mask_strategy", scope),
|
| 140 |
help="Which extracted activation artifact set to load.",
|
tabs/compare_chat.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from nnterp import StandardizedTransformer
|
| 3 |
from persona_data.synth_persona import PersonaData
|
|
@@ -19,22 +21,26 @@ from .chat import (
|
|
| 19 |
)
|
| 20 |
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
def _generate_panel_reply(
|
|
@@ -61,7 +67,6 @@ def render_compare_mode(
|
|
| 61 |
dataset_source: str,
|
| 62 |
personas: list[PersonaData],
|
| 63 |
gen_kwargs: dict,
|
| 64 |
-
advanced_generation: bool,
|
| 65 |
) -> None:
|
| 66 |
"""Render the full side-by-side comparison UI."""
|
| 67 |
model: StandardizedTransformer | None = None
|
|
@@ -85,19 +90,15 @@ def render_compare_mode(
|
|
| 85 |
),
|
| 86 |
)
|
| 87 |
|
| 88 |
-
|
| 89 |
-
left_panel_key = widget_key(context_key, "cmp_left")
|
| 90 |
-
right_panel_key = widget_key(context_key, "cmp_right")
|
| 91 |
-
left_prompt_key = widget_key(left_panel_key, "custom_prompt")
|
| 92 |
-
right_prompt_key = widget_key(right_panel_key, "custom_prompt")
|
| 93 |
-
left_edit_key = widget_key(left_panel_key, "edit_idx")
|
| 94 |
-
right_edit_key = widget_key(right_panel_key, "edit_idx")
|
| 95 |
-
left_pending_key = widget_key(left_panel_key, "pending_regen")
|
| 96 |
-
right_pending_key = widget_key(right_panel_key, "pending_regen")
|
| 97 |
-
|
| 98 |
-
def render_panel(side: str) -> tuple[dict, object, str | None, str, PersonaData]:
|
| 99 |
panel_key = widget_key(context_key, f"cmp_{side}")
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
# Carry over persona / prompt selections across model or remote switches.
|
| 103 |
persist_persona_key = f"chat:last_cmp_{side}_persona"
|
|
@@ -106,10 +107,6 @@ def render_compare_mode(
|
|
| 106 |
state["persona_id"] = st.session_state.get(persist_persona_key)
|
| 107 |
state["prompt_mode"] = st.session_state.get(persist_prompt_key, "templated")
|
| 108 |
|
| 109 |
-
prompt_key = widget_key(panel_key, "custom_prompt")
|
| 110 |
-
edit_key = widget_key(panel_key, "edit_idx")
|
| 111 |
-
pending_regen_key = widget_key(panel_key, "pending_regen")
|
| 112 |
-
|
| 113 |
selected_persona, prompt_mode, changed = render_persona_prompt_controls(
|
| 114 |
personas,
|
| 115 |
state["persona_id"],
|
|
@@ -122,11 +119,7 @@ def render_compare_mode(
|
|
| 122 |
|
| 123 |
if changed:
|
| 124 |
reset_chat_context_state(
|
| 125 |
-
state,
|
| 126 |
-
selected_persona.id,
|
| 127 |
-
prompt_mode,
|
| 128 |
-
prompt_key,
|
| 129 |
-
pending_regen_key,
|
| 130 |
)
|
| 131 |
st.session_state.pop(edit_key, None)
|
| 132 |
|
|
@@ -137,83 +130,57 @@ def render_compare_mode(
|
|
| 137 |
chat_log = st.container()
|
| 138 |
with chat_log:
|
| 139 |
active_system_prompt = render_system_prompt(
|
| 140 |
-
prompt_key,
|
| 141 |
-
prompt_mode,
|
| 142 |
-
active_system_prompt,
|
| 143 |
)
|
| 144 |
-
return (
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
selected_persona,
|
|
|
|
|
|
|
|
|
|
| 150 |
)
|
| 151 |
|
|
|
|
| 152 |
with left_col:
|
| 153 |
-
|
| 154 |
-
"left"
|
| 155 |
-
)
|
| 156 |
with right_col:
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
)
|
| 160 |
-
|
| 161 |
-
panels = [
|
| 162 |
-
(
|
| 163 |
-
left_state,
|
| 164 |
-
left_log,
|
| 165 |
-
left_prompt,
|
| 166 |
-
left_pending,
|
| 167 |
-
left_edit_key,
|
| 168 |
-
left_persona,
|
| 169 |
-
),
|
| 170 |
-
(
|
| 171 |
-
right_state,
|
| 172 |
-
right_log,
|
| 173 |
-
right_prompt,
|
| 174 |
-
right_pending,
|
| 175 |
-
right_edit_key,
|
| 176 |
-
right_persona,
|
| 177 |
-
),
|
| 178 |
-
]
|
| 179 |
|
| 180 |
# Handle per-panel regeneration triggered by message edits
|
| 181 |
-
regen_panels = [
|
| 182 |
-
(panel_state, panel_log, panel_prompt)
|
| 183 |
-
for panel_state, panel_log, panel_prompt, p_pending, _panel_edit_key, _ in panels
|
| 184 |
-
if st.session_state.pop(p_pending, False)
|
| 185 |
-
]
|
| 186 |
if regen_panels:
|
| 187 |
model = _get_model()
|
| 188 |
|
| 189 |
results: list[ChatReply | Exception] = []
|
| 190 |
with st.spinner("Regenerating..."):
|
| 191 |
-
for
|
| 192 |
try:
|
| 193 |
results.append(
|
| 194 |
_generate_panel_reply(
|
| 195 |
model=model,
|
| 196 |
remote=remote,
|
| 197 |
-
panel_state=
|
| 198 |
-
panel_prompt=
|
| 199 |
gen_kwargs=gen_kwargs,
|
| 200 |
)
|
| 201 |
)
|
| 202 |
except Exception as exc:
|
| 203 |
results.append(exc)
|
| 204 |
|
| 205 |
-
for
|
| 206 |
-
regen_panels, results
|
| 207 |
-
):
|
| 208 |
if isinstance(result, Exception):
|
| 209 |
-
with
|
| 210 |
st.error(f"Generation failed: {result}")
|
| 211 |
-
|
| 212 |
continue
|
| 213 |
-
|
| 214 |
{"role": "assistant", "content": result.text}
|
| 215 |
)
|
| 216 |
-
|
| 217 |
result.past_key_values if not remote else None
|
| 218 |
)
|
| 219 |
st.rerun()
|
|
@@ -222,28 +189,28 @@ def render_compare_mode(
|
|
| 222 |
if contrast_enabled:
|
| 223 |
pending_edits: list[tuple[int, int]] = [
|
| 224 |
(panel_idx, msg_idx)
|
| 225 |
-
for panel_idx,
|
| 226 |
-
for msg_idx, msg in enumerate(
|
| 227 |
if msg.get("_needs_contrast") and msg.get("role") == "assistant"
|
| 228 |
]
|
| 229 |
if pending_edits:
|
| 230 |
model = _get_model()
|
| 231 |
-
label_a = persona_label(
|
| 232 |
-
label_b = persona_label(
|
| 233 |
with st.spinner("Recomputing token contrast…"):
|
| 234 |
for panel_idx, msg_idx in pending_edits:
|
| 235 |
-
|
| 236 |
-
msg =
|
| 237 |
-
if msg_idx >= len(
|
| 238 |
-
|
| 239 |
):
|
| 240 |
msg.pop("_needs_contrast", None)
|
| 241 |
continue
|
| 242 |
context_a = build_chat_messages(
|
| 243 |
-
|
| 244 |
)
|
| 245 |
context_b = build_chat_messages(
|
| 246 |
-
|
| 247 |
)
|
| 248 |
try:
|
| 249 |
response_ids = model.tokenizer(
|
|
@@ -267,20 +234,13 @@ def render_compare_mode(
|
|
| 267 |
msg.pop("_needs_contrast", None)
|
| 268 |
st.rerun()
|
| 269 |
|
| 270 |
-
for
|
| 271 |
-
panel_state,
|
| 272 |
-
panel_log,
|
| 273 |
-
_panel_prompt,
|
| 274 |
-
panel_pending,
|
| 275 |
-
panel_edit_key,
|
| 276 |
-
_,
|
| 277 |
-
) in panels:
|
| 278 |
render_chat_window(
|
| 279 |
-
chat_log=
|
| 280 |
-
messages=
|
| 281 |
-
chat_state=
|
| 282 |
-
edit_key=
|
| 283 |
-
pending_key=
|
| 284 |
show_contrast=contrast_enabled,
|
| 285 |
edit_column_ratio=(10, 1),
|
| 286 |
)
|
|
@@ -298,20 +258,17 @@ def render_compare_mode(
|
|
| 298 |
key=widget_key(context_key, "cmp_export"),
|
| 299 |
help="Export both chats",
|
| 300 |
):
|
| 301 |
-
for
|
| 302 |
-
("left", left_state, left_prompt, left_persona),
|
| 303 |
-
("right", right_state, right_prompt, right_persona),
|
| 304 |
-
):
|
| 305 |
save_chat_export(
|
| 306 |
model_name=model_name,
|
| 307 |
dataset_source=dataset_source,
|
| 308 |
-
persona_id=
|
| 309 |
-
persona_name=getattr(
|
| 310 |
-
prompt_mode=
|
| 311 |
-
system_prompt=
|
| 312 |
-
messages=
|
| 313 |
-
generation=generation_dict(gen_kwargs
|
| 314 |
-
panel_label=side,
|
| 315 |
)
|
| 316 |
st.toast("Exported", icon=":material/check:")
|
| 317 |
with rst_col:
|
|
@@ -326,55 +283,21 @@ def render_compare_mode(
|
|
| 326 |
help="Reset chat",
|
| 327 |
key=popover_key,
|
| 328 |
):
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
left_state["prompt_mode"],
|
| 338 |
-
left_prompt_key,
|
| 339 |
-
left_pending_key,
|
| 340 |
-
)
|
| 341 |
-
st.session_state[reset_menu_nonce_key] += 1
|
| 342 |
-
st.rerun()
|
| 343 |
-
if st.button(
|
| 344 |
-
"Reset right",
|
| 345 |
-
key=widget_key(context_key, "cmp_reset_right"),
|
| 346 |
-
):
|
| 347 |
-
_reset_compare_panel(
|
| 348 |
-
right_state,
|
| 349 |
-
right_edit_key,
|
| 350 |
-
right_persona.id,
|
| 351 |
-
right_state["prompt_mode"],
|
| 352 |
-
right_prompt_key,
|
| 353 |
-
right_pending_key,
|
| 354 |
-
)
|
| 355 |
-
st.session_state[reset_menu_nonce_key] += 1
|
| 356 |
-
st.rerun()
|
| 357 |
if st.button(
|
| 358 |
"Reset both",
|
| 359 |
key=widget_key(context_key, "cmp_reset_both"),
|
| 360 |
type="primary",
|
| 361 |
):
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
left_edit_key,
|
| 365 |
-
left_persona.id,
|
| 366 |
-
left_state["prompt_mode"],
|
| 367 |
-
left_prompt_key,
|
| 368 |
-
left_pending_key,
|
| 369 |
-
)
|
| 370 |
-
_reset_compare_panel(
|
| 371 |
-
right_state,
|
| 372 |
-
right_edit_key,
|
| 373 |
-
right_persona.id,
|
| 374 |
-
right_state["prompt_mode"],
|
| 375 |
-
right_prompt_key,
|
| 376 |
-
right_pending_key,
|
| 377 |
-
)
|
| 378 |
st.session_state[reset_menu_nonce_key] += 1
|
| 379 |
st.rerun()
|
| 380 |
|
|
@@ -388,36 +311,27 @@ def render_compare_mode(
|
|
| 388 |
|
| 389 |
model = cached_model(model_name=model_name, remote=remote)
|
| 390 |
|
| 391 |
-
for
|
| 392 |
-
|
| 393 |
-
with
|
| 394 |
render_chat_message({"role": "user", "content": user_prompt})
|
| 395 |
|
| 396 |
# Snapshot contexts before the new assistant turn is appended (needed for contrast).
|
| 397 |
pre_gen_contexts = [
|
| 398 |
-
build_chat_messages(
|
| 399 |
-
for panel_state, _panel_log, panel_prompt, _p_pending, _panel_edit_key, _ in panels
|
| 400 |
]
|
| 401 |
|
| 402 |
results: list[ChatReply | Exception] = []
|
| 403 |
with st.spinner("Generating..."):
|
| 404 |
-
#
|
| 405 |
-
|
| 406 |
-
for (
|
| 407 |
-
panel_state,
|
| 408 |
-
_panel_log,
|
| 409 |
-
panel_prompt,
|
| 410 |
-
_p_pending,
|
| 411 |
-
_panel_edit_key,
|
| 412 |
-
_,
|
| 413 |
-
) in panels:
|
| 414 |
try:
|
| 415 |
results.append(
|
| 416 |
_generate_panel_reply(
|
| 417 |
model=model,
|
| 418 |
remote=remote,
|
| 419 |
-
panel_state=
|
| 420 |
-
panel_prompt=
|
| 421 |
gen_kwargs=gen_kwargs,
|
| 422 |
)
|
| 423 |
)
|
|
@@ -425,23 +339,16 @@ def render_compare_mode(
|
|
| 425 |
results.append(exc)
|
| 426 |
|
| 427 |
valid_results: list[ChatReply | None] = []
|
| 428 |
-
for (
|
| 429 |
-
panel_state,
|
| 430 |
-
panel_log,
|
| 431 |
-
_panel_prompt,
|
| 432 |
-
_p_pending,
|
| 433 |
-
_panel_edit_key,
|
| 434 |
-
_,
|
| 435 |
-
), result in zip(panels, results):
|
| 436 |
if isinstance(result, Exception):
|
| 437 |
-
with
|
| 438 |
st.error(f"Generation failed: {result}")
|
| 439 |
-
|
| 440 |
valid_results.append(None)
|
| 441 |
continue
|
| 442 |
|
| 443 |
-
|
| 444 |
-
|
| 445 |
valid_results.append(result)
|
| 446 |
|
| 447 |
# Compute contrastive token coloring when both panels succeeded.
|
|
@@ -458,14 +365,14 @@ def render_compare_mode(
|
|
| 458 |
context_b=pre_gen_contexts[1],
|
| 459 |
response_ids_a=valid_results[0].generated_ids,
|
| 460 |
response_ids_b=valid_results[1].generated_ids,
|
| 461 |
-
label_a=persona_label(
|
| 462 |
-
label_b=persona_label(
|
| 463 |
remote=remote,
|
| 464 |
)
|
| 465 |
if tc_a is not None:
|
| 466 |
-
|
| 467 |
if tc_b is not None:
|
| 468 |
-
|
| 469 |
except Exception as exc:
|
| 470 |
st.warning(f"Token contrast failed: {exc}")
|
| 471 |
|
|
|
|
| 1 |
+
from typing import Any, NamedTuple
|
| 2 |
+
|
| 3 |
import streamlit as st
|
| 4 |
from nnterp import StandardizedTransformer
|
| 5 |
from persona_data.synth_persona import PersonaData
|
|
|
|
| 21 |
)
|
| 22 |
|
| 23 |
|
| 24 |
+
class ComparePanel(NamedTuple):
|
| 25 |
+
side: str
|
| 26 |
+
state: dict[str, object]
|
| 27 |
+
log: Any
|
| 28 |
+
prompt: str | None
|
| 29 |
+
persona: PersonaData
|
| 30 |
+
prompt_key: str
|
| 31 |
+
edit_key: str
|
| 32 |
+
pending_key: str
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _reset_compare_panel(panel: ComparePanel) -> None:
|
| 36 |
+
reset_chat_context_state(
|
| 37 |
+
panel.state,
|
| 38 |
+
panel.persona.id,
|
| 39 |
+
panel.state["prompt_mode"],
|
| 40 |
+
panel.prompt_key,
|
| 41 |
+
panel.pending_key,
|
| 42 |
+
)
|
| 43 |
+
st.session_state.pop(panel.edit_key, None)
|
| 44 |
|
| 45 |
|
| 46 |
def _generate_panel_reply(
|
|
|
|
| 67 |
dataset_source: str,
|
| 68 |
personas: list[PersonaData],
|
| 69 |
gen_kwargs: dict,
|
|
|
|
| 70 |
) -> None:
|
| 71 |
"""Render the full side-by-side comparison UI."""
|
| 72 |
model: StandardizedTransformer | None = None
|
|
|
|
| 90 |
),
|
| 91 |
)
|
| 92 |
|
| 93 |
+
def render_panel(side: str) -> ComparePanel:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
panel_key = widget_key(context_key, f"cmp_{side}")
|
| 95 |
+
if panel_key not in st.session_state:
|
| 96 |
+
st.session_state[panel_key] = default_chat_state()
|
| 97 |
+
state = st.session_state[panel_key]
|
| 98 |
+
|
| 99 |
+
prompt_key = widget_key(panel_key, "custom_prompt")
|
| 100 |
+
edit_key = widget_key(panel_key, "edit_idx")
|
| 101 |
+
pending_key = widget_key(panel_key, "pending_regen")
|
| 102 |
|
| 103 |
# Carry over persona / prompt selections across model or remote switches.
|
| 104 |
persist_persona_key = f"chat:last_cmp_{side}_persona"
|
|
|
|
| 107 |
state["persona_id"] = st.session_state.get(persist_persona_key)
|
| 108 |
state["prompt_mode"] = st.session_state.get(persist_prompt_key, "templated")
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
selected_persona, prompt_mode, changed = render_persona_prompt_controls(
|
| 111 |
personas,
|
| 112 |
state["persona_id"],
|
|
|
|
| 119 |
|
| 120 |
if changed:
|
| 121 |
reset_chat_context_state(
|
| 122 |
+
state, selected_persona.id, prompt_mode, prompt_key, pending_key
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
)
|
| 124 |
st.session_state.pop(edit_key, None)
|
| 125 |
|
|
|
|
| 130 |
chat_log = st.container()
|
| 131 |
with chat_log:
|
| 132 |
active_system_prompt = render_system_prompt(
|
| 133 |
+
prompt_key, prompt_mode, active_system_prompt
|
|
|
|
|
|
|
| 134 |
)
|
| 135 |
+
return ComparePanel(
|
| 136 |
+
side=side,
|
| 137 |
+
state=state,
|
| 138 |
+
log=chat_log,
|
| 139 |
+
prompt=active_system_prompt,
|
| 140 |
+
persona=selected_persona,
|
| 141 |
+
prompt_key=prompt_key,
|
| 142 |
+
edit_key=edit_key,
|
| 143 |
+
pending_key=pending_key,
|
| 144 |
)
|
| 145 |
|
| 146 |
+
left_col, right_col = st.columns(2)
|
| 147 |
with left_col:
|
| 148 |
+
left = render_panel("left")
|
|
|
|
|
|
|
| 149 |
with right_col:
|
| 150 |
+
right = render_panel("right")
|
| 151 |
+
panels: list[ComparePanel] = [left, right]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
# Handle per-panel regeneration triggered by message edits
|
| 154 |
+
regen_panels = [p for p in panels if st.session_state.pop(p.pending_key, False)]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
if regen_panels:
|
| 156 |
model = _get_model()
|
| 157 |
|
| 158 |
results: list[ChatReply | Exception] = []
|
| 159 |
with st.spinner("Regenerating..."):
|
| 160 |
+
for panel in regen_panels:
|
| 161 |
try:
|
| 162 |
results.append(
|
| 163 |
_generate_panel_reply(
|
| 164 |
model=model,
|
| 165 |
remote=remote,
|
| 166 |
+
panel_state=panel.state,
|
| 167 |
+
panel_prompt=panel.prompt,
|
| 168 |
gen_kwargs=gen_kwargs,
|
| 169 |
)
|
| 170 |
)
|
| 171 |
except Exception as exc:
|
| 172 |
results.append(exc)
|
| 173 |
|
| 174 |
+
for panel, result in zip(regen_panels, results):
|
|
|
|
|
|
|
| 175 |
if isinstance(result, Exception):
|
| 176 |
+
with panel.log:
|
| 177 |
st.error(f"Generation failed: {result}")
|
| 178 |
+
panel.state["messages"].pop()
|
| 179 |
continue
|
| 180 |
+
panel.state["messages"].append(
|
| 181 |
{"role": "assistant", "content": result.text}
|
| 182 |
)
|
| 183 |
+
panel.state["past_key_values"] = (
|
| 184 |
result.past_key_values if not remote else None
|
| 185 |
)
|
| 186 |
st.rerun()
|
|
|
|
| 189 |
if contrast_enabled:
|
| 190 |
pending_edits: list[tuple[int, int]] = [
|
| 191 |
(panel_idx, msg_idx)
|
| 192 |
+
for panel_idx, panel in enumerate(panels)
|
| 193 |
+
for msg_idx, msg in enumerate(panel.state["messages"])
|
| 194 |
if msg.get("_needs_contrast") and msg.get("role") == "assistant"
|
| 195 |
]
|
| 196 |
if pending_edits:
|
| 197 |
model = _get_model()
|
| 198 |
+
label_a = persona_label(left.persona)
|
| 199 |
+
label_b = persona_label(right.persona)
|
| 200 |
with st.spinner("Recomputing token contrast…"):
|
| 201 |
for panel_idx, msg_idx in pending_edits:
|
| 202 |
+
panel = panels[panel_idx]
|
| 203 |
+
msg = panel.state["messages"][msg_idx]
|
| 204 |
+
if msg_idx >= len(left.state["messages"]) or msg_idx >= len(
|
| 205 |
+
right.state["messages"]
|
| 206 |
):
|
| 207 |
msg.pop("_needs_contrast", None)
|
| 208 |
continue
|
| 209 |
context_a = build_chat_messages(
|
| 210 |
+
left.prompt, left.state["messages"][:msg_idx]
|
| 211 |
)
|
| 212 |
context_b = build_chat_messages(
|
| 213 |
+
right.prompt, right.state["messages"][:msg_idx]
|
| 214 |
)
|
| 215 |
try:
|
| 216 |
response_ids = model.tokenizer(
|
|
|
|
| 234 |
msg.pop("_needs_contrast", None)
|
| 235 |
st.rerun()
|
| 236 |
|
| 237 |
+
for panel in panels:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
render_chat_window(
|
| 239 |
+
chat_log=panel.log,
|
| 240 |
+
messages=panel.state["messages"],
|
| 241 |
+
chat_state=panel.state,
|
| 242 |
+
edit_key=panel.edit_key,
|
| 243 |
+
pending_key=panel.pending_key,
|
| 244 |
show_contrast=contrast_enabled,
|
| 245 |
edit_column_ratio=(10, 1),
|
| 246 |
)
|
|
|
|
| 258 |
key=widget_key(context_key, "cmp_export"),
|
| 259 |
help="Export both chats",
|
| 260 |
):
|
| 261 |
+
for panel in panels:
|
|
|
|
|
|
|
|
|
|
| 262 |
save_chat_export(
|
| 263 |
model_name=model_name,
|
| 264 |
dataset_source=dataset_source,
|
| 265 |
+
persona_id=panel.persona.id,
|
| 266 |
+
persona_name=getattr(panel.persona, "name", None),
|
| 267 |
+
prompt_mode=panel.state["prompt_mode"],
|
| 268 |
+
system_prompt=panel.prompt,
|
| 269 |
+
messages=panel.state["messages"],
|
| 270 |
+
generation=generation_dict(gen_kwargs),
|
| 271 |
+
panel_label=panel.side,
|
| 272 |
)
|
| 273 |
st.toast("Exported", icon=":material/check:")
|
| 274 |
with rst_col:
|
|
|
|
| 283 |
help="Reset chat",
|
| 284 |
key=popover_key,
|
| 285 |
):
|
| 286 |
+
for panel in panels:
|
| 287 |
+
if st.button(
|
| 288 |
+
f"Reset {panel.side}",
|
| 289 |
+
key=widget_key(context_key, f"cmp_reset_{panel.side}"),
|
| 290 |
+
):
|
| 291 |
+
_reset_compare_panel(panel)
|
| 292 |
+
st.session_state[reset_menu_nonce_key] += 1
|
| 293 |
+
st.rerun()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
if st.button(
|
| 295 |
"Reset both",
|
| 296 |
key=widget_key(context_key, "cmp_reset_both"),
|
| 297 |
type="primary",
|
| 298 |
):
|
| 299 |
+
for panel in panels:
|
| 300 |
+
_reset_compare_panel(panel)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
st.session_state[reset_menu_nonce_key] += 1
|
| 302 |
st.rerun()
|
| 303 |
|
|
|
|
| 311 |
|
| 312 |
model = cached_model(model_name=model_name, remote=remote)
|
| 313 |
|
| 314 |
+
for panel in panels:
|
| 315 |
+
panel.state["messages"].append({"role": "user", "content": user_prompt})
|
| 316 |
+
with panel.log:
|
| 317 |
render_chat_message({"role": "user", "content": user_prompt})
|
| 318 |
|
| 319 |
# Snapshot contexts before the new assistant turn is appended (needed for contrast).
|
| 320 |
pre_gen_contexts = [
|
| 321 |
+
build_chat_messages(panel.prompt, panel.state["messages"]) for panel in panels
|
|
|
|
| 322 |
]
|
| 323 |
|
| 324 |
results: list[ChatReply | Exception] = []
|
| 325 |
with st.spinner("Generating..."):
|
| 326 |
+
# Sequential generation keeps both panels using model/session state safely.
|
| 327 |
+
for panel in panels:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
try:
|
| 329 |
results.append(
|
| 330 |
_generate_panel_reply(
|
| 331 |
model=model,
|
| 332 |
remote=remote,
|
| 333 |
+
panel_state=panel.state,
|
| 334 |
+
panel_prompt=panel.prompt,
|
| 335 |
gen_kwargs=gen_kwargs,
|
| 336 |
)
|
| 337 |
)
|
|
|
|
| 339 |
results.append(exc)
|
| 340 |
|
| 341 |
valid_results: list[ChatReply | None] = []
|
| 342 |
+
for panel, result in zip(panels, results):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
if isinstance(result, Exception):
|
| 344 |
+
with panel.log:
|
| 345 |
st.error(f"Generation failed: {result}")
|
| 346 |
+
panel.state["messages"].pop()
|
| 347 |
valid_results.append(None)
|
| 348 |
continue
|
| 349 |
|
| 350 |
+
panel.state["messages"].append({"role": "assistant", "content": result.text})
|
| 351 |
+
panel.state["past_key_values"] = result.past_key_values if not remote else None
|
| 352 |
valid_results.append(result)
|
| 353 |
|
| 354 |
# Compute contrastive token coloring when both panels succeeded.
|
|
|
|
| 365 |
context_b=pre_gen_contexts[1],
|
| 366 |
response_ids_a=valid_results[0].generated_ids,
|
| 367 |
response_ids_b=valid_results[1].generated_ids,
|
| 368 |
+
label_a=persona_label(left.persona),
|
| 369 |
+
label_b=persona_label(right.persona),
|
| 370 |
remote=remote,
|
| 371 |
)
|
| 372 |
if tc_a is not None:
|
| 373 |
+
left.state["messages"][-1]["_contrast"] = tc_a
|
| 374 |
if tc_b is not None:
|
| 375 |
+
right.state["messages"][-1]["_contrast"] = tc_b
|
| 376 |
except Exception as exc:
|
| 377 |
st.warning(f"Token contrast failed: {exc}")
|
| 378 |
|
tabs/extract.py
CHANGED
|
@@ -32,11 +32,18 @@ _LAST_VARIANTS_KEY = "extract:last_variants"
|
|
| 32 |
_LAST_BASELINE_KEY = "extract:last_include_baseline"
|
| 33 |
_LAST_PERSONA_IDS_KEY = "extract:last_persona_ids"
|
| 34 |
_LAST_QA_TYPE_KEY = "extract:last_qa_type"
|
| 35 |
-
|
|
|
|
| 36 |
_LAST_MAX_QUESTIONS_KEY = "extract:last_max_questions"
|
| 37 |
_LAST_MASK_STRATEGY_KEY = "extract:last_mask_strategy"
|
| 38 |
|
| 39 |
_QA_TYPE_OPTIONS = ["all", "explicit", "implicit"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
def _build_run_plan(
|
|
@@ -106,31 +113,25 @@ def _render_sample_tokens_html(p, tokenizer, *, max_tokens: int = 200) -> str:
|
|
| 106 |
)
|
| 107 |
|
| 108 |
|
| 109 |
-
def _render_local_dataset_uploads() -> None:
|
| 110 |
-
"""Render file inputs for local dataset uploads."""
|
| 111 |
-
|
| 112 |
-
with st.expander("Local dataset upload", expanded=True):
|
| 113 |
-
st.file_uploader(
|
| 114 |
-
"personas.jsonl",
|
| 115 |
-
type=["jsonl"],
|
| 116 |
-
key="extract__personas_file",
|
| 117 |
-
help="Expected fields: id, persona, templated_view, biography_view",
|
| 118 |
-
)
|
| 119 |
-
st.file_uploader(
|
| 120 |
-
"qa.jsonl",
|
| 121 |
-
type=["jsonl"],
|
| 122 |
-
key="extract__qa_file",
|
| 123 |
-
help="Expected fields: id, qid, type, question, answer, difficulty",
|
| 124 |
-
)
|
| 125 |
-
|
| 126 |
-
|
| 127 |
def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
| 128 |
"""Render the extraction tab."""
|
| 129 |
|
| 130 |
st.title("Extract")
|
| 131 |
|
| 132 |
if dataset_source == "Local JSONL upload":
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
last_variants = st.session_state.get(
|
| 136 |
_LAST_VARIANTS_KEY, [*PERSONA_VARIANTS, BASELINE_PERSONA_ID]
|
|
@@ -146,12 +147,12 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 146 |
key=_extract_widget_key(model_name, remote, dataset_source, "persona_variants"),
|
| 147 |
help="Extract these variants for each selected persona.",
|
| 148 |
)
|
| 149 |
-
include_baseline_default = st.session_state.get(
|
| 150 |
-
_LAST_BASELINE_KEY, BASELINE_PERSONA_ID in last_variants
|
| 151 |
-
)
|
| 152 |
include_baseline = st.checkbox(
|
| 153 |
"Extract Assistant baseline",
|
| 154 |
-
value=
|
|
|
|
|
|
|
|
|
|
| 155 |
key=_extract_widget_key(model_name, remote, dataset_source, "baseline"),
|
| 156 |
help=(
|
| 157 |
"Extracts the persona-less Assistant prompt once using the first "
|
|
@@ -214,18 +215,15 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 214 |
with st.expander("Advanced", expanded=False):
|
| 215 |
st.caption("Filters")
|
| 216 |
|
| 217 |
-
col1, col2,
|
| 218 |
with col1:
|
| 219 |
-
last_qa_type = st.session_state.get(_LAST_QA_TYPE_KEY, "all")
|
| 220 |
-
qa_type_index = (
|
| 221 |
-
_QA_TYPE_OPTIONS.index(last_qa_type)
|
| 222 |
-
if last_qa_type in _QA_TYPE_OPTIONS
|
| 223 |
-
else 0
|
| 224 |
-
)
|
| 225 |
qa_type_select = st.selectbox(
|
| 226 |
"QA type",
|
| 227 |
options=_QA_TYPE_OPTIONS,
|
| 228 |
-
index=
|
|
|
|
|
|
|
|
|
|
| 229 |
key=_extract_widget_key(
|
| 230 |
model_name, remote, dataset_source, "qa_type_select"
|
| 231 |
),
|
|
@@ -237,39 +235,68 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 237 |
else None
|
| 238 |
)
|
| 239 |
with col2:
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
key=_extract_widget_key(
|
| 251 |
-
model_name,
|
|
|
|
|
|
|
|
|
|
| 252 |
),
|
| 253 |
)
|
| 254 |
-
st.session_state[
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
st.caption("Extraction settings")
|
| 258 |
last_strategy = st.session_state.get(
|
| 259 |
-
_LAST_MASK_STRATEGY_KEY,
|
|
|
|
| 260 |
)
|
| 261 |
strategy_options = list(MaskStrategy)
|
| 262 |
-
strategy_index = next(
|
| 263 |
-
(i for i, s in enumerate(strategy_options) if s.value == last_strategy),
|
| 264 |
-
0,
|
| 265 |
-
)
|
| 266 |
mask_strategy = st.selectbox(
|
| 267 |
"Mask strategy",
|
| 268 |
options=strategy_options,
|
| 269 |
-
index=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
format_func=lambda s: s.value.replace("_", " ").title(),
|
| 271 |
key=_extract_widget_key(
|
| 272 |
-
model_name,
|
|
|
|
|
|
|
|
|
|
| 273 |
),
|
| 274 |
help="Which tokens contribute to the averaged hidden state.",
|
| 275 |
)
|
|
@@ -279,7 +306,10 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 279 |
for persona in selected_personas:
|
| 280 |
qa = list(
|
| 281 |
dataset.get_qa(
|
| 282 |
-
persona.id,
|
|
|
|
|
|
|
|
|
|
| 283 |
)
|
| 284 |
)
|
| 285 |
if qa:
|
|
@@ -295,13 +325,14 @@ def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> No
|
|
| 295 |
return
|
| 296 |
|
| 297 |
max_q = min(len(qa_pairs) for _, qa_pairs in runs)
|
| 298 |
-
last_max = st.session_state.get(_LAST_MAX_QUESTIONS_KEY, max_q)
|
| 299 |
-
default_max = min(max(last_max, 1), max_q)
|
| 300 |
max_questions = st.slider(
|
| 301 |
"Max questions",
|
| 302 |
min_value=1,
|
| 303 |
max_value=max_q,
|
| 304 |
-
value=
|
|
|
|
|
|
|
|
|
|
| 305 |
key=_extract_widget_key(
|
| 306 |
model_name, remote, dataset_source, "max_questions"
|
| 307 |
),
|
|
|
|
| 32 |
_LAST_BASELINE_KEY = "extract:last_include_baseline"
|
| 33 |
_LAST_PERSONA_IDS_KEY = "extract:last_persona_ids"
|
| 34 |
_LAST_QA_TYPE_KEY = "extract:last_qa_type"
|
| 35 |
+
_LAST_ITEM_TYPE_KEY = "extract:last_item_type"
|
| 36 |
+
_LAST_SCOPE_KEY = "extract:last_scope"
|
| 37 |
_LAST_MAX_QUESTIONS_KEY = "extract:last_max_questions"
|
| 38 |
_LAST_MASK_STRATEGY_KEY = "extract:last_mask_strategy"
|
| 39 |
|
| 40 |
_QA_TYPE_OPTIONS = ["all", "explicit", "implicit"]
|
| 41 |
+
_ITEM_TYPE_OPTIONS = ["all", "mcq", "frq"]
|
| 42 |
+
_SCOPE_OPTIONS = ["all", "individual", "shared"]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _option_index(options: list[str], value: str) -> int:
|
| 46 |
+
return options.index(value) if value in options else 0
|
| 47 |
|
| 48 |
|
| 49 |
def _build_run_plan(
|
|
|
|
| 113 |
)
|
| 114 |
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
def render_extract_tab(remote: bool, model_name: str, dataset_source: str) -> None:
|
| 117 |
"""Render the extraction tab."""
|
| 118 |
|
| 119 |
st.title("Extract")
|
| 120 |
|
| 121 |
if dataset_source == "Local JSONL upload":
|
| 122 |
+
with st.expander("Local dataset upload", expanded=True):
|
| 123 |
+
st.file_uploader(
|
| 124 |
+
"personas.jsonl",
|
| 125 |
+
type=["jsonl"],
|
| 126 |
+
key="extract__personas_file",
|
| 127 |
+
help="Expected fields: id, persona, templated_view, biography_view",
|
| 128 |
+
)
|
| 129 |
+
st.file_uploader(
|
| 130 |
+
"qa.jsonl",
|
| 131 |
+
type=["jsonl"],
|
| 132 |
+
key="extract__qa_file",
|
| 133 |
+
help="Expected fields: id, qid, type, item_type, scope, question, answer",
|
| 134 |
+
)
|
| 135 |
|
| 136 |
last_variants = st.session_state.get(
|
| 137 |
_LAST_VARIANTS_KEY, [*PERSONA_VARIANTS, BASELINE_PERSONA_ID]
|
|
|
|
| 147 |
key=_extract_widget_key(model_name, remote, dataset_source, "persona_variants"),
|
| 148 |
help="Extract these variants for each selected persona.",
|
| 149 |
)
|
|
|
|
|
|
|
|
|
|
| 150 |
include_baseline = st.checkbox(
|
| 151 |
"Extract Assistant baseline",
|
| 152 |
+
value=st.session_state.get(
|
| 153 |
+
_LAST_BASELINE_KEY,
|
| 154 |
+
BASELINE_PERSONA_ID in last_variants,
|
| 155 |
+
),
|
| 156 |
key=_extract_widget_key(model_name, remote, dataset_source, "baseline"),
|
| 157 |
help=(
|
| 158 |
"Extracts the persona-less Assistant prompt once using the first "
|
|
|
|
| 215 |
with st.expander("Advanced", expanded=False):
|
| 216 |
st.caption("Filters")
|
| 217 |
|
| 218 |
+
col1, col2, col3 = st.columns(3)
|
| 219 |
with col1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
qa_type_select = st.selectbox(
|
| 221 |
"QA type",
|
| 222 |
options=_QA_TYPE_OPTIONS,
|
| 223 |
+
index=_option_index(
|
| 224 |
+
_QA_TYPE_OPTIONS,
|
| 225 |
+
st.session_state.get(_LAST_QA_TYPE_KEY, "all"),
|
| 226 |
+
),
|
| 227 |
key=_extract_widget_key(
|
| 228 |
model_name, remote, dataset_source, "qa_type_select"
|
| 229 |
),
|
|
|
|
| 235 |
else None
|
| 236 |
)
|
| 237 |
with col2:
|
| 238 |
+
item_type_select = st.selectbox(
|
| 239 |
+
"Item type",
|
| 240 |
+
options=_ITEM_TYPE_OPTIONS,
|
| 241 |
+
index=_option_index(
|
| 242 |
+
_ITEM_TYPE_OPTIONS,
|
| 243 |
+
st.session_state.get(_LAST_ITEM_TYPE_KEY, "all"),
|
| 244 |
+
),
|
| 245 |
+
key=_extract_widget_key(
|
| 246 |
+
model_name, remote, dataset_source, "item_type_select"
|
| 247 |
+
),
|
| 248 |
+
)
|
| 249 |
+
st.session_state[_LAST_ITEM_TYPE_KEY] = item_type_select
|
| 250 |
+
qa_filter_item_type: Literal["mcq", "frq"] | None = (
|
| 251 |
+
cast(Literal["mcq", "frq"], item_type_select)
|
| 252 |
+
if item_type_select in ("mcq", "frq")
|
| 253 |
+
else None
|
| 254 |
+
)
|
| 255 |
+
with col3:
|
| 256 |
+
scope_select = st.selectbox(
|
| 257 |
+
"Scope",
|
| 258 |
+
options=_SCOPE_OPTIONS,
|
| 259 |
+
index=_option_index(
|
| 260 |
+
_SCOPE_OPTIONS,
|
| 261 |
+
st.session_state.get(_LAST_SCOPE_KEY, "all"),
|
| 262 |
+
),
|
| 263 |
key=_extract_widget_key(
|
| 264 |
+
model_name,
|
| 265 |
+
remote,
|
| 266 |
+
dataset_source,
|
| 267 |
+
"scope_select",
|
| 268 |
),
|
| 269 |
)
|
| 270 |
+
st.session_state[_LAST_SCOPE_KEY] = scope_select
|
| 271 |
+
qa_filter_scope: Literal["individual", "shared"] | None = (
|
| 272 |
+
cast(Literal["individual", "shared"], scope_select)
|
| 273 |
+
if scope_select in ("individual", "shared")
|
| 274 |
+
else None
|
| 275 |
+
)
|
| 276 |
|
| 277 |
st.caption("Extraction settings")
|
| 278 |
last_strategy = st.session_state.get(
|
| 279 |
+
_LAST_MASK_STRATEGY_KEY,
|
| 280 |
+
MaskStrategy.ANSWER_MEAN.value,
|
| 281 |
)
|
| 282 |
strategy_options = list(MaskStrategy)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
mask_strategy = st.selectbox(
|
| 284 |
"Mask strategy",
|
| 285 |
options=strategy_options,
|
| 286 |
+
index=next(
|
| 287 |
+
(
|
| 288 |
+
idx
|
| 289 |
+
for idx, strategy in enumerate(strategy_options)
|
| 290 |
+
if strategy.value == last_strategy
|
| 291 |
+
),
|
| 292 |
+
0,
|
| 293 |
+
),
|
| 294 |
format_func=lambda s: s.value.replace("_", " ").title(),
|
| 295 |
key=_extract_widget_key(
|
| 296 |
+
model_name,
|
| 297 |
+
remote,
|
| 298 |
+
dataset_source,
|
| 299 |
+
"mask_strategy",
|
| 300 |
),
|
| 301 |
help="Which tokens contribute to the averaged hidden state.",
|
| 302 |
)
|
|
|
|
| 306 |
for persona in selected_personas:
|
| 307 |
qa = list(
|
| 308 |
dataset.get_qa(
|
| 309 |
+
persona.id,
|
| 310 |
+
type=qa_filter_type,
|
| 311 |
+
item_type=qa_filter_item_type,
|
| 312 |
+
scope=qa_filter_scope,
|
| 313 |
)
|
| 314 |
)
|
| 315 |
if qa:
|
|
|
|
| 325 |
return
|
| 326 |
|
| 327 |
max_q = min(len(qa_pairs) for _, qa_pairs in runs)
|
|
|
|
|
|
|
| 328 |
max_questions = st.slider(
|
| 329 |
"Max questions",
|
| 330 |
min_value=1,
|
| 331 |
max_value=max_q,
|
| 332 |
+
value=min(
|
| 333 |
+
max(st.session_state.get(_LAST_MAX_QUESTIONS_KEY, max_q), 1),
|
| 334 |
+
max_q,
|
| 335 |
+
),
|
| 336 |
key=_extract_widget_key(
|
| 337 |
model_name, remote, dataset_source, "max_questions"
|
| 338 |
),
|
uv.lock
CHANGED
|
@@ -122,11 +122,11 @@ wheels = [
|
|
| 122 |
|
| 123 |
[[package]]
|
| 124 |
name = "cachetools"
|
| 125 |
-
version = "7.
|
| 126 |
source = { registry = "https://pypi.org/simple" }
|
| 127 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 128 |
wheels = [
|
| 129 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 130 |
]
|
| 131 |
|
| 132 |
[[package]]
|
|
@@ -511,15 +511,6 @@ wheels = [
|
|
| 511 |
{ url = "https://files.pythonhosted.org/packages/5d/13/ad7d7ca3808a898b4612b6fe93cde56b53f3034dcde235acb1f0e1df24c6/idna-3.13-py3-none-any.whl", hash = "sha256:892ea0cde124a99ce773decba204c5552b69c3c67ffd5f232eb7696135bc8bb3", size = 68629, upload-time = "2026-04-22T16:42:40.909Z" },
|
| 512 |
]
|
| 513 |
|
| 514 |
-
[[package]]
|
| 515 |
-
name = "iniconfig"
|
| 516 |
-
version = "2.3.0"
|
| 517 |
-
source = { registry = "https://pypi.org/simple" }
|
| 518 |
-
sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
|
| 519 |
-
wheels = [
|
| 520 |
-
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
|
| 521 |
-
]
|
| 522 |
-
|
| 523 |
[[package]]
|
| 524 |
name = "ipython"
|
| 525 |
version = "9.13.0"
|
|
@@ -565,14 +556,14 @@ wheels = [
|
|
| 565 |
|
| 566 |
[[package]]
|
| 567 |
name = "jedi"
|
| 568 |
-
version = "0.
|
| 569 |
source = { registry = "https://pypi.org/simple" }
|
| 570 |
dependencies = [
|
| 571 |
{ name = "parso" },
|
| 572 |
]
|
| 573 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 574 |
wheels = [
|
| 575 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 576 |
]
|
| 577 |
|
| 578 |
[[package]]
|
|
@@ -625,18 +616,17 @@ wheels = [
|
|
| 625 |
|
| 626 |
[[package]]
|
| 627 |
name = "kaleido"
|
| 628 |
-
version = "1.
|
| 629 |
source = { registry = "https://pypi.org/simple" }
|
| 630 |
dependencies = [
|
| 631 |
{ name = "choreographer" },
|
| 632 |
{ name = "logistro" },
|
| 633 |
{ name = "orjson" },
|
| 634 |
{ name = "packaging" },
|
| 635 |
-
{ name = "pytest-timeout" },
|
| 636 |
]
|
| 637 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 638 |
wheels = [
|
| 639 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 640 |
]
|
| 641 |
|
| 642 |
[[package]]
|
|
@@ -797,7 +787,7 @@ wheels = [
|
|
| 797 |
|
| 798 |
[[package]]
|
| 799 |
name = "nnsight"
|
| 800 |
-
version = "0.
|
| 801 |
source = { registry = "https://pypi.org/simple" }
|
| 802 |
dependencies = [
|
| 803 |
{ name = "accelerate" },
|
|
@@ -813,24 +803,24 @@ dependencies = [
|
|
| 813 |
{ name = "transformers" },
|
| 814 |
{ name = "zstandard" },
|
| 815 |
]
|
| 816 |
-
sdist = { url = "https://files.pythonhosted.org/packages/a6/
|
| 817 |
wheels = [
|
| 818 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 819 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 820 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 821 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 822 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 823 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 824 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 825 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 826 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 827 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 828 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 829 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 830 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 831 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 832 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 833 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 834 |
]
|
| 835 |
|
| 836 |
[[package]]
|
|
@@ -1200,16 +1190,16 @@ wheels = [
|
|
| 1200 |
|
| 1201 |
[[package]]
|
| 1202 |
name = "parso"
|
| 1203 |
-
version = "0.8.
|
| 1204 |
source = { registry = "https://pypi.org/simple" }
|
| 1205 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 1206 |
wheels = [
|
| 1207 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 1208 |
]
|
| 1209 |
|
| 1210 |
[[package]]
|
| 1211 |
name = "persona-data"
|
| 1212 |
-
version = "0.
|
| 1213 |
source = { registry = "https://pypi.org/simple" }
|
| 1214 |
dependencies = [
|
| 1215 |
{ name = "huggingface-hub" },
|
|
@@ -1218,9 +1208,9 @@ dependencies = [
|
|
| 1218 |
{ name = "python-dotenv" },
|
| 1219 |
{ name = "torch" },
|
| 1220 |
]
|
| 1221 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 1222 |
wheels = [
|
| 1223 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 1224 |
]
|
| 1225 |
|
| 1226 |
[[package]]
|
|
@@ -1238,8 +1228,8 @@ dependencies = [
|
|
| 1238 |
|
| 1239 |
[package.metadata]
|
| 1240 |
requires-dist = [
|
| 1241 |
-
{ name = "persona-data", specifier = ">=0.
|
| 1242 |
-
{ name = "persona-vectors",
|
| 1243 |
{ name = "plotly", specifier = ">=6.6.0" },
|
| 1244 |
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
| 1245 |
{ name = "streamlit", specifier = ">=1.44.0" },
|
|
@@ -1249,7 +1239,7 @@ requires-dist = [
|
|
| 1249 |
[[package]]
|
| 1250 |
name = "persona-vectors"
|
| 1251 |
version = "0.4.3"
|
| 1252 |
-
source = {
|
| 1253 |
dependencies = [
|
| 1254 |
{ name = "kaleido" },
|
| 1255 |
{ name = "nnsight" },
|
|
@@ -1265,9 +1255,22 @@ dependencies = [
|
|
| 1265 |
{ name = "transformers" },
|
| 1266 |
{ name = "umap-learn" },
|
| 1267 |
]
|
| 1268 |
-
|
| 1269 |
-
|
| 1270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1271 |
]
|
| 1272 |
|
| 1273 |
[[package]]
|
|
@@ -1373,15 +1376,6 @@ wheels = [
|
|
| 1373 |
{ url = "https://files.pythonhosted.org/packages/90/ad/cba91b3bcf04073e4d1655a5c1710ef3f457f56f7d1b79dcc3d72f4dd912/plotly-6.7.0-py3-none-any.whl", hash = "sha256:ac8aca1c25c663a59b5b9140a549264a5badde2e057d79b8c772ae2920e32ff0", size = 9898444, upload-time = "2026-04-09T20:36:39.812Z" },
|
| 1374 |
]
|
| 1375 |
|
| 1376 |
-
[[package]]
|
| 1377 |
-
name = "pluggy"
|
| 1378 |
-
version = "1.6.0"
|
| 1379 |
-
source = { registry = "https://pypi.org/simple" }
|
| 1380 |
-
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
|
| 1381 |
-
wheels = [
|
| 1382 |
-
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
| 1383 |
-
]
|
| 1384 |
-
|
| 1385 |
[[package]]
|
| 1386 |
name = "prompt-toolkit"
|
| 1387 |
version = "3.0.52"
|
|
@@ -1626,34 +1620,6 @@ wheels = [
|
|
| 1626 |
{ url = "https://files.pythonhosted.org/packages/b2/e6/94145d714402fd5ade00b5661f2d0ab981219e07f7db9bfa16786cdb9c04/pynndescent-0.6.0-py3-none-any.whl", hash = "sha256:dc8c74844e4c7f5cbd1e0cd6909da86fdc789e6ff4997336e344779c3d5538ef", size = 73511, upload-time = "2026-01-08T21:29:57.306Z" },
|
| 1627 |
]
|
| 1628 |
|
| 1629 |
-
[[package]]
|
| 1630 |
-
name = "pytest"
|
| 1631 |
-
version = "9.0.3"
|
| 1632 |
-
source = { registry = "https://pypi.org/simple" }
|
| 1633 |
-
dependencies = [
|
| 1634 |
-
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
| 1635 |
-
{ name = "iniconfig" },
|
| 1636 |
-
{ name = "packaging" },
|
| 1637 |
-
{ name = "pluggy" },
|
| 1638 |
-
{ name = "pygments" },
|
| 1639 |
-
]
|
| 1640 |
-
sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" }
|
| 1641 |
-
wheels = [
|
| 1642 |
-
{ url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" },
|
| 1643 |
-
]
|
| 1644 |
-
|
| 1645 |
-
[[package]]
|
| 1646 |
-
name = "pytest-timeout"
|
| 1647 |
-
version = "2.4.0"
|
| 1648 |
-
source = { registry = "https://pypi.org/simple" }
|
| 1649 |
-
dependencies = [
|
| 1650 |
-
{ name = "pytest" },
|
| 1651 |
-
]
|
| 1652 |
-
sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973, upload-time = "2025-05-05T19:44:34.99Z" }
|
| 1653 |
-
wheels = [
|
| 1654 |
-
{ url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" },
|
| 1655 |
-
]
|
| 1656 |
-
|
| 1657 |
[[package]]
|
| 1658 |
name = "python-dateutil"
|
| 1659 |
version = "2.9.0.post0"
|
|
@@ -2564,11 +2530,11 @@ wheels = [
|
|
| 2564 |
|
| 2565 |
[[package]]
|
| 2566 |
name = "wcwidth"
|
| 2567 |
-
version = "0.
|
| 2568 |
source = { registry = "https://pypi.org/simple" }
|
| 2569 |
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
| 2570 |
wheels = [
|
| 2571 |
-
{ url = "https://files.pythonhosted.org/packages/
|
| 2572 |
]
|
| 2573 |
|
| 2574 |
[[package]]
|
|
|
|
| 122 |
|
| 123 |
[[package]]
|
| 124 |
name = "cachetools"
|
| 125 |
+
version = "7.1.1"
|
| 126 |
source = { registry = "https://pypi.org/simple" }
|
| 127 |
+
sdist = { url = "https://files.pythonhosted.org/packages/ff/e2/85f227594656000ff4d8adadae91a21f536d4a84c6c716a86bd6685874be/cachetools-7.1.1.tar.gz", hash = "sha256:27bdf856d68fd3c71c26c01b5edc312124ed427524d1ddb31aa2b7746fe20d4b", size = 40202, upload-time = "2026-05-03T20:00:29.391Z" }
|
| 128 |
wheels = [
|
| 129 |
+
{ url = "https://files.pythonhosted.org/packages/bf/0f/f897abe4ea0a8c408ae65c8c83bffab4936ad65d6032d4fb4cd35bbdc3ee/cachetools-7.1.1-py3-none-any.whl", hash = "sha256:0335cd7a0952d2b22327441fb0628139e234c565559eeb91a8a4ac7551c5353d", size = 16775, upload-time = "2026-05-03T20:00:27.857Z" },
|
| 130 |
]
|
| 131 |
|
| 132 |
[[package]]
|
|
|
|
| 511 |
{ url = "https://files.pythonhosted.org/packages/5d/13/ad7d7ca3808a898b4612b6fe93cde56b53f3034dcde235acb1f0e1df24c6/idna-3.13-py3-none-any.whl", hash = "sha256:892ea0cde124a99ce773decba204c5552b69c3c67ffd5f232eb7696135bc8bb3", size = 68629, upload-time = "2026-04-22T16:42:40.909Z" },
|
| 512 |
]
|
| 513 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 514 |
[[package]]
|
| 515 |
name = "ipython"
|
| 516 |
version = "9.13.0"
|
|
|
|
| 556 |
|
| 557 |
[[package]]
|
| 558 |
name = "jedi"
|
| 559 |
+
version = "0.20.0"
|
| 560 |
source = { registry = "https://pypi.org/simple" }
|
| 561 |
dependencies = [
|
| 562 |
{ name = "parso" },
|
| 563 |
]
|
| 564 |
+
sdist = { url = "https://files.pythonhosted.org/packages/46/b7/a3635f6a2d7cf5b5dd98064fc1d5fbbafcb25477bcea204a3a92145d158b/jedi-0.20.0.tar.gz", hash = "sha256:c3f4ccbd276696f4b19c54618d4fb18f9fc24b0aef02acf704b23f487daa1011", size = 3119416, upload-time = "2026-05-01T23:38:47.814Z" }
|
| 565 |
wheels = [
|
| 566 |
+
{ url = "https://files.pythonhosted.org/packages/9a/93/242e2eab5fe682ffcb8b0084bde703a41d51e17ee0f3a31ff0d9d813620a/jedi-0.20.0-py2.py3-none-any.whl", hash = "sha256:7bdd9c2634f56713299976f4cbd59cb3fa92165cc5e05ea811fb253480728b67", size = 4884812, upload-time = "2026-05-01T23:38:43.919Z" },
|
| 567 |
]
|
| 568 |
|
| 569 |
[[package]]
|
|
|
|
| 616 |
|
| 617 |
[[package]]
|
| 618 |
name = "kaleido"
|
| 619 |
+
version = "1.3.0"
|
| 620 |
source = { registry = "https://pypi.org/simple" }
|
| 621 |
dependencies = [
|
| 622 |
{ name = "choreographer" },
|
| 623 |
{ name = "logistro" },
|
| 624 |
{ name = "orjson" },
|
| 625 |
{ name = "packaging" },
|
|
|
|
| 626 |
]
|
| 627 |
+
sdist = { url = "https://files.pythonhosted.org/packages/e0/64/53eac73d31dbfc3310ee2e87bcac1ae7417427f0fbe3dd800eaf676db324/kaleido-1.3.0.tar.gz", hash = "sha256:5e0378a7475e98852773deeb6483dee91f8aa7b364dde7b5f2b3622cb468a3e6", size = 68938, upload-time = "2026-05-04T19:45:28.932Z" }
|
| 628 |
wheels = [
|
| 629 |
+
{ url = "https://files.pythonhosted.org/packages/9e/b9/a6d8bb7d228940f01885bd9f327ab7f9d366a9be775c4bf366bf9d9477ae/kaleido-1.3.0-py3-none-any.whl", hash = "sha256:52714dfd38e8f2a114831826200c40bb10d0ca0c11d4272f3f48ad499cd8f8ea", size = 55580, upload-time = "2026-05-04T19:45:27.483Z" },
|
| 630 |
]
|
| 631 |
|
| 632 |
[[package]]
|
|
|
|
| 787 |
|
| 788 |
[[package]]
|
| 789 |
name = "nnsight"
|
| 790 |
+
version = "0.7.0"
|
| 791 |
source = { registry = "https://pypi.org/simple" }
|
| 792 |
dependencies = [
|
| 793 |
{ name = "accelerate" },
|
|
|
|
| 803 |
{ name = "transformers" },
|
| 804 |
{ name = "zstandard" },
|
| 805 |
]
|
| 806 |
+
sdist = { url = "https://files.pythonhosted.org/packages/a6/9e/76fd632deef926599d3d16c02fd736cf5f8465f49d708d5be13cd6638484/nnsight-0.7.0.tar.gz", hash = "sha256:5bc6678d567ecc5590b823b7bbab2c310c69f8dda6f4064684c6d488563eebee", size = 1912851, upload-time = "2026-05-05T05:40:50.084Z" }
|
| 807 |
wheels = [
|
| 808 |
+
{ url = "https://files.pythonhosted.org/packages/a8/4d/634c1f08e5d34d9d145f81a5b872c5d2a2fcae714c77450e4f14660670b8/nnsight-0.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d988fd477abcc15d413012aa1b6eaef667da81b2d61c86c528b2a154adf259c5", size = 265015, upload-time = "2026-05-05T05:40:30.453Z" },
|
| 809 |
+
{ url = "https://files.pythonhosted.org/packages/f5/fa/aee0a17356528f6cf4e0deaf3af2f0a88f660f793fb1fbb10f363682ef6a/nnsight-0.7.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:000f5285a754d63fafc0a677e4518192f16c5fa3f2af347a46e2f27c793ef2b2", size = 272487, upload-time = "2026-05-05T05:40:31.545Z" },
|
| 810 |
+
{ url = "https://files.pythonhosted.org/packages/85/77/c59d1e983936e77b2d9fcd4694461f78c66d0af95982b42ffc365fa0c224/nnsight-0.7.0-cp312-cp312-win32.whl", hash = "sha256:2601ccc8a352c09f6f17576ea07943ed29eb3c8d1bba2ec174fbb554942a3821", size = 267080, upload-time = "2026-05-05T05:40:32.909Z" },
|
| 811 |
+
{ url = "https://files.pythonhosted.org/packages/f1/1b/d3d60dfbbd1167320ec7f02cf06f4c589a3d2d462c3748f45bba7eaa17eb/nnsight-0.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:408e2ae8b0da3af0b71f3c8dbb2301b671801205f7a3f95ede7f3f6672c5a8a5", size = 267491, upload-time = "2026-05-05T05:40:34.469Z" },
|
| 812 |
+
{ url = "https://files.pythonhosted.org/packages/d3/dd/2e2f800876f00a0f38e7ef5e536c53bacf9ac7775efc8b337ba117a4459c/nnsight-0.7.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:753acc23dcc97ed32bfc3f650400f049e94621dc56cd9a372dbbfef868f98753", size = 265005, upload-time = "2026-05-05T05:40:36.066Z" },
|
| 813 |
+
{ url = "https://files.pythonhosted.org/packages/0a/e0/335c94482b3c739994cdabaef02a12458c806cbe7ca5883d03380f6fdf8a/nnsight-0.7.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3b59ff8a2d41ed482313fa21f62e5a035068bd68828201cccb928c6e0ded2b4f", size = 272540, upload-time = "2026-05-05T05:40:37.293Z" },
|
| 814 |
+
{ url = "https://files.pythonhosted.org/packages/93/18/65f473ae3147156dec3e6d30b5fa386491474964db1af24f478ef467718d/nnsight-0.7.0-cp313-cp313-win32.whl", hash = "sha256:ed4c01b7cc882e3699de56f1ee77cedd7f720797bb3670a690ec1edeeb34f9bb", size = 267067, upload-time = "2026-05-05T05:40:38.35Z" },
|
| 815 |
+
{ url = "https://files.pythonhosted.org/packages/67/0f/eb2a6cdff12abcd7264b6e9992c9fe907cfd4c445ea3df1b3aa4165061ff/nnsight-0.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:8cd6b5aac59175a6c436fb4fca713a6a58fc85407f9b4f938a819a467c17108e", size = 267492, upload-time = "2026-05-05T05:40:39.66Z" },
|
| 816 |
+
{ url = "https://files.pythonhosted.org/packages/ba/76/6148cd42a1323d218c64ed122f4da775233fb2e44edef0341fd8aa976c7e/nnsight-0.7.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:baaa6409717a8e95cadf7d71fbb376103524f9137a0cc0d8ea55c2656add4326", size = 265011, upload-time = "2026-05-05T05:40:40.693Z" },
|
| 817 |
+
{ url = "https://files.pythonhosted.org/packages/29/a7/b1ba47acbe95fafc854170390d29ef3dd0bdaac44576f1554f25e53c8309/nnsight-0.7.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:22f172565dab55583fb2eb0d615697af76e2c22f12644309738bf7174090ccc6", size = 272569, upload-time = "2026-05-05T05:40:42.003Z" },
|
| 818 |
+
{ url = "https://files.pythonhosted.org/packages/d1/38/05845f3fc29631450efa5d9b8e17aaff37d2ae4e6eee8f8e38ce0ba5b162/nnsight-0.7.0-cp314-cp314-win32.whl", hash = "sha256:6b360febd90a1c5330c7e881a11aaa51f76f04e4876fe045c7f0481ded9e0398", size = 267182, upload-time = "2026-05-05T05:40:43.146Z" },
|
| 819 |
+
{ url = "https://files.pythonhosted.org/packages/ef/00/9300ab917f759d2e4cf2c4ceb23b76827fe5e77379c1933f6b67c0bb680c/nnsight-0.7.0-cp314-cp314-win_amd64.whl", hash = "sha256:65a894add8c93b38d781affac4bf10a58050ed3e46497b2038febf9d8633ce21", size = 267606, upload-time = "2026-05-05T05:40:44.303Z" },
|
| 820 |
+
{ url = "https://files.pythonhosted.org/packages/81/36/1f698a62ae01187dd6c4e63a561220ec7999ad489fea796df3733d700d36/nnsight-0.7.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:171a0a1a13f3f8ea8ddc06c98eb3f924c7a8dfcc108be1407c7c071504f0370f", size = 265117, upload-time = "2026-05-05T05:40:45.411Z" },
|
| 821 |
+
{ url = "https://files.pythonhosted.org/packages/ec/2e/0df50b4b98896b02fc76dd08997aab977eef96e0fb3a2646c5c695a3389f/nnsight-0.7.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:eeed67fa7d8ef03a1196533ec341dbe1dfb40bb7615c8dee06a22009eba1be0f", size = 273366, upload-time = "2026-05-05T05:40:46.648Z" },
|
| 822 |
+
{ url = "https://files.pythonhosted.org/packages/7f/da/f47e1102a9216dc487608ce5b7bc058afb0ea9f3f087e85f83ed877d156d/nnsight-0.7.0-cp314-cp314t-win32.whl", hash = "sha256:7140bf2cab8d4bff46e8f6a1a942185fcdf8f18f62ddb173ca7fe000b6e580cc", size = 267290, upload-time = "2026-05-05T05:40:47.956Z" },
|
| 823 |
+
{ url = "https://files.pythonhosted.org/packages/bf/c9/749370592057f790fc60f2542848b09542ca8eb75b4583e1dde01e7e40c9/nnsight-0.7.0-cp314-cp314t-win_amd64.whl", hash = "sha256:544f7e32e277738ca10e85cdd3f07d8965ec3388854fdf14373308580f8fd0f4", size = 267722, upload-time = "2026-05-05T05:40:49.036Z" },
|
| 824 |
]
|
| 825 |
|
| 826 |
[[package]]
|
|
|
|
| 1190 |
|
| 1191 |
[[package]]
|
| 1192 |
name = "parso"
|
| 1193 |
+
version = "0.8.7"
|
| 1194 |
source = { registry = "https://pypi.org/simple" }
|
| 1195 |
+
sdist = { url = "https://files.pythonhosted.org/packages/30/4b/90c937815137d43ce71ba043cd3566221e9df6b9c805f24b5d138c9d40a7/parso-0.8.7.tar.gz", hash = "sha256:eaaac4c9fdd5e9e8852dc778d2d7405897ec510f2a298071453e5e3a07914bb1", size = 401824, upload-time = "2026-05-01T23:13:02.138Z" }
|
| 1196 |
wheels = [
|
| 1197 |
+
{ url = "https://files.pythonhosted.org/packages/99/5d/8268b644392ee874ee82a635cd0df1773de230bde356c38de28e298392cc/parso-0.8.7-py2.py3-none-any.whl", hash = "sha256:a8926eb2a1b915486941fdbd31e86a4baf88fe8c210f25f2f35ecec5b574ca1c", size = 107025, upload-time = "2026-05-01T23:12:58.867Z" },
|
| 1198 |
]
|
| 1199 |
|
| 1200 |
[[package]]
|
| 1201 |
name = "persona-data"
|
| 1202 |
+
version = "0.3.4"
|
| 1203 |
source = { registry = "https://pypi.org/simple" }
|
| 1204 |
dependencies = [
|
| 1205 |
{ name = "huggingface-hub" },
|
|
|
|
| 1208 |
{ name = "python-dotenv" },
|
| 1209 |
{ name = "torch" },
|
| 1210 |
]
|
| 1211 |
+
sdist = { url = "https://files.pythonhosted.org/packages/c5/03/25b6dbcbf4be440cf85e5dc5ee30eff69a5c2db638c5452b50932e758141/persona_data-0.3.4.tar.gz", hash = "sha256:c887b53a6fcddd80bac6485f6639902456cc7fd878af89ba9b8102d02aaa335e", size = 8533, upload-time = "2026-05-04T15:27:17.057Z" }
|
| 1212 |
wheels = [
|
| 1213 |
+
{ url = "https://files.pythonhosted.org/packages/c7/28/b45feb5186615b8a4ff647eaaf35739a35b1afe3fd35b3d09e2734475b2c/persona_data-0.3.4-py3-none-any.whl", hash = "sha256:666c34b8bd7588bfe2d55a62737cd7d87e0d663fbead3a4458d284f3e660bab2", size = 11178, upload-time = "2026-05-04T15:27:17.749Z" },
|
| 1214 |
]
|
| 1215 |
|
| 1216 |
[[package]]
|
|
|
|
| 1228 |
|
| 1229 |
[package.metadata]
|
| 1230 |
requires-dist = [
|
| 1231 |
+
{ name = "persona-data", specifier = ">=0.3.4" },
|
| 1232 |
+
{ name = "persona-vectors", editable = "../persona-vectors" },
|
| 1233 |
{ name = "plotly", specifier = ">=6.6.0" },
|
| 1234 |
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
| 1235 |
{ name = "streamlit", specifier = ">=1.44.0" },
|
|
|
|
| 1239 |
[[package]]
|
| 1240 |
name = "persona-vectors"
|
| 1241 |
version = "0.4.3"
|
| 1242 |
+
source = { editable = "../persona-vectors" }
|
| 1243 |
dependencies = [
|
| 1244 |
{ name = "kaleido" },
|
| 1245 |
{ name = "nnsight" },
|
|
|
|
| 1255 |
{ name = "transformers" },
|
| 1256 |
{ name = "umap-learn" },
|
| 1257 |
]
|
| 1258 |
+
|
| 1259 |
+
[package.metadata]
|
| 1260 |
+
requires-dist = [
|
| 1261 |
+
{ name = "kaleido", specifier = ">=1.0.0" },
|
| 1262 |
+
{ name = "nnsight", specifier = ">=0.6.1" },
|
| 1263 |
+
{ name = "nnterp", specifier = ">=1.3.0" },
|
| 1264 |
+
{ name = "persona-data", specifier = ">=0.3.4" },
|
| 1265 |
+
{ name = "plotly", specifier = ">=6.6.0" },
|
| 1266 |
+
{ name = "python-dotenv", specifier = ">=1.2.2" },
|
| 1267 |
+
{ name = "safetensors", specifier = ">=0.7.0" },
|
| 1268 |
+
{ name = "scikit-learn", specifier = ">=1.6.0" },
|
| 1269 |
+
{ name = "torch", specifier = ">=2.10.0" },
|
| 1270 |
+
{ name = "torchvision", specifier = ">=0.26.0" },
|
| 1271 |
+
{ name = "tqdm", specifier = ">=4.67.3" },
|
| 1272 |
+
{ name = "transformers", specifier = ">=5.2.0" },
|
| 1273 |
+
{ name = "umap-learn", specifier = ">=0.5.7" },
|
| 1274 |
]
|
| 1275 |
|
| 1276 |
[[package]]
|
|
|
|
| 1376 |
{ url = "https://files.pythonhosted.org/packages/90/ad/cba91b3bcf04073e4d1655a5c1710ef3f457f56f7d1b79dcc3d72f4dd912/plotly-6.7.0-py3-none-any.whl", hash = "sha256:ac8aca1c25c663a59b5b9140a549264a5badde2e057d79b8c772ae2920e32ff0", size = 9898444, upload-time = "2026-04-09T20:36:39.812Z" },
|
| 1377 |
]
|
| 1378 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1379 |
[[package]]
|
| 1380 |
name = "prompt-toolkit"
|
| 1381 |
version = "3.0.52"
|
|
|
|
| 1620 |
{ url = "https://files.pythonhosted.org/packages/b2/e6/94145d714402fd5ade00b5661f2d0ab981219e07f7db9bfa16786cdb9c04/pynndescent-0.6.0-py3-none-any.whl", hash = "sha256:dc8c74844e4c7f5cbd1e0cd6909da86fdc789e6ff4997336e344779c3d5538ef", size = 73511, upload-time = "2026-01-08T21:29:57.306Z" },
|
| 1621 |
]
|
| 1622 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1623 |
[[package]]
|
| 1624 |
name = "python-dateutil"
|
| 1625 |
version = "2.9.0.post0"
|
|
|
|
| 2530 |
|
| 2531 |
[[package]]
|
| 2532 |
name = "wcwidth"
|
| 2533 |
+
version = "0.7.0"
|
| 2534 |
source = { registry = "https://pypi.org/simple" }
|
| 2535 |
+
sdist = { url = "https://files.pythonhosted.org/packages/2c/ee/afaf0f85a9a18fe47a67f1e4422ed6cf1fe642f0ae0a2f81166231303c52/wcwidth-0.7.0.tar.gz", hash = "sha256:90e3a7ea092341c44b99562e75d09e4d5160fe7a3974c6fb842a101a95e7eed0", size = 182132, upload-time = "2026-05-02T16:04:12.653Z" }
|
| 2536 |
wheels = [
|
| 2537 |
+
{ url = "https://files.pythonhosted.org/packages/41/52/e465037f5375f43533d1a80b6923955201596a99142ed524d77b571a1418/wcwidth-0.7.0-py3-none-any.whl", hash = "sha256:5d69154c429a82910e241c738cd0e2976fac8a2dd47a1a805f4afed1c0f136f2", size = 110825, upload-time = "2026-05-02T16:04:11.033Z" },
|
| 2538 |
]
|
| 2539 |
|
| 2540 |
[[package]]
|