File size: 8,699 Bytes
5d30bdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""run_sql tool β€” execute a read-only SELECT against DuckDB.

Use this AFTER calling inspect_schema to confirm table and column names.
Returns rows as a list of dicts, plus truncation metadata. Never raises β€”
all failures come back as {error, hint} so the agent can self-correct.
"""

from __future__ import annotations

import logging
import os
import re
import time
from pathlib import Path

import duckdb

from agent.constants import DEFAULT_PARQUET_DIR, ENV_PARQUET_DIR
from agent.tools.schemas import RunSqlInput, RunSqlOutput

logger = logging.getLogger(__name__)

# Stem must be a valid SQL identifier β€” prevents injection via filenames.
_SAFE_STEM_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")


def build_connection(parquet_dir: str | None = None) -> duckdb.DuckDBPyConnection:
    """Create an in-memory DuckDB connection with parquet files registered as views.

    Parquet files whose stems are not valid SQL identifiers are skipped with a warning.
    The view name is double-quoted; the file path is parameterized β€” no injection surface.
    """
    directory = Path(parquet_dir or os.getenv(ENV_PARQUET_DIR, DEFAULT_PARQUET_DIR))
    conn = duckdb.connect()
    for pq_file in sorted(directory.glob("*.parquet")):
        stem = pq_file.stem
        if not _SAFE_STEM_RE.match(stem):
            logger.warning("Skipping parquet file with unsafe stem: %r", stem)
            continue
        # DuckDB DDL doesn't support parameterized queries, so we escape
        # single quotes in the path (path comes from the local filesystem, not user input).
        safe_path = str(pq_file.resolve()).replace("'", "''")
        conn.execute(f"CREATE VIEW \"{stem}\" AS SELECT * FROM read_parquet('{safe_path}')")
        logger.debug("Registered view %r from %s", stem, pq_file)
    return conn


def _strip_leading_sql_comments(query: str) -> str:
    """Remove leading SQL comments before validating the first executable token."""
    stripped = query.lstrip()
    while stripped:
        if stripped.startswith("--"):
            newline = stripped.find("\n")
            if newline == -1:
                return ""
            stripped = stripped[newline + 1 :].lstrip()
            continue
        if stripped.startswith("/*"):
            end = stripped.find("*/", 2)
            if end == -1:
                return ""
            stripped = stripped[end + 2 :].lstrip()
            continue
        return stripped
    return stripped


def _is_readonly(query: str) -> bool:
    """True only when the query is a bare SELECT or a WITH … SELECT (CTE).

    Semicolons are rejected outright β€” LLM-generated queries never need them,
    and they would allow multi-statement injection even though DuckDB only
    executes the first statement in a single .execute() call.
    """
    stripped = _strip_leading_sql_comments(query)
    if not stripped:
        return False
    if ";" in stripped:
        return False
    first_token = stripped.split()[0].upper()
    return first_token in {"SELECT", "WITH"}


def _hint_from_error(error: str) -> str:
    """Return a targeted hint by pattern-matching common DuckDB error messages."""
    # "Referenced column X not found in FROM clause! Candidate bindings: Y"
    col_match = re.search(
        r'referenced column[^"]*"([^"]+)".*?candidate bindings:\s*"([^"]+)"',
        error,
        re.IGNORECASE | re.DOTALL,
    )
    if col_match:
        missing, candidates = col_match.group(1), col_match.group(2)
        return (
            f'Column "{missing}" does not exist in the tables you queried '
            f'(available: "{candidates}"). '
            "If this column belongs to another table, add the appropriate JOIN β€” "
            "call inspect_schema() with no args to list available tables and their join keys."
        )

    # "Referenced table X not found"
    tbl_match = re.search(r'referenced table[^"]*"([^"]+)"', error, re.IGNORECASE)
    if tbl_match:
        return (
            f'Table "{tbl_match.group(1)}" not found. '
            "Call inspect_schema() with no args to list available tables, "
            "then use the exact table name in your FROM clause."
        )

    # "Table X does not have a column named Y"
    col2_match = re.search(r'table[^"]*"([^"]+)"[^"]*column[^"]*"([^"]+)"', error, re.IGNORECASE)
    if col2_match:
        return (
            f'Column "{col2_match.group(2)}" not found in table "{col2_match.group(1)}". '
            "Call inspect_schema(table=<name>) to see the correct column names, "
            "paying attention to the primary_key field."
        )

    # "Values list "c" does not have a column named "name"" β€” DuckDB's label for
    # any aliased subquery/CTE when column resolution fails.
    values_match = re.search(
        r'values list[^"]*"([^"]+)"[^"]*column[^"]*"([^"]+)"', error, re.IGNORECASE
    )
    if values_match:
        alias, col = values_match.group(1), values_match.group(2)
        return (
            f'Column "{col}" does not exist in the result aliased as "{alias}". '
            "Call inspect_schema(table=<name>) to confirm exact column names before writing SQL. "
            "Rewrite the query using only columns confirmed by inspect_schema."
        )

    # "column X must appear in the GROUP BY clause"
    if "must appear in the group by clause" in error.lower():
        col_gb = re.search(r'column[^"]*"([^"]+)"', error, re.IGNORECASE)
        col_name = col_gb.group(1) if col_gb else "unknown"
        return (
            f'Column "{col_name}" appears in SELECT or ORDER BY but is missing from GROUP BY. '
            "Either add it to GROUP BY, or wrap it in an aggregate (e.g. ANY_VALUE(col))."
        )

    # "aggregate function calls cannot be nested"
    if "aggregate function calls cannot be nested" in error.lower():
        return (
            "Nested aggregates (e.g. AVG(SUM(...))) are not allowed. "
            "Use a subquery or CTE: compute the inner aggregate first, then aggregate the result."
        )

    return "Check table and column names with inspect_schema, then rewrite the query."


def _execute(args: RunSqlInput, conn: duckdb.DuckDBPyConnection) -> RunSqlOutput:
    try:
        start = time.monotonic()
        cursor = conn.execute(args.query)
        # Fetch one extra row to detect truncation without loading the full result set.
        raw_rows = cursor.fetchmany(args.max_rows + 1)
        elapsed_ms = (time.monotonic() - start) * 1000.0

        truncated = len(raw_rows) > args.max_rows
        raw_rows = raw_rows[: args.max_rows]

        columns = [d[0] for d in cursor.description]
        rows = [dict(zip(columns, row)) for row in raw_rows]

        return RunSqlOutput(
            rows=rows,
            truncated=truncated,
            row_count=len(rows),
            execution_ms=round(elapsed_ms, 3),
        )
    except Exception as exc:
        logger.warning("run_sql failed: %s", exc)
        hint = _hint_from_error(str(exc))
        return RunSqlOutput(
            rows=[],
            truncated=False,
            row_count=0,
            execution_ms=0.0,
            error=str(exc),
            hint=hint,
        )


def run_sql(args: RunSqlInput, conn: duckdb.DuckDBPyConnection | None = None) -> RunSqlOutput:
    """Execute a read-only SELECT (or WITH … SELECT) against DuckDB.

    Call inspect_schema first to confirm table/column names.
    Returns up to max_rows rows; truncated=True signals more existed.
    All errors are returned as {error, hint} β€” never raised.

    Prefer injecting a shared `conn` from the graph; if omitted, a one-shot
    connection is built from PARQUET_DIR and closed after the query.
    """
    # Strip a single trailing semicolon β€” LLMs routinely append one; it is
    # harmless but trips the read-only guard below.
    clean_query = args.query.rstrip().rstrip(";").rstrip()
    if clean_query != args.query:
        args = RunSqlInput(query=clean_query, max_rows=args.max_rows)

    if not _is_readonly(args.query):
        return RunSqlOutput(
            rows=[],
            truncated=False,
            row_count=0,
            execution_ms=0.0,
            error="Only SELECT (or WITH … SELECT) statements are allowed. Semicolons are not permitted.",
            hint="Rewrite as a read-only SELECT. Use inspect_schema to find table/column names.",
        )

    if conn is not None:
        return _execute(args, conn)

    # One-shot path: build, query, close to avoid connection leaks.
    logger.warning("run_sql called without injected connection; prefer passing a shared conn")
    tmp = build_connection()
    try:
        return _execute(args, tmp)
    finally:
        tmp.close()