File size: 22,872 Bytes
f7021ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f50d923
 
 
 
 
 
f7021ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f50d923
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
"""
DuckDB Q&A Agent (LangGraph + ReAct) — Streaming Edition
========================================================

Adds:
- Streaming responses for the final natural-language answer (CLI prints tokens live).
- Stricter schema discovery focused on *user* tables from the provided DB (defaults to schema 'main').
- Keeps SELECT-only safety, up to 3 SQL refinements on error, optional plotting saved to ./plots.

Quick start
-----------
1) Install deps (example):
   pip install --upgrade duckdb pandas matplotlib python-dotenv langgraph langchain langchain-openai

2) Ensure an OpenAI API key is available:
   - Put it in a .env file as: OPENAI_API_KEY=sk-...
   - Or set an env var: set OPENAI_API_KEY=... (Windows) / export OPENAI_API_KEY=... (macOS/Linux)

3) Run the agent in an interactive loop:
   python duckdb_react_agent.py --duckdb path/to/your.db --stream

Notes
-----
- Targets DuckDB SQL. The LLM prompt instructs to write *only* a SELECT statement.
- ReAct loop: on error, the LLM sees the error message + previous SQL and attempts to fix it (max 3 tries).
- The agent avoids DDL/DML; SELECT-only for safety.
- Plots are saved to ./plots/ as PNG files (no GUI required). The script uses a non-interactive backend.
- Internal/helper tables beginning with "__" (e.g., ingestion metadata) and non-user schemas are ignored.
"""

import os
import re
import json
import uuid
import argparse
from typing import Any, Dict, List, Optional, TypedDict, Callable

# Non-GUI backend for saving plots on servers/Windows terminals
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import duckdb
import pandas as pd

from dotenv import load_dotenv, find_dotenv

# LangChain / LangGraph
from langchain_openai import ChatOpenAI
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.graph import StateGraph, END


# -----------------------------
# Utility: safe number formatting
# -----------------------------
def format_number(x: Any) -> str:
    try:
        if isinstance(x, (int, float)) and not isinstance(x, bool):
            if abs(x) >= 1000 and abs(x) < 1e12:
                return f"{x:,.2f}".rstrip('0').rstrip('.')
            else:
                return f"{x:.4g}" if isinstance(x, float) else str(x)
        return str(x)
    except Exception:
        return str(x)


# -----------------------------
# Schema introspection from DuckDB (user tables only)
# -----------------------------
def get_schema_summary(con: duckdb.DuckDBPyConnection, sample_rows: int = 3, include_views: bool = True,
                       allowed_schemas: Optional[List[str]] = None) -> str:
    """
    Build a compact schema snapshot for user tables in the *provided DB*.
    By default we only list tables/views in schema 'main' (user content) and skip internals.
    """
    if allowed_schemas is None:
        allowed_schemas = ["main"]

    type_filter = "('BASE TABLE')" if not include_views else "('BASE TABLE','VIEW')"

    tables_df = con.execute(
        f"""
        SELECT table_schema, table_name, table_type
        FROM information_schema.tables
        WHERE table_type IN {type_filter}
          AND table_schema IN ({','.join(['?']*len(allowed_schemas))})
          AND table_name NOT LIKE 'duckdb_%'
          AND table_name NOT LIKE 'sqlite_%'
        ORDER BY table_schema, table_name
        """,
        allowed_schemas
    ).fetchdf()

    lines: List[str] = []
    for _, row in tables_df.iterrows():
        schema = row["table_schema"]
        name = row["table_name"]
        if name.startswith("__"):
            # Common ingestion metadata; not user-facing
            continue

        cols = con.execute(
            """
            SELECT column_name, data_type
            FROM information_schema.columns
            WHERE table_schema = ? AND table_name = ?
            ORDER BY ordinal_position
            """,
            [schema, name],
        ).fetchdf()

        lines.append(f"TABLE {schema}.{name}")
        if len(cols) == 0:
            lines.append("  (no columns discovered)")
            continue

        # Sample small set of values per column to guide the LLM
        try:
            sample = con.execute(f"SELECT * FROM {schema}.{name} LIMIT {sample_rows}").fetchdf()
        except Exception:
            sample = pd.DataFrame(columns=cols["column_name"].tolist())

        for _, c in cols.iterrows():
            col = c["column_name"]
            dtype = c["data_type"]
            examples: List[str] = []
            if col in sample.columns:
                for v in sample[col].tolist():
                    examples.append(format_number(v)[:80])
            example_str = ", ".join(examples) if examples else ""
            lines.append(f"  - {col} :: {dtype}  e.g. [{example_str}]")
        lines.append("")

    if not lines:
        lines.append("(No user tables discovered in schema(s): " + ", ".join(allowed_schemas) + ")")

    return "\n".join(lines)


