Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from fastapi.testclient import TestClient | |
| import pytest | |
| from app.api.main import app | |
| from app.core.schemas import ( | |
| AnalysisResult, | |
| CurrentUserResponse, | |
| ModelCapability, | |
| RuntimeMetadata, | |
| SessionResponse, | |
| ValidationMetadata, | |
| ) | |
| def auth_override(monkeypatch): | |
| monkeypatch.setattr( | |
| "app.api.main.require_user", | |
| lambda _request, _settings: type( | |
| "User", | |
| (), | |
| {"id": "user-123", "display_name": "Test User", "authenticated": True}, | |
| )(), | |
| ) | |
| monkeypatch.setattr( | |
| "app.api.main.get_optional_user", | |
| lambda _request: type( | |
| "User", | |
| (), | |
| { | |
| "id": "user-123", | |
| "username": "tester", | |
| "display_name": "Test User", | |
| "avatar_url": "https://example.com/avatar.png", | |
| "authenticated": True, | |
| }, | |
| )(), | |
| ) | |
| def test_healthz_returns_runtime_flags() -> None: | |
| with TestClient(app) as client: | |
| response = client.get("/healthz") | |
| assert response.status_code == 200 | |
| payload = response.json() | |
| assert payload["status"] == "ok" | |
| assert "cuda_available" in payload | |
| assert "dtype_preference" in payload | |
| def test_analyze_delegates_to_runtime(monkeypatch) -> None: | |
| def fake_compute_attribution_analysis(**_kwargs): | |
| return AnalysisResult( | |
| question="Why?", | |
| model_name="fake-model", | |
| answer="Because.", | |
| raw_trace_text="<think>Alpha.</think>", | |
| normalized_trace_text="Alpha.", | |
| sentences=["Alpha."], | |
| sentence_token_ranges=[(0, 1)], | |
| suppression_matrix=[[0.0]], | |
| raw_suppression_matrix=[[0.0]], | |
| outgoing_importance=[0.0], | |
| incoming_importance=[0.0], | |
| top_edges=[], | |
| runtime_metadata=RuntimeMetadata( | |
| device="cpu", | |
| capability=ModelCapability(supports_attribution=True, layer_count=2, attention_impl="eager"), | |
| ), | |
| validation_metadata=ValidationMetadata(enabled=False, top_k=0), | |
| ) | |
| monkeypatch.setattr("app.api.main.compute_attribution_analysis", fake_compute_attribution_analysis) | |
| with TestClient(app) as client: | |
| response = client.post( | |
| "/api/analyze", | |
| json={ | |
| "question": "Why?", | |
| "max_new_tokens": 8, | |
| "validate_top_k": 0, | |
| }, | |
| ) | |
| assert response.status_code == 200 | |
| payload = response.json() | |
| assert payload["answer"] == "Because." | |
| assert payload["model_name"] == "fake-model" | |
| def test_me_reports_current_user() -> None: | |
| with TestClient(app) as client: | |
| response = client.get("/api/me") | |
| assert response.status_code == 200 | |
| payload = CurrentUserResponse.model_validate(response.json()) | |
| assert payload.authenticated is True | |
| assert payload.username == "tester" | |
| def test_root_serves_frontend() -> None: | |
| with TestClient(app) as client: | |
| response = client.get("/") | |
| assert response.status_code == 200 | |
| assert "Thought Anchors" in response.text | |
| def test_session_routes_use_service(monkeypatch) -> None: | |
| class FakeSessionService: | |
| def __init__(self) -> None: | |
| self.payload = { | |
| "id": "session-123", | |
| "status": "completed", | |
| "question": "Why?", | |
| "model_name": "fake-model", | |
| "error": None, | |
| "created_at": "2026-04-06T00:00:00+00:00", | |
| "updated_at": "2026-04-06T00:00:05+00:00", | |
| "answer": "Because.", | |
| "raw_trace_text": "<think>Alpha.</think>", | |
| "normalized_trace_text": "Alpha.", | |
| "sentences": ["Alpha."], | |
| "generation_metadata": {"max_new_tokens": 8}, | |
| "analysis": AnalysisResult( | |
| question="Why?", | |
| model_name="fake-model", | |
| answer="Because.", | |
| raw_trace_text="<think>Alpha.</think>", | |
| normalized_trace_text="Alpha.", | |
| sentences=["Alpha."], | |
| sentence_token_ranges=[(0, 1)], | |
| suppression_matrix=[[0.0]], | |
| raw_suppression_matrix=[[0.0]], | |
| outgoing_importance=[0.0], | |
| incoming_importance=[0.0], | |
| top_edges=[], | |
| runtime_metadata=RuntimeMetadata( | |
| device="cpu", | |
| capability=ModelCapability( | |
| supports_attribution=True, | |
| layer_count=2, | |
| attention_impl="eager", | |
| ), | |
| ), | |
| validation_metadata=ValidationMetadata(enabled=False, top_k=0), | |
| ).model_dump(), | |
| } | |
| def create_session(self, _request, **_kwargs): | |
| return SessionResponse.model_validate(self.payload) | |
| def get_session_payload(self, _session_id: str, **_kwargs): | |
| return self.payload | |
| def start_analysis(self, _session_id: str, **_kwargs): | |
| return SessionResponse.model_validate(self.payload) | |
| def list_sessions(self, _owner_id: str, **_kwargs): | |
| return [self.payload] | |
| def get_analysis_result(self, _session_id: str, **_kwargs): | |
| return AnalysisResult.model_validate(self.payload["analysis"]) | |
| monkeypatch.setattr("app.api.main.get_session_service", lambda: FakeSessionService()) | |
| with TestClient(app) as client: | |
| listing = client.get("/api/sessions") | |
| created = client.post("/api/sessions", json={"question": "Why?"}) | |
| session = client.get("/api/sessions/session-123") | |
| result = client.get("/api/sessions/session-123/result") | |
| exported_json = client.get("/api/sessions/session-123/export.json") | |
| exported_csv = client.get("/api/sessions/session-123/export.csv") | |
| assert listing.status_code == 200 | |
| assert created.status_code == 200 | |
| assert session.status_code == 200 | |
| assert result.status_code == 200 | |
| assert exported_json.status_code == 200 | |
| assert exported_csv.status_code == 200 | |
| assert created.json()["id"] == "session-123" | |
| assert session.json()["answer"] == "Because." | |
| assert result.json()["analysis"]["model_name"] == "fake-model" | |