File size: 10,171 Bytes
570f7bd
b0bec17
105e019
 
b72c625
 
 
 
570f7bd
c24bfe8
 
570f7bd
b72c625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0bec17
b72c625
 
570f7bd
b72c625
 
 
 
 
 
570f7bd
 
 
b0bec17
b72c625
 
 
 
 
 
 
 
570f7bd
b72c625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1bc4eb
b72c625
370553a
b0bec17
b72c625
 
b0bec17
570f7bd
b72c625
 
 
c1bc4eb
570f7bd
b72c625
 
 
 
 
 
 
 
 
 
c1bc4eb
b72c625
 
370553a
b72c625
370553a
b72c625
 
 
 
 
 
 
b0bec17
 
b72c625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570f7bd
c1bc4eb
570f7bd
b72c625
 
 
 
 
570f7bd
 
b72c625
b0bec17
 
570f7bd
 
370553a
b72c625
 
c24bfe8
 
570f7bd
 
b72c625
 
 
 
c24bfe8
 
b72c625
 
 
b0bec17
570f7bd
 
b72c625
 
 
 
 
 
 
c24bfe8
 
b72c625
 
 
 
 
 
 
 
 
 
 
 
570f7bd
b72c625
 
 
b0bec17
b72c625
c24bfe8
 
570f7bd
 
b72c625
b0bec17
570f7bd
b72c625
 
 
 
c24bfe8
 
b72c625
 
 
 
 
570f7bd
b72c625
 
 
 
 
c24bfe8
 
b72c625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c24bfe8
 
b72c625
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0bec17
b72c625
c24bfe8
 
570f7bd
 
b72c625
b0bec17
570f7bd
 
b72c625
c24bfe8
 
b72c625
 
 
 
 
 
 
c24bfe8
 
570f7bd
 
 
370553a
b72c625
 
 
570f7bd
b0bec17
570f7bd
a337fad
b72c625
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
from __future__ import annotations

import re
import time
from typing import List, Pattern

import sqlglot

from nl2sql.types import StageResult, StageTrace
from nl2sql.metrics import safety_blocks_total, stage_duration_ms, safety_checks_total


# ------------------------- Zero-width & basic regexes -------------------------

_ZERO_WIDTH = [
    "\u200b",
    "\u200c",
    "\u200d",
    "\ufeff",
    "\u2060",
    "\u180e",
    "\u200e",
    "\u200f",
]
_ZERO_WIDTH_RE = re.compile("|".join(map(re.escape, _ZERO_WIDTH)))

# String / comment regexes
_STR_SINGLE_RE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL)
_STR_DOUBLE_RE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL)
_LINE_COMMENT_RE = re.compile(r"--[^\n]*")
_BLOCK_COMMENT_RE = re.compile(r"/\*.*?\*/", re.DOTALL)

# Markdown code fences: ```sql\n ... \n```
_FENCE_RE = re.compile(r"^\s*```[a-zA-Z]*\n(?P<body>.*)\n```\s*$", re.DOTALL)

# Strict forbidden keywords (word boundaries)
_FORBIDDEN: Pattern[str] = re.compile(
    r"\b("
    r"delete|update|insert|drop|create|alter|truncate|merge|"
    r"grant|revoke|execute|call|copy|attach|pragma|reindex|vacuum|replace"
    r")\b",
    re.IGNORECASE,
)


def _loose_keyword(pattern: str) -> Pattern[str]:
    r"""
    Build a regex that allows arbitrary whitespace between characters of a keyword.
    Example: "insert" -> i\s*n\s*s\s*e\s*r\s*t
    """
    chars = r"\s*".join(list(pattern))
    return re.compile(rf"\b{chars}\b", re.IGNORECASE)


_FORBIDDEN_LOOSE: List[Pattern[str]] = [
    _loose_keyword(w)
    for w in [
        "delete",
        "update",
        "insert",
        "drop",
        "create",
        "alter",
        "truncate",
        "merge",
        "grant",
        "revoke",
        "execute",
        "call",
        "copy",
        "attach",
        "pragma",
        "reindex",
        "vacuum",
        "replace",
    ]
]

_MAX_SQL_LEN = 200_000  # defensive bound against catastrophic inputs


def _ms(t0: float) -> int:
    return int((time.perf_counter() - t0) * 1000)


def _strip_fences(sql: str) -> str:
    m = _FENCE_RE.match(sql)
    return m.group("body") if m else sql


def _collapse_trailing_semicolons(body: str) -> str:
    """
    Keep at most one trailing semicolon. This makes 'SELECT 1;;' equivalent to 'SELECT 1;'.
    """
    body = body.rstrip()
    had_any = False
    while body.endswith(";"):
        had_any = True
        body = body[:-1].rstrip()
    return (body + ";") if had_any else body


def _sanitize(sql: str) -> str:
    """
    Remove zero-width chars, strip markdown fences, trim, and normalize trailing semicolons.
    """
    if not sql:
        return ""
    sql = _ZERO_WIDTH_RE.sub("", sql)
    sql = _strip_fences(sql)
    sql = sql.strip()
    sql = _collapse_trailing_semicolons(sql)
    return sql


def _remove_comments(body: str) -> str:
    body = _BLOCK_COMMENT_RE.sub("", body)
    body = _LINE_COMMENT_RE.sub("", body)
    return body


def _strip_strings(body: str) -> str:
    """
    Remove string literals (so forbidden keyword checks won't fire on quoted text).
    """
    body = _STR_SINGLE_RE.sub("''", body)
    body = _STR_DOUBLE_RE.sub('""', body)
    return body


