from __future__ import annotations import re from enum import Enum from typing import Any, Optional import sqlglot from sqlglot import exp from sqlglot.errors import ParseError class QueryType(str, Enum): SELECT = "SELECT" INSERT = "INSERT" UPDATE = "UPDATE" DELETE = "DELETE" CREATE = "CREATE" ALTER = "ALTER" DROP = "DROP" TRUNCATE = "TRUNCATE" MERGE = "MERGE" WITH = "WITH" CALL = "CALL" EXPLAIN = "EXPLAIN" UNKNOWN = "UNKNOWN" _DIALECT_ALIASES: dict[str, Optional[str]] = { "mysql": "mysql", "mariadb": "mysql", "postgres": "postgres", "postgresql": "postgres", "pg": "postgres", "redshift": "redshift", "cockroachdb": "cockroachdb", "cockroach": "cockroachdb", "crdb": "cockroachdb", "sqlite": "sqlite", "sqlserver": "tsql", "mssql": "tsql", "tsql": "tsql", "oracle": "oracle", "oracledb": "oracle", "bigquery": "bigquery", "gcp": "bigquery", "bq": "bigquery", "snowflake": "snowflake", "sf": "snowflake", "spark": "spark", "hive": "hive", "databricks": "databricks", "duckdb": "duckdb", "presto": "presto", "trino": "trino", "clickhouse": "clickhouse", "ansi": None, "standard": None, "sql": None, } _READ_ONLY_TYPES: frozenset[QueryType] = frozenset( {QueryType.SELECT, QueryType.WITH, QueryType.EXPLAIN} ) _AST_TYPE_MAP: list[tuple[type[exp.Expression], QueryType]] = [ (exp.Select, QueryType.SELECT), (exp.Union, QueryType.SELECT), (exp.Intersect, QueryType.SELECT), (exp.Except, QueryType.SELECT), (exp.Insert, QueryType.INSERT), (exp.Update, QueryType.UPDATE), (exp.Delete, QueryType.DELETE), (exp.Create, QueryType.CREATE), (exp.Alter, QueryType.ALTER), (exp.AlterColumn, QueryType.ALTER), (exp.Drop, QueryType.DROP), (exp.TruncateTable, QueryType.TRUNCATE), (exp.Merge, QueryType.MERGE), (exp.With, QueryType.WITH), (exp.Command, QueryType.UNKNOWN), ] _WRITE_NODE_TYPES: tuple[type[exp.Expression], ...] = ( exp.Insert, exp.Update, exp.Delete, exp.Create, exp.Alter, exp.AlterColumn, exp.Drop, exp.TruncateTable, exp.Merge, ) _PLACEHOLDER_PATTERNS: list[tuple[str, re.Pattern[str]]] = [ ("positional_?", re.compile(r"\?")), ("pyformat_%s", re.compile(r"%s")), ("numeric_$n", re.compile(r"\$\d+")), ("named_:param", re.compile(r"(? Optional[str]: if dialect is None: return None key = dialect.strip().lower() return _DIALECT_ALIASES.get(key, key) def _has_ctes(statement: exp.Expression) -> bool: return statement.args.get("with") is not None def _detect_query_type(statement: exp.Expression) -> QueryType: if _has_ctes(statement): return QueryType.WITH for ast_type, query_type in _AST_TYPE_MAP: if isinstance(statement, ast_type): return query_type type_name = type(statement).__name__.upper() for qt in QueryType: if qt != QueryType.UNKNOWN and qt.value in type_name: return qt return QueryType.UNKNOWN def _is_read_only(statement: exp.Expression, query_type: QueryType) -> bool: if query_type in _READ_ONLY_TYPES and query_type != QueryType.WITH: return True if query_type not in _READ_ONLY_TYPES: return False if query_type == QueryType.WITH: for node in statement.find_all(_WRITE_NODE_TYPES): return False return True return False def _extract_tables(statement: exp.Expression) -> list[str]: tables: list[str] = [] seen: set[str] = set() for tbl in statement.find_all(exp.Table): parts: list[str] = [] if tbl.catalog: parts.append(tbl.catalog) if tbl.db: parts.append(tbl.db) parts.append(tbl.name) full_name = ".".join(p for p in parts if p) if full_name and full_name not in seen: seen.add(full_name) tables.append(full_name) return tables def _extract_columns(statement: exp.Expression) -> list[str]: columns: list[str] = [] seen: set[str] = set() for col in statement.find_all(exp.Column): name = col.name if name and name not in seen: seen.add(name) columns.append(name) return columns def _detect_placeholders(query: str) -> list[str]: stripped = re.sub(r"'(?:[^'\\]|\\.)*'", "", query) stripped = re.sub(r'"(?:[^"\\]|\\.)*"', "", stripped) detected: set[str] = set() for name, pattern in _PLACEHOLDER_PATTERNS: if pattern.search(stripped): detected.add(name) if detected: return [ f"Prepared-statement placeholders detected: " f"{', '.join(sorted(detected))}. Ensure the target " f"database driver supports these placeholder styles." ] return [] def _check_best_practices(statement: exp.Expression) -> list[str]: warnings: list[str] = [] for sel in statement.find_all(exp.Select): if any(isinstance(e, exp.Star) for e in (sel.expressions or [])): warnings.append( "SELECT * detected — explicitly listing columns is " "recommended for performance and maintainability." ) break if isinstance(statement, exp.Update) and not statement.args.get("where"): warnings.append( "UPDATE without a WHERE clause will affect every row in the table." ) if isinstance(statement, exp.Delete) and not statement.args.get("where"): warnings.append( "DELETE without a WHERE clause will remove every row from the table." ) for join in statement.find_all(exp.Join): kind = (join.args.get("kind") or "").upper() method = (join.args.get("method") or "").upper() if kind == "NATURAL" or method == "NATURAL": warnings.append( "NATURAL JOIN can produce unexpected column matches — " "prefer explicit JOIN conditions." ) break return warnings def _analyze_statements(statements: list[exp.Expression]) -> dict[str, Any]: all_tables: list[str] = [] tables_seen: set[str] = set() all_columns: list[str] = [] columns_seen: set[str] = set() all_warnings: list[str] = [] primary_type: QueryType = QueryType.UNKNOWN is_read_only = True for idx, stmt in enumerate(statements): stmt_type = _detect_query_type(stmt) if idx == 0: primary_type = stmt_type if not _is_read_only(stmt, stmt_type): is_read_only = False for t in _extract_tables(stmt): if t not in tables_seen: tables_seen.add(t) all_tables.append(t) for c in _extract_columns(stmt): if c not in columns_seen: columns_seen.add(c) all_columns.append(c) all_warnings.extend(_check_best_practices(stmt)) if isinstance(stmt, exp.Command): cmd_verb = getattr(stmt, "name", "") all_warnings.append( f"Statement uses a {cmd_verb!r} command that could not be " f"fully analysed. Syntax validation may be incomplete." ) if len(statements) > 1: all_warnings.append( f"Multiple statements detected ({len(statements)} total). " f"Results reflect the combined analysis of all statements." ) return { "query_type": primary_type.value, "is_read_only": is_read_only, "tables": all_tables, "columns": all_columns, "warnings": all_warnings, } class SqlValidatorService: def validate(self, query: str, dialect: Optional[str] = None) -> dict[str, Any]: if not isinstance(query, str): raise TypeError(f"query must be a string, got {type(query).__name__}") if dialect is not None and not isinstance(dialect, str): raise TypeError(f"dialect must be a string or None, got {type(dialect).__name__}") normalized_dialect = _normalize_dialect(dialect) stripped_query = query.strip() if not stripped_query: return { "valid": False, "query_type": QueryType.UNKNOWN.value, "dialect": normalized_dialect, "errors": ["Empty query string provided."], "warnings": [], "is_read_only": False, "tables": [], "columns": [], } warnings: list[str] = _detect_placeholders(stripped_query) statements: list[exp.Expression] = [] errors: list[str] = [] try: parsed = sqlglot.parse( stripped_query, dialect=normalized_dialect, error_level=sqlglot.ErrorLevel.RAISE, ) statements = [s for s in parsed if s is not None] except ParseError as exc: errors.append(f"SQL syntax error: {exc}") except RecursionError: errors.append("Query is too deeply nested to parse — consider simplifying.") except Exception as exc: errors.append(f"Unexpected error during parsing: {exc}") if errors: try: parsed = sqlglot.parse( stripped_query, dialect=normalized_dialect, error_level=sqlglot.ErrorLevel.WARN, ) statements = [s for s in parsed if s is not None] except Exception: statements = [] if not statements: if not errors: errors.append("No valid SQL statements found. The query may be empty or contain only comments.") return { "valid": False, "query_type": QueryType.UNKNOWN.value, "dialect": normalized_dialect, "errors": errors, "warnings": warnings, "is_read_only": False, "tables": [], "columns": [], } analysis = _analyze_statements(statements) warnings.extend(analysis["warnings"]) valid = len(errors) == 0 if valid and analysis["query_type"] == QueryType.UNKNOWN.value: has_known_statement = any( _detect_query_type(s) != QueryType.UNKNOWN for s in statements ) if not has_known_statement: valid = False if not errors: errors.append( "Query could not be recognised as a valid SQL statement." ) return { "valid": valid, "query_type": analysis["query_type"], "dialect": normalized_dialect, "errors": errors, "warnings": warnings, "is_read_only": analysis["is_read_only"], "tables": analysis["tables"], "columns": analysis["columns"], }