Abhinav Singh
feat(v2): execution-grounded rewards via DuckDB -- the key differentiator
c15d346
"""
server/app.py β€” FastAPI Server
================================
OpenEnv-compliant endpoints + two unique endpoints:
POST /execute β€” run your optimized query against real DuckDB data,
see actual speedup + result correctness instantly
GET /leaderboard β€” see best scores + speedups across all tasks
"""
import json
import os
import sys
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from env import SQLOptimEnv
from executor import get_executor
from graders import grade
from leaderboard import get_board
from models import (
Action,
EnvironmentState,
ExecutionResult,
Observation,
StepResult,
)
from tasks import TASKS, get_task_list
# ── Lifespan: pre-warm DuckDB on startup ─────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
# Build all 4 synthetic tables before first request
get_executor()
yield
app = FastAPI(
title="SQL Query Optimization Environment",
description=(
"OpenEnv-compliant RL environment where AI agents learn to diagnose "
"and optimize SQL queries. Uniquely, optimized queries are EXECUTED "
"against real DuckDB data β€” reward is based on actual speedup + "
"result correctness, not keyword heuristics."
),
version="2.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
env = SQLOptimEnv()
# ── Standard OpenEnv endpoints ────────────────────────────────────────────
@app.get("/")
def root():
ex = get_executor()
return {
"status": "ok",
"environment": "sql-optim-env",
"version": "2.0.0",
"unique_feature": "Execution-grounded rewards via DuckDB",
"table_stats": ex.table_stats,
"tasks": [t["task_id"] for t in get_task_list()],
}
@app.post("/reset", response_model=Observation)
async def reset(request: Request):
"""Start a new episode. Body: {"task_id": "..."} (optional)."""
try:
body = await request.body()
task_id = "task_1_basic_antipatterns"
if body:
try:
data = json.loads(body)
task_id = data.get("task_id", task_id) or task_id
except Exception:
pass
return env.reset(task_id=task_id)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
@app.post("/step", response_model=StepResult)
def step(action: Action):
"""Submit an optimization action; get real execution feedback."""
try:
return env.step(action)
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc))
@app.get("/state", response_model=EnvironmentState)
def state():
return env.state()
@app.get("/tasks")
def tasks():
return {"tasks": get_task_list()}
@app.post("/grader")
def grader(action: Action):
"""Grade an action against the current task without advancing the episode."""
if env._task_data is None:
raise HTTPException(status_code=400, detail="No active episode. Call /reset first.")
return grade(env._task_data, action)
@app.post("/baseline")
def baseline():
"""Run the baseline inference script and return output."""
import subprocess
try:
result = subprocess.run(
["python", "inference.py"],
capture_output=True,
text=True,
timeout=300,
cwd=os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
)
return {
"stdout": result.stdout,
"stderr": result.stderr,
"returncode": result.returncode,
}
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Baseline failed: {exc}")
# ── Unique endpoints (no other team has these) ────────────────────────────
@app.post("/execute", response_model=ExecutionResult)
async def execute(request: Request):
"""
πŸš€ UNIQUE ENDPOINT β€” Execute your optimized query against real DuckDB data.
Body:
{
"task_id": "task_1_basic_antipatterns",
"optimized_query": "SELECT id, customer_id ... WHERE customer_id = 5000 ..."
}
Returns actual execution timing, speedup ratio, result correctness,
and an EXPLAIN plan β€” no other OpenEnv environment does this.
"""
body = await request.body()
if not body:
raise HTTPException(status_code=400, detail="Body required: {task_id, optimized_query}")
try:
data = json.loads(body)
except Exception:
raise HTTPException(status_code=400, detail="Invalid JSON body")
task_id = data.get("task_id", "task_1_basic_antipatterns")
optimized_query = (data.get("optimized_query") or "").strip()
if task_id not in TASKS:
raise HTTPException(status_code=400, detail=f"Unknown task_id: {task_id}")
if not optimized_query:
raise HTTPException(status_code=400, detail="optimized_query is required")
original_query = TASKS[task_id]["sql_query"]
ex = get_executor()
try:
result = ex.compare(original_query, optimized_query)
explain = ex.explain(optimized_query)
return ExecutionResult(
original_ms=result["original_ms"],
optimized_ms=result["optimized_ms"],
speedup=result["speedup"],
results_match=result["results_match"],
original_rows=result["original_rows"],
optimized_rows=result["optimized_rows"],
original_error=result.get("original_error"),
optimized_error=result.get("optimized_error"),
verdict=result["verdict"],
explain_plan=explain,
)
except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc))
@app.get("/leaderboard")
def leaderboard():
"""
πŸ† UNIQUE ENDPOINT β€” Real-time leaderboard of best execution scores.
Shows per-task: best score, best speedup achieved, total attempts,
how many optimized queries produced correct results.
"""
return {
"leaderboard": get_board(),
"description": (
"Scores are based on real DuckDB execution: "
"speedup ratio (35%) + result correctness (20%) + issue detection (25%) + other (20%)"
),
}
# ── Entry point ───────────────────────────────────────────────────────────
def main():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
main()