File size: 10,764 Bytes
723f9ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
037873c
723f9ab
 
037873c
 
 
3f2266f
723f9ab
 
3f2266f
 
 
 
 
 
723f9ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204988a
723f9ab
 
 
204988a
 
 
 
 
723f9ab
 
 
 
 
 
 
 
 
 
 
 
204988a
 
 
 
 
723f9ab
204988a
 
 
 
 
723f9ab
037873c
 
 
 
 
723f9ab
 
 
 
 
3f2266f
 
 
 
 
 
723f9ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f1bc70
5f30028
2f1bc70
 
5f30028
 
 
 
 
 
 
 
 
 
 
2f1bc70
 
 
5f30028
 
 
2f1bc70
 
 
5f30028
 
 
2f1bc70
 
 
5f30028
 
 
2f1bc70
 
 
5f30028
 
 
2f1bc70
723f9ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import re
from typing import Dict, List, Tuple

class QueryParser:
    """Parse and validate SQL queries to prevent injection attacks"""
    
    # Allowed SQL keywords for read operations
    SAFE_READ_KEYWORDS = {
        'SELECT', 'FROM', 'WHERE', 'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER',
        'ON', 'AND', 'OR', 'NOT', 'IN', 'LIKE', 'BETWEEN', 'IS', 'NULL',
        'ORDER', 'BY', 'GROUP', 'HAVING', 'LIMIT', 'OFFSET', 'AS', 'DISTINCT',
        'COUNT', 'SUM', 'AVG', 'MIN', 'MAX', 'UNION', 'INTERSECT', 'EXCEPT'
    }
    
    # Write operations that need allow_write=True
    WRITE_KEYWORDS = {
        'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE', 'ALTER', 'DROP'
    }
    
    # CREATE operations (allowed for tables, blocked for databases)
    CREATE_KEYWORDS = {'CREATE'}
    
    # Always dangerous - never allow (must be standalone commands)
    ALWAYS_DANGEROUS = {
        'GRANT', 'REVOKE', 'EXEC', 'EXECUTE', 'CALL', 'DECLARE', 'CURSOR',
        'PRAGMA'
    }
    
    # Dangerous commands that need word boundary checking
    DANGEROUS_COMMANDS = {
        'ATTACH', 'DETACH'
    }
    
    # Keywords that need special validation
    DROP_DATABASE_PATTERNS = [
        r'DROP\s+DATABASE',
        r'DROP\s+SCHEMA'
    ]
    
    CREATE_DATABASE_PATTERNS = [
        r'CREATE\s+DATABASE',
        r'CREATE\s+SCHEMA'
    ]
    
    def __init__(self):
        self.max_query_length = 10000
    
    def sanitize_identifier(self, identifier: str) -> str:
        """Sanitize table/column identifiers"""
        # Remove any non-alphanumeric characters except underscore and dot
        sanitized = re.sub(r'[^\w\.]', '', identifier)
        
        # Prevent path traversal
        if '..' in sanitized or sanitized.startswith('/'):
            raise ValueError(f"Invalid identifier: {identifier}")
        
        return sanitized
    
    def validate_query(self, sql: str, allow_write: bool = False, allow_schema_ops: bool = False) -> Tuple[bool, str]:
        """
        Validate SQL query for safety
        
        Args:
            sql: SQL query to validate
            allow_write: Allow INSERT, UPDATE, DELETE, CREATE TABLE
            allow_schema_ops: Allow CREATE/DROP DATABASE/SCHEMA (for SQL imports)
        
        Returns:
            (is_valid, error_message)
        """
        if not sql or not sql.strip():
            return False, "Empty query"
        
        if len(sql) > self.max_query_length:
            return False, f"Query too long (max {self.max_query_length} chars)"
        
        # Normalize query
        sql_upper = sql.upper()
        
        # Check for DROP DATABASE/SCHEMA (blocked unless allow_schema_ops)
        if not allow_schema_ops:
            for pattern in self.DROP_DATABASE_PATTERNS:
                if re.search(pattern, sql_upper):
                    return False, "DROP DATABASE/SCHEMA not allowed"
        
        # Check for CREATE DATABASE/SCHEMA (blocked unless allow_schema_ops)
        if not allow_schema_ops:
            for pattern in self.CREATE_DATABASE_PATTERNS:
                if re.search(pattern, sql_upper):
                    return False, "CREATE DATABASE/SCHEMA not allowed via SQL. Use the 'Create Database' button or API endpoint instead."
        
        # Check for CREATE TABLE (allowed with allow_write=True)
        if 'CREATE' in sql_upper and 'TABLE' in sql_upper:
            if not allow_write:
                return False, "CREATE TABLE requires allow_write=True"
        
        # Check for always dangerous keywords
        for keyword in self.ALWAYS_DANGEROUS:
            if keyword in sql_upper:
                return False, f"Dangerous keyword not allowed: {keyword}"
        
        # Check for dangerous commands (with word boundaries to avoid false positives)
        for keyword in self.DANGEROUS_COMMANDS:
            # Use word boundary to match only standalone commands, not table names
            if re.search(r'\b' + keyword + r'\b', sql_upper):
                return False, f"Dangerous keyword not allowed: {keyword}"
        
        # Check for write operations
        for keyword in self.WRITE_KEYWORDS:
            if not allow_write and keyword in sql_upper:
                return False, f"Write operation not allowed: {keyword} (use allow_write=True)"
        
        # Check for common SQL injection patterns (but allow comments for phpMyAdmin imports)
        injection_patterns = [
            r';\s*DROP\s+DATABASE',  # Only block DROP DATABASE after semicolon
            r'UNION\s+SELECT.*FROM\s+(?![\w\.]+\s)',  # UNION injection
            r'OR\s+1\s*=\s*1',
            r'OR\s+\'1\'\s*=\s*\'1\'',
        ]
        
        for pattern in injection_patterns:
            if re.search(pattern, sql_upper):
                return False, f"Potential SQL injection detected"
        
        return True, ""
    
    def extract_tables(self, sql: str) -> List[str]:
        """Extract table names from SQL query"""
        tables = set()
        # FROM and JOIN clauses (qualified: db.table)
        for m in re.findall(r'(?:FROM|JOIN)\s+["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE):
            tables.add(m)
        # Unqualified table names
        for m in re.findall(r'(?:FROM|JOIN)\s+["\']?(\w+)["\']?\s', sql, re.IGNORECASE):
            if '.' not in m and m.upper() not in (
                'SELECT', 'WHERE', 'AND', 'OR', 'NOT', 'IN', 'LIKE', 'LIMIT',
                'ORDER', 'GROUP', 'HAVING', 'SET', 'VALUES', 'INTO', 'AS',
                'ON', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'CROSS', 'NATURAL',
                'BETWEEN', 'IS', 'NULL', 'TRUE', 'FALSE', 'DISTINCT', 'ALL',
                'EXISTS', 'CASE', 'WHEN', 'THEN', 'ELSE', 'END', 'UNION',
                'INTERSECT', 'EXCEPT', 'WITH', 'RECURSIVE'
            ):
                tables.add(m)
        # CREATE TABLE
        for m in re.findall(r'CREATE\s+(?:OR\s+REPLACE\s+)?TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE):
            tables.add(m)
        # CREATE TABLE unqualified
        for m in re.findall(r'CREATE\s+(?:OR\s+REPLACE\s+)?TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?["\']?(\w+)["\']?\s', sql, re.IGNORECASE):
            tables.add(m)
        # INSERT INTO
        for m in re.findall(r'INSERT\s+INTO\s+["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE):
            tables.add(m)
        # INSERT INTO unqualified
        for m in re.findall(r'INSERT\s+INTO\s+["\']?(\w+)["\']?\s', sql, re.IGNORECASE):
            tables.add(m)
        # UPDATE
        for m in re.findall(r'UPDATE\s+["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE):
            tables.add(m)
        # UPDATE unqualified
        for m in re.findall(r'UPDATE\s+["\']?(\w+)["\']?\s', sql, re.IGNORECASE):
            tables.add(m)
        # DELETE FROM
        for m in re.findall(r'DELETE\s+FROM\s+["\']?(\w+\.\w+)["\']?', sql, re.IGNORECASE):
            tables.add(m)
        # DELETE FROM unqualified
        for m in re.findall(r'DELETE\s+FROM\s+["\']?(\w+)["\']?\s', sql, re.IGNORECASE):
            tables.add(m)
        return [self.sanitize_identifier(m) for m in tables]
    
    def build_safe_query(self, database: str, table: str, filters: Dict = None, 
                        limit: int = None, offset: int = None) -> str:
        """Build a safe parameterized query"""
        db = self.sanitize_identifier(database)
        tbl = self.sanitize_identifier(table)
        
        query = f'SELECT * FROM "{db}"."{tbl}"'
        
        if filters:
            where_clauses = []
            for key, value in filters.items():
                safe_key = self.sanitize_identifier(key)
                # Use parameterized queries - values will be bound separately
                where_clauses.append(f'"{safe_key}" = ?')
            if where_clauses:
                query += ' WHERE ' + ' AND '.join(where_clauses)
        
        if limit:
            query += f' LIMIT {int(limit)}'
        if offset:
            query += f' OFFSET {int(offset)}'
        
        return query
    
    def is_read_only(self, sql: str) -> bool:
        """Check if query is read-only"""
        sql_upper = sql.upper().strip()
        return sql_upper.startswith('SELECT') or sql_upper.startswith('WITH')
    
    def split_sql_statements(self, sql: str) -> List[str]:
        """Split SQL dump into individual statements
        
        Handles phpMyAdmin-style SQL dumps with comments and multiple statements
        """
        statements = []
        current_statement = []
        in_string = False
        string_char = None
        
        lines = sql.split('\n')
        
        for line in lines:
            stripped = line.strip()
            
            # Skip empty lines
            if not stripped:
                continue
            
            # Skip comments
            if stripped.startswith('--'):
                continue
            
            # Skip SET commands and transaction commands
            upper_stripped = stripped.upper()
            if upper_stripped.startswith(('SET ', 'START TRANSACTION', 'COMMIT')):
                continue
            
            # Check if line contains semicolon (statement terminator)
            has_semicolon = False
            for i, char in enumerate(line):
                if char in ('"', "'"):
                    if not in_string:
                        in_string = True
                        string_char = char
                    elif char == string_char:
                        # Check if escaped
                        if i > 0 and line[i-1] != '\\':
                            in_string = False
                            string_char = None
                elif char == ';' and not in_string:
                    has_semicolon = True
                    # Add everything up to semicolon
                    current_statement.append(line[:i+1])
                    break
            
            if has_semicolon:
                # Statement complete
                stmt = ' '.join(current_statement).strip()
                if stmt and not stmt.startswith('--'):
                    statements.append(stmt)
                current_statement = []
                in_string = False
                string_char = None
            else:
                # Continue building statement
                current_statement.append(line)
        
        # Add any remaining statement
        if current_statement:
            stmt = ' '.join(current_statement).strip()
            if stmt and not stmt.startswith('--'):
                # Add semicolon if missing
                if not stmt.endswith(';'):
                    stmt += ';'
                statements.append(stmt)
        
        return statements

query_parser = QueryParser()