import re from typing import Tuple, Dict class QueryComplexityAnalyzer: """Analyze and limit query complexity to prevent DoS""" def __init__(self): # Complexity limits self.max_joins = 5 self.max_subqueries = 3 self.max_union_operations = 3 self.max_aggregations = 10 self.max_result_rows = 10000 self.max_query_time_seconds = 30 # Cost weights self.costs = { 'select': 1, 'join': 10, 'subquery': 15, 'union': 5, 'aggregation': 5, 'sort': 8, 'group_by': 8, 'distinct': 10, 'cross_join': 50, } def analyze_complexity(self, sql: str) -> Tuple[bool, str, int]: """ Analyze query complexity Returns: (is_allowed, reason, complexity_score) """ sql_upper = sql.upper() complexity_score = 0 # Count JOINs join_count = len(re.findall(r'\bJOIN\b', sql_upper)) if join_count > self.max_joins: return False, f"Too many JOINs ({join_count}). Maximum: {self.max_joins}", 0 complexity_score += join_count * self.costs['join'] # Count subqueries (nested SELECT statements) subquery_count = len(re.findall(r'\(\s*SELECT\b', sql_upper)) if subquery_count > self.max_subqueries: return False, f"Too many subqueries ({subquery_count}). Maximum: {self.max_subqueries}", 0 complexity_score += subquery_count * self.costs['subquery'] # Count UNION operations union_count = len(re.findall(r'\bUNION\b', sql_upper)) if union_count > self.max_union_operations: return False, f"Too many UNION operations ({union_count}). Maximum: {self.max_union_operations}", 0 complexity_score += union_count * self.costs['union'] # Count aggregations aggregation_count = len(re.findall(r'\b(COUNT|SUM|AVG|MIN|MAX|GROUP_CONCAT)\b', sql_upper)) if aggregation_count > self.max_aggregations: return False, f"Too many aggregations ({aggregation_count}). Maximum: {self.max_aggregations}", 0 complexity_score += aggregation_count * self.costs['aggregation'] # Check for CROSS JOIN (very expensive) if 'CROSS JOIN' in sql_upper: return False, "CROSS JOIN is not allowed (too expensive)", 0 # Check for SELECT * without LIMIT if re.search(r'SELECT\s+\*', sql_upper) and 'LIMIT' not in sql_upper: return False, "SELECT * requires LIMIT clause", 0 # Check LIMIT value limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper) if limit_match: limit_value = int(limit_match.group(1)) if limit_value > self.max_result_rows: return False, f"LIMIT too high ({limit_value}). Maximum: {self.max_result_rows}", 0 # Add base costs if 'ORDER BY' in sql_upper: complexity_score += self.costs['sort'] if 'GROUP BY' in sql_upper: complexity_score += self.costs['group_by'] if 'DISTINCT' in sql_upper: complexity_score += self.costs['distinct'] # Complexity threshold max_complexity = 100 if complexity_score > max_complexity: return False, f"Query too complex (score: {complexity_score}). Maximum: {max_complexity}", complexity_score return True, "", complexity_score def estimate_result_size(self, sql: str) -> int: """Estimate result size from LIMIT clause""" limit_match = re.search(r'LIMIT\s+(\d+)', sql.upper()) if limit_match: return int(limit_match.group(1)) return self.max_result_rows # Assume max if no LIMIT def add_safety_limit(self, sql: str) -> str: """Add LIMIT clause if missing""" sql_upper = sql.upper() if 'LIMIT' not in sql_upper: # Add default LIMIT sql = sql.rstrip(';') + f' LIMIT {self.max_result_rows}' return sql def get_complexity_report(self, sql: str) -> Dict: """Get detailed complexity report""" sql_upper = sql.upper() return { 'joins': len(re.findall(r'\bJOIN\b', sql_upper)), 'subqueries': len(re.findall(r'\(\s*SELECT\b', sql_upper)), 'unions': len(re.findall(r'\bUNION\b', sql_upper)), 'aggregations': len(re.findall(r'\b(COUNT|SUM|AVG|MIN|MAX)\b', sql_upper)), 'has_order_by': 'ORDER BY' in sql_upper, 'has_group_by': 'GROUP BY' in sql_upper, 'has_distinct': 'DISTINCT' in sql_upper, 'has_limit': 'LIMIT' in sql_upper, 'estimated_rows': self.estimate_result_size(sql) } query_complexity_analyzer = QueryComplexityAnalyzer()