codeflow-ai / sql_validator.py
unknown
Initial commit: CodeFlow AI - NL to SQL Generator
7814c1f
"""
SQL Query Validator and Optimizer
Provides validation, optimization suggestions, and query analysis
"""
import re
from typing import Dict, List, Tuple
class SQLValidator:
"""Validates and analyzes SQL queries"""
def __init__(self, schema: Dict):
self.schema = schema
self.table_names = set(schema.keys())
self.column_map = {}
# Build column map for quick lookup
for table, columns in schema.items():
for col in columns:
col_name = col['name'].lower()
if col_name not in self.column_map:
self.column_map[col_name] = []
self.column_map[col_name].append(table)
def validate(self, sql: str) -> Dict[str, any]:
"""
Validate SQL query against schema
Returns dict with validation results
"""
results = {
"valid": True,
"errors": [],
"warnings": [],
"suggestions": [],
"query_type": self._detect_query_type(sql)
}
# Check for common SQL anti-patterns
self._check_null_comparison(sql, results)
self._check_select_star(sql, results)
self._check_missing_where(sql, results)
self._check_implicit_joins(sql, results)
self._check_table_names(sql, results)
# If we found errors, mark as invalid
if results["errors"]:
results["valid"] = False
return results
def _detect_query_type(self, sql: str) -> str:
"""Detect the type of query"""
sql_upper = sql.upper()
if "WITH" in sql_upper and "AS" in sql_upper:
return "cte"
elif any(func in sql_upper for func in ["ROW_NUMBER(", "RANK(", "DENSE_RANK(", "PARTITION BY"]):
return "window"
elif "GROUP BY" in sql_upper or any(func in sql_upper for func in ["COUNT(", "SUM(", "AVG(", "MAX(", "MIN("]):
return "aggregate"
elif "JOIN" in sql_upper:
return "join"
elif "UNION" in sql_upper:
return "union"
elif "EXISTS" in sql_upper or "IN (SELECT" in sql_upper:
return "subquery"
else:
return "simple"
def _check_null_comparison(self, sql: str, results: Dict):
"""Check for = NULL instead of IS NULL"""
if re.search(r"=\s*NULL|!=\s*NULL|<>\s*NULL", sql, re.IGNORECASE):
results["errors"].append(
"Use IS NULL or IS NOT NULL instead of = NULL or != NULL"
)
def _check_select_star(self, sql: str, results: Dict):
"""Warn about SELECT *"""
if re.search(r"SELECT\s+\*", sql, re.IGNORECASE):
results["warnings"].append(
"Consider specifying column names instead of SELECT * for better performance"
)
def _check_missing_where(self, sql: str, results: Dict):
"""Check for queries without WHERE clause on large tables"""
sql_upper = sql.upper()
if "DELETE" in sql_upper or "UPDATE" in sql_upper:
if "WHERE" not in sql_upper:
results["errors"].append(
"DELETE or UPDATE without WHERE clause will affect all rows"
)
def _check_implicit_joins(self, sql: str, results: Dict):
"""Check for implicit joins (comma-separated tables)"""
# Look for FROM table1, table2 pattern
if re.search(r"FROM\s+\w+\s*,\s*\w+", sql, re.IGNORECASE):
results["suggestions"].append(
"Consider using explicit JOIN syntax instead of comma-separated tables"
)
def _check_table_names(self, sql: str, results: Dict):
"""Check if referenced tables exist in schema"""
# Extract table names from FROM and JOIN clauses
from_pattern = r"FROM\s+(\w+)"
join_pattern = r"JOIN\s+(\w+)"
tables_in_query = set()
for match in re.finditer(from_pattern, sql, re.IGNORECASE):
tables_in_query.add(match.group(1).lower())
for match in re.finditer(join_pattern, sql, re.IGNORECASE):
tables_in_query.add(match.group(1).lower())
# Check against schema
schema_tables = set(t.lower() for t in self.table_names)
invalid_tables = tables_in_query - schema_tables
if invalid_tables:
results["errors"].append(
f"Table(s) not found in schema: {', '.join(invalid_tables)}"
)
def suggest_optimizations(self, sql: str) -> List[str]:
"""Suggest optimizations for the query"""
suggestions = []
sql_upper = sql.upper()
# Check for NOT IN with subquery
if "NOT IN (SELECT" in sql_upper:
suggestions.append(
"Consider using NOT EXISTS instead of NOT IN for better NULL handling"
)
# Check for multiple subqueries
subquery_count = sql_upper.count("(SELECT")
if subquery_count > 2:
suggestions.append(
"Consider using CTEs (WITH clause) for better readability with multiple subqueries"
)
# Check for DISTINCT
if "DISTINCT" in sql_upper and "GROUP BY" not in sql_upper:
suggestions.append(
"DISTINCT can be expensive. Consider if GROUP BY might be more appropriate"
)
# Check for ORDER BY in subquery
if re.search(r"\(SELECT.*ORDER BY.*\)", sql, re.IGNORECASE):
suggestions.append(
"ORDER BY in subquery may be ignored. Apply ORDER BY to outer query"
)
return suggestions
def format_sql(self, sql: str) -> str:
"""Basic SQL formatting for readability"""
# This is a simple formatter - for production use a proper SQL formatter
formatted = sql.strip()
# Add newlines before major keywords
keywords = ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'HAVING', 'ORDER BY', 'LIMIT']
for keyword in keywords:
formatted = re.sub(
f'\\b{keyword}\\b',
f'\n{keyword}',
formatted,
flags=re.IGNORECASE
)
return formatted.strip()