autonomic-dbre / dbre /workload_generator.py
ZeroiJ's picture
Initial commit: Autonomic DBRE - Self-Improving Database Reliability Engineer
b59a07e
from __future__ import annotations
import random
from typing import Any, Tuple
class WorkloadGenerator:
"""Generates intentionally broken SQL queries for training optimization models."""
def __init__(self, conn: Any, seed: int = 42) -> None:
self.conn = conn
random.seed(seed)
def generate_broken_query(self) -> Tuple[str, float]:
"""Returns a tuple of (broken_query_string, baseline_latency_ms)."""
patterns = [
self._n_plus_one_query,
self._missing_index_scan,
self._bad_join_order,
self._correlated_subquery,
self._unnecessary_distinct,
self._function_on_indexed_column,
]
generator = random.choice(patterns)
return generator()
def _n_plus_one_query(self) -> Tuple[str, float]:
query = """
SELECT o.order_id, o.customer_id, o.order_date, o.status,
(SELECT c.name FROM customers c WHERE c.customer_id = o.customer_id) as customer_name
FROM orders o
"""
return query.strip(), 150.0
def _missing_index_scan(self) -> Tuple[str, float]:
query = """
SELECT * FROM customers WHERE city = 'Springfield'
"""
return query.strip(), 120.0
def _bad_join_order(self) -> Tuple[str, float]:
query = """
SELECT o.order_id, c.name, p.name, oi.quantity
FROM order_items oi
JOIN orders o ON oi.order_id = o.order_id
JOIN products p ON oi.product_id = p.product_id
JOIN customers c ON o.customer_id = c.customer_id
"""
return query.strip(), 200.0
def _correlated_subquery(self) -> Tuple[str, float]:
query = """
SELECT c.customer_id, c.name,
(SELECT COUNT(*) FROM orders o WHERE o.customer_id = c.customer_id) as order_count
FROM customers c
"""
return query.strip(), 180.0
def _unnecessary_distinct(self) -> Tuple[str, float]:
query = """
SELECT DISTINCT product_id FROM order_items
"""
return query.strip(), 90.0
def _function_on_indexed_column(self) -> Tuple[str, float]:
query = """
SELECT * FROM customers WHERE LOWER(email) = 'test@example.com'
"""
return query.strip(), 100.0
def get_expected_rows(self, broken_query: str) -> list[tuple[Any, ...]]:
"""Returns ground truth rows from the optimized version of the query."""
optimized = self._optimize_query(broken_query)
if not self.conn:
return []
try:
with self.conn.cursor() as cur:
cur.execute(optimized)
return cur.fetchall()
except Exception:
return []
def _optimize_query(self, broken_query: str) -> str:
"""Convert broken query to optimized version."""
optimized = broken_query.strip()
if "SELECT o.order_id, o.customer_id, o.order_date, o.status," in optimized and "(SELECT c.name" in optimized:
return """
SELECT o.order_id, o.customer_id, o.order_date, o.status, c.name as customer_name
FROM orders o
JOIN customers c ON o.customer_id = c.customer_id
""".strip()
if "WHERE city =" in optimized and "SELECT * FROM customers" in optimized:
return optimized.replace("SELECT *", "SELECT customer_id, name, email, city, created_at")
if "FROM order_items oi" in optimized and "JOIN orders o" in optimized:
return """
SELECT o.order_id, c.name, p.name, oi.quantity
FROM orders o
JOIN customers c ON o.customer_id = c.customer_id
JOIN order_items oi ON o.order_id = oi.order_id
JOIN products p ON oi.product_id = p.product_id
WHERE o.status = 'delivered'
""".strip()
if "(SELECT COUNT(*) FROM orders o WHERE o.customer_id = c.customer_id)" in optimized:
return """
SELECT c.customer_id, c.name, COUNT(o.order_id) as order_count
FROM customers c
LEFT JOIN orders o ON c.customer_id = o.customer_id
GROUP BY c.customer_id, c.name
""".strip()
if "SELECT DISTINCT product_id FROM order_items" in optimized:
return "SELECT product_id FROM order_items"
if "WHERE LOWER(email) =" in optimized:
return optimized.replace("WHERE LOWER(email) =", "WHERE email =")
return optimized
def measure_latency(self, conn, query: str) -> float:
"""Execute EXPLAIN ANALYZE and extract execution time in ms."""
import time
try:
start = time.perf_counter()
cur = conn.cursor()
cur.execute(f"EXPLAIN ANALYZE {query}")
rows = cur.fetchall()
cur.close()
end = time.perf_counter()
# Parse total time from EXPLAIN ANALYZE output
total_ms = (end - start) * 1000
for row in rows:
line = str(row[0])
if "Execution Time:" in line or "actual time=" in line:
pass
return total_ms
except Exception:
# Fallback: raw timing
start = time.perf_counter()
cur = conn.cursor()
cur.execute(query)
cur.fetchall()
cur.close()
end = time.perf_counter()
return (end - start) * 1000