from __future__ import annotations from types import SimpleNamespace import torch from persona_vectors.analysis import LayeredSamples from persona_vectors.probes import AttributeLabels from tabs import probe_sweep def test_cached_sweep_keeps_per_attribute_samples_and_full_plus_pca(monkeypatch): samples = LayeredSamples( vectors=torch.zeros((3, 2, 4)), labels=["p0", "p1", "p2"], hover_text=["p0", "p1", "p2"], ) sweep_calls: list[tuple[str, int | None]] = [] monkeypatch.setattr( probe_sweep, "load_persona_vectors_cached", lambda *args: samples, ) monkeypatch.setattr( probe_sweep, "synth_persona_dataset_cached", lambda: SimpleNamespace(), ) def labels_for(_dataset, attribute, _persona_ids, *, task): return AttributeLabels( attribute_name=attribute, task=task, y=torch.tensor([0, 1, 0]).numpy(), labels=["a", "b", "a"], class_names=["a", "b"], ) monkeypatch.setattr(probe_sweep, "attribute_probe_labels", labels_for) def filtered(input_samples, labels, *, min_count): assert min_count == 2 return input_samples, labels monkeypatch.setattr( probe_sweep, "filter_attribute_samples_min_count", filtered, ) def sweep( input_samples, labels, *, layers, probe_kinds, n_pca_components, seed, ): assert input_samples is samples assert layers == [0, 1] assert probe_kinds == ["logistic_regression"] assert seed == 0 sweep_calls.append((labels.attribute_name, n_pca_components)) return [ { "attribute": labels.attribute_name, "layer": 0, "probe_kind": probe_kinds[0], "balanced_accuracy": 0.5, } ] monkeypatch.setattr(probe_sweep, "sweep_attribute", sweep) inputs = probe_sweep.SweepInputs( source="src", location="loc", model_name="model", mask_value="answer_mean", variant="templated", persona_ids=("p0", "p1", "p2"), attributes=("sex", "gender"), task="binary", probe_kinds=("logistic_regression",), n_pca_components=2, layers=(0, 1), min_class_count=2, seed=0, ) rows_by_label, per_attr = probe_sweep.cached_sweep.__wrapped__(inputs) assert list(rows_by_label) == ["full", "pca2"] assert [row["attribute"] for row in rows_by_label["full"]] == ["sex", "gender"] assert set(per_attr) == {"sex", "gender"} assert sweep_calls == [ ("sex", None), ("gender", None), ("sex", 2), ("gender", 2), ]