File size: 13,306 Bytes
abd4352
 
c5f9c5f
 
 
 
 
 
 
 
 
 
 
 
abd4352
c5f9c5f
 
 
abd4352
 
 
c5f9c5f
abd4352
 
c5f9c5f
 
abd4352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5f9c5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f126119
c5f9c5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71d945d
 
c5f9c5f
 
71d945d
c5f9c5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abd4352
c5f9c5f
abd4352
 
 
 
 
 
c5f9c5f
 
abd4352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5f9c5f
abd4352
 
 
 
 
 
c5f9c5f
abd4352
 
 
 
 
 
 
 
 
 
c5f9c5f
 
 
abd4352
 
 
 
c5f9c5f
abd4352
c5f9c5f
abd4352
 
 
 
 
 
 
 
 
 
 
c5f9c5f
abd4352
 
 
 
 
 
 
c5f9c5f
abd4352
c5f9c5f
abd4352
 
 
 
 
 
 
c5f9c5f
 
abd4352
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
"""
agent/graph.py
LangGraph stateful agent graph with tracing, anomaly detection, and
PERFORMANCE OPTIMIZATIONS:

1. Fused memory_retriever + query_planner into a single node that runs
   memory vector recall and schema RAG concurrently via ThreadPoolExecutor.
2. Fused insight_synthesizer + anomaly_detector + visualizer into a single
   "output_pipeline" node that runs the LLM insight call concurrently with
   CPU-bound anomaly detection and chart generation.
3. memory_updater runs as fire-and-forget background I/O β€” the response is
   returned to the user BEFORE the database write completes.

Flow (optimized):
  intent_router
    β”œβ”€ sql      β†’ planner_with_memory β†’ sql_generator β†’ safety_validator β†’ executor
    β”œβ”€ pandas   β†’ planner_with_memory β†’ pandas_generator β†’ safety_validator β†’ executor
    └─ insight  β†’ output_pipeline (skip code gen)
                                                           β”‚
                                                    (error?) yes β†’ error_classifier β†’ self_corrector β†’ safety_validator (loop)
                                                           β”‚ no
                                                    output_pipeline [insight + anomaly + visualizer in parallel] β†’ memory_updater_async β†’ END
"""

import concurrent.futures

from langgraph.graph import END, StateGraph

from agent.state import AgentState
from agent.trace import trace_node
from agent.nodes import (
    error_classifier,
    executor,
    insight_synthesizer,
    intent_router,
    memory_retriever,
    memory_updater,
    pandas_generator,
    query_planner,
    safety_validator,
    self_corrector,
    sql_generator,
    visualizer,
)
from agent.nodes.anomaly_detector import anomaly_detector


# ── Persistent thread pool for parallel node execution ─────────────────────────
_parallel_pool = concurrent.futures.ThreadPoolExecutor(
    max_workers=4, thread_name_prefix="agent_parallel"
)


# ── Fused node: planner_with_memory ────────────────────────────────────────────
# Runs memory_retriever and the expensive schema vector search concurrently,
# then feeds both into the query planner LLM call.

