Data_analysis_agent / tests /integration /test_query_api.py
rohitdeshmukh318's picture
deployment prep
612eafc
"""
tests/integration/test_query_api.py
FastAPI integration tests — agent graph is fully mocked so no LLM/DB calls happen.
Uses TestClient (synchronous ASGI runner).
"""
import json
import uuid
import pytest
from unittest.mock import patch, MagicMock
def _make_final_state(session_id: str) -> dict:
return {
"session_id": session_id,
"user_id": "test",
"user_query": "top products",
"connector_id": "neon:public",
"intent": "sql",
"query_plan": {},
"relevant_tables": [],
"schema_context": "",
"memory_context": "",
"generated_code": "SELECT product, SUM(amount) FROM orders GROUP BY product",
"code_type": "sql",
"sql_dialect": "postgres",
"execution_result": [{"product": "Widget", "total": 299.97}],
"execution_error": None,
"from_cache": False,
"error_class": None,
"correction_attempts": 0,
"max_corrections": 3,
"insight_text": "Widget is the top product with $300 in revenue.",
"chart_spec": {
"type": "bar",
"plotly_json": {
"data": [{"type": "bar", "x": ["Widget"], "y": [299.97]}],
"layout": {"title": "top products"},
},
},
"history_id": str(uuid.uuid4()),
"latency_ms": 850,
"stream_tokens": [],
}
@pytest.fixture
def mocked_graph(session_id):
"""Patch get_graph() to return a mock that returns controlled state."""
state = _make_final_state(session_id)
mock_graph = MagicMock()
mock_graph.invoke.return_value = state
with patch("api.routers.query.get_graph", return_value=mock_graph):
yield mock_graph, state
@pytest.mark.integration
class TestQueryRunEndpoint:
def test_returns_200(self, api_client, mocked_graph, session_id):
resp = api_client.post("/api/query/run", json={
"user_query": "top products",
"connector_id": "neon:public",
"session_id": session_id,
})
assert resp.status_code == 200
def test_response_has_required_fields(self, api_client, mocked_graph, session_id):
resp = api_client.post("/api/query/run", json={
"user_query": "top products",
"connector_id": "neon:public",
"session_id": session_id,
})
body = resp.json()
required = ["session_id", "intent", "generated_code", "execution_result",
"insight_text", "chart_spec", "from_cache", "latency_ms"]
for field in required:
assert field in body, f"Missing field: {field}"
def test_insight_text_returned(self, api_client, mocked_graph, session_id):
resp = api_client.post("/api/query/run", json={
"user_query": "top products",
"connector_id": "neon:public",
"session_id": session_id,
})
assert resp.json()["insight_text"] == "Widget is the top product with $300 in revenue."
def test_chart_spec_returned(self, api_client, mocked_graph, session_id):
resp = api_client.post("/api/query/run", json={
"user_query": "top products",
"connector_id": "neon:public",
"session_id": session_id,
})
chart = resp.json()["chart_spec"]
assert chart is not None
assert chart["type"] == "bar"
def test_empty_query_rejected(self, api_client):
resp = api_client.post("/api/query/run", json={
"user_query": "",
"connector_id": "neon:public",
})
assert resp.status_code == 422
def test_missing_connector_id_rejected(self, api_client):
resp = api_client.post("/api/query/run", json={"user_query": "test"})
assert resp.status_code == 422
def test_session_id_auto_generated_when_absent(self, api_client, mocked_graph):
resp = api_client.post("/api/query/run", json={
"user_query": "top products",
"connector_id": "neon:public",
})
assert resp.status_code == 200
assert resp.json()["session_id"]
def test_graph_invoke_called_once(self, api_client, mocked_graph, session_id):
mock_graph, _ = mocked_graph
api_client.post("/api/query/run", json={
"user_query": "top products",
"connector_id": "neon:public",
"session_id": session_id,
})
mock_graph.invoke.assert_called_once()
def test_graph_exception_returns_500(self, api_client, session_id):
mock_graph = MagicMock()
mock_graph.invoke.side_effect = RuntimeError("LLM timeout")
with patch("api.routers.query.get_graph", return_value=mock_graph):
resp = api_client.post("/api/query/run", json={
"user_query": "top products",
"connector_id": "neon:public",
"session_id": session_id,
})
assert resp.status_code == 500
@pytest.mark.integration
class TestHealthEndpoint:
def test_health_returns_ok(self, api_client):
resp = api_client.get("/health")
assert resp.status_code == 200
assert resp.json() == {"status": "ok", "version": "2.0.0"}
@pytest.mark.integration
class TestStreamEndpoint:
def test_stream_returns_200(self, api_client, mocked_graph, session_id):
resp = api_client.post("/api/query/stream", json={
"user_query": "top products",
"connector_id": "neon:public",
"session_id": session_id,
})
assert resp.status_code == 200
def test_stream_content_type_is_event_stream(self, api_client, mocked_graph, session_id):
resp = api_client.post("/api/query/stream", json={
"user_query": "top products",
"connector_id": "neon:public",
"session_id": session_id,
})
assert "text/event-stream" in resp.headers.get("content-type", "")
def test_stream_contains_done_event(self, api_client, mocked_graph, session_id):
resp = api_client.post("/api/query/stream", json={
"user_query": "top products",
"connector_id": "neon:public",
"session_id": session_id,
})
raw = resp.text
events = [
json.loads(line[len("data: "):])
for line in raw.split("\n")
if line.startswith("data: ")
]
done_events = [e for e in events if e.get("done")]
assert len(done_events) == 1
def test_stream_emits_token_events(self, api_client, mocked_graph, session_id):
resp = api_client.post("/api/query/stream", json={
"user_query": "top products",
"connector_id": "neon:public",
"session_id": session_id,
})
raw = resp.text
events = [
json.loads(line[len("data: "):])
for line in raw.split("\n")
if line.startswith("data: ")
]
token_events = [e for e in events if "token" in e]
assert len(token_events) > 0