File size: 2,428 Bytes
abd4352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5f9c5f
abd4352
 
 
 
 
 
 
c5f9c5f
abd4352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
sandbox/sql_sandbox.py
AST-level SQL validation using sqlglot.
Blocks all DML/DDL write operations before any query reaches the database.
"""

from typing import Tuple

import sqlglot
import sqlglot.expressions as exp

# Blocked statement types (write operations)
_BLOCKED_TYPES = (
    exp.Drop,
    exp.Delete,
    exp.Update,
    exp.Insert,
    exp.AlterTable,        # covers ALTER TABLE
    exp.Create,
    exp.Command,      # covers arbitrary COPY, VACUUM, etc.
    exp.Transaction,
)

# Also block by class name for newer sqlglot versions that add new node types
_BLOCKED_CLASS_NAMES = {
    "TruncateTable", "Truncate", "Revoke", "AlterTable", "Grant",
}

# Blocked function names (extra caution)
_BLOCKED_FUNCTIONS = {
    "pg_read_file", "pg_ls_dir", "pg_stat_file",
    "lo_import", "lo_export", "copy",
    "dblink", "dblink_exec",
}


def validate_sql(sql: str, dialect: str = "postgres") -> Tuple[bool, str]:
    """
    Parse and validate SQL.
    Returns (True, "") if safe, (False, reason) if blocked.
    """
    sql_stripped = sql.strip()
    if not sql_stripped:
        return False, "Empty query"

    try:
        statements = sqlglot.parse(sql_stripped, dialect=dialect, error_level=sqlglot.ErrorLevel.RAISE)
    except sqlglot.errors.ParseError as e:
        return False, f"SQL parse error: {e}"

    if not statements:
        return False, "No valid SQL statement found"

    for stmt in statements:
        if stmt is None:
            continue

        # Block write statement types
        if isinstance(stmt, _BLOCKED_TYPES):
            return False, f"Blocked statement type: {type(stmt).__name__}"
        if type(stmt).__name__ in _BLOCKED_CLASS_NAMES:
            return False, f"Blocked statement type: {type(stmt).__name__}"

        # Walk AST for any write nodes embedded in CTEs, subqueries, etc.
        for node in stmt.walk():
            if isinstance(node, _BLOCKED_TYPES):
                return False, f"Blocked operation in query: {type(node).__name__}"
            if type(node).__name__ in _BLOCKED_CLASS_NAMES:
                return False, f"Blocked operation in query: {type(node).__name__}"

            # Block dangerous function calls
            if isinstance(node, exp.Anonymous):
                fname = (node.name or "").lower()
                if fname in _BLOCKED_FUNCTIONS:
                    return False, f"Blocked function: {fname}"

    return True, ""