persona-ui / tests /test_probe_sweep.py
Jac-Zac
Updated to latest persona-vector version
bf30281
from __future__ import annotations
from types import SimpleNamespace
import torch
from persona_vectors.analysis import LayeredSamples
from persona_vectors.probes import AttributeLabels
from tabs import probe_sweep
def test_cached_sweep_keeps_per_attribute_samples_and_full_plus_pca(monkeypatch):
samples = LayeredSamples(
vectors=torch.zeros((3, 2, 4)),
labels=["p0", "p1", "p2"],
hover_text=["p0", "p1", "p2"],
)
sweep_calls: list[tuple[str, int | None]] = []
monkeypatch.setattr(
probe_sweep,
"load_persona_vectors_cached",
lambda *args: samples,
)
monkeypatch.setattr(
probe_sweep,
"synth_persona_dataset_cached",
lambda: SimpleNamespace(),
)
def labels_for(_dataset, attribute, _persona_ids, *, task):
return AttributeLabels(
attribute_name=attribute,
task=task,
y=torch.tensor([0, 1, 0]).numpy(),
labels=["a", "b", "a"],
class_names=["a", "b"],
)
monkeypatch.setattr(probe_sweep, "attribute_probe_labels", labels_for)
def filtered(input_samples, labels, *, min_count):
assert min_count == 2
return input_samples, labels
monkeypatch.setattr(
probe_sweep,
"filter_attribute_samples_min_count",
filtered,
)
def sweep(
input_samples,
labels,
*,
layers,
probe_kinds,
n_pca_components,
seed,
):
assert input_samples is samples
assert layers == [0, 1]
assert probe_kinds == ["logistic_regression"]
assert seed == 0
sweep_calls.append((labels.attribute_name, n_pca_components))
return [
{
"attribute": labels.attribute_name,
"layer": 0,
"probe_kind": probe_kinds[0],
"balanced_accuracy": 0.5,
}
]
monkeypatch.setattr(probe_sweep, "sweep_attribute", sweep)
inputs = probe_sweep.SweepInputs(
source="src",
location="loc",
model_name="model",
mask_value="answer_mean",
variant="templated",
persona_ids=("p0", "p1", "p2"),
attributes=("sex", "gender"),
task="binary",
probe_kinds=("logistic_regression",),
n_pca_components=2,
layers=(0, 1),
min_class_count=2,
seed=0,
)
rows_by_label, per_attr = probe_sweep.cached_sweep.__wrapped__(inputs)
assert list(rows_by_label) == ["full", "pca2"]
assert [row["attribute"] for row in rows_by_label["full"]] == ["sex", "gender"]
assert set(per_attr) == {"sex", "gender"}
assert sweep_calls == [
("sex", None),
("gender", None),
("sex", 2),
("gender", 2),
]