File size: 5,595 Bytes
b59a07e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from __future__ import annotations

import random
from typing import Any, Tuple


class WorkloadGenerator:
    """Generates intentionally broken SQL queries for training optimization models."""

    def __init__(self, conn: Any, seed: int = 42) -> None:
        self.conn = conn
        random.seed(seed)

    def generate_broken_query(self) -> Tuple[str, float]:
        """Returns a tuple of (broken_query_string, baseline_latency_ms)."""
        patterns = [
            self._n_plus_one_query,
            self._missing_index_scan,
            self._bad_join_order,
            self._correlated_subquery,
            self._unnecessary_distinct,
            self._function_on_indexed_column,
        ]
        generator = random.choice(patterns)
        return generator()

    def _n_plus_one_query(self) -> Tuple[str, float]:
        query = """
            SELECT o.order_id, o.customer_id, o.order_date, o.status,
                   (SELECT c.name FROM customers c WHERE c.customer_id = o.customer_id) as customer_name
            FROM orders o
        """
        return query.strip(), 150.0

    def _missing_index_scan(self) -> Tuple[str, float]:
        query = """
            SELECT * FROM customers WHERE city = 'Springfield'
        """
        return query.strip(), 120.0

    def _bad_join_order(self) -> Tuple[str, float]:
        query = """
            SELECT o.order_id, c.name, p.name, oi.quantity
            FROM order_items oi
            JOIN orders o ON oi.order_id = o.order_id
            JOIN products p ON oi.product_id = p.product_id
            JOIN customers c ON o.customer_id = c.customer_id
        """
        return query.strip(), 200.0

    def _correlated_subquery(self) -> Tuple[str, float]:
        query = """
            SELECT c.customer_id, c.name,
                   (SELECT COUNT(*) FROM orders o WHERE o.customer_id = c.customer_id) as order_count
            FROM customers c
        """
        return query.strip(), 180.0

    def _unnecessary_distinct(self) -> Tuple[str, float]:
        query = """
            SELECT DISTINCT product_id FROM order_items
        """
        return query.strip(), 90.0

    def _function_on_indexed_column(self) -> Tuple[str, float]:
        query = """
            SELECT * FROM customers WHERE LOWER(email) = 'test@example.com'
        """
        return query.strip(), 100.0

    def get_expected_rows(self, broken_query: str) -> list[tuple[Any, ...]]:
        """Returns ground truth rows from the optimized version of the query."""
        optimized = self._optimize_query(broken_query)

        if not self.conn:
            return []

        try:
            with self.conn.cursor() as cur:
                cur.execute(optimized)
                return cur.fetchall()
        except Exception:
            return []

    def _optimize_query(self, broken_query: str) -> str:
        """Convert broken query to optimized version."""
        optimized = broken_query.strip()

        if "SELECT o.order_id, o.customer_id, o.order_date, o.status," in optimized and "(SELECT c.name" in optimized:
            return """
                SELECT o.order_id, o.customer_id, o.order_date, o.status, c.name as customer_name
                FROM orders o
                JOIN customers c ON o.customer_id = c.customer_id
            """.strip()

        if "WHERE city =" in optimized and "SELECT * FROM customers" in optimized:
            return optimized.replace("SELECT *", "SELECT customer_id, name, email, city, created_at")

        if "FROM order_items oi" in optimized and "JOIN orders o" in optimized:
            return """
                SELECT o.order_id, c.name, p.name, oi.quantity
                FROM orders o
                JOIN customers c ON o.customer_id = c.customer_id
                JOIN order_items oi ON o.order_id = oi.order_id
                JOIN products p ON oi.product_id = p.product_id
                WHERE o.status = 'delivered'
            """.strip()

        if "(SELECT COUNT(*) FROM orders o WHERE o.customer_id = c.customer_id)" in optimized:
            return """
                SELECT c.customer_id, c.name, COUNT(o.order_id) as order_count
                FROM customers c
                LEFT JOIN orders o ON c.customer_id = o.customer_id
                GROUP BY c.customer_id, c.name
            """.strip()

        if "SELECT DISTINCT product_id FROM order_items" in optimized:
            return "SELECT product_id FROM order_items"

        if "WHERE LOWER(email) =" in optimized:
            return optimized.replace("WHERE LOWER(email) =", "WHERE email =")

        return optimized

    def measure_latency(self, conn, query: str) -> float:
        """Execute EXPLAIN ANALYZE and extract execution time in ms."""
        import time
        try:
            start = time.perf_counter()
            cur = conn.cursor()
            cur.execute(f"EXPLAIN ANALYZE {query}")
            rows = cur.fetchall()
            cur.close()
            end = time.perf_counter()
            # Parse total time from EXPLAIN ANALYZE output
            total_ms = (end - start) * 1000
            for row in rows:
                line = str(row[0])
                if "Execution Time:" in line or "actual time=" in line:
                    pass
            return total_ms
        except Exception:
            # Fallback: raw timing
            start = time.perf_counter()
            cur = conn.cursor()
            cur.execute(query)
            cur.fetchall()
            cur.close()
            end = time.perf_counter()
            return (end - start) * 1000