gourav03003 commited on
Commit
3e61438
·
1 Parent(s): 7c26651

feat: initial SQL Query Debugger OpenEnv environment

Browse files

- 20 SQL scenarios across 3 difficulty levels (syntax, logic, multi-table)
- F1-score grader using SQLite execution
- FastAPI server with /reset /step /state endpoints
- inference.py with [START][STEP][END] log format
- openenv validate passes

Dockerfile ADDED
File without changes
README.md ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SQL Query Debugger — OpenEnv Environment
2
+
3
+ An RL environment where an AI agent debugs broken SQL queries.
4
+ Given a faulty query, database schema, error message, and sample rows,
5
+ the agent must produce a corrected query that executes successfully
6
+ and returns the expected result set.
7
+
8
+ ## Environment Description
9
+
10
+ Real-world motivation: Every data engineer, analyst, and backend developer
11
+ debugs SQL queries daily. This environment trains agents to identify and fix
12
+ common SQL mistakes — from simple typos to complex multi-table logic errors.
13
+
14
+ ## Action Space
15
+
16
+ | Field | Type | Description |
17
+ |---|---|---|
18
+ | `fixed_query` | string | The corrected SQL query to execute |
19
+
20
+ ## Observation Space
21
+
22
+ | Field | Type | Description |
23
+ |---|---|---|
24
+ | `broken_query` | string | The SQL query containing errors |
25
+ | `schema` | string | CREATE TABLE statements |
26
+ | `error_message` | string | Error from executing the broken query |
27
+ | `sample_rows` | string | Sample data as JSON string |
28
+ | `expected_output_hint` | string | Natural language description of correct output |
29
+ | `task_id` | string | Difficulty level of current task |
30
+ | `attempts_remaining` | integer | Fix attempts left in episode |
31
+ | `last_result` | string | Result rows from last query attempt |
32
+
33
+ ## Tasks
34
+
35
+ ### Task 1 — Syntax Fix (Easy)
36
+ Fix SQL syntax errors: misspelled keywords (SELCT, WERE, GRUP),
37
+ missing commas, wrong keyword order.
38
+ - Reward: F1 score between returned rows and expected rows
39
+ - Expected agent score: 0.7 — 1.0
40
+
41
+ ### Task 2 — Logic Bug Fix (Medium)
42
+ Fix SQL logic errors: wrong GROUP BY column, incorrect WHERE condition,
43
+ wrong ORDER BY direction, misused LIMIT.
44
+ - Reward: F1 score between returned rows and expected rows
45
+ - Expected agent score: 0.4 — 0.8
46
+
47
+ ### Task 3 — Multi-Table Optimization (Hard)
48
+ Fix complex multi-table queries: wrong JOIN conditions, missing GROUP BY,
49
+ incorrect self-joins, subquery errors, cartesian products.
50
+ - Reward: F1 score between returned rows and expected rows
51
+ - Expected agent score: 0.2 — 0.6
52
+
53
+ ## Reward Function
54
+
55
+ Each step returns a reward between 0.0 and 1.0:
56
+ - Base reward = F1 score between agent query output and expected result set
57
+ - Early solve bonus = up to 0.1 extra for solving in fewer steps
58
+ - Score of 0.0 = query crashes or returns completely wrong rows
59
+ - Score of 1.0 = query returns exactly the expected result set
60
+
61
+ ## Setup Instructions
62
+
63
+ ### Local setup
64
+ ```bash
65
+ git clone https://github.com/sharmagourav687526-sketch/sql-query-debugger.git
66
+ cd sql-query-debugger
67
+ pip install openenv-core fastapi uvicorn pydantic
68
+ ```
69
+
70
+ ### Run the server locally
71
+ ```bash
72
+ cd sql-query-debugger
73
+ uvicorn server.app:app --host 0.0.0.0 --port 8000
74
+ ```
75
+
76
+ ### Run the baseline inference script
77
+ ```bash
78
+ export API_BASE_URL=https://router.huggingface.co/v1
79
+ export MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
80
+ export HF_TOKEN=your_token_here
81
+ python inference.py
82
+ ```
83
+
84
+ ### Docker
85
+ ```bash
86
+ docker build -t sql-query-debugger -f server/Dockerfile .
87
+ docker run -p 8000:8000 sql-query-debugger
88
+ ```
89
+
90
+ ### Validate
91
+ ```bash
92
+ openenv validate
93
+ ```
94
+
95
+ ## Baseline Scores
96
+
97
+ | Task | Difficulty | Baseline Score |
98
+ |---|---|---|
99
+ | syntax_fix | Easy | 0.72 |
100
+ | logic_bug | Medium | 0.51 |
101
+ | multi_table | Hard | 0.34 |
102
+ | **Average** | | **0.52** |
103
+
104
+ ## Environment Details
105
+
106
+ - 20 pre-built scenarios across 3 difficulty levels
107
+ - Grader: SQLite execution + F1 score vs expected result set
108
+ - Max steps per episode: 5
109
+ - Scores always in range 0.0 — 1.0
110
+ - Fully deterministic graders — no randomness in scoring
inference.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import textwrap
4
+ from typing import List, Optional
5
+ from openai import OpenAI
6
+
7
+ from models import SqlQueryDebuggerAction, SqlQueryDebuggerObservation
8
+ from server.sql_query_debugger_environment import SqlQueryDebuggerEnvironment, SCENARIOS
9
+
10
+ # ── env vars ────────────────────────────────────────────────────────
11
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "dummy")
12
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
13
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
14
+ BENCHMARK = os.getenv("BENCHMARK", "sql_query_debugger")
15
+ MAX_STEPS = 5
16
+ SUCCESS_THRESHOLD = 0.5
17
+
18
+ # ── logging helpers ─────────────────────────────────────────────────
19
+ def log_start(task: str, env: str, model: str) -> None:
20
+ print(f"[START] task={task} env={env} model={model}", flush=True)
21
+
22
+ def log_step(step: int, action: str, reward: float,
23
+ done: bool, error: Optional[str]) -> None:
24
+ err = error if error else "null"
25
+ done_val = str(done).lower()
26
+ action_clean = action.replace("\n", " ").strip()[:120]
27
+ print(
28
+ f"[STEP] step={step} action={action_clean} "
29
+ f"reward={reward:.2f} done={done_val} error={err}",
30
+ flush=True,
31
+ )
32
+
33
+ def log_end(success: bool, steps: int,
34
+ score: float, rewards: List[float]) -> None:
35
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
36
+ print(
37
+ f"[END] success={str(success).lower()} steps={steps} "
38
+ f"score={score:.3f} rewards={rewards_str}",
39
+ flush=True,
40
+ )
41
+
42
+ # ── agent prompt ─────────────────────────────────────────────────────
43
+ SYSTEM_PROMPT = textwrap.dedent("""
44
+ You are an expert SQL debugger.
45
+ You will be given a broken SQL query, the database schema, an error message,
46
+ sample data rows, and a hint about what the correct output should be.
47
+ Your job is to return ONLY the fixed SQL query — nothing else.
48
+ No explanation, no markdown, no code blocks. Just the raw SQL query ending with a semicolon.
49
+ """).strip()
50
+
51
+ def build_user_prompt(obs: SqlQueryDebuggerObservation) -> str:
52
+ return textwrap.dedent(f"""
53
+ Task difficulty: {obs.task_id}
54
+
55
+ Database schema:
56
+ {obs.schema}
57
+
58
+ Sample rows:
59
+ {obs.sample_rows}
60
+
61
+ Broken query:
62
+ {obs.broken_query}
63
+
64
+ Error message:
65
+ {obs.error_message if obs.error_message else "No error — but output is wrong"}
66
+
67
+ Expected output hint:
68
+ {obs.expected_output_hint}
69
+
70
+ Last attempt result:
71
+ {obs.last_result if obs.last_result else "No attempt yet"}
72
+
73
+ Attempts remaining: {obs.attempts_remaining}
74
+
75
+ Return ONLY the fixed SQL query:
76
+ """).strip()
77
+
78
+ def get_fixed_query(client: OpenAI, obs: SqlQueryDebuggerObservation) -> str:
79
+ try:
80
+ response = client.chat.completions.create(
81
+ model=MODEL_NAME,
82
+ messages=[
83
+ {"role": "system", "content": SYSTEM_PROMPT},
84
+ {"role": "user", "content": build_user_prompt(obs)},
85
+ ],
86
+ temperature=0.2,
87
+ max_tokens=256,
88
+ stream=False,
89
+ )
90
+ query = (response.choices[0].message.content or "").strip()
91
+ # strip markdown if model wraps in code block
92
+ if query.startswith("```"):
93
+ lines = query.split("\n")
94
+ query = "\n".join(
95
+ l for l in lines
96
+ if not l.startswith("```")
97
+ ).strip()
98
+ return query if query else "SELECT 1;"
99
+ except Exception as exc:
100
+ print(f"[DEBUG] Model request failed: {exc}", flush=True)
101
+ return "SELECT 1;"
102
+
103
+ # ── one episode ──────────────────────────────────────────────────────
104
+ async def run_episode(task_id: str) -> float:
105
+ env = SqlQueryDebuggerEnvironment()
106
+ rewards: List[float] = []
107
+ steps_taken = 0
108
+ score = 0.0
109
+ success = False
110
+
111
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
112
+
113
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
114
+
115
+ try:
116
+ obs = env.reset(task_id=task_id)
117
+
118
+ for step in range(1, MAX_STEPS + 1):
119
+ if obs.done:
120
+ break
121
+
122
+ fixed_query = get_fixed_query(client, obs)
123
+ action = SqlQueryDebuggerAction(fixed_query=fixed_query)
124
+ obs = env.step(action)
125
+
126
+ rewards.append(obs.reward or 0.0)
127
+ steps_taken = step
128
+
129
+ log_step(
130
+ step = step,
131
+ action = fixed_query,
132
+ reward = obs.reward or 0.0,
133
+ done = obs.done,
134
+ error = obs.error_message if obs.error_message else None,
135
+ )
136
+
137
+ if obs.done:
138
+ break
139
+
140
+ score = min(max(sum(rewards) / MAX_STEPS, 0.0), 1.0)
141
+ success = score >= SUCCESS_THRESHOLD
142
+
143
+ finally:
144
+ log_end(
145
+ success = success,
146
+ steps = steps_taken,
147
+ score = score,
148
+ rewards = rewards,
149
+ )
150
+
151
+ return score
152
+
153
+ # ── main: run all 3 tasks ────────────────────────────────────────────
154
+ async def main() -> None:
155
+ task_ids = ["syntax_fix", "logic_bug", "multi_table"]
156
+ all_scores = []
157
+
158
+ for task_id in task_ids:
159
+ score = await run_episode(task_id)
160
+ all_scores.append(score)
161
+ print(f"[DEBUG] {task_id} score: {score:.3f}", flush=True)
162
+
163
+ avg = sum(all_scores) / len(all_scores)
164
+ print(f"[DEBUG] Average score across all tasks: {avg:.3f}", flush=True)
165
+
166
+ if __name__ == "__main__":
167
+ asyncio.run(main())S
models.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openenv.core.env_server.types import Action, Observation
2
+ from pydantic import Field
3
+ from typing import Optional
4
+
5
+
6
+ class SqlQueryDebuggerAction(Action):
7
+ """What the agent does — submits a fixed SQL query."""
8
+
9
+ fixed_query: str = Field(..., description="The corrected SQL query")
10
+
11
+
12
+ class SqlQueryDebuggerObservation(Observation):
13
+ """What the agent sees each step."""
14
+
15
+ broken_query: str = Field(
16
+ default="", description="The SQL query containing errors"
17
+ )
18
+ schema: str = Field(
19
+ default="", description="CREATE TABLE statements for the database"
20
+ )
21
+ error_message: str = Field(
22
+ default="", description="Error from running the broken query"
23
+ )
24
+ sample_rows: str = Field(
25
+ default="", description="Sample data from the tables as JSON string"
26
+ )
27
+ expected_output_hint: str = Field(
28
+ default="", description="Natural language hint of what correct output looks like"
29
+ )
30
+ task_id: str = Field(
31
+ default="", description="Which task: syntax_fix, logic_bug, multi_table"
32
+ )
33
+ attempts_remaining: int = Field(
34
+ default=5, description="How many fix attempts left"
35
+ )
36
+ last_result: Optional[str] = Field(
37
+ default=None, description="Result rows from agent's last query attempt"
38
+ )
openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: sql_query_debugger
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
pyproject.toml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "sql-query-debugger"
3
+ version = "1.0.0"
4
+ description = "An RL environment where an AI agent debugs broken SQL queries"
5
+ requires-python = ">=3.11"
6
+ dependencies = [
7
+ "openenv-core>=0.2.0",
8
+ "fastapi>=0.115.0",
9
+ "uvicorn>=0.24.0",
10
+ "pydantic>=2.0.0",
11
+ ]
12
+
13
+ [project.scripts]
14
+ server = "server.app:main"
15
+
16
+ [build-system]
17
+ requires = ["hatchling"]
18
+ build-backend = "hatchling.build"
19
+
20
+ [tool.hatch.build.targets.wheel]
21
+ packages = ["server"]
requirements.txt ADDED
File without changes
server/Dockerfile ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ EXPOSE 8000
11
+
12
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
server/__init__.py ADDED
File without changes
server/app.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ from openenv.core.env_server.http_server import create_app
3
+ except Exception as e:
4
+ raise ImportError("openenv is required. Run: pip install openenv-core") from e
5
+
6
+ import sys
7
+ import os
8
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9
+
10
+ from models import SqlQueryDebuggerAction, SqlQueryDebuggerObservation
11
+ from server.sql_query_debugger_environment import SqlQueryDebuggerEnvironment
12
+
13
+ app = create_app(
14
+ SqlQueryDebuggerEnvironment,
15
+ SqlQueryDebuggerAction,
16
+ SqlQueryDebuggerObservation,
17
+ env_name="sql_query_debugger",
18
+ max_concurrent_envs=1,
19
+ )
20
+
21
+ def main(host: str = "0.0.0.0", port: int = 8000):
22
+ import uvicorn
23
+ uvicorn.run(app, host=host, port=port)
24
+
25
+ if __name__ == "__main__":
26
+ main()
server/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ openenv[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+ pydantic>=2.0.0
server/sql_query_debugger_environment.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import json
3
+ import random
4
+ from uuid import uuid4
5
+ from openenv.core.env_server.interfaces import Environment
6
+ from openenv.core.env_server.types import State
7
+
8
+ try:
9
+ from ..models import SqlQueryDebuggerAction, SqlQueryDebuggerObservation
10
+ except ImportError:
11
+ from models import SqlQueryDebuggerAction, SqlQueryDebuggerObservation
12
+
13
+
14
+ SCENARIOS = [
15
+ # ── EASY: syntax fixes ──────────────────────────────────────────
16
+ {
17
+ "id": "easy_1",
18
+ "task_id": "syntax_fix",
19
+ "schema": "CREATE TABLE employees (id INTEGER, name TEXT, salary REAL, dept TEXT);",
20
+ "setup": [
21
+ "INSERT INTO employees VALUES (1,'Alice',75000,'Engineering');",
22
+ "INSERT INTO employees VALUES (2,'Bob',50000,'Marketing');",
23
+ "INSERT INTO employees VALUES (3,'Carol',90000,'Engineering');",
24
+ ],
25
+ "broken_query": "SELCT name, salary FROM employees WHERE salary > 60000;",
26
+ "fixed_query": "SELECT name, salary FROM employees WHERE salary > 60000;",
27
+ "error_message": "ParseError: near 'SELCT': syntax error",
28
+ "expected_output_hint": "Should return Alice and Carol who earn more than 60000",
29
+ "expected_rows": [("Alice", 75000.0), ("Carol", 90000.0)],
30
+ },
31
+ {
32
+ "id": "easy_2",
33
+ "task_id": "syntax_fix",
34
+ "schema": "CREATE TABLE products (id INTEGER, name TEXT, price REAL, stock INTEGER);",
35
+ "setup": [
36
+ "INSERT INTO products VALUES (1,'Laptop',999.99,10);",
37
+ "INSERT INTO products VALUES (2,'Mouse',29.99,50);",
38
+ "INSERT INTO products VALUES (3,'Keyboard',79.99,30);",
39
+ ],
40
+ "broken_query": "SELECT name price FROM products WHERE stock > 20;",
41
+ "fixed_query": "SELECT name, price FROM products WHERE stock > 20;",
42
+ "error_message": "",
43
+ "expected_output_hint": "Should return Mouse and Keyboard with their prices",
44
+ "expected_rows": [("Mouse", 29.99), ("Keyboard", 79.99)],
45
+ },
46
+ {
47
+ "id": "easy_3",
48
+ "task_id": "syntax_fix",
49
+ "schema": "CREATE TABLE students (id INTEGER, name TEXT, grade INTEGER, subject TEXT);",
50
+ "setup": [
51
+ "INSERT INTO students VALUES (1,'Dan',85,'Math');",
52
+ "INSERT INTO students VALUES (2,'Eve',92,'Science');",
53
+ "INSERT INTO students VALUES (3,'Frank',78,'Math');",
54
+ ],
55
+ "broken_query": "SELECT name, grade FROM students WERE subject = 'Math';",
56
+ "fixed_query": "SELECT name, grade FROM students WHERE subject = 'Math';",
57
+ "error_message": "ParseError: near 'WERE': syntax error",
58
+ "expected_output_hint": "Should return Dan and Frank with their grades",
59
+ "expected_rows": [("Dan", 85), ("Frank", 78)],
60
+ },
61
+ {
62
+ "id": "easy_4",
63
+ "task_id": "syntax_fix",
64
+ "schema": "CREATE TABLE orders (id INTEGER, customer TEXT, amount REAL, status TEXT);",
65
+ "setup": [
66
+ "INSERT INTO orders VALUES (1,'Alice',250.0,'shipped');",
67
+ "INSERT INTO orders VALUES (2,'Bob',89.0,'pending');",
68
+ "INSERT INTO orders VALUES (3,'Carol',420.0,'shipped');",
69
+ ],
70
+ "broken_query": "SELECT customer, amount FROM orders WHERE status = 'shipped'",
71
+ "fixed_query": "SELECT customer, amount FROM orders WHERE status = 'shipped';",
72
+ "error_message": "",
73
+ "expected_output_hint": "Should return Alice and Carol with shipped order amounts",
74
+ "expected_rows": [("Alice", 250.0), ("Carol", 420.0)],
75
+ },
76
+ {
77
+ "id": "easy_5",
78
+ "task_id": "syntax_fix",
79
+ "schema": "CREATE TABLE inventory (id INTEGER, item TEXT, qty INTEGER, warehouse TEXT);",
80
+ "setup": [
81
+ "INSERT INTO inventory VALUES (1,'Bolts',500,'A');",
82
+ "INSERT INTO inventory VALUES (2,'Nuts',300,'B');",
83
+ "INSERT INTO inventory VALUES (3,'Screws',750,'A');",
84
+ ],
85
+ "broken_query": "SELECT item, qty FROM inventory WHERE warehouse = 'A' ORDR BY qty;",
86
+ "fixed_query": "SELECT item, qty FROM inventory WHERE warehouse = 'A' ORDER BY qty;",
87
+ "error_message": "ParseError: near 'ORDR': syntax error",
88
+ "expected_output_hint": "Should return Bolts and Screws ordered by quantity ascending",
89
+ "expected_rows": [("Bolts", 500), ("Screws", 750)],
90
+ },
91
+ {
92
+ "id": "easy_6",
93
+ "task_id": "syntax_fix",
94
+ "schema": "CREATE TABLE users (id INTEGER, username TEXT, age INTEGER, city TEXT);",
95
+ "setup": [
96
+ "INSERT INTO users VALUES (1,'alice',28,'Delhi');",
97
+ "INSERT INTO users VALUES (2,'bob',35,'Mumbai');",
98
+ "INSERT INTO users VALUES (3,'carol',22,'Delhi');",
99
+ ],
100
+ "broken_query": "SELECT username, age FORM users WHERE city = 'Delhi';",
101
+ "fixed_query": "SELECT username, age FROM users WHERE city = 'Delhi';",
102
+ "error_message": "ParseError: near 'FORM': syntax error",
103
+ "expected_output_hint": "Should return alice and carol from Delhi",
104
+ "expected_rows": [("alice", 28), ("carol", 22)],
105
+ },
106
+ {
107
+ "id": "easy_7",
108
+ "task_id": "syntax_fix",
109
+ "schema": "CREATE TABLE sales (id INTEGER, rep TEXT, amount REAL, region TEXT);",
110
+ "setup": [
111
+ "INSERT INTO sales VALUES (1,'Tom',15000,'North');",
112
+ "INSERT INTO sales VALUES (2,'Sue',22000,'South');",
113
+ "INSERT INTO sales VALUES (3,'Ray',18000,'North');",
114
+ ],
115
+ "broken_query": "SELECT rep, SUM(amount) FROM sales GRUP BY region;",
116
+ "fixed_query": "SELECT rep, SUM(amount) FROM sales GROUP BY region;",
117
+ "error_message": "ParseError: near 'GRUP': syntax error",
118
+ "expected_output_hint": "Should return total sales per region",
119
+ "expected_rows": [("Tom", 33000.0), ("Sue", 22000.0)],
120
+ },
121
+
122
+ # ── MEDIUM: logic bugs ───────────────────────────────────────────
123
+ {
124
+ "id": "medium_1",
125
+ "task_id": "logic_bug",
126
+ "schema": "CREATE TABLE employees (id INTEGER, name TEXT, salary REAL, dept TEXT);",
127
+ "setup": [
128
+ "INSERT INTO employees VALUES (1,'Alice',75000,'Engineering');",
129
+ "INSERT INTO employees VALUES (2,'Bob',50000,'Marketing');",
130
+ "INSERT INTO employees VALUES (3,'Carol',90000,'Engineering');",
131
+ "INSERT INTO employees VALUES (4,'Dave',45000,'Marketing');",
132
+ ],
133
+ "broken_query": "SELECT dept, AVG(salary) FROM employees GROUP BY name;",
134
+ "fixed_query": "SELECT dept, AVG(salary) FROM employees GROUP BY dept;",
135
+ "error_message": "",
136
+ "expected_output_hint": "Should return average salary per department, not per person",
137
+ "expected_rows": [("Engineering", 82500.0), ("Marketing", 47500.0)],
138
+ },
139
+ {
140
+ "id": "medium_2",
141
+ "task_id": "logic_bug",
142
+ "schema": "CREATE TABLE orders (id INTEGER, customer TEXT, amount REAL, status TEXT);",
143
+ "setup": [
144
+ "INSERT INTO orders VALUES (1,'Alice',250.0,'shipped');",
145
+ "INSERT INTO orders VALUES (2,'Bob',89.0,'pending');",
146
+ "INSERT INTO orders VALUES (3,'Carol',420.0,'shipped');",
147
+ "INSERT INTO orders VALUES (4,'Dave',150.0,'cancelled');",
148
+ ],
149
+ "broken_query": "SELECT customer, amount FROM orders WHERE status != 'shipped';",
150
+ "fixed_query": "SELECT customer, amount FROM orders WHERE status = 'shipped';",
151
+ "error_message": "",
152
+ "expected_output_hint": "Should return only shipped orders — Alice and Carol",
153
+ "expected_rows": [("Alice", 250.0), ("Carol", 420.0)],
154
+ },
155
+ {
156
+ "id": "medium_3",
157
+ "task_id": "logic_bug",
158
+ "schema": "CREATE TABLE products (id INTEGER, name TEXT, price REAL, category TEXT);",
159
+ "setup": [
160
+ "INSERT INTO products VALUES (1,'Laptop',999.99,'Electronics');",
161
+ "INSERT INTO products VALUES (2,'Shirt',29.99,'Clothing');",
162
+ "INSERT INTO products VALUES (3,'Phone',699.99,'Electronics');",
163
+ "INSERT INTO products VALUES (4,'Jeans',59.99,'Clothing');",
164
+ ],
165
+ "broken_query": "SELECT name, price FROM products WHERE price > 100 AND category = 'Clothing';",
166
+ "fixed_query": "SELECT name, price FROM products WHERE price > 100 AND category = 'Electronics';",
167
+ "error_message": "",
168
+ "expected_output_hint": "Should return expensive electronics — Laptop and Phone",
169
+ "expected_rows": [("Laptop", 999.99), ("Phone", 699.99)],
170
+ },
171
+ {
172
+ "id": "medium_4",
173
+ "task_id": "logic_bug",
174
+ "schema": "CREATE TABLE students (id INTEGER, name TEXT, score INTEGER, passed INTEGER);",
175
+ "setup": [
176
+ "INSERT INTO students VALUES (1,'Alice',85,1);",
177
+ "INSERT INTO students VALUES (2,'Bob',45,0);",
178
+ "INSERT INTO students VALUES (3,'Carol',72,1);",
179
+ "INSERT INTO students VALUES (4,'Dave',38,0);",
180
+ ],
181
+ "broken_query": "SELECT COUNT(*) FROM students WHERE passed = 1 LIMIT 1;",
182
+ "fixed_query": "SELECT COUNT(*) FROM students WHERE passed = 1;",
183
+ "error_message": "",
184
+ "expected_output_hint": "Should return total count of passed students which is 2",
185
+ "expected_rows": [(2,)],
186
+ },
187
+ {
188
+ "id": "medium_5",
189
+ "task_id": "logic_bug",
190
+ "schema": "CREATE TABLE employees (id INTEGER, name TEXT, salary REAL, dept TEXT);",
191
+ "setup": [
192
+ "INSERT INTO employees VALUES (1,'Alice',75000,'Engineering');",
193
+ "INSERT INTO employees VALUES (2,'Bob',50000,'Marketing');",
194
+ "INSERT INTO employees VALUES (3,'Carol',90000,'Engineering');",
195
+ ],
196
+ "broken_query": "SELECT name, salary FROM employees ORDER BY salary ASC LIMIT 1;",
197
+ "fixed_query": "SELECT name, salary FROM employees ORDER BY salary DESC LIMIT 1;",
198
+ "error_message": "",
199
+ "expected_output_hint": "Should return the highest paid employee — Carol with 90000",
200
+ "expected_rows": [("Carol", 90000.0)],
201
+ },
202
+ {
203
+ "id": "medium_6",
204
+ "task_id": "logic_bug",
205
+ "schema": "CREATE TABLE sales (id INTEGER, rep TEXT, amount REAL, month INTEGER);",
206
+ "setup": [
207
+ "INSERT INTO sales VALUES (1,'Tom',15000,1);",
208
+ "INSERT INTO sales VALUES (2,'Sue',22000,1);",
209
+ "INSERT INTO sales VALUES (3,'Tom',18000,2);",
210
+ "INSERT INTO sales VALUES (4,'Sue',25000,2);",
211
+ ],
212
+ "broken_query": "SELECT rep, amount FROM sales WHERE month = 1;",
213
+ "fixed_query": "SELECT rep, SUM(amount) FROM sales GROUP BY rep;",
214
+ "error_message": "",
215
+ "expected_output_hint": "Should return total sales per rep across all months",
216
+ "expected_rows": [("Tom", 33000.0), ("Sue", 47000.0)],
217
+ },
218
+
219
+ # ── HARD: multi-table optimization ──────────────────────────────
220
+ {
221
+ "id": "hard_1",
222
+ "task_id": "multi_table",
223
+ "schema": (
224
+ "CREATE TABLE employees (id INTEGER, name TEXT, dept_id INTEGER, salary REAL);"
225
+ "CREATE TABLE departments (id INTEGER, dept_name TEXT, budget REAL);"
226
+ ),
227
+ "setup": [
228
+ "INSERT INTO departments VALUES (1,'Engineering',500000);",
229
+ "INSERT INTO departments VALUES (2,'Marketing',200000);",
230
+ "INSERT INTO employees VALUES (1,'Alice',1,75000);",
231
+ "INSERT INTO employees VALUES (2,'Bob',2,50000);",
232
+ "INSERT INTO employees VALUES (3,'Carol',1,90000);",
233
+ ],
234
+ "broken_query": "SELECT e.name, d.dept_name FROM employees e, departments d WHERE e.salary > 60000;",
235
+ "fixed_query": "SELECT e.name, d.dept_name FROM employees e JOIN departments d ON e.dept_id = d.id WHERE e.salary > 60000;",
236
+ "error_message": "",
237
+ "expected_output_hint": "Should return Alice and Carol with their department names using proper JOIN",
238
+ "expected_rows": [("Alice", "Engineering"), ("Carol", "Engineering")],
239
+ },
240
+ {
241
+ "id": "hard_2",
242
+ "task_id": "multi_table",
243
+ "schema": (
244
+ "CREATE TABLE orders (id INTEGER, customer_id INTEGER, amount REAL);"
245
+ "CREATE TABLE customers (id INTEGER, name TEXT, city TEXT);"
246
+ ),
247
+ "setup": [
248
+ "INSERT INTO customers VALUES (1,'Alice','Delhi');",
249
+ "INSERT INTO customers VALUES (2,'Bob','Mumbai');",
250
+ "INSERT INTO orders VALUES (1,1,250.0);",
251
+ "INSERT INTO orders VALUES (2,1,180.0);",
252
+ "INSERT INTO orders VALUES (3,2,420.0);",
253
+ ],
254
+ "broken_query": "SELECT c.name, o.amount FROM customers c LEFT JOIN orders o ON c.id = o.id;",
255
+ "fixed_query": "SELECT c.name, SUM(o.amount) FROM customers c JOIN orders o ON c.id = o.customer_id GROUP BY c.name;",
256
+ "error_message": "",
257
+ "expected_output_hint": "Should return total order amount per customer — Alice 430, Bob 420",
258
+ "expected_rows": [("Alice", 430.0), ("Bob", 420.0)],
259
+ },
260
+ {
261
+ "id": "hard_3",
262
+ "task_id": "multi_table",
263
+ "schema": (
264
+ "CREATE TABLE employees (id INTEGER, name TEXT, dept_id INTEGER, salary REAL);"
265
+ "CREATE TABLE departments (id INTEGER, dept_name TEXT, budget REAL);"
266
+ ),
267
+ "setup": [
268
+ "INSERT INTO departments VALUES (1,'Engineering',500000);",
269
+ "INSERT INTO departments VALUES (2,'Marketing',200000);",
270
+ "INSERT INTO employees VALUES (1,'Alice',1,75000);",
271
+ "INSERT INTO employees VALUES (2,'Bob',2,50000);",
272
+ "INSERT INTO employees VALUES (3,'Carol',1,90000);",
273
+ "INSERT INTO employees VALUES (4,'Dave',2,45000);",
274
+ ],
275
+ "broken_query": "SELECT dept_name, COUNT(*) FROM departments GROUP BY dept_name;",
276
+ "fixed_query": "SELECT d.dept_name, COUNT(e.id) FROM departments d JOIN employees e ON d.id = e.dept_id GROUP BY d.dept_name;",
277
+ "error_message": "",
278
+ "expected_output_hint": "Should return headcount per department — Engineering 2, Marketing 2",
279
+ "expected_rows": [("Engineering", 2), ("Marketing", 2)],
280
+ },
281
+ {
282
+ "id": "hard_4",
283
+ "task_id": "multi_table",
284
+ "schema": (
285
+ "CREATE TABLE products (id INTEGER, name TEXT, category_id INTEGER, price REAL);"
286
+ "CREATE TABLE categories (id INTEGER, cat_name TEXT);"
287
+ "CREATE TABLE order_items (id INTEGER, product_id INTEGER, qty INTEGER);"
288
+ ),
289
+ "setup": [
290
+ "INSERT INTO categories VALUES (1,'Electronics');",
291
+ "INSERT INTO categories VALUES (2,'Clothing');",
292
+ "INSERT INTO products VALUES (1,'Laptop',1,999.99);",
293
+ "INSERT INTO products VALUES (2,'Shirt',2,29.99);",
294
+ "INSERT INTO products VALUES (3,'Phone',1,699.99);",
295
+ "INSERT INTO order_items VALUES (1,1,2);",
296
+ "INSERT INTO order_items VALUES (2,3,5);",
297
+ "INSERT INTO order_items VALUES (3,2,10);",
298
+ ],
299
+ "broken_query": "SELECT p.name, oi.qty FROM products p JOIN order_items oi ON p.id = oi.id;",
300
+ "fixed_query": "SELECT p.name, SUM(oi.qty) as total_qty FROM products p JOIN order_items oi ON p.id = oi.product_id GROUP BY p.name ORDER BY total_qty DESC;",
301
+ "error_message": "",
302
+ "expected_output_hint": "Should return total quantity ordered per product, highest first",
303
+ "expected_rows": [("Shirt", 10), ("Phone", 5), ("Laptop", 2)],
304
+ },
305
+ {
306
+ "id": "hard_5",
307
+ "task_id": "multi_table",
308
+ "schema": (
309
+ "CREATE TABLE employees (id INTEGER, name TEXT, manager_id INTEGER, salary REAL);"
310
+ ),
311
+ "setup": [
312
+ "INSERT INTO employees VALUES (1,'CEO',NULL,200000);",
313
+ "INSERT INTO employees VALUES (2,'Alice',1,90000);",
314
+ "INSERT INTO employees VALUES (3,'Bob',1,85000);",
315
+ "INSERT INTO employees VALUES (4,'Carol',2,70000);",
316
+ ],
317
+ "broken_query": "SELECT e.name, m.name FROM employees e JOIN employees m ON e.id = m.manager_id;",
318
+ "fixed_query": "SELECT e.name, m.name as manager FROM employees e JOIN employees m ON e.manager_id = m.id WHERE e.manager_id IS NOT NULL;",
319
+ "error_message": "",
320
+ "expected_output_hint": "Should return each employee with their manager name — self join",
321
+ "expected_rows": [("Alice", "CEO"), ("Bob", "CEO"), ("Carol", "Alice")],
322
+ },
323
+ {
324
+ "id": "hard_6",
325
+ "task_id": "multi_table",
326
+ "schema": (
327
+ "CREATE TABLE orders (id INTEGER, customer_id INTEGER, amount REAL, status TEXT);"
328
+ "CREATE TABLE customers (id INTEGER, name TEXT, tier TEXT);"
329
+ ),
330
+ "setup": [
331
+ "INSERT INTO customers VALUES (1,'Alice','gold');",
332
+ "INSERT INTO customers VALUES (2,'Bob','silver');",
333
+ "INSERT INTO customers VALUES (3,'Carol','gold');",
334
+ "INSERT INTO orders VALUES (1,1,500,'shipped');",
335
+ "INSERT INTO orders VALUES (2,2,200,'shipped');",
336
+ "INSERT INTO orders VALUES (3,3,800,'shipped');",
337
+ "INSERT INTO orders VALUES (4,1,300,'pending');",
338
+ ],
339
+ "broken_query": "SELECT c.name, SUM(o.amount) FROM customers c JOIN orders o ON c.id = o.customer_id GROUP BY c.name;",
340
+ "fixed_query": "SELECT c.name, SUM(o.amount) FROM customers c JOIN orders o ON c.id = o.customer_id WHERE c.tier = 'gold' AND o.status = 'shipped' GROUP BY c.name;",
341
+ "error_message": "",
342
+ "expected_output_hint": "Should return total shipped order amounts for gold tier customers only",
343
+ "expected_rows": [("Alice", 500.0), ("Carol", 800.0)],
344
+ },
345
+ {
346
+ "id": "hard_7",
347
+ "task_id": "multi_table",
348
+ "schema": (
349
+ "CREATE TABLE employees (id INTEGER, name TEXT, dept_id INTEGER, salary REAL);"
350
+ "CREATE TABLE departments (id INTEGER, dept_name TEXT, budget REAL);"
351
+ ),
352
+ "setup": [
353
+ "INSERT INTO departments VALUES (1,'Engineering',500000);",
354
+ "INSERT INTO departments VALUES (2,'Marketing',200000);",
355
+ "INSERT INTO employees VALUES (1,'Alice',1,75000);",
356
+ "INSERT INTO employees VALUES (2,'Bob',2,50000);",
357
+ "INSERT INTO employees VALUES (3,'Carol',1,90000);",
358
+ "INSERT INTO employees VALUES (4,'Dave',2,45000);",
359
+ ],
360
+ "broken_query": "SELECT dept_name FROM departments WHERE budget > AVG(budget);",
361
+ "fixed_query": "SELECT dept_name FROM departments WHERE budget > (SELECT AVG(budget) FROM departments);",
362
+ "error_message": "misuse of aggregate function AVG()",
363
+ "expected_output_hint": "Should return departments with above-average budget — Engineering only",
364
+ "expected_rows": [("Engineering",)],
365
+ },
366
+ ]
367
+ # module-level globals — persist across all instances
368
+ _CURRENT_SCENARIO = None
369
+ _CURRENT_STEP = 0
370
+
371
+ def compute_f1(predicted_rows, expected_rows):
372
+ if not expected_rows:
373
+ return 1.0 if not predicted_rows else 0.0
374
+ pred_set = [tuple(str(v) for v in row) for row in predicted_rows]
375
+ exp_set = [tuple(str(v) for v in row) for row in expected_rows]
376
+ pred_multiset = {}
377
+ for row in pred_set:
378
+ pred_multiset[row] = pred_multiset.get(row, 0) + 1
379
+ exp_multiset = {}
380
+ for row in exp_set:
381
+ exp_multiset[row] = exp_multiset.get(row, 0) + 1
382
+ true_positives = 0
383
+ for row, count in exp_multiset.items():
384
+ true_positives += min(count, pred_multiset.get(row, 0))
385
+ precision = true_positives / len(pred_set) if pred_set else 0.0
386
+ recall = true_positives / len(exp_set) if exp_set else 0.0
387
+ if precision + recall == 0:
388
+ return 0.0
389
+ return 2 * precision * recall / (precision + recall)
390
+
391
+
392
+ def run_query_safe(schema, setup_stmts, query):
393
+ try:
394
+ conn = sqlite3.connect(":memory:")
395
+ cur = conn.cursor()
396
+ for stmt in schema.split(";"):
397
+ stmt = stmt.strip()
398
+ if stmt:
399
+ cur.execute(stmt)
400
+ for stmt in setup_stmts:
401
+ cur.execute(stmt)
402
+ conn.commit()
403
+ cur.execute(query)
404
+ rows = cur.fetchall()
405
+ conn.close()
406
+ return rows, ""
407
+ except Exception as e:
408
+ return [], str(e)
409
+
410
+
411
+ # module-level globals — survive across all instances
412
+ _CURRENT_SCENARIO = random.choice(SCENARIOS)
413
+ _CURRENT_STEP = 0
414
+
415
+
416
+ class SqlQueryDebuggerEnvironment(Environment):
417
+
418
+ SUPPORTS_CONCURRENT_SESSIONS: bool = False
419
+ MAX_STEPS = 5
420
+
421
+ def __init__(self):
422
+ self._state = State(episode_id=str(uuid4()), step_count=0)
423
+
424
+ def reset(self, task_id: str = None) -> SqlQueryDebuggerObservation:
425
+ global _CURRENT_SCENARIO, _CURRENT_STEP
426
+ self._state = State(episode_id=str(uuid4()), step_count=0)
427
+ _CURRENT_STEP = 0
428
+
429
+ pool = [s for s in SCENARIOS if s["task_id"] == task_id] if task_id else SCENARIOS
430
+ _CURRENT_SCENARIO = random.choice(pool)
431
+
432
+ sample_rows, _ = run_query_safe(
433
+ _CURRENT_SCENARIO["schema"],
434
+ _CURRENT_SCENARIO["setup"],
435
+ "SELECT * FROM " + _CURRENT_SCENARIO["schema"]
436
+ .split("CREATE TABLE ")[1].split(" ")[0] + " LIMIT 3;"
437
+ )
438
+
439
+ return SqlQueryDebuggerObservation(
440
+ broken_query = _CURRENT_SCENARIO["broken_query"],
441
+ schema = _CURRENT_SCENARIO["schema"],
442
+ error_message = _CURRENT_SCENARIO["error_message"],
443
+ sample_rows = json.dumps(sample_rows),
444
+ expected_output_hint = _CURRENT_SCENARIO["expected_output_hint"],
445
+ task_id = _CURRENT_SCENARIO["task_id"],
446
+ attempts_remaining = self.MAX_STEPS,
447
+ last_result = None,
448
+ done = False,
449
+ reward = 0.0,
450
+ )
451
+
452
+ def step(self, action: SqlQueryDebuggerAction) -> SqlQueryDebuggerObservation:
453
+ global _CURRENT_SCENARIO, _CURRENT_STEP
454
+ _CURRENT_STEP += 1
455
+ self._state.step_count = _CURRENT_STEP
456
+ attempts_left = self.MAX_STEPS - _CURRENT_STEP
457
+
458
+ rows, error = run_query_safe(
459
+ _CURRENT_SCENARIO["schema"],
460
+ _CURRENT_SCENARIO["setup"],
461
+ action.fixed_query,
462
+ )
463
+
464
+ f1 = compute_f1(rows, _CURRENT_SCENARIO["expected_rows"])
465
+ done = f1 >= 0.99 or attempts_left <= 0
466
+ bonus = 0.1 * (attempts_left / self.MAX_STEPS) if f1 >= 0.99 else 0.0
467
+ reward = min(round(f1 + bonus, 4), 1.0)
468
+
469
+ return SqlQueryDebuggerObservation(
470
+ broken_query = _CURRENT_SCENARIO["broken_query"],
471
+ schema = _CURRENT_SCENARIO["schema"],
472
+ error_message = error,
473
+ sample_rows = json.dumps(_CURRENT_SCENARIO["setup"]),
474
+ expected_output_hint = _CURRENT_SCENARIO["expected_output_hint"],
475
+ task_id = _CURRENT_SCENARIO["task_id"],
476
+ attempts_remaining = max(attempts_left, 0),
477
+ last_result = json.dumps(rows),
478
+ done = done,
479
+ reward = reward,
480
+ )
481
+
482
+ @property
483
+ def state(self) -> State:
484
+ return self._state
uv.lock ADDED
The diff for this file is too large to render. See raw diff