Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import random | |
| from typing import Any, Tuple | |
| class WorkloadGenerator: | |
| """Generates intentionally broken SQL queries for training optimization models.""" | |
| def __init__(self, conn: Any, seed: int = 42) -> None: | |
| self.conn = conn | |
| random.seed(seed) | |
| def generate_broken_query(self) -> Tuple[str, float]: | |
| """Returns a tuple of (broken_query_string, baseline_latency_ms).""" | |
| patterns = [ | |
| self._n_plus_one_query, | |
| self._missing_index_scan, | |
| self._bad_join_order, | |
| self._correlated_subquery, | |
| self._unnecessary_distinct, | |
| self._function_on_indexed_column, | |
| ] | |
| generator = random.choice(patterns) | |
| return generator() | |
| def _n_plus_one_query(self) -> Tuple[str, float]: | |
| query = """ | |
| SELECT o.order_id, o.customer_id, o.order_date, o.status, | |
| (SELECT c.name FROM customers c WHERE c.customer_id = o.customer_id) as customer_name | |
| FROM orders o | |
| """ | |
| return query.strip(), 150.0 | |
| def _missing_index_scan(self) -> Tuple[str, float]: | |
| query = """ | |
| SELECT * FROM customers WHERE city = 'Springfield' | |
| """ | |
| return query.strip(), 120.0 | |
| def _bad_join_order(self) -> Tuple[str, float]: | |
| query = """ | |
| SELECT o.order_id, c.name, p.name, oi.quantity | |
| FROM order_items oi | |
| JOIN orders o ON oi.order_id = o.order_id | |
| JOIN products p ON oi.product_id = p.product_id | |
| JOIN customers c ON o.customer_id = c.customer_id | |
| """ | |
| return query.strip(), 200.0 | |
| def _correlated_subquery(self) -> Tuple[str, float]: | |
| query = """ | |
| SELECT c.customer_id, c.name, | |
| (SELECT COUNT(*) FROM orders o WHERE o.customer_id = c.customer_id) as order_count | |
| FROM customers c | |
| """ | |
| return query.strip(), 180.0 | |
| def _unnecessary_distinct(self) -> Tuple[str, float]: | |
| query = """ | |
| SELECT DISTINCT product_id FROM order_items | |
| """ | |
| return query.strip(), 90.0 | |
| def _function_on_indexed_column(self) -> Tuple[str, float]: | |
| query = """ | |
| SELECT * FROM customers WHERE LOWER(email) = 'test@example.com' | |
| """ | |
| return query.strip(), 100.0 | |
| def get_expected_rows(self, broken_query: str) -> list[tuple[Any, ...]]: | |
| """Returns ground truth rows from the optimized version of the query.""" | |
| optimized = self._optimize_query(broken_query) | |
| if not self.conn: | |
| return [] | |
| try: | |
| with self.conn.cursor() as cur: | |
| cur.execute(optimized) | |
| return cur.fetchall() | |
| except Exception: | |
| return [] | |
| def _optimize_query(self, broken_query: str) -> str: | |
| """Convert broken query to optimized version.""" | |
| optimized = broken_query.strip() | |
| if "SELECT o.order_id, o.customer_id, o.order_date, o.status," in optimized and "(SELECT c.name" in optimized: | |
| return """ | |
| SELECT o.order_id, o.customer_id, o.order_date, o.status, c.name as customer_name | |
| FROM orders o | |
| JOIN customers c ON o.customer_id = c.customer_id | |
| """.strip() | |
| if "WHERE city =" in optimized and "SELECT * FROM customers" in optimized: | |
| return optimized.replace("SELECT *", "SELECT customer_id, name, email, city, created_at") | |
| if "FROM order_items oi" in optimized and "JOIN orders o" in optimized: | |
| return """ | |
| SELECT o.order_id, c.name, p.name, oi.quantity | |
| FROM orders o | |
| JOIN customers c ON o.customer_id = c.customer_id | |
| JOIN order_items oi ON o.order_id = oi.order_id | |
| JOIN products p ON oi.product_id = p.product_id | |
| WHERE o.status = 'delivered' | |
| """.strip() | |
| if "(SELECT COUNT(*) FROM orders o WHERE o.customer_id = c.customer_id)" in optimized: | |
| return """ | |
| SELECT c.customer_id, c.name, COUNT(o.order_id) as order_count | |
| FROM customers c | |
| LEFT JOIN orders o ON c.customer_id = o.customer_id | |
| GROUP BY c.customer_id, c.name | |
| """.strip() | |
| if "SELECT DISTINCT product_id FROM order_items" in optimized: | |
| return "SELECT product_id FROM order_items" | |
| if "WHERE LOWER(email) =" in optimized: | |
| return optimized.replace("WHERE LOWER(email) =", "WHERE email =") | |
| return optimized | |
| def measure_latency(self, conn, query: str) -> float: | |
| """Execute EXPLAIN ANALYZE and extract execution time in ms.""" | |
| import time | |
| try: | |
| start = time.perf_counter() | |
| cur = conn.cursor() | |
| cur.execute(f"EXPLAIN ANALYZE {query}") | |
| rows = cur.fetchall() | |
| cur.close() | |
| end = time.perf_counter() | |
| # Parse total time from EXPLAIN ANALYZE output | |
| total_ms = (end - start) * 1000 | |
| for row in rows: | |
| line = str(row[0]) | |
| if "Execution Time:" in line or "actual time=" in line: | |
| pass | |
| return total_ms | |
| except Exception: | |
| # Fallback: raw timing | |
| start = time.perf_counter() | |
| cur = conn.cursor() | |
| cur.execute(query) | |
| cur.fetchall() | |
| cur.close() | |
| end = time.perf_counter() | |
| return (end - start) * 1000 | |