| """ |
| tests/test_autointerp.py |
| Pure-logic tests for autointerp.py helpers. |
| No model loading, no subprocess calls. |
| """ |
| import json |
| import sys |
| import threading |
| from pathlib import Path |
|
|
| import pytest |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) |
|
|
| from sae_gemma.autointerp import _append_to_cache, _format_snippets, _load_cache |
|
|
|
|
| |
|
|
|
|
| def test_format_snippets_marks_token(): |
| """Each activating token should be wrapped in <<...>> in the output.""" |
| rows = [ |
| {"context": "The cat sat on the mat", "token": "cat"}, |
| {"context": "Dogs are great pets", "token": "Dogs"}, |
| ] |
| result = _format_snippets(rows) |
| assert "<<cat>>" in result |
| assert "<<Dogs>>" in result |
|
|
|
|
| def test_format_snippets_numbering(): |
| """Each snippet is prefixed with a 1-based index.""" |
| rows = [{"context": "hello world", "token": "hello"}] |
| result = _format_snippets(rows) |
| assert result.startswith("1.") |
|
|
|
|
| def test_format_snippets_multiple_rows_numbered(): |
| rows = [ |
| {"context": "foo bar baz", "token": "foo"}, |
| {"context": "alpha beta", "token": "alpha"}, |
| {"context": "x y z", "token": "x"}, |
| ] |
| result = _format_snippets(rows) |
| assert "1." in result |
| assert "2." in result |
| assert "3." in result |
|
|
|
|
| def test_format_snippets_token_not_in_context(): |
| """If token is not a substring of context, the context is returned unchanged (no crash).""" |
| rows = [{"context": "some text here", "token": "absent_token"}] |
| result = _format_snippets(rows) |
| |
| assert "some text here" in result |
|
|
|
|
| def test_format_snippets_empty_token(): |
| """Empty token string: context is returned without any <<>> wrapping.""" |
| rows = [{"context": "plain text", "token": ""}] |
| result = _format_snippets(rows) |
| assert "<<>>" not in result |
| assert "plain text" in result |
|
|
|
|
| def test_format_snippets_caps_at_20(): |
| """Only the first 20 rows are used even if more are provided.""" |
| rows = [{"context": f"context {i}", "token": f"tok{i}"} for i in range(30)] |
| result = _format_snippets(rows) |
| lines = [l for l in result.strip().split("\n") if l] |
| assert len(lines) == 20 |
|
|
|
|
| def test_format_snippets_empty_list(): |
| """Empty input produces an empty string (no crash).""" |
| result = _format_snippets([]) |
| assert result == "" |
|
|
|
|
| def test_format_snippets_replaces_first_occurrence_only(): |
| """Only the first occurrence of the token in context is wrapped.""" |
| rows = [{"context": "cat cat cat", "token": "cat"}] |
| result = _format_snippets(rows) |
| |
| assert result.count("<<cat>>") == 1 |
|
|
|
|
| |
|
|
|
|
| def test_load_cache_missing_file(tmp_path): |
| """_load_cache returns {} when the file does not exist.""" |
| missing = tmp_path / "nonexistent.json" |
| result = _load_cache(missing) |
| assert result == {} |
|
|
|
|
| def test_load_cache_reads_correctly(tmp_path): |
| """_load_cache reads a valid JSON file and converts keys to int.""" |
| cache_file = tmp_path / "labels.json" |
| data = {"1": "some label", "42": "another label"} |
| cache_file.write_text(json.dumps(data), encoding="utf-8") |
|
|
| result = _load_cache(cache_file) |
| assert result == {1: "some label", 42: "another label"} |
|
|
|
|
| def test_load_cache_key_types_are_int(tmp_path): |
| """Keys returned by _load_cache must be Python ints, not strings.""" |
| cache_file = tmp_path / "labels.json" |
| cache_file.write_text(json.dumps({"7": "label"}), encoding="utf-8") |
| result = _load_cache(cache_file) |
| for k in result: |
| assert isinstance(k, int), f"Expected int key, got {type(k)}" |
|
|
|
|
| def test_load_cache_returns_empty_on_corrupt_json(tmp_path): |
| """_load_cache returns {} when the file contains invalid JSON.""" |
| bad_file = tmp_path / "bad.json" |
| bad_file.write_text("{this is not valid json", encoding="utf-8") |
| result = _load_cache(bad_file) |
| assert result == {} |
|
|
|
|
| def test_load_cache_empty_json_object(tmp_path): |
| """_load_cache returns {} for an empty JSON object.""" |
| cache_file = tmp_path / "empty.json" |
| cache_file.write_text("{}", encoding="utf-8") |
| result = _load_cache(cache_file) |
| assert result == {} |
|
|
|
|
| |
|
|
|
|
| def test_append_to_cache_creates_file(tmp_path): |
| """_append_to_cache creates the file if it does not exist.""" |
| cache_file = tmp_path / "labels.json" |
| assert not cache_file.exists() |
| _append_to_cache(cache_file, feature_id=5, label="detects cats") |
| assert cache_file.exists() |
|
|
|
|
| def test_append_to_cache_single_entry(tmp_path): |
| """A single append produces a valid JSON file with one entry.""" |
| cache_file = tmp_path / "labels.json" |
| _append_to_cache(cache_file, feature_id=10, label="detects numbers") |
| with cache_file.open() as f: |
| data = json.load(f) |
| assert "10" in data |
| assert data["10"] == "detects numbers" |
|
|
|
|
| def test_append_to_cache_accumulates(tmp_path): |
| """Multiple sequential appends accumulate without overwriting.""" |
| cache_file = tmp_path / "labels.json" |
| _append_to_cache(cache_file, feature_id=1, label="label one") |
| _append_to_cache(cache_file, feature_id=2, label="label two") |
| _append_to_cache(cache_file, feature_id=3, label="label three") |
|
|
| with cache_file.open() as f: |
| data = json.load(f) |
| assert len(data) == 3 |
| assert data["1"] == "label one" |
| assert data["2"] == "label two" |
| assert data["3"] == "label three" |
|
|
|
|
| def test_append_to_cache_overwrites_existing_key(tmp_path): |
| """Writing the same feature_id twice updates the label.""" |
| cache_file = tmp_path / "labels.json" |
| _append_to_cache(cache_file, feature_id=7, label="old label") |
| _append_to_cache(cache_file, feature_id=7, label="new label") |
|
|
| with cache_file.open() as f: |
| data = json.load(f) |
| assert data["7"] == "new label" |
| assert len(data) == 1 |
|
|
|
|
| def test_append_to_cache_thread_safe(tmp_path): |
| """Concurrent appends from multiple threads should not lose any entry.""" |
| cache_file = tmp_path / "labels.json" |
| n_threads = 20 |
| errors = [] |
|
|
| def worker(fid): |
| try: |
| _append_to_cache(cache_file, feature_id=fid, label=f"label-{fid}") |
| except Exception as exc: |
| errors.append(exc) |
|
|
| threads = [threading.Thread(target=worker, args=(i,)) for i in range(n_threads)] |
| for t in threads: |
| t.start() |
| for t in threads: |
| t.join() |
|
|
| assert not errors, f"Thread errors: {errors}" |
|
|
| with cache_file.open() as f: |
| data = json.load(f) |
|
|
| |
| assert len(data) == n_threads, f"Expected {n_threads} entries, got {len(data)}" |
| for i in range(n_threads): |
| assert str(i) in data, f"Missing feature_id {i}" |
|
|