| """ |
| api/routers/query.py |
| POST /api/query/run – run agent, return full result (with trace + anomalies) |
| POST /api/query/stream – SSE stream with trace events + insight tokens |
| """ |
|
|
| import asyncio |
| import json |
| import time |
| import uuid |
| from typing import Any, AsyncGenerator, Dict, List, Optional |
|
|
| from fastapi import APIRouter, HTTPException |
| from fastapi.responses import StreamingResponse |
| from pydantic import BaseModel, Field |
|
|
| from agent.graph import get_graph |
| from agent.state import AgentState |
| from agent.trace import AgentTracer, set_tracer, get_tracer |
| from agent.metrics import get_metrics_collector |
|
|
| router = APIRouter() |
|
|
| |
| _conversations: Dict[str, List[Dict[str, Any]]] = {} |
| _MAX_HISTORY = 5 |
|
|
|
|
| class QueryRequest(BaseModel): |
| user_query: str = Field(..., min_length=1, max_length=2000) |
| connector_id: str = Field(..., description="e.g. neon:public or csv:<url>") |
| session_id: str = Field(default_factory=lambda: str(uuid.uuid4())) |
| user_id: str = Field(default="anonymous") |
|
|
|
|
| class TraceEventResponse(BaseModel): |
| node: str |
| status: str |
| latency_ms: int = 0 |
| tokens_used: int = 0 |
| metadata: dict = {} |
|
|
|
|
| class QueryResponse(BaseModel): |
| session_id: str |
| intent: str |
| generated_code: str |
| code_type: str |
| execution_result: list |
| insight_text: str |
| chart_spec: dict | None |
| from_cache: bool |
| latency_ms: int |
| correction_attempts: int |
| history_id: str | None |
| anomalies: list = [] |
| trace: list = [] |
|
|
|
|
| def _build_initial_state(req: QueryRequest) -> AgentState: |
| |
| history = _conversations.get(req.session_id, []) |
|
|
| return { |
| "session_id": req.session_id, |
| "user_id": req.user_id, |
| "user_query": req.user_query, |
| "connector_id": req.connector_id, |
| "intent": "", |
| "query_plan": {}, |
| "relevant_tables": [], |
| "schema_context": "", |
| "memory_context": "", |
| "conversation_history": history, |
| "generated_code": "", |
| "code_type": "sql", |
| "sql_dialect": "postgres", |
| "execution_result": None, |
| "execution_error": None, |
| "from_cache": False, |
| "error_class": None, |
| "correction_attempts": 0, |
| "max_corrections": 3, |
| "insight_text": "", |
| "chart_spec": None, |
| "anomalies": [], |
| "history_id": None, |
| "latency_ms": None, |
| "stream_tokens": [], |
| } |
|
|
|
|
| def _update_conversation(session_id: str, result: dict): |
| """Store this turn in conversation history for multi-turn context.""" |
| turn = { |
| "query": result.get("user_query", ""), |
| "code": result.get("generated_code", ""), |
| "result_preview": json.dumps((result.get("execution_result") or [])[:5], default=str), |
| "insight": result.get("insight_text", ""), |
| } |
| if session_id not in _conversations: |
| _conversations[session_id] = [] |
| _conversations[session_id].append(turn) |
| |
| if len(_conversations[session_id]) > _MAX_HISTORY: |
| _conversations[session_id] = _conversations[session_id][-_MAX_HISTORY:] |
|
|
|
|
| @router.post("/run", response_model=QueryResponse) |
| async def run_query(req: QueryRequest): |
| graph = get_graph() |
| state = _build_initial_state(req) |
|
|
| |
| tracer = AgentTracer() |
| set_tracer(tracer) |
|
|
| t0 = time.time() |
| try: |
| result = await asyncio.get_event_loop().run_in_executor( |
| None, graph.invoke, state |
| ) |
| except Exception as exc: |
| raise HTTPException(status_code=500, detail=str(exc)) |
| finally: |
| set_tracer(None) |
|
|
| total_ms = int((time.time() - t0) * 1000) |
|
|
| |
| _update_conversation(req.session_id, result) |
|
|
| |
| metrics = get_metrics_collector() |
| metrics.record_query( |
| latency_ms=total_ms, |
| from_cache=result.get("from_cache", False), |
| correction_attempts=result.get("correction_attempts", 0), |
| intent=result.get("intent", "sql"), |
| error_class=result.get("error_class"), |
| ) |
|
|
| return QueryResponse( |
| session_id=result["session_id"], |
| intent=result.get("intent", "sql"), |
| generated_code=result.get("generated_code", ""), |
| code_type=result.get("code_type", "sql"), |
| execution_result=result.get("execution_result") or [], |
| insight_text=result.get("insight_text", ""), |
| chart_spec=result.get("chart_spec"), |
| from_cache=result.get("from_cache", False), |
| latency_ms=total_ms, |
| correction_attempts=result.get("correction_attempts", 0), |
| history_id=result.get("history_id"), |
| anomalies=result.get("anomalies", []), |
| trace=tracer.get_events(), |
| ) |
|
|
|
|
| async def _stream_insight(req: QueryRequest) -> AsyncGenerator[str, None]: |
| """Run the agent, stream trace events live, then stream insight word-by-word.""" |
| graph = get_graph() |
| state = _build_initial_state(req) |
|
|
| |
| tracer = AgentTracer() |
| set_tracer(tracer) |
|
|
| t0 = time.time() |
| loop = asyncio.get_event_loop() |
| result = await loop.run_in_executor(None, graph.invoke, state) |
| total_ms = int((time.time() - t0) * 1000) |
| set_tracer(None) |
|
|
| |
| _update_conversation(req.session_id, result) |
|
|
| |
| metrics = get_metrics_collector() |
| metrics.record_query( |
| latency_ms=total_ms, |
| from_cache=result.get("from_cache", False), |
| correction_attempts=result.get("correction_attempts", 0), |
| intent=result.get("intent", "sql"), |
| error_class=result.get("error_class"), |
| ) |
|
|
| |
| for trace_event in tracer.get_events(): |
| yield f"data: {json.dumps(trace_event)}\n\n" |
|
|
| |
| insight = result.get("insight_text", "") |
| for word in insight.split(" "): |
| event = json.dumps({"token": word + " "}) |
| yield f"data: {event}\n\n" |
| await asyncio.sleep(0.03) |
|
|
| |
| final = { |
| "done": True, |
| "chart_spec": result.get("chart_spec"), |
| "generated_code": result.get("generated_code", ""), |
| "code_type": result.get("code_type", "sql"), |
| "execution_result": (result.get("execution_result") or [])[:20], |
| "latency_ms": total_ms, |
| "from_cache": result.get("from_cache", False), |
| "history_id": result.get("history_id"), |
| "anomalies": result.get("anomalies", []), |
| "correction_attempts": result.get("correction_attempts", 0), |
| "query_plan": result.get("query_plan", {}), |
| "intent": result.get("intent", "sql"), |
| "trace_summary": tracer.get_summary(), |
| } |
| yield f"data: {json.dumps(final, default=str)}\n\n" |
|
|
|
|
| @router.post("/stream") |
| async def stream_query(req: QueryRequest): |
| return StreamingResponse( |
| _stream_insight(req), |
| media_type="text/event-stream", |
| headers={ |
| "Cache-Control": "no-cache", |
| "X-Accel-Buffering": "no", |
| }, |
| ) |
|
|
|
|
| class SuggestRequest(BaseModel): |
| connector_id: str |
|
|
|
|
| class SuggestResponse(BaseModel): |
| suggestions: List[str] |
|
|
|
|
| @router.post("/suggest", response_model=SuggestResponse) |
| async def get_suggestions(req: SuggestRequest): |
| try: |
| from connectors.base import get_connector |
| connector = get_connector(req.connector_id) |
| schema = connector.get_schema() |
| except Exception as exc: |
| return SuggestResponse(suggestions=["What is the total number of rows in this dataset?"]) |
|
|
| |
| schema_lines = [] |
| for t in schema[:10]: |
| cols = ", ".join(f"{c['name']} ({c['type']})" for c in t.get("columns", [])[:15]) |
| schema_lines.append(f"Table: {t['table']}\nColumns: {cols}") |
| schema_context = "\n\n".join(schema_lines) |
|
|
| SYSTEM = """You are a senior data analyst. |
| Based on the provided database schema, generate 3 highly relevant, interesting analytical questions that a user might want to ask. |
| Return ONLY a JSON list of 3 strings. Example: ["question 1?", "question 2?", "question 3?"] |
| Focus on business metrics, trends, and aggregations.""" |
|
|
| from llm import get_groq_client |
| client = get_groq_client() |
| try: |
| raw = client.complete_system( |
| system=SYSTEM, |
| user=f"Schema:\n{schema_context}", |
| model=client.reason_model, |
| max_tokens=200, |
| ) |
| import json |
| suggestions = json.loads(raw) |
| if isinstance(suggestions, list) and len(suggestions) > 0: |
| return SuggestResponse(suggestions=suggestions[:4]) |
| except Exception: |
| pass |
|
|
| return SuggestResponse(suggestions=["What are the top trends in this dataset?"]) |
|
|