| import re |
| from typing import Tuple, Dict |
|
|
| class QueryComplexityAnalyzer: |
| """Analyze and limit query complexity to prevent DoS""" |
| |
| def __init__(self): |
| |
| self.max_joins = 5 |
| self.max_subqueries = 3 |
| self.max_union_operations = 3 |
| self.max_aggregations = 10 |
| self.max_result_rows = 10000 |
| self.max_query_time_seconds = 30 |
| |
| |
| self.costs = { |
| 'select': 1, |
| 'join': 10, |
| 'subquery': 15, |
| 'union': 5, |
| 'aggregation': 5, |
| 'sort': 8, |
| 'group_by': 8, |
| 'distinct': 10, |
| 'cross_join': 50, |
| } |
| |
| def analyze_complexity(self, sql: str) -> Tuple[bool, str, int]: |
| """ |
| Analyze query complexity |
| |
| Returns: |
| (is_allowed, reason, complexity_score) |
| """ |
| sql_upper = sql.upper() |
| complexity_score = 0 |
| |
| |
| join_count = len(re.findall(r'\bJOIN\b', sql_upper)) |
| if join_count > self.max_joins: |
| return False, f"Too many JOINs ({join_count}). Maximum: {self.max_joins}", 0 |
| complexity_score += join_count * self.costs['join'] |
| |
| |
| subquery_count = len(re.findall(r'\(\s*SELECT\b', sql_upper)) |
| if subquery_count > self.max_subqueries: |
| return False, f"Too many subqueries ({subquery_count}). Maximum: {self.max_subqueries}", 0 |
| complexity_score += subquery_count * self.costs['subquery'] |
| |
| |
| union_count = len(re.findall(r'\bUNION\b', sql_upper)) |
| if union_count > self.max_union_operations: |
| return False, f"Too many UNION operations ({union_count}). Maximum: {self.max_union_operations}", 0 |
| complexity_score += union_count * self.costs['union'] |
| |
| |
| aggregation_count = len(re.findall(r'\b(COUNT|SUM|AVG|MIN|MAX|GROUP_CONCAT)\b', sql_upper)) |
| if aggregation_count > self.max_aggregations: |
| return False, f"Too many aggregations ({aggregation_count}). Maximum: {self.max_aggregations}", 0 |
| complexity_score += aggregation_count * self.costs['aggregation'] |
| |
| |
| if 'CROSS JOIN' in sql_upper: |
| return False, "CROSS JOIN is not allowed (too expensive)", 0 |
| |
| |
| if re.search(r'SELECT\s+\*', sql_upper) and 'LIMIT' not in sql_upper: |
| return False, "SELECT * requires LIMIT clause", 0 |
| |
| |
| limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper) |
| if limit_match: |
| limit_value = int(limit_match.group(1)) |
| if limit_value > self.max_result_rows: |
| return False, f"LIMIT too high ({limit_value}). Maximum: {self.max_result_rows}", 0 |
| |
| |
| if 'ORDER BY' in sql_upper: |
| complexity_score += self.costs['sort'] |
| if 'GROUP BY' in sql_upper: |
| complexity_score += self.costs['group_by'] |
| if 'DISTINCT' in sql_upper: |
| complexity_score += self.costs['distinct'] |
| |
| |
| max_complexity = 100 |
| if complexity_score > max_complexity: |
| return False, f"Query too complex (score: {complexity_score}). Maximum: {max_complexity}", complexity_score |
| |
| return True, "", complexity_score |
| |
| def estimate_result_size(self, sql: str) -> int: |
| """Estimate result size from LIMIT clause""" |
| limit_match = re.search(r'LIMIT\s+(\d+)', sql.upper()) |
| if limit_match: |
| return int(limit_match.group(1)) |
| return self.max_result_rows |
| |
| def add_safety_limit(self, sql: str) -> str: |
| """Add LIMIT clause if missing""" |
| sql_upper = sql.upper() |
| |
| if 'LIMIT' not in sql_upper: |
| |
| sql = sql.rstrip(';') + f' LIMIT {self.max_result_rows}' |
| |
| return sql |
| |
| def get_complexity_report(self, sql: str) -> Dict: |
| """Get detailed complexity report""" |
| sql_upper = sql.upper() |
| |
| return { |
| 'joins': len(re.findall(r'\bJOIN\b', sql_upper)), |
| 'subqueries': len(re.findall(r'\(\s*SELECT\b', sql_upper)), |
| 'unions': len(re.findall(r'\bUNION\b', sql_upper)), |
| 'aggregations': len(re.findall(r'\b(COUNT|SUM|AVG|MIN|MAX)\b', sql_upper)), |
| 'has_order_by': 'ORDER BY' in sql_upper, |
| 'has_group_by': 'GROUP BY' in sql_upper, |
| 'has_distinct': 'DISTINCT' in sql_upper, |
| 'has_limit': 'LIMIT' in sql_upper, |
| 'estimated_rows': self.estimate_result_size(sql) |
| } |
|
|
| query_complexity_analyzer = QueryComplexityAnalyzer() |
|
|