Spaces:
Sleeping
Sleeping
File size: 6,394 Bytes
7814c1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
"""
SQL Query Validator and Optimizer
Provides validation, optimization suggestions, and query analysis
"""
import re
from typing import Dict, List, Tuple
class SQLValidator:
"""Validates and analyzes SQL queries"""
def __init__(self, schema: Dict):
self.schema = schema
self.table_names = set(schema.keys())
self.column_map = {}
# Build column map for quick lookup
for table, columns in schema.items():
for col in columns:
col_name = col['name'].lower()
if col_name not in self.column_map:
self.column_map[col_name] = []
self.column_map[col_name].append(table)
def validate(self, sql: str) -> Dict[str, any]:
"""
Validate SQL query against schema
Returns dict with validation results
"""
results = {
"valid": True,
"errors": [],
"warnings": [],
"suggestions": [],
"query_type": self._detect_query_type(sql)
}
# Check for common SQL anti-patterns
self._check_null_comparison(sql, results)
self._check_select_star(sql, results)
self._check_missing_where(sql, results)
self._check_implicit_joins(sql, results)
self._check_table_names(sql, results)
# If we found errors, mark as invalid
if results["errors"]:
results["valid"] = False
return results
def _detect_query_type(self, sql: str) -> str:
"""Detect the type of query"""
sql_upper = sql.upper()
if "WITH" in sql_upper and "AS" in sql_upper:
return "cte"
elif any(func in sql_upper for func in ["ROW_NUMBER(", "RANK(", "DENSE_RANK(", "PARTITION BY"]):
return "window"
elif "GROUP BY" in sql_upper or any(func in sql_upper for func in ["COUNT(", "SUM(", "AVG(", "MAX(", "MIN("]):
return "aggregate"
elif "JOIN" in sql_upper:
return "join"
elif "UNION" in sql_upper:
return "union"
elif "EXISTS" in sql_upper or "IN (SELECT" in sql_upper:
return "subquery"
else:
return "simple"
def _check_null_comparison(self, sql: str, results: Dict):
"""Check for = NULL instead of IS NULL"""
if re.search(r"=\s*NULL|!=\s*NULL|<>\s*NULL", sql, re.IGNORECASE):
results["errors"].append(
"Use IS NULL or IS NOT NULL instead of = NULL or != NULL"
)
def _check_select_star(self, sql: str, results: Dict):
"""Warn about SELECT *"""
if re.search(r"SELECT\s+\*", sql, re.IGNORECASE):
results["warnings"].append(
"Consider specifying column names instead of SELECT * for better performance"
)
def _check_missing_where(self, sql: str, results: Dict):
"""Check for queries without WHERE clause on large tables"""
sql_upper = sql.upper()
if "DELETE" in sql_upper or "UPDATE" in sql_upper:
if "WHERE" not in sql_upper:
results["errors"].append(
"DELETE or UPDATE without WHERE clause will affect all rows"
)
def _check_implicit_joins(self, sql: str, results: Dict):
"""Check for implicit joins (comma-separated tables)"""
# Look for FROM table1, table2 pattern
if re.search(r"FROM\s+\w+\s*,\s*\w+", sql, re.IGNORECASE):
results["suggestions"].append(
"Consider using explicit JOIN syntax instead of comma-separated tables"
)
def _check_table_names(self, sql: str, results: Dict):
"""Check if referenced tables exist in schema"""
# Extract table names from FROM and JOIN clauses
from_pattern = r"FROM\s+(\w+)"
join_pattern = r"JOIN\s+(\w+)"
tables_in_query = set()
for match in re.finditer(from_pattern, sql, re.IGNORECASE):
tables_in_query.add(match.group(1).lower())
for match in re.finditer(join_pattern, sql, re.IGNORECASE):
tables_in_query.add(match.group(1).lower())
# Check against schema
schema_tables = set(t.lower() for t in self.table_names)
invalid_tables = tables_in_query - schema_tables
if invalid_tables:
results["errors"].append(
f"Table(s) not found in schema: {', '.join(invalid_tables)}"
)
def suggest_optimizations(self, sql: str) -> List[str]:
"""Suggest optimizations for the query"""
suggestions = []
sql_upper = sql.upper()
# Check for NOT IN with subquery
if "NOT IN (SELECT" in sql_upper:
suggestions.append(
"Consider using NOT EXISTS instead of NOT IN for better NULL handling"
)
# Check for multiple subqueries
subquery_count = sql_upper.count("(SELECT")
if subquery_count > 2:
suggestions.append(
"Consider using CTEs (WITH clause) for better readability with multiple subqueries"
)
# Check for DISTINCT
if "DISTINCT" in sql_upper and "GROUP BY" not in sql_upper:
suggestions.append(
"DISTINCT can be expensive. Consider if GROUP BY might be more appropriate"
)
# Check for ORDER BY in subquery
if re.search(r"\(SELECT.*ORDER BY.*\)", sql, re.IGNORECASE):
suggestions.append(
"ORDER BY in subquery may be ignored. Apply ORDER BY to outer query"
)
return suggestions
def format_sql(self, sql: str) -> str:
"""Basic SQL formatting for readability"""
# This is a simple formatter - for production use a proper SQL formatter
formatted = sql.strip()
# Add newlines before major keywords
keywords = ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'HAVING', 'ORDER BY', 'LIMIT']
for keyword in keywords:
formatted = re.sub(
f'\\b{keyword}\\b',
f'\n{keyword}',
formatted,
flags=re.IGNORECASE
)
return formatted.strip() |