File size: 7,046 Bytes
abd4352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612eafc
abd4352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
tests/integration/test_query_api.py
FastAPI integration tests — agent graph is fully mocked so no LLM/DB calls happen.
Uses TestClient (synchronous ASGI runner).
"""

import json
import uuid
import pytest
from unittest.mock import patch, MagicMock


def _make_final_state(session_id: str) -> dict:
    return {
        "session_id": session_id,
        "user_id": "test",
        "user_query": "top products",
        "connector_id": "neon:public",
        "intent": "sql",
        "query_plan": {},
        "relevant_tables": [],
        "schema_context": "",
        "memory_context": "",
        "generated_code": "SELECT product, SUM(amount) FROM orders GROUP BY product",
        "code_type": "sql",
        "sql_dialect": "postgres",
        "execution_result": [{"product": "Widget", "total": 299.97}],
        "execution_error": None,
        "from_cache": False,
        "error_class": None,
        "correction_attempts": 0,
        "max_corrections": 3,
        "insight_text": "Widget is the top product with $300 in revenue.",
        "chart_spec": {
            "type": "bar",
            "plotly_json": {
                "data": [{"type": "bar", "x": ["Widget"], "y": [299.97]}],
                "layout": {"title": "top products"},
            },
        },
        "history_id": str(uuid.uuid4()),
        "latency_ms": 850,
        "stream_tokens": [],
    }


@pytest.fixture
def mocked_graph(session_id):
    """Patch get_graph() to return a mock that returns controlled state."""
    state = _make_final_state(session_id)
    mock_graph = MagicMock()
    mock_graph.invoke.return_value = state
    with patch("api.routers.query.get_graph", return_value=mock_graph):
        yield mock_graph, state


@pytest.mark.integration
class TestQueryRunEndpoint:
    def test_returns_200(self, api_client, mocked_graph, session_id):
        resp = api_client.post("/api/query/run", json={
            "user_query": "top products",
            "connector_id": "neon:public",
            "session_id": session_id,
        })
        assert resp.status_code == 200

    def test_response_has_required_fields(self, api_client, mocked_graph, session_id):
        resp = api_client.post("/api/query/run", json={
            "user_query": "top products",
            "connector_id": "neon:public",
            "session_id": session_id,
        })
        body = resp.json()
        required = ["session_id", "intent", "generated_code", "execution_result",
                    "insight_text", "chart_spec", "from_cache", "latency_ms"]
        for field in required:
            assert field in body, f"Missing field: {field}"

    def test_insight_text_returned(self, api_client, mocked_graph, session_id):
        resp = api_client.post("/api/query/run", json={
            "user_query": "top products",
            "connector_id": "neon:public",
            "session_id": session_id,
        })
        assert resp.json()["insight_text"] == "Widget is the top product with $300 in revenue."

    def test_chart_spec_returned(self, api_client, mocked_graph, session_id):
        resp = api_client.post("/api/query/run", json={
            "user_query": "top products",
            "connector_id": "neon:public",
            "session_id": session_id,
        })
        chart = resp.json()["chart_spec"]
        assert chart is not None
        assert chart["type"] == "bar"

    def test_empty_query_rejected(self, api_client):
        resp = api_client.post("/api/query/run", json={
            "user_query": "",
            "connector_id": "neon:public",
        })
        assert resp.status_code == 422

    def test_missing_connector_id_rejected(self, api_client):
        resp = api_client.post("/api/query/run", json={"user_query": "test"})
        assert resp.status_code == 422

    def test_session_id_auto_generated_when_absent(self, api_client, mocked_graph):
        resp = api_client.post("/api/query/run", json={
            "user_query": "top products",
            "connector_id": "neon:public",
        })
        assert resp.status_code == 200
        assert resp.json()["session_id"]

    def test_graph_invoke_called_once(self, api_client, mocked_graph, session_id):
        mock_graph, _ = mocked_graph
        api_client.post("/api/query/run", json={
            "user_query": "top products",
            "connector_id": "neon:public",
            "session_id": session_id,
        })
        mock_graph.invoke.assert_called_once()

    def test_graph_exception_returns_500(self, api_client, session_id):
        mock_graph = MagicMock()
        mock_graph.invoke.side_effect = RuntimeError("LLM timeout")
        with patch("api.routers.query.get_graph", return_value=mock_graph):
            resp = api_client.post("/api/query/run", json={
                "user_query": "top products",
                "connector_id": "neon:public",
                "session_id": session_id,
            })
        assert resp.status_code == 500


@pytest.mark.integration
class TestHealthEndpoint:
    def test_health_returns_ok(self, api_client):
        resp = api_client.get("/health")
        assert resp.status_code == 200
        assert resp.json() == {"status": "ok", "version": "2.0.0"}


@pytest.mark.integration
class TestStreamEndpoint:
    def test_stream_returns_200(self, api_client, mocked_graph, session_id):
        resp = api_client.post("/api/query/stream", json={
            "user_query": "top products",
            "connector_id": "neon:public",
            "session_id": session_id,
        })
        assert resp.status_code == 200

    def test_stream_content_type_is_event_stream(self, api_client, mocked_graph, session_id):
        resp = api_client.post("/api/query/stream", json={
            "user_query": "top products",
            "connector_id": "neon:public",
            "session_id": session_id,
        })
        assert "text/event-stream" in resp.headers.get("content-type", "")

    def test_stream_contains_done_event(self, api_client, mocked_graph, session_id):
        resp = api_client.post("/api/query/stream", json={
            "user_query": "top products",
            "connector_id": "neon:public",
            "session_id": session_id,
        })
        raw = resp.text
        events = [
            json.loads(line[len("data: "):])
            for line in raw.split("\n")
            if line.startswith("data: ")
        ]
        done_events = [e for e in events if e.get("done")]
        assert len(done_events) == 1

    def test_stream_emits_token_events(self, api_client, mocked_graph, session_id):
        resp = api_client.post("/api/query/stream", json={
            "user_query": "top products",
            "connector_id": "neon:public",
            "session_id": session_id,
        })
        raw = resp.text
        events = [
            json.loads(line[len("data: "):])
            for line in raw.split("\n")
            if line.startswith("data: ")
        ]
        token_events = [e for e in events if "token" in e]
        assert len(token_events) > 0