File size: 1,381 Bytes
570f7bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import sqlglot
from sqlglot import expressions as exp
from nl2sql.types import StageResult, StageTrace

class Verifier:
    name = "verifier"

    def run(self, sql: str, exec_result: StageResult) -> StageResult:
        if not exec_result.ok:
            return StageResult(ok=False, data=None,
                               trace=StageTrace(stage=self.name, duration_ms=0,
                               notes={"reason": "execution_error"}),
                               error=exec_result.errors)

        # Rule 1: check SELECT / GROUP consistency
        issues = []
        try:
            tree = sqlglot.parse_one(sql)
            if isinstance(tree, exp.Select):
                group = tree.args.get("group")
                aggs = [a for a in tree.find_all(exp.AggFunc)]
                if aggs and not group:
                    issues.append("Aggregation without GROUP BY.")
        except Exception as e:
            issues.append(f"Parse error during verification: {e}")

        if issues:
            return StageResult(ok=False, data=None,
                               trace=StageTrace(stage=self.name, duration_ms=0,
                               notes={"issues": issues}),
                               error=issues)
        return StageResult(ok=True, data={"verified": True},
                           trace=StageTrace(stage=self.name, duration_ms=0))