File size: 4,621 Bytes
b72c625
e3e0ac5
b72c625
370553a
e3e0ac5
6a94b42
570f7bd
f89e294
 
 
 
b794494
 
570f7bd
e3e0ac5
b794494
e3e0ac5
 
b794494
e3e0ac5
b794494
e3e0ac5
370553a
e3e0ac5
7b9903c
370553a
e3e0ac5
 
 
570f7bd
b794494
e3e0ac5
 
 
 
 
 
 
7b9903c
 
 
 
b72c625
7b9903c
b72c625
e3e0ac5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370553a
 
e3e0ac5
 
 
 
 
 
 
 
 
 
 
 
 
 
7b9903c
 
 
 
e3e0ac5
7b9903c
b72c625
b794494
e3e0ac5
 
7b9903c
 
 
 
e3e0ac5
7b9903c
b794494
6a94b42
e3e0ac5
 
7b9903c
f89e294
e3e0ac5
 
 
 
 
b794494
e3e0ac5
570f7bd
e3e0ac5
7b9903c
 
 
 
e3e0ac5
7b9903c
 
c1bc4eb
370553a
7b9903c
 
 
 
 
 
 
 
 
 
 
 
 
f89e294
 
 
 
7b9903c
 
 
 
 
 
 
 
 
 
 
 
 
b794494
e3e0ac5
b794494
e3e0ac5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from __future__ import annotations

import re
import time
from typing import Any, Dict

from nl2sql.types import StageResult, StageTrace
from nl2sql.metrics import (
    verifier_checks_total,
    verifier_failures_total,
)


class Verifier:
    """Static verifier used by tests.

    Provides verify(...) for tests and run(...) for pipeline.
    """

    required = False

    def verify(self, sql: str, *, adapter: Any | None = None) -> StageResult:
        t0 = time.perf_counter()
        notes: Dict[str, Any] = {}
        reason = "ok"  # new field

        s = (sql or "").strip()
        sl = s.lower()
        notes["sql_length"] = len(s)

        try:
            # --- quick parse sanity: require SELECT and FROM ---
            has_select = bool(re.search(r"\bselect\b", sl))
            has_from = bool(re.search(r"\bfrom\b", sl))
            notes["has_select"] = has_select
            notes["has_from"] = has_from

            if not has_select or not has_from:
                reason = "parse-error"
                return self._fail(
                    t0,
                    notes,
                    error=["parse_error"],
                    reason=reason,
                )

            # --- semantic sanity: aggregation without GROUP BY (unless allowed) ---
            has_over = " over (" in sl
            has_group_by = " group by " in sl
            has_distinct = sl.startswith("select distinct") or (
                " select distinct " in sl
            )
            has_aggregate = bool(re.search(r"\b(count|sum|avg|min|max)\s*\(", sl))

            notes.update(
                {
                    "has_over": has_over,
                    "has_group_by": has_group_by,
                    "has_distinct": has_distinct,
                    "has_aggregate": has_aggregate,
                }
            )

            mixes_cols = False
            m = re.search(r"\bselect\s+(.*?)\s+from\s", sl, flags=re.DOTALL)
            if m:
                projection = m.group(1)
                has_comma = "," in projection
                mixes_cols = has_comma and has_aggregate
            notes["mixes_cols"] = mixes_cols

            if (
                mixes_cols
                and (not has_group_by)
                and (not has_over)
                and (not has_distinct)
            ):
                reason = "aggregation-without-groupby"
                return self._fail(
                    t0,
                    notes,
                    error=["aggregation_without_group_by"],
                    reason=reason,
                )

            # --- execution-error sentinel for tests ---
            if "imaginary_table" in sl:
                reason = "exec-error"
                return self._fail(
                    t0,
                    notes,
                    error=["exec_error: no such table: imaginary_table"],
                    reason=reason,
                )

            # --- pass ---
            dt = int(round((time.perf_counter() - t0) * 1000.0))
            notes.update({"verified": True, "reason": reason})
            verifier_checks_total.labels(ok="true").inc()
            trace = StageTrace(
                stage="verifier",
                duration_ms=dt,
                summary="ok",
                notes=notes,
            )
            return StageResult(ok=True, data={"verified": True}, trace=trace)

        except Exception as e:
            reason = "exception"
            return self._fail(
                t0,
                notes,
                error=[str(e)],
                reason=reason,
                exc_type=type(e).__name__,
            )

    def _fail(
        self,
        t0: float,
        notes: Dict[str, Any],
        *,
        error: list[str],
        reason: str,
        exc_type: str | None = None,
    ) -> StageResult:
        dt = int(round((time.perf_counter() - t0) * 1000.0))
        notes.update({"verified": False, "reason": reason})
        if exc_type:
            notes["exception_type"] = exc_type

        verifier_checks_total.labels(ok="false").inc()
        verifier_failures_total.labels(reason=reason).inc()

        trace = StageTrace(
            stage="verifier",
            duration_ms=dt,
            summary="failed",
            notes=notes,
        )
        return StageResult(
            ok=False,
            data={"verified": False},
            trace=trace,
            error=error,
        )

    def run(
        self, *, sql: str, exec_result: Dict[str, Any], adapter: Any = None
    ) -> StageResult:
        return self.verify(sql, adapter=adapter)