File size: 6,943 Bytes
253d988
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""
tests/test_buffer.py
Tests for TopKBuffer (pure Python / heapq β€” no torch, no model).
"""
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src"))

from sae_gemma.capture_activations import TopKBuffer


# ── Basic keep-top-k behaviour ────────────────────────────────────────────────


def test_topk_keeps_highest_activations():
    """After inserting 100 items into a top-5 buffer, the 5 highest survive."""
    buf = TopKBuffer(n_features=1, top_k=5)
    for i in range(100):
        buf.update(0, float(i), {"token": f"tok{i}", "token_pos": i, "seq_id": 0, "context": ""})

    # Heap items: (activation, metadata)
    heap = buf.heaps[0]
    surviving_activations = sorted([item[0] for item in heap], reverse=True)
    # Should be [99, 98, 97, 96, 95]
    assert surviving_activations == [99.0, 98.0, 97.0, 96.0, 95.0], surviving_activations


def test_topk_does_not_exceed_k():
    """Buffer never holds more than top_k items per feature."""
    k = 7
    buf = TopKBuffer(n_features=3, top_k=k)
    for fid in range(3):
        for i in range(50):
            buf.update(fid, float(i), {"token": "x", "token_pos": i, "seq_id": 0, "context": ""})
        assert len(buf.heaps[fid]) == k


def test_topk_updates_when_higher():
    """A new item that beats the current minimum replaces it."""
    buf = TopKBuffer(n_features=1, top_k=3)
    buf.update(0, 1.0, {"token": "a", "token_pos": 0, "seq_id": 0, "context": ""})
    buf.update(0, 2.0, {"token": "b", "token_pos": 1, "seq_id": 0, "context": ""})
    buf.update(0, 3.0, {"token": "c", "token_pos": 2, "seq_id": 0, "context": ""})

    # All three slots full; minimum is 1.0
    assert len(buf.heaps[0]) == 3

    # Insert something lower β€” should be rejected
    buf.update(0, 0.5, {"token": "low", "token_pos": 3, "seq_id": 0, "context": ""})
    acts = {item[0] for item in buf.heaps[0]}
    assert 0.5 not in acts

    # Insert something higher β€” should replace the minimum (1.0)
    buf.update(0, 10.0, {"token": "high", "token_pos": 4, "seq_id": 0, "context": ""})
    acts = {item[0] for item in buf.heaps[0]}
    assert 10.0 in acts
    assert 1.0 not in acts


def test_topk_does_not_update_when_lower():
    """Items below the current minimum are silently dropped."""
    buf = TopKBuffer(n_features=1, top_k=2)
    buf.update(0, 5.0, {"token": "x", "token_pos": 0, "seq_id": 0, "context": ""})
    buf.update(0, 6.0, {"token": "y", "token_pos": 1, "seq_id": 0, "context": ""})
    # Heap full; current min = 5.0
    buf.update(0, 3.0, {"token": "z", "token_pos": 2, "seq_id": 0, "context": ""})
    assert len(buf.heaps[0]) == 2
    acts = {item[0] for item in buf.heaps[0]}
    assert 3.0 not in acts


# ── to_dataframe ──────────────────────────────────────────────────────────────


def test_to_dataframe_shape():
    """to_dataframe returns one row per (feature, rank) pair."""
    n_features = 4
    top_k = 3
    buf = TopKBuffer(n_features=n_features, top_k=top_k)
    for fid in range(n_features):
        for i in range(top_k):
            buf.update(fid, float(i + 1), {
                "token": f"t{i}", "token_pos": i, "seq_id": fid * 10 + i, "context": "ctx"
            })

    df = buf.to_dataframe()
    assert len(df) == n_features * top_k
    assert set(df.columns) >= {"feature_id", "rank", "activation", "token", "context", "token_pos", "seq_id"}


def test_to_dataframe_dtypes():
    """to_dataframe enforces the expected dtypes."""
    buf = TopKBuffer(n_features=2, top_k=2)
    for fid in range(2):
        for i in range(2):
            buf.update(fid, float(i + 1), {"token": "a", "token_pos": i, "seq_id": i, "context": ""})

    df = buf.to_dataframe()
    assert str(df["feature_id"].dtype) == "int32"
    assert str(df["rank"].dtype) == "int8"
    assert str(df["activation"].dtype) == "float32"
    assert str(df["token_pos"].dtype) == "int32"
    assert str(df["seq_id"].dtype) == "int64"


def test_to_dataframe_ranks_descending():
    """Rank 0 should have the highest activation within each feature."""
    buf = TopKBuffer(n_features=1, top_k=5)
    for v in [10.0, 3.0, 7.0, 1.0, 5.0]:
        buf.update(0, v, {"token": str(v), "token_pos": 0, "seq_id": 0, "context": ""})

    df = buf.to_dataframe()
    feat_df = df[df["feature_id"] == 0].sort_values("rank")
    acts = feat_df["activation"].tolist()
    # Must be strictly descending
    for i in range(len(acts) - 1):
        assert acts[i] >= acts[i + 1], f"Not descending: {acts}"


def test_to_dataframe_empty_buffer():
    """Buffer with no updates produces an empty DataFrame with correct columns."""
    buf = TopKBuffer(n_features=3, top_k=5)
    df = buf.to_dataframe()
    assert len(df) == 0
    assert "feature_id" in df.columns


# ── Edge cases ────────────────────────────────────────────────────────────────


def test_edge_case_n_features_1_top_k_1():
    """n_features=1, top_k=1: only the single highest activation survives."""
    buf = TopKBuffer(n_features=1, top_k=1)
    for v in [0.1, 0.9, 0.5, 0.8]:
        buf.update(0, v, {"token": str(v), "token_pos": 0, "seq_id": 0, "context": ""})

    assert len(buf.heaps[0]) == 1
    assert abs(buf.heaps[0][0][0] - 0.9) < 1e-6

    df = buf.to_dataframe()
    assert len(df) == 1
    assert abs(float(df["activation"].iloc[0]) - 0.9) < 1e-6


def test_edge_case_fewer_inserts_than_k():
    """Buffer only partially filled β€” should hold exactly what was inserted."""
    buf = TopKBuffer(n_features=1, top_k=10)
    for i in range(3):
        buf.update(0, float(i), {"token": "t", "token_pos": i, "seq_id": i, "context": ""})

    assert len(buf.heaps[0]) == 3
    df = buf.to_dataframe()
    assert len(df) == 3


def test_multiple_features_independent():
    """Updates to feature 0 must not affect feature 1."""
    buf = TopKBuffer(n_features=2, top_k=3)
    for v in [100.0, 200.0, 300.0]:
        buf.update(0, v, {"token": "f0", "token_pos": 0, "seq_id": 0, "context": ""})
    # Feature 1 receives no updates
    assert len(buf.heaps[0]) == 3
    assert len(buf.heaps[1]) == 0


def test_to_dataframe_metadata_preserved():
    """to_dataframe correctly passes through token, context, token_pos, seq_id."""
    buf = TopKBuffer(n_features=1, top_k=1)
    buf.update(0, 42.0, {
        "token": "hello",
        "context": "some context",
        "token_pos": 7,
        "seq_id": 99,
    })
    df = buf.to_dataframe()
    row = df.iloc[0]
    assert row["token"] == "hello"
    assert row["context"] == "some context"
    assert row["token_pos"] == 7
    assert row["seq_id"] == 99