# -----------------------------
# LLM helpers
# -----------------------------
def make_llm(model: str = "gpt-4o-mini", temperature: float = 0.0) -> ChatOpenAI:
    return ChatOpenAI(model=model, temperature=temperature)


SQL_SYSTEM_PROMPT = """You are an expert DuckDB SQL writer. You will receive:
- A database schema (tables and columns) with a few sample values
- A user's natural-language question

Write exactly ONE valid DuckDB SQL SELECT statement to answer the question.
Rules:
- Output ONLY the SQL (no backticks, no explanation, no comments)
- Use DuckDB SQL
- Never write DDL/DML (CREATE/INSERT/UPDATE/DELETE), only SELECT statements
- Prefer explicit column names, avoid SELECT *
- If time grouping is needed, use date_trunc('month', col), date_trunc('day', col), etc.
- If casting is needed, use CAST(... AS TYPE)
- Use ONLY the listed user tables from the provided database (e.g., schema 'main'). Do not rely on external/attached sources.
- Avoid tables starting with "__" unless explicitly referenced.
- If uncertain, make a reasonable assumption and produce the best possible query

Be careful with joins and filters. Ensure SQL parses successfully.
"""

REFINE_SYSTEM_PROMPT = """You are fixing a DuckDB SQL query that errored. You'll see:
- the user's question
- the previous SQL
- the exact error message

Return a corrected DuckDB SELECT statement that addresses the error.
Rules:
- Output ONLY the SQL (no backticks, no explanation, no comments)
- SELECT-only (no DDL/DML)
- Keep the intent of the question intact
- Fix joins, field names, casts, aggregations or date functions as needed
- Use ONLY the listed user tables from the provided database (e.g., schema 'main')
"""

ANSWER_SYSTEM_PROMPT = """You are a helpful data analyst. Given a user's question and the query result data,
write a clear, concise, conversational answer in plain English. Use bullet points or short paragraphs.
- Format large numbers with thousands separators where appropriate.
- If percentages are present, include % signs.
- If a chart/image accompanies the answer, DO NOT mention file paths or filenames. Refer generically, e.g., 'as depicted in the chart' or 'see the chart alongside this answer.'
- If the result is empty, say so and suggest a next step.
- Do not invent columns or values that are not in the result.
- (Optional) You may refer to prior Q&A context if it informs this answer.
"""

VIZ_SYSTEM_PROMPT = """You will decide whether a chart would help answer the question using the SQL result.
Return STRICT JSON with keys: {"make_plot": bool, "chart": "line|bar|scatter|hist|box|pie", "x": "<col or null>", "y": "<col or null>", "series": "<col or null>", "agg": "<sum|avg|count|max|min|null>", "reason": "<short reason>"}.
Conservative rules (prefer NO plot):
- Only set make_plot=true if the user explicitly asks for a chart/visualization OR the result involves trends over time, distributions, or comparisons with many rows (e.g., > 20) where a chart adds clarity.
- For simple aggregates or few rows (<= 10), set make_plot=false unless explicitly requested.
Guidelines when make_plot=true:
- Prefer line for trends over time
- Prefer bar for ranking/comparison of categories
- Prefer scatter for correlation between two numeric columns
- Prefer hist for distribution of a single numeric column
- Prefer box for distribution with quartiles per category
- Prefer pie sparingly for simple part-to-whole
JSON only. No prose.
"""


def extract_sql(text: str) -> str:
    """Extract SQL from potential LLM output. We expect raw SQL with no fences, but be defensive."""
    text = text.strip()
    fence = re.compile(r"```(?:sql)?(.*?)```", re.DOTALL | re.IGNORECASE)
    m = fence.search(text)
    if m:
        return m.group(1).strip()
    return text


# -----------------------------
# LangGraph State
# -----------------------------
class AgentState(TypedDict):
    question: str
    schema: str
    attempts: int
    sql: Optional[str]
    error: Optional[str]
    result_json: Optional[str]        # JSON records preview
    result_columns: Optional[List[str]]
    plot_path: Optional[str]
    final_answer: Optional[str]
    _result_df: Optional[pd.DataFrame]
    _viz_spec: Optional[Dict[str, Any]]
    _stream: bool
    _token_cb: Optional[Callable[[str], None]]


