File size: 6,394 Bytes
7814c1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
"""
SQL Query Validator and Optimizer
Provides validation, optimization suggestions, and query analysis
"""

import re
from typing import Dict, List, Tuple

class SQLValidator:
    """Validates and analyzes SQL queries"""
    
    def __init__(self, schema: Dict):
        self.schema = schema
        self.table_names = set(schema.keys())
        self.column_map = {}
        
        # Build column map for quick lookup
        for table, columns in schema.items():
            for col in columns:
                col_name = col['name'].lower()
                if col_name not in self.column_map:
                    self.column_map[col_name] = []
                self.column_map[col_name].append(table)
    
    def validate(self, sql: str) -> Dict[str, any]:
        """
        Validate SQL query against schema
        Returns dict with validation results
        """
        results = {
            "valid": True,
            "errors": [],
            "warnings": [],
            "suggestions": [],
            "query_type": self._detect_query_type(sql)
        }
        
        # Check for common SQL anti-patterns
        self._check_null_comparison(sql, results)
        self._check_select_star(sql, results)
        self._check_missing_where(sql, results)
        self._check_implicit_joins(sql, results)
        self._check_table_names(sql, results)
        
        # If we found errors, mark as invalid
        if results["errors"]:
            results["valid"] = False
        
        return results
    
    def _detect_query_type(self, sql: str) -> str:
        """Detect the type of query"""
        sql_upper = sql.upper()
        
        if "WITH" in sql_upper and "AS" in sql_upper:
            return "cte"
        elif any(func in sql_upper for func in ["ROW_NUMBER(", "RANK(", "DENSE_RANK(", "PARTITION BY"]):
            return "window"
        elif "GROUP BY" in sql_upper or any(func in sql_upper for func in ["COUNT(", "SUM(", "AVG(", "MAX(", "MIN("]):
            return "aggregate"
        elif "JOIN" in sql_upper:
            return "join"
        elif "UNION" in sql_upper:
            return "union"
        elif "EXISTS" in sql_upper or "IN (SELECT" in sql_upper:
            return "subquery"
        else:
            return "simple"
    
    def _check_null_comparison(self, sql: str, results: Dict):
        """Check for = NULL instead of IS NULL"""
        if re.search(r"=\s*NULL|!=\s*NULL|<>\s*NULL", sql, re.IGNORECASE):
            results["errors"].append(
                "Use IS NULL or IS NOT NULL instead of = NULL or != NULL"
            )
    
    def _check_select_star(self, sql: str, results: Dict):
        """Warn about SELECT *"""
        if re.search(r"SELECT\s+\*", sql, re.IGNORECASE):
            results["warnings"].append(
                "Consider specifying column names instead of SELECT * for better performance"
            )
    
    def _check_missing_where(self, sql: str, results: Dict):
        """Check for queries without WHERE clause on large tables"""
        sql_upper = sql.upper()
        if "DELETE" in sql_upper or "UPDATE" in sql_upper:
            if "WHERE" not in sql_upper:
                results["errors"].append(
                    "DELETE or UPDATE without WHERE clause will affect all rows"
                )
    
    def _check_implicit_joins(self, sql: str, results: Dict):
        """Check for implicit joins (comma-separated tables)"""
        # Look for FROM table1, table2 pattern
        if re.search(r"FROM\s+\w+\s*,\s*\w+", sql, re.IGNORECASE):
            results["suggestions"].append(
                "Consider using explicit JOIN syntax instead of comma-separated tables"
            )
    
    def _check_table_names(self, sql: str, results: Dict):
        """Check if referenced tables exist in schema"""
        # Extract table names from FROM and JOIN clauses
        from_pattern = r"FROM\s+(\w+)"
        join_pattern = r"JOIN\s+(\w+)"
        
        tables_in_query = set()
        
        for match in re.finditer(from_pattern, sql, re.IGNORECASE):
            tables_in_query.add(match.group(1).lower())
        
        for match in re.finditer(join_pattern, sql, re.IGNORECASE):
            tables_in_query.add(match.group(1).lower())
        
        # Check against schema
        schema_tables = set(t.lower() for t in self.table_names)
        invalid_tables = tables_in_query - schema_tables
        
        if invalid_tables:
            results["errors"].append(
                f"Table(s) not found in schema: {', '.join(invalid_tables)}"
            )
    
    def suggest_optimizations(self, sql: str) -> List[str]:
        """Suggest optimizations for the query"""
        suggestions = []
        sql_upper = sql.upper()
        
        # Check for NOT IN with subquery
        if "NOT IN (SELECT" in sql_upper:
            suggestions.append(
                "Consider using NOT EXISTS instead of NOT IN for better NULL handling"
            )
        
        # Check for multiple subqueries
        subquery_count = sql_upper.count("(SELECT")
        if subquery_count > 2:
            suggestions.append(
                "Consider using CTEs (WITH clause) for better readability with multiple subqueries"
            )
        
        # Check for DISTINCT
        if "DISTINCT" in sql_upper and "GROUP BY" not in sql_upper:
            suggestions.append(
                "DISTINCT can be expensive. Consider if GROUP BY might be more appropriate"
            )
        
        # Check for ORDER BY in subquery
        if re.search(r"\(SELECT.*ORDER BY.*\)", sql, re.IGNORECASE):
            suggestions.append(
                "ORDER BY in subquery may be ignored. Apply ORDER BY to outer query"
            )
        
        return suggestions

    def format_sql(self, sql: str) -> str:
        """Basic SQL formatting for readability"""
        # This is a simple formatter - for production use a proper SQL formatter
        formatted = sql.strip()
        
        # Add newlines before major keywords
        keywords = ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'HAVING', 'ORDER BY', 'LIMIT']
        for keyword in keywords:
            formatted = re.sub(
                f'\\b{keyword}\\b',
                f'\n{keyword}',
                formatted,
                flags=re.IGNORECASE
            )
        
        return formatted.strip()