File size: 7,091 Bytes
c15d346
 
 
 
 
 
 
 
 
 
e8bc352
 
c15d346
 
 
 
e8bc352
 
 
 
c15d346
e8bc352
c15d346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8bc352
 
 
 
c15d346
 
 
 
e8bc352
c15d346
 
e8bc352
 
 
 
 
 
 
 
 
 
 
 
c15d346
 
e8bc352
 
c15d346
e8bc352
c15d346
e8bc352
c15d346
 
 
 
e8bc352
 
 
 
 
c15d346
e8bc352
 
 
 
 
 
 
 
 
c15d346
 
 
e8bc352
 
 
 
c15d346
e8bc352
c15d346
 
 
e8bc352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c15d346
e8bc352
 
 
 
c15d346
 
e8bc352
 
 
c15d346
 
 
 
e8bc352
 
c15d346
 
e8bc352
 
c15d346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8bc352
 
c15d346
e8bc352
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
"""
server/app.py β€” FastAPI Server
================================
OpenEnv-compliant endpoints + two unique endpoints:
  POST /execute    β€” run your optimized query against real DuckDB data,
                     see actual speedup + result correctness instantly
  GET  /leaderboard β€” see best scores + speedups across all tasks
"""

import json
import os
import sys
from contextlib import asynccontextmanager

from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from env import SQLOptimEnv
from executor import get_executor
from graders import grade
from leaderboard import get_board
from models import (
    Action,
    EnvironmentState,
    ExecutionResult,
    Observation,
    StepResult,
)
from tasks import TASKS, get_task_list


# ── Lifespan: pre-warm DuckDB on startup ─────────────────────────────────

@asynccontextmanager
async def lifespan(app: FastAPI):
    # Build all 4 synthetic tables before first request
    get_executor()
    yield


app = FastAPI(
    title="SQL Query Optimization Environment",
    description=(
        "OpenEnv-compliant RL environment where AI agents learn to diagnose "
        "and optimize SQL queries. Uniquely, optimized queries are EXECUTED "
        "against real DuckDB data β€” reward is based on actual speedup + "
        "result correctness, not keyword heuristics."
    ),
    version="2.0.0",
    lifespan=lifespan,
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

env = SQLOptimEnv()


# ── Standard OpenEnv endpoints ────────────────────────────────────────────

@app.get("/")
def root():
    ex = get_executor()
    return {
        "status":      "ok",
        "environment": "sql-optim-env",
        "version":     "2.0.0",
        "unique_feature": "Execution-grounded rewards via DuckDB",
        "table_stats": ex.table_stats,
        "tasks":       [t["task_id"] for t in get_task_list()],
    }


@app.post("/reset", response_model=Observation)
async def reset(request: Request):
    """Start a new episode. Body: {"task_id": "..."}  (optional)."""
    try:
        body = await request.body()
        task_id = "task_1_basic_antipatterns"
        if body:
            try:
                data = json.loads(body)
                task_id = data.get("task_id", task_id) or task_id
            except Exception:
                pass
        return env.reset(task_id=task_id)
    except ValueError as exc:
        raise HTTPException(status_code=400, detail=str(exc))


@app.post("/step", response_model=StepResult)
def step(action: Action):
    """Submit an optimization action; get real execution feedback."""
    try:
        return env.step(action)
    except RuntimeError as exc:
        raise HTTPException(status_code=400, detail=str(exc))


@app.get("/state", response_model=EnvironmentState)
def state():
    return env.state()


@app.get("/tasks")
def tasks():
    return {"tasks": get_task_list()}


@app.post("/grader")
def grader(action: Action):
    """Grade an action against the current task without advancing the episode."""
    if env._task_data is None:
        raise HTTPException(status_code=400, detail="No active episode. Call /reset first.")
    return grade(env._task_data, action)


@app.post("/baseline")
def baseline():
    """Run the baseline inference script and return output."""
    import subprocess
    try:
        result = subprocess.run(
            ["python", "inference.py"],
            capture_output=True,
            text=True,
            timeout=300,
            cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
        )
        return {
            "stdout":     result.stdout,
            "stderr":     result.stderr,
            "returncode": result.returncode,
        }
    except Exception as exc:
        raise HTTPException(status_code=500, detail=f"Baseline failed: {exc}")


# ── Unique endpoints (no other team has these) ────────────────────────────

@app.post("/execute", response_model=ExecutionResult)
async def execute(request: Request):
    """
    πŸš€ UNIQUE ENDPOINT β€” Execute your optimized query against real DuckDB data.

    Body:
      {
        "task_id": "task_1_basic_antipatterns",
        "optimized_query": "SELECT id, customer_id ... WHERE customer_id = 5000 ..."
      }

    Returns actual execution timing, speedup ratio, result correctness,
    and an EXPLAIN plan β€” no other OpenEnv environment does this.
    """
    body = await request.body()
    if not body:
        raise HTTPException(status_code=400, detail="Body required: {task_id, optimized_query}")
    try:
        data = json.loads(body)
    except Exception:
        raise HTTPException(status_code=400, detail="Invalid JSON body")

    task_id = data.get("task_id", "task_1_basic_antipatterns")
    optimized_query = (data.get("optimized_query") or "").strip()

    if task_id not in TASKS:
        raise HTTPException(status_code=400, detail=f"Unknown task_id: {task_id}")
    if not optimized_query:
        raise HTTPException(status_code=400, detail="optimized_query is required")

    original_query = TASKS[task_id]["sql_query"]
    ex = get_executor()

    try:
        result = ex.compare(original_query, optimized_query)
        explain = ex.explain(optimized_query)
        return ExecutionResult(
            original_ms=result["original_ms"],
            optimized_ms=result["optimized_ms"],
            speedup=result["speedup"],
            results_match=result["results_match"],
            original_rows=result["original_rows"],
            optimized_rows=result["optimized_rows"],
            original_error=result.get("original_error"),
            optimized_error=result.get("optimized_error"),
            verdict=result["verdict"],
            explain_plan=explain,
        )
    except Exception as exc:
        raise HTTPException(status_code=500, detail=str(exc))


@app.get("/leaderboard")
def leaderboard():
    """
    πŸ† UNIQUE ENDPOINT β€” Real-time leaderboard of best execution scores.

    Shows per-task: best score, best speedup achieved, total attempts,
    how many optimized queries produced correct results.
    """
    return {
        "leaderboard": get_board(),
        "description": (
            "Scores are based on real DuckDB execution: "
            "speedup ratio (35%) + result correctness (20%) + issue detection (25%) + other (20%)"
        ),
    }


# ── Entry point ───────────────────────────────────────────────────────────

def main():
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)


if __name__ == "__main__":
    main()