File size: 3,967 Bytes
370553a
6a94b42
 
570f7bd
 
6a94b42
570f7bd
 
c1bc4eb
570f7bd
 
 
6a94b42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370553a
 
6a94b42
 
 
 
 
 
 
 
c1bc4eb
6a94b42
570f7bd
6a94b42
 
 
 
370553a
6a94b42
 
 
370553a
 
6a94b42
570f7bd
 
 
6a94b42
 
 
 
 
370553a
6a94b42
 
 
 
 
 
 
 
 
 
570f7bd
370553a
570f7bd
6a94b42
 
c1bc4eb
6a94b42
370553a
6a94b42
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import time
from typing import Any, Iterable

import sqlglot
from sqlglot import expressions as exp

from nl2sql.types import StageResult, StageTrace


class Verifier:
    name = "verifier"

    # ----------------- helpers -----------------
    @staticmethod
    def _extract_ok(exec_result: Any) -> bool | None:
        """Normalize exec_result.ok across dict or object."""
        if exec_result is None:
            return None
        if isinstance(exec_result, dict):
            return bool(exec_result.get("ok")) if "ok" in exec_result else None
        if hasattr(exec_result, "ok"):
            try:
                return bool(getattr(exec_result, "ok"))
            except Exception:
                return None
        return None

    @staticmethod
    def _extract_errors(exec_result: Any) -> list[str] | None:
        """Pull ['...'] from exec_result['error'] or exec_result.error."""
        val = None
        if isinstance(exec_result, dict):
            val = exec_result.get("error")
        elif hasattr(exec_result, "error"):
            val = getattr(exec_result, "error")

        if val is None:
            return None
        if isinstance(val, str):
            return [val]
        if isinstance(val, Iterable):
            # normalize to list[str]
            return [str(x) for x in val]
        return [str(val)]

    @staticmethod
    def _has_aggregation(tree: exp.Expression) -> bool:
        for node in tree.walk():
            if getattr(node, "is_aggregate", False):
                return True
            if isinstance(node, (exp.Count, exp.Sum, exp.Avg, exp.Min, exp.Max)):
                return True
        return False

    @staticmethod
    def _has_group_by(select: exp.Select) -> bool:
        return bool(select.args.get("group"))

    # ------------------- main -------------------
    def run(self, *, sql: str, exec_result: Any) -> StageResult:
        t0 = time.perf_counter()

        # 1) validate / normalize executor result
        ok_flag = self._extract_ok(exec_result)
        if ok_flag is False:
            errs = self._extract_errors(exec_result) or ["execution_error"]
            trace_err = StageTrace(
                stage=self.name,
                duration_ms=(time.perf_counter() - t0) * 1000,
                notes={"reason": "execution_error"},
            )
            return StageResult(ok=False, error=errs, trace=trace_err)

        if exec_result is None:
            trace_inv = StageTrace(
                stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
            )
            return StageResult(
                ok=False,
                error=["invalid or missing exec_result"],
                trace=trace_inv,
            )

        # 2) structural verification
        try:
            tree = sqlglot.parse_one(sql)
        except Exception as e:
            # parsing failed → accept with a note
            trace_skip = StageTrace(
                stage=self.name,
                duration_ms=(time.perf_counter() - t0) * 1000,
                notes={"note": f"Skipped parse: {e}"},
            )
            return StageResult(ok=True, data={"verified": True}, trace=trace_skip)

        issues: list[str] = []

        # Detect ANY aggregation without GROUP BY for SELECT statements
        if isinstance(tree, exp.Select):
            has_agg = self._has_aggregation(tree)
            has_group = self._has_group_by(tree)
            if has_agg and not has_group:
                issues.append("Aggregation without GROUP BY")

        dur = (time.perf_counter() - t0) * 1000
        if issues:
            trace_bad = StageTrace(
                stage=self.name, duration_ms=dur, notes={"issues": issues}
            )
            return StageResult(ok=False, error=issues, trace=trace_bad)

        # 3) success
        trace_ok = StageTrace(stage=self.name, duration_ms=dur)
        return StageResult(ok=True, data={"verified": True}, trace=trace_ok)