def _planner_with_memory(state: AgentState) -> AgentState:
    """
    Fused node that runs memory retrieval and schema RAG concurrently,
    then feeds the combined context into the query planner.
    
    Before: memory_retriever (300ms) β†’ query_planner (500ms) = 800ms sequential
    After:  memory + schema_RAG concurrent (300ms) β†’ planner LLM (500ms) = 500ms total
    """
    from llm import get_embedder, get_groq_client
    from schema.ingestor import get_relevant_tables
    from db.pool import pooled_cursor
    import json

    embedder = get_embedder()
    query = state["user_query"]
    connector_id = state["connector_id"]

    # Kick off embedding generation once β€” reuse the vector for both tasks
    query_vec = embedder.embed(query)

    # ── Run memory recall and schema RAG concurrently ──────────────────────────
    def _fetch_memory():
        with pooled_cursor(readonly=True, dict_cursor=True) as (cur, conn):
            cur.execute(
                """
                SELECT query, insight, table_names,
                       1 - (embedding <=> %s::vector) AS similarity
                FROM memory_embeddings
                WHERE session_id = %s
                ORDER BY similarity DESC
                LIMIT 3
                """,
                (query_vec, state["session_id"]),
            )
            rows = cur.fetchall()
        if not rows:
            return ""
        lines = []
        for r in rows:
            if r["similarity"] > 0.75:
                lines.append(f"[Past query: {r['query']}]\n[Insight: {r['insight']}]")
        return "\n---\n".join(lines)

    def _fetch_schema():
        return get_relevant_tables(
            connector_id=connector_id,
            query=query,
            top_k=15,
        )

    mem_future = _parallel_pool.submit(_fetch_memory)
    schema_future = _parallel_pool.submit(_fetch_schema)

    memory_context = mem_future.result(timeout=10)
    relevant_tables = schema_future.result(timeout=10)

    # ── Build schema context ───────────────────────────────────────────────────
    schema_lines = []
    for t in relevant_tables:
        cols = ", ".join(f"{c['name']} ({c['type']})" for c in t.get("columns", []))
        schema_lines.append(f"Table: {t['table']}\nColumns: {cols}")
    schema_context = "\n\n".join(schema_lines)

    # ── Run query planner LLM call ─────────────────────────────────────────────
    PLANNER_SYSTEM = """You are a data analyst query planner.
Given the user query, relevant table schemas, and memory context, produce a concise query plan.
Respond ONLY with JSON:
{
  "tables": ["table1", "table2"],
  "approach": "one sentence describing the analytical approach",
  "complexity": "simple|medium|complex",
  "requires_join": true|false
}"""

    client = get_groq_client()
    user_msg = (
        f"User query: {query}\n\n"
        f"Available schema:\n{schema_context}\n\n"
        f"Memory context:\n{memory_context or 'none'}"
    )
    raw = client.complete_system(
        system=PLANNER_SYSTEM,
        user=user_msg,
        model=client.reason_model,
        max_tokens=256,
    )
    try:
        plan = json.loads(raw)
    except json.JSONDecodeError:
        plan = {"tables": [], "approach": "direct query", "complexity": "simple", "requires_join": False}

    return {
        **state,
        "memory_context": memory_context,
        "relevant_tables": relevant_tables,
        "schema_context": schema_context,
        "query_plan": plan,
    }


# ── Fused node: output_pipeline ────────────────────────────────────────────────
# Runs insight synthesis (LLM), anomaly detection (CPU), and visualization (CPU)
# concurrently instead of sequentially.

def _output_pipeline(state: AgentState) -> AgentState:
    """
    Fused output pipeline that runs three independent tasks concurrently:
    - Insight synthesis (LLM call, ~400ms)
    - Anomaly detection (pure CPU, ~5ms)
    - Chart visualization (pure CPU, ~2ms)
    
    Before: insight (400ms) β†’ anomaly (5ms) β†’ visualizer (2ms) = 407ms sequential
    After:  all three concurrent = ~400ms (bounded by the LLM call)
    """
    result = state.get("execution_result")

    error_msg = state.get("execution_error")
    if error_msg:
        return {
            **state,
            "insight_text": f"Execution failed: {error_msg}",
            "anomalies": [],
            "chart_spec": None,
        }

    # Run all three concurrently
    insight_future = _parallel_pool.submit(insight_synthesizer, state)
    anomaly_future = _parallel_pool.submit(anomaly_detector, state)
    visualizer_future = _parallel_pool.submit(visualizer, state)

    insight_state = insight_future.result(timeout=30)
    anomaly_state = anomaly_future.result(timeout=10)
    vis_state = visualizer_future.result(timeout=10)

    return {
        **state,
        "insight_text": insight_state.get("insight_text", ""),
        "anomalies": anomaly_state.get("anomalies", []),
        "chart_spec": vis_state.get("chart_spec"),
    }


# ── Async memory updater (fire-and-forget) ─────────────────────────────────────

def _memory_updater_async(state: AgentState) -> AgentState:
    """
    Submits the memory write (embedding + 2 DB inserts) to a background thread.
    The response is returned to the user immediately without waiting for persistence.
    
    Savings: ~200-400ms removed from the critical response path.
    """
    _parallel_pool.submit(_safe_memory_write, state)

    # Return immediately with a generated history_id
    import uuid
    return {**state, "history_id": str(uuid.uuid4())}


def _safe_memory_write(state: AgentState):
    """Background task: persist query history and memory embeddings."""
    try:
        memory_updater(state)
    except Exception:
        pass  # Non-critical β€” don't crash the background thread


