persona-ui / tabs /probe.py
Jac-Zac
Big refactoring
b279884
"""Probing tab: run linear-probe sweeps over persona vectors.
UX mirrors the Analysis tab (source -> mask -> variant -> personas), but
the action is a probe sweep and the output is a metric-over-layer curve,
the best-layer summary, and optional controls (shuffled-label selectivity,
save artifact).
The probe primitives all live in ``persona_vectors.probes``; this file
is a thin Streamlit wrapper around them.
"""
from __future__ import annotations
import streamlit as st
from persona_vectors.analysis import LayeredSamples
from persona_vectors.attributes import attribute_display_label
from persona_vectors.extraction import MaskStrategy
from persona_vectors.plots import plot_metric_comparison, plot_metric_over_layers
from persona_vectors.probes import (
AttributeLabels,
default_probe_kinds,
infer_probe_task,
layer_matrix,
save_probe_artifact,
shuffle_label_baseline,
)
from tabs.probe_sweep import SweepInputs, cached_sweep
from utils.analysis_metadata import (
synth_persona_attribute_names,
synth_persona_dataset_cached,
)
from utils.analysis_sources import (
Store,
available_variants,
persona_names_cached,
personas_cached,
store_cache_parts,
store_layers_cached,
)
from utils.controls import render_mask_strategy_select
from utils.helpers import widget_key
from utils.source_controls import render_source_select, render_store_select
# ---------------------------------------------------------------------------
# Constants and config
# ---------------------------------------------------------------------------
_DEFAULT_OUTPUT_DIR = "artifacts/probes"
_MIN_CLASS_COUNT = 5
# Per-task primary metric for "best layer" + first plot.
_PRIMARY_METRIC = {
"binary": "balanced_accuracy",
"categorical": "balanced_accuracy",
"ordinal": "balanced_accuracy",
"numeric": "r2",
}
_SECONDARY_METRIC = {
"binary": None,
"categorical": None,
"ordinal": "mae",
"numeric": "mae",
}
def _select_variant(store: Store, mask_strategy: MaskStrategy) -> str | None:
variants = available_variants(store, mask_strategy)
if not variants:
st.info("No variants with saved vectors for this selection.")
return None
previous = st.session_state.get("probe:variant", variants[0])
return st.selectbox(
"Variant",
options=variants,
index=variants.index(previous) if previous in variants else 0,
key="probe:variant",
)
def _select_personas(
store: Store, variant: str, mask_strategy: MaskStrategy
) -> list[str]:
source, location, model_name = store_cache_parts(store)
all_ids = personas_cached(
source, location, model_name, mask_strategy.value, (variant,)
)
if not all_ids:
st.info("No personas found for this variant.")
return []
regular = all_ids
if len(regular) < 2:
st.info("At least two non-assistant personas are needed for probing.")
return []
min_count = min(10, len(regular))
if min_count == len(regular):
count = len(regular)
st.warning(
f"Only {count} non-assistant personas are available; using all of them."
)
st.session_state["probe:persona_count"] = count
persona_ids = regular
persona_names_cached(
source,
location,
model_name,
mask_strategy.value,
(variant,),
tuple(persona_ids),
)
st.caption(f"Probing {len(persona_ids)} non-assistant personas.")
return persona_ids
default_count = min(
len(regular),
max(min_count, st.session_state.get("probe:persona_count", len(regular))),
)
count = st.slider(
"Personas",
min_value=min_count,
max_value=len(regular),
value=default_count,
key="probe:persona_count_slider",
)
st.session_state["probe:persona_count"] = count
persona_ids = regular[:count]
persona_names_cached(
source,
location,
model_name,
mask_strategy.value,
(variant,),
tuple(persona_ids),
)
st.caption(f"Probing {len(persona_ids)} of {len(regular)} non-assistant personas.")
return persona_ids
# ---------------------------------------------------------------------------
# Probe config UI
# ---------------------------------------------------------------------------
@st.cache_data(show_spinner=False)
def _attribute_tasks() -> dict[str, str]:
dataset = synth_persona_dataset_cached()
return {
name: infer_probe_task(dataset, name)
for name in synth_persona_attribute_names()
}
def _select_attributes() -> list[str]:
"""Multi-select locked to one task type.
Picking the first attribute fixes the task; only same-task attributes stay
selectable. Clearing the selection reopens every attribute again.
"""
dataset = synth_persona_dataset_cached()
tasks = _attribute_tasks()
all_names = list(synth_persona_attribute_names())
key = "probe:attributes"
if key not in st.session_state:
st.session_state[key] = ["sex"] if "sex" in all_names else all_names[:1]
selected = st.session_state[key]
if selected:
locked = tasks[selected[0]]
options = [name for name in all_names if tasks[name] == locked]
else:
options = all_names
return st.multiselect(
"Attributes to probe",
options=options,
format_func=lambda name: attribute_display_label(dataset, name),
key=key,
help="Pick one or more attributes of the same task type. They are "
"overlaid in one figure. Remove all to switch to a different task type.",
)
def _select_probe_kinds(task: str) -> list[str]:
"""Pick which probe families to fit. Only shown when the task has >1."""
available = list(default_probe_kinds(task)) # type: ignore[arg-type]
if len(available) < 2:
return available
selected = st.multiselect(
"Probe kinds to fit",
options=available,
default=available,
key=f"probe:kinds:{task}",
help="Which probe families to fit at each layer. Defaults to all "
"available for this task.",
)
return selected or available
def _select_pca_components() -> int | None:
use_pca = st.toggle(
"Add PCA-compressed comparison",
value=False,
key="probe:use_pca",
help="Runs the normal full-activation sweep and a second sweep where "
"PCA is fit on the train split only before probing.",
)
if not use_pca:
return None
return int(
st.number_input(
"PCA components",
min_value=2,
max_value=512,
value=10,
step=1,
key="probe:pca_components",
)
)
def _select_layers(num_layers: int) -> list[int]:
fast = st.toggle(
"Fast layer set (5 evenly-spaced)",
value=True,
key="probe:fast",
help="Off = sweep every layer. Slow on big models.",
)
if not fast:
return list(range(num_layers))
return sorted(
{
0,
num_layers // 4,
num_layers // 2,
(3 * num_layers) // 4,
num_layers - 1,
}
)
# ---------------------------------------------------------------------------
# Sweep + display
# ---------------------------------------------------------------------------
def _show_sweep(
rows_by_label: dict[str, list[dict[str, object]]],
per_attr: dict[str, tuple[AttributeLabels, LayeredSamples]],
attributes: tuple[str, ...],
task: str,
inputs: SweepInputs,
) -> None:
primary = _PRIMARY_METRIC[task]
secondary = _SECONDARY_METRIC.get(task)
primary_label = (
f"pca{inputs.n_pca_components}" if inputs.n_pca_components else "full"
)
rows = rows_by_label.get(primary_label) or next(iter(rows_by_label.values()))
def _plot(metric: str):
if len(rows_by_label) > 1 or len(attributes) > 1:
return plot_metric_comparison(
rows_by_label, list(attributes), metric=metric
)
return plot_metric_over_layers(rows, attributes[0], metric=metric)
st.plotly_chart(_plot(primary), width="stretch")
if secondary is not None:
st.plotly_chart(_plot(secondary), width="stretch")
higher_better = primary != "mae"
def _best_row(label_rows: list[dict[str, object]]) -> dict[str, object] | None:
valid_rows = [row for row in label_rows if row.get(primary) is not None]
if not valid_rows:
return None
return max(
valid_rows,
key=lambda row: row[primary] * (1 if higher_better else -1),
)
valid = [row for row in rows if row.get(primary) is not None]
if not valid:
st.warning(f"No rows reported {primary!r}; can't pick a best layer.")
return
best = _best_row(rows)
if best is None:
return
multi_attr = len(attributes) > 1
if len(rows_by_label) > 1 or multi_attr:
summary_rows = []
for label, label_rows in rows_by_label.items():
for attribute in attributes:
attr_rows = [
row for row in label_rows if row.get("attribute") == attribute
]
label_best = _best_row(attr_rows)
if label_best is None:
continue
summary_row: dict[str, object] = {}
if multi_attr:
summary_row["attribute"] = attribute
summary_row.update(
{
"features": label,
"best_layer": label_best["layer"],
"probe": label_best["probe_kind"],
primary: round(float(label_best[primary]), 3),
f"baseline_{primary}": round(
float(label_best.get(f"baseline_{primary}", float("nan"))),
3,
),
}
)
summary_rows.append(summary_row)
if summary_rows:
st.dataframe(summary_rows, width="stretch", hide_index=True)
feature_desc = f" · pca{inputs.n_pca_components}" if inputs.n_pca_components else ""
best_attr = str(best["attribute"])
labels, samples = per_attr[best_attr]
if multi_attr:
# The per-attribute summary table above already covers every result;
# a single "best" card would only show one attribute, so skip it and
# just say which one the controls below operate on.
st.caption(f"Controls below use the best result: **{best_attr}**.")
else:
cols = st.columns([1, 1.2, 1.8])
cols[0].metric("Best layer", best["layer"])
cols[1].metric(
f"Best {primary}",
f"{best[primary]:.3f}",
delta=f"baseline {best.get(f'baseline_{primary}', float('nan')):.3f}",
delta_color="off",
)
cols[2].metric("Probe", f"{best['probe_kind']}{feature_desc}")
_render_selectivity_control(best, labels, samples, task, inputs)
_render_save_artifact(best, labels, samples, task, inputs)
def _render_selectivity_control(
best: dict[str, object],
labels: AttributeLabels,
samples: LayeredSamples,
task: str,
inputs: SweepInputs,
) -> None:
if task == "numeric":
return # selectivity control is classification-only
with st.expander("Selectivity control (shuffled labels)"):
st.caption(
"Trains the same probe on shuffled labels. The gap between the real-label "
"score and this shuffled score is the probe's *selectivity* "
"(Hewitt & Liang 2019). High shuffled scores mean the probe is reading "
"dataset artifacts, not the property."
)
n_repeats = st.slider(
"Shuffle repeats",
min_value=3,
max_value=15,
value=5,
key="probe:shuffle_repeats",
)
if st.button("Run selectivity control", key="probe:run_shuffle"):
with st.spinner("Running shuffled-label control..."):
X = layer_matrix(samples, int(best["layer"]))
shuffled = shuffle_label_baseline(
X,
labels.y,
task=task, # type: ignore[arg-type]
layer=int(best["layer"]),
probe_kind=best["probe_kind"], # type: ignore[arg-type]
n_pca_components=inputs.n_pca_components,
n_repeats=n_repeats,
)
cols = st.columns(2)
cols[0].metric(
"Real balanced acc.",
f"{float(best['balanced_accuracy']):.3f}",
)
cols[1].metric(
"Shuffled balanced acc.",
f"{shuffled['balanced_accuracy_mean']:.3f}",
delta=f"+/- {shuffled['balanced_accuracy_std']:.3f}",
delta_color="off",
)
def _render_save_artifact(
best: dict[str, object],
labels: AttributeLabels,
samples: LayeredSamples,
task: str,
inputs: SweepInputs,
) -> None:
def synced_default(key: str, default: str) -> str:
default_key = f"{key}:default"
previous_default = st.session_state.get(default_key)
current_value = st.session_state.get(key)
if current_value is None or current_value == previous_default:
st.session_state[key] = default
st.session_state[default_key] = default
return st.session_state[key]
with st.expander("Save best probe (loadable by the Chat tab)"):
output_dir = st.text_input(
"Output directory",
value=st.session_state.get("probe:output_dir", _DEFAULT_OUTPUT_DIR),
key="probe:output_dir",
help="Probe artifacts will be written under this root.",
)
synced_default("probe:save_model", inputs.model_name)
model_name = st.text_input(
"Model name (for the artifact path)",
key="probe:save_model",
)
synced_default("probe:save_variant", inputs.variant)
variant = st.text_input(
"Variant",
key="probe:save_variant",
)
synced_default("probe:save_mask", inputs.mask_value)
mask_value = st.text_input(
"Mask strategy",
key="probe:save_mask",
)
if st.button("Save", key="probe:save_artifact"):
X = layer_matrix(samples, int(best["layer"]))
directory = save_probe_artifact(
X=X,
y=labels.y,
labels=labels,
task=task, # type: ignore[arg-type]
probe_kind=best["probe_kind"], # type: ignore[arg-type]
n_pca_components=inputs.n_pca_components,
layer=int(best["layer"]),
model_name=model_name,
variant=variant,
mask_strategy=mask_value,
output_dir=output_dir,
metrics=best,
)
st.success(f"Saved to `{directory}`")
st.caption(
f"Wrote `probe.json` + `weights.safetensors`. "
"The Chat tab can load the saved `probe.json` artifact."
)
# ---------------------------------------------------------------------------
# Tab entry point
# ---------------------------------------------------------------------------
def render_probing_tab() -> None:
st.title("Probing")
source = render_source_select(widget_scope="probe")
with st.expander("Source", expanded=True):
mask_strategy = render_mask_strategy_select(
key=widget_key("probe", "mask_strategy"),
last_key="probe:last_mask_strategy",
remember_key="source:last_mask_strategy",
help_text="Which extracted activation set to load.",
)
store = render_store_select(
source,
mask_strategy,
state_prefix="probe",
widget_scope="probe",
artifacts_root_key="probe:local_root",
)
variant = _select_variant(store, mask_strategy)
if variant is None:
return
persona_ids = _select_personas(store, variant, mask_strategy)
if not persona_ids:
return
with st.expander("Probe configuration", expanded=True):
attributes = _select_attributes()
if not attributes:
st.info("Select at least one attribute to probe.")
return
task = _attribute_tasks()[attributes[0]]
st.caption(f"Inferred task: **{task}**")
probe_kinds = _select_probe_kinds(task)
n_pca_components = _select_pca_components()
source, location, model_name = store_cache_parts(store)
available_layers = store_layers_cached(
source,
location,
model_name,
mask_strategy.value,
(variant,),
tuple(persona_ids),
)
if not available_layers:
st.info("No layers found for the selected personas.")
return
num_layers = max(available_layers) + 1
layers = _select_layers(num_layers)
min_class_count = _MIN_CLASS_COUNT
seed = 0
inputs = SweepInputs(
source=source,
location=location,
model_name=model_name,
mask_value=mask_strategy.value,
variant=variant,
persona_ids=tuple(persona_ids),
attributes=tuple(attributes),
task=task,
probe_kinds=tuple(probe_kinds),
n_pca_components=n_pca_components,
layers=tuple(layers),
min_class_count=min_class_count,
seed=int(seed),
)
run = st.button("Run sweep", type="primary", key="probe:run")
state_key = "probe:last_result"
if run:
with st.spinner("Evaluating probes across layers..."):
try:
sweep, per_attr = cached_sweep(inputs)
except Exception as exc:
st.error(f"Sweep failed: {exc}")
st.session_state.pop(state_key, None)
return
st.session_state[state_key] = (sweep, per_attr, inputs)
if state_key in st.session_state:
saved_result = st.session_state[state_key]
if len(saved_result) != 3:
# Stale shape from a previous code version — drop it.
st.session_state.pop(state_key, None)
else:
sweep, per_attr, result_inputs = saved_result
_show_sweep(
sweep,
per_attr,
result_inputs.attributes,
result_inputs.task,
result_inputs,
)