# -----------------------------
# Nodes: SQL drafting, execution, refinement, viz, answer
# -----------------------------
def node_draft_sql(state: AgentState, llm: ChatOpenAI) -> AgentState:
    msgs = [
        SystemMessage(content=SQL_SYSTEM_PROMPT),
        HumanMessage(content=f"SCHEMA:\n{state['schema']}\n\nQUESTION:\n{state['question']}"),
    ]
    resp = llm.invoke(msgs)  # type: ignore
    sql = extract_sql(resp.content or "")
    return {**state, "sql": sql, "error": None}


def run_duckdb_query(con: duckdb.DuckDBPyConnection, sql: str) -> pd.DataFrame:
    first_token = re.split(r"\s+", sql.strip(), maxsplit=1)[0].upper()
    if first_token != "SELECT" and not sql.strip().upper().startswith("WITH "):
        raise ValueError("Only SELECT queries are allowed.")
    df = con.execute(sql).fetchdf()
    return df


def node_run_sql(state: AgentState, con: duckdb.DuckDBPyConnection) -> AgentState:
    try:
        df = run_duckdb_query(con, state["sql"] or "")
        preview = df.head(50).to_dict(orient="records")
        return {
            **state,
            "error": None,
            "result_json": json.dumps(preview, default=str),
            "result_columns": list(df.columns),
            "_result_df": df,
        }
    except Exception as e:
        return {**state, "error": str(e), "result_json": None, "result_columns": None, "_result_df": None}


def node_refine_sql(state: AgentState, llm: ChatOpenAI) -> AgentState:
    if state.get("attempts", 0) >= 3:
        return state
    msgs = [
        SystemMessage(content=REFINE_SYSTEM_PROMPT),
        HumanMessage(
            content=(
                f"QUESTION:\n{state['question']}\n\n"
                f"PREVIOUS SQL:\n{state.get('sql','')}\n\n"
                f"ERROR:\n{state.get('error','')}"
            )
        ),
    ]
    resp = llm.invoke(msgs)  # type: ignore
    sql = extract_sql(resp.content or "")
    return {**state, "sql": sql, "attempts": state.get("attempts", 0) + 1, "error": None}


def node_decide_viz(state: AgentState, llm: ChatOpenAI) -> AgentState:
    if not state.get("result_json"):
        return state

    result_preview = state["result_json"]
    msgs = [
        SystemMessage(content=VIZ_SYSTEM_PROMPT),
        HumanMessage(
            content=(
                f"QUESTION:\n{state['question']}\n\n"
                f"COLUMNS: {state.get('result_columns',[])}\n"
                f"RESULT PREVIEW (first rows):\n{result_preview}"
            )
        ),
    ]
    resp = llm.invoke(msgs)  # type: ignore

    spec = {"make_plot": False}
    # try:
    #     spec = json.loads(resp.content)  # type: ignore
    #     if not isinstance(spec, dict) or "make_plot" not in spec:
    #         spec = {"make_plot": False}
    # except Exception:
    #     spec = {"make_plot": False}

    return {**state, "_viz_spec": spec}


def node_make_plot(state: AgentState) -> AgentState:
    spec = state.get("_viz_spec") or {}
    if not spec or not spec.get("make_plot"):
        return state

    df: Optional[pd.DataFrame] = state.get("_result_df")
    if df is None or df.empty:
        return state

    # Additional guard: avoid trivial plots unless explicitly requested
    q_lower = (state.get('question') or '').lower()
    explicit_viz = any(k in q_lower for k in ['chart', 'plot', 'graph', 'visual', 'visualize', 'trend'])
    many_rows = df.shape[0] > 20
    if not explicit_viz and not many_rows:
        return state

    x = spec.get("x")
    y = spec.get("y")
    series = spec.get("series")
    chart = (spec.get("chart") or "").lower()

    def col_ok(c: Optional[str]) -> bool:
        return isinstance(c, str) and c in df.columns

    if not col_ok(x) and df.shape[1] >= 1:
        x = df.columns[0]
    if not col_ok(y) and df.shape[1] >= 2:
        y = df.columns[1]

    try:
        os.makedirs("plots", exist_ok=True)
        fig = plt.figure()
        ax = fig.gca()

        if chart == "line":
            if series and col_ok(series):
                for k, g in df.groupby(series):
                    ax.plot(g[x], g[y], label=str(k))
                ax.legend(loc="best")
            else:
                ax.plot(df[x], df[y])
            ax.set_xlabel(str(x)); ax.set_ylabel(str(y))

        elif chart == "bar":
            if series and col_ok(series):
                pivot = df.pivot_table(index=x, columns=series, values=y, aggfunc="sum", fill_value=0)
                pivot.plot(kind="bar", ax=ax)
                ax.set_xlabel(str(x)); ax.set_ylabel(str(y))
            else:
                ax.bar(df[x], df[y])
                ax.set_xlabel(str(x)); ax.set_ylabel(str(y))
            fig.autofmt_xdate(rotation=45)

        elif chart == "scatter":
            ax.scatter(df[x], df[y])
            ax.set_xlabel(str(x)); ax.set_ylabel(str(y))

        elif chart == "hist":
            ax.hist(df[y] if col_ok(y) else df[x], bins=30)
            ax.set_xlabel(str(y if col_ok(y) else x)); ax.set_ylabel("Frequency")

        elif chart == "box":
            ax.boxplot([df[y]] if col_ok(y) else [df[x]])
            ax.set_ylabel(str(y if col_ok(y) else x))

        elif chart == "pie":
            labels = df[x].astype(str).tolist()
            values = df[y].tolist()
            ax.pie(values, labels=labels, autopct="%1.1f%%")
            ax.axis("equal")

        else:
            plt.close(fig)
            return state

        out_path = os.path.join("plots", f"duckdb_answer_{uuid.uuid4().hex[:8]}.png")
        fig.tight_layout()
        fig.savefig(out_path, dpi=150)
        plt.close(fig)
        return {**state, "plot_path": out_path}

    except Exception:
        return state


