corpusdb / app /query_complexity.py
mrsavage1's picture
Upload 52 files
723f9ab verified
import re
from typing import Tuple, Dict
class QueryComplexityAnalyzer:
"""Analyze and limit query complexity to prevent DoS"""
def __init__(self):
# Complexity limits
self.max_joins = 5
self.max_subqueries = 3
self.max_union_operations = 3
self.max_aggregations = 10
self.max_result_rows = 10000
self.max_query_time_seconds = 30
# Cost weights
self.costs = {
'select': 1,
'join': 10,
'subquery': 15,
'union': 5,
'aggregation': 5,
'sort': 8,
'group_by': 8,
'distinct': 10,
'cross_join': 50,
}
def analyze_complexity(self, sql: str) -> Tuple[bool, str, int]:
"""
Analyze query complexity
Returns:
(is_allowed, reason, complexity_score)
"""
sql_upper = sql.upper()
complexity_score = 0
# Count JOINs
join_count = len(re.findall(r'\bJOIN\b', sql_upper))
if join_count > self.max_joins:
return False, f"Too many JOINs ({join_count}). Maximum: {self.max_joins}", 0
complexity_score += join_count * self.costs['join']
# Count subqueries (nested SELECT statements)
subquery_count = len(re.findall(r'\(\s*SELECT\b', sql_upper))
if subquery_count > self.max_subqueries:
return False, f"Too many subqueries ({subquery_count}). Maximum: {self.max_subqueries}", 0
complexity_score += subquery_count * self.costs['subquery']
# Count UNION operations
union_count = len(re.findall(r'\bUNION\b', sql_upper))
if union_count > self.max_union_operations:
return False, f"Too many UNION operations ({union_count}). Maximum: {self.max_union_operations}", 0
complexity_score += union_count * self.costs['union']
# Count aggregations
aggregation_count = len(re.findall(r'\b(COUNT|SUM|AVG|MIN|MAX|GROUP_CONCAT)\b', sql_upper))
if aggregation_count > self.max_aggregations:
return False, f"Too many aggregations ({aggregation_count}). Maximum: {self.max_aggregations}", 0
complexity_score += aggregation_count * self.costs['aggregation']
# Check for CROSS JOIN (very expensive)
if 'CROSS JOIN' in sql_upper:
return False, "CROSS JOIN is not allowed (too expensive)", 0
# Check for SELECT * without LIMIT
if re.search(r'SELECT\s+\*', sql_upper) and 'LIMIT' not in sql_upper:
return False, "SELECT * requires LIMIT clause", 0
# Check LIMIT value
limit_match = re.search(r'LIMIT\s+(\d+)', sql_upper)
if limit_match:
limit_value = int(limit_match.group(1))
if limit_value > self.max_result_rows:
return False, f"LIMIT too high ({limit_value}). Maximum: {self.max_result_rows}", 0
# Add base costs
if 'ORDER BY' in sql_upper:
complexity_score += self.costs['sort']
if 'GROUP BY' in sql_upper:
complexity_score += self.costs['group_by']
if 'DISTINCT' in sql_upper:
complexity_score += self.costs['distinct']
# Complexity threshold
max_complexity = 100
if complexity_score > max_complexity:
return False, f"Query too complex (score: {complexity_score}). Maximum: {max_complexity}", complexity_score
return True, "", complexity_score
def estimate_result_size(self, sql: str) -> int:
"""Estimate result size from LIMIT clause"""
limit_match = re.search(r'LIMIT\s+(\d+)', sql.upper())
if limit_match:
return int(limit_match.group(1))
return self.max_result_rows # Assume max if no LIMIT
def add_safety_limit(self, sql: str) -> str:
"""Add LIMIT clause if missing"""
sql_upper = sql.upper()
if 'LIMIT' not in sql_upper:
# Add default LIMIT
sql = sql.rstrip(';') + f' LIMIT {self.max_result_rows}'
return sql
def get_complexity_report(self, sql: str) -> Dict:
"""Get detailed complexity report"""
sql_upper = sql.upper()
return {
'joins': len(re.findall(r'\bJOIN\b', sql_upper)),
'subqueries': len(re.findall(r'\(\s*SELECT\b', sql_upper)),
'unions': len(re.findall(r'\bUNION\b', sql_upper)),
'aggregations': len(re.findall(r'\b(COUNT|SUM|AVG|MIN|MAX)\b', sql_upper)),
'has_order_by': 'ORDER BY' in sql_upper,
'has_group_by': 'GROUP BY' in sql_upper,
'has_distinct': 'DISTINCT' in sql_upper,
'has_limit': 'LIMIT' in sql_upper,
'estimated_rows': self.estimate_result_size(sql)
}
query_complexity_analyzer = QueryComplexityAnalyzer()