File size: 8,967 Bytes
abd4352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c960c82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f126119
c960c82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
"""
api/routers/query.py
POST /api/query/run        – run agent, return full result (with trace + anomalies)
POST /api/query/stream     – SSE stream with trace events + insight tokens
"""

import asyncio
import json
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional

from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field

from agent.graph import get_graph
from agent.state import AgentState
from agent.trace import AgentTracer, set_tracer, get_tracer
from agent.metrics import get_metrics_collector

router = APIRouter()

# ── In-memory conversation store (per session) ────────────────────────────────
_conversations: Dict[str, List[Dict[str, Any]]] = {}
_MAX_HISTORY = 5  # Keep last 5 turns per session


class QueryRequest(BaseModel):
    user_query: str = Field(..., min_length=1, max_length=2000)
    connector_id: str = Field(..., description="e.g. neon:public or csv:<url>")
    session_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
    user_id: str = Field(default="anonymous")


class TraceEventResponse(BaseModel):
    node: str
    status: str
    latency_ms: int = 0
    tokens_used: int = 0
    metadata: dict = {}


class QueryResponse(BaseModel):
    session_id: str
    intent: str
    generated_code: str
    code_type: str
    execution_result: list
    insight_text: str
    chart_spec: dict | None
    from_cache: bool
    latency_ms: int
    correction_attempts: int
    history_id: str | None
    anomalies: list = []
    trace: list = []


def _build_initial_state(req: QueryRequest) -> AgentState:
    # Inject conversation history for multi-turn context
    history = _conversations.get(req.session_id, [])

    return {
        "session_id": req.session_id,
        "user_id": req.user_id,
        "user_query": req.user_query,
        "connector_id": req.connector_id,
        "intent": "",
        "query_plan": {},
        "relevant_tables": [],
        "schema_context": "",
        "memory_context": "",
        "conversation_history": history,
        "generated_code": "",
        "code_type": "sql",
        "sql_dialect": "postgres",
        "execution_result": None,
        "execution_error": None,
        "from_cache": False,
        "error_class": None,
        "correction_attempts": 0,
        "max_corrections": 3,
        "insight_text": "",
        "chart_spec": None,
        "anomalies": [],
        "history_id": None,
        "latency_ms": None,
        "stream_tokens": [],
    }


def _update_conversation(session_id: str, result: dict):
    """Store this turn in conversation history for multi-turn context."""
    turn = {
        "query": result.get("user_query", ""),
        "code": result.get("generated_code", ""),
        "result_preview": json.dumps((result.get("execution_result") or [])[:5], default=str),
        "insight": result.get("insight_text", ""),
    }
    if session_id not in _conversations:
        _conversations[session_id] = []
    _conversations[session_id].append(turn)
    # Trim to max history
    if len(_conversations[session_id]) > _MAX_HISTORY:
        _conversations[session_id] = _conversations[session_id][-_MAX_HISTORY:]


@router.post("/run", response_model=QueryResponse)
async def run_query(req: QueryRequest):
    graph = get_graph()
    state = _build_initial_state(req)

    # Set up tracing
    tracer = AgentTracer()
    set_tracer(tracer)

    t0 = time.time()
    try:
        result = await asyncio.get_event_loop().run_in_executor(
            None, graph.invoke, state
        )
    except Exception as exc:
        raise HTTPException(status_code=500, detail=str(exc))
    finally:
        set_tracer(None)

    total_ms = int((time.time() - t0) * 1000)

    # Update conversation history
    _update_conversation(req.session_id, result)

    # Record metrics
    metrics = get_metrics_collector()
    metrics.record_query(
        latency_ms=total_ms,
        from_cache=result.get("from_cache", False),
        correction_attempts=result.get("correction_attempts", 0),
        intent=result.get("intent", "sql"),
        error_class=result.get("error_class"),
    )

    return QueryResponse(
        session_id=result["session_id"],
        intent=result.get("intent", "sql"),
        generated_code=result.get("generated_code", ""),
        code_type=result.get("code_type", "sql"),
        execution_result=result.get("execution_result") or [],
        insight_text=result.get("insight_text", ""),
        chart_spec=result.get("chart_spec"),
        from_cache=result.get("from_cache", False),
        latency_ms=total_ms,
        correction_attempts=result.get("correction_attempts", 0),
        history_id=result.get("history_id"),
        anomalies=result.get("anomalies", []),
        trace=tracer.get_events(),
    )


