nl2sql-copilot / nl2sql /verifier.py
Melika Kheirieh
style: format code with ruff
c1bc4eb
raw
history blame
1.45 kB
import sqlglot
from sqlglot import expressions as exp
from nl2sql.types import StageResult, StageTrace
class Verifier:
name = "verifier"
def run(self, sql: str, exec_result: StageResult) -> StageResult:
if not exec_result.ok:
return StageResult(
ok=False,
data=None,
trace=StageTrace(
stage=self.name, duration_ms=0, notes={"reason": "execution_error"}
),
error=exec_result.errors,
)
# Rule 1: check SELECT / GROUP consistency
issues = []
try:
tree = sqlglot.parse_one(sql)
if isinstance(tree, exp.Select):
group = tree.args.get("group")
aggs = [a for a in tree.find_all(exp.AggFunc)]
if aggs and not group:
issues.append("Aggregation without GROUP BY.")
except Exception as e:
issues.append(f"Parse error during verification: {e}")
if issues:
return StageResult(
ok=False,
data=None,
trace=StageTrace(
stage=self.name, duration_ms=0, notes={"issues": issues}
),
error=issues,
)
return StageResult(
ok=True,
data={"verified": True},
trace=StageTrace(stage=self.name, duration_ms=0),
)