Spaces:
Running
Running
| import sqlglot | |
| from sqlglot import expressions as exp | |
| from nl2sql.types import StageResult, StageTrace | |
| class Verifier: | |
| name = "verifier" | |
| def run(self, sql: str, exec_result: StageResult) -> StageResult: | |
| if not exec_result.ok: | |
| return StageResult(ok=False, data=None, | |
| trace=StageTrace(stage=self.name, duration_ms=0, | |
| notes={"reason": "execution_error"}), | |
| error=exec_result.errors) | |
| # Rule 1: check SELECT / GROUP consistency | |
| issues = [] | |
| try: | |
| tree = sqlglot.parse_one(sql) | |
| if isinstance(tree, exp.Select): | |
| group = tree.args.get("group") | |
| aggs = [a for a in tree.find_all(exp.AggFunc)] | |
| if aggs and not group: | |
| issues.append("Aggregation without GROUP BY.") | |
| except Exception as e: | |
| issues.append(f"Parse error during verification: {e}") | |
| if issues: | |
| return StageResult(ok=False, data=None, | |
| trace=StageTrace(stage=self.name, duration_ms=0, | |
| notes={"issues": issues}), | |
| error=issues) | |
| return StageResult(ok=True, data={"verified": True}, | |
| trace=StageTrace(stage=self.name, duration_ms=0)) | |