why-agent / agent /tools /run_sql.py
MapoTofu9's picture
deploy: HF Spaces
5d30bdc
Raw
History Blame Contribute Delete
8.7 kB
"""run_sql tool — execute a read-only SELECT against DuckDB.
Use this AFTER calling inspect_schema to confirm table and column names.
Returns rows as a list of dicts, plus truncation metadata. Never raises —
all failures come back as {error, hint} so the agent can self-correct.
"""
from __future__ import annotations
import logging
import os
import re
import time
from pathlib import Path
import duckdb
from agent.constants import DEFAULT_PARQUET_DIR, ENV_PARQUET_DIR
from agent.tools.schemas import RunSqlInput, RunSqlOutput
logger = logging.getLogger(__name__)
# Stem must be a valid SQL identifier — prevents injection via filenames.
_SAFE_STEM_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def build_connection(parquet_dir: str | None = None) -> duckdb.DuckDBPyConnection:
"""Create an in-memory DuckDB connection with parquet files registered as views.
Parquet files whose stems are not valid SQL identifiers are skipped with a warning.
The view name is double-quoted; the file path is parameterized — no injection surface.
"""
directory = Path(parquet_dir or os.getenv(ENV_PARQUET_DIR, DEFAULT_PARQUET_DIR))
conn = duckdb.connect()
for pq_file in sorted(directory.glob("*.parquet")):
stem = pq_file.stem
if not _SAFE_STEM_RE.match(stem):
logger.warning("Skipping parquet file with unsafe stem: %r", stem)
continue
# DuckDB DDL doesn't support parameterized queries, so we escape
# single quotes in the path (path comes from the local filesystem, not user input).
safe_path = str(pq_file.resolve()).replace("'", "''")
conn.execute(f"CREATE VIEW \"{stem}\" AS SELECT * FROM read_parquet('{safe_path}')")
logger.debug("Registered view %r from %s", stem, pq_file)
return conn
def _strip_leading_sql_comments(query: str) -> str:
"""Remove leading SQL comments before validating the first executable token."""
stripped = query.lstrip()
while stripped:
if stripped.startswith("--"):
newline = stripped.find("\n")
if newline == -1:
return ""
stripped = stripped[newline + 1 :].lstrip()
continue
if stripped.startswith("/*"):
end = stripped.find("*/", 2)
if end == -1:
return ""
stripped = stripped[end + 2 :].lstrip()
continue
return stripped
return stripped
def _is_readonly(query: str) -> bool:
"""True only when the query is a bare SELECT or a WITH … SELECT (CTE).
Semicolons are rejected outright — LLM-generated queries never need them,
and they would allow multi-statement injection even though DuckDB only
executes the first statement in a single .execute() call.
"""
stripped = _strip_leading_sql_comments(query)
if not stripped:
return False
if ";" in stripped:
return False
first_token = stripped.split()[0].upper()
return first_token in {"SELECT", "WITH"}
def _hint_from_error(error: str) -> str:
"""Return a targeted hint by pattern-matching common DuckDB error messages."""
# "Referenced column X not found in FROM clause! Candidate bindings: Y"
col_match = re.search(
r'referenced column[^"]*"([^"]+)".*?candidate bindings:\s*"([^"]+)"',
error,
re.IGNORECASE | re.DOTALL,
)
if col_match:
missing, candidates = col_match.group(1), col_match.group(2)
return (
f'Column "{missing}" does not exist in the tables you queried '
f'(available: "{candidates}"). '
"If this column belongs to another table, add the appropriate JOIN — "
"call inspect_schema() with no args to list available tables and their join keys."
)
# "Referenced table X not found"
tbl_match = re.search(r'referenced table[^"]*"([^"]+)"', error, re.IGNORECASE)
if tbl_match:
return (
f'Table "{tbl_match.group(1)}" not found. '
"Call inspect_schema() with no args to list available tables, "
"then use the exact table name in your FROM clause."
)
# "Table X does not have a column named Y"
col2_match = re.search(r'table[^"]*"([^"]+)"[^"]*column[^"]*"([^"]+)"', error, re.IGNORECASE)
if col2_match:
return (
f'Column "{col2_match.group(2)}" not found in table "{col2_match.group(1)}". '
"Call inspect_schema(table=<name>) to see the correct column names, "
"paying attention to the primary_key field."
)
# "Values list "c" does not have a column named "name"" — DuckDB's label for
# any aliased subquery/CTE when column resolution fails.
values_match = re.search(
r'values list[^"]*"([^"]+)"[^"]*column[^"]*"([^"]+)"', error, re.IGNORECASE
)
if values_match:
alias, col = values_match.group(1), values_match.group(2)
return (
f'Column "{col}" does not exist in the result aliased as "{alias}". '
"Call inspect_schema(table=<name>) to confirm exact column names before writing SQL. "
"Rewrite the query using only columns confirmed by inspect_schema."
)
# "column X must appear in the GROUP BY clause"
if "must appear in the group by clause" in error.lower():
col_gb = re.search(r'column[^"]*"([^"]+)"', error, re.IGNORECASE)
col_name = col_gb.group(1) if col_gb else "unknown"
return (
f'Column "{col_name}" appears in SELECT or ORDER BY but is missing from GROUP BY. '
"Either add it to GROUP BY, or wrap it in an aggregate (e.g. ANY_VALUE(col))."
)
# "aggregate function calls cannot be nested"
if "aggregate function calls cannot be nested" in error.lower():
return (
"Nested aggregates (e.g. AVG(SUM(...))) are not allowed. "
"Use a subquery or CTE: compute the inner aggregate first, then aggregate the result."
)
return "Check table and column names with inspect_schema, then rewrite the query."
def _execute(args: RunSqlInput, conn: duckdb.DuckDBPyConnection) -> RunSqlOutput:
try:
start = time.monotonic()
cursor = conn.execute(args.query)
# Fetch one extra row to detect truncation without loading the full result set.
raw_rows = cursor.fetchmany(args.max_rows + 1)
elapsed_ms = (time.monotonic() - start) * 1000.0
truncated = len(raw_rows) > args.max_rows
raw_rows = raw_rows[: args.max_rows]
columns = [d[0] for d in cursor.description]
rows = [dict(zip(columns, row)) for row in raw_rows]
return RunSqlOutput(
rows=rows,
truncated=truncated,
row_count=len(rows),
execution_ms=round(elapsed_ms, 3),
)
except Exception as exc:
logger.warning("run_sql failed: %s", exc)
hint = _hint_from_error(str(exc))
return RunSqlOutput(
rows=[],
truncated=False,
row_count=0,
execution_ms=0.0,
error=str(exc),
hint=hint,
)
def run_sql(args: RunSqlInput, conn: duckdb.DuckDBPyConnection | None = None) -> RunSqlOutput:
"""Execute a read-only SELECT (or WITH … SELECT) against DuckDB.
Call inspect_schema first to confirm table/column names.
Returns up to max_rows rows; truncated=True signals more existed.
All errors are returned as {error, hint} — never raised.
Prefer injecting a shared `conn` from the graph; if omitted, a one-shot
connection is built from PARQUET_DIR and closed after the query.
"""
# Strip a single trailing semicolon — LLMs routinely append one; it is
# harmless but trips the read-only guard below.
clean_query = args.query.rstrip().rstrip(";").rstrip()
if clean_query != args.query:
args = RunSqlInput(query=clean_query, max_rows=args.max_rows)
if not _is_readonly(args.query):
return RunSqlOutput(
rows=[],
truncated=False,
row_count=0,
execution_ms=0.0,
error="Only SELECT (or WITH … SELECT) statements are allowed. Semicolons are not permitted.",
hint="Rewrite as a read-only SELECT. Use inspect_schema to find table/column names.",
)
if conn is not None:
return _execute(args, conn)
# One-shot path: build, query, close to avoid connection leaks.
logger.warning("run_sql called without injected connection; prefer passing a shared conn")
tmp = build_connection()
try:
return _execute(args, tmp)
finally:
tmp.close()