nl2sql-copilot / nl2sql /verifier.py
Melika Kheirieh
fix(verifier): robust aggregate detection and projection-level semantic check
b72c625
raw
history blame
11.6 kB
from __future__ import annotations
import re
import time
from typing import Any, Iterable, List, Optional
import sqlglot
from sqlglot import expressions as exp
from nl2sql.types import StageResult, StageTrace
def _ms(t0: float) -> int:
return int((time.perf_counter() - t0) * 1000)
class Verifier:
name = "verifier"
# Textual fallback: scan for common aggregate calls
_AGG_CALL_RE = re.compile(r"\b(count|sum|avg|min|max)\s*\(", re.IGNORECASE)
# ----------------------- AST helpers (version-friendly) --------------------
def _walk(self, node: exp.Expression) -> Iterable[exp.Expression]:
"""Non-recursive DFS over sqlglot Expression tree (avoid private APIs)."""
stack = [node]
while stack:
cur = stack.pop()
if isinstance(cur, exp.Expression):
yield cur
args = getattr(cur, "args", {}) or {}
for v in args.values():
if isinstance(v, exp.Expression):
stack.append(v)
elif isinstance(v, list):
for it in v:
if isinstance(it, exp.Expression):
stack.append(it)
def _first_select(self, tree: exp.Expression) -> Optional[exp.Select]:
for n in self._walk(tree):
if isinstance(n, exp.Select):
return n
return None
def _has_group_by(self, tree: exp.Expression) -> bool:
sel = self._first_select(tree)
if not sel:
return False
# sqlglot stores GROUP BY on Select.group
return bool(getattr(sel, "group", None))
def _is_distinct_projection(self, tree: exp.Expression) -> bool:
sel = self._first_select(tree)
if not sel:
return False
# DISTINCT may appear as Select.distinct or a Distinct node
if getattr(sel, "distinct", None):
return True
return any(isinstance(n, exp.Distinct) for n in self._walk(sel))
def _has_windowed_aggregate(self, tree: exp.Expression) -> bool:
# If there is any OVER(...) window, aggregates without GROUP BY can be legitimate
return any(isinstance(n, exp.Window) for n in self._walk(tree))
def _expr_contains_agg(self, node: exp.Expression) -> bool:
"""True if subtree contains an aggregate call."""
# Note: exp.Aggregate doesn't exist in sqlglot, use specific aggregate types
AGG_TYPES = (exp.Count, exp.Sum, exp.Avg, exp.Min, exp.Max)
# Also check for other aggregate functions that might exist
try:
AGG_TYPES = AGG_TYPES + (exp.GroupConcat, exp.ArrayAgg, exp.StringAgg)
except AttributeError:
pass # Some aggregate types might not exist in all sqlglot versions
return any(isinstance(n, AGG_TYPES) for n in self._walk(node))
def _has_nonagg_column(self, node: exp.Expression) -> bool:
"""Subtree contains a column reference that is NOT inside an aggregate."""
# Check if there are any columns in this expression
columns = [n for n in self._walk(node) if isinstance(n, exp.Column)]
if not columns:
return False
# Check if all columns are inside aggregates
for col in columns:
# Walk up from column to see if it's inside an aggregate
# is_in_agg = False
# For simplicity, check if the entire expression contains both column and aggregate
# A more precise check would require parent tracking
if self._expr_contains_agg(node):
# This is a simplified check - if the node has both columns and aggregates,
# we need more complex logic to determine if columns are outside aggregates
return True
else:
# No aggregates, so if there are columns, they're non-aggregate
return True
return False
# ----------------------- Textual fallback helpers -------------------------
def _clean_sql_for_fn_scan(self, sql: str) -> str:
"""Remove comments/strings so regex won't be fooled."""
s = re.sub(r"/\*.*?\*/", " ", sql, flags=re.DOTALL) # block comments
s = re.sub(r"--.*?$", " ", s, flags=re.MULTILINE) # line comments
s = re.sub(
r"('([^']|'')*'|\"([^\"]|\"\")*\"|`[^`]*`)", " ", s
) # quoted strings / idents
s = re.sub(r"\s+", " ", s).strip()
return s
# ----------------------- Adapter result helpers ---------------------------
def _extract_ok(self, exec_result: Any) -> Optional[bool]:
if isinstance(exec_result, dict):
v = exec_result.get("ok")
if isinstance(v, bool):
return v
return None
def _extract_error(self, exec_result: Any) -> Optional[str]:
if isinstance(exec_result, dict):
for k in ("error", "message", "detail"):
if k in exec_result and exec_result[k]:
return str(exec_result[k])
return None
# ----------------------------- Main entry ---------------------------------
def verify(self, sql: str, *, adapter: Any) -> StageResult:
t0 = time.perf_counter()
issues: List[str] = []
# 1) Parse - Check for errors in the parsed result
try:
tree = sqlglot.parse_one(sql, read=None) # autodetect dialect
# Check if the parse actually succeeded
if tree is None:
return StageResult(
ok=False,
error=["parse_error"],
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
)
# sqlglot may parse broken SQL as an "Unknown" or "Command" type
# Check if we got a proper SQL statement type
tree_type = type(tree).__name__
# Check for common sqlglot error indicators
# When sqlglot can't parse properly, it often creates Command or Unknown nodes
if tree_type in ("Command", "Unknown"):
return StageResult(
ok=False,
error=["parse_error"],
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
)
# Also check if the tree has errors attribute (some versions of sqlglot)
if hasattr(tree, "errors") and tree.errors:
return StageResult(
ok=False,
error=["parse_error"],
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
)
# Additional check: if it's not a recognized DML/DQL statement
valid_types = ("Select", "With", "Union", "Intersect", "Except", "Values")
if tree_type not in valid_types:
# This might be a parse error disguised as a different statement type
# Let's check if it looks like it should be a SELECT
sql_lower = sql.lower().strip()
if any(
sql_lower.startswith(kw)
for kw in ["selct", "slect", "selet", "seelct"]
):
# Common misspellings of SELECT
return StageResult(
ok=False,
error=["parse_error"],
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
)
except Exception:
return StageResult(
ok=False,
error=["parse_error"],
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
)
# 2) Semantic checks (AST-first)
try:
sel = self._first_select(tree)
if sel:
has_group = self._has_group_by(tree)
has_window = self._has_windowed_aggregate(tree)
is_distinct = self._is_distinct_projection(tree)
select_items = list(getattr(sel, "expressions", []) or [])
any_agg = any(self._expr_contains_agg(it) for it in select_items)
# More precise check for non-aggregate columns
any_nonagg_col = False
for item in select_items:
# Check if this select item has columns but no aggregates
has_cols = any(isinstance(n, exp.Column) for n in self._walk(item))
has_aggs = self._expr_contains_agg(item)
if has_cols and not has_aggs:
any_nonagg_col = True
break
# Core rule: aggregate + non-aggregate column without GROUP BY is an issue,
# unless DISTINCT or windowed aggregate makes it legitimate.
if (
any_agg
and any_nonagg_col
and not (has_group or has_window or is_distinct)
):
issues.append("aggregation_without_group_by")
except Exception as e:
# Don't crash the verifier; surface a soft issue and let fallback run
issues.append(f"semantic_check_error:{e!s}")
# 3) Fallback textual scan — only if AST didn't already flag
if not any("aggregation_without_group_by" in i for i in issues):
try:
cleaned = self._clean_sql_for_fn_scan(sql)
has_agg_call = bool(self._AGG_CALL_RE.search(cleaned))
has_group_kw = re.search(r"\bgroup\s+by\b", cleaned, re.IGNORECASE)
has_over_kw = re.search(r"\bover\s*\(", cleaned, re.IGNORECASE)
has_distinct_kw = re.search(
r"\bselect\s+distinct\b", cleaned, re.IGNORECASE
)
if has_agg_call and not (
has_group_kw or has_over_kw or has_distinct_kw
):
m_sel = re.search(
r"\bselect\s+(?P<sel>.+?)\s+\bfrom\b",
cleaned,
re.IGNORECASE | re.DOTALL,
)
if m_sel:
select_list = m_sel.group("sel")
# a comma strongly suggests mixing aggregate and non-aggregate in projection
if "," in select_list:
issues.append("aggregation_without_group_by")
except Exception:
# ignore fallback errors
pass
# 4) Optional: cheap preview execution (adapter may be a stub in tests)
try:
exec_result = adapter.execute_preview(sql) if adapter else {"ok": True}
ok_val = self._extract_ok(exec_result)
if ok_val is False:
err = self._extract_error(exec_result)
issues.append(f"exec_error:{err}" if err else "exec_error")
except Exception as e:
issues.append(f"exec_exception:{e!s}")
# 5) Final decision — AFTER all checks (note: no early return before fallback)
if issues:
return StageResult(
ok=False,
error=issues,
trace=StageTrace(
stage=self.name, duration_ms=_ms(t0), notes={"issues": issues}
),
)
return StageResult(
ok=True,
data={"verified": True},
trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
)