File size: 3,190 Bytes
570f7bd
105e019
 
570f7bd
 
 
 
c1bc4eb
570f7bd
 
 
 
 
 
 
 
 
 
 
 
 
370553a
 
 
c1bc4eb
370553a
 
 
 
 
570f7bd
 
370553a
 
 
570f7bd
 
c1bc4eb
570f7bd
 
 
 
 
c1bc4eb
570f7bd
370553a
 
 
 
570f7bd
370553a
 
570f7bd
c1bc4eb
570f7bd
 
 
 
 
 
370553a
 
 
570f7bd
 
 
 
 
 
 
c1bc4eb
 
 
570f7bd
 
 
 
 
 
 
 
c1bc4eb
 
 
570f7bd
 
 
 
 
 
c1bc4eb
 
 
570f7bd
 
 
 
 
370553a
 
570f7bd
c1bc4eb
 
 
570f7bd
a337fad
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from __future__ import annotations
import re
import time
from nl2sql.types import StageResult, StageTrace

# --- Regex utils ---
_COMMENT_BLOCK = re.compile(r"/\*.*?\*/", re.DOTALL)
_COMMENT_LINE = re.compile(r"--.*?$", re.MULTILINE)
# string literals (single & double quotes), allow escaped quotes
_STRING_SINGLE = re.compile(r"'([^'\\]|\\.)*'", re.DOTALL)
_STRING_DOUBLE = re.compile(r'"([^"\\]|\\.)*"', re.DOTALL)

# case-insensitive, word-boundary forbidden keywords
_FORBIDDEN = re.compile(
    r"\b(delete|update|insert|drop|create|alter|attach|pragma|reindex|vacuum|replace|grant|revoke|execute)\b",
    re.IGNORECASE,
)

# allow: SELECT ...   or   WITH <cte...> SELECT ...
_ALLOW_SELECT = re.compile(r"^(?:WITH\b.*?\)\s*)?SELECT\b", re.IGNORECASE | re.DOTALL)

# --- New cleanup helpers ---
_FENCE_SQL = re.compile(r"```sql", re.IGNORECASE)
_FENCE_ANY = re.compile(r"```")


def _sanitize_sql(sql: str) -> str:
    """Remove markdown fences, comments, and surrounding junk."""
    s = _FENCE_SQL.sub("", sql)
    s = _FENCE_ANY.sub("", s)
    s = _COMMENT_BLOCK.sub(" ", s)
    s = _COMMENT_LINE.sub(" ", s)
    s = s.strip()
    # remove trailing semicolon safely
    s = s.rstrip(";").strip()
    return s


def _mask_strings(s: str) -> str:
    s = _STRING_SINGLE.sub("'X'", s)
    s = _STRING_DOUBLE.sub('"X"', s)
    return s


def _split_statements(s: str) -> list[str]:
    """
    Split only if there are real multiple statements,
    ignoring harmless trailing semicolons or markdown.
    """
    parts = [p.strip() for p in s.split(";")]
    parts = [p for p in parts if p]
    return parts


class Safety:
    name = "safety"

    def check(self, sql: str) -> StageResult:
        t0 = time.perf_counter()
        print("🧩 SQL candidate:", sql)

        # --- sanitize first ---
        s = _sanitize_sql(sql)
        s = _mask_strings(s).strip()

        stmts = _split_statements(s)
        if len(stmts) != 1:
            return StageResult(
                ok=False,
                error=["Multiple statements detected"],
                trace=StageTrace(
                    stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
                ),
            )

        body = stmts[0]

        if _FORBIDDEN.search(body):
            return StageResult(
                ok=False,
                error=["Forbidden keyword detected"],
                trace=StageTrace(
                    stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
                ),
            )

        if not _ALLOW_SELECT.match(body):
            return StageResult(
                ok=False,
                error=["Non-SELECT statement"],
                trace=StageTrace(
                    stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
                ),
            )

        return StageResult(
            ok=True,
            data={
                "sql": body,
                "rationale": "Statement validated as SELECT-only (strings/comments/markdown ignored).",
            },
            trace=StageTrace(
                stage=self.name, duration_ms=(time.perf_counter() - t0) * 1000
            ),
        )

    run = check