async def _stream_insight(req: QueryRequest) -> AsyncGenerator[str, None]:
    """Run the agent, stream trace events live, then stream insight word-by-word."""
    graph = get_graph()
    state = _build_initial_state(req)

    # Set up tracing
    tracer = AgentTracer()
    set_tracer(tracer)

    t0 = time.time()
    loop = asyncio.get_event_loop()
    result = await loop.run_in_executor(None, graph.invoke, state)
    total_ms = int((time.time() - t0) * 1000)
    set_tracer(None)

    # Update conversation history
    _update_conversation(req.session_id, result)

    # Record metrics
    metrics = get_metrics_collector()
    metrics.record_query(
        latency_ms=total_ms,
        from_cache=result.get("from_cache", False),
        correction_attempts=result.get("correction_attempts", 0),
        intent=result.get("intent", "sql"),
        error_class=result.get("error_class"),
    )

    # Stream trace events first
    for trace_event in tracer.get_events():
        yield f"data: {json.dumps(trace_event)}\n\n"

    # Stream insight word by word
    insight = result.get("insight_text", "")
    for word in insight.split(" "):
        event = json.dumps({"token": word + " "})
        yield f"data: {event}\n\n"
        await asyncio.sleep(0.03)

    # Final event with full payload
    final = {
        "done": True,
        "chart_spec": result.get("chart_spec"),
        "generated_code": result.get("generated_code", ""),
        "code_type": result.get("code_type", "sql"),
        "execution_result": (result.get("execution_result") or [])[:20],
        "latency_ms": total_ms,
        "from_cache": result.get("from_cache", False),
        "history_id": result.get("history_id"),
        "anomalies": result.get("anomalies", []),
        "correction_attempts": result.get("correction_attempts", 0),
        "query_plan": result.get("query_plan", {}),
        "intent": result.get("intent", "sql"),
        "trace_summary": tracer.get_summary(),
    }
    yield f"data: {json.dumps(final, default=str)}\n\n"


@router.post("/stream")
async def stream_query(req: QueryRequest):
    return StreamingResponse(
        _stream_insight(req),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "X-Accel-Buffering": "no",
        },
    )


class SuggestRequest(BaseModel):
    connector_id: str


class SuggestResponse(BaseModel):
    suggestions: List[str]


@router.post("/suggest", response_model=SuggestResponse)
async def get_suggestions(req: SuggestRequest):
    try:
        from connectors.base import get_connector
        connector = get_connector(req.connector_id)
        schema = connector.get_schema()
    except Exception as exc:
        return SuggestResponse(suggestions=["What is the total number of rows in this dataset?"])

    # Format schema for prompt
    schema_lines = []
    for t in schema[:10]:  # Limit to 10 tables
        cols = ", ".join(f"{c['name']} ({c['type']})" for c in t.get("columns", [])[:15])
        schema_lines.append(f"Table: {t['table']}\nColumns: {cols}")
    schema_context = "\n\n".join(schema_lines)

    SYSTEM = """You are a senior data analyst.
Based on the provided database schema, generate 3 highly relevant, interesting analytical questions that a user might want to ask.
Return ONLY a JSON list of 3 strings. Example: ["question 1?", "question 2?", "question 3?"]
Focus on business metrics, trends, and aggregations."""

    from llm import get_groq_client
    client = get_groq_client()
    try:
        raw = client.complete_system(
            system=SYSTEM,
            user=f"Schema:\n{schema_context}",
            model=client.reason_model,
            max_tokens=200,
        )
        import json
        suggestions = json.loads(raw)
        if isinstance(suggestions, list) and len(suggestions) > 0:
            return SuggestResponse(suggestions=suggestions[:4])
    except Exception:
        pass

    return SuggestResponse(suggestions=["What are the top trends in this dataset?"])