File size: 8,699 Bytes
5d30bdc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | """run_sql tool β execute a read-only SELECT against DuckDB.
Use this AFTER calling inspect_schema to confirm table and column names.
Returns rows as a list of dicts, plus truncation metadata. Never raises β
all failures come back as {error, hint} so the agent can self-correct.
"""
from __future__ import annotations
import logging
import os
import re
import time
from pathlib import Path
import duckdb
from agent.constants import DEFAULT_PARQUET_DIR, ENV_PARQUET_DIR
from agent.tools.schemas import RunSqlInput, RunSqlOutput
logger = logging.getLogger(__name__)
# Stem must be a valid SQL identifier β prevents injection via filenames.
_SAFE_STEM_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def build_connection(parquet_dir: str | None = None) -> duckdb.DuckDBPyConnection:
"""Create an in-memory DuckDB connection with parquet files registered as views.
Parquet files whose stems are not valid SQL identifiers are skipped with a warning.
The view name is double-quoted; the file path is parameterized β no injection surface.
"""
directory = Path(parquet_dir or os.getenv(ENV_PARQUET_DIR, DEFAULT_PARQUET_DIR))
conn = duckdb.connect()
for pq_file in sorted(directory.glob("*.parquet")):
stem = pq_file.stem
if not _SAFE_STEM_RE.match(stem):
logger.warning("Skipping parquet file with unsafe stem: %r", stem)
continue
# DuckDB DDL doesn't support parameterized queries, so we escape
# single quotes in the path (path comes from the local filesystem, not user input).
safe_path = str(pq_file.resolve()).replace("'", "''")
conn.execute(f"CREATE VIEW \"{stem}\" AS SELECT * FROM read_parquet('{safe_path}')")
logger.debug("Registered view %r from %s", stem, pq_file)
return conn
def _strip_leading_sql_comments(query: str) -> str:
"""Remove leading SQL comments before validating the first executable token."""
stripped = query.lstrip()
while stripped:
if stripped.startswith("--"):
newline = stripped.find("\n")
if newline == -1:
return ""
stripped = stripped[newline + 1 :].lstrip()
continue
if stripped.startswith("/*"):
end = stripped.find("*/", 2)
if end == -1:
return ""
stripped = stripped[end + 2 :].lstrip()
continue
return stripped
return stripped
def _is_readonly(query: str) -> bool:
"""True only when the query is a bare SELECT or a WITH β¦ SELECT (CTE).
Semicolons are rejected outright β LLM-generated queries never need them,
and they would allow multi-statement injection even though DuckDB only
executes the first statement in a single .execute() call.
"""
stripped = _strip_leading_sql_comments(query)
if not stripped:
return False
if ";" in stripped:
return False
first_token = stripped.split()[0].upper()
return first_token in {"SELECT", "WITH"}
def _hint_from_error(error: str) -> str:
"""Return a targeted hint by pattern-matching common DuckDB error messages."""
# "Referenced column X not found in FROM clause! Candidate bindings: Y"
col_match = re.search(
r'referenced column[^"]*"([^"]+)".*?candidate bindings:\s*"([^"]+)"',
error,
re.IGNORECASE | re.DOTALL,
)
if col_match:
missing, candidates = col_match.group(1), col_match.group(2)
return (
f'Column "{missing}" does not exist in the tables you queried '
f'(available: "{candidates}"). '
"If this column belongs to another table, add the appropriate JOIN β "
"call inspect_schema() with no args to list available tables and their join keys."
)
# "Referenced table X not found"
tbl_match = re.search(r'referenced table[^"]*"([^"]+)"', error, re.IGNORECASE)
if tbl_match:
return (
f'Table "{tbl_match.group(1)}" not found. '
"Call inspect_schema() with no args to list available tables, "
"then use the exact table name in your FROM clause."
)
# "Table X does not have a column named Y"
col2_match = re.search(r'table[^"]*"([^"]+)"[^"]*column[^"]*"([^"]+)"', error, re.IGNORECASE)
if col2_match:
return (
f'Column "{col2_match.group(2)}" not found in table "{col2_match.group(1)}". '
"Call inspect_schema(table=<name>) to see the correct column names, "
"paying attention to the primary_key field."
)
# "Values list "c" does not have a column named "name"" β DuckDB's label for
# any aliased subquery/CTE when column resolution fails.
values_match = re.search(
r'values list[^"]*"([^"]+)"[^"]*column[^"]*"([^"]+)"', error, re.IGNORECASE
)
if values_match:
alias, col = values_match.group(1), values_match.group(2)
return (
f'Column "{col}" does not exist in the result aliased as "{alias}". '
"Call inspect_schema(table=<name>) to confirm exact column names before writing SQL. "
"Rewrite the query using only columns confirmed by inspect_schema."
)
# "column X must appear in the GROUP BY clause"
if "must appear in the group by clause" in error.lower():
col_gb = re.search(r'column[^"]*"([^"]+)"', error, re.IGNORECASE)
col_name = col_gb.group(1) if col_gb else "unknown"
return (
f'Column "{col_name}" appears in SELECT or ORDER BY but is missing from GROUP BY. '
"Either add it to GROUP BY, or wrap it in an aggregate (e.g. ANY_VALUE(col))."
)
# "aggregate function calls cannot be nested"
if "aggregate function calls cannot be nested" in error.lower():
return (
"Nested aggregates (e.g. AVG(SUM(...))) are not allowed. "
"Use a subquery or CTE: compute the inner aggregate first, then aggregate the result."
)
return "Check table and column names with inspect_schema, then rewrite the query."
def _execute(args: RunSqlInput, conn: duckdb.DuckDBPyConnection) -> RunSqlOutput:
try:
start = time.monotonic()
cursor = conn.execute(args.query)
# Fetch one extra row to detect truncation without loading the full result set.
raw_rows = cursor.fetchmany(args.max_rows + 1)
elapsed_ms = (time.monotonic() - start) * 1000.0
truncated = len(raw_rows) > args.max_rows
raw_rows = raw_rows[: args.max_rows]
columns = [d[0] for d in cursor.description]
rows = [dict(zip(columns, row)) for row in raw_rows]
return RunSqlOutput(
rows=rows,
truncated=truncated,
row_count=len(rows),
execution_ms=round(elapsed_ms, 3),
)
except Exception as exc:
logger.warning("run_sql failed: %s", exc)
hint = _hint_from_error(str(exc))
return RunSqlOutput(
rows=[],
truncated=False,
row_count=0,
execution_ms=0.0,
error=str(exc),
hint=hint,
)
def run_sql(args: RunSqlInput, conn: duckdb.DuckDBPyConnection | None = None) -> RunSqlOutput:
"""Execute a read-only SELECT (or WITH β¦ SELECT) against DuckDB.
Call inspect_schema first to confirm table/column names.
Returns up to max_rows rows; truncated=True signals more existed.
All errors are returned as {error, hint} β never raised.
Prefer injecting a shared `conn` from the graph; if omitted, a one-shot
connection is built from PARQUET_DIR and closed after the query.
"""
# Strip a single trailing semicolon β LLMs routinely append one; it is
# harmless but trips the read-only guard below.
clean_query = args.query.rstrip().rstrip(";").rstrip()
if clean_query != args.query:
args = RunSqlInput(query=clean_query, max_rows=args.max_rows)
if not _is_readonly(args.query):
return RunSqlOutput(
rows=[],
truncated=False,
row_count=0,
execution_ms=0.0,
error="Only SELECT (or WITH β¦ SELECT) statements are allowed. Semicolons are not permitted.",
hint="Rewrite as a read-only SELECT. Use inspect_schema to find table/column names.",
)
if conn is not None:
return _execute(args, conn)
# One-shot path: build, query, close to avoid connection leaks.
logger.warning("run_sql called without injected connection; prefer passing a shared conn")
tmp = build_connection()
try:
return _execute(args, tmp)
finally:
tmp.close()
|