File size: 2,820 Bytes
b279884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf30281
 
 
 
 
 
 
 
 
b279884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from __future__ import annotations

from types import SimpleNamespace

import torch
from persona_vectors.analysis import LayeredSamples
from persona_vectors.probes import AttributeLabels

from tabs import probe_sweep


def test_cached_sweep_keeps_per_attribute_samples_and_full_plus_pca(monkeypatch):
    samples = LayeredSamples(
        vectors=torch.zeros((3, 2, 4)),
        labels=["p0", "p1", "p2"],
        hover_text=["p0", "p1", "p2"],
    )
    sweep_calls: list[tuple[str, int | None]] = []

    monkeypatch.setattr(
        probe_sweep,
        "load_persona_vectors_cached",
        lambda *args: samples,
    )
    monkeypatch.setattr(
        probe_sweep,
        "synth_persona_dataset_cached",
        lambda: SimpleNamespace(),
    )

    def labels_for(_dataset, attribute, _persona_ids, *, task):
        return AttributeLabels(
            attribute_name=attribute,
            task=task,
            y=torch.tensor([0, 1, 0]).numpy(),
            labels=["a", "b", "a"],
            class_names=["a", "b"],
        )

    monkeypatch.setattr(probe_sweep, "attribute_probe_labels", labels_for)

    def filtered(input_samples, labels, *, min_count):
        assert min_count == 2
        return input_samples, labels

    monkeypatch.setattr(
        probe_sweep,
        "filter_attribute_samples_min_count",
        filtered,
    )

    def sweep(
        input_samples,
        labels,
        *,
        layers,
        probe_kinds,
        n_pca_components,
        seed,
    ):
        assert input_samples is samples
        assert layers == [0, 1]
        assert probe_kinds == ["logistic_regression"]
        assert seed == 0
        sweep_calls.append((labels.attribute_name, n_pca_components))
        return [
            {
                "attribute": labels.attribute_name,
                "layer": 0,
                "probe_kind": probe_kinds[0],
                "balanced_accuracy": 0.5,
            }
        ]

    monkeypatch.setattr(probe_sweep, "sweep_attribute", sweep)

    inputs = probe_sweep.SweepInputs(
        source="src",
        location="loc",
        model_name="model",
        mask_value="answer_mean",
        variant="templated",
        persona_ids=("p0", "p1", "p2"),
        attributes=("sex", "gender"),
        task="binary",
        probe_kinds=("logistic_regression",),
        n_pca_components=2,
        layers=(0, 1),
        min_class_count=2,
        seed=0,
    )

    rows_by_label, per_attr = probe_sweep.cached_sweep.__wrapped__(inputs)

    assert list(rows_by_label) == ["full", "pca2"]
    assert [row["attribute"] for row in rows_by_label["full"]] == ["sex", "gender"]
    assert set(per_attr) == {"sex", "gender"}
    assert sweep_calls == [
        ("sex", None),
        ("gender", None),
        ("sex", 2),
        ("gender", 2),
    ]