| """Probing tab: run linear-probe sweeps over persona vectors. |
| |
| UX mirrors the Analysis tab (source -> mask -> variant -> personas), but |
| the action is a probe sweep and the output is a metric-over-layer curve, |
| the best-layer summary, and optional controls (shuffled-label selectivity, |
| save artifact). |
| |
| The probe primitives all live in ``persona_vectors.probes``; this file |
| is a thin Streamlit wrapper around them. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import streamlit as st |
| from persona_vectors.analysis import LayeredSamples |
| from persona_vectors.attributes import attribute_display_label |
| from persona_vectors.extraction import MaskStrategy |
| from persona_vectors.plots import plot_metric_comparison, plot_metric_over_layers |
| from persona_vectors.probes import ( |
| AttributeLabels, |
| default_probe_kinds, |
| infer_probe_task, |
| layer_matrix, |
| save_probe_artifact, |
| shuffle_label_baseline, |
| ) |
|
|
| from tabs.probe_sweep import SweepInputs, cached_sweep |
| from utils.analysis_metadata import ( |
| synth_persona_attribute_names, |
| synth_persona_dataset_cached, |
| ) |
| from utils.analysis_sources import ( |
| Store, |
| available_variants, |
| persona_names_cached, |
| personas_cached, |
| store_cache_parts, |
| store_layers_cached, |
| ) |
| from utils.controls import render_mask_strategy_select |
| from utils.helpers import widget_key |
| from utils.source_controls import render_source_select, render_store_select |
|
|
| |
| |
| |
|
|
| _DEFAULT_OUTPUT_DIR = "artifacts/probes" |
| _MIN_CLASS_COUNT = 5 |
|
|
| |
| _PRIMARY_METRIC = { |
| "binary": "balanced_accuracy", |
| "categorical": "balanced_accuracy", |
| "ordinal": "balanced_accuracy", |
| "numeric": "r2", |
| } |
| _SECONDARY_METRIC = { |
| "binary": None, |
| "categorical": None, |
| "ordinal": "mae", |
| "numeric": "mae", |
| } |
|
|
|
|
| def _select_variant(store: Store, mask_strategy: MaskStrategy) -> str | None: |
| variants = available_variants(store, mask_strategy) |
| if not variants: |
| st.info("No variants with saved vectors for this selection.") |
| return None |
| previous = st.session_state.get("probe:variant", variants[0]) |
| return st.selectbox( |
| "Variant", |
| options=variants, |
| index=variants.index(previous) if previous in variants else 0, |
| key="probe:variant", |
| ) |
|
|
|
|
| def _select_personas( |
| store: Store, variant: str, mask_strategy: MaskStrategy |
| ) -> list[str]: |
| source, location, model_name = store_cache_parts(store) |
| all_ids = personas_cached( |
| source, location, model_name, mask_strategy.value, (variant,) |
| ) |
| if not all_ids: |
| st.info("No personas found for this variant.") |
| return [] |
| regular = all_ids |
| if len(regular) < 2: |
| st.info("At least two non-assistant personas are needed for probing.") |
| return [] |
| min_count = min(10, len(regular)) |
| if min_count == len(regular): |
| count = len(regular) |
| st.warning( |
| f"Only {count} non-assistant personas are available; using all of them." |
| ) |
| st.session_state["probe:persona_count"] = count |
| persona_ids = regular |
| persona_names_cached( |
| source, |
| location, |
| model_name, |
| mask_strategy.value, |
| (variant,), |
| tuple(persona_ids), |
| ) |
| st.caption(f"Probing {len(persona_ids)} non-assistant personas.") |
| return persona_ids |
|
|
| default_count = min( |
| len(regular), |
| max(min_count, st.session_state.get("probe:persona_count", len(regular))), |
| ) |
| count = st.slider( |
| "Personas", |
| min_value=min_count, |
| max_value=len(regular), |
| value=default_count, |
| key="probe:persona_count_slider", |
| ) |
| st.session_state["probe:persona_count"] = count |
| persona_ids = regular[:count] |
| persona_names_cached( |
| source, |
| location, |
| model_name, |
| mask_strategy.value, |
| (variant,), |
| tuple(persona_ids), |
| ) |
| st.caption(f"Probing {len(persona_ids)} of {len(regular)} non-assistant personas.") |
| return persona_ids |
|
|
|
|
| |
| |
| |
|
|
|
|
| @st.cache_data(show_spinner=False) |
| def _attribute_tasks() -> dict[str, str]: |
| dataset = synth_persona_dataset_cached() |
| return { |
| name: infer_probe_task(dataset, name) |
| for name in synth_persona_attribute_names() |
| } |
|
|
|
|
| def _select_attributes() -> list[str]: |
| """Multi-select locked to one task type. |
| |
| Picking the first attribute fixes the task; only same-task attributes stay |
| selectable. Clearing the selection reopens every attribute again. |
| """ |
| dataset = synth_persona_dataset_cached() |
| tasks = _attribute_tasks() |
| all_names = list(synth_persona_attribute_names()) |
|
|
| key = "probe:attributes" |
| if key not in st.session_state: |
| st.session_state[key] = ["sex"] if "sex" in all_names else all_names[:1] |
|
|
| selected = st.session_state[key] |
| if selected: |
| locked = tasks[selected[0]] |
| options = [name for name in all_names if tasks[name] == locked] |
| else: |
| options = all_names |
|
|
| return st.multiselect( |
| "Attributes to probe", |
| options=options, |
| format_func=lambda name: attribute_display_label(dataset, name), |
| key=key, |
| help="Pick one or more attributes of the same task type. They are " |
| "overlaid in one figure. Remove all to switch to a different task type.", |
| ) |
|
|
|
|
| def _select_probe_kinds(task: str) -> list[str]: |
| """Pick which probe families to fit. Only shown when the task has >1.""" |
| available = list(default_probe_kinds(task)) |
| if len(available) < 2: |
| return available |
| selected = st.multiselect( |
| "Probe kinds to fit", |
| options=available, |
| default=available, |
| key=f"probe:kinds:{task}", |
| help="Which probe families to fit at each layer. Defaults to all " |
| "available for this task.", |
| ) |
| return selected or available |
|
|
|
|
| def _select_pca_components() -> int | None: |
| use_pca = st.toggle( |
| "Add PCA-compressed comparison", |
| value=False, |
| key="probe:use_pca", |
| help="Runs the normal full-activation sweep and a second sweep where " |
| "PCA is fit on the train split only before probing.", |
| ) |
| if not use_pca: |
| return None |
| return int( |
| st.number_input( |
| "PCA components", |
| min_value=2, |
| max_value=512, |
| value=10, |
| step=1, |
| key="probe:pca_components", |
| ) |
| ) |
|
|
|
|
| def _select_layers(num_layers: int) -> list[int]: |
| fast = st.toggle( |
| "Fast layer set (5 evenly-spaced)", |
| value=True, |
| key="probe:fast", |
| help="Off = sweep every layer. Slow on big models.", |
| ) |
| if not fast: |
| return list(range(num_layers)) |
| return sorted( |
| { |
| 0, |
| num_layers // 4, |
| num_layers // 2, |
| (3 * num_layers) // 4, |
| num_layers - 1, |
| } |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _show_sweep( |
| rows_by_label: dict[str, list[dict[str, object]]], |
| per_attr: dict[str, tuple[AttributeLabels, LayeredSamples]], |
| attributes: tuple[str, ...], |
| task: str, |
| inputs: SweepInputs, |
| ) -> None: |
| primary = _PRIMARY_METRIC[task] |
| secondary = _SECONDARY_METRIC.get(task) |
|
|
| primary_label = ( |
| f"pca{inputs.n_pca_components}" if inputs.n_pca_components else "full" |
| ) |
| rows = rows_by_label.get(primary_label) or next(iter(rows_by_label.values())) |
|
|
| def _plot(metric: str): |
| if len(rows_by_label) > 1 or len(attributes) > 1: |
| return plot_metric_comparison( |
| rows_by_label, list(attributes), metric=metric |
| ) |
| return plot_metric_over_layers(rows, attributes[0], metric=metric) |
|
|
| st.plotly_chart(_plot(primary), width="stretch") |
| if secondary is not None: |
| st.plotly_chart(_plot(secondary), width="stretch") |
|
|
| higher_better = primary != "mae" |
|
|
| def _best_row(label_rows: list[dict[str, object]]) -> dict[str, object] | None: |
| valid_rows = [row for row in label_rows if row.get(primary) is not None] |
| if not valid_rows: |
| return None |
| return max( |
| valid_rows, |
| key=lambda row: row[primary] * (1 if higher_better else -1), |
| ) |
|
|
| valid = [row for row in rows if row.get(primary) is not None] |
| if not valid: |
| st.warning(f"No rows reported {primary!r}; can't pick a best layer.") |
| return |
| best = _best_row(rows) |
| if best is None: |
| return |
|
|
| multi_attr = len(attributes) > 1 |
| if len(rows_by_label) > 1 or multi_attr: |
| summary_rows = [] |
| for label, label_rows in rows_by_label.items(): |
| for attribute in attributes: |
| attr_rows = [ |
| row for row in label_rows if row.get("attribute") == attribute |
| ] |
| label_best = _best_row(attr_rows) |
| if label_best is None: |
| continue |
| summary_row: dict[str, object] = {} |
| if multi_attr: |
| summary_row["attribute"] = attribute |
| summary_row.update( |
| { |
| "features": label, |
| "best_layer": label_best["layer"], |
| "probe": label_best["probe_kind"], |
| primary: round(float(label_best[primary]), 3), |
| f"baseline_{primary}": round( |
| float(label_best.get(f"baseline_{primary}", float("nan"))), |
| 3, |
| ), |
| } |
| ) |
| summary_rows.append(summary_row) |
| if summary_rows: |
| st.dataframe(summary_rows, width="stretch", hide_index=True) |
|
|
| feature_desc = f" · pca{inputs.n_pca_components}" if inputs.n_pca_components else "" |
|
|
| best_attr = str(best["attribute"]) |
| labels, samples = per_attr[best_attr] |
| if multi_attr: |
| |
| |
| |
| st.caption(f"Controls below use the best result: **{best_attr}**.") |
| else: |
| cols = st.columns([1, 1.2, 1.8]) |
| cols[0].metric("Best layer", best["layer"]) |
| cols[1].metric( |
| f"Best {primary}", |
| f"{best[primary]:.3f}", |
| delta=f"baseline {best.get(f'baseline_{primary}', float('nan')):.3f}", |
| delta_color="off", |
| ) |
| cols[2].metric("Probe", f"{best['probe_kind']}{feature_desc}") |
|
|
| _render_selectivity_control(best, labels, samples, task, inputs) |
| _render_save_artifact(best, labels, samples, task, inputs) |
|
|
|
|
| def _render_selectivity_control( |
| best: dict[str, object], |
| labels: AttributeLabels, |
| samples: LayeredSamples, |
| task: str, |
| inputs: SweepInputs, |
| ) -> None: |
| if task == "numeric": |
| return |
| with st.expander("Selectivity control (shuffled labels)"): |
| st.caption( |
| "Trains the same probe on shuffled labels. The gap between the real-label " |
| "score and this shuffled score is the probe's *selectivity* " |
| "(Hewitt & Liang 2019). High shuffled scores mean the probe is reading " |
| "dataset artifacts, not the property." |
| ) |
| n_repeats = st.slider( |
| "Shuffle repeats", |
| min_value=3, |
| max_value=15, |
| value=5, |
| key="probe:shuffle_repeats", |
| ) |
| if st.button("Run selectivity control", key="probe:run_shuffle"): |
| with st.spinner("Running shuffled-label control..."): |
| X = layer_matrix(samples, int(best["layer"])) |
| shuffled = shuffle_label_baseline( |
| X, |
| labels.y, |
| task=task, |
| layer=int(best["layer"]), |
| probe_kind=best["probe_kind"], |
| n_pca_components=inputs.n_pca_components, |
| n_repeats=n_repeats, |
| ) |
| cols = st.columns(2) |
| cols[0].metric( |
| "Real balanced acc.", |
| f"{float(best['balanced_accuracy']):.3f}", |
| ) |
| cols[1].metric( |
| "Shuffled balanced acc.", |
| f"{shuffled['balanced_accuracy_mean']:.3f}", |
| delta=f"+/- {shuffled['balanced_accuracy_std']:.3f}", |
| delta_color="off", |
| ) |
|
|
|
|
| def _render_save_artifact( |
| best: dict[str, object], |
| labels: AttributeLabels, |
| samples: LayeredSamples, |
| task: str, |
| inputs: SweepInputs, |
| ) -> None: |
| def synced_default(key: str, default: str) -> str: |
| default_key = f"{key}:default" |
| previous_default = st.session_state.get(default_key) |
| current_value = st.session_state.get(key) |
| if current_value is None or current_value == previous_default: |
| st.session_state[key] = default |
| st.session_state[default_key] = default |
| return st.session_state[key] |
|
|
| with st.expander("Save best probe (loadable by the Chat tab)"): |
| output_dir = st.text_input( |
| "Output directory", |
| value=st.session_state.get("probe:output_dir", _DEFAULT_OUTPUT_DIR), |
| key="probe:output_dir", |
| help="Probe artifacts will be written under this root.", |
| ) |
| synced_default("probe:save_model", inputs.model_name) |
| model_name = st.text_input( |
| "Model name (for the artifact path)", |
| key="probe:save_model", |
| ) |
| synced_default("probe:save_variant", inputs.variant) |
| variant = st.text_input( |
| "Variant", |
| key="probe:save_variant", |
| ) |
| synced_default("probe:save_mask", inputs.mask_value) |
| mask_value = st.text_input( |
| "Mask strategy", |
| key="probe:save_mask", |
| ) |
| if st.button("Save", key="probe:save_artifact"): |
| X = layer_matrix(samples, int(best["layer"])) |
| directory = save_probe_artifact( |
| X=X, |
| y=labels.y, |
| labels=labels, |
| task=task, |
| probe_kind=best["probe_kind"], |
| n_pca_components=inputs.n_pca_components, |
| layer=int(best["layer"]), |
| model_name=model_name, |
| variant=variant, |
| mask_strategy=mask_value, |
| output_dir=output_dir, |
| metrics=best, |
| ) |
| st.success(f"Saved to `{directory}`") |
| st.caption( |
| f"Wrote `probe.json` + `weights.safetensors`. " |
| "The Chat tab can load the saved `probe.json` artifact." |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def render_probing_tab() -> None: |
| st.title("Probing") |
|
|
| source = render_source_select(widget_scope="probe") |
| with st.expander("Source", expanded=True): |
| mask_strategy = render_mask_strategy_select( |
| key=widget_key("probe", "mask_strategy"), |
| last_key="probe:last_mask_strategy", |
| remember_key="source:last_mask_strategy", |
| help_text="Which extracted activation set to load.", |
| ) |
| store = render_store_select( |
| source, |
| mask_strategy, |
| state_prefix="probe", |
| widget_scope="probe", |
| artifacts_root_key="probe:local_root", |
| ) |
| variant = _select_variant(store, mask_strategy) |
| if variant is None: |
| return |
| persona_ids = _select_personas(store, variant, mask_strategy) |
| if not persona_ids: |
| return |
|
|
| with st.expander("Probe configuration", expanded=True): |
| attributes = _select_attributes() |
| if not attributes: |
| st.info("Select at least one attribute to probe.") |
| return |
| task = _attribute_tasks()[attributes[0]] |
| st.caption(f"Inferred task: **{task}**") |
|
|
| probe_kinds = _select_probe_kinds(task) |
| n_pca_components = _select_pca_components() |
|
|
| source, location, model_name = store_cache_parts(store) |
| available_layers = store_layers_cached( |
| source, |
| location, |
| model_name, |
| mask_strategy.value, |
| (variant,), |
| tuple(persona_ids), |
| ) |
| if not available_layers: |
| st.info("No layers found for the selected personas.") |
| return |
| num_layers = max(available_layers) + 1 |
| layers = _select_layers(num_layers) |
| min_class_count = _MIN_CLASS_COUNT |
| seed = 0 |
|
|
| inputs = SweepInputs( |
| source=source, |
| location=location, |
| model_name=model_name, |
| mask_value=mask_strategy.value, |
| variant=variant, |
| persona_ids=tuple(persona_ids), |
| attributes=tuple(attributes), |
| task=task, |
| probe_kinds=tuple(probe_kinds), |
| n_pca_components=n_pca_components, |
| layers=tuple(layers), |
| min_class_count=min_class_count, |
| seed=int(seed), |
| ) |
|
|
| run = st.button("Run sweep", type="primary", key="probe:run") |
| state_key = "probe:last_result" |
| if run: |
| with st.spinner("Evaluating probes across layers..."): |
| try: |
| sweep, per_attr = cached_sweep(inputs) |
| except Exception as exc: |
| st.error(f"Sweep failed: {exc}") |
| st.session_state.pop(state_key, None) |
| return |
| st.session_state[state_key] = (sweep, per_attr, inputs) |
|
|
| if state_key in st.session_state: |
| saved_result = st.session_state[state_key] |
| if len(saved_result) != 3: |
| |
| st.session_state.pop(state_key, None) |
| else: |
| sweep, per_attr, result_inputs = saved_result |
| _show_sweep( |
| sweep, |
| per_attr, |
| result_inputs.attributes, |
| result_inputs.task, |
| result_inputs, |
| ) |
|
|