"""
DuckDB Q&A Agent (LangGraph + ReAct) — Streaming Edition
========================================================
Adds:
- Streaming responses for the final natural-language answer (CLI prints tokens live).
- Stricter schema discovery focused on *user* tables from the provided DB (defaults to schema 'main').
- Keeps SELECT-only safety, up to 3 SQL refinements on error, optional plotting saved to ./plots.
Quick start
-----------
1) Install deps (example):
pip install --upgrade duckdb pandas matplotlib python-dotenv langgraph langchain langchain-openai
2) Ensure an OpenAI API key is available:
- Put it in a .env file as: OPENAI_API_KEY=sk-...
- Or set an env var: set OPENAI_API_KEY=... (Windows) / export OPENAI_API_KEY=... (macOS/Linux)
3) Run the agent in an interactive loop:
python duckdb_react_agent.py --duckdb path/to/your.db --stream
Notes
-----
- Targets DuckDB SQL. The LLM prompt instructs to write *only* a SELECT statement.
- ReAct loop: on error, the LLM sees the error message + previous SQL and attempts to fix it (max 3 tries).
- The agent avoids DDL/DML; SELECT-only for safety.
- Plots are saved to ./plots/ as PNG files (no GUI required). The script uses a non-interactive backend.
- Internal/helper tables beginning with "__" (e.g., ingestion metadata) and non-user schemas are ignored.
"""
import os
import re
import json
import uuid
import argparse
from typing import Any, Dict, List, Optional, TypedDict, Callable
# Non-GUI backend for saving plots on servers/Windows terminals
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import duckdb
import pandas as pd
from dotenv import load_dotenv, find_dotenv
# LangChain / LangGraph
from langchain_openai import ChatOpenAI
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.graph import StateGraph, END
# -----------------------------
# Utility: safe number formatting
# -----------------------------
def format_number(x: Any) -> str:
try:
if isinstance(x, (int, float)) and not isinstance(x, bool):
if abs(x) >= 1000 and abs(x) < 1e12:
return f"{x:,.2f}".rstrip('0').rstrip('.')
else:
return f"{x:.4g}" if isinstance(x, float) else str(x)
return str(x)
except Exception:
return str(x)
# -----------------------------
# Schema introspection from DuckDB (user tables only)
# -----------------------------
def get_schema_summary(con: duckdb.DuckDBPyConnection, sample_rows: int = 3, include_views: bool = True,
allowed_schemas: Optional[List[str]] = None) -> str:
"""
Build a compact schema snapshot for user tables in the *provided DB*.
By default we only list tables/views in schema 'main' (user content) and skip internals.
"""
if allowed_schemas is None:
allowed_schemas = ["main"]
type_filter = "('BASE TABLE')" if not include_views else "('BASE TABLE','VIEW')"
tables_df = con.execute(
f"""
SELECT table_schema, table_name, table_type
FROM information_schema.tables
WHERE table_type IN {type_filter}
AND table_schema IN ({','.join(['?']*len(allowed_schemas))})
AND table_name NOT LIKE 'duckdb_%'
AND table_name NOT LIKE 'sqlite_%'
ORDER BY table_schema, table_name
""",
allowed_schemas
).fetchdf()
lines: List[str] = []
for _, row in tables_df.iterrows():
schema = row["table_schema"]
name = row["table_name"]
if name.startswith("__"):
# Common ingestion metadata; not user-facing
continue
cols = con.execute(
"""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = ? AND table_name = ?
ORDER BY ordinal_position
""",
[schema, name],
).fetchdf()
lines.append(f"TABLE {schema}.{name}")
if len(cols) == 0:
lines.append(" (no columns discovered)")
continue
# Sample small set of values per column to guide the LLM
try:
sample = con.execute(f"SELECT * FROM {schema}.{name} LIMIT {sample_rows}").fetchdf()
except Exception:
sample = pd.DataFrame(columns=cols["column_name"].tolist())
for _, c in cols.iterrows():
col = c["column_name"]
dtype = c["data_type"]
examples: List[str] = []
if col in sample.columns:
for v in sample[col].tolist():
examples.append(format_number(v)[:80])
example_str = ", ".join(examples) if examples else ""
lines.append(f" - {col} :: {dtype} e.g. [{example_str}]")
lines.append("")
if not lines:
lines.append("(No user tables discovered in schema(s): " + ", ".join(allowed_schemas) + ")")
return "\n".join(lines)
# -----------------------------
# LLM helpers
# -----------------------------
def make_llm(model: str = "gpt-4o-mini", temperature: float = 0.0) -> ChatOpenAI:
return ChatOpenAI(model=model, temperature=temperature)
SQL_SYSTEM_PROMPT = """You are an expert DuckDB SQL writer. You will receive:
- A database schema (tables and columns) with a few sample values
- A user's natural-language question
Write exactly ONE valid DuckDB SQL SELECT statement to answer the question.
Rules:
- Output ONLY the SQL (no backticks, no explanation, no comments)
- Use DuckDB SQL
- Never write DDL/DML (CREATE/INSERT/UPDATE/DELETE), only SELECT statements
- Prefer explicit column names, avoid SELECT *
- If time grouping is needed, use date_trunc('month', col), date_trunc('day', col), etc.
- If casting is needed, use CAST(... AS TYPE)
- Use ONLY the listed user tables from the provided database (e.g., schema 'main'). Do not rely on external/attached sources.
- Avoid tables starting with "__" unless explicitly referenced.
- If uncertain, make a reasonable assumption and produce the best possible query
Be careful with joins and filters. Ensure SQL parses successfully.
"""
REFINE_SYSTEM_PROMPT = """You are fixing a DuckDB SQL query that errored. You'll see:
- the user's question
- the previous SQL
- the exact error message
Return a corrected DuckDB SELECT statement that addresses the error.
Rules:
- Output ONLY the SQL (no backticks, no explanation, no comments)
- SELECT-only (no DDL/DML)
- Keep the intent of the question intact
- Fix joins, field names, casts, aggregations or date functions as needed
- Use ONLY the listed user tables from the provided database (e.g., schema 'main')
"""
ANSWER_SYSTEM_PROMPT = """You are a helpful data analyst. Given a user's question and the query result data,
write a clear, concise, conversational answer in plain English. Use bullet points or short paragraphs.
- Format large numbers with thousands separators where appropriate.
- If percentages are present, include % signs.
- If a chart/image accompanies the answer, DO NOT mention file paths or filenames. Refer generically, e.g., 'as depicted in the chart' or 'see the chart alongside this answer.'
- If the result is empty, say so and suggest a next step.
- Do not invent columns or values that are not in the result.
- (Optional) You may refer to prior Q&A context if it informs this answer.
"""
VIZ_SYSTEM_PROMPT = """You will decide whether a chart would help answer the question using the SQL result.
Return STRICT JSON with keys: {"make_plot": bool, "chart": "line|bar|scatter|hist|box|pie", "x": "
", "y": "", "series": "", "agg": "", "reason": ""}.
Conservative rules (prefer NO plot):
- Only set make_plot=true if the user explicitly asks for a chart/visualization OR the result involves trends over time, distributions, or comparisons with many rows (e.g., > 20) where a chart adds clarity.
- For simple aggregates or few rows (<= 10), set make_plot=false unless explicitly requested.
Guidelines when make_plot=true:
- Prefer line for trends over time
- Prefer bar for ranking/comparison of categories
- Prefer scatter for correlation between two numeric columns
- Prefer hist for distribution of a single numeric column
- Prefer box for distribution with quartiles per category
- Prefer pie sparingly for simple part-to-whole
JSON only. No prose.
"""
def extract_sql(text: str) -> str:
"""Extract SQL from potential LLM output. We expect raw SQL with no fences, but be defensive."""
text = text.strip()
fence = re.compile(r"```(?:sql)?(.*?)```", re.DOTALL | re.IGNORECASE)
m = fence.search(text)
if m:
return m.group(1).strip()
return text
# -----------------------------
# LangGraph State
# -----------------------------
class AgentState(TypedDict):
question: str
schema: str
attempts: int
sql: Optional[str]
error: Optional[str]
result_json: Optional[str] # JSON records preview
result_columns: Optional[List[str]]
plot_path: Optional[str]
final_answer: Optional[str]
_result_df: Optional[pd.DataFrame]
_viz_spec: Optional[Dict[str, Any]]
_stream: bool
_token_cb: Optional[Callable[[str], None]]
# -----------------------------
# Nodes: SQL drafting, execution, refinement, viz, answer
# -----------------------------
def node_draft_sql(state: AgentState, llm: ChatOpenAI) -> AgentState:
msgs = [
SystemMessage(content=SQL_SYSTEM_PROMPT),
HumanMessage(content=f"SCHEMA:\n{state['schema']}\n\nQUESTION:\n{state['question']}"),
]
resp = llm.invoke(msgs) # type: ignore
sql = extract_sql(resp.content or "")
return {**state, "sql": sql, "error": None}
def run_duckdb_query(con: duckdb.DuckDBPyConnection, sql: str) -> pd.DataFrame:
first_token = re.split(r"\s+", sql.strip(), maxsplit=1)[0].upper()
if first_token != "SELECT" and not sql.strip().upper().startswith("WITH "):
raise ValueError("Only SELECT queries are allowed.")
df = con.execute(sql).fetchdf()
return df
def node_run_sql(state: AgentState, con: duckdb.DuckDBPyConnection) -> AgentState:
try:
df = run_duckdb_query(con, state["sql"] or "")
preview = df.head(50).to_dict(orient="records")
return {
**state,
"error": None,
"result_json": json.dumps(preview, default=str),
"result_columns": list(df.columns),
"_result_df": df,
}
except Exception as e:
return {**state, "error": str(e), "result_json": None, "result_columns": None, "_result_df": None}
def node_refine_sql(state: AgentState, llm: ChatOpenAI) -> AgentState:
if state.get("attempts", 0) >= 3:
return state
msgs = [
SystemMessage(content=REFINE_SYSTEM_PROMPT),
HumanMessage(
content=(
f"QUESTION:\n{state['question']}\n\n"
f"PREVIOUS SQL:\n{state.get('sql','')}\n\n"
f"ERROR:\n{state.get('error','')}"
)
),
]
resp = llm.invoke(msgs) # type: ignore
sql = extract_sql(resp.content or "")
return {**state, "sql": sql, "attempts": state.get("attempts", 0) + 1, "error": None}
def node_decide_viz(state: AgentState, llm: ChatOpenAI) -> AgentState:
if not state.get("result_json"):
return state
result_preview = state["result_json"]
msgs = [
SystemMessage(content=VIZ_SYSTEM_PROMPT),
HumanMessage(
content=(
f"QUESTION:\n{state['question']}\n\n"
f"COLUMNS: {state.get('result_columns',[])}\n"
f"RESULT PREVIEW (first rows):\n{result_preview}"
)
),
]
resp = llm.invoke(msgs) # type: ignore
spec = {"make_plot": False}
# try:
# spec = json.loads(resp.content) # type: ignore
# if not isinstance(spec, dict) or "make_plot" not in spec:
# spec = {"make_plot": False}
# except Exception:
# spec = {"make_plot": False}
return {**state, "_viz_spec": spec}
def node_make_plot(state: AgentState) -> AgentState:
spec = state.get("_viz_spec") or {}
if not spec or not spec.get("make_plot"):
return state
df: Optional[pd.DataFrame] = state.get("_result_df")
if df is None or df.empty:
return state
# Additional guard: avoid trivial plots unless explicitly requested
q_lower = (state.get('question') or '').lower()
explicit_viz = any(k in q_lower for k in ['chart', 'plot', 'graph', 'visual', 'visualize', 'trend'])
many_rows = df.shape[0] > 20
if not explicit_viz and not many_rows:
return state
x = spec.get("x")
y = spec.get("y")
series = spec.get("series")
chart = (spec.get("chart") or "").lower()
def col_ok(c: Optional[str]) -> bool:
return isinstance(c, str) and c in df.columns
if not col_ok(x) and df.shape[1] >= 1:
x = df.columns[0]
if not col_ok(y) and df.shape[1] >= 2:
y = df.columns[1]
try:
os.makedirs("plots", exist_ok=True)
fig = plt.figure()
ax = fig.gca()
if chart == "line":
if series and col_ok(series):
for k, g in df.groupby(series):
ax.plot(g[x], g[y], label=str(k))
ax.legend(loc="best")
else:
ax.plot(df[x], df[y])
ax.set_xlabel(str(x)); ax.set_ylabel(str(y))
elif chart == "bar":
if series and col_ok(series):
pivot = df.pivot_table(index=x, columns=series, values=y, aggfunc="sum", fill_value=0)
pivot.plot(kind="bar", ax=ax)
ax.set_xlabel(str(x)); ax.set_ylabel(str(y))
else:
ax.bar(df[x], df[y])
ax.set_xlabel(str(x)); ax.set_ylabel(str(y))
fig.autofmt_xdate(rotation=45)
elif chart == "scatter":
ax.scatter(df[x], df[y])
ax.set_xlabel(str(x)); ax.set_ylabel(str(y))
elif chart == "hist":
ax.hist(df[y] if col_ok(y) else df[x], bins=30)
ax.set_xlabel(str(y if col_ok(y) else x)); ax.set_ylabel("Frequency")
elif chart == "box":
ax.boxplot([df[y]] if col_ok(y) else [df[x]])
ax.set_ylabel(str(y if col_ok(y) else x))
elif chart == "pie":
labels = df[x].astype(str).tolist()
values = df[y].tolist()
ax.pie(values, labels=labels, autopct="%1.1f%%")
ax.axis("equal")
else:
plt.close(fig)
return state
out_path = os.path.join("plots", f"duckdb_answer_{uuid.uuid4().hex[:8]}.png")
fig.tight_layout()
fig.savefig(out_path, dpi=150)
plt.close(fig)
return {**state, "plot_path": out_path}
except Exception:
return state
def _stream_llm_answer(llm: ChatOpenAI, msgs: List[Any], token_cb: Optional[Callable[[str], None]]) -> str:
"""Stream tokens for the final answer; returns full text as well."""
out = ""
try:
for chunk in llm.stream(msgs): # type: ignore
piece = getattr(chunk, "content", None)
if not piece and hasattr(chunk, "message") and getattr(chunk, "message", None):
piece = getattr(chunk.message, "content", "")
if not piece:
continue
out += piece
if token_cb:
token_cb(piece)
except Exception:
# Fallback to non-streaming if stream not supported
resp = llm.invoke(msgs) # type: ignore
out = resp.content if isinstance(resp.content, str) else str(resp.content)
if token_cb:
token_cb(out)
return out
def node_answer(state: AgentState, llm: ChatOpenAI) -> AgentState:
preview = state.get("result_json") or "[]"
columns = state.get("result_columns") or []
plot_path = state.get("plot_path")
if len(preview) > 4000:
preview = preview[:4000] + " ..."
msgs = [
SystemMessage(content=ANSWER_SYSTEM_PROMPT),
HumanMessage(
content=(
f"QUESTION:\n{state['question']}\n\n"
f"COLUMNS: {columns}\n"
f"RESULT PREVIEW (rows):\n{preview}\n\n"
f"PLOT_PATH: {plot_path if plot_path else 'None'}"
)
),
]
if state.get("error") and not state.get("result_json"):
err_text = (
"I couldn't produce a working SQL query after 3 attempts.\n\n"
"Details:\n" + (state.get("error") or "Unknown error")
)
# Stream the error too for consistency
if state.get("_stream") and state.get("_token_cb"):
state["_token_cb"](err_text)
return {**state, "final_answer": err_text}
# Stream the final answer if requested
if state.get("_stream"):
answer = _stream_llm_answer(llm, msgs, state.get("_token_cb"))
else:
resp = llm.invoke(msgs) # type: ignore
answer = resp.content if isinstance(resp.content, str) else str(resp.content)
return {**state, "final_answer": answer}
# -----------------------------
# Graph assembly
# -----------------------------
def build_graph(con: duckdb.DuckDBPyConnection, llm: ChatOpenAI):
g = StateGraph(AgentState)
g.add_node("draft_sql", lambda s: node_draft_sql(s, llm))
g.add_node("run_sql", lambda s: node_run_sql(s, con))
g.add_node("refine_sql", lambda s: node_refine_sql(s, llm))
g.add_node("decide_viz", lambda s: node_decide_viz(s, llm))
g.add_node("make_plot", node_make_plot)
g.add_node("answer", lambda s: node_answer(s, llm))
g.set_entry_point("draft_sql")
g.add_edge("draft_sql", "run_sql")
def on_run_sql(state: AgentState):
if state.get("error"):
if state.get("attempts", 0) < 3:
return "refine_sql"
else:
return "answer"
return "decide_viz"
g.add_conditional_edges("run_sql", on_run_sql, {
"refine_sql": "refine_sql",
"decide_viz": "decide_viz",
"answer": "answer",
})
g.add_edge("refine_sql", "run_sql")
def on_decide_viz(state: AgentState):
spec = state.get("_viz_spec") or {}
if spec.get("make_plot"):
return "make_plot"
return "answer"
g.add_conditional_edges("decide_viz", on_decide_viz, {
"make_plot": "make_plot",
"answer": "answer",
})
g.add_edge("make_plot", "answer")
g.add_edge("answer", END)
return g.compile()
# -----------------------------
# Public function to ask a question
# -----------------------------
def answer_question(con: duckdb.DuckDBPyConnection, llm: ChatOpenAI, schema_text: str, question: str,
stream: bool = False, token_callback: Optional[Callable[[str], None]] = None) -> Dict[str, Any]:
app = build_graph(con, llm)
initial: AgentState = {
"question": question,
"schema": schema_text,
"attempts": 0,
"sql": None,
"error": None,
"result_json": None,
"result_columns": None,
"plot_path": None,
"final_answer": None,
"_result_df": None,
"_viz_spec": None,
"_stream": stream,
"_token_cb": token_callback,
}
final_state: AgentState = app.invoke(initial) # type: ignore
return {
"question": question,
"sql": final_state.get("sql"),
"answer": final_state.get("final_answer"),
"plot_path": final_state.get("plot_path"),
"error": final_state.get("error"),
}
# -----------------------------
# CLI
# -----------------------------
def main():
load_dotenv(find_dotenv())
parser = argparse.ArgumentParser(description="DuckDB Q&A ReAct Agent (Streaming)")
parser.add_argument("--duckdb", required=True, help="Path to DuckDB database file")
parser.add_argument("--model", default="gpt-4o", help="OpenAI chat model (default: gpt-4o)")
parser.add_argument("--stream", action="store_true", help="Stream the final answer to stdout")
parser.add_argument("--schemas", nargs="*", default=["main"], help="Schemas to include (default: main)")
args = parser.parse_args()
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
print("ERROR: OPENAI_API_KEY not set. Put it in a .env file or export the env var.")
return
# Connect DuckDB strictly to the provided file (user DB)
try:
con = duckdb.connect(database=args.duckdb, read_only=True)
except Exception as e:
print(f"Failed to open DuckDB at {args.duckdb}: {e}")
return
# Introspect schema (user tables only)
schema_text = get_schema_summary(con, allowed_schemas=args.schemas)
# Init LLM; enable streaming capability (used only in final answer node)
llm = make_llm(model=args.model, temperature=0.0)
print("DuckDB Q&A Agent (ReAct, Streaming)\n")
print(f"Connected to: {args.duckdb}")
print(f"Schemas included: {', '.join(args.schemas)}")
print("\nSchema snapshot:\n-----------------\n")
print(schema_text)
print("\nType your question and press ENTER. Type 'exit' to quit.\n")
def print_token(t: str):
print(t, end="", flush=True)
while True:
try:
q = input("Q> ").strip()
except (EOFError, KeyboardInterrupt):
print("\nExiting.")
break
if q.lower() in ("exit", "quit"):
print("Goodbye.")
break
if not q:
continue
if args.stream:
print("\n--- ANSWER (streaming) ---")
result = answer_question(con, llm, schema_text, q, stream=True, token_callback=print_token)
print("") # newline after stream
if result.get("plot_path"):
print(f"\nChart saved to: {result['plot_path']}")
print("\n--- SQL ---")
print((result.get("sql") or "").strip())
if result.get("error"):
print("\nERROR: " + str(result.get("error")))
else:
result = answer_question(con, llm, schema_text, q, stream=False, token_callback=None)
print("\n--- SQL ---")
print((result.get("sql") or "").strip())
if result.get("error"):
print("\n--- RESULT ---")
print("Sorry, I couldn't resolve a working query after 3 attempts.")
print("Error: " + str(result.get("error")))
else:
print("\n--- ANSWER ---")
print(result.get("answer") or "")
if result.get("plot_path"):
print(f"\nChart saved to: {result['plot_path']}")
print("\n")
con.close()
if __name__ == "__main__":
main()