Spaces:
Running
Running
| from __future__ import annotations | |
| import re | |
| from enum import Enum | |
| from typing import Any, Optional | |
| import sqlglot | |
| from sqlglot import exp | |
| from sqlglot.errors import ParseError | |
| class QueryType(str, Enum): | |
| SELECT = "SELECT" | |
| INSERT = "INSERT" | |
| UPDATE = "UPDATE" | |
| DELETE = "DELETE" | |
| CREATE = "CREATE" | |
| ALTER = "ALTER" | |
| DROP = "DROP" | |
| TRUNCATE = "TRUNCATE" | |
| MERGE = "MERGE" | |
| WITH = "WITH" | |
| CALL = "CALL" | |
| EXPLAIN = "EXPLAIN" | |
| UNKNOWN = "UNKNOWN" | |
| _DIALECT_ALIASES: dict[str, Optional[str]] = { | |
| "mysql": "mysql", | |
| "mariadb": "mysql", | |
| "postgres": "postgres", | |
| "postgresql": "postgres", | |
| "pg": "postgres", | |
| "redshift": "redshift", | |
| "cockroachdb": "cockroachdb", | |
| "cockroach": "cockroachdb", | |
| "crdb": "cockroachdb", | |
| "sqlite": "sqlite", | |
| "sqlserver": "tsql", | |
| "mssql": "tsql", | |
| "tsql": "tsql", | |
| "oracle": "oracle", | |
| "oracledb": "oracle", | |
| "bigquery": "bigquery", | |
| "gcp": "bigquery", | |
| "bq": "bigquery", | |
| "snowflake": "snowflake", | |
| "sf": "snowflake", | |
| "spark": "spark", | |
| "hive": "hive", | |
| "databricks": "databricks", | |
| "duckdb": "duckdb", | |
| "presto": "presto", | |
| "trino": "trino", | |
| "clickhouse": "clickhouse", | |
| "ansi": None, | |
| "standard": None, | |
| "sql": None, | |
| } | |
| _READ_ONLY_TYPES: frozenset[QueryType] = frozenset( | |
| {QueryType.SELECT, QueryType.WITH, QueryType.EXPLAIN} | |
| ) | |
| _AST_TYPE_MAP: list[tuple[type[exp.Expression], QueryType]] = [ | |
| (exp.Select, QueryType.SELECT), | |
| (exp.Union, QueryType.SELECT), | |
| (exp.Intersect, QueryType.SELECT), | |
| (exp.Except, QueryType.SELECT), | |
| (exp.Insert, QueryType.INSERT), | |
| (exp.Update, QueryType.UPDATE), | |
| (exp.Delete, QueryType.DELETE), | |
| (exp.Create, QueryType.CREATE), | |
| (exp.Alter, QueryType.ALTER), | |
| (exp.AlterColumn, QueryType.ALTER), | |
| (exp.Drop, QueryType.DROP), | |
| (exp.TruncateTable, QueryType.TRUNCATE), | |
| (exp.Merge, QueryType.MERGE), | |
| (exp.With, QueryType.WITH), | |
| (exp.Command, QueryType.UNKNOWN), | |
| ] | |
| _WRITE_NODE_TYPES: tuple[type[exp.Expression], ...] = ( | |
| exp.Insert, | |
| exp.Update, | |
| exp.Delete, | |
| exp.Create, | |
| exp.Alter, | |
| exp.AlterColumn, | |
| exp.Drop, | |
| exp.TruncateTable, | |
| exp.Merge, | |
| ) | |
| _PLACEHOLDER_PATTERNS: list[tuple[str, re.Pattern[str]]] = [ | |
| ("positional_?", re.compile(r"\?")), | |
| ("pyformat_%s", re.compile(r"%s")), | |
| ("numeric_$n", re.compile(r"\$\d+")), | |
| ("named_:param", re.compile(r"(?<![:\w]):(\w+)")), | |
| ] | |
| def _normalize_dialect(dialect: Optional[str]) -> Optional[str]: | |
| if dialect is None: | |
| return None | |
| key = dialect.strip().lower() | |
| return _DIALECT_ALIASES.get(key, key) | |
| def _has_ctes(statement: exp.Expression) -> bool: | |
| return statement.args.get("with") is not None | |
| def _detect_query_type(statement: exp.Expression) -> QueryType: | |
| if _has_ctes(statement): | |
| return QueryType.WITH | |
| for ast_type, query_type in _AST_TYPE_MAP: | |
| if isinstance(statement, ast_type): | |
| return query_type | |
| type_name = type(statement).__name__.upper() | |
| for qt in QueryType: | |
| if qt != QueryType.UNKNOWN and qt.value in type_name: | |
| return qt | |
| return QueryType.UNKNOWN | |
| def _is_read_only(statement: exp.Expression, query_type: QueryType) -> bool: | |
| if query_type in _READ_ONLY_TYPES and query_type != QueryType.WITH: | |
| return True | |
| if query_type not in _READ_ONLY_TYPES: | |
| return False | |
| if query_type == QueryType.WITH: | |
| for node in statement.find_all(_WRITE_NODE_TYPES): | |
| return False | |
| return True | |
| return False | |
| def _extract_tables(statement: exp.Expression) -> list[str]: | |
| tables: list[str] = [] | |
| seen: set[str] = set() | |
| for tbl in statement.find_all(exp.Table): | |
| parts: list[str] = [] | |
| if tbl.catalog: | |
| parts.append(tbl.catalog) | |
| if tbl.db: | |
| parts.append(tbl.db) | |
| parts.append(tbl.name) | |
| full_name = ".".join(p for p in parts if p) | |
| if full_name and full_name not in seen: | |
| seen.add(full_name) | |
| tables.append(full_name) | |
| return tables | |
| def _extract_columns(statement: exp.Expression) -> list[str]: | |
| columns: list[str] = [] | |
| seen: set[str] = set() | |
| for col in statement.find_all(exp.Column): | |
| name = col.name | |
| if name and name not in seen: | |
| seen.add(name) | |
| columns.append(name) | |
| return columns | |
| def _detect_placeholders(query: str) -> list[str]: | |
| stripped = re.sub(r"'(?:[^'\\]|\\.)*'", "", query) | |
| stripped = re.sub(r'"(?:[^"\\]|\\.)*"', "", stripped) | |
| detected: set[str] = set() | |
| for name, pattern in _PLACEHOLDER_PATTERNS: | |
| if pattern.search(stripped): | |
| detected.add(name) | |
| if detected: | |
| return [ | |
| f"Prepared-statement placeholders detected: " | |
| f"{', '.join(sorted(detected))}. Ensure the target " | |
| f"database driver supports these placeholder styles." | |
| ] | |
| return [] | |
| def _check_best_practices(statement: exp.Expression) -> list[str]: | |
| warnings: list[str] = [] | |
| for sel in statement.find_all(exp.Select): | |
| if any(isinstance(e, exp.Star) for e in (sel.expressions or [])): | |
| warnings.append( | |
| "SELECT * detected — explicitly listing columns is " | |
| "recommended for performance and maintainability." | |
| ) | |
| break | |
| if isinstance(statement, exp.Update) and not statement.args.get("where"): | |
| warnings.append( | |
| "UPDATE without a WHERE clause will affect every row in the table." | |
| ) | |
| if isinstance(statement, exp.Delete) and not statement.args.get("where"): | |
| warnings.append( | |
| "DELETE without a WHERE clause will remove every row from the table." | |
| ) | |
| for join in statement.find_all(exp.Join): | |
| kind = (join.args.get("kind") or "").upper() | |
| method = (join.args.get("method") or "").upper() | |
| if kind == "NATURAL" or method == "NATURAL": | |
| warnings.append( | |
| "NATURAL JOIN can produce unexpected column matches — " | |
| "prefer explicit JOIN conditions." | |
| ) | |
| break | |
| return warnings | |
| def _analyze_statements(statements: list[exp.Expression]) -> dict[str, Any]: | |
| all_tables: list[str] = [] | |
| tables_seen: set[str] = set() | |
| all_columns: list[str] = [] | |
| columns_seen: set[str] = set() | |
| all_warnings: list[str] = [] | |
| primary_type: QueryType = QueryType.UNKNOWN | |
| is_read_only = True | |
| for idx, stmt in enumerate(statements): | |
| stmt_type = _detect_query_type(stmt) | |
| if idx == 0: | |
| primary_type = stmt_type | |
| if not _is_read_only(stmt, stmt_type): | |
| is_read_only = False | |
| for t in _extract_tables(stmt): | |
| if t not in tables_seen: | |
| tables_seen.add(t) | |
| all_tables.append(t) | |
| for c in _extract_columns(stmt): | |
| if c not in columns_seen: | |
| columns_seen.add(c) | |
| all_columns.append(c) | |
| all_warnings.extend(_check_best_practices(stmt)) | |
| if isinstance(stmt, exp.Command): | |
| cmd_verb = getattr(stmt, "name", "") | |
| all_warnings.append( | |
| f"Statement uses a {cmd_verb!r} command that could not be " | |
| f"fully analysed. Syntax validation may be incomplete." | |
| ) | |
| if len(statements) > 1: | |
| all_warnings.append( | |
| f"Multiple statements detected ({len(statements)} total). " | |
| f"Results reflect the combined analysis of all statements." | |
| ) | |
| return { | |
| "query_type": primary_type.value, | |
| "is_read_only": is_read_only, | |
| "tables": all_tables, | |
| "columns": all_columns, | |
| "warnings": all_warnings, | |
| } | |
| class SqlValidatorService: | |
| def validate(self, query: str, dialect: Optional[str] = None) -> dict[str, Any]: | |
| if not isinstance(query, str): | |
| raise TypeError(f"query must be a string, got {type(query).__name__}") | |
| if dialect is not None and not isinstance(dialect, str): | |
| raise TypeError(f"dialect must be a string or None, got {type(dialect).__name__}") | |
| normalized_dialect = _normalize_dialect(dialect) | |
| stripped_query = query.strip() | |
| if not stripped_query: | |
| return { | |
| "valid": False, | |
| "query_type": QueryType.UNKNOWN.value, | |
| "dialect": normalized_dialect, | |
| "errors": ["Empty query string provided."], | |
| "warnings": [], | |
| "is_read_only": False, | |
| "tables": [], | |
| "columns": [], | |
| } | |
| warnings: list[str] = _detect_placeholders(stripped_query) | |
| statements: list[exp.Expression] = [] | |
| errors: list[str] = [] | |
| try: | |
| parsed = sqlglot.parse( | |
| stripped_query, | |
| dialect=normalized_dialect, | |
| error_level=sqlglot.ErrorLevel.RAISE, | |
| ) | |
| statements = [s for s in parsed if s is not None] | |
| except ParseError as exc: | |
| errors.append(f"SQL syntax error: {exc}") | |
| except RecursionError: | |
| errors.append("Query is too deeply nested to parse — consider simplifying.") | |
| except Exception as exc: | |
| errors.append(f"Unexpected error during parsing: {exc}") | |
| if errors: | |
| try: | |
| parsed = sqlglot.parse( | |
| stripped_query, | |
| dialect=normalized_dialect, | |
| error_level=sqlglot.ErrorLevel.WARN, | |
| ) | |
| statements = [s for s in parsed if s is not None] | |
| except Exception: | |
| statements = [] | |
| if not statements: | |
| if not errors: | |
| errors.append("No valid SQL statements found. The query may be empty or contain only comments.") | |
| return { | |
| "valid": False, | |
| "query_type": QueryType.UNKNOWN.value, | |
| "dialect": normalized_dialect, | |
| "errors": errors, | |
| "warnings": warnings, | |
| "is_read_only": False, | |
| "tables": [], | |
| "columns": [], | |
| } | |
| analysis = _analyze_statements(statements) | |
| warnings.extend(analysis["warnings"]) | |
| valid = len(errors) == 0 | |
| if valid and analysis["query_type"] == QueryType.UNKNOWN.value: | |
| has_known_statement = any( | |
| _detect_query_type(s) != QueryType.UNKNOWN for s in statements | |
| ) | |
| if not has_known_statement: | |
| valid = False | |
| if not errors: | |
| errors.append( | |
| "Query could not be recognised as a valid SQL statement." | |
| ) | |
| return { | |
| "valid": valid, | |
| "query_type": analysis["query_type"], | |
| "dialect": normalized_dialect, | |
| "errors": errors, | |
| "warnings": warnings, | |
| "is_read_only": analysis["is_read_only"], | |
| "tables": analysis["tables"], | |
| "columns": analysis["columns"], | |
| } | |