def _stream_llm_answer(llm: ChatOpenAI, msgs: List[Any], token_cb: Optional[Callable[[str], None]]) -> str:
    """Stream tokens for the final answer; returns full text as well."""
    out = ""
    try:
        for chunk in llm.stream(msgs):  # type: ignore
            piece = getattr(chunk, "content", None)
            if not piece and hasattr(chunk, "message") and getattr(chunk, "message", None):
                piece = getattr(chunk.message, "content", "")
            if not piece:
                continue
            out += piece
            if token_cb:
                token_cb(piece)
    except Exception:
        # Fallback to non-streaming if stream not supported
        resp = llm.invoke(msgs)  # type: ignore
        out = resp.content if isinstance(resp.content, str) else str(resp.content)
        if token_cb:
            token_cb(out)
    return out


def node_answer(state: AgentState, llm: ChatOpenAI) -> AgentState:
    preview = state.get("result_json") or "[]"
    columns = state.get("result_columns") or []
    plot_path = state.get("plot_path")
    if len(preview) > 4000:
        preview = preview[:4000] + " ..."

    msgs = [
        SystemMessage(content=ANSWER_SYSTEM_PROMPT),
        HumanMessage(
            content=(
                f"QUESTION:\n{state['question']}\n\n"
                f"COLUMNS: {columns}\n"
                f"RESULT PREVIEW (rows):\n{preview}\n\n"
                f"PLOT_PATH: {plot_path if plot_path else 'None'}"
            )
        ),
    ]

    if state.get("error") and not state.get("result_json"):
        err_text = (
            "I couldn't produce a working SQL query after 3 attempts.\n\n"
            "Details:\n" + (state.get("error") or "Unknown error")
        )
        # Stream the error too for consistency
        if state.get("_stream") and state.get("_token_cb"):
            state["_token_cb"](err_text)
        return {**state, "final_answer": err_text}

    # Stream the final answer if requested
    if state.get("_stream"):
        answer = _stream_llm_answer(llm, msgs, state.get("_token_cb"))
    else:
        resp = llm.invoke(msgs)  # type: ignore
        answer = resp.content if isinstance(resp.content, str) else str(resp.content)

    return {**state, "final_answer": answer}


# -----------------------------
# Graph assembly
# -----------------------------
def build_graph(con: duckdb.DuckDBPyConnection, llm: ChatOpenAI):
    g = StateGraph(AgentState)

    g.add_node("draft_sql", lambda s: node_draft_sql(s, llm))
    g.add_node("run_sql", lambda s: node_run_sql(s, con))
    g.add_node("refine_sql", lambda s: node_refine_sql(s, llm))
    g.add_node("decide_viz", lambda s: node_decide_viz(s, llm))
    g.add_node("make_plot", node_make_plot)
    g.add_node("answer", lambda s: node_answer(s, llm))

    g.set_entry_point("draft_sql")
    g.add_edge("draft_sql", "run_sql")

    def on_run_sql(state: AgentState):
        if state.get("error"):
            if state.get("attempts", 0) < 3:
                return "refine_sql"
            else:
                return "answer"
        return "decide_viz"

    g.add_conditional_edges("run_sql", on_run_sql, {
        "refine_sql": "refine_sql",
        "decide_viz": "decide_viz",
        "answer": "answer",
    })

    g.add_edge("refine_sql", "run_sql")

    def on_decide_viz(state: AgentState):
        spec = state.get("_viz_spec") or {}
        if spec.get("make_plot"):
            return "make_plot"
        return "answer"

    g.add_conditional_edges("decide_viz", on_decide_viz, {
        "make_plot": "make_plot",
        "answer": "answer",
    })

    g.add_edge("make_plot", "answer")
    g.add_edge("answer", END)

    return g.compile()


