File size: 4,977 Bytes
723f9ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import re
from typing import Tuple, Dict

class QueryComplexityAnalyzer:
    """Analyze and limit query complexity to prevent DoS"""
    
    def __init__(self):
        # Complexity limits
        self.max_joins = 5
        self.max_subqueries = 3
        self.max_union_operations = 3
        self.max_aggregations = 10
        self.max_result_rows = 10000
        self.max_query_time_seconds = 30
        
        # Cost weights
        self.costs = {
            'select': 1,
            'join': 10,
            'subquery': 15,
            'union': 5,
            'aggregation': 5,
            'sort': 8,
            'group_by': 8,
            'distinct': 10,
            'cross_join': 50,
        }
    
    def analyze_complexity(self, sql: str) -> Tuple[bool, str, int]:
        """
        Analyze query complexity
        
        Returns:
            (is_allowed, reason, complexity_score)
        """
        sql_upper = sql.upper()
        complexity_score = 0
        
        # Count JOINs
        join_count = len(re.findall(r'\bJOIN\b', sql_upper))
        if join_count > self.max_joins:
            return False, f"Too many JOINs ({join_count}). Maximum: {self.max_joins}", 0
        complexity_score += join_count * self.costs['join']
        
        # Count subqueries (nested SELECT statements)
        subquery_count = len(re.findall(r'\(\s*SELECT\b', sql_upper))
        if subquery_count > self.max_subqueries:
            return False, f"Too many subqueries ({subquery_count}). Maximum: {self.max_subqueries}", 0
        complexity_score += subquery_count * self.costs['subquery']
        
        # Count UNION operations
        union_count = len(re.findall(r'\bUNION\b', sql_upper))
        if union_count > self.max_union_operations:
            return False, f"Too many UNION operations ({union_count}). Maximum: {self.max_union_operations}", 0
        complexity_score += union_count * self.costs['union']
        
        # Count aggregations
        aggregation_count = len(re.findall(r'\b(COUNT|SUM|AVG|MIN|MAX|GROUP_CONCAT)\b', sql_upper))
        if aggregation_count > self.max_aggregations:
            return False, f"Too many aggregations ({aggregation_count}). Maximum: {self.max_aggregations}", 0
        complexity_score += aggregation_count * self.costs['aggregation']
        
        # Check for CROSS JOIN (very expensive)
        if 'CROSS JOIN' in sql_upper:
            return False, "CROSS JOIN is not allowed (too expensive)", 0
        
        # Check for SELECT * without LIMIT
        if re.search(r'SELECT\s+\*', sql_upper) and 'LIMIT' not in sql_upper:
            return False, "SELECT * requires LIMIT clause", 0
        
        # Check LIMIT value
        limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper)
        if limit_match:
            limit_value = int(limit_match.group(1))
            if limit_value > self.max_result_rows:
                return False, f"LIMIT too high ({limit_value}). Maximum: {self.max_result_rows}", 0
        
        # Add base costs
        if 'ORDER BY' in sql_upper:
            complexity_score += self.costs['sort']
        if 'GROUP BY' in sql_upper:
            complexity_score += self.costs['group_by']
        if 'DISTINCT' in sql_upper:
            complexity_score += self.costs['distinct']
        
        # Complexity threshold
        max_complexity = 100
        if complexity_score > max_complexity:
            return False, f"Query too complex (score: {complexity_score}). Maximum: {max_complexity}", complexity_score
        
        return True, "", complexity_score
    
    def estimate_result_size(self, sql: str) -> int:
        """Estimate result size from LIMIT clause"""
        limit_match = re.search(r'LIMIT\s+(\d+)', sql.upper())
        if limit_match:
            return int(limit_match.group(1))
        return self.max_result_rows  # Assume max if no LIMIT
    
    def add_safety_limit(self, sql: str) -> str:
        """Add LIMIT clause if missing"""
        sql_upper = sql.upper()
        
        if 'LIMIT' not in sql_upper:
            # Add default LIMIT
            sql = sql.rstrip(';') + f' LIMIT {self.max_result_rows}'
        
        return sql
    
    def get_complexity_report(self, sql: str) -> Dict:
        """Get detailed complexity report"""
        sql_upper = sql.upper()
        
        return {
            'joins': len(re.findall(r'\bJOIN\b', sql_upper)),
            'subqueries': len(re.findall(r'\(\s*SELECT\b', sql_upper)),
            'unions': len(re.findall(r'\bUNION\b', sql_upper)),
            'aggregations': len(re.findall(r'\b(COUNT|SUM|AVG|MIN|MAX)\b', sql_upper)),
            'has_order_by': 'ORDER BY' in sql_upper,
            'has_group_by': 'GROUP BY' in sql_upper,
            'has_distinct': 'DISTINCT' in sql_upper,
            'has_limit': 'LIMIT' in sql_upper,
            'estimated_rows': self.estimate_result_size(sql)
        }

query_complexity_analyzer = QueryComplexityAnalyzer()