File size: 4,390 Bytes
570f7bd
b0bec17
105e019
 
b0bec17
570f7bd
 
 
 
c1bc4eb
b0bec17
 
570f7bd
 
 
b0bec17
570f7bd
 
 
 
 
b0bec17
 
 
 
 
 
 
 
 
 
570f7bd
b0bec17
370553a
 
c1bc4eb
370553a
b0bec17
 
 
 
 
 
 
 
 
 
 
 
370553a
b0bec17
 
 
370553a
570f7bd
 
370553a
 
 
570f7bd
 
c1bc4eb
570f7bd
b0bec17
570f7bd
 
 
 
c1bc4eb
570f7bd
370553a
b0bec17
370553a
570f7bd
b0bec17
 
 
 
 
570f7bd
c1bc4eb
570f7bd
 
 
b0bec17
 
 
 
 
 
570f7bd
 
370553a
b0bec17
370553a
570f7bd
 
b0bec17
570f7bd
 
 
 
 
b0bec17
570f7bd
 
 
 
b0bec17
 
 
570f7bd
 
b0bec17
 
570f7bd
 
b0bec17
 
 
 
 
 
570f7bd
 
 
b0bec17
570f7bd
 
b0bec17
570f7bd
 
 
370553a
b0bec17
 
 
 
570f7bd
b0bec17
570f7bd
a337fad
b0bec17
a337fad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from __future__ import annotations

import re
import time
import unicodedata
from nl2sql.types import StageResult, StageTrace

# --- Regex utils ---
_COMMENT_BLOCK = re.compile(r"/\*.*?\*/", re.DOTALL)
_COMMENT_LINE = re.compile(r"--.*?$", re.MULTILINE)

# String literals (single & double quotes), allow escaped quotes
_STRING_SINGLE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL)
_STRING_DOUBLE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL)

# Case-insensitive, word-boundary forbidden keywords
_FORBIDDEN = re.compile(
    r"\b(delete|update|insert|drop|create|alter|attach|pragma|reindex|vacuum|replace|grant|revoke|execute)\b",
    re.IGNORECASE,
)

# Allow: SELECT ...  or   WITH (one or many CTEs, optional RECURSIVE) ... SELECT ...
_ALLOW_SELECT = re.compile(
    r"^(?:WITH\s+(?:RECURSIVE\s+)?"
    r".*?\)\s*(?:,\s*.*?\)\s*)*"
    r")?SELECT\b",
    re.IGNORECASE | re.DOTALL,
)

# Optional allowance: EXPLAIN SELECT ...
_ALLOW_EXPLAIN_SELECT = re.compile(r"^EXPLAIN\s+SELECT\b", re.IGNORECASE | re.DOTALL)

# --- Cleanup helpers ---
_FENCE_SQL = re.compile(r"```sql", re.IGNORECASE)
_FENCE_ANY = re.compile(r"```")


def _normalize_sql(sql: str) -> str:
    """Normalize to NFKC and strip zero-width characters to prevent obfuscation."""
    s = unicodedata.normalize("NFKC", sql)
    # strip common zero-width spaces/joiners
    return (
        s.replace("\u200b", "")
        .replace("\u200c", "")
        .replace("\u200d", "")
        .replace("\ufeff", "")
    )


def _sanitize_sql(sql: str) -> str:
    """Remove markdown fences, comments, and harmless trailing semicolons."""
    s = _normalize_sql(sql)
    s = _FENCE_SQL.sub("", s)
    s = _FENCE_ANY.sub("", s)
    s = _COMMENT_BLOCK.sub(" ", s)
    s = _COMMENT_LINE.sub(" ", s)
    s = s.strip()
    # remove trailing semicolon safely
    s = s.rstrip(";").strip()
    return s


def _mask_strings(s: str) -> str:
    """Replace string literals so that inner semicolons/keywords don't affect checks."""
    s = _STRING_SINGLE.sub("'X'", s)
    s = _STRING_DOUBLE.sub('"X"', s)
    return s


def _split_statements(s: str) -> list[str]:
    """
    Split on semicolons after string-masking. Ignore empties (e.g., trailing ';').
    """
    parts = [p.strip() for p in s.split(";")]
    return [p for p in parts if p]


def _ms(t0: float) -> int:
    return int((time.perf_counter() - t0) * 1000)


class Safety:
    name = "safety"

    def __init__(self, allow_explain: bool = False) -> None:
        """
        :param allow_explain: If True, 'EXPLAIN SELECT ...' is allowed in addition to SELECT.
        """
        self.allow_explain = allow_explain

    def check(self, sql: str) -> StageResult:
        t0 = time.perf_counter()

        # 1) Sanitize and mask
        s = _sanitize_sql(sql)
        s = _mask_strings(s).strip()

        # 2) Multiple statements check
        stmts = _split_statements(s)
        if len(stmts) != 1:
            return StageResult(
                ok=False,
                error=["Multiple statements detected"],
                trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
            )

        body = stmts[0]

        # 3) Forbidden keyword check (report exact offending token)
        m = _FORBIDDEN.search(body)
        if m:
            return StageResult(
                ok=False,
                error=[f"Forbidden keyword detected: '{m.group(0)}'"],
                trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
            )

        # 4) Allow only SELECT (or optionally EXPLAIN SELECT)
        allowed = bool(_ALLOW_SELECT.match(body))
        if not allowed and self.allow_explain:
            allowed = bool(_ALLOW_EXPLAIN_SELECT.match(body))

        if not allowed:
            return StageResult(
                ok=False,
                error=["Non-SELECT statement"],
                trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
            )

        # 5) Success
        return StageResult(
            ok=True,
            data={
                "sql": body,
                "rationale": (
                    "Statement validated as SELECT-only (strings/comments/markdown ignored)."
                    + (" EXPLAIN SELECT allowed." if self.allow_explain else "")
                ),
            },
            trace=StageTrace(stage=self.name, duration_ms=_ms(t0)),
        )

    # Backward-compat alias
    run = check