Neural-MRI / backend /tests /test_sae_registry.py
Hiconcep's picture
Upload folder using huggingface_hub
0ce9643 verified
"""Tests for the SAE registry."""
from neural_mri.core.sae_registry import get_sae_info, list_sae_support
def test_get_sae_info_gpt2():
info = get_sae_info("gpt2")
assert info is not None
assert info["release"] == "gpt2-small-res-jb"
assert 0 in info["layers"]
assert info["d_sae"] == 24576
def test_get_sae_info_gemma():
info = get_sae_info("google/gemma-2-2b")
assert info is not None
assert info["release"] == "gemma-scope-2b-pt-res-canonical"
assert len(info["layers"]) == 26
def test_get_sae_info_unknown_returns_none():
assert get_sae_info("gpt2-medium") is None
assert get_sae_info("nonexistent") is None
def test_list_sae_support_has_all_registry_models():
support = list_sae_support()
assert "gpt2" in support
assert "gpt2-medium" in support
assert len(support) == 8
def test_list_sae_support_gpt2_true():
assert list_sae_support()["gpt2"] is True
def test_list_sae_support_gpt2_medium_false():
assert list_sae_support()["gpt2-medium"] is False