| """ |
| tests/test_buffer.py |
| Tests for TopKBuffer (pure Python / heapq β no torch, no model). |
| """ |
| import sys |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) |
|
|
| from sae_gemma.capture_activations import TopKBuffer |
|
|
|
|
| |
|
|
|
|
| def test_topk_keeps_highest_activations(): |
| """After inserting 100 items into a top-5 buffer, the 5 highest survive.""" |
| buf = TopKBuffer(n_features=1, top_k=5) |
| for i in range(100): |
| buf.update(0, float(i), {"token": f"tok{i}", "token_pos": i, "seq_id": 0, "context": ""}) |
|
|
| |
| heap = buf.heaps[0] |
| surviving_activations = sorted([item[0] for item in heap], reverse=True) |
| |
| assert surviving_activations == [99.0, 98.0, 97.0, 96.0, 95.0], surviving_activations |
|
|
|
|
| def test_topk_does_not_exceed_k(): |
| """Buffer never holds more than top_k items per feature.""" |
| k = 7 |
| buf = TopKBuffer(n_features=3, top_k=k) |
| for fid in range(3): |
| for i in range(50): |
| buf.update(fid, float(i), {"token": "x", "token_pos": i, "seq_id": 0, "context": ""}) |
| assert len(buf.heaps[fid]) == k |
|
|
|
|
| def test_topk_updates_when_higher(): |
| """A new item that beats the current minimum replaces it.""" |
| buf = TopKBuffer(n_features=1, top_k=3) |
| buf.update(0, 1.0, {"token": "a", "token_pos": 0, "seq_id": 0, "context": ""}) |
| buf.update(0, 2.0, {"token": "b", "token_pos": 1, "seq_id": 0, "context": ""}) |
| buf.update(0, 3.0, {"token": "c", "token_pos": 2, "seq_id": 0, "context": ""}) |
|
|
| |
| assert len(buf.heaps[0]) == 3 |
|
|
| |
| buf.update(0, 0.5, {"token": "low", "token_pos": 3, "seq_id": 0, "context": ""}) |
| acts = {item[0] for item in buf.heaps[0]} |
| assert 0.5 not in acts |
|
|
| |
| buf.update(0, 10.0, {"token": "high", "token_pos": 4, "seq_id": 0, "context": ""}) |
| acts = {item[0] for item in buf.heaps[0]} |
| assert 10.0 in acts |
| assert 1.0 not in acts |
|
|
|
|
| def test_topk_does_not_update_when_lower(): |
| """Items below the current minimum are silently dropped.""" |
| buf = TopKBuffer(n_features=1, top_k=2) |
| buf.update(0, 5.0, {"token": "x", "token_pos": 0, "seq_id": 0, "context": ""}) |
| buf.update(0, 6.0, {"token": "y", "token_pos": 1, "seq_id": 0, "context": ""}) |
| |
| buf.update(0, 3.0, {"token": "z", "token_pos": 2, "seq_id": 0, "context": ""}) |
| assert len(buf.heaps[0]) == 2 |
| acts = {item[0] for item in buf.heaps[0]} |
| assert 3.0 not in acts |
|
|
|
|
| |
|
|
|
|
| def test_to_dataframe_shape(): |
| """to_dataframe returns one row per (feature, rank) pair.""" |
| n_features = 4 |
| top_k = 3 |
| buf = TopKBuffer(n_features=n_features, top_k=top_k) |
| for fid in range(n_features): |
| for i in range(top_k): |
| buf.update(fid, float(i + 1), { |
| "token": f"t{i}", "token_pos": i, "seq_id": fid * 10 + i, "context": "ctx" |
| }) |
|
|
| df = buf.to_dataframe() |
| assert len(df) == n_features * top_k |
| assert set(df.columns) >= {"feature_id", "rank", "activation", "token", "context", "token_pos", "seq_id"} |
|
|
|
|
| def test_to_dataframe_dtypes(): |
| """to_dataframe enforces the expected dtypes.""" |
| buf = TopKBuffer(n_features=2, top_k=2) |
| for fid in range(2): |
| for i in range(2): |
| buf.update(fid, float(i + 1), {"token": "a", "token_pos": i, "seq_id": i, "context": ""}) |
|
|
| df = buf.to_dataframe() |
| assert str(df["feature_id"].dtype) == "int32" |
| assert str(df["rank"].dtype) == "int8" |
| assert str(df["activation"].dtype) == "float32" |
| assert str(df["token_pos"].dtype) == "int32" |
| assert str(df["seq_id"].dtype) == "int64" |
|
|
|
|
| def test_to_dataframe_ranks_descending(): |
| """Rank 0 should have the highest activation within each feature.""" |
| buf = TopKBuffer(n_features=1, top_k=5) |
| for v in [10.0, 3.0, 7.0, 1.0, 5.0]: |
| buf.update(0, v, {"token": str(v), "token_pos": 0, "seq_id": 0, "context": ""}) |
|
|
| df = buf.to_dataframe() |
| feat_df = df[df["feature_id"] == 0].sort_values("rank") |
| acts = feat_df["activation"].tolist() |
| |
| for i in range(len(acts) - 1): |
| assert acts[i] >= acts[i + 1], f"Not descending: {acts}" |
|
|
|
|
| def test_to_dataframe_empty_buffer(): |
| """Buffer with no updates produces an empty DataFrame with correct columns.""" |
| buf = TopKBuffer(n_features=3, top_k=5) |
| df = buf.to_dataframe() |
| assert len(df) == 0 |
| assert "feature_id" in df.columns |
|
|
|
|
| |
|
|
|
|
| def test_edge_case_n_features_1_top_k_1(): |
| """n_features=1, top_k=1: only the single highest activation survives.""" |
| buf = TopKBuffer(n_features=1, top_k=1) |
| for v in [0.1, 0.9, 0.5, 0.8]: |
| buf.update(0, v, {"token": str(v), "token_pos": 0, "seq_id": 0, "context": ""}) |
|
|
| assert len(buf.heaps[0]) == 1 |
| assert abs(buf.heaps[0][0][0] - 0.9) < 1e-6 |
|
|
| df = buf.to_dataframe() |
| assert len(df) == 1 |
| assert abs(float(df["activation"].iloc[0]) - 0.9) < 1e-6 |
|
|
|
|
| def test_edge_case_fewer_inserts_than_k(): |
| """Buffer only partially filled β should hold exactly what was inserted.""" |
| buf = TopKBuffer(n_features=1, top_k=10) |
| for i in range(3): |
| buf.update(0, float(i), {"token": "t", "token_pos": i, "seq_id": i, "context": ""}) |
|
|
| assert len(buf.heaps[0]) == 3 |
| df = buf.to_dataframe() |
| assert len(df) == 3 |
|
|
|
|
| def test_multiple_features_independent(): |
| """Updates to feature 0 must not affect feature 1.""" |
| buf = TopKBuffer(n_features=2, top_k=3) |
| for v in [100.0, 200.0, 300.0]: |
| buf.update(0, v, {"token": "f0", "token_pos": 0, "seq_id": 0, "context": ""}) |
| |
| assert len(buf.heaps[0]) == 3 |
| assert len(buf.heaps[1]) == 0 |
|
|
|
|
| def test_to_dataframe_metadata_preserved(): |
| """to_dataframe correctly passes through token, context, token_pos, seq_id.""" |
| buf = TopKBuffer(n_features=1, top_k=1) |
| buf.update(0, 42.0, { |
| "token": "hello", |
| "context": "some context", |
| "token_pos": 7, |
| "seq_id": 99, |
| }) |
| df = buf.to_dataframe() |
| row = df.iloc[0] |
| assert row["token"] == "hello" |
| assert row["context"] == "some context" |
| assert row["token_pos"] == 7 |
| assert row["seq_id"] == 99 |
|
|