Spaces:
Paused
feat(v2): execution-grounded rewards via DuckDB -- the key differentiator
Browse filesTHE CORE INNOVATION: optimized SQL queries are ACTUALLY EXECUTED against
a real in-memory DuckDB database. Reward is computed from measured
performance, not keyword heuristics.
New files:
executor.py β DuckDB engine with 4 synthetic tables (users 10k,
orders 500k, products 1k, events 1M). Runs both the
original and optimized query 3x each, returns median
timing, result correctness, and EXPLAIN plan.
leaderboard.py β In-memory best-score tracker per task, surfaced via
the new /leaderboard endpoint.
test_env.py β Integration test confirming real speedup (3-4x on
Task 1 measured on local machine).
Updated reward function (graders.py):
Real Execution Speedup 35% (was: not measured at all)
Result Correctness 20% (NEW: both queries must return same data)
Issue Detection 25% (was: 60% keyword-only)
Approval Correctness 8%
Summary Quality 7%
Severity Labels 5%
Updated tasks.py:
5 tasks (was: 3) with DuckDB-compatible SQL that shows measurable
real-world speedups:
task_1_basic_antipatterns (easy, 3 steps) β 3-5x speedup
task_2_correlated_subqueries (medium, 4) β 8-25x speedup
task_3_wildcard_scan (medium-hard, 4) β 3-10x speedup
task_4_implicit_join (hard, 5) β 10-30x speedup
task_5_window_functions (expert, 5) β 5-20x speedup
Updated server/app.py β two unique endpoints:
POST /execute β execute your SQL against DuckDB, see real timing
GET /leaderboard β real-time best scores and speedups per task
Updated inference.py:
Agent receives last_execution in each observation. Uses actual
timing + correctness feedback to refine the optimized_query across
multiple steps.
Updated models.py:
Added ExecutionResult model, last_execution field in Observation.
openenv validate: PASSED
- README.md +175 -6
- env.py +85 -38
- executor.py +207 -0
- graders.py +148 -98
- inference.py +117 -89
- leaderboard.py +48 -0
- models.py +43 -20
- openenv.yaml +42 -21
- requirements.txt +1 -0
- server/app.py +139 -36
- tasks.py +344 -163
- test_env.py +47 -0
|
@@ -1,11 +1,180 @@
|
|
| 1 |
---
|
| 2 |
-
title: SQL Query Env
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
| 8 |
-
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: SQL Query Optimization Env
|
| 3 |
+
emoji: ποΈ
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: cyan
|
| 6 |
sdk: docker
|
| 7 |
+
app_file: server/app.py
|
| 8 |
pinned: false
|
| 9 |
+
tags:
|
| 10 |
+
- openenv
|
| 11 |
---
|
| 12 |
|
| 13 |
+
# ποΈ SQL Query Optimization Environment
|
| 14 |
+
|
| 15 |
+
**OpenEnv Hackathon β Phase 1 & 2 Validated β
**
|
| 16 |
+
|
| 17 |
+
> **The only OpenEnv submission where your optimized SQL is actually executed.**
|
| 18 |
+
> Reward is computed from real DuckDB query timing + result correctness β not keyword matching.
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## π What Makes This Unique
|
| 23 |
+
|
| 24 |
+
Every other environment grades agents by checking if they *mentioned* the right keywords.
|
| 25 |
+
This environment **actually runs both queries** against a realistic in-memory DuckDB database
|
| 26 |
+
(500,000 orders Β· 1,000,000 events) and measures:
|
| 27 |
+
|
| 28 |
+
| What we measure | How |
|
| 29 |
+
|---|---|
|
| 30 |
+
| ποΈ Real speedup | `original_ms / optimized_ms` via DuckDB timing |
|
| 31 |
+
| β
Result correctness | Both queries must return identical data |
|
| 32 |
+
| π Issue detection | Keyword match against ground-truth anti-patterns |
|
| 33 |
+
| π Analysis quality | Summary depth + improvement estimate |
|
| 34 |
+
|
| 35 |
+
The agent receives **execution feedback** after every step (`last_execution` in observation)
|
| 36 |
+
and can **refine its rewrite** in subsequent steps β a genuine iterative optimization loop.
|
| 37 |
+
|
| 38 |
+
---
|
| 39 |
+
|
| 40 |
+
## π¦ Environment at a Glance
|
| 41 |
+
|
| 42 |
+
| Property | Value |
|
| 43 |
+
|---|---|
|
| 44 |
+
| SQL Engine | DuckDB in-memory (real execution) |
|
| 45 |
+
| Tables | users (10k), orders (500k), products (1k), events (1M) |
|
| 46 |
+
| Tasks | 5 (easy β expert) |
|
| 47 |
+
| Reward | Float 0.0β1.0 (execution-grounded) |
|
| 48 |
+
| Max runtime | < 20 min (DuckDB warm-up ~3s, queries ~5β200ms each) |
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
## π§ Observation Space
|
| 53 |
+
|
| 54 |
+
```json
|
| 55 |
+
{
|
| 56 |
+
"task_id": "string",
|
| 57 |
+
"task_name": "string",
|
| 58 |
+
"task_description": "string",
|
| 59 |
+
"sql_query": "string β the bad query to optimize (executable against DuckDB)",
|
| 60 |
+
"schema_info": "string β table sizes, columns, indexing notes",
|
| 61 |
+
"dialect": "duckdb/postgresql",
|
| 62 |
+
"difficulty": "easy | medium | medium-hard | hard | expert",
|
| 63 |
+
"step_count": 0,
|
| 64 |
+
"max_steps": 5,
|
| 65 |
+
"issues_found_so_far": ["issue types flagged in previous steps"],
|
| 66 |
+
"last_execution": {
|
| 67 |
+
"original_ms": 145.7,
|
| 68 |
+
"optimized_ms": 9.3,
|
| 69 |
+
"speedup": 15.67,
|
| 70 |
+
"results_match": true,
|
| 71 |
+
"verdict": "β
15.7x faster with correct results"
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
## β‘ Action Space
|
| 77 |
+
|
| 78 |
+
```json
|
| 79 |
+
{
|
| 80 |
+
"suggestions": [
|
| 81 |
+
{
|
| 82 |
+
"issue_type": "correlated_subquery",
|
| 83 |
+
"line": 4,
|
| 84 |
+
"description": "Correlated subquery scans 500k orders for each of 3,300 premium users",
|
| 85 |
+
"severity": "critical",
|
| 86 |
+
"fix": "Rewrite as LEFT JOIN with GROUP BY aggregation"
|
| 87 |
+
}
|
| 88 |
+
],
|
| 89 |
+
"optimized_query": "SELECT ... FROM users u LEFT JOIN (SELECT ...) s ON ...",
|
| 90 |
+
"summary": "Three correlated subqueries cause ~10M row reads. Single JOIN reduces this to one 500k-row scan.",
|
| 91 |
+
"estimated_improvement": "15-20x faster β eliminates N+1 subquery pattern",
|
| 92 |
+
"approved": false
|
| 93 |
+
}
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
---
|
| 97 |
+
|
| 98 |
+
## π Five Tasks
|
| 99 |
+
|
| 100 |
+
| # | Task | Difficulty | Key Anti-Pattern | Expected Speedup |
|
| 101 |
+
|---|---|---|---|---|
|
| 102 |
+
| 1 | Basic Anti-pattern Detection | Easy | SELECT \*, CAST on filter, YEAR() | 2β5x |
|
| 103 |
+
| 2 | N+1 Correlated Subquery Elimination | Medium | 3 correlated subqueries β JOIN | 8β25x |
|
| 104 |
+
| 3 | Wildcard LIKE & Projection | Medium-Hard | `LIKE '%purchase%'` on 1M rows | 3β10x |
|
| 105 |
+
| 4 | Implicit Cross Join & Scalar Subqueries | Hard | Comma-syntax join + 2 global aggregates | 10β30x |
|
| 106 |
+
| 5 | Window Function Full-Scan Audit | Expert | 5 OVER() on unfiltered 1M-row table | 5β20x |
|
| 107 |
+
|
| 108 |
+
---
|
| 109 |
+
|
| 110 |
+
## π Reward Function
|
| 111 |
+
|
| 112 |
+
| Component | Weight | Measured By |
|
| 113 |
+
|---|---|---|
|
| 114 |
+
| ποΈ Real Execution Speedup | **35%** | `original_ms / optimized_ms` via DuckDB |
|
| 115 |
+
| β
Result Correctness | **20%** | Sorted row-set equality check |
|
| 116 |
+
| π Issue Detection | **25%** | Keyword match vs ground truth |
|
| 117 |
+
| β
Approval Correctness | **8%** | Bool match vs expected |
|
| 118 |
+
| π Summary Quality | **7%** | Analysis length & depth |
|
| 119 |
+
| π·οΈ Severity Labels | **5%** | Severity values present |
|
| 120 |
+
|
| 121 |
+
---
|
| 122 |
+
|
| 123 |
+
## π‘ API Endpoints
|
| 124 |
+
|
| 125 |
+
| Endpoint | Method | Description |
|
| 126 |
+
|---|---|---|
|
| 127 |
+
| `/` | GET | Health check + table stats |
|
| 128 |
+
| `/reset` | POST | Start episode (`{"task_id": "..."}`) |
|
| 129 |
+
| `/step` | POST | Submit action β real execution |
|
| 130 |
+
| `/state` | GET | Current episode state |
|
| 131 |
+
| `/tasks` | GET | All 5 tasks with schema |
|
| 132 |
+
| `/grader` | POST | Grade without advancing episode |
|
| 133 |
+
| `/baseline` | POST | Run inference.py |
|
| 134 |
+
| **`/execute`** | POST | **Run your SQL against DuckDB, get timing + verdict** |
|
| 135 |
+
| **`/leaderboard`** | GET | **Real-time best scores & speedups per task** |
|
| 136 |
+
|
| 137 |
+
### π₯ Try /execute right now:
|
| 138 |
+
```bash
|
| 139 |
+
curl -X POST https://laterabhi-sql-query-env.hf.space/execute \
|
| 140 |
+
-H "Content-Type: application/json" \
|
| 141 |
+
-d '{
|
| 142 |
+
"task_id": "task_1_basic_antipatterns",
|
| 143 |
+
"optimized_query": "SELECT id, customer_id, status, total FROM orders WHERE customer_id = 5000 AND created_at >= DATE '\''2024-01-01'\'' AND created_at < DATE '\''2025-01-01'\''"
|
| 144 |
+
}'
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
---
|
| 148 |
+
|
| 149 |
+
## π Local Setup
|
| 150 |
+
|
| 151 |
+
```bash
|
| 152 |
+
git clone https://github.com/OfficialAbhinavSingh/SQL-Query-Optimization-Environment-
|
| 153 |
+
cd SQL-Query-Optimization-Environment-
|
| 154 |
+
pip install -r requirements.txt
|
| 155 |
+
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
```bash
|
| 159 |
+
# Run inference
|
| 160 |
+
export API_BASE_URL=https://router.huggingface.co/v1
|
| 161 |
+
export MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
|
| 162 |
+
export HF_TOKEN=hf_...
|
| 163 |
+
python inference.py
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
---
|
| 167 |
+
|
| 168 |
+
## π Baseline Scores (Qwen2.5-72B)
|
| 169 |
+
|
| 170 |
+
| Task | Score | Speedup | Correct? |
|
| 171 |
+
|---|---|---|---|
|
| 172 |
+
| Basic Anti-patterns (Easy) | ~0.82 | ~4x | β
|
|
| 173 |
+
| N+1 Subqueries (Medium) | ~0.71 | ~12x | β
|
|
| 174 |
+
| Wildcard LIKE (Medium-Hard) | ~0.60 | ~6x | β
|
|
| 175 |
+
| Implicit Join (Hard) | ~0.52 | ~8x | β
|
|
| 176 |
+
| Window Functions (Expert) | ~0.44 | ~7x | β
|
|
| 177 |
+
|
| 178 |
+
---
|
| 179 |
+
|
| 180 |
+
*Built with β€οΈ for the OpenEnv Hackathon β Phase 1 & 2 Validated*
|
|
@@ -1,88 +1,132 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from graders import grade
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class SQLOptimEnv:
|
| 8 |
"""
|
| 9 |
OpenEnv-compliant environment for SQL Query Optimization.
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
-
def __init__(self):
|
| 17 |
-
self._task_data: Optional[
|
| 18 |
self._step_count: int = 0
|
| 19 |
self._done: bool = False
|
| 20 |
self._cumulative_reward: float = 0.0
|
| 21 |
self._issues_found: list = []
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
-
def reset(
|
| 24 |
-
|
|
|
|
| 25 |
if task_id not in TASKS:
|
| 26 |
raise ValueError(
|
| 27 |
f"Unknown task_id '{task_id}'. "
|
| 28 |
-
f"Valid
|
| 29 |
)
|
| 30 |
self._task_data = TASKS[task_id]
|
| 31 |
self._step_count = 0
|
| 32 |
self._done = False
|
| 33 |
self._cumulative_reward = 0.0
|
| 34 |
self._issues_found = []
|
| 35 |
-
|
| 36 |
-
return self.
|
| 37 |
|
| 38 |
def step(self, action: Action) -> StepResult:
|
| 39 |
-
"""Process one agent action and return (observation, reward, done, info)."""
|
| 40 |
if self._task_data is None:
|
| 41 |
-
raise RuntimeError("
|
| 42 |
if self._done:
|
| 43 |
-
raise RuntimeError("Episode
|
| 44 |
|
| 45 |
self._step_count += 1
|
| 46 |
|
| 47 |
-
# Grade
|
| 48 |
reward: Reward = grade(self._task_data, action)
|
| 49 |
self._cumulative_reward += reward.score
|
| 50 |
|
| 51 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
for s in action.suggestions:
|
| 53 |
-
|
| 54 |
-
if
|
| 55 |
-
self._issues_found.append(
|
| 56 |
|
| 57 |
-
|
| 58 |
-
max_steps = self._task_data["max_steps"]
|
| 59 |
done = self._step_count >= max_steps or reward.score >= 0.95
|
| 60 |
-
|
| 61 |
self._done = done
|
| 62 |
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
return StepResult(
|
| 66 |
-
observation=
|
| 67 |
reward=reward,
|
| 68 |
done=done,
|
| 69 |
info={
|
| 70 |
-
"step":
|
| 71 |
"cumulative_reward": round(self._cumulative_reward, 4),
|
| 72 |
-
"
|
| 73 |
-
|
|
|
|
| 74 |
)
|
| 75 |
|
| 76 |
def state(self) -> EnvironmentState:
|
| 77 |
-
"""Return current environment state (for /state endpoint)."""
|
| 78 |
if self._task_data is None:
|
| 79 |
return EnvironmentState(
|
| 80 |
-
task_id="none",
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
episode_done=True,
|
| 84 |
-
cumulative_reward=0.0,
|
| 85 |
-
current_task="No active episode"
|
| 86 |
)
|
| 87 |
return EnvironmentState(
|
| 88 |
task_id=self._task_data["task_id"],
|
|
@@ -93,7 +137,9 @@ class SQLOptimEnv:
|
|
| 93 |
current_task=self._task_data["task_name"],
|
| 94 |
)
|
| 95 |
|
| 96 |
-
|
|
|
|
|
|
|
| 97 |
d = self._task_data
|
| 98 |
return Observation(
|
| 99 |
task_id=d["task_id"],
|
|
@@ -101,9 +147,10 @@ class SQLOptimEnv:
|
|
| 101 |
task_description=d["task_description"],
|
| 102 |
sql_query=d["sql_query"],
|
| 103 |
schema_info=d["schema_info"],
|
| 104 |
-
dialect=d.get("dialect", "postgresql"),
|
| 105 |
difficulty=d["difficulty"],
|
| 106 |
step_count=self._step_count,
|
| 107 |
max_steps=d["max_steps"],
|
| 108 |
issues_found_so_far=list(self._issues_found),
|
|
|
|
| 109 |
)
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
env.py β SQLOptimEnv: Core OpenEnv Environment Class
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Any, Dict, Optional
|
| 6 |
+
|
| 7 |
+
from executor import get_executor
|
| 8 |
from graders import grade
|
| 9 |
+
from leaderboard import record as lb_record
|
| 10 |
+
from models import (
|
| 11 |
+
Action,
|
| 12 |
+
EnvironmentState,
|
| 13 |
+
Observation,
|
| 14 |
+
Reward,
|
| 15 |
+
StepResult,
|
| 16 |
+
)
|
| 17 |
+
from tasks import TASKS
|
| 18 |
|
| 19 |
|
| 20 |
class SQLOptimEnv:
|
| 21 |
"""
|
| 22 |
OpenEnv-compliant environment for SQL Query Optimization.
|
| 23 |
|
| 24 |
+
The agent receives a SQL query + schema context, emits an Action
|
| 25 |
+
containing a list of optimization suggestions AND a rewritten
|
| 26 |
+
optimized_query. The environment executes both queries against
|
| 27 |
+
real DuckDB data, measures the actual speedup, and checks
|
| 28 |
+
result correctness β all fed into the reward function.
|
| 29 |
+
|
| 30 |
+
Multi-step:
|
| 31 |
+
β’ issues_found_so_far accumulates flagged issue types.
|
| 32 |
+
β’ last_execution carries execution metrics back to the agent
|
| 33 |
+
so it can refine the optimized_query in subsequent steps.
|
| 34 |
"""
|
| 35 |
|
| 36 |
+
def __init__(self) -> None:
|
| 37 |
+
self._task_data: Optional[Dict[str, Any]] = None
|
| 38 |
self._step_count: int = 0
|
| 39 |
self._done: bool = False
|
| 40 |
self._cumulative_reward: float = 0.0
|
| 41 |
self._issues_found: list = []
|
| 42 |
+
self._last_execution: Optional[Dict[str, Any]] = None
|
| 43 |
+
|
| 44 |
+
# ββ OpenEnv interface βββββββββββββββββββββββββββββββββββββββββββββ
|
| 45 |
|
| 46 |
+
def reset(
|
| 47 |
+
self, task_id: str = "task_1_basic_antipatterns"
|
| 48 |
+
) -> Observation:
|
| 49 |
if task_id not in TASKS:
|
| 50 |
raise ValueError(
|
| 51 |
f"Unknown task_id '{task_id}'. "
|
| 52 |
+
f"Valid: {list(TASKS.keys())}"
|
| 53 |
)
|
| 54 |
self._task_data = TASKS[task_id]
|
| 55 |
self._step_count = 0
|
| 56 |
self._done = False
|
| 57 |
self._cumulative_reward = 0.0
|
| 58 |
self._issues_found = []
|
| 59 |
+
self._last_execution = None
|
| 60 |
+
return self._make_obs()
|
| 61 |
|
| 62 |
def step(self, action: Action) -> StepResult:
|
|
|
|
| 63 |
if self._task_data is None:
|
| 64 |
+
raise RuntimeError("No active episode β call reset() first.")
|
| 65 |
if self._done:
|
| 66 |
+
raise RuntimeError("Episode finished β call reset() to start a new one.")
|
| 67 |
|
| 68 |
self._step_count += 1
|
| 69 |
|
| 70 |
+
# Grade (runs DuckDB internally)
|
| 71 |
reward: Reward = grade(self._task_data, action)
|
| 72 |
self._cumulative_reward += reward.score
|
| 73 |
|
| 74 |
+
# Extract execution info from grader feedback for next obs
|
| 75 |
+
opt_q = (action.optimized_query or "").strip()
|
| 76 |
+
if opt_q:
|
| 77 |
+
try:
|
| 78 |
+
ex = get_executor()
|
| 79 |
+
self._last_execution = ex.compare(
|
| 80 |
+
self._task_data["sql_query"], opt_q
|
| 81 |
+
)
|
| 82 |
+
except Exception:
|
| 83 |
+
self._last_execution = None
|
| 84 |
+
|
| 85 |
+
# Track issue types for progressive context
|
| 86 |
for s in action.suggestions:
|
| 87 |
+
itype = s.get("issue_type", "")
|
| 88 |
+
if itype and itype not in self._issues_found:
|
| 89 |
+
self._issues_found.append(itype)
|
| 90 |
|
| 91 |
+
max_steps: int = self._task_data["max_steps"]
|
|
|
|
| 92 |
done = self._step_count >= max_steps or reward.score >= 0.95
|
|
|
|
| 93 |
self._done = done
|
| 94 |
|
| 95 |
+
# Update leaderboard
|
| 96 |
+
speedup = (
|
| 97 |
+
self._last_execution.get("speedup", 1.0)
|
| 98 |
+
if self._last_execution else 1.0
|
| 99 |
+
)
|
| 100 |
+
results_match = (
|
| 101 |
+
self._last_execution.get("results_match", False)
|
| 102 |
+
if self._last_execution else False
|
| 103 |
+
)
|
| 104 |
+
lb_record(
|
| 105 |
+
task_id=self._task_data["task_id"],
|
| 106 |
+
speedup=speedup,
|
| 107 |
+
score=reward.score,
|
| 108 |
+
results_match=results_match,
|
| 109 |
+
steps=self._step_count,
|
| 110 |
+
)
|
| 111 |
|
| 112 |
return StepResult(
|
| 113 |
+
observation=self._make_obs(),
|
| 114 |
reward=reward,
|
| 115 |
done=done,
|
| 116 |
info={
|
| 117 |
+
"step": self._step_count,
|
| 118 |
"cumulative_reward": round(self._cumulative_reward, 4),
|
| 119 |
+
"issues_found": len(self._issues_found),
|
| 120 |
+
"execution": self._last_execution,
|
| 121 |
+
},
|
| 122 |
)
|
| 123 |
|
| 124 |
def state(self) -> EnvironmentState:
|
|
|
|
| 125 |
if self._task_data is None:
|
| 126 |
return EnvironmentState(
|
| 127 |
+
task_id="none", step_count=0, max_steps=0,
|
| 128 |
+
episode_done=True, cumulative_reward=0.0,
|
| 129 |
+
current_task="No active episode",
|
|
|
|
|
|
|
|
|
|
| 130 |
)
|
| 131 |
return EnvironmentState(
|
| 132 |
task_id=self._task_data["task_id"],
|
|
|
|
| 137 |
current_task=self._task_data["task_name"],
|
| 138 |
)
|
| 139 |
|
| 140 |
+
# ββ Internal ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 141 |
+
|
| 142 |
+
def _make_obs(self) -> Observation:
|
| 143 |
d = self._task_data
|
| 144 |
return Observation(
|
| 145 |
task_id=d["task_id"],
|
|
|
|
| 147 |
task_description=d["task_description"],
|
| 148 |
sql_query=d["sql_query"],
|
| 149 |
schema_info=d["schema_info"],
|
| 150 |
+
dialect=d.get("dialect", "duckdb/postgresql"),
|
| 151 |
difficulty=d["difficulty"],
|
| 152 |
step_count=self._step_count,
|
| 153 |
max_steps=d["max_steps"],
|
| 154 |
issues_found_so_far=list(self._issues_found),
|
| 155 |
+
last_execution=self._last_execution,
|
| 156 |
)
|
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
executor.py β DuckDB In-Memory SQL Execution Engine
|
| 3 |
+
=====================================================
|
| 4 |
+
The core innovation of this environment: instead of keyword-matching
|
| 5 |
+
heuristics, we ACTUALLY execute both the original and optimized queries
|
| 6 |
+
against realistic synthetic data and measure real performance differences.
|
| 7 |
+
|
| 8 |
+
Tables populated:
|
| 9 |
+
users β 10,000 rows
|
| 10 |
+
orders β 500,000 rows
|
| 11 |
+
products β 1,000 rows
|
| 12 |
+
events β 1,000,000 rows
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import threading
|
| 16 |
+
import time
|
| 17 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 18 |
+
|
| 19 |
+
import duckdb
|
| 20 |
+
|
| 21 |
+
_instance: Optional["QueryExecutor"] = None
|
| 22 |
+
_lock = threading.Lock()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class QueryExecutor:
|
| 26 |
+
"""
|
| 27 |
+
Runs SQL against an in-memory DuckDB database with realistic
|
| 28 |
+
synthetic data. Provides execution timing, result correctness
|
| 29 |
+
checks, and EXPLAIN plans β all used by the reward function.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
def __init__(self) -> None:
|
| 33 |
+
self.conn = duckdb.connect(database=":memory:")
|
| 34 |
+
self.conn.execute("SET threads=2")
|
| 35 |
+
self._build_tables()
|
| 36 |
+
|
| 37 |
+
# ββ Schema Setup βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
+
|
| 39 |
+
def _build_tables(self) -> None:
|
| 40 |
+
"""Create and populate all four synthetic tables."""
|
| 41 |
+
|
| 42 |
+
# users β 10k rows
|
| 43 |
+
self.conn.execute("""
|
| 44 |
+
CREATE TABLE users AS
|
| 45 |
+
SELECT
|
| 46 |
+
i AS id,
|
| 47 |
+
'u' || i || '@mail.com' AS email,
|
| 48 |
+
CASE i % 3
|
| 49 |
+
WHEN 0 THEN 'premium'
|
| 50 |
+
WHEN 1 THEN 'free'
|
| 51 |
+
ELSE 'enterprise' END AS tier,
|
| 52 |
+
CASE i % 5
|
| 53 |
+
WHEN 0 THEN 'US' WHEN 1 THEN 'EU'
|
| 54 |
+
WHEN 2 THEN 'IN' WHEN 3 THEN 'UK'
|
| 55 |
+
ELSE 'AU' END AS region,
|
| 56 |
+
CASE i % 2 WHEN 0 THEN 'premium' ELSE 'basic' END AS plan,
|
| 57 |
+
DATE '2020-01-01' + CAST(i AS INTEGER) AS created_at
|
| 58 |
+
FROM generate_series(1, 10000) t(i)
|
| 59 |
+
""")
|
| 60 |
+
|
| 61 |
+
# orders β 500k rows
|
| 62 |
+
self.conn.execute("""
|
| 63 |
+
CREATE TABLE orders AS
|
| 64 |
+
SELECT
|
| 65 |
+
i AS id,
|
| 66 |
+
1 + (i % 10000) AS customer_id,
|
| 67 |
+
(i % 100) + 1 AS product_id,
|
| 68 |
+
CASE i % 4
|
| 69 |
+
WHEN 0 THEN 'completed' WHEN 1 THEN 'pending'
|
| 70 |
+
WHEN 2 THEN 'cancelled' ELSE 'shipped' END AS status,
|
| 71 |
+
ROUND((i % 1000) * 1.5 + 49.99, 2) AS total,
|
| 72 |
+
DATE '2023-01-01' + CAST(i % 730 AS INTEGER) AS created_at
|
| 73 |
+
FROM generate_series(1, 500000) t(i)
|
| 74 |
+
""")
|
| 75 |
+
|
| 76 |
+
# products β 1k rows
|
| 77 |
+
self.conn.execute("""
|
| 78 |
+
CREATE TABLE products AS
|
| 79 |
+
SELECT
|
| 80 |
+
i AS id,
|
| 81 |
+
'Product_' || i AS name,
|
| 82 |
+
CASE i % 5
|
| 83 |
+
WHEN 0 THEN 'Electronics' WHEN 1 THEN 'Clothing'
|
| 84 |
+
WHEN 2 THEN 'Food' WHEN 3 THEN 'Books'
|
| 85 |
+
ELSE 'Sports' END AS category,
|
| 86 |
+
ROUND((i % 500) + 9.99, 2) AS price
|
| 87 |
+
FROM generate_series(1, 1000) t(i)
|
| 88 |
+
""")
|
| 89 |
+
|
| 90 |
+
# events β 1M rows
|
| 91 |
+
self.conn.execute("""
|
| 92 |
+
CREATE TABLE events AS
|
| 93 |
+
SELECT
|
| 94 |
+
i AS id,
|
| 95 |
+
1 + (i % 10000) AS user_id,
|
| 96 |
+
'sess_' || (i % 50000) AS session_id,
|
| 97 |
+
CASE i % 6
|
| 98 |
+
WHEN 0 THEN 'purchase' WHEN 1 THEN 'view'
|
| 99 |
+
WHEN 2 THEN 'click' WHEN 3 THEN 'signup'
|
| 100 |
+
WHEN 4 THEN 'logout' ELSE 'search' END AS event_type,
|
| 101 |
+
DATE '2024-01-01' + CAST(i % 365 AS INTEGER) AS occurred_at
|
| 102 |
+
FROM generate_series(1, 1000000) t(i)
|
| 103 |
+
""")
|
| 104 |
+
|
| 105 |
+
# ββ Execution helpers βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 106 |
+
|
| 107 |
+
def _run(
|
| 108 |
+
self, query: str, runs: int = 3
|
| 109 |
+
) -> Tuple[float, Optional[List], Optional[str]]:
|
| 110 |
+
"""
|
| 111 |
+
Execute *query* up to *runs* times.
|
| 112 |
+
Returns (median_ms, rows, error_or_None).
|
| 113 |
+
"""
|
| 114 |
+
timings: List[float] = []
|
| 115 |
+
rows: Optional[List] = None
|
| 116 |
+
|
| 117 |
+
for _ in range(runs):
|
| 118 |
+
try:
|
| 119 |
+
t0 = time.perf_counter()
|
| 120 |
+
rows = self.conn.execute(query).fetchall()
|
| 121 |
+
timings.append((time.perf_counter() - t0) * 1000.0)
|
| 122 |
+
except Exception as exc:
|
| 123 |
+
return 99_999.0, None, str(exc)
|
| 124 |
+
|
| 125 |
+
timings.sort()
|
| 126 |
+
return round(timings[len(timings) // 2], 3), rows, None
|
| 127 |
+
|
| 128 |
+
# ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 129 |
+
|
| 130 |
+
def compare(self, original: str, optimized: str) -> Dict[str, Any]:
|
| 131 |
+
"""
|
| 132 |
+
Execute both queries, measure real timing, check correctness.
|
| 133 |
+
|
| 134 |
+
Returns a dict with:
|
| 135 |
+
original_ms, optimized_ms, speedup,
|
| 136 |
+
results_match, original_rows, optimized_rows,
|
| 137 |
+
original_error, optimized_error, verdict
|
| 138 |
+
"""
|
| 139 |
+
orig_ms, orig_rows, orig_err = self._run(original)
|
| 140 |
+
opt_ms, opt_rows, opt_err = self._run(optimized)
|
| 141 |
+
|
| 142 |
+
# ββ Correctness: do both queries return the same data? ββββββββ
|
| 143 |
+
results_match = False
|
| 144 |
+
if orig_rows is not None and opt_rows is not None:
|
| 145 |
+
try:
|
| 146 |
+
orig_s = sorted(str(r) for r in orig_rows)
|
| 147 |
+
opt_s = sorted(str(r) for r in opt_rows)
|
| 148 |
+
results_match = orig_s == opt_s
|
| 149 |
+
except Exception:
|
| 150 |
+
results_match = len(orig_rows) == len(opt_rows)
|
| 151 |
+
|
| 152 |
+
# ββ Speedup ratio βββββββββββββββββββββββββββββββββββββββββββββ
|
| 153 |
+
speedup = 1.0
|
| 154 |
+
if opt_ms > 0 and orig_ms < 90_000:
|
| 155 |
+
speedup = round(orig_ms / opt_ms, 3)
|
| 156 |
+
|
| 157 |
+
# ββ Human-readable verdict ββββββββββββββββββββββββββββββββββββ
|
| 158 |
+
if opt_err:
|
| 159 |
+
verdict = f"[FAIL] Optimized query error: {opt_err[:120]}"
|
| 160 |
+
elif results_match and speedup >= 2.0:
|
| 161 |
+
verdict = f"[OK] {speedup:.1f}x faster with correct results"
|
| 162 |
+
elif results_match and speedup >= 1.0:
|
| 163 |
+
verdict = f"[WARN] Correct results but only {speedup:.1f}x speedup -- dig deeper"
|
| 164 |
+
elif not results_match and speedup >= 2.0:
|
| 165 |
+
verdict = f"[WARN] {speedup:.1f}x faster but results don't match -- fix the logic"
|
| 166 |
+
else:
|
| 167 |
+
verdict = f"[FAIL] {speedup:.1f}x -- no meaningful improvement"
|
| 168 |
+
|
| 169 |
+
return {
|
| 170 |
+
"original_ms": orig_ms,
|
| 171 |
+
"optimized_ms": opt_ms,
|
| 172 |
+
"speedup": speedup,
|
| 173 |
+
"results_match": results_match,
|
| 174 |
+
"original_rows": len(orig_rows) if orig_rows is not None else 0,
|
| 175 |
+
"optimized_rows": len(opt_rows) if opt_rows is not None else 0,
|
| 176 |
+
"original_error": orig_err,
|
| 177 |
+
"optimized_error": opt_err,
|
| 178 |
+
"verdict": verdict,
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
def explain(self, query: str) -> str:
|
| 182 |
+
"""Return EXPLAIN output for a query."""
|
| 183 |
+
try:
|
| 184 |
+
rows = self.conn.execute(f"EXPLAIN {query}").fetchall()
|
| 185 |
+
return "\n".join(str(r[1]) for r in rows)
|
| 186 |
+
except Exception as exc:
|
| 187 |
+
return f"EXPLAIN error: {exc}"
|
| 188 |
+
|
| 189 |
+
@property
|
| 190 |
+
def table_stats(self) -> Dict[str, int]:
|
| 191 |
+
tables = ["users", "orders", "products", "events"]
|
| 192 |
+
return {
|
| 193 |
+
t: self.conn.execute(f"SELECT COUNT(*) FROM {t}").fetchone()[0]
|
| 194 |
+
for t in tables
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ββ Singleton accessor ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 199 |
+
|
| 200 |
+
def get_executor() -> QueryExecutor:
|
| 201 |
+
"""Return the process-level singleton (lazy init, thread-safe)."""
|
| 202 |
+
global _instance
|
| 203 |
+
if _instance is None:
|
| 204 |
+
with _lock:
|
| 205 |
+
if _instance is None:
|
| 206 |
+
_instance = QueryExecutor()
|
| 207 |
+
return _instance
|
|
@@ -1,126 +1,176 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from models import Action, Reward
|
| 3 |
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def _suggestions_text(action: Action) -> str:
|
| 12 |
-
"""Flatten all suggestion fields into one searchable string."""
|
| 13 |
parts = [action.summary, action.optimized_query, action.estimated_improvement]
|
| 14 |
for s in action.suggestions:
|
| 15 |
-
parts
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
| 20 |
return " ".join(parts)
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
def grade(task_data: Dict[str, Any], action: Action) -> Reward:
|
| 24 |
-
""
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
Scoring breakdown:
|
| 28 |
-
- Issue Detection: 60% (did agent find the right problems?)
|
| 29 |
-
- Optimized Query Quality: 15% (did agent provide a meaningful rewrite?)
|
| 30 |
-
- Approval Correctness: 10% (correctly flagged as needing changes?)
|
| 31 |
-
- Summary Quality: 8% (is the summary thorough and informative?)
|
| 32 |
-
- Improvement Estimate: 4% (did agent quantify the expected gain?)
|
| 33 |
-
- Severity Labels: 3% (are severity levels present?)
|
| 34 |
-
"""
|
| 35 |
ground_truth: List[Dict[str, Any]] = task_data["ground_truth_issues"]
|
| 36 |
full_text = _suggestions_text(action)
|
| 37 |
|
| 38 |
-
# ββ 1.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
detected = 0
|
| 40 |
-
|
| 41 |
-
for
|
| 42 |
-
found =
|
| 43 |
if found:
|
| 44 |
detected += 1
|
| 45 |
-
|
| 46 |
else:
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
# ββ 2. Optimized Query Quality (0.0β0.15) ββββββββββββββββββββββββββ
|
| 52 |
-
query_score = 0.0
|
| 53 |
-
oq = action.optimized_query.strip()
|
| 54 |
-
if len(oq) > 50:
|
| 55 |
-
query_score = 0.05
|
| 56 |
-
if len(oq) > 150:
|
| 57 |
-
query_score = 0.10
|
| 58 |
-
# Bonus if the rewrite removes obvious anti-patterns found in original
|
| 59 |
-
original_query = task_data["sql_query"].lower()
|
| 60 |
-
if "select *" in original_query and "select *" not in oq.lower():
|
| 61 |
-
query_score = min(query_score + 0.03, 0.15)
|
| 62 |
-
if query_score < 0.15 and len(action.suggestions) > 0 and len(oq) > 100:
|
| 63 |
-
query_score = min(query_score + 0.02, 0.15)
|
| 64 |
-
query_score = min(query_score, 0.15)
|
| 65 |
-
|
| 66 |
-
# ββ 3. Approval Correctness (0.0β0.10) βββββββββββββββββββββββββββββ
|
| 67 |
expected_approved = task_data.get("approved_expected", False)
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
# ββ 4. Summary Quality (0.0β0.
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
# ββ 6. Severity Labels Present (0.0β0.03) ββββββββββββββββββββββββββ
|
| 83 |
-
severity_keywords = ["critical", "high", "medium", "low"]
|
| 84 |
-
has_severity = any(
|
| 85 |
-
_keyword_match(str(s.get("severity", "")), severity_keywords)
|
| 86 |
-
for s in action.suggestions
|
| 87 |
)
|
| 88 |
-
|
| 89 |
|
| 90 |
-
# ββ
|
| 91 |
-
total = (
|
| 92 |
-
|
| 93 |
-
|
|
|
|
| 94 |
)
|
| 95 |
-
total = round(
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
if total == 0.0 and len(action.suggestions) > 0:
|
| 99 |
-
total = 0.02
|
| 100 |
|
| 101 |
breakdown = {
|
| 102 |
-
"
|
| 103 |
-
"
|
| 104 |
-
"
|
| 105 |
-
"
|
| 106 |
-
"
|
| 107 |
-
"severity_labels":
|
| 108 |
}
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
return Reward(
|
| 123 |
-
score=total,
|
| 124 |
-
breakdown=breakdown,
|
| 125 |
-
feedback="\n".join(feedback_lines)
|
| 126 |
)
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
graders.py β Execution-Grounded Reward Function
|
| 3 |
+
=================================================
|
| 4 |
+
What makes this environment unique: reward is computed from REAL
|
| 5 |
+
DuckDB execution results, not just keyword heuristics.
|
| 6 |
+
|
| 7 |
+
Scoring breakdown (sums to 1.0):
|
| 8 |
+
Real Execution Speedup 35% β actual timing ratio from DuckDB
|
| 9 |
+
Result Correctness 20% β both queries return identical data?
|
| 10 |
+
Issue Detection 25% β keyword match vs ground truth
|
| 11 |
+
Approval Correctness 8% β correctly flags query as bad?
|
| 12 |
+
Summary Quality 7% β is the written analysis thorough?
|
| 13 |
+
Severity Labels 5% β are severity values present?
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from typing import Any, Dict, List, Optional
|
| 17 |
+
|
| 18 |
+
from executor import get_executor
|
| 19 |
from models import Action, Reward
|
| 20 |
|
| 21 |
|
| 22 |
+
# ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 23 |
+
|
| 24 |
+
def _kw_match(text: str, keywords: List[str]) -> bool:
|
| 25 |
+
t = text.lower()
|
| 26 |
+
return any(kw.lower() in t for kw in keywords)
|
| 27 |
|
| 28 |
|
| 29 |
def _suggestions_text(action: Action) -> str:
|
|
|
|
| 30 |
parts = [action.summary, action.optimized_query, action.estimated_improvement]
|
| 31 |
for s in action.suggestions:
|
| 32 |
+
parts += [
|
| 33 |
+
str(s.get("issue_type", "")),
|
| 34 |
+
str(s.get("description", "")),
|
| 35 |
+
str(s.get("fix", "")),
|
| 36 |
+
str(s.get("severity", "")),
|
| 37 |
+
]
|
| 38 |
return " ".join(parts)
|
| 39 |
|
| 40 |
|
| 41 |
+
# ββ Speedup β score mapping βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
+
|
| 43 |
+
def _speedup_score(speedup: float, has_error: bool) -> float:
|
| 44 |
+
"""Map real speedup ratio to a score in [0.0, 0.35]."""
|
| 45 |
+
if has_error:
|
| 46 |
+
return 0.0
|
| 47 |
+
if speedup >= 15.0:
|
| 48 |
+
return 0.35
|
| 49 |
+
if speedup >= 8.0:
|
| 50 |
+
return 0.30
|
| 51 |
+
if speedup >= 4.0:
|
| 52 |
+
return 0.25
|
| 53 |
+
if speedup >= 2.0:
|
| 54 |
+
return 0.18
|
| 55 |
+
if speedup >= 1.2:
|
| 56 |
+
return 0.10
|
| 57 |
+
if speedup >= 0.9: # slightly slower β acceptable
|
| 58 |
+
return 0.04
|
| 59 |
+
return 0.0 # regression
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ββ Main grader βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 63 |
+
|
| 64 |
def grade(task_data: Dict[str, Any], action: Action) -> Reward:
|
| 65 |
+
original_query: str = task_data["sql_query"]
|
| 66 |
+
optimized_query: str = (action.optimized_query or "").strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
ground_truth: List[Dict[str, Any]] = task_data["ground_truth_issues"]
|
| 68 |
full_text = _suggestions_text(action)
|
| 69 |
|
| 70 |
+
# ββ 1. Real Execution (0.0β0.35) βββββββββββββββββββββββββββββββββ
|
| 71 |
+
exec_info: Dict[str, Any] = {}
|
| 72 |
+
speedup_sc = 0.0
|
| 73 |
+
correctness_sc = 0.0
|
| 74 |
+
exec_feedback: List[str] = []
|
| 75 |
+
|
| 76 |
+
if optimized_query:
|
| 77 |
+
try:
|
| 78 |
+
ex = get_executor()
|
| 79 |
+
exec_info = ex.compare(original_query, optimized_query)
|
| 80 |
+
speedup = exec_info.get("speedup", 1.0)
|
| 81 |
+
r_match = exec_info.get("results_match", False)
|
| 82 |
+
opt_err = exec_info.get("optimized_error")
|
| 83 |
+
|
| 84 |
+
# 1a. Speedup score
|
| 85 |
+
speedup_sc = _speedup_score(speedup, bool(opt_err))
|
| 86 |
+
|
| 87 |
+
# 1b. Correctness score (0.0-0.20)
|
| 88 |
+
if opt_err:
|
| 89 |
+
correctness_sc = 0.0
|
| 90 |
+
elif r_match:
|
| 91 |
+
correctness_sc = 0.20
|
| 92 |
+
elif exec_info.get("optimized_rows", 0) > 0:
|
| 93 |
+
# Query ran but different results -- partial credit
|
| 94 |
+
correctness_sc = 0.05
|
| 95 |
+
|
| 96 |
+
# Feedback lines
|
| 97 |
+
exec_feedback = [
|
| 98 |
+
"\n[DuckDB Execution Results]",
|
| 99 |
+
f" Original : {exec_info['original_ms']:.1f} ms "
|
| 100 |
+
f"({exec_info['original_rows']} rows)",
|
| 101 |
+
f" Optimized : {exec_info['optimized_ms']:.1f} ms "
|
| 102 |
+
f"({exec_info['optimized_rows']} rows)",
|
| 103 |
+
f" Speedup : {speedup:.2f}x",
|
| 104 |
+
f" Correct? : {'YES' if r_match else 'NO -- results differ'}",
|
| 105 |
+
f" Verdict : {exec_info.get('verdict', '')}",
|
| 106 |
+
]
|
| 107 |
+
if opt_err:
|
| 108 |
+
exec_feedback.append(f" SQL Error : {opt_err[:200]}")
|
| 109 |
+
|
| 110 |
+
except Exception as exc:
|
| 111 |
+
exec_feedback = [f"\n[WARN] Execution engine error: {exc}"]
|
| 112 |
+
|
| 113 |
+
# ββ 2. Issue Detection (0.0β0.25) ββββββββββββββββββββββββββββββββ
|
| 114 |
detected = 0
|
| 115 |
+
detection_fb: List[str] = ["\n[Issue Detection]"]
|
| 116 |
+
for gt in ground_truth:
|
| 117 |
+
found = _kw_match(full_text, gt["keywords"])
|
| 118 |
if found:
|
| 119 |
detected += 1
|
| 120 |
+
detection_fb.append(f" [FOUND] {gt['type']} (line ~{gt['line']})")
|
| 121 |
else:
|
| 122 |
+
detection_fb.append(f" [MISS ] {gt['type']} (line ~{gt['line']})")
|
| 123 |
+
detection_sc = (detected / len(ground_truth)) * 0.25 if ground_truth else 0.0
|
| 124 |
+
|
| 125 |
+
# ββ 3. Approval Correctness (0.0β0.08) βββββββββββββββββββββββββββ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
expected_approved = task_data.get("approved_expected", False)
|
| 127 |
+
approval_sc = 0.08 if action.approved == expected_approved else 0.0
|
| 128 |
+
|
| 129 |
+
# ββ 4. Summary Quality (0.0β0.07) ββββββββββββββββββββββββββββββββ
|
| 130 |
+
summary_sc = 0.0
|
| 131 |
+
slen = len(action.summary)
|
| 132 |
+
if slen > 50:
|
| 133 |
+
summary_sc = 0.03
|
| 134 |
+
if slen > 120:
|
| 135 |
+
summary_sc = 0.07
|
| 136 |
+
|
| 137 |
+
# ββ 5. Severity Labels (0.0β0.05) ββββββββββββββββββββββββββββββββ
|
| 138 |
+
sev_kw = ["critical", "high", "medium", "low"]
|
| 139 |
+
has_sev = any(
|
| 140 |
+
_kw_match(str(s.get("severity", "")), sev_kw) for s in action.suggestions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
)
|
| 142 |
+
severity_sc = 0.05 if has_sev else 0.0
|
| 143 |
|
| 144 |
+
# ββ Total βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 145 |
+
total = min(
|
| 146 |
+
max(speedup_sc + correctness_sc + detection_sc +
|
| 147 |
+
approval_sc + summary_sc + severity_sc, 0.0),
|
| 148 |
+
1.0,
|
| 149 |
)
|
| 150 |
+
total = round(total, 4)
|
| 151 |
+
if total == 0.0 and action.suggestions:
|
| 152 |
+
total = 0.02 # minimum signal for any submission
|
|
|
|
|
|
|
| 153 |
|
| 154 |
breakdown = {
|
| 155 |
+
"execution_speedup": round(speedup_sc, 4),
|
| 156 |
+
"result_correctness": round(correctness_sc, 4),
|
| 157 |
+
"issue_detection": round(detection_sc, 4),
|
| 158 |
+
"approval_correctness": round(approval_sc, 4),
|
| 159 |
+
"summary_quality": round(summary_sc, 4),
|
| 160 |
+
"severity_labels": round(severity_sc, 4),
|
| 161 |
}
|
| 162 |
|
| 163 |
+
feedback = "\n".join(
|
| 164 |
+
exec_feedback
|
| 165 |
+
+ detection_fb
|
| 166 |
+
+ [
|
| 167 |
+
f"\n Suggestions submitted: {len(action.suggestions)} "
|
| 168 |
+
f"(expected ~{len(ground_truth)})",
|
| 169 |
+
f" Approval: {'β
' if action.approved == expected_approved else 'β'} "
|
| 170 |
+
f"(got {'approved' if action.approved else 'rejected'}, "
|
| 171 |
+
f"expected {'approved' if expected_approved else 'rejected'})",
|
| 172 |
+
f"\nπ Total score: {total:.4f}",
|
| 173 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
)
|
| 175 |
+
|
| 176 |
+
return Reward(score=total, breakdown=breakdown, feedback=feedback)
|
|
@@ -1,179 +1,196 @@
|
|
| 1 |
"""
|
| 2 |
inference.py β SQL Query Optimization Environment
|
| 3 |
===================================================
|
| 4 |
-
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
HF_TOKEN Your HuggingFace / API key
|
| 10 |
|
| 11 |
stdout format (strictly followed):
|
| 12 |
-
[START] task=<
|
| 13 |
-
[STEP] step=<n> action=<
|
| 14 |
-
[END] success=<
|
| 15 |
"""
|
| 16 |
|
| 17 |
-
import os
|
| 18 |
import json
|
|
|
|
| 19 |
import sys
|
| 20 |
-
from typing import List, Optional
|
|
|
|
| 21 |
from openai import OpenAI
|
| 22 |
|
| 23 |
-
# ββ Resolve paths so we can import env/models from root ββββββββββββββββββ
|
| 24 |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 25 |
sys.path.insert(0, ROOT_DIR)
|
| 26 |
|
| 27 |
from env import SQLOptimEnv
|
| 28 |
from models import Action
|
| 29 |
|
| 30 |
-
# ββ
|
| 31 |
API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 32 |
MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 33 |
HF_TOKEN = os.environ.get("HF_TOKEN", "") or os.environ.get("API_KEY", "")
|
| 34 |
|
| 35 |
-
BENCHMARK
|
| 36 |
-
TEMPERATURE
|
| 37 |
-
MAX_TOKENS
|
| 38 |
|
| 39 |
TASK_IDS = [
|
| 40 |
"task_1_basic_antipatterns",
|
| 41 |
-
"
|
| 42 |
-
"
|
|
|
|
|
|
|
| 43 |
]
|
| 44 |
|
|
|
|
| 45 |
SYSTEM_PROMPT = """\
|
| 46 |
-
You are an
|
| 47 |
-
PostgreSQL internals, query planning,
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
1. Identify ALL performance
|
| 52 |
-
2. Produce
|
| 53 |
-
3.
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
{
|
| 57 |
"suggestions": [
|
| 58 |
{
|
| 59 |
-
"issue_type": "
|
| 60 |
-
"line": <integer
|
| 61 |
-
"description": "
|
| 62 |
"severity": "critical | high | medium | low",
|
| 63 |
-
"fix": "specific
|
| 64 |
}
|
| 65 |
],
|
| 66 |
-
"optimized_query": "
|
| 67 |
-
"summary": "2-4 sentence
|
| 68 |
-
"estimated_improvement": "e.g. '
|
| 69 |
"approved": false
|
| 70 |
}
|
| 71 |
-
|
| 72 |
-
Be thorough and precise. Every issue you identify should have a concrete fix.
|
| 73 |
"""
|
| 74 |
|
| 75 |
-
|
| 76 |
-
# ββ Logging helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 77 |
|
| 78 |
def log_start(task: str, env: str, model: str) -> None:
|
| 79 |
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 80 |
|
| 81 |
|
| 82 |
-
def log_step(
|
| 83 |
-
|
| 84 |
-
|
| 85 |
print(
|
| 86 |
-
f"[STEP] step={step} action={action} reward={reward:.2f}
|
|
|
|
| 87 |
flush=True,
|
| 88 |
)
|
| 89 |
|
| 90 |
|
| 91 |
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 92 |
-
|
| 93 |
print(
|
| 94 |
-
f"[END] success={str(success).lower()} steps={steps}
|
|
|
|
| 95 |
flush=True,
|
| 96 |
)
|
| 97 |
|
| 98 |
|
| 99 |
-
# ββ Model interaction βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 100 |
|
| 101 |
-
def parse_action(
|
| 102 |
-
|
| 103 |
-
clean = response_text.strip()
|
| 104 |
if clean.startswith("```"):
|
| 105 |
lines = clean.split("\n")
|
| 106 |
-
|
| 107 |
-
|
|
|
|
| 108 |
if clean.startswith("json"):
|
| 109 |
clean = clean[4:].strip()
|
| 110 |
try:
|
| 111 |
return json.loads(clean)
|
| 112 |
except json.JSONDecodeError:
|
| 113 |
return {
|
| 114 |
-
"suggestions":
|
| 115 |
-
"optimized_query":
|
| 116 |
-
"summary":
|
| 117 |
"estimated_improvement": "unknown",
|
| 118 |
-
"approved":
|
| 119 |
}
|
| 120 |
|
| 121 |
|
| 122 |
-
def
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
{obs.
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
-
Issues identified in previous steps: {obs.issues_found_so_far if obs.issues_found_so_far else 'None yet'}
|
| 140 |
|
| 141 |
-
|
| 142 |
-
"""
|
| 143 |
try:
|
| 144 |
-
|
| 145 |
model=MODEL_NAME,
|
| 146 |
messages=[
|
| 147 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 148 |
-
{"role": "user",
|
| 149 |
],
|
| 150 |
temperature=TEMPERATURE,
|
| 151 |
max_tokens=MAX_TOKENS,
|
| 152 |
stream=False,
|
| 153 |
)
|
| 154 |
-
|
| 155 |
-
return parse_action(response_text), None
|
| 156 |
except Exception as exc:
|
| 157 |
-
error_msg = str(exc)
|
| 158 |
return {
|
| 159 |
-
"suggestions": [],
|
| 160 |
-
"
|
| 161 |
-
"summary": f"Model call failed: {error_msg}",
|
| 162 |
"estimated_improvement": "unknown",
|
| 163 |
-
|
| 164 |
-
}, error_msg
|
| 165 |
|
| 166 |
|
| 167 |
-
# ββ Main loop βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 168 |
|
| 169 |
def main():
|
| 170 |
if not HF_TOKEN:
|
| 171 |
-
print("[ERROR] HF_TOKEN
|
| 172 |
sys.exit(1)
|
| 173 |
|
| 174 |
client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
|
| 175 |
local_env = SQLOptimEnv()
|
| 176 |
-
results = {}
|
| 177 |
|
| 178 |
for task_id in TASK_IDS:
|
| 179 |
obs = local_env.reset(task_id=task_id)
|
|
@@ -186,7 +203,7 @@ def main():
|
|
| 186 |
|
| 187 |
try:
|
| 188 |
for step in range(1, obs.max_steps + 1):
|
| 189 |
-
parsed, error =
|
| 190 |
|
| 191 |
action = Action(
|
| 192 |
suggestions=parsed.get("suggestions", []),
|
|
@@ -200,12 +217,22 @@ def main():
|
|
| 200 |
reward = result.reward.score
|
| 201 |
done = result.done
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
rewards.append(reward)
|
| 204 |
steps_taken = step
|
| 205 |
obs = result.observation
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
|
| 210 |
if done:
|
| 211 |
break
|
|
@@ -214,10 +241,11 @@ def main():
|
|
| 214 |
success = score >= 0.5
|
| 215 |
|
| 216 |
finally:
|
| 217 |
-
log_end(success=success, steps=steps_taken,
|
|
|
|
| 218 |
|
| 219 |
results[task_id] = {
|
| 220 |
-
"task_name":
|
| 221 |
"final_score": round(score, 4),
|
| 222 |
"steps_taken": steps_taken,
|
| 223 |
}
|
|
|
|
| 1 |
"""
|
| 2 |
inference.py β SQL Query Optimization Environment
|
| 3 |
===================================================
|
| 4 |
+
Multi-step inference loop with execution-feedback awareness.
|
| 5 |
|
| 6 |
+
When the environment returns execution results from a previous step,
|
| 7 |
+
the agent uses them to REFINE its optimized query β creating a genuine
|
| 8 |
+
iterative optimization loop grounded in real performance data.
|
|
|
|
| 9 |
|
| 10 |
stdout format (strictly followed):
|
| 11 |
+
[START] task=<task_id> env=sql-optim-env model=<MODEL_NAME>
|
| 12 |
+
[STEP] step=<n> action=<summary> reward=<0.00> done=<bool> error=<msg|null>
|
| 13 |
+
[END] success=<bool> steps=<n> score=<score> rewards=<r1,...,rn>
|
| 14 |
"""
|
| 15 |
|
|
|
|
| 16 |
import json
|
| 17 |
+
import os
|
| 18 |
import sys
|
| 19 |
+
from typing import Dict, List, Optional
|
| 20 |
+
|
| 21 |
from openai import OpenAI
|
| 22 |
|
|
|
|
| 23 |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 24 |
sys.path.insert(0, ROOT_DIR)
|
| 25 |
|
| 26 |
from env import SQLOptimEnv
|
| 27 |
from models import Action
|
| 28 |
|
| 29 |
+
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 31 |
MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 32 |
HF_TOKEN = os.environ.get("HF_TOKEN", "") or os.environ.get("API_KEY", "")
|
| 33 |
|
| 34 |
+
BENCHMARK = "sql-optim-env"
|
| 35 |
+
TEMPERATURE = 0.0
|
| 36 |
+
MAX_TOKENS = 2000
|
| 37 |
|
| 38 |
TASK_IDS = [
|
| 39 |
"task_1_basic_antipatterns",
|
| 40 |
+
"task_2_correlated_subqueries",
|
| 41 |
+
"task_3_wildcard_scan",
|
| 42 |
+
"task_4_implicit_join",
|
| 43 |
+
"task_5_window_functions",
|
| 44 |
]
|
| 45 |
|
| 46 |
+
# ββ System prompt βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 47 |
SYSTEM_PROMPT = """\
|
| 48 |
+
You are an elite database engineer and SQL performance specialist with expert-level \
|
| 49 |
+
knowledge of PostgreSQL/DuckDB internals, query planning, columnar storage, \
|
| 50 |
+
and index design.
|
| 51 |
+
|
| 52 |
+
You will receive a SQL query and its schema. Your job:
|
| 53 |
+
1. Identify ALL performance anti-patterns.
|
| 54 |
+
2. Produce a complete, correct, optimized rewrite.
|
| 55 |
+
3. Your optimized_query will be ACTUALLY EXECUTED against a DuckDB database \
|
| 56 |
+
with realistic data (orders=500k rows, events=1M rows). \
|
| 57 |
+
If it returns wrong results or errors, your score drops.
|
| 58 |
+
4. If you receive execution feedback from a previous step, USE IT to refine \
|
| 59 |
+
your rewrite β fix incorrect results first, then improve speed.
|
| 60 |
+
|
| 61 |
+
Respond ONLY with valid JSON (no markdown, no fences):
|
| 62 |
{
|
| 63 |
"suggestions": [
|
| 64 |
{
|
| 65 |
+
"issue_type": "e.g. select_star / correlated_subquery / wildcard_like",
|
| 66 |
+
"line": <integer>,
|
| 67 |
+
"description": "precise explanation of the performance problem",
|
| 68 |
"severity": "critical | high | medium | low",
|
| 69 |
+
"fix": "specific rewrite or corrective SQL"
|
| 70 |
}
|
| 71 |
],
|
| 72 |
+
"optimized_query": "<complete, executable SQL that produces IDENTICAL results to original>",
|
| 73 |
+
"summary": "2-4 sentence performance profile of the original query",
|
| 74 |
+
"estimated_improvement": "e.g. '15x faster β eliminates N+1 subquery pattern'",
|
| 75 |
"approved": false
|
| 76 |
}
|
|
|
|
|
|
|
| 77 |
"""
|
| 78 |
|
| 79 |
+
# ββ Logging (strict OpenEnv format) ββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 80 |
|
| 81 |
def log_start(task: str, env: str, model: str) -> None:
|
| 82 |
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 83 |
|
| 84 |
|
| 85 |
+
def log_step(
|
| 86 |
+
step: int, action: str, reward: float, done: bool, error: Optional[str]
|
| 87 |
+
) -> None:
|
| 88 |
print(
|
| 89 |
+
f"[STEP] step={step} action={action} reward={reward:.2f} "
|
| 90 |
+
f"done={str(done).lower()} error={error or 'null'}",
|
| 91 |
flush=True,
|
| 92 |
)
|
| 93 |
|
| 94 |
|
| 95 |
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 96 |
+
rstr = ",".join(f"{r:.2f}" for r in rewards)
|
| 97 |
print(
|
| 98 |
+
f"[END] success={str(success).lower()} steps={steps} "
|
| 99 |
+
f"score={score:.2f} rewards={rstr}",
|
| 100 |
flush=True,
|
| 101 |
)
|
| 102 |
|
| 103 |
|
| 104 |
+
# ββ Model interaction βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 105 |
|
| 106 |
+
def parse_action(text: str) -> Dict:
|
| 107 |
+
clean = text.strip()
|
|
|
|
| 108 |
if clean.startswith("```"):
|
| 109 |
lines = clean.split("\n")
|
| 110 |
+
clean = "\n".join(
|
| 111 |
+
lines[1:-1] if lines[-1].strip() == "```" else lines[1:]
|
| 112 |
+
)
|
| 113 |
if clean.startswith("json"):
|
| 114 |
clean = clean[4:].strip()
|
| 115 |
try:
|
| 116 |
return json.loads(clean)
|
| 117 |
except json.JSONDecodeError:
|
| 118 |
return {
|
| 119 |
+
"suggestions": [],
|
| 120 |
+
"optimized_query": "",
|
| 121 |
+
"summary": "Parse error β model returned malformed JSON.",
|
| 122 |
"estimated_improvement": "unknown",
|
| 123 |
+
"approved": False,
|
| 124 |
}
|
| 125 |
|
| 126 |
|
| 127 |
+
def build_user_prompt(obs) -> str:
|
| 128 |
+
exec_feedback = ""
|
| 129 |
+
if obs.last_execution:
|
| 130 |
+
ex = obs.last_execution
|
| 131 |
+
exec_feedback = (
|
| 132 |
+
f"\n\nβ‘ EXECUTION FEEDBACK FROM YOUR LAST OPTIMIZED QUERY:\n"
|
| 133 |
+
f" Original query : {ex.get('original_ms', '?'):.1f} ms "
|
| 134 |
+
f" ({ex.get('original_rows', 0)} rows)\n"
|
| 135 |
+
f" Your last query : {ex.get('optimized_ms', '?'):.1f} ms "
|
| 136 |
+
f" ({ex.get('optimized_rows', 0)} rows)\n"
|
| 137 |
+
f" Speedup achieved: {ex.get('speedup', 1.0):.2f}x\n"
|
| 138 |
+
f" Results match : {'β
YES' if ex.get('results_match') else 'β NO β fix your WHERE/JOIN logic'}\n"
|
| 139 |
+
f" Verdict : {ex.get('verdict', '')}\n"
|
| 140 |
+
f"Refine your optimized_query to fix any correctness issues first, "
|
| 141 |
+
f"then improve speed further."
|
| 142 |
+
)
|
| 143 |
|
| 144 |
+
issues_ctx = ""
|
| 145 |
+
if obs.issues_found_so_far:
|
| 146 |
+
issues_ctx = (
|
| 147 |
+
f"\nIssue types you've already flagged: {obs.issues_found_so_far}"
|
| 148 |
+
)
|
| 149 |
|
| 150 |
+
return (
|
| 151 |
+
f"Task : {obs.task_name}\n"
|
| 152 |
+
f"Difficulty : {obs.difficulty}\n"
|
| 153 |
+
f"Step : {obs.step_count + 1} / {obs.max_steps}\n\n"
|
| 154 |
+
f"Instructions:\n{obs.task_description}\n\n"
|
| 155 |
+
f"Database Schema:\n{obs.schema_info}\n\n"
|
| 156 |
+
f"SQL Query to Optimize:\n```sql\n{obs.sql_query}\n```"
|
| 157 |
+
f"{issues_ctx}"
|
| 158 |
+
f"{exec_feedback}\n\n"
|
| 159 |
+
f"Provide your complete analysis and optimized_query now."
|
| 160 |
+
)
|
| 161 |
|
|
|
|
| 162 |
|
| 163 |
+
def call_model(client: OpenAI, obs) -> tuple:
|
|
|
|
| 164 |
try:
|
| 165 |
+
resp = client.chat.completions.create(
|
| 166 |
model=MODEL_NAME,
|
| 167 |
messages=[
|
| 168 |
{"role": "system", "content": SYSTEM_PROMPT},
|
| 169 |
+
{"role": "user", "content": build_user_prompt(obs)},
|
| 170 |
],
|
| 171 |
temperature=TEMPERATURE,
|
| 172 |
max_tokens=MAX_TOKENS,
|
| 173 |
stream=False,
|
| 174 |
)
|
| 175 |
+
return parse_action(resp.choices[0].message.content or ""), None
|
|
|
|
| 176 |
except Exception as exc:
|
|
|
|
| 177 |
return {
|
| 178 |
+
"suggestions": [], "optimized_query": "", "approved": False,
|
| 179 |
+
"summary": f"Model error: {exc}",
|
|
|
|
| 180 |
"estimated_improvement": "unknown",
|
| 181 |
+
}, str(exc)
|
|
|
|
| 182 |
|
| 183 |
|
| 184 |
+
# ββ Main loop βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 185 |
|
| 186 |
def main():
|
| 187 |
if not HF_TOKEN:
|
| 188 |
+
print("[ERROR] HF_TOKEN not set.", flush=True)
|
| 189 |
sys.exit(1)
|
| 190 |
|
| 191 |
client = OpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
|
| 192 |
local_env = SQLOptimEnv()
|
| 193 |
+
results: Dict[str, Dict] = {}
|
| 194 |
|
| 195 |
for task_id in TASK_IDS:
|
| 196 |
obs = local_env.reset(task_id=task_id)
|
|
|
|
| 203 |
|
| 204 |
try:
|
| 205 |
for step in range(1, obs.max_steps + 1):
|
| 206 |
+
parsed, error = call_model(client, obs)
|
| 207 |
|
| 208 |
action = Action(
|
| 209 |
suggestions=parsed.get("suggestions", []),
|
|
|
|
| 217 |
reward = result.reward.score
|
| 218 |
done = result.done
|
| 219 |
|
| 220 |
+
# Pull execution info for the action summary
|
| 221 |
+
exec_info = result.info.get("execution") or {}
|
| 222 |
+
speedup = exec_info.get("speedup", 1.0)
|
| 223 |
+
correct = exec_info.get("results_match", False)
|
| 224 |
+
action_summary = (
|
| 225 |
+
f"suggestions={len(action.suggestions)},"
|
| 226 |
+
f"speedup={speedup:.2f}x,"
|
| 227 |
+
f"correct={str(correct).lower()}"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
rewards.append(reward)
|
| 231 |
steps_taken = step
|
| 232 |
obs = result.observation
|
| 233 |
|
| 234 |
+
log_step(step=step, action=action_summary,
|
| 235 |
+
reward=reward, done=done, error=error)
|
| 236 |
|
| 237 |
if done:
|
| 238 |
break
|
|
|
|
| 241 |
success = score >= 0.5
|
| 242 |
|
| 243 |
finally:
|
| 244 |
+
log_end(success=success, steps=steps_taken,
|
| 245 |
+
score=score, rewards=rewards)
|
| 246 |
|
| 247 |
results[task_id] = {
|
| 248 |
+
"task_name": obs.task_name,
|
| 249 |
"final_score": round(score, 4),
|
| 250 |
"steps_taken": steps_taken,
|
| 251 |
}
|
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
leaderboard.py β In-Memory Best-Score Tracker
|
| 3 |
+
Tracks every execution attempt across all tasks so the /leaderboard
|
| 4 |
+
endpoint can display real-time standings.
|
| 5 |
+
"""
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
from datetime import datetime, timezone
|
| 8 |
+
from typing import Any, Dict, List
|
| 9 |
+
|
| 10 |
+
_board: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def record(
|
| 14 |
+
task_id: str,
|
| 15 |
+
speedup: float,
|
| 16 |
+
score: float,
|
| 17 |
+
results_match: bool,
|
| 18 |
+
steps: int,
|
| 19 |
+
) -> None:
|
| 20 |
+
_board[task_id].append(
|
| 21 |
+
{
|
| 22 |
+
"speedup": round(speedup, 3),
|
| 23 |
+
"score": round(score, 4),
|
| 24 |
+
"results_match": results_match,
|
| 25 |
+
"steps": steps,
|
| 26 |
+
"ts": datetime.now(timezone.utc).isoformat(),
|
| 27 |
+
}
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_board() -> Dict[str, Any]:
|
| 32 |
+
out: Dict[str, Any] = {}
|
| 33 |
+
for task_id, entries in _board.items():
|
| 34 |
+
if not entries:
|
| 35 |
+
continue
|
| 36 |
+
best = max(entries, key=lambda e: e["score"])
|
| 37 |
+
valid = [e for e in entries if e["results_match"]]
|
| 38 |
+
fastest = max(valid, key=lambda e: e["speedup"]) if valid else None
|
| 39 |
+
|
| 40 |
+
out[task_id] = {
|
| 41 |
+
"best_score": best["score"],
|
| 42 |
+
"best_speedup": fastest["speedup"] if fastest else 0.0,
|
| 43 |
+
"total_attempts": len(entries),
|
| 44 |
+
"correct_attempts": len(valid),
|
| 45 |
+
"success_rate": round(len(valid) / len(entries), 3),
|
| 46 |
+
"best_attempt_at": best["ts"],
|
| 47 |
+
}
|
| 48 |
+
return out
|
|
@@ -6,38 +6,61 @@ class Observation(BaseModel):
|
|
| 6 |
task_id: str = Field(..., description="Unique task identifier")
|
| 7 |
task_name: str = Field(..., description="Human-readable task name")
|
| 8 |
task_description: str = Field(..., description="What the agent must do")
|
| 9 |
-
sql_query: str = Field(..., description="The SQL query to analyze
|
| 10 |
-
schema_info: str = Field(..., description="Database schema
|
| 11 |
-
dialect: str = Field(default="postgresql", description="SQL dialect
|
| 12 |
-
difficulty: str = Field(..., description="easy | medium | hard")
|
| 13 |
step_count: int = Field(default=0, description="Steps taken in this episode")
|
| 14 |
max_steps: int = Field(default=5, description="Max steps per episode")
|
| 15 |
-
issues_found_so_far: List[str] = Field(
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
|
| 25 |
|
| 26 |
class Action(BaseModel):
|
| 27 |
suggestions: List[Dict[str, Any]] = Field(
|
| 28 |
...,
|
| 29 |
-
description="List of
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
)
|
| 31 |
-
optimized_query: str = Field(..., description="Rewritten/optimized version of the SQL query")
|
| 32 |
-
summary: str = Field(..., description="Overall analysis summary")
|
| 33 |
-
estimated_improvement: str = Field(..., description="Estimated performance improvement (e.g. '10x faster', '~50% less I/O')")
|
| 34 |
-
approved: bool = Field(..., description="Whether query is already optimal (True) or needs changes (False)")
|
| 35 |
|
| 36 |
|
| 37 |
class Reward(BaseModel):
|
| 38 |
-
score: float = Field(..., ge=0.0, le=1.0, description="
|
| 39 |
breakdown: Dict[str, float] = Field(..., description="Per-criterion scores")
|
| 40 |
-
feedback: str = Field(..., description="Human-readable feedback
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
class StepResult(BaseModel):
|
|
|
|
| 6 |
task_id: str = Field(..., description="Unique task identifier")
|
| 7 |
task_name: str = Field(..., description="Human-readable task name")
|
| 8 |
task_description: str = Field(..., description="What the agent must do")
|
| 9 |
+
sql_query: str = Field(..., description="The SQL query to analyze and optimize")
|
| 10 |
+
schema_info: str = Field(..., description="Database schema, table sizes, and index info")
|
| 11 |
+
dialect: str = Field(default="duckdb/postgresql", description="SQL dialect")
|
| 12 |
+
difficulty: str = Field(..., description="easy | medium | hard | expert")
|
| 13 |
step_count: int = Field(default=0, description="Steps taken in this episode")
|
| 14 |
max_steps: int = Field(default=5, description="Max steps per episode")
|
| 15 |
+
issues_found_so_far: List[str] = Field(
|
| 16 |
+
default_factory=list,
|
| 17 |
+
description="Issue types flagged in previous steps"
|
| 18 |
+
)
|
| 19 |
+
last_execution: Optional[Dict[str, Any]] = Field(
|
| 20 |
+
None,
|
| 21 |
+
description="Execution comparison result from previous step β "
|
| 22 |
+
"use this to refine your optimized_query"
|
| 23 |
+
)
|
| 24 |
|
| 25 |
|
| 26 |
class Action(BaseModel):
|
| 27 |
suggestions: List[Dict[str, Any]] = Field(
|
| 28 |
...,
|
| 29 |
+
description="List of issues. Each: {issue_type, line, description, severity, fix}"
|
| 30 |
+
)
|
| 31 |
+
optimized_query: str = Field(
|
| 32 |
+
...,
|
| 33 |
+
description="Complete rewritten SQL β will be EXECUTED against real data to measure speedup"
|
| 34 |
+
)
|
| 35 |
+
summary: str = Field(..., description="Overall analysis and performance profile")
|
| 36 |
+
estimated_improvement: str = Field(
|
| 37 |
+
...,
|
| 38 |
+
description="Expected speedup (e.g. '10x faster', '~80% I/O reduction')"
|
| 39 |
+
)
|
| 40 |
+
approved: bool = Field(
|
| 41 |
+
...,
|
| 42 |
+
description="True if query is already optimal, False if it needs changes"
|
| 43 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
class Reward(BaseModel):
|
| 47 |
+
score: float = Field(..., ge=0.0, le=1.0, description="Composite reward 0.0β1.0")
|
| 48 |
breakdown: Dict[str, float] = Field(..., description="Per-criterion scores")
|
| 49 |
+
feedback: str = Field(..., description="Human-readable feedback with execution details")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ExecutionResult(BaseModel):
|
| 53 |
+
"""Real DuckDB execution comparison β returned by /execute endpoint."""
|
| 54 |
+
original_ms: float = Field(..., description="Original query median execution time (ms)")
|
| 55 |
+
optimized_ms: float = Field(..., description="Optimized query median execution time (ms)")
|
| 56 |
+
speedup: float = Field(..., description="Speedup ratio (original_ms / optimized_ms)")
|
| 57 |
+
results_match: bool = Field(..., description="Do both queries return identical results?")
|
| 58 |
+
original_rows: int = Field(..., description="Row count from original query")
|
| 59 |
+
optimized_rows: int = Field(..., description="Row count from optimized query")
|
| 60 |
+
original_error: Optional[str] = Field(None, description="Error from original, if any")
|
| 61 |
+
optimized_error: Optional[str] = Field(None, description="Error from optimized, if any")
|
| 62 |
+
verdict: str = Field(..., description="Human-readable verdict")
|
| 63 |
+
explain_plan: Optional[str] = Field(None, description="EXPLAIN output for optimized query")
|
| 64 |
|
| 65 |
|
| 66 |
class StepResult(BaseModel):
|
|
@@ -1,11 +1,13 @@
|
|
| 1 |
name: sql-optim-env
|
| 2 |
-
version: "
|
| 3 |
description: >
|
| 4 |
An OpenEnv-compliant reinforcement learning environment where AI agents
|
| 5 |
-
learn to
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
|
| 10 |
tags:
|
| 11 |
- openenv
|
|
@@ -13,6 +15,8 @@ tags:
|
|
| 13 |
- database
|
| 14 |
- performance
|
| 15 |
- optimization
|
|
|
|
|
|
|
| 16 |
- llm-agent
|
| 17 |
|
| 18 |
language: python
|
|
@@ -31,6 +35,7 @@ observation_space:
|
|
| 31 |
step_count: integer
|
| 32 |
max_steps: integer
|
| 33 |
issues_found_so_far: array
|
|
|
|
| 34 |
|
| 35 |
action_space:
|
| 36 |
type: object
|
|
@@ -46,36 +51,52 @@ reward:
|
|
| 46 |
min: 0.0
|
| 47 |
max: 1.0
|
| 48 |
description: >
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
| 52 |
|
| 53 |
tasks:
|
| 54 |
- id: task_1_basic_antipatterns
|
| 55 |
name: "Basic SQL Anti-pattern Detection"
|
| 56 |
difficulty: easy
|
| 57 |
max_steps: 3
|
| 58 |
-
description: "
|
| 59 |
|
| 60 |
-
- id:
|
| 61 |
-
name: "N+1
|
| 62 |
difficulty: medium
|
| 63 |
max_steps: 4
|
| 64 |
-
description: "
|
| 65 |
|
| 66 |
-
- id:
|
| 67 |
-
name: "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
difficulty: hard
|
| 69 |
max_steps: 5
|
| 70 |
-
description: "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
endpoints:
|
| 73 |
-
reset:
|
| 74 |
-
step:
|
| 75 |
-
state:
|
| 76 |
-
tasks:
|
| 77 |
-
grader:
|
| 78 |
-
baseline:
|
|
|
|
|
|
|
| 79 |
|
| 80 |
deployment:
|
| 81 |
platform: huggingface-spaces
|
|
|
|
| 1 |
name: sql-optim-env
|
| 2 |
+
version: "2.0.0"
|
| 3 |
description: >
|
| 4 |
An OpenEnv-compliant reinforcement learning environment where AI agents
|
| 5 |
+
learn to diagnose and optimize SQL queries. Unlike any other submission,
|
| 6 |
+
optimized queries are ACTUALLY EXECUTED against a DuckDB in-memory
|
| 7 |
+
database with realistic synthetic data (500k orders, 1M events).
|
| 8 |
+
Reward is computed from real execution speedup + result correctness β
|
| 9 |
+
not keyword heuristics. Five tasks from easy anti-patterns to expert
|
| 10 |
+
window function audits.
|
| 11 |
|
| 12 |
tags:
|
| 13 |
- openenv
|
|
|
|
| 15 |
- database
|
| 16 |
- performance
|
| 17 |
- optimization
|
| 18 |
+
- duckdb
|
| 19 |
+
- execution-grounded
|
| 20 |
- llm-agent
|
| 21 |
|
| 22 |
language: python
|
|
|
|
| 35 |
step_count: integer
|
| 36 |
max_steps: integer
|
| 37 |
issues_found_so_far: array
|
| 38 |
+
last_execution: object
|
| 39 |
|
| 40 |
action_space:
|
| 41 |
type: object
|
|
|
|
| 51 |
min: 0.0
|
| 52 |
max: 1.0
|
| 53 |
description: >
|
| 54 |
+
Execution-grounded composite score:
|
| 55 |
+
Real Speedup (35%) β actual DuckDB timing ratio,
|
| 56 |
+
Result Correctness (20%) β both queries return identical data,
|
| 57 |
+
Issue Detection (25%) β keyword match vs ground truth,
|
| 58 |
+
Approval Correctness (8%), Summary Quality (7%), Severity Labels (5%).
|
| 59 |
|
| 60 |
tasks:
|
| 61 |
- id: task_1_basic_antipatterns
|
| 62 |
name: "Basic SQL Anti-pattern Detection"
|
| 63 |
difficulty: easy
|
| 64 |
max_steps: 3
|
| 65 |
+
description: "SELECT *, CAST on filter column, YEAR() function β 3 classic anti-patterns on 500k rows"
|
| 66 |
|
| 67 |
+
- id: task_2_correlated_subqueries
|
| 68 |
+
name: "N+1 Correlated Subquery Elimination"
|
| 69 |
difficulty: medium
|
| 70 |
max_steps: 4
|
| 71 |
+
description: "3 correlated subqueries causing ~10M row reads β rewrite to single aggregation JOIN"
|
| 72 |
|
| 73 |
+
- id: task_3_wildcard_scan
|
| 74 |
+
name: "Wildcard LIKE & Projection Optimization"
|
| 75 |
+
difficulty: medium-hard
|
| 76 |
+
max_steps: 4
|
| 77 |
+
description: "Leading-wildcard LIKE on 1M events, SELECT *, pre-filter push-down"
|
| 78 |
+
|
| 79 |
+
- id: task_4_implicit_join
|
| 80 |
+
name: "Implicit Cross Join & Scalar Subquery Elimination"
|
| 81 |
difficulty: hard
|
| 82 |
max_steps: 5
|
| 83 |
+
description: "Comma-syntax join risk + 2 correlated global aggregations β rewrite with CTE"
|
| 84 |
+
|
| 85 |
+
- id: task_5_window_functions
|
| 86 |
+
name: "Window Function & Full-Scan Audit"
|
| 87 |
+
difficulty: expert
|
| 88 |
+
max_steps: 5
|
| 89 |
+
description: "5 window functions over 1M unfiltered rows including a global RANK() sort"
|
| 90 |
|
| 91 |
endpoints:
|
| 92 |
+
reset: POST /reset
|
| 93 |
+
step: POST /step
|
| 94 |
+
state: GET /state
|
| 95 |
+
tasks: GET /tasks
|
| 96 |
+
grader: POST /grader
|
| 97 |
+
baseline: POST /baseline
|
| 98 |
+
execute: POST /execute
|
| 99 |
+
leaderboard: GET /leaderboard
|
| 100 |
|
| 101 |
deployment:
|
| 102 |
platform: huggingface-spaces
|
|
@@ -5,3 +5,4 @@ openai>=1.0.0
|
|
| 5 |
pyyaml==6.0.2
|
| 6 |
requests==2.32.3
|
| 7 |
openenv-core>=0.2.0
|
|
|
|
|
|
| 5 |
pyyaml==6.0.2
|
| 6 |
requests==2.32.3
|
| 7 |
openenv-core>=0.2.0
|
| 8 |
+
duckdb>=0.10.0
|
|
@@ -1,23 +1,55 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import os
|
| 4 |
import sys
|
| 5 |
-
import
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 8 |
|
| 9 |
from env import SQLOptimEnv
|
| 10 |
-
from
|
| 11 |
-
from tasks import get_task_list
|
| 12 |
from graders import grade
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
app = FastAPI(
|
| 15 |
title="SQL Query Optimization Environment",
|
| 16 |
description=(
|
| 17 |
-
"OpenEnv-compliant RL environment where AI agents learn to
|
| 18 |
-
"
|
|
|
|
|
|
|
| 19 |
),
|
| 20 |
-
version="
|
|
|
|
| 21 |
)
|
| 22 |
|
| 23 |
app.add_middleware(
|
|
@@ -30,22 +62,24 @@ app.add_middleware(
|
|
| 30 |
env = SQLOptimEnv()
|
| 31 |
|
| 32 |
|
|
|
|
|
|
|
| 33 |
@app.get("/")
|
| 34 |
def root():
|
|
|
|
| 35 |
return {
|
| 36 |
-
"status":
|
| 37 |
"environment": "sql-optim-env",
|
| 38 |
-
"version":
|
| 39 |
-
"
|
|
|
|
|
|
|
| 40 |
}
|
| 41 |
|
| 42 |
|
| 43 |
@app.post("/reset", response_model=Observation)
|
| 44 |
async def reset(request: Request):
|
| 45 |
-
"""
|
| 46 |
-
Start a new episode. Optionally pass {"task_id": "..."} in the body.
|
| 47 |
-
Defaults to task_1_basic_antipatterns.
|
| 48 |
-
"""
|
| 49 |
try:
|
| 50 |
body = await request.body()
|
| 51 |
task_id = "task_1_basic_antipatterns"
|
|
@@ -55,31 +89,27 @@ async def reset(request: Request):
|
|
| 55 |
task_id = data.get("task_id", task_id) or task_id
|
| 56 |
except Exception:
|
| 57 |
pass
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
raise HTTPException(status_code=400, detail=str(e))
|
| 62 |
|
| 63 |
|
| 64 |
@app.post("/step", response_model=StepResult)
|
| 65 |
def step(action: Action):
|
| 66 |
-
"""
|
| 67 |
try:
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
raise HTTPException(status_code=400, detail=str(e))
|
| 72 |
|
| 73 |
|
| 74 |
@app.get("/state", response_model=EnvironmentState)
|
| 75 |
def state():
|
| 76 |
-
"""Get current environment state without advancing the episode."""
|
| 77 |
return env.state()
|
| 78 |
|
| 79 |
|
| 80 |
@app.get("/tasks")
|
| 81 |
def tasks():
|
| 82 |
-
"""List all available tasks with descriptions and action schema."""
|
| 83 |
return {"tasks": get_task_list()}
|
| 84 |
|
| 85 |
|
|
@@ -88,29 +118,102 @@ def grader(action: Action):
|
|
| 88 |
"""Grade an action against the current task without advancing the episode."""
|
| 89 |
if env._task_data is None:
|
| 90 |
raise HTTPException(status_code=400, detail="No active episode. Call /reset first.")
|
| 91 |
-
|
| 92 |
-
return reward
|
| 93 |
|
| 94 |
|
| 95 |
@app.post("/baseline")
|
| 96 |
def baseline():
|
| 97 |
-
"""Run the baseline
|
|
|
|
| 98 |
try:
|
| 99 |
-
import subprocess
|
| 100 |
result = subprocess.run(
|
| 101 |
["python", "inference.py"],
|
| 102 |
-
capture_output=True,
|
| 103 |
-
|
|
|
|
|
|
|
| 104 |
)
|
| 105 |
return {
|
| 106 |
-
"stdout":
|
| 107 |
-
"stderr":
|
| 108 |
"returncode": result.returncode,
|
| 109 |
}
|
| 110 |
-
except Exception as
|
| 111 |
-
raise HTTPException(status_code=500, detail=f"Baseline failed: {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
|
|
|
|
| 114 |
|
| 115 |
def main():
|
| 116 |
import uvicorn
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
server/app.py β FastAPI Server
|
| 3 |
+
================================
|
| 4 |
+
OpenEnv-compliant endpoints + two unique endpoints:
|
| 5 |
+
POST /execute β run your optimized query against real DuckDB data,
|
| 6 |
+
see actual speedup + result correctness instantly
|
| 7 |
+
GET /leaderboard β see best scores + speedups across all tasks
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
import os
|
| 12 |
import sys
|
| 13 |
+
from contextlib import asynccontextmanager
|
| 14 |
+
|
| 15 |
+
from fastapi import FastAPI, HTTPException, Request
|
| 16 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 17 |
|
| 18 |
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 19 |
|
| 20 |
from env import SQLOptimEnv
|
| 21 |
+
from executor import get_executor
|
|
|
|
| 22 |
from graders import grade
|
| 23 |
+
from leaderboard import get_board
|
| 24 |
+
from models import (
|
| 25 |
+
Action,
|
| 26 |
+
EnvironmentState,
|
| 27 |
+
ExecutionResult,
|
| 28 |
+
Observation,
|
| 29 |
+
StepResult,
|
| 30 |
+
)
|
| 31 |
+
from tasks import TASKS, get_task_list
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ββ Lifespan: pre-warm DuckDB on startup βββββββββββββββββββββββββββββββββ
|
| 35 |
+
|
| 36 |
+
@asynccontextmanager
|
| 37 |
+
async def lifespan(app: FastAPI):
|
| 38 |
+
# Build all 4 synthetic tables before first request
|
| 39 |
+
get_executor()
|
| 40 |
+
yield
|
| 41 |
+
|
| 42 |
|
| 43 |
app = FastAPI(
|
| 44 |
title="SQL Query Optimization Environment",
|
| 45 |
description=(
|
| 46 |
+
"OpenEnv-compliant RL environment where AI agents learn to diagnose "
|
| 47 |
+
"and optimize SQL queries. Uniquely, optimized queries are EXECUTED "
|
| 48 |
+
"against real DuckDB data β reward is based on actual speedup + "
|
| 49 |
+
"result correctness, not keyword heuristics."
|
| 50 |
),
|
| 51 |
+
version="2.0.0",
|
| 52 |
+
lifespan=lifespan,
|
| 53 |
)
|
| 54 |
|
| 55 |
app.add_middleware(
|
|
|
|
| 62 |
env = SQLOptimEnv()
|
| 63 |
|
| 64 |
|
| 65 |
+
# ββ Standard OpenEnv endpoints ββββββββββββββββββββββββββββββββββββββββββββ
|
| 66 |
+
|
| 67 |
@app.get("/")
|
| 68 |
def root():
|
| 69 |
+
ex = get_executor()
|
| 70 |
return {
|
| 71 |
+
"status": "ok",
|
| 72 |
"environment": "sql-optim-env",
|
| 73 |
+
"version": "2.0.0",
|
| 74 |
+
"unique_feature": "Execution-grounded rewards via DuckDB",
|
| 75 |
+
"table_stats": ex.table_stats,
|
| 76 |
+
"tasks": [t["task_id"] for t in get_task_list()],
|
| 77 |
}
|
| 78 |
|
| 79 |
|
| 80 |
@app.post("/reset", response_model=Observation)
|
| 81 |
async def reset(request: Request):
|
| 82 |
+
"""Start a new episode. Body: {"task_id": "..."} (optional)."""
|
|
|
|
|
|
|
|
|
|
| 83 |
try:
|
| 84 |
body = await request.body()
|
| 85 |
task_id = "task_1_basic_antipatterns"
|
|
|
|
| 89 |
task_id = data.get("task_id", task_id) or task_id
|
| 90 |
except Exception:
|
| 91 |
pass
|
| 92 |
+
return env.reset(task_id=task_id)
|
| 93 |
+
except ValueError as exc:
|
| 94 |
+
raise HTTPException(status_code=400, detail=str(exc))
|
|
|
|
| 95 |
|
| 96 |
|
| 97 |
@app.post("/step", response_model=StepResult)
|
| 98 |
def step(action: Action):
|
| 99 |
+
"""Submit an optimization action; get real execution feedback."""
|
| 100 |
try:
|
| 101 |
+
return env.step(action)
|
| 102 |
+
except RuntimeError as exc:
|
| 103 |
+
raise HTTPException(status_code=400, detail=str(exc))
|
|
|
|
| 104 |
|
| 105 |
|
| 106 |
@app.get("/state", response_model=EnvironmentState)
|
| 107 |
def state():
|
|
|
|
| 108 |
return env.state()
|
| 109 |
|
| 110 |
|
| 111 |
@app.get("/tasks")
|
| 112 |
def tasks():
|
|
|
|
| 113 |
return {"tasks": get_task_list()}
|
| 114 |
|
| 115 |
|
|
|
|
| 118 |
"""Grade an action against the current task without advancing the episode."""
|
| 119 |
if env._task_data is None:
|
| 120 |
raise HTTPException(status_code=400, detail="No active episode. Call /reset first.")
|
| 121 |
+
return grade(env._task_data, action)
|
|
|
|
| 122 |
|
| 123 |
|
| 124 |
@app.post("/baseline")
|
| 125 |
def baseline():
|
| 126 |
+
"""Run the baseline inference script and return output."""
|
| 127 |
+
import subprocess
|
| 128 |
try:
|
|
|
|
| 129 |
result = subprocess.run(
|
| 130 |
["python", "inference.py"],
|
| 131 |
+
capture_output=True,
|
| 132 |
+
text=True,
|
| 133 |
+
timeout=300,
|
| 134 |
+
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
| 135 |
)
|
| 136 |
return {
|
| 137 |
+
"stdout": result.stdout,
|
| 138 |
+
"stderr": result.stderr,
|
| 139 |
"returncode": result.returncode,
|
| 140 |
}
|
| 141 |
+
except Exception as exc:
|
| 142 |
+
raise HTTPException(status_code=500, detail=f"Baseline failed: {exc}")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
# ββ Unique endpoints (no other team has these) ββββββββββββββββββββββββββββ
|
| 146 |
+
|
| 147 |
+
@app.post("/execute", response_model=ExecutionResult)
|
| 148 |
+
async def execute(request: Request):
|
| 149 |
+
"""
|
| 150 |
+
π UNIQUE ENDPOINT β Execute your optimized query against real DuckDB data.
|
| 151 |
+
|
| 152 |
+
Body:
|
| 153 |
+
{
|
| 154 |
+
"task_id": "task_1_basic_antipatterns",
|
| 155 |
+
"optimized_query": "SELECT id, customer_id ... WHERE customer_id = 5000 ..."
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
Returns actual execution timing, speedup ratio, result correctness,
|
| 159 |
+
and an EXPLAIN plan β no other OpenEnv environment does this.
|
| 160 |
+
"""
|
| 161 |
+
body = await request.body()
|
| 162 |
+
if not body:
|
| 163 |
+
raise HTTPException(status_code=400, detail="Body required: {task_id, optimized_query}")
|
| 164 |
+
try:
|
| 165 |
+
data = json.loads(body)
|
| 166 |
+
except Exception:
|
| 167 |
+
raise HTTPException(status_code=400, detail="Invalid JSON body")
|
| 168 |
+
|
| 169 |
+
task_id = data.get("task_id", "task_1_basic_antipatterns")
|
| 170 |
+
optimized_query = (data.get("optimized_query") or "").strip()
|
| 171 |
+
|
| 172 |
+
if task_id not in TASKS:
|
| 173 |
+
raise HTTPException(status_code=400, detail=f"Unknown task_id: {task_id}")
|
| 174 |
+
if not optimized_query:
|
| 175 |
+
raise HTTPException(status_code=400, detail="optimized_query is required")
|
| 176 |
+
|
| 177 |
+
original_query = TASKS[task_id]["sql_query"]
|
| 178 |
+
ex = get_executor()
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
result = ex.compare(original_query, optimized_query)
|
| 182 |
+
explain = ex.explain(optimized_query)
|
| 183 |
+
return ExecutionResult(
|
| 184 |
+
original_ms=result["original_ms"],
|
| 185 |
+
optimized_ms=result["optimized_ms"],
|
| 186 |
+
speedup=result["speedup"],
|
| 187 |
+
results_match=result["results_match"],
|
| 188 |
+
original_rows=result["original_rows"],
|
| 189 |
+
optimized_rows=result["optimized_rows"],
|
| 190 |
+
original_error=result.get("original_error"),
|
| 191 |
+
optimized_error=result.get("optimized_error"),
|
| 192 |
+
verdict=result["verdict"],
|
| 193 |
+
explain_plan=explain,
|
| 194 |
+
)
|
| 195 |
+
except Exception as exc:
|
| 196 |
+
raise HTTPException(status_code=500, detail=str(exc))
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@app.get("/leaderboard")
|
| 200 |
+
def leaderboard():
|
| 201 |
+
"""
|
| 202 |
+
π UNIQUE ENDPOINT β Real-time leaderboard of best execution scores.
|
| 203 |
+
|
| 204 |
+
Shows per-task: best score, best speedup achieved, total attempts,
|
| 205 |
+
how many optimized queries produced correct results.
|
| 206 |
+
"""
|
| 207 |
+
return {
|
| 208 |
+
"leaderboard": get_board(),
|
| 209 |
+
"description": (
|
| 210 |
+
"Scores are based on real DuckDB execution: "
|
| 211 |
+
"speedup ratio (35%) + result correctness (20%) + issue detection (25%) + other (20%)"
|
| 212 |
+
),
|
| 213 |
+
}
|
| 214 |
|
| 215 |
|
| 216 |
+
# ββ Entry point βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 217 |
|
| 218 |
def main():
|
| 219 |
import uvicorn
|
|
@@ -1,216 +1,396 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
TASKS: Dict[str, Dict[str, Any]] = {
|
| 4 |
|
| 5 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 6 |
-
# TASK 1 β EASY: Basic
|
| 7 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 8 |
"task_1_basic_antipatterns": {
|
| 9 |
-
"task_id":
|
| 10 |
"task_name": "Basic SQL Anti-pattern Detection",
|
| 11 |
"task_description": (
|
| 12 |
-
"Analyze the SQL query below for common anti-patterns that
|
| 13 |
-
"Identify
|
| 14 |
-
"
|
| 15 |
-
"
|
|
|
|
|
|
|
|
|
|
| 16 |
),
|
| 17 |
"difficulty": "easy",
|
| 18 |
-
"dialect":
|
| 19 |
"max_steps": 3,
|
| 20 |
-
"schema_info":
|
| 21 |
-
Table: orders (
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
""
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
""
|
|
|
|
| 33 |
"ground_truth_issues": [
|
| 34 |
{
|
| 35 |
"type": "select_star",
|
| 36 |
-
"line":
|
| 37 |
-
"keywords": [
|
|
|
|
|
|
|
|
|
|
| 38 |
},
|
| 39 |
{
|
| 40 |
-
"type": "
|
| 41 |
-
"line":
|
| 42 |
-
"keywords": [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
},
|
| 44 |
{
|
| 45 |
-
"type": "
|
| 46 |
-
"line":
|
| 47 |
-
"keywords": [
|
|
|
|
|
|
|
|
|
|
| 48 |
},
|
| 49 |
],
|
| 50 |
"approved_expected": False,
|
| 51 |
},
|
| 52 |
|
| 53 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
-
# TASK 2 β MEDIUM: N+1
|
| 55 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 56 |
-
"
|
| 57 |
-
"task_id":
|
| 58 |
-
"task_name": "N+1
|
| 59 |
"task_description": (
|
| 60 |
-
"
|
| 61 |
-
"
|
| 62 |
-
"
|
| 63 |
-
"
|
|
|
|
|
|
|
| 64 |
),
|
| 65 |
"difficulty": "medium",
|
| 66 |
-
"dialect":
|
| 67 |
"max_steps": 4,
|
| 68 |
-
"schema_info":
|
| 69 |
-
Table: users (
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
Table:
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
"
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
u.
|
| 81 |
-
|
| 82 |
-
(SELECT
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
"ground_truth_issues": [
|
| 91 |
{
|
| 92 |
-
"type": "
|
| 93 |
"line": 4,
|
| 94 |
-
"keywords": [
|
|
|
|
|
|
|
|
|
|
| 95 |
},
|
| 96 |
{
|
| 97 |
-
"type": "
|
| 98 |
-
"line":
|
| 99 |
-
"keywords": [
|
|
|
|
|
|
|
|
|
|
| 100 |
},
|
| 101 |
{
|
| 102 |
-
"type": "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
"line": 6,
|
| 104 |
-
"keywords": [
|
|
|
|
|
|
|
|
|
|
| 105 |
},
|
| 106 |
{
|
| 107 |
-
"type": "
|
| 108 |
-
"line":
|
| 109 |
-
"keywords": [
|
|
|
|
|
|
|
|
|
|
| 110 |
},
|
| 111 |
{
|
| 112 |
-
"type": "
|
| 113 |
-
"line":
|
| 114 |
-
"keywords": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
},
|
| 116 |
],
|
| 117 |
"approved_expected": False,
|
| 118 |
},
|
| 119 |
|
| 120 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 121 |
-
# TASK
|
| 122 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 123 |
-
"
|
| 124 |
-
"task_id":
|
| 125 |
-
"task_name": "
|
| 126 |
"task_description": (
|
| 127 |
-
"
|
| 128 |
-
"
|
| 129 |
-
"
|
| 130 |
-
"implicit
|
| 131 |
-
"
|
| 132 |
-
"
|
| 133 |
-
"
|
| 134 |
-
"
|
| 135 |
),
|
| 136 |
"difficulty": "hard",
|
| 137 |
-
"dialect":
|
| 138 |
"max_steps": 5,
|
| 139 |
-
"schema_info":
|
| 140 |
-
Table:
|
| 141 |
-
Table:
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
""
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
AND properties->>'plan' = 'premium'
|
| 161 |
-
GROUP BY e.user_id, e.session_id
|
| 162 |
-
),
|
| 163 |
-
ranked_sessions AS (
|
| 164 |
-
SELECT
|
| 165 |
-
*,
|
| 166 |
-
ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY purchases DESC, session_end DESC) AS rn,
|
| 167 |
-
AVG(event_count) OVER (PARTITION BY user_id) AS avg_events_per_session
|
| 168 |
-
FROM user_sessions
|
| 169 |
-
)
|
| 170 |
-
SELECT
|
| 171 |
-
u.country,
|
| 172 |
-
u.plan,
|
| 173 |
-
AVG(rs.purchases) AS avg_purchases,
|
| 174 |
-
COUNT(DISTINCT rs.user_id) AS active_users,
|
| 175 |
-
PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY rs.event_count) AS median_events
|
| 176 |
-
FROM ranked_sessions rs
|
| 177 |
-
JOIN users u ON u.id = rs.user_id
|
| 178 |
-
WHERE rs.rn = 1
|
| 179 |
-
AND u.plan::text IN ('premium', 'enterprise')
|
| 180 |
-
GROUP BY u.country, u.plan
|
| 181 |
-
HAVING COUNT(DISTINCT rs.user_id) > 10
|
| 182 |
-
ORDER BY avg_purchases DESC;
|
| 183 |
-
""",
|
| 184 |
"ground_truth_issues": [
|
| 185 |
{
|
| 186 |
-
"type": "
|
| 187 |
-
"line":
|
| 188 |
-
"keywords": [
|
|
|
|
|
|
|
|
|
|
| 189 |
},
|
| 190 |
{
|
| 191 |
-
"type": "
|
| 192 |
-
"line":
|
| 193 |
-
"keywords": [
|
|
|
|
|
|
|
|
|
|
| 194 |
},
|
| 195 |
{
|
| 196 |
-
"type": "
|
| 197 |
-
"line":
|
| 198 |
-
"keywords": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
{
|
| 201 |
-
"type": "
|
| 202 |
-
"line":
|
| 203 |
-
"keywords": [
|
|
|
|
|
|
|
|
|
|
| 204 |
},
|
| 205 |
{
|
| 206 |
-
"type": "
|
| 207 |
"line": 8,
|
| 208 |
-
"keywords": [
|
|
|
|
|
|
|
|
|
|
| 209 |
},
|
| 210 |
{
|
| 211 |
-
"type": "
|
| 212 |
-
"line":
|
| 213 |
-
"keywords": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
},
|
| 215 |
],
|
| 216 |
"approved_expected": False,
|
|
@@ -218,20 +398,21 @@ ORDER BY avg_purchases DESC;
|
|
| 218 |
}
|
| 219 |
|
| 220 |
|
| 221 |
-
def get_task_list()
|
| 222 |
return [
|
| 223 |
{
|
| 224 |
-
"task_id":
|
| 225 |
-
"task_name":
|
| 226 |
"difficulty": t["difficulty"],
|
|
|
|
| 227 |
"description": t["task_description"],
|
| 228 |
"action_schema": {
|
| 229 |
-
"suggestions":
|
| 230 |
-
"optimized_query":
|
| 231 |
-
"summary":
|
| 232 |
-
"estimated_improvement": "str β expected
|
| 233 |
-
"approved":
|
| 234 |
-
}
|
| 235 |
}
|
| 236 |
for t in TASKS.values()
|
| 237 |
]
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
tasks.py β SQL Query Optimization Tasks
|
| 3 |
+
========================================
|
| 4 |
+
Five tasks of increasing difficulty, each with a DuckDB-executable
|
| 5 |
+
"bad" query (stored in sql_query) that agents must analyze and rewrite.
|
| 6 |
+
|
| 7 |
+
All queries run against the executor's synthetic tables:
|
| 8 |
+
users (10,000 rows) β id, email, tier, region, plan, created_at
|
| 9 |
+
orders (500,000 rows) β id, customer_id, product_id, status, total, created_at
|
| 10 |
+
products (1,000 rows) β id, name, category, price
|
| 11 |
+
events (1,000,000 rows) β id, user_id, session_id, event_type, occurred_at
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from typing import Any, Dict, List
|
| 15 |
|
| 16 |
TASKS: Dict[str, Dict[str, Any]] = {
|
| 17 |
|
| 18 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 19 |
+
# TASK 1 β EASY: Basic Anti-pattern Detection
|
| 20 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 21 |
"task_1_basic_antipatterns": {
|
| 22 |
+
"task_id": "task_1_basic_antipatterns",
|
| 23 |
"task_name": "Basic SQL Anti-pattern Detection",
|
| 24 |
"task_description": (
|
| 25 |
+
"Analyze the SQL query below for common anti-patterns that destroy performance. "
|
| 26 |
+
"Identify: SELECT * (fetches unnecessary columns from 500k rows), "
|
| 27 |
+
"CAST on a filter column (prevents any index or min/max pruning), "
|
| 28 |
+
"and a function applied to a date column (forces full table evaluation). "
|
| 29 |
+
"For each issue report: issue_type, line, description, severity, and a concrete fix. "
|
| 30 |
+
"Also provide a fully rewritten optimized_query β it will be EXECUTED against "
|
| 31 |
+
"real data and your speedup will be measured."
|
| 32 |
),
|
| 33 |
"difficulty": "easy",
|
| 34 |
+
"dialect": "duckdb/postgresql",
|
| 35 |
"max_steps": 3,
|
| 36 |
+
"schema_info": (
|
| 37 |
+
"Table: orders (500,000 rows)\n"
|
| 38 |
+
" id INT, customer_id INT, product_id INT,\n"
|
| 39 |
+
" status VARCHAR, total DECIMAL, created_at DATE\n\n"
|
| 40 |
+
"No indexes defined (DuckDB uses columnar min/max pruning when columns "
|
| 41 |
+
"are not wrapped in functions).\n"
|
| 42 |
+
"Scan cost: ~500k rows Γ all columns with SELECT *"
|
| 43 |
+
),
|
| 44 |
+
"sql_query": (
|
| 45 |
+
"SELECT *\n"
|
| 46 |
+
"FROM orders\n"
|
| 47 |
+
"WHERE CAST(customer_id AS VARCHAR) = '5000'\n"
|
| 48 |
+
" AND year(created_at) = 2024;"
|
| 49 |
+
),
|
| 50 |
"ground_truth_issues": [
|
| 51 |
{
|
| 52 |
"type": "select_star",
|
| 53 |
+
"line": 1,
|
| 54 |
+
"keywords": [
|
| 55 |
+
"select *", "star", "all columns", "unnecessary columns",
|
| 56 |
+
"column projection", "specify columns", "bandwidth",
|
| 57 |
+
],
|
| 58 |
},
|
| 59 |
{
|
| 60 |
+
"type": "non_sargable_cast",
|
| 61 |
+
"line": 3,
|
| 62 |
+
"keywords": [
|
| 63 |
+
"cast", "varchar", "type cast", "type conversion",
|
| 64 |
+
"non-sargable", "sargable", "integer comparison",
|
| 65 |
+
"string comparison", "prevents", "pruning",
|
| 66 |
+
],
|
| 67 |
},
|
| 68 |
{
|
| 69 |
+
"type": "function_on_date_column",
|
| 70 |
+
"line": 4,
|
| 71 |
+
"keywords": [
|
| 72 |
+
"year(", "function on column", "non-sargable", "date range",
|
| 73 |
+
"between", "extract", "full scan", "date filter",
|
| 74 |
+
],
|
| 75 |
},
|
| 76 |
],
|
| 77 |
"approved_expected": False,
|
| 78 |
},
|
| 79 |
|
| 80 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 81 |
+
# TASK 2 β MEDIUM: N+1 Correlated Subqueries
|
| 82 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 83 |
+
"task_2_correlated_subqueries": {
|
| 84 |
+
"task_id": "task_2_correlated_subqueries",
|
| 85 |
+
"task_name": "N+1 Correlated Subquery Elimination",
|
| 86 |
"task_description": (
|
| 87 |
+
"The query below uses three correlated scalar subqueries β each one scans "
|
| 88 |
+
"the entire orders table (500k rows) once per premium user (~3,300 users). "
|
| 89 |
+
"That's ~10 million row reads just for aggregation. "
|
| 90 |
+
"Identify the N+1 pattern, explain why each subquery is harmful, "
|
| 91 |
+
"and rewrite the query as a single aggregation JOIN. "
|
| 92 |
+
"Your optimized_query will be executed; results must match the original."
|
| 93 |
),
|
| 94 |
"difficulty": "medium",
|
| 95 |
+
"dialect": "duckdb/postgresql",
|
| 96 |
"max_steps": 4,
|
| 97 |
+
"schema_info": (
|
| 98 |
+
"Table: users (10,000 rows)\n"
|
| 99 |
+
" id INT, email VARCHAR, tier VARCHAR, region VARCHAR,\n"
|
| 100 |
+
" plan VARCHAR, created_at DATE\n\n"
|
| 101 |
+
"Table: orders (500,000 rows)\n"
|
| 102 |
+
" id INT, customer_id INT, product_id INT,\n"
|
| 103 |
+
" status VARCHAR, total DECIMAL, created_at DATE\n\n"
|
| 104 |
+
"Premium users: ~3,300 | Orders per user avg: 50\n"
|
| 105 |
+
"Worst-case scans: 3 subqueries Γ 3,300 users Γ 500k rows = ~5B row reads"
|
| 106 |
+
),
|
| 107 |
+
"sql_query": (
|
| 108 |
+
"SELECT\n"
|
| 109 |
+
" u.email,\n"
|
| 110 |
+
" u.region,\n"
|
| 111 |
+
" (SELECT COUNT(*)\n"
|
| 112 |
+
" FROM orders o\n"
|
| 113 |
+
" WHERE o.customer_id = u.id AND o.status = 'completed') AS completed_orders,\n"
|
| 114 |
+
" (SELECT SUM(o.total)\n"
|
| 115 |
+
" FROM orders o\n"
|
| 116 |
+
" WHERE o.customer_id = u.id\n"
|
| 117 |
+
" AND o.created_at >= DATE '2024-01-01') AS ytd_spend,\n"
|
| 118 |
+
" (SELECT total\n"
|
| 119 |
+
" FROM orders o\n"
|
| 120 |
+
" WHERE o.customer_id = u.id\n"
|
| 121 |
+
" ORDER BY created_at DESC LIMIT 1) AS last_order_amount\n"
|
| 122 |
+
"FROM users u\n"
|
| 123 |
+
"WHERE u.tier = 'premium';"
|
| 124 |
+
),
|
| 125 |
"ground_truth_issues": [
|
| 126 |
{
|
| 127 |
+
"type": "correlated_subquery_count",
|
| 128 |
"line": 4,
|
| 129 |
+
"keywords": [
|
| 130 |
+
"correlated", "subquery", "per row", "n+1", "each user",
|
| 131 |
+
"repeated scan", "join", "aggregation",
|
| 132 |
+
],
|
| 133 |
},
|
| 134 |
{
|
| 135 |
+
"type": "correlated_subquery_sum",
|
| 136 |
+
"line": 7,
|
| 137 |
+
"keywords": [
|
| 138 |
+
"correlated", "subquery", "per row", "n+1", "each user",
|
| 139 |
+
"repeated scan", "join", "group by",
|
| 140 |
+
],
|
| 141 |
},
|
| 142 |
{
|
| 143 |
+
"type": "correlated_subquery_limit",
|
| 144 |
+
"line": 11,
|
| 145 |
+
"keywords": [
|
| 146 |
+
"correlated", "subquery", "limit 1", "order by", "lateral",
|
| 147 |
+
"row_number", "rank", "window function", "per row",
|
| 148 |
+
],
|
| 149 |
+
},
|
| 150 |
+
{
|
| 151 |
+
"type": "missing_aggregation_join",
|
| 152 |
+
"line": 16,
|
| 153 |
+
"keywords": [
|
| 154 |
+
"left join", "group by", "aggreg", "single pass",
|
| 155 |
+
"coalesce", "join aggregat",
|
| 156 |
+
],
|
| 157 |
+
},
|
| 158 |
+
],
|
| 159 |
+
"approved_expected": False,
|
| 160 |
+
},
|
| 161 |
+
|
| 162 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 163 |
+
# TASK 3 β MEDIUM-HARD: Wildcard LIKE Full Scan
|
| 164 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 165 |
+
"task_3_wildcard_scan": {
|
| 166 |
+
"task_id": "task_3_wildcard_scan",
|
| 167 |
+
"task_name": "Wildcard LIKE & Projection Optimization",
|
| 168 |
+
"task_description": (
|
| 169 |
+
"The query scans all 1,000,000 events rows with leading and trailing wildcard "
|
| 170 |
+
"LIKE patterns β these disable min/max pruning and force full column evaluation. "
|
| 171 |
+
"It also computes derived columns for every row before filtering. "
|
| 172 |
+
"Identify: leading-wildcard LIKE patterns that kill pruning, "
|
| 173 |
+
"SELECT * on a million-row table, redundant OR conditions, "
|
| 174 |
+
"and unnecessary computed columns evaluated before WHERE filtering. "
|
| 175 |
+
"Rewrite to use exact equality and minimal column projection."
|
| 176 |
+
),
|
| 177 |
+
"difficulty": "medium-hard",
|
| 178 |
+
"dialect": "duckdb/postgresql",
|
| 179 |
+
"max_steps": 4,
|
| 180 |
+
"schema_info": (
|
| 181 |
+
"Table: events (1,000,000 rows)\n"
|
| 182 |
+
" id INT, user_id INT, session_id VARCHAR,\n"
|
| 183 |
+
" event_type VARCHAR, occurred_at DATE\n\n"
|
| 184 |
+
"Distinct event_type values: purchase, view, click, signup, logout, search\n"
|
| 185 |
+
"Wildcard LIKE on all 1M rows: forces full column scan\n"
|
| 186 |
+
"Exact equality match: enables columnar zone-map pruning"
|
| 187 |
+
),
|
| 188 |
+
"sql_query": (
|
| 189 |
+
"SELECT\n"
|
| 190 |
+
" *,\n"
|
| 191 |
+
" CAST(id AS VARCHAR) || '_' || event_type AS event_key,\n"
|
| 192 |
+
" upper(event_type) AS event_type_upper\n"
|
| 193 |
+
"FROM events\n"
|
| 194 |
+
"WHERE event_type LIKE '%purchase%'\n"
|
| 195 |
+
" OR event_type LIKE '%buy%'\n"
|
| 196 |
+
" OR session_id LIKE 'sess_%';"
|
| 197 |
+
),
|
| 198 |
+
"ground_truth_issues": [
|
| 199 |
+
{
|
| 200 |
+
"type": "leading_wildcard_like",
|
| 201 |
"line": 6,
|
| 202 |
+
"keywords": [
|
| 203 |
+
"leading wildcard", "like '%", "full scan", "exact match",
|
| 204 |
+
"equality", "pruning disabled", "wildcard", "zone map",
|
| 205 |
+
],
|
| 206 |
},
|
| 207 |
{
|
| 208 |
+
"type": "or_expands_to_full_scan",
|
| 209 |
+
"line": 7,
|
| 210 |
+
"keywords": [
|
| 211 |
+
"or condition", "union", "separate queries", "or expands",
|
| 212 |
+
"full scan", "like '%buy%'", "redundant",
|
| 213 |
+
],
|
| 214 |
},
|
| 215 |
{
|
| 216 |
+
"type": "select_star_large_table",
|
| 217 |
+
"line": 2,
|
| 218 |
+
"keywords": [
|
| 219 |
+
"select *", "1 million", "all columns", "projection",
|
| 220 |
+
"column pruning", "unnecessary", "bandwidth",
|
| 221 |
+
],
|
| 222 |
+
},
|
| 223 |
+
{
|
| 224 |
+
"type": "pre_filter_computed_columns",
|
| 225 |
+
"line": 3,
|
| 226 |
+
"keywords": [
|
| 227 |
+
"computed column", "derived", "upper(", "cast(", "concatenat",
|
| 228 |
+
"before filter", "pre-filter", "push down", "CTE",
|
| 229 |
+
],
|
| 230 |
},
|
| 231 |
],
|
| 232 |
"approved_expected": False,
|
| 233 |
},
|
| 234 |
|
| 235 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 236 |
+
# TASK 4 β HARD: Implicit Cross Join + Repeated Scalar Subqueries
|
| 237 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 238 |
+
"task_4_implicit_join": {
|
| 239 |
+
"task_id": "task_4_implicit_join",
|
| 240 |
+
"task_name": "Implicit Cross Join & Scalar Subquery Elimination",
|
| 241 |
"task_description": (
|
| 242 |
+
"This query uses comma-separated FROM (implicit cross join syntax) and "
|
| 243 |
+
"two correlated scalar subqueries that re-aggregate the entire orders table "
|
| 244 |
+
"once per GROUP BY group. "
|
| 245 |
+
"Identify: implicit cross join risk (comma in FROM clause), "
|
| 246 |
+
"two correlated scalar subqueries recalculating global stats, "
|
| 247 |
+
"and the GROUP BY without an explicit JOIN. "
|
| 248 |
+
"Rewrite using explicit INNER JOIN and a CTE/subquery for the global stats "
|
| 249 |
+
"so they are computed exactly once."
|
| 250 |
),
|
| 251 |
"difficulty": "hard",
|
| 252 |
+
"dialect": "duckdb/postgresql",
|
| 253 |
"max_steps": 5,
|
| 254 |
+
"schema_info": (
|
| 255 |
+
"Table: users (10,000 rows) β id, email, tier, region, plan, created_at\n"
|
| 256 |
+
"Table: orders (500,000 rows) β id, customer_id, product_id, status, total, created_at\n\n"
|
| 257 |
+
"Join: users.id = orders.customer_id\n"
|
| 258 |
+
"Implicit join (comma syntax) risk: if WHERE predicate is missing,\n"
|
| 259 |
+
"produces a Cartesian product of 10k Γ 500k = 5 BILLION rows.\n"
|
| 260 |
+
"Scalar subqueries: each recalculates over all 500k orders per group."
|
| 261 |
+
),
|
| 262 |
+
"sql_query": (
|
| 263 |
+
"SELECT\n"
|
| 264 |
+
" u.region,\n"
|
| 265 |
+
" u.plan,\n"
|
| 266 |
+
" COUNT(*) AS total_orders,\n"
|
| 267 |
+
" SUM(o.total) AS revenue,\n"
|
| 268 |
+
" (SELECT AVG(total) FROM orders) AS global_avg,\n"
|
| 269 |
+
" (SELECT MAX(total) FROM orders WHERE status = 'completed') AS max_deal\n"
|
| 270 |
+
"FROM users u, orders o\n"
|
| 271 |
+
"WHERE u.id = o.customer_id\n"
|
| 272 |
+
" AND o.status IN ('completed', 'shipped')\n"
|
| 273 |
+
"GROUP BY u.region, u.plan;"
|
| 274 |
+
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
"ground_truth_issues": [
|
| 276 |
{
|
| 277 |
+
"type": "implicit_cross_join",
|
| 278 |
+
"line": 8,
|
| 279 |
+
"keywords": [
|
| 280 |
+
"implicit", "cross join", "comma join", "explicit join",
|
| 281 |
+
"inner join", "cartesian", "comma in from",
|
| 282 |
+
],
|
| 283 |
},
|
| 284 |
{
|
| 285 |
+
"type": "repeated_scalar_subquery_avg",
|
| 286 |
+
"line": 6,
|
| 287 |
+
"keywords": [
|
| 288 |
+
"scalar subquery", "correlated", "per group", "once per row",
|
| 289 |
+
"cte", "with clause", "pre-compute", "global avg",
|
| 290 |
+
],
|
| 291 |
},
|
| 292 |
{
|
| 293 |
+
"type": "repeated_scalar_subquery_max",
|
| 294 |
+
"line": 7,
|
| 295 |
+
"keywords": [
|
| 296 |
+
"scalar subquery", "correlated", "per group", "max deal",
|
| 297 |
+
"cte", "pre-compute", "compute once", "constant",
|
| 298 |
+
],
|
| 299 |
+
},
|
| 300 |
+
{
|
| 301 |
+
"type": "missing_explicit_join",
|
| 302 |
+
"line": 8,
|
| 303 |
+
"keywords": [
|
| 304 |
+
"inner join", "explicit", "on clause", "join condition",
|
| 305 |
+
"readable", "maintainable", "ansi sql",
|
| 306 |
+
],
|
| 307 |
},
|
| 308 |
+
],
|
| 309 |
+
"approved_expected": False,
|
| 310 |
+
},
|
| 311 |
+
|
| 312 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 313 |
+
# TASK 5 β EXPERT: Window Function Over Entire 1M-Row Table
|
| 314 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 315 |
+
"task_5_window_functions": {
|
| 316 |
+
"task_id": "task_5_window_functions",
|
| 317 |
+
"task_name": "Window Function & Full-Scan Audit",
|
| 318 |
+
"task_description": (
|
| 319 |
+
"Five window functions are computed over ALL 1,000,000 events rows with no "
|
| 320 |
+
"pre-filtering. Each OVER() clause requires a full sort or hash-aggregate pass. "
|
| 321 |
+
"The global RANK() OVER (ORDER BY occurred_at) requires sorting the entire "
|
| 322 |
+
"table β the most expensive operation here. "
|
| 323 |
+
"Identify: no WHERE clause causing full 1M-row scans, "
|
| 324 |
+
"redundant window functions that can be merged, "
|
| 325 |
+
"a global ordering window function with no PARTITION, "
|
| 326 |
+
"and SELECT * on the full events table. "
|
| 327 |
+
"Rewrite to filter first, merge windows, and remove the global RANK."
|
| 328 |
+
),
|
| 329 |
+
"difficulty": "expert",
|
| 330 |
+
"dialect": "duckdb/postgresql",
|
| 331 |
+
"max_steps": 5,
|
| 332 |
+
"schema_info": (
|
| 333 |
+
"Table: events (1,000,000 rows)\n"
|
| 334 |
+
" id INT, user_id INT, session_id VARCHAR,\n"
|
| 335 |
+
" event_type VARCHAR, occurred_at DATE\n\n"
|
| 336 |
+
"Window function cost: each OVER() = full sort/hash pass over 1M rows\n"
|
| 337 |
+
"5 window functions = 5 full passes before any filtering\n"
|
| 338 |
+
"Global RANK(): sorts all 1M rows globally β most expensive operation\n"
|
| 339 |
+
"Filtering to 'purchase' events first reduces dataset to ~167k rows (1/6)"
|
| 340 |
+
),
|
| 341 |
+
"sql_query": (
|
| 342 |
+
"SELECT\n"
|
| 343 |
+
" user_id,\n"
|
| 344 |
+
" event_type,\n"
|
| 345 |
+
" occurred_at,\n"
|
| 346 |
+
" COUNT(*) OVER (PARTITION BY user_id) AS total_user_events,\n"
|
| 347 |
+
" COUNT(*) OVER (PARTITION BY user_id, event_type) AS type_count,\n"
|
| 348 |
+
" ROW_NUMBER() OVER (PARTITION BY user_id ORDER BY occurred_at DESC) AS recency_rank,\n"
|
| 349 |
+
" RANK() OVER (ORDER BY occurred_at DESC) AS global_rank,\n"
|
| 350 |
+
" SUM(CASE WHEN event_type = 'purchase' THEN 1 ELSE 0 END)\n"
|
| 351 |
+
" OVER (PARTITION BY user_id) AS user_purchases\n"
|
| 352 |
+
"FROM events;"
|
| 353 |
+
),
|
| 354 |
+
"ground_truth_issues": [
|
| 355 |
{
|
| 356 |
+
"type": "no_pre_filter",
|
| 357 |
+
"line": 11,
|
| 358 |
+
"keywords": [
|
| 359 |
+
"no where", "no filter", "full table", "1 million", "all rows",
|
| 360 |
+
"pre-filter", "filter first", "cte", "with clause",
|
| 361 |
+
],
|
| 362 |
},
|
| 363 |
{
|
| 364 |
+
"type": "global_rank_no_partition",
|
| 365 |
"line": 8,
|
| 366 |
+
"keywords": [
|
| 367 |
+
"rank() over", "global rank", "no partition", "entire table",
|
| 368 |
+
"full sort", "expensive", "global ordering", "remove",
|
| 369 |
+
],
|
| 370 |
},
|
| 371 |
{
|
| 372 |
+
"type": "redundant_window_functions",
|
| 373 |
+
"line": 5,
|
| 374 |
+
"keywords": [
|
| 375 |
+
"5 window", "multiple over", "redundant", "merge", "combine",
|
| 376 |
+
"single pass", "same partition", "consolidate",
|
| 377 |
+
],
|
| 378 |
+
},
|
| 379 |
+
{
|
| 380 |
+
"type": "count_vs_conditional_sum",
|
| 381 |
+
"line": 9,
|
| 382 |
+
"keywords": [
|
| 383 |
+
"case when", "sum case", "count filter", "filter clause",
|
| 384 |
+
"count(*) filter", "simpler", "merge with",
|
| 385 |
+
],
|
| 386 |
+
},
|
| 387 |
+
{
|
| 388 |
+
"type": "select_all_unfiltered",
|
| 389 |
+
"line": 1,
|
| 390 |
+
"keywords": [
|
| 391 |
+
"select *", "user_id, event_type", "projection", "column pruning",
|
| 392 |
+
"select specific", "1 million rows", "bandwidth",
|
| 393 |
+
],
|
| 394 |
},
|
| 395 |
],
|
| 396 |
"approved_expected": False,
|
|
|
|
| 398 |
}
|
| 399 |
|
| 400 |
|
| 401 |
+
def get_task_list():
|
| 402 |
return [
|
| 403 |
{
|
| 404 |
+
"task_id": t["task_id"],
|
| 405 |
+
"task_name": t["task_name"],
|
| 406 |
"difficulty": t["difficulty"],
|
| 407 |
+
"max_steps": t["max_steps"],
|
| 408 |
"description": t["task_description"],
|
| 409 |
"action_schema": {
|
| 410 |
+
"suggestions": "List of {issue_type, line, description, severity, fix}",
|
| 411 |
+
"optimized_query": "str β complete rewritten SQL (will be EXECUTED for real timing)",
|
| 412 |
+
"summary": "str β overall performance analysis",
|
| 413 |
+
"estimated_improvement": "str β expected speedup (e.g. '10x faster')",
|
| 414 |
+
"approved": "bool β True if already optimal",
|
| 415 |
+
},
|
| 416 |
}
|
| 417 |
for t in TASKS.values()
|
| 418 |
]
|
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import sys
|
| 3 |
+
sys.path.insert(0, r'c:\Users\ishua\OneDrive\Desktop\meta-2')
|
| 4 |
+
|
| 5 |
+
print('Testing DuckDB executor...')
|
| 6 |
+
t0 = time.time()
|
| 7 |
+
from executor import get_executor
|
| 8 |
+
ex = get_executor()
|
| 9 |
+
print(f'Tables built in {time.time()-t0:.1f}s')
|
| 10 |
+
print('Table stats:', ex.table_stats)
|
| 11 |
+
|
| 12 |
+
print()
|
| 13 |
+
print('Testing real query comparison (Task 1)...')
|
| 14 |
+
from tasks import TASKS
|
| 15 |
+
task = TASKS['task_1_basic_antipatterns']
|
| 16 |
+
original = task['sql_query']
|
| 17 |
+
optimized = "SELECT id, customer_id, status, total, created_at FROM orders WHERE customer_id = 5000 AND created_at >= DATE '2024-01-01' AND created_at < DATE '2025-01-01'"
|
| 18 |
+
|
| 19 |
+
result = ex.compare(original, optimized)
|
| 20 |
+
print(f" Original : {result['original_ms']:.1f} ms ({result['original_rows']} rows)")
|
| 21 |
+
print(f" Optimized: {result['optimized_ms']:.1f} ms ({result['optimized_rows']} rows)")
|
| 22 |
+
print(f" Speedup : {result['speedup']:.2f}x")
|
| 23 |
+
print(f" Correct : {result['results_match']}")
|
| 24 |
+
print(f" Verdict : {result['verdict']}")
|
| 25 |
+
|
| 26 |
+
print()
|
| 27 |
+
print('Testing full grader...')
|
| 28 |
+
from graders import grade
|
| 29 |
+
from models import Action
|
| 30 |
+
|
| 31 |
+
action = Action(
|
| 32 |
+
suggestions=[
|
| 33 |
+
{"issue_type": "select_star", "line": 1, "description": "SELECT * fetches all columns unnecessarily from 500k rows", "severity": "medium", "fix": "Select only needed columns"},
|
| 34 |
+
{"issue_type": "non_sargable_cast", "line": 3, "description": "CAST on customer_id prevents columnar pruning", "severity": "high", "fix": "Use direct integer comparison"},
|
| 35 |
+
{"issue_type": "function_on_date_column", "line": 4, "description": "year() on created_at forces full column evaluation", "severity": "high", "fix": "Use date range with BETWEEN"},
|
| 36 |
+
],
|
| 37 |
+
optimized_query=optimized,
|
| 38 |
+
summary="Three anti-patterns identified: SELECT * wastes bandwidth, CAST and year() prevent DuckDB zone-map pruning causing full 500k row scans.",
|
| 39 |
+
estimated_improvement="5-10x faster by enabling columnar pruning and reducing I/O",
|
| 40 |
+
approved=False
|
| 41 |
+
)
|
| 42 |
+
reward = grade(task, action)
|
| 43 |
+
print(f" Score : {reward.score}")
|
| 44 |
+
print(f" Breakdown: {reward.breakdown}")
|
| 45 |
+
print(reward.feedback[:300])
|
| 46 |
+
print()
|
| 47 |
+
print('ALL TESTS PASSED!')
|