def _count_statements_semicolon(body: str) -> int:
    """
    Count statements by semicolons after removing comments and masking strings.
    """
    masked_strings = _STR_SINGLE_RE.sub("'S'", body)
    masked_strings = _STR_DOUBLE_RE.sub('"S"', masked_strings)
    no_comments = _remove_comments(masked_strings)
    parts = [p.strip() for p in no_comments.split(";")]
    non_empty = [p for p in parts if p]
    return len(non_empty) if non_empty else 0


def _count_statements_sqlglot(body: str) -> int:
    """
    Count statements via sqlglot parser after removing comments.
    """
    try:
        trees = sqlglot.parse(_remove_comments(body))
        return len([t for t in trees if t is not None])
    except Exception:
        # If parse fails, conservatively return 1 to avoid double blocking.
        return 1


class Safety:
    """
    Read-only safety: allow only single-statement SELECT/EXPLAIN (configurable),
    block DML/DDL and multi-statements, detect obfuscations.
    """

    name = "safety"

    def __init__(self, allow_explain: bool = True) -> None:
        self.allow_explain = allow_explain

    def check(self, sql: str) -> StageResult:
        t0 = time.perf_counter()

        # 0) nil / size guard
        if not sql or not sql.strip():
            safety_blocks_total.labels(reason="empty_sql").inc()
            safety_checks_total.labels(ok="false").inc()
            return StageResult(
                ok=False,
                error=["empty_sql"],
                trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
            )
        if len(sql) > _MAX_SQL_LEN:
            safety_blocks_total.labels(reason="sql_too_long").inc()
            safety_checks_total.labels(ok="false").inc()
            return StageResult(
                ok=False,
                error=["sql_too_long"],
                trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
            )

        # 1) sanitize
        body = _sanitize(sql)

        # 2) single-statement check (semicolon + parser)
        semicolon_count = _count_statements_semicolon(body)
        glot_count = _count_statements_sqlglot(body)
        if semicolon_count != 1 or glot_count != 1:
            safety_blocks_total.labels(reason="multiple_statements").inc()
            safety_checks_total.labels(ok="false").inc()
            return StageResult(
                ok=False,
                error=["Multiple statements detected"],
                trace=StageTrace(
                    stage=self.name,
                    duration_ms=_ms(t0),
                    notes={
                        "semicolon_count": semicolon_count,
                        "parser_count": glot_count,
                    },
                ),
            )

        # 3) forbidden keywords (ignore inside string literals)
        scan_body = _strip_strings(body)
        m = _FORBIDDEN.search(scan_body)
        if m:
            tok = m.group(0).strip().lower()
            safety_blocks_total.labels(reason="forbidden_keyword").inc()
            safety_checks_total.labels(ok="false").inc()
            return StageResult(
                ok=False,
                error=[f"Forbidden: {tok}"],
                trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
            )
        for rx in _FORBIDDEN_LOOSE:
            m2 = rx.search(scan_body)
            if m2:
                tok = m2.group(0).strip().lower()
                safety_blocks_total.labels(reason="forbidden_keyword").inc()
                safety_checks_total.labels(ok="false").inc()
                return StageResult(
                    ok=False,
                    error=[f"Forbidden: {tok}"],
                    trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
                )

        # 4) read-only root kind (SELECT/EXPLAIN[/WITH])
        try:
            trees = sqlglot.parse(body)
            root = trees[0]
        except Exception as e:
            safety_blocks_total.labels(reason="parse_error").inc()
            safety_checks_total.labels(ok="false").inc()
            return StageResult(
                ok=False,
                error=["parse_error"],
                trace=StageTrace(
                    stage=self.name, duration_ms=_ms(t0), notes={"parse_error": str(e)}
                ),
            )

        root_type = type(root).__name__.lower()

        # Manual EXPLAIN handling for dialects that parse EXPLAIN to Command
        _EXPLAIN_HEAD_RE = re.compile(r"^\s*explain\s+", re.IGNORECASE)
        if self.allow_explain and _EXPLAIN_HEAD_RE.match(body):
            remainder = _EXPLAIN_HEAD_RE.sub("", body, count=1).lstrip()
            try:
                t2 = sqlglot.parse_one(remainder)
                t2_type = type(t2).__name__.lower() if t2 else ""
                if t2_type in {"select", "with"}:
                    stage_duration_ms.labels("safety").observe(_ms(t0) / 1.0)
                    safety_checks_total.labels(ok="true").inc()
                    return StageResult(
                        ok=True,
                        data={
                            "sql": body,
                            "original_len": len(sql),
                            "sanitized_len": len(body),
                            "allow_explain": True,
                        },
                        trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
                    )
            except Exception:
                # fall through to normal handling
                pass

        is_select_like = root_type in {"select", "with"}
        is_explain = root_type == "explain"

        if is_explain and not self.allow_explain:
            safety_blocks_total.labels(reason="explain_not_allowed").inc()
            safety_checks_total.labels(ok="false").inc()
            return StageResult(
                ok=False,
                error=["EXPLAIN not allowed"],
                trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
            )

        if not (is_select_like or (is_explain and self.allow_explain)):
            safety_blocks_total.labels(reason="non_select").inc()
            safety_checks_total.labels(ok="false").inc()
            return StageResult(
                ok=False,
                error=[f"Non-SELECT statement: {root_type}"],
                trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
            )

        # 5) success
        stage_duration_ms.labels("safety").observe(_ms(t0) / 1.0)
        safety_checks_total.labels(ok="true").inc()
        return StageResult(
            ok=True,
            data={
                "sql": body,
                "original_len": len(sql),
                "sanitized_len": len(body),
                "allow_explain": self.allow_explain,
            },
            trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
        )

    # Keep Pipeline API compatibility (pipeline calls .run(sql=...))
    def run(self, *, sql: str) -> StageResult:
        return self.check(sql)