# -----------------------------
# Public function to ask a question
# -----------------------------
def answer_question(con: duckdb.DuckDBPyConnection, llm: ChatOpenAI, schema_text: str, question: str,
                    stream: bool = False, token_callback: Optional[Callable[[str], None]] = None) -> Dict[str, Any]:
    app = build_graph(con, llm)
    initial: AgentState = {
        "question": question,
        "schema": schema_text,
        "attempts": 0,
        "sql": None,
        "error": None,
        "result_json": None,
        "result_columns": None,
        "plot_path": None,
        "final_answer": None,
        "_result_df": None,
        "_viz_spec": None,
        "_stream": stream,
        "_token_cb": token_callback,
    }
    final_state: AgentState = app.invoke(initial)  # type: ignore

    return {
        "question": question,
        "sql": final_state.get("sql"),
        "answer": final_state.get("final_answer"),
        "plot_path": final_state.get("plot_path"),
        "error": final_state.get("error"),
    }


# -----------------------------
# CLI
# -----------------------------
def main():
    load_dotenv(find_dotenv())

    parser = argparse.ArgumentParser(description="DuckDB Q&A ReAct Agent (Streaming)")
    parser.add_argument("--duckdb", required=True, help="Path to DuckDB database file")
    parser.add_argument("--model", default="gpt-4o", help="OpenAI chat model (default: gpt-4o)")
    parser.add_argument("--stream", action="store_true", help="Stream the final answer to stdout")
    parser.add_argument("--schemas", nargs="*", default=["main"], help="Schemas to include (default: main)")
    args = parser.parse_args()

    api_key = os.environ.get("OPENAI_API_KEY")
    if not api_key:
        print("ERROR: OPENAI_API_KEY not set. Put it in a .env file or export the env var.")
        return

    # Connect DuckDB strictly to the provided file (user DB)
    try:
        con = duckdb.connect(database=args.duckdb, read_only=True)
    except Exception as e:
        print(f"Failed to open DuckDB at {args.duckdb}: {e}")
        return

    # Introspect schema (user tables only)
    schema_text = get_schema_summary(con, allowed_schemas=args.schemas)

    # Init LLM; enable streaming capability (used only in final answer node)
    llm = make_llm(model=args.model, temperature=0.0)

    print("DuckDB Q&A Agent (ReAct, Streaming)\n")
    print(f"Connected to: {args.duckdb}")
    print(f"Schemas included: {', '.join(args.schemas)}")
    print("\nSchema snapshot:\n-----------------\n")
    print(schema_text)
    print("\nType your question and press ENTER. Type 'exit' to quit.\n")

    def print_token(t: str):
        print(t, end="", flush=True)

    while True:
        try:
            q = input("Q> ").strip()
        except (EOFError, KeyboardInterrupt):
            print("\nExiting.")
            break
        if q.lower() in ("exit", "quit"):
            print("Goodbye.")
            break
        if not q:
            continue

        if args.stream:
            print("\n--- ANSWER (streaming) ---")
            result = answer_question(con, llm, schema_text, q, stream=True, token_callback=print_token)
            print("")  # newline after stream
            if result.get("plot_path"):
                print(f"\nChart saved to: {result['plot_path']}")
            print("\n--- SQL ---")
            print((result.get("sql") or "").strip())
            if result.get("error"):
                print("\nERROR: " + str(result.get("error")))
        else:
            result = answer_question(con, llm, schema_text, q, stream=False, token_callback=None)
            print("\n--- SQL ---")
            print((result.get("sql") or "").strip())
            if result.get("error"):
                print("\n--- RESULT ---")
                print("Sorry, I couldn't resolve a working query after 3 attempts.")
                print("Error: " + str(result.get("error")))
            else:
                print("\n--- ANSWER ---")
                print(result.get("answer") or "")
                if result.get("plot_path"):
                    print(f"\nChart saved to: {result['plot_path']}")
        print("\n")

    con.close()


if __name__ == "__main__":
    main()