File size: 6,531 Bytes
fda8fb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from __future__ import annotations

from fastapi.testclient import TestClient
import pytest

from app.api.main import app
from app.core.schemas import (
    AnalysisResult,
    CurrentUserResponse,
    ModelCapability,
    RuntimeMetadata,
    SessionResponse,
    ValidationMetadata,
)


@pytest.fixture(autouse=True)
def auth_override(monkeypatch):
    monkeypatch.setattr(
        "app.api.main.require_user",
        lambda _request, _settings: type(
            "User",
            (),
            {"id": "user-123", "display_name": "Test User", "authenticated": True},
        )(),
    )
    monkeypatch.setattr(
        "app.api.main.get_optional_user",
        lambda _request: type(
            "User",
            (),
            {
                "id": "user-123",
                "username": "tester",
                "display_name": "Test User",
                "avatar_url": "https://example.com/avatar.png",
                "authenticated": True,
            },
        )(),
    )


def test_healthz_returns_runtime_flags() -> None:
    with TestClient(app) as client:
        response = client.get("/healthz")

    assert response.status_code == 200
    payload = response.json()
    assert payload["status"] == "ok"
    assert "cuda_available" in payload
    assert "dtype_preference" in payload


def test_analyze_delegates_to_runtime(monkeypatch) -> None:
    def fake_compute_attribution_analysis(**_kwargs):
        return AnalysisResult(
            question="Why?",
            model_name="fake-model",
            answer="Because.",
            raw_trace_text="<think>Alpha.</think>",
            normalized_trace_text="Alpha.",
            sentences=["Alpha."],
            sentence_token_ranges=[(0, 1)],
            suppression_matrix=[[0.0]],
            raw_suppression_matrix=[[0.0]],
            outgoing_importance=[0.0],
            incoming_importance=[0.0],
            top_edges=[],
            runtime_metadata=RuntimeMetadata(
                device="cpu",
                capability=ModelCapability(supports_attribution=True, layer_count=2, attention_impl="eager"),
            ),
            validation_metadata=ValidationMetadata(enabled=False, top_k=0),
        )

    monkeypatch.setattr("app.api.main.compute_attribution_analysis", fake_compute_attribution_analysis)

    with TestClient(app) as client:
        response = client.post(
            "/api/analyze",
            json={
                "question": "Why?",
                "max_new_tokens": 8,
                "validate_top_k": 0,
            },
        )

    assert response.status_code == 200
    payload = response.json()
    assert payload["answer"] == "Because."
    assert payload["model_name"] == "fake-model"


def test_me_reports_current_user() -> None:
    with TestClient(app) as client:
        response = client.get("/api/me")

    assert response.status_code == 200
    payload = CurrentUserResponse.model_validate(response.json())
    assert payload.authenticated is True
    assert payload.username == "tester"


def test_root_serves_frontend() -> None:
    with TestClient(app) as client:
        response = client.get("/")

    assert response.status_code == 200
    assert "Thought Anchors" in response.text


def test_session_routes_use_service(monkeypatch) -> None:
    class FakeSessionService:
        def __init__(self) -> None:
            self.payload = {
                "id": "session-123",
                "status": "completed",
                "question": "Why?",
                "model_name": "fake-model",
                "error": None,
                "created_at": "2026-04-06T00:00:00+00:00",
                "updated_at": "2026-04-06T00:00:05+00:00",
                "answer": "Because.",
                "raw_trace_text": "<think>Alpha.</think>",
                "normalized_trace_text": "Alpha.",
                "sentences": ["Alpha."],
                "generation_metadata": {"max_new_tokens": 8},
                "analysis": AnalysisResult(
                    question="Why?",
                    model_name="fake-model",
                    answer="Because.",
                    raw_trace_text="<think>Alpha.</think>",
                    normalized_trace_text="Alpha.",
                    sentences=["Alpha."],
                    sentence_token_ranges=[(0, 1)],
                    suppression_matrix=[[0.0]],
                    raw_suppression_matrix=[[0.0]],
                    outgoing_importance=[0.0],
                    incoming_importance=[0.0],
                    top_edges=[],
                    runtime_metadata=RuntimeMetadata(
                        device="cpu",
                        capability=ModelCapability(
                            supports_attribution=True,
                            layer_count=2,
                            attention_impl="eager",
                        ),
                    ),
                    validation_metadata=ValidationMetadata(enabled=False, top_k=0),
                ).model_dump(),
            }

        def create_session(self, _request, **_kwargs):
            return SessionResponse.model_validate(self.payload)

        def get_session_payload(self, _session_id: str, **_kwargs):
            return self.payload

        def start_analysis(self, _session_id: str, **_kwargs):
            return SessionResponse.model_validate(self.payload)

        def list_sessions(self, _owner_id: str, **_kwargs):
            return [self.payload]

        def get_analysis_result(self, _session_id: str, **_kwargs):
            return AnalysisResult.model_validate(self.payload["analysis"])

    monkeypatch.setattr("app.api.main.get_session_service", lambda: FakeSessionService())

    with TestClient(app) as client:
        listing = client.get("/api/sessions")
        created = client.post("/api/sessions", json={"question": "Why?"})
        session = client.get("/api/sessions/session-123")
        result = client.get("/api/sessions/session-123/result")
        exported_json = client.get("/api/sessions/session-123/export.json")
        exported_csv = client.get("/api/sessions/session-123/export.csv")

    assert listing.status_code == 200
    assert created.status_code == 200
    assert session.status_code == 200
    assert result.status_code == 200
    assert exported_json.status_code == 200
    assert exported_csv.status_code == 200
    assert created.json()["id"] == "session-123"
    assert session.json()["answer"] == "Because."
    assert result.json()["analysis"]["model_name"] == "fake-model"