Spaces:
Runtime error
Runtime error
| """API tests for /api/sae endpoints.""" | |
| from unittest.mock import MagicMock | |
| import pytest | |
| from httpx import ASGITransport, AsyncClient | |
| from neural_mri.api.routes_sae import get_model_manager, get_sae_manager, get_scan_cache | |
| from neural_mri.core.scan_cache import ScanCache | |
| from neural_mri.main import app | |
| def _override_deps(mock_model_manager, mock_sae_manager): | |
| app.dependency_overrides[get_model_manager] = lambda: mock_model_manager | |
| app.dependency_overrides[get_sae_manager] = lambda: mock_sae_manager | |
| app.dependency_overrides[get_scan_cache] = lambda: ScanCache(max_entries=5) | |
| yield | |
| app.dependency_overrides.clear() | |
| def _override_no_model(): | |
| mm = MagicMock() | |
| mm.is_loaded = False | |
| mm.model_id = None | |
| app.dependency_overrides[get_model_manager] = lambda: mm | |
| yield | |
| app.dependency_overrides.clear() | |
| async def test_sae_info_model_loaded(_override_deps): | |
| async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: | |
| resp = await client.get("/api/sae/info") | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["available"] is True | |
| assert data["model_id"] == "gpt2" | |
| async def test_sae_info_no_model(_override_no_model): | |
| async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: | |
| resp = await client.get("/api/sae/info") | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["available"] is False | |
| async def test_sae_support_returns_dict(_override_deps): | |
| async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: | |
| resp = await client.get("/api/sae/support") | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert "gpt2" in data | |
| assert isinstance(data["gpt2"], bool) | |
| async def test_sae_scan_no_model_400(_override_no_model): | |
| async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: | |
| resp = await client.post("/api/sae/scan", json={"prompt": "test", "layer_idx": 0}) | |
| assert resp.status_code == 400 | |