Spaces:
Sleeping
Sleeping
File size: 5,595 Bytes
b59a07e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | from __future__ import annotations
import random
from typing import Any, Tuple
class WorkloadGenerator:
"""Generates intentionally broken SQL queries for training optimization models."""
def __init__(self, conn: Any, seed: int = 42) -> None:
self.conn = conn
random.seed(seed)
def generate_broken_query(self) -> Tuple[str, float]:
"""Returns a tuple of (broken_query_string, baseline_latency_ms)."""
patterns = [
self._n_plus_one_query,
self._missing_index_scan,
self._bad_join_order,
self._correlated_subquery,
self._unnecessary_distinct,
self._function_on_indexed_column,
]
generator = random.choice(patterns)
return generator()
def _n_plus_one_query(self) -> Tuple[str, float]:
query = """
SELECT o.order_id, o.customer_id, o.order_date, o.status,
(SELECT c.name FROM customers c WHERE c.customer_id = o.customer_id) as customer_name
FROM orders o
"""
return query.strip(), 150.0
def _missing_index_scan(self) -> Tuple[str, float]:
query = """
SELECT * FROM customers WHERE city = 'Springfield'
"""
return query.strip(), 120.0
def _bad_join_order(self) -> Tuple[str, float]:
query = """
SELECT o.order_id, c.name, p.name, oi.quantity
FROM order_items oi
JOIN orders o ON oi.order_id = o.order_id
JOIN products p ON oi.product_id = p.product_id
JOIN customers c ON o.customer_id = c.customer_id
"""
return query.strip(), 200.0
def _correlated_subquery(self) -> Tuple[str, float]:
query = """
SELECT c.customer_id, c.name,
(SELECT COUNT(*) FROM orders o WHERE o.customer_id = c.customer_id) as order_count
FROM customers c
"""
return query.strip(), 180.0
def _unnecessary_distinct(self) -> Tuple[str, float]:
query = """
SELECT DISTINCT product_id FROM order_items
"""
return query.strip(), 90.0
def _function_on_indexed_column(self) -> Tuple[str, float]:
query = """
SELECT * FROM customers WHERE LOWER(email) = 'test@example.com'
"""
return query.strip(), 100.0
def get_expected_rows(self, broken_query: str) -> list[tuple[Any, ...]]:
"""Returns ground truth rows from the optimized version of the query."""
optimized = self._optimize_query(broken_query)
if not self.conn:
return []
try:
with self.conn.cursor() as cur:
cur.execute(optimized)
return cur.fetchall()
except Exception:
return []
def _optimize_query(self, broken_query: str) -> str:
"""Convert broken query to optimized version."""
optimized = broken_query.strip()
if "SELECT o.order_id, o.customer_id, o.order_date, o.status," in optimized and "(SELECT c.name" in optimized:
return """
SELECT o.order_id, o.customer_id, o.order_date, o.status, c.name as customer_name
FROM orders o
JOIN customers c ON o.customer_id = c.customer_id
""".strip()
if "WHERE city =" in optimized and "SELECT * FROM customers" in optimized:
return optimized.replace("SELECT *", "SELECT customer_id, name, email, city, created_at")
if "FROM order_items oi" in optimized and "JOIN orders o" in optimized:
return """
SELECT o.order_id, c.name, p.name, oi.quantity
FROM orders o
JOIN customers c ON o.customer_id = c.customer_id
JOIN order_items oi ON o.order_id = oi.order_id
JOIN products p ON oi.product_id = p.product_id
WHERE o.status = 'delivered'
""".strip()
if "(SELECT COUNT(*) FROM orders o WHERE o.customer_id = c.customer_id)" in optimized:
return """
SELECT c.customer_id, c.name, COUNT(o.order_id) as order_count
FROM customers c
LEFT JOIN orders o ON c.customer_id = o.customer_id
GROUP BY c.customer_id, c.name
""".strip()
if "SELECT DISTINCT product_id FROM order_items" in optimized:
return "SELECT product_id FROM order_items"
if "WHERE LOWER(email) =" in optimized:
return optimized.replace("WHERE LOWER(email) =", "WHERE email =")
return optimized
def measure_latency(self, conn, query: str) -> float:
"""Execute EXPLAIN ANALYZE and extract execution time in ms."""
import time
try:
start = time.perf_counter()
cur = conn.cursor()
cur.execute(f"EXPLAIN ANALYZE {query}")
rows = cur.fetchall()
cur.close()
end = time.perf_counter()
# Parse total time from EXPLAIN ANALYZE output
total_ms = (end - start) * 1000
for row in rows:
line = str(row[0])
if "Execution Time:" in line or "actual time=" in line:
pass
return total_ms
except Exception:
# Fallback: raw timing
start = time.perf_counter()
cur = conn.cursor()
cur.execute(query)
cur.fetchall()
cur.close()
end = time.perf_counter()
return (end - start) * 1000
|