Neural-MRI / backend /tests /test_api_sae.py
Hiconcep's picture
Upload folder using huggingface_hub
0ce9643 verified
"""API tests for /api/sae endpoints."""
from unittest.mock import MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
from neural_mri.api.routes_sae import get_model_manager, get_sae_manager, get_scan_cache
from neural_mri.core.scan_cache import ScanCache
from neural_mri.main import app
@pytest.fixture
def _override_deps(mock_model_manager, mock_sae_manager):
app.dependency_overrides[get_model_manager] = lambda: mock_model_manager
app.dependency_overrides[get_sae_manager] = lambda: mock_sae_manager
app.dependency_overrides[get_scan_cache] = lambda: ScanCache(max_entries=5)
yield
app.dependency_overrides.clear()
@pytest.fixture
def _override_no_model():
mm = MagicMock()
mm.is_loaded = False
mm.model_id = None
app.dependency_overrides[get_model_manager] = lambda: mm
yield
app.dependency_overrides.clear()
async def test_sae_info_model_loaded(_override_deps):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get("/api/sae/info")
assert resp.status_code == 200
data = resp.json()
assert data["available"] is True
assert data["model_id"] == "gpt2"
async def test_sae_info_no_model(_override_no_model):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get("/api/sae/info")
assert resp.status_code == 200
data = resp.json()
assert data["available"] is False
async def test_sae_support_returns_dict(_override_deps):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.get("/api/sae/support")
assert resp.status_code == 200
data = resp.json()
assert "gpt2" in data
assert isinstance(data["gpt2"], bool)
async def test_sae_scan_no_model_400(_override_no_model):
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client:
resp = await client.post("/api/sae/scan", json={"prompt": "test", "layer_idx": 0})
assert resp.status_code == 400