File size: 8,892 Bytes
c15d346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
"""
executor.py β€” DuckDB In-Memory SQL Execution Engine
=====================================================
The core innovation of this environment: instead of keyword-matching
heuristics, we ACTUALLY execute both the original and optimized queries
against realistic synthetic data and measure real performance differences.

Tables populated:
  users    β€” 10,000 rows
  orders   β€” 500,000 rows
  products β€”  1,000 rows
  events   β€” 1,000,000 rows
"""

import threading
import time
from typing import Any, Dict, List, Optional, Tuple

import duckdb

_instance: Optional["QueryExecutor"] = None
_lock = threading.Lock()


class QueryExecutor:
    """
    Runs SQL against an in-memory DuckDB database with realistic
    synthetic data.  Provides execution timing, result correctness
    checks, and EXPLAIN plans β€” all used by the reward function.
    """

    def __init__(self) -> None:
        self.conn = duckdb.connect(database=":memory:")
        self.conn.execute("SET threads=2")
        self._build_tables()

    # ── Schema Setup ─────────────────────────────────────────────────────

    def _build_tables(self) -> None:
        """Create and populate all four synthetic tables."""

        # users β€” 10k rows
        self.conn.execute("""
            CREATE TABLE users AS
            SELECT
                i                                                      AS id,
                'u' || i || '@mail.com'                                AS email,
                CASE i % 3
                    WHEN 0 THEN 'premium'
                    WHEN 1 THEN 'free'
                    ELSE 'enterprise' END                              AS tier,
                CASE i % 5
                    WHEN 0 THEN 'US'   WHEN 1 THEN 'EU'
                    WHEN 2 THEN 'IN'   WHEN 3 THEN 'UK'
                    ELSE 'AU' END                                      AS region,
                CASE i % 2 WHEN 0 THEN 'premium' ELSE 'basic' END     AS plan,
                DATE '2020-01-01' + CAST(i AS INTEGER)                 AS created_at
            FROM generate_series(1, 10000) t(i)
        """)

        # orders β€” 500k rows
        self.conn.execute("""
            CREATE TABLE orders AS
            SELECT
                i                                                      AS id,
                1 + (i % 10000)                                        AS customer_id,
                (i % 100) + 1                                          AS product_id,
                CASE i % 4
                    WHEN 0 THEN 'completed'  WHEN 1 THEN 'pending'
                    WHEN 2 THEN 'cancelled'  ELSE 'shipped' END        AS status,
                ROUND((i % 1000) * 1.5 + 49.99, 2)                   AS total,
                DATE '2023-01-01' + CAST(i % 730 AS INTEGER)          AS created_at
            FROM generate_series(1, 500000) t(i)
        """)

        # products β€” 1k rows
        self.conn.execute("""
            CREATE TABLE products AS
            SELECT
                i                                                      AS id,
                'Product_' || i                                        AS name,
                CASE i % 5
                    WHEN 0 THEN 'Electronics'  WHEN 1 THEN 'Clothing'
                    WHEN 2 THEN 'Food'         WHEN 3 THEN 'Books'
                    ELSE 'Sports' END                                  AS category,
                ROUND((i % 500) + 9.99, 2)                            AS price
            FROM generate_series(1, 1000) t(i)
        """)

        # events β€” 1M rows
        self.conn.execute("""
            CREATE TABLE events AS
            SELECT
                i                                                      AS id,
                1 + (i % 10000)                                        AS user_id,
                'sess_' || (i % 50000)                                 AS session_id,
                CASE i % 6
                    WHEN 0 THEN 'purchase'  WHEN 1 THEN 'view'
                    WHEN 2 THEN 'click'     WHEN 3 THEN 'signup'
                    WHEN 4 THEN 'logout'    ELSE 'search' END          AS event_type,
                DATE '2024-01-01' + CAST(i % 365 AS INTEGER)          AS occurred_at
            FROM generate_series(1, 1000000) t(i)
        """)

    # ── Execution helpers ─────────────────────────────────────────────────

    def _run(
        self, query: str, runs: int = 3
    ) -> Tuple[float, Optional[List], Optional[str]]:
        """
        Execute *query* up to *runs* times.
        Returns (median_ms, rows, error_or_None).
        """
        timings: List[float] = []
        rows: Optional[List] = None

        for _ in range(runs):
            try:
                t0 = time.perf_counter()
                rows = self.conn.execute(query).fetchall()
                timings.append((time.perf_counter() - t0) * 1000.0)
            except Exception as exc:
                return 99_999.0, None, str(exc)

        timings.sort()
        return round(timings[len(timings) // 2], 3), rows, None

    # ── Public API ────────────────────────────────────────────────────────

    def compare(self, original: str, optimized: str) -> Dict[str, Any]:
        """
        Execute both queries, measure real timing, check correctness.

        Returns a dict with:
          original_ms, optimized_ms, speedup,
          results_match, original_rows, optimized_rows,
          original_error, optimized_error, verdict
        """
        orig_ms, orig_rows, orig_err = self._run(original)
        opt_ms, opt_rows, opt_err = self._run(optimized)

        # ── Correctness: do both queries return the same data? ────────
        results_match = False
        if orig_rows is not None and opt_rows is not None:
            try:
                orig_s = sorted(str(r) for r in orig_rows)
                opt_s = sorted(str(r) for r in opt_rows)
                results_match = orig_s == opt_s
            except Exception:
                results_match = len(orig_rows) == len(opt_rows)

        # ── Speedup ratio ─────────────────────────────────────────────
        speedup = 1.0
        if opt_ms > 0 and orig_ms < 90_000:
            speedup = round(orig_ms / opt_ms, 3)

        # ── Human-readable verdict ────────────────────────────────────
        if opt_err:
            verdict = f"[FAIL] Optimized query error: {opt_err[:120]}"
        elif results_match and speedup >= 2.0:
            verdict = f"[OK] {speedup:.1f}x faster with correct results"
        elif results_match and speedup >= 1.0:
            verdict = f"[WARN] Correct results but only {speedup:.1f}x speedup -- dig deeper"
        elif not results_match and speedup >= 2.0:
            verdict = f"[WARN] {speedup:.1f}x faster but results don't match -- fix the logic"
        else:
            verdict = f"[FAIL] {speedup:.1f}x -- no meaningful improvement"

        return {
            "original_ms":     orig_ms,
            "optimized_ms":    opt_ms,
            "speedup":         speedup,
            "results_match":   results_match,
            "original_rows":   len(orig_rows) if orig_rows is not None else 0,
            "optimized_rows":  len(opt_rows) if opt_rows is not None else 0,
            "original_error":  orig_err,
            "optimized_error": opt_err,
            "verdict":         verdict,
        }

    def explain(self, query: str) -> str:
        """Return EXPLAIN output for a query."""
        try:
            rows = self.conn.execute(f"EXPLAIN {query}").fetchall()
            return "\n".join(str(r[1]) for r in rows)
        except Exception as exc:
            return f"EXPLAIN error: {exc}"

    @property
    def table_stats(self) -> Dict[str, int]:
        tables = ["users", "orders", "products", "events"]
        return {
            t: self.conn.execute(f"SELECT COUNT(*) FROM {t}").fetchone()[0]
            for t in tables
        }


# ── Singleton accessor ────────────────────────────────────────────────────

def get_executor() -> QueryExecutor:
    """Return the process-level singleton (lazy init, thread-safe)."""
    global _instance
    if _instance is None:
        with _lock:
            if _instance is None:
                _instance = QueryExecutor()
    return _instance