""" DuckDB Q&A Agent (LangGraph + ReAct) — Streaming Edition ======================================================== Adds: - Streaming responses for the final natural-language answer (CLI prints tokens live). - Stricter schema discovery focused on *user* tables from the provided DB (defaults to schema 'main'). - Keeps SELECT-only safety, up to 3 SQL refinements on error, optional plotting saved to ./plots. Quick start ----------- 1) Install deps (example): pip install --upgrade duckdb pandas matplotlib python-dotenv langgraph langchain langchain-openai 2) Ensure an OpenAI API key is available: - Put it in a .env file as: OPENAI_API_KEY=sk-... - Or set an env var: set OPENAI_API_KEY=... (Windows) / export OPENAI_API_KEY=... (macOS/Linux) 3) Run the agent in an interactive loop: python duckdb_react_agent.py --duckdb path/to/your.db --stream Notes ----- - Targets DuckDB SQL. The LLM prompt instructs to write *only* a SELECT statement. - ReAct loop: on error, the LLM sees the error message + previous SQL and attempts to fix it (max 3 tries). - The agent avoids DDL/DML; SELECT-only for safety. - Plots are saved to ./plots/ as PNG files (no GUI required). The script uses a non-interactive backend. - Internal/helper tables beginning with "__" (e.g., ingestion metadata) and non-user schemas are ignored. """ import os import re import json import uuid import argparse from typing import Any, Dict, List, Optional, TypedDict, Callable # Non-GUI backend for saving plots on servers/Windows terminals import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import duckdb import pandas as pd from dotenv import load_dotenv, find_dotenv # LangChain / LangGraph from langchain_openai import ChatOpenAI from langchain_core.messages import SystemMessage, HumanMessage from langgraph.graph import StateGraph, END # ----------------------------- # Utility: safe number formatting # ----------------------------- def format_number(x: Any) -> str: try: if isinstance(x, (int, float)) and not isinstance(x, bool): if abs(x) >= 1000 and abs(x) < 1e12: return f"{x:,.2f}".rstrip('0').rstrip('.') else: return f"{x:.4g}" if isinstance(x, float) else str(x) return str(x) except Exception: return str(x) # ----------------------------- # Schema introspection from DuckDB (user tables only) # ----------------------------- def get_schema_summary(con: duckdb.DuckDBPyConnection, sample_rows: int = 3, include_views: bool = True, allowed_schemas: Optional[List[str]] = None) -> str: """ Build a compact schema snapshot for user tables in the *provided DB*. By default we only list tables/views in schema 'main' (user content) and skip internals. """ if allowed_schemas is None: allowed_schemas = ["main"] type_filter = "('BASE TABLE')" if not include_views else "('BASE TABLE','VIEW')" tables_df = con.execute( f""" SELECT table_schema, table_name, table_type FROM information_schema.tables WHERE table_type IN {type_filter} AND table_schema IN ({','.join(['?']*len(allowed_schemas))}) AND table_name NOT LIKE 'duckdb_%' AND table_name NOT LIKE 'sqlite_%' ORDER BY table_schema, table_name """, allowed_schemas ).fetchdf() lines: List[str] = [] for _, row in tables_df.iterrows(): schema = row["table_schema"] name = row["table_name"] if name.startswith("__"): # Common ingestion metadata; not user-facing continue cols = con.execute( """ SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = ? AND table_name = ? ORDER BY ordinal_position """, [schema, name], ).fetchdf() lines.append(f"TABLE {schema}.{name}") if len(cols) == 0: lines.append(" (no columns discovered)") continue # Sample small set of values per column to guide the LLM try: sample = con.execute(f"SELECT * FROM {schema}.{name} LIMIT {sample_rows}").fetchdf() except Exception: sample = pd.DataFrame(columns=cols["column_name"].tolist()) for _, c in cols.iterrows(): col = c["column_name"] dtype = c["data_type"] examples: List[str] = [] if col in sample.columns: for v in sample[col].tolist(): examples.append(format_number(v)[:80]) example_str = ", ".join(examples) if examples else "" lines.append(f" - {col} :: {dtype} e.g. [{example_str}]") lines.append("") if not lines: lines.append("(No user tables discovered in schema(s): " + ", ".join(allowed_schemas) + ")") return "\n".join(lines) # ----------------------------- # LLM helpers # ----------------------------- def make_llm(model: str = "gpt-4o-mini", temperature: float = 0.0) -> ChatOpenAI: return ChatOpenAI(model=model, temperature=temperature) SQL_SYSTEM_PROMPT = """You are an expert DuckDB SQL writer. You will receive: - A database schema (tables and columns) with a few sample values - A user's natural-language question Write exactly ONE valid DuckDB SQL SELECT statement to answer the question. Rules: - Output ONLY the SQL (no backticks, no explanation, no comments) - Use DuckDB SQL - Never write DDL/DML (CREATE/INSERT/UPDATE/DELETE), only SELECT statements - Prefer explicit column names, avoid SELECT * - If time grouping is needed, use date_trunc('month', col), date_trunc('day', col), etc. - If casting is needed, use CAST(... AS TYPE) - Use ONLY the listed user tables from the provided database (e.g., schema 'main'). Do not rely on external/attached sources. - Avoid tables starting with "__" unless explicitly referenced. - If uncertain, make a reasonable assumption and produce the best possible query Be careful with joins and filters. Ensure SQL parses successfully. """ REFINE_SYSTEM_PROMPT = """You are fixing a DuckDB SQL query that errored. You'll see: - the user's question - the previous SQL - the exact error message Return a corrected DuckDB SELECT statement that addresses the error. Rules: - Output ONLY the SQL (no backticks, no explanation, no comments) - SELECT-only (no DDL/DML) - Keep the intent of the question intact - Fix joins, field names, casts, aggregations or date functions as needed - Use ONLY the listed user tables from the provided database (e.g., schema 'main') """ ANSWER_SYSTEM_PROMPT = """You are a helpful data analyst. Given a user's question and the query result data, write a clear, concise, conversational answer in plain English. Use bullet points or short paragraphs. - Format large numbers with thousands separators where appropriate. - If percentages are present, include % signs. - If a chart/image accompanies the answer, DO NOT mention file paths or filenames. Refer generically, e.g., 'as depicted in the chart' or 'see the chart alongside this answer.' - If the result is empty, say so and suggest a next step. - Do not invent columns or values that are not in the result. - (Optional) You may refer to prior Q&A context if it informs this answer. """ VIZ_SYSTEM_PROMPT = """You will decide whether a chart would help answer the question using the SQL result. Return STRICT JSON with keys: {"make_plot": bool, "chart": "line|bar|scatter|hist|box|pie", "x": "", "y": "", "series": "", "agg": "", "reason": ""}. Conservative rules (prefer NO plot): - Only set make_plot=true if the user explicitly asks for a chart/visualization OR the result involves trends over time, distributions, or comparisons with many rows (e.g., > 20) where a chart adds clarity. - For simple aggregates or few rows (<= 10), set make_plot=false unless explicitly requested. Guidelines when make_plot=true: - Prefer line for trends over time - Prefer bar for ranking/comparison of categories - Prefer scatter for correlation between two numeric columns - Prefer hist for distribution of a single numeric column - Prefer box for distribution with quartiles per category - Prefer pie sparingly for simple part-to-whole JSON only. No prose. """ def extract_sql(text: str) -> str: """Extract SQL from potential LLM output. We expect raw SQL with no fences, but be defensive.""" text = text.strip() fence = re.compile(r"```(?:sql)?(.*?)```", re.DOTALL | re.IGNORECASE) m = fence.search(text) if m: return m.group(1).strip() return text # ----------------------------- # LangGraph State # ----------------------------- class AgentState(TypedDict): question: str schema: str attempts: int sql: Optional[str] error: Optional[str] result_json: Optional[str] # JSON records preview result_columns: Optional[List[str]] plot_path: Optional[str] final_answer: Optional[str] _result_df: Optional[pd.DataFrame] _viz_spec: Optional[Dict[str, Any]] _stream: bool _token_cb: Optional[Callable[[str], None]] # ----------------------------- # Nodes: SQL drafting, execution, refinement, viz, answer # ----------------------------- def node_draft_sql(state: AgentState, llm: ChatOpenAI) -> AgentState: msgs = [ SystemMessage(content=SQL_SYSTEM_PROMPT), HumanMessage(content=f"SCHEMA:\n{state['schema']}\n\nQUESTION:\n{state['question']}"), ] resp = llm.invoke(msgs) # type: ignore sql = extract_sql(resp.content or "") return {**state, "sql": sql, "error": None} def run_duckdb_query(con: duckdb.DuckDBPyConnection, sql: str) -> pd.DataFrame: first_token = re.split(r"\s+", sql.strip(), maxsplit=1)[0].upper() if first_token != "SELECT" and not sql.strip().upper().startswith("WITH "): raise ValueError("Only SELECT queries are allowed.") df = con.execute(sql).fetchdf() return df def node_run_sql(state: AgentState, con: duckdb.DuckDBPyConnection) -> AgentState: try: df = run_duckdb_query(con, state["sql"] or "") preview = df.head(50).to_dict(orient="records") return { **state, "error": None, "result_json": json.dumps(preview, default=str), "result_columns": list(df.columns), "_result_df": df, } except Exception as e: return {**state, "error": str(e), "result_json": None, "result_columns": None, "_result_df": None} def node_refine_sql(state: AgentState, llm: ChatOpenAI) -> AgentState: if state.get("attempts", 0) >= 3: return state msgs = [ SystemMessage(content=REFINE_SYSTEM_PROMPT), HumanMessage( content=( f"QUESTION:\n{state['question']}\n\n" f"PREVIOUS SQL:\n{state.get('sql','')}\n\n" f"ERROR:\n{state.get('error','')}" ) ), ] resp = llm.invoke(msgs) # type: ignore sql = extract_sql(resp.content or "") return {**state, "sql": sql, "attempts": state.get("attempts", 0) + 1, "error": None} def node_decide_viz(state: AgentState, llm: ChatOpenAI) -> AgentState: if not state.get("result_json"): return state result_preview = state["result_json"] msgs = [ SystemMessage(content=VIZ_SYSTEM_PROMPT), HumanMessage( content=( f"QUESTION:\n{state['question']}\n\n" f"COLUMNS: {state.get('result_columns',[])}\n" f"RESULT PREVIEW (first rows):\n{result_preview}" ) ), ] resp = llm.invoke(msgs) # type: ignore spec = {"make_plot": False} # try: # spec = json.loads(resp.content) # type: ignore # if not isinstance(spec, dict) or "make_plot" not in spec: # spec = {"make_plot": False} # except Exception: # spec = {"make_plot": False} return {**state, "_viz_spec": spec} def node_make_plot(state: AgentState) -> AgentState: spec = state.get("_viz_spec") or {} if not spec or not spec.get("make_plot"): return state df: Optional[pd.DataFrame] = state.get("_result_df") if df is None or df.empty: return state # Additional guard: avoid trivial plots unless explicitly requested q_lower = (state.get('question') or '').lower() explicit_viz = any(k in q_lower for k in ['chart', 'plot', 'graph', 'visual', 'visualize', 'trend']) many_rows = df.shape[0] > 20 if not explicit_viz and not many_rows: return state x = spec.get("x") y = spec.get("y") series = spec.get("series") chart = (spec.get("chart") or "").lower() def col_ok(c: Optional[str]) -> bool: return isinstance(c, str) and c in df.columns if not col_ok(x) and df.shape[1] >= 1: x = df.columns[0] if not col_ok(y) and df.shape[1] >= 2: y = df.columns[1] try: os.makedirs("plots", exist_ok=True) fig = plt.figure() ax = fig.gca() if chart == "line": if series and col_ok(series): for k, g in df.groupby(series): ax.plot(g[x], g[y], label=str(k)) ax.legend(loc="best") else: ax.plot(df[x], df[y]) ax.set_xlabel(str(x)); ax.set_ylabel(str(y)) elif chart == "bar": if series and col_ok(series): pivot = df.pivot_table(index=x, columns=series, values=y, aggfunc="sum", fill_value=0) pivot.plot(kind="bar", ax=ax) ax.set_xlabel(str(x)); ax.set_ylabel(str(y)) else: ax.bar(df[x], df[y]) ax.set_xlabel(str(x)); ax.set_ylabel(str(y)) fig.autofmt_xdate(rotation=45) elif chart == "scatter": ax.scatter(df[x], df[y]) ax.set_xlabel(str(x)); ax.set_ylabel(str(y)) elif chart == "hist": ax.hist(df[y] if col_ok(y) else df[x], bins=30) ax.set_xlabel(str(y if col_ok(y) else x)); ax.set_ylabel("Frequency") elif chart == "box": ax.boxplot([df[y]] if col_ok(y) else [df[x]]) ax.set_ylabel(str(y if col_ok(y) else x)) elif chart == "pie": labels = df[x].astype(str).tolist() values = df[y].tolist() ax.pie(values, labels=labels, autopct="%1.1f%%") ax.axis("equal") else: plt.close(fig) return state out_path = os.path.join("plots", f"duckdb_answer_{uuid.uuid4().hex[:8]}.png") fig.tight_layout() fig.savefig(out_path, dpi=150) plt.close(fig) return {**state, "plot_path": out_path} except Exception: return state def _stream_llm_answer(llm: ChatOpenAI, msgs: List[Any], token_cb: Optional[Callable[[str], None]]) -> str: """Stream tokens for the final answer; returns full text as well.""" out = "" try: for chunk in llm.stream(msgs): # type: ignore piece = getattr(chunk, "content", None) if not piece and hasattr(chunk, "message") and getattr(chunk, "message", None): piece = getattr(chunk.message, "content", "") if not piece: continue out += piece if token_cb: token_cb(piece) except Exception: # Fallback to non-streaming if stream not supported resp = llm.invoke(msgs) # type: ignore out = resp.content if isinstance(resp.content, str) else str(resp.content) if token_cb: token_cb(out) return out def node_answer(state: AgentState, llm: ChatOpenAI) -> AgentState: preview = state.get("result_json") or "[]" columns = state.get("result_columns") or [] plot_path = state.get("plot_path") if len(preview) > 4000: preview = preview[:4000] + " ..." msgs = [ SystemMessage(content=ANSWER_SYSTEM_PROMPT), HumanMessage( content=( f"QUESTION:\n{state['question']}\n\n" f"COLUMNS: {columns}\n" f"RESULT PREVIEW (rows):\n{preview}\n\n" f"PLOT_PATH: {plot_path if plot_path else 'None'}" ) ), ] if state.get("error") and not state.get("result_json"): err_text = ( "I couldn't produce a working SQL query after 3 attempts.\n\n" "Details:\n" + (state.get("error") or "Unknown error") ) # Stream the error too for consistency if state.get("_stream") and state.get("_token_cb"): state["_token_cb"](err_text) return {**state, "final_answer": err_text} # Stream the final answer if requested if state.get("_stream"): answer = _stream_llm_answer(llm, msgs, state.get("_token_cb")) else: resp = llm.invoke(msgs) # type: ignore answer = resp.content if isinstance(resp.content, str) else str(resp.content) return {**state, "final_answer": answer} # ----------------------------- # Graph assembly # ----------------------------- def build_graph(con: duckdb.DuckDBPyConnection, llm: ChatOpenAI): g = StateGraph(AgentState) g.add_node("draft_sql", lambda s: node_draft_sql(s, llm)) g.add_node("run_sql", lambda s: node_run_sql(s, con)) g.add_node("refine_sql", lambda s: node_refine_sql(s, llm)) g.add_node("decide_viz", lambda s: node_decide_viz(s, llm)) g.add_node("make_plot", node_make_plot) g.add_node("answer", lambda s: node_answer(s, llm)) g.set_entry_point("draft_sql") g.add_edge("draft_sql", "run_sql") def on_run_sql(state: AgentState): if state.get("error"): if state.get("attempts", 0) < 3: return "refine_sql" else: return "answer" return "decide_viz" g.add_conditional_edges("run_sql", on_run_sql, { "refine_sql": "refine_sql", "decide_viz": "decide_viz", "answer": "answer", }) g.add_edge("refine_sql", "run_sql") def on_decide_viz(state: AgentState): spec = state.get("_viz_spec") or {} if spec.get("make_plot"): return "make_plot" return "answer" g.add_conditional_edges("decide_viz", on_decide_viz, { "make_plot": "make_plot", "answer": "answer", }) g.add_edge("make_plot", "answer") g.add_edge("answer", END) return g.compile() # ----------------------------- # Public function to ask a question # ----------------------------- def answer_question(con: duckdb.DuckDBPyConnection, llm: ChatOpenAI, schema_text: str, question: str, stream: bool = False, token_callback: Optional[Callable[[str], None]] = None) -> Dict[str, Any]: app = build_graph(con, llm) initial: AgentState = { "question": question, "schema": schema_text, "attempts": 0, "sql": None, "error": None, "result_json": None, "result_columns": None, "plot_path": None, "final_answer": None, "_result_df": None, "_viz_spec": None, "_stream": stream, "_token_cb": token_callback, } final_state: AgentState = app.invoke(initial) # type: ignore return { "question": question, "sql": final_state.get("sql"), "answer": final_state.get("final_answer"), "plot_path": final_state.get("plot_path"), "error": final_state.get("error"), } # ----------------------------- # CLI # ----------------------------- def main(): load_dotenv(find_dotenv()) parser = argparse.ArgumentParser(description="DuckDB Q&A ReAct Agent (Streaming)") parser.add_argument("--duckdb", required=True, help="Path to DuckDB database file") parser.add_argument("--model", default="gpt-4o", help="OpenAI chat model (default: gpt-4o)") parser.add_argument("--stream", action="store_true", help="Stream the final answer to stdout") parser.add_argument("--schemas", nargs="*", default=["main"], help="Schemas to include (default: main)") args = parser.parse_args() api_key = os.environ.get("OPENAI_API_KEY") if not api_key: print("ERROR: OPENAI_API_KEY not set. Put it in a .env file or export the env var.") return # Connect DuckDB strictly to the provided file (user DB) try: con = duckdb.connect(database=args.duckdb, read_only=True) except Exception as e: print(f"Failed to open DuckDB at {args.duckdb}: {e}") return # Introspect schema (user tables only) schema_text = get_schema_summary(con, allowed_schemas=args.schemas) # Init LLM; enable streaming capability (used only in final answer node) llm = make_llm(model=args.model, temperature=0.0) print("DuckDB Q&A Agent (ReAct, Streaming)\n") print(f"Connected to: {args.duckdb}") print(f"Schemas included: {', '.join(args.schemas)}") print("\nSchema snapshot:\n-----------------\n") print(schema_text) print("\nType your question and press ENTER. Type 'exit' to quit.\n") def print_token(t: str): print(t, end="", flush=True) while True: try: q = input("Q> ").strip() except (EOFError, KeyboardInterrupt): print("\nExiting.") break if q.lower() in ("exit", "quit"): print("Goodbye.") break if not q: continue if args.stream: print("\n--- ANSWER (streaming) ---") result = answer_question(con, llm, schema_text, q, stream=True, token_callback=print_token) print("") # newline after stream if result.get("plot_path"): print(f"\nChart saved to: {result['plot_path']}") print("\n--- SQL ---") print((result.get("sql") or "").strip()) if result.get("error"): print("\nERROR: " + str(result.get("error"))) else: result = answer_question(con, llm, schema_text, q, stream=False, token_callback=None) print("\n--- SQL ---") print((result.get("sql") or "").strip()) if result.get("error"): print("\n--- RESULT ---") print("Sorry, I couldn't resolve a working query after 3 attempts.") print("Error: " + str(result.get("error"))) else: print("\n--- ANSWER ---") print(result.get("answer") or "") if result.get("plot_path"): print(f"\nChart saved to: {result['plot_path']}") print("\n") con.close() if __name__ == "__main__": main()