""" tests/integration/test_query_api.py FastAPI integration tests — agent graph is fully mocked so no LLM/DB calls happen. Uses TestClient (synchronous ASGI runner). """ import json import uuid import pytest from unittest.mock import patch, MagicMock def _make_final_state(session_id: str) -> dict: return { "session_id": session_id, "user_id": "test", "user_query": "top products", "connector_id": "neon:public", "intent": "sql", "query_plan": {}, "relevant_tables": [], "schema_context": "", "memory_context": "", "generated_code": "SELECT product, SUM(amount) FROM orders GROUP BY product", "code_type": "sql", "sql_dialect": "postgres", "execution_result": [{"product": "Widget", "total": 299.97}], "execution_error": None, "from_cache": False, "error_class": None, "correction_attempts": 0, "max_corrections": 3, "insight_text": "Widget is the top product with $300 in revenue.", "chart_spec": { "type": "bar", "plotly_json": { "data": [{"type": "bar", "x": ["Widget"], "y": [299.97]}], "layout": {"title": "top products"}, }, }, "history_id": str(uuid.uuid4()), "latency_ms": 850, "stream_tokens": [], } @pytest.fixture def mocked_graph(session_id): """Patch get_graph() to return a mock that returns controlled state.""" state = _make_final_state(session_id) mock_graph = MagicMock() mock_graph.invoke.return_value = state with patch("api.routers.query.get_graph", return_value=mock_graph): yield mock_graph, state @pytest.mark.integration class TestQueryRunEndpoint: def test_returns_200(self, api_client, mocked_graph, session_id): resp = api_client.post("/api/query/run", json={ "user_query": "top products", "connector_id": "neon:public", "session_id": session_id, }) assert resp.status_code == 200 def test_response_has_required_fields(self, api_client, mocked_graph, session_id): resp = api_client.post("/api/query/run", json={ "user_query": "top products", "connector_id": "neon:public", "session_id": session_id, }) body = resp.json() required = ["session_id", "intent", "generated_code", "execution_result", "insight_text", "chart_spec", "from_cache", "latency_ms"] for field in required: assert field in body, f"Missing field: {field}" def test_insight_text_returned(self, api_client, mocked_graph, session_id): resp = api_client.post("/api/query/run", json={ "user_query": "top products", "connector_id": "neon:public", "session_id": session_id, }) assert resp.json()["insight_text"] == "Widget is the top product with $300 in revenue." def test_chart_spec_returned(self, api_client, mocked_graph, session_id): resp = api_client.post("/api/query/run", json={ "user_query": "top products", "connector_id": "neon:public", "session_id": session_id, }) chart = resp.json()["chart_spec"] assert chart is not None assert chart["type"] == "bar" def test_empty_query_rejected(self, api_client): resp = api_client.post("/api/query/run", json={ "user_query": "", "connector_id": "neon:public", }) assert resp.status_code == 422 def test_missing_connector_id_rejected(self, api_client): resp = api_client.post("/api/query/run", json={"user_query": "test"}) assert resp.status_code == 422 def test_session_id_auto_generated_when_absent(self, api_client, mocked_graph): resp = api_client.post("/api/query/run", json={ "user_query": "top products", "connector_id": "neon:public", }) assert resp.status_code == 200 assert resp.json()["session_id"] def test_graph_invoke_called_once(self, api_client, mocked_graph, session_id): mock_graph, _ = mocked_graph api_client.post("/api/query/run", json={ "user_query": "top products", "connector_id": "neon:public", "session_id": session_id, }) mock_graph.invoke.assert_called_once() def test_graph_exception_returns_500(self, api_client, session_id): mock_graph = MagicMock() mock_graph.invoke.side_effect = RuntimeError("LLM timeout") with patch("api.routers.query.get_graph", return_value=mock_graph): resp = api_client.post("/api/query/run", json={ "user_query": "top products", "connector_id": "neon:public", "session_id": session_id, }) assert resp.status_code == 500 @pytest.mark.integration class TestHealthEndpoint: def test_health_returns_ok(self, api_client): resp = api_client.get("/health") assert resp.status_code == 200 assert resp.json() == {"status": "ok", "version": "2.0.0"} @pytest.mark.integration class TestStreamEndpoint: def test_stream_returns_200(self, api_client, mocked_graph, session_id): resp = api_client.post("/api/query/stream", json={ "user_query": "top products", "connector_id": "neon:public", "session_id": session_id, }) assert resp.status_code == 200 def test_stream_content_type_is_event_stream(self, api_client, mocked_graph, session_id): resp = api_client.post("/api/query/stream", json={ "user_query": "top products", "connector_id": "neon:public", "session_id": session_id, }) assert "text/event-stream" in resp.headers.get("content-type", "") def test_stream_contains_done_event(self, api_client, mocked_graph, session_id): resp = api_client.post("/api/query/stream", json={ "user_query": "top products", "connector_id": "neon:public", "session_id": session_id, }) raw = resp.text events = [ json.loads(line[len("data: "):]) for line in raw.split("\n") if line.startswith("data: ") ] done_events = [e for e in events if e.get("done")] assert len(done_events) == 1 def test_stream_emits_token_events(self, api_client, mocked_graph, session_id): resp = api_client.post("/api/query/stream", json={ "user_query": "top products", "connector_id": "neon:public", "session_id": session_id, }) raw = resp.text events = [ json.loads(line[len("data: "):]) for line in raw.split("\n") if line.startswith("data: ") ] token_events = [e for e in events if "token" in e] assert len(token_events) > 0