llm-ready-data / app /services /sql_validator_service.py
light-infer-chat's picture
ok
a6a4880
Raw
History Blame Contribute Delete
11.2 kB
from __future__ import annotations
import re
from enum import Enum
from typing import Any, Optional
import sqlglot
from sqlglot import exp
from sqlglot.errors import ParseError
class QueryType(str, Enum):
SELECT = "SELECT"
INSERT = "INSERT"
UPDATE = "UPDATE"
DELETE = "DELETE"
CREATE = "CREATE"
ALTER = "ALTER"
DROP = "DROP"
TRUNCATE = "TRUNCATE"
MERGE = "MERGE"
WITH = "WITH"
CALL = "CALL"
EXPLAIN = "EXPLAIN"
UNKNOWN = "UNKNOWN"
_DIALECT_ALIASES: dict[str, Optional[str]] = {
"mysql": "mysql",
"mariadb": "mysql",
"postgres": "postgres",
"postgresql": "postgres",
"pg": "postgres",
"redshift": "redshift",
"cockroachdb": "cockroachdb",
"cockroach": "cockroachdb",
"crdb": "cockroachdb",
"sqlite": "sqlite",
"sqlserver": "tsql",
"mssql": "tsql",
"tsql": "tsql",
"oracle": "oracle",
"oracledb": "oracle",
"bigquery": "bigquery",
"gcp": "bigquery",
"bq": "bigquery",
"snowflake": "snowflake",
"sf": "snowflake",
"spark": "spark",
"hive": "hive",
"databricks": "databricks",
"duckdb": "duckdb",
"presto": "presto",
"trino": "trino",
"clickhouse": "clickhouse",
"ansi": None,
"standard": None,
"sql": None,
}
_READ_ONLY_TYPES: frozenset[QueryType] = frozenset(
{QueryType.SELECT, QueryType.WITH, QueryType.EXPLAIN}
)
_AST_TYPE_MAP: list[tuple[type[exp.Expression], QueryType]] = [
(exp.Select, QueryType.SELECT),
(exp.Union, QueryType.SELECT),
(exp.Intersect, QueryType.SELECT),
(exp.Except, QueryType.SELECT),
(exp.Insert, QueryType.INSERT),
(exp.Update, QueryType.UPDATE),
(exp.Delete, QueryType.DELETE),
(exp.Create, QueryType.CREATE),
(exp.Alter, QueryType.ALTER),
(exp.AlterColumn, QueryType.ALTER),
(exp.Drop, QueryType.DROP),
(exp.TruncateTable, QueryType.TRUNCATE),
(exp.Merge, QueryType.MERGE),
(exp.With, QueryType.WITH),
(exp.Command, QueryType.UNKNOWN),
]
_WRITE_NODE_TYPES: tuple[type[exp.Expression], ...] = (
exp.Insert,
exp.Update,
exp.Delete,
exp.Create,
exp.Alter,
exp.AlterColumn,
exp.Drop,
exp.TruncateTable,
exp.Merge,
)
_PLACEHOLDER_PATTERNS: list[tuple[str, re.Pattern[str]]] = [
("positional_?", re.compile(r"\?")),
("pyformat_%s", re.compile(r"%s")),
("numeric_$n", re.compile(r"\$\d+")),
("named_:param", re.compile(r"(?<![:\w]):(\w+)")),
]
def _normalize_dialect(dialect: Optional[str]) -> Optional[str]:
if dialect is None:
return None
key = dialect.strip().lower()
return _DIALECT_ALIASES.get(key, key)
def _has_ctes(statement: exp.Expression) -> bool:
return statement.args.get("with") is not None
def _detect_query_type(statement: exp.Expression) -> QueryType:
if _has_ctes(statement):
return QueryType.WITH
for ast_type, query_type in _AST_TYPE_MAP:
if isinstance(statement, ast_type):
return query_type
type_name = type(statement).__name__.upper()
for qt in QueryType:
if qt != QueryType.UNKNOWN and qt.value in type_name:
return qt
return QueryType.UNKNOWN
def _is_read_only(statement: exp.Expression, query_type: QueryType) -> bool:
if query_type in _READ_ONLY_TYPES and query_type != QueryType.WITH:
return True
if query_type not in _READ_ONLY_TYPES:
return False
if query_type == QueryType.WITH:
for node in statement.find_all(_WRITE_NODE_TYPES):
return False
return True
return False
def _extract_tables(statement: exp.Expression) -> list[str]:
tables: list[str] = []
seen: set[str] = set()
for tbl in statement.find_all(exp.Table):
parts: list[str] = []
if tbl.catalog:
parts.append(tbl.catalog)
if tbl.db:
parts.append(tbl.db)
parts.append(tbl.name)
full_name = ".".join(p for p in parts if p)
if full_name and full_name not in seen:
seen.add(full_name)
tables.append(full_name)
return tables
def _extract_columns(statement: exp.Expression) -> list[str]:
columns: list[str] = []
seen: set[str] = set()
for col in statement.find_all(exp.Column):
name = col.name
if name and name not in seen:
seen.add(name)
columns.append(name)
return columns
def _detect_placeholders(query: str) -> list[str]:
stripped = re.sub(r"'(?:[^'\\]|\\.)*'", "", query)
stripped = re.sub(r'"(?:[^"\\]|\\.)*"', "", stripped)
detected: set[str] = set()
for name, pattern in _PLACEHOLDER_PATTERNS:
if pattern.search(stripped):
detected.add(name)
if detected:
return [
f"Prepared-statement placeholders detected: "
f"{', '.join(sorted(detected))}. Ensure the target "
f"database driver supports these placeholder styles."
]
return []
def _check_best_practices(statement: exp.Expression) -> list[str]:
warnings: list[str] = []
for sel in statement.find_all(exp.Select):
if any(isinstance(e, exp.Star) for e in (sel.expressions or [])):
warnings.append(
"SELECT * detected — explicitly listing columns is "
"recommended for performance and maintainability."
)
break
if isinstance(statement, exp.Update) and not statement.args.get("where"):
warnings.append(
"UPDATE without a WHERE clause will affect every row in the table."
)
if isinstance(statement, exp.Delete) and not statement.args.get("where"):
warnings.append(
"DELETE without a WHERE clause will remove every row from the table."
)
for join in statement.find_all(exp.Join):
kind = (join.args.get("kind") or "").upper()
method = (join.args.get("method") or "").upper()
if kind == "NATURAL" or method == "NATURAL":
warnings.append(
"NATURAL JOIN can produce unexpected column matches — "
"prefer explicit JOIN conditions."
)
break
return warnings
def _analyze_statements(statements: list[exp.Expression]) -> dict[str, Any]:
all_tables: list[str] = []
tables_seen: set[str] = set()
all_columns: list[str] = []
columns_seen: set[str] = set()
all_warnings: list[str] = []
primary_type: QueryType = QueryType.UNKNOWN
is_read_only = True
for idx, stmt in enumerate(statements):
stmt_type = _detect_query_type(stmt)
if idx == 0:
primary_type = stmt_type
if not _is_read_only(stmt, stmt_type):
is_read_only = False
for t in _extract_tables(stmt):
if t not in tables_seen:
tables_seen.add(t)
all_tables.append(t)
for c in _extract_columns(stmt):
if c not in columns_seen:
columns_seen.add(c)
all_columns.append(c)
all_warnings.extend(_check_best_practices(stmt))
if isinstance(stmt, exp.Command):
cmd_verb = getattr(stmt, "name", "")
all_warnings.append(
f"Statement uses a {cmd_verb!r} command that could not be "
f"fully analysed. Syntax validation may be incomplete."
)
if len(statements) > 1:
all_warnings.append(
f"Multiple statements detected ({len(statements)} total). "
f"Results reflect the combined analysis of all statements."
)
return {
"query_type": primary_type.value,
"is_read_only": is_read_only,
"tables": all_tables,
"columns": all_columns,
"warnings": all_warnings,
}
class SqlValidatorService:
def validate(self, query: str, dialect: Optional[str] = None) -> dict[str, Any]:
if not isinstance(query, str):
raise TypeError(f"query must be a string, got {type(query).__name__}")
if dialect is not None and not isinstance(dialect, str):
raise TypeError(f"dialect must be a string or None, got {type(dialect).__name__}")
normalized_dialect = _normalize_dialect(dialect)
stripped_query = query.strip()
if not stripped_query:
return {
"valid": False,
"query_type": QueryType.UNKNOWN.value,
"dialect": normalized_dialect,
"errors": ["Empty query string provided."],
"warnings": [],
"is_read_only": False,
"tables": [],
"columns": [],
}
warnings: list[str] = _detect_placeholders(stripped_query)
statements: list[exp.Expression] = []
errors: list[str] = []
try:
parsed = sqlglot.parse(
stripped_query,
dialect=normalized_dialect,
error_level=sqlglot.ErrorLevel.RAISE,
)
statements = [s for s in parsed if s is not None]
except ParseError as exc:
errors.append(f"SQL syntax error: {exc}")
except RecursionError:
errors.append("Query is too deeply nested to parse — consider simplifying.")
except Exception as exc:
errors.append(f"Unexpected error during parsing: {exc}")
if errors:
try:
parsed = sqlglot.parse(
stripped_query,
dialect=normalized_dialect,
error_level=sqlglot.ErrorLevel.WARN,
)
statements = [s for s in parsed if s is not None]
except Exception:
statements = []
if not statements:
if not errors:
errors.append("No valid SQL statements found. The query may be empty or contain only comments.")
return {
"valid": False,
"query_type": QueryType.UNKNOWN.value,
"dialect": normalized_dialect,
"errors": errors,
"warnings": warnings,
"is_read_only": False,
"tables": [],
"columns": [],
}
analysis = _analyze_statements(statements)
warnings.extend(analysis["warnings"])
valid = len(errors) == 0
if valid and analysis["query_type"] == QueryType.UNKNOWN.value:
has_known_statement = any(
_detect_query_type(s) != QueryType.UNKNOWN for s in statements
)
if not has_known_statement:
valid = False
if not errors:
errors.append(
"Query could not be recognised as a valid SQL statement."
)
return {
"valid": valid,
"query_type": analysis["query_type"],
"dialect": normalized_dialect,
"errors": errors,
"warnings": warnings,
"is_read_only": analysis["is_read_only"],
"tables": analysis["tables"],
"columns": analysis["columns"],
}