# ── Wrap nodes with tracing ────────────────────────────────────────────────────
_traced_intent_router = trace_node("intent_router")(intent_router)
_traced_planner_with_memory = trace_node("planner_with_memory")(_planner_with_memory)
_traced_sql_generator = trace_node("sql_generator")(sql_generator)
_traced_pandas_generator = trace_node("pandas_generator")(pandas_generator)
_traced_safety_validator = trace_node("safety_validator")(safety_validator)
_traced_executor = trace_node("executor")(executor)
_traced_error_classifier = trace_node("error_classifier")(error_classifier)
_traced_self_corrector = trace_node("self_corrector")(self_corrector)
_traced_output_pipeline = trace_node("output_pipeline")(_output_pipeline)
_traced_memory_updater = trace_node("memory_updater")(_memory_updater_async)


# ── Conditional edges ──────────────────────────────────────────────────────────

def route_intent(state: AgentState) -> str:
    intent = state.get("intent", "sql")
    if intent == "unsupported":
        return "unsupported"
    if intent == "pandas":
        return "pandas"
    if intent == "insight":
        return "insight_only"
    return "sql"


def route_after_validation(state: AgentState) -> str:
    """After safety_validator: proceed to execute or short-circuit if blocked."""
    error = state.get("execution_error", "")
    if error and error.startswith("SAFETY_BLOCK"):
        return "blocked"
    return "execute"


def route_after_execution(state: AgentState) -> str:
    """After executor: either synthesize or enter self-correction loop."""
    if state.get("execution_error"):
        attempts = state.get("correction_attempts", 0)
        max_attempts = state.get("max_corrections", 3)
        if attempts >= max_attempts:
            return "give_up"
        return "correct"
    return "success"


def route_after_correction(state: AgentState) -> str:
    """After self_corrector: always re-validate."""
    return "revalidate"


# ── Graph builder ──────────────────────────────────────────────────────────────

def build_graph() -> StateGraph:
    g = StateGraph(AgentState)

    # Nodes (all traced)
    g.add_node("intent_router", _traced_intent_router)
    g.add_node("planner_with_memory", _traced_planner_with_memory)
    g.add_node("sql_generator", _traced_sql_generator)
    g.add_node("pandas_generator", _traced_pandas_generator)
    g.add_node("safety_validator", _traced_safety_validator)
    g.add_node("executor", _traced_executor)
    g.add_node("error_classifier", _traced_error_classifier)
    g.add_node("self_corrector", _traced_self_corrector)
    g.add_node("output_pipeline", _traced_output_pipeline)
    g.add_node("memory_updater", _traced_memory_updater)

    # Entry
    g.set_entry_point("intent_router")

    # Intent routing
    g.add_conditional_edges(
        "intent_router",
        route_intent,
        {
            "sql": "planner_with_memory",
            "pandas": "planner_with_memory",
            "insight_only": "output_pipeline",
            "unsupported": END,
        },
    )

    # Fused planner β†’ code gen
    g.add_conditional_edges(
        "planner_with_memory",
        lambda s: "pandas" if s.get("intent") == "pandas" else "sql",
        {"sql": "sql_generator", "pandas": "pandas_generator"},
    )

    g.add_edge("sql_generator", "safety_validator")
    g.add_edge("pandas_generator", "safety_validator")

    # Validation β†’ execution or block
    g.add_conditional_edges(
        "safety_validator",
        route_after_validation,
        {"execute": "executor", "blocked": "output_pipeline"},
    )

    # Execution β†’ success or self-correction
    g.add_conditional_edges(
        "executor",
        route_after_execution,
        {
            "success": "output_pipeline",
            "correct": "error_classifier",
            "give_up": "output_pipeline",
        },
    )

    # Error loop
    g.add_edge("error_classifier", "self_corrector")
    g.add_edge("self_corrector", "safety_validator")  # re-validate corrected code

    # Output β†’ fire-and-forget memory write β†’ END
    g.add_edge("output_pipeline", "memory_updater")
    g.add_edge("memory_updater", END)

    return g.compile()


# Singleton compiled graph
_graph = None


def get_graph():
    global _graph
    if _graph is None:
        _graph = build_graph()
    return _graph