Spaces:
Sleeping
Sleeping
Commit ·
269f632
0
Parent(s):
feat: SQL Repair OpenEnv submission — Phase 1 validator passes
Browse files- 3 SQL repair tasks (easy/medium/hard) with SQLite-backed env
- FastAPI server with all required endpoints (/health /tasks /reset /step /grader /baseline)
- /reset accepts empty body (Phase 1 requirement)
- inference.py: HTTP client + OpenAI-compatible LLM caller
- Strict (0,1) score clamping with NaN/inf -> 0.5 fallback
- Every task emits exactly one [START]/[END] even on crash (Phase 2 lesson)
- Sterile stdout: only bracket lines on stdout, diagnostics on stderr
- pyproject.toml + uv.lock + server/app.py:main + openenv-core>=0.2.0
- openenv validate .: PASS
- 8/8 unit tests pass
- .dockerignore +13 -0
- .gitignore +21 -0
- Dockerfile +27 -0
- README.md +99 -0
- inference.py +425 -0
- openenv.yaml +28 -0
- pyproject.toml +30 -0
- requirements.txt +7 -0
- server/__init__.py +1 -0
- server/app.py +142 -0
- sql_env/__init__.py +14 -0
- sql_env/env_core.py +174 -0
- sql_env/grader.py +64 -0
- sql_env/tasks.py +81 -0
- tests/__init__.py +0 -0
- tests/test_smoke.py +110 -0
- uv.lock +0 -0
.dockerignore
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
.venv
|
| 5 |
+
.env
|
| 6 |
+
.git
|
| 7 |
+
.pytest_cache
|
| 8 |
+
.ruff_cache
|
| 9 |
+
.mypy_cache
|
| 10 |
+
tests/
|
| 11 |
+
*.egg-info
|
| 12 |
+
build/
|
| 13 |
+
dist/
|
.gitignore
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
*$py.class
|
| 4 |
+
*.so
|
| 5 |
+
.Python
|
| 6 |
+
.venv/
|
| 7 |
+
venv/
|
| 8 |
+
env/
|
| 9 |
+
.env
|
| 10 |
+
.env.local
|
| 11 |
+
.pytest_cache/
|
| 12 |
+
.ruff_cache/
|
| 13 |
+
.mypy_cache/
|
| 14 |
+
*.egg-info/
|
| 15 |
+
build/
|
| 16 |
+
dist/
|
| 17 |
+
.DS_Store
|
| 18 |
+
.claude-flow/
|
| 19 |
+
.swarm/
|
| 20 |
+
.claude/
|
| 21 |
+
|
Dockerfile
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# System deps (curl for healthchecks)
|
| 6 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 7 |
+
curl \
|
| 8 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
# Copy & install Python deps first for layer caching
|
| 11 |
+
COPY requirements.txt ./
|
| 12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
+
|
| 14 |
+
# Copy application
|
| 15 |
+
COPY . .
|
| 16 |
+
|
| 17 |
+
# Install package so [project.scripts] is callable
|
| 18 |
+
RUN pip install --no-cache-dir -e .
|
| 19 |
+
|
| 20 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 21 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 22 |
+
PORT=8000
|
| 23 |
+
|
| 24 |
+
EXPOSE 8000
|
| 25 |
+
|
| 26 |
+
# Use the entry point declared in pyproject.toml
|
| 27 |
+
CMD ["python", "-m", "server.app"]
|
README.md
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SQL Repair OpenEnv
|
| 2 |
+
|
| 3 |
+
An OpenEnv environment for the **Meta PyTorch x Scaler hackathon** where
|
| 4 |
+
agents repair broken SQL queries against a small SQLite schema.
|
| 5 |
+
|
| 6 |
+
## Tasks
|
| 7 |
+
|
| 8 |
+
| ID | Difficulty | What's broken |
|
| 9 |
+
|----------|------------|------------------------------------------------|
|
| 10 |
+
| `task_1` | easy | SELECT list missing commas |
|
| 11 |
+
| `task_2` | medium | JOIN references columns that don't exist |
|
| 12 |
+
| `task_3` | hard | Aggregate query missing GROUP BY |
|
| 13 |
+
|
| 14 |
+
Each task gives the agent the schema, the broken query, the runtime error
|
| 15 |
+
(if any), and a one-line hint. The agent submits a corrected query via the
|
| 16 |
+
`/step` endpoint and is scored on whether the result rows match the
|
| 17 |
+
canonical expected rows.
|
| 18 |
+
|
| 19 |
+
## Architecture
|
| 20 |
+
|
| 21 |
+
```
|
| 22 |
+
.
|
| 23 |
+
├── pyproject.toml # uv project, server entry point
|
| 24 |
+
├── uv.lock # uv lockfile
|
| 25 |
+
├── Dockerfile # builds the env server image
|
| 26 |
+
├── inference.py # AGENT — talks to the env via HTTP, calls an LLM
|
| 27 |
+
├── openenv.yaml # OpenEnv metadata
|
| 28 |
+
├── server/
|
| 29 |
+
│ └── app.py # FastAPI env server (def main)
|
| 30 |
+
├── sql_env/
|
| 31 |
+
│ ├── env_core.py # SQLite-backed env state
|
| 32 |
+
│ ├── tasks.py # Task definitions
|
| 33 |
+
│ └── grader.py # Strict (0, 1) score clamping
|
| 34 |
+
└── tests/
|
| 35 |
+
└── test_smoke.py # Pytest smoke suite
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## HTTP API
|
| 39 |
+
|
| 40 |
+
| Method | Path | Body | Returns |
|
| 41 |
+
|--------|-------------|-------------------------------------------|--------------------------------------|
|
| 42 |
+
| GET | `/health` | — | `{"status":"ok"}` |
|
| 43 |
+
| GET | `/tasks` | — | task list + metadata |
|
| 44 |
+
| POST | `/reset` | `{"task_id":"task_1"}` (optional) | observation |
|
| 45 |
+
| POST | `/step` | `{"action":{"action_type":"submit_query","query":"..."}}` | observation/reward/done |
|
| 46 |
+
| POST | `/grader` | `{"task_id":"task_1"}` | `{"score": float in (0,1)}` |
|
| 47 |
+
| POST | `/baseline` | `{"tasks":[...]}` (optional) | scores for all tasks |
|
| 48 |
+
|
| 49 |
+
`/reset` accepts an empty body and defaults to `task_1` — required by the
|
| 50 |
+
OpenEnv validator.
|
| 51 |
+
|
| 52 |
+
## Running locally
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
# 1. Install
|
| 56 |
+
uv sync # or: pip install -e . && pip install -r requirements.txt
|
| 57 |
+
|
| 58 |
+
# 2. Start the env server
|
| 59 |
+
python -m server.app # listens on http://localhost:8000
|
| 60 |
+
|
| 61 |
+
# 3. Run the agent (in another terminal)
|
| 62 |
+
export HF_TOKEN=<your-groq-or-openai-key>
|
| 63 |
+
export API_BASE_URL=https://api.groq.com/openai/v1
|
| 64 |
+
export MODEL_NAME=llama-3.3-70b-versatile
|
| 65 |
+
python inference.py
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
Expected output:
|
| 69 |
+
|
| 70 |
+
```
|
| 71 |
+
[START] task_1
|
| 72 |
+
[STEP] 01 | task=task_1 | action=submit_query | reward=+1.0000 | matches=True | rows=5
|
| 73 |
+
[END] task_1 | score=0.9890 | status=ok
|
| 74 |
+
[START] task_2
|
| 75 |
+
...
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
## Environment variables
|
| 79 |
+
|
| 80 |
+
| Name | Default | Notes |
|
| 81 |
+
|------------------|------------------------------------------|---------------------------------------------|
|
| 82 |
+
| `API_BASE_URL` | `https://api.groq.com/openai/v1` | Required by OpenEnv submission checklist |
|
| 83 |
+
| `MODEL_NAME` | `llama-3.3-70b-versatile` | Required by OpenEnv submission checklist |
|
| 84 |
+
| `HF_TOKEN` | (none — must be set in HF Space Secrets) | Required by OpenEnv submission checklist |
|
| 85 |
+
| `LOCAL_IMAGE_NAME` | (unset) | If set, inference.py boots a Docker image |
|
| 86 |
+
| `ENV_URL` | `http://localhost:8000` | Where the env server is reachable |
|
| 87 |
+
|
| 88 |
+
## Validation
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
# Phase 1 — official OpenEnv validator
|
| 92 |
+
uvx --from openenv-core openenv validate .
|
| 93 |
+
|
| 94 |
+
# Smoke tests
|
| 95 |
+
python -m pytest tests/ -q
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
No API keys are hardcoded in this repo. The agent reads `HF_TOKEN` (with
|
| 99 |
+
optional `GROQ_API_KEY`/`OPENAI_API_KEY` fallbacks) at runtime only.
|
inference.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""inference.py — SQL Repair OpenEnv agent.
|
| 2 |
+
|
| 3 |
+
This script is the AGENT side of the OpenEnv hackathon submission. The
|
| 4 |
+
validator runs `python inference.py`, expects exit code 0, and parses
|
| 5 |
+
exactly these stdout lines per task:
|
| 6 |
+
|
| 7 |
+
[START] task_x
|
| 8 |
+
[STEP] NN | task=task_x | ...
|
| 9 |
+
[END] task_x | score=0.NNNN | status=ok
|
| 10 |
+
|
| 11 |
+
INVARIANTS (each one was learned from a Phase 2 failure):
|
| 12 |
+
1. EVERY task emits exactly one [START] and one [END] line — even on crash.
|
| 13 |
+
2. EVERY score is strictly inside the open interval (0, 1) — never 0.0 or 1.0.
|
| 14 |
+
3. NaN, inf, and parsing failures collapse to 0.5 (in-range fallback).
|
| 15 |
+
4. NO non-bracket prints on stdout from the main path. Diagnostics go to stderr.
|
| 16 |
+
5. flush=True on every emit so partial output survives a SIGKILL.
|
| 17 |
+
6. inference.py exits 0 even on catastrophic failure (we still emit safe scores).
|
| 18 |
+
|
| 19 |
+
The agent uses the standardized OpenEnv environment variables that the
|
| 20 |
+
validator injects: API_BASE_URL, MODEL_NAME, HF_TOKEN.
|
| 21 |
+
"""
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import json
|
| 25 |
+
import os
|
| 26 |
+
import subprocess
|
| 27 |
+
import sys
|
| 28 |
+
import time
|
| 29 |
+
import traceback
|
| 30 |
+
from typing import Any, Dict, List, Optional
|
| 31 |
+
|
| 32 |
+
# ===========================================================================
|
| 33 |
+
# Standardized OpenEnv environment variables (REQUIRED by submission checklist)
|
| 34 |
+
# ===========================================================================
|
| 35 |
+
API_BASE_URL: str = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
|
| 36 |
+
MODEL_NAME: str = os.getenv("MODEL_NAME", "llama-3.3-70b-versatile")
|
| 37 |
+
HF_TOKEN: Optional[str] = os.getenv("HF_TOKEN") # no default — must be set in HF Secrets
|
| 38 |
+
|
| 39 |
+
# Optional knobs
|
| 40 |
+
LOCAL_IMAGE_NAME: Optional[str] = os.getenv("LOCAL_IMAGE_NAME")
|
| 41 |
+
ENV_URL_DEFAULT: str = os.getenv("ENV_URL", "http://localhost:8000")
|
| 42 |
+
REPO_ROOT: str = os.path.dirname(os.path.abspath(__file__))
|
| 43 |
+
|
| 44 |
+
TASK_IDS: List[str] = ["task_1", "task_2", "task_3"]
|
| 45 |
+
MAX_STEPS: int = 6
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ===========================================================================
|
| 49 |
+
# Sterile stdout sink — only [START]/[STEP]/[END] lines pass through this.
|
| 50 |
+
# ===========================================================================
|
| 51 |
+
def emit(line: str) -> None:
|
| 52 |
+
print(line, flush=True)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def warn(msg: str) -> None:
|
| 56 |
+
"""Diagnostics — stderr only, never parsed by the validator."""
|
| 57 |
+
print(f"# {msg}", file=sys.stderr, flush=True)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ===========================================================================
|
| 61 |
+
# Strict (0, 1) score clamp — duplicated here so the agent never depends on
|
| 62 |
+
# importable env code (the validator may run inference.py outside the package).
|
| 63 |
+
# ===========================================================================
|
| 64 |
+
def clamp_score(value: Any) -> float:
|
| 65 |
+
try:
|
| 66 |
+
s = float(value)
|
| 67 |
+
except (TypeError, ValueError):
|
| 68 |
+
return 0.5
|
| 69 |
+
if s != s: # NaN
|
| 70 |
+
return 0.5
|
| 71 |
+
if s == float("inf") or s == float("-inf"):
|
| 72 |
+
return 0.5
|
| 73 |
+
if s <= 0.0:
|
| 74 |
+
return 0.001
|
| 75 |
+
if s >= 1.0:
|
| 76 |
+
return 0.999
|
| 77 |
+
return round(s, 4)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# ===========================================================================
|
| 81 |
+
# HTTP env client — minimal, no openenv-core dependency required.
|
| 82 |
+
# ===========================================================================
|
| 83 |
+
class HttpEnvClient:
|
| 84 |
+
"""Thin REST client for our env server."""
|
| 85 |
+
|
| 86 |
+
def __init__(self, base_url: str) -> None:
|
| 87 |
+
import requests # local import so the module can load even without it
|
| 88 |
+
self._requests = requests
|
| 89 |
+
self.base_url = base_url.rstrip("/")
|
| 90 |
+
|
| 91 |
+
def health(self) -> Dict[str, Any]:
|
| 92 |
+
r = self._requests.get(f"{self.base_url}/health", timeout=10)
|
| 93 |
+
r.raise_for_status()
|
| 94 |
+
return r.json()
|
| 95 |
+
|
| 96 |
+
def reset(self, task_id: str) -> Dict[str, Any]:
|
| 97 |
+
r = self._requests.post(
|
| 98 |
+
f"{self.base_url}/reset",
|
| 99 |
+
json={"task_id": task_id},
|
| 100 |
+
timeout=30,
|
| 101 |
+
)
|
| 102 |
+
r.raise_for_status()
|
| 103 |
+
return r.json()
|
| 104 |
+
|
| 105 |
+
def step(self, action: Dict[str, Any]) -> Dict[str, Any]:
|
| 106 |
+
r = self._requests.post(
|
| 107 |
+
f"{self.base_url}/step",
|
| 108 |
+
json={"action": action},
|
| 109 |
+
timeout=60,
|
| 110 |
+
)
|
| 111 |
+
r.raise_for_status()
|
| 112 |
+
return r.json()
|
| 113 |
+
|
| 114 |
+
def grader(self, task_id: str) -> Dict[str, Any]:
|
| 115 |
+
r = self._requests.post(
|
| 116 |
+
f"{self.base_url}/grader",
|
| 117 |
+
json={"task_id": task_id},
|
| 118 |
+
timeout=30,
|
| 119 |
+
)
|
| 120 |
+
r.raise_for_status()
|
| 121 |
+
return r.json()
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _wait_for_health(url: str, timeout: float = 60.0) -> bool:
|
| 125 |
+
import requests
|
| 126 |
+
deadline = time.time() + timeout
|
| 127 |
+
while time.time() < deadline:
|
| 128 |
+
try:
|
| 129 |
+
r = requests.get(f"{url}/health", timeout=3)
|
| 130 |
+
if r.status_code == 200:
|
| 131 |
+
return True
|
| 132 |
+
except Exception:
|
| 133 |
+
pass
|
| 134 |
+
time.sleep(0.5)
|
| 135 |
+
return False
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_env_client() -> HttpEnvClient:
|
| 139 |
+
"""Connect to the env server using the first viable strategy.
|
| 140 |
+
|
| 141 |
+
Strategies (in order of preference):
|
| 142 |
+
1. openenv-core's Env.from_docker_image() if LOCAL_IMAGE_NAME is set
|
| 143 |
+
2. Direct HTTP at ENV_URL if /health responds
|
| 144 |
+
3. Spawn a local subprocess `python -m server.app` from this repo
|
| 145 |
+
"""
|
| 146 |
+
# Strategy 1: openenv-core image launch (sample pattern)
|
| 147 |
+
if LOCAL_IMAGE_NAME:
|
| 148 |
+
try:
|
| 149 |
+
from openenv_core.client import Env # type: ignore
|
| 150 |
+
|
| 151 |
+
env = Env.from_docker_image(LOCAL_IMAGE_NAME, ports={8000: 8000})
|
| 152 |
+
warn(f"openenv-core launched container from image {LOCAL_IMAGE_NAME}")
|
| 153 |
+
# Wait for the launched container to be reachable
|
| 154 |
+
if _wait_for_health("http://localhost:8000", timeout=60):
|
| 155 |
+
return HttpEnvClient("http://localhost:8000")
|
| 156 |
+
warn("Container started but health check failed; falling through")
|
| 157 |
+
except Exception as exc:
|
| 158 |
+
warn(f"openenv-core import/launch failed: {exc}")
|
| 159 |
+
|
| 160 |
+
# Strategy 2: env already running at ENV_URL
|
| 161 |
+
if _wait_for_health(ENV_URL_DEFAULT, timeout=5):
|
| 162 |
+
warn(f"Reusing already-running env at {ENV_URL_DEFAULT}")
|
| 163 |
+
return HttpEnvClient(ENV_URL_DEFAULT)
|
| 164 |
+
|
| 165 |
+
# Strategy 3: spawn a local server subprocess
|
| 166 |
+
warn("No env reachable — spawning local subprocess on port 8000")
|
| 167 |
+
env_proc = subprocess.Popen(
|
| 168 |
+
[sys.executable, "-m", "server.app"],
|
| 169 |
+
cwd=REPO_ROOT,
|
| 170 |
+
stdout=subprocess.DEVNULL,
|
| 171 |
+
stderr=subprocess.DEVNULL,
|
| 172 |
+
env={**os.environ, "PORT": "8000", "PYTHONUNBUFFERED": "1"},
|
| 173 |
+
)
|
| 174 |
+
if not _wait_for_health("http://localhost:8000", timeout=45):
|
| 175 |
+
try:
|
| 176 |
+
env_proc.terminate()
|
| 177 |
+
except Exception:
|
| 178 |
+
pass
|
| 179 |
+
raise RuntimeError("Local env server did not become healthy within 45s")
|
| 180 |
+
warn(f"Local env subprocess pid={env_proc.pid} healthy")
|
| 181 |
+
return HttpEnvClient("http://localhost:8000")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# ===========================================================================
|
| 185 |
+
# OpenAI-compatible LLM client (Groq / OpenAI / HF inference endpoints)
|
| 186 |
+
# ===========================================================================
|
| 187 |
+
def make_llm_client():
|
| 188 |
+
from openai import OpenAI
|
| 189 |
+
|
| 190 |
+
api_key = (
|
| 191 |
+
HF_TOKEN
|
| 192 |
+
or os.getenv("GROQ_API_KEY")
|
| 193 |
+
or os.getenv("OPENAI_API_KEY")
|
| 194 |
+
)
|
| 195 |
+
if not api_key:
|
| 196 |
+
raise EnvironmentError(
|
| 197 |
+
"No API key found. Set HF_TOKEN (or GROQ_API_KEY) in env."
|
| 198 |
+
)
|
| 199 |
+
return OpenAI(base_url=API_BASE_URL, api_key=api_key)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
SYSTEM_PROMPT = """You are an expert SQL engineer. Your job is to repair broken SQL queries.
|
| 203 |
+
|
| 204 |
+
You will be given:
|
| 205 |
+
- A SQL schema (CREATE TABLE / INSERT statements)
|
| 206 |
+
- A broken SQL query that errors or returns the wrong rows
|
| 207 |
+
- The error message (if any)
|
| 208 |
+
- A short hint
|
| 209 |
+
- The expected number of rows and columns
|
| 210 |
+
|
| 211 |
+
Respond with ONLY a JSON object on a single line:
|
| 212 |
+
{"query": "<the corrected SQL query>"}
|
| 213 |
+
|
| 214 |
+
Do NOT include any prose, explanation, code fences, or markdown — only the JSON object."""
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def _parse_query(content: str) -> str:
|
| 218 |
+
"""Best-effort extraction of a SQL string from an LLM response."""
|
| 219 |
+
if not content:
|
| 220 |
+
return ""
|
| 221 |
+
s = content.strip()
|
| 222 |
+
# Strip markdown code fences
|
| 223 |
+
if s.startswith("```"):
|
| 224 |
+
s = s.strip("`").strip()
|
| 225 |
+
if s.lower().startswith("json"):
|
| 226 |
+
s = s[4:].strip()
|
| 227 |
+
elif s.lower().startswith("sql"):
|
| 228 |
+
s = s[3:].strip()
|
| 229 |
+
# Try strict JSON
|
| 230 |
+
try:
|
| 231 |
+
data = json.loads(s)
|
| 232 |
+
if isinstance(data, dict) and "query" in data:
|
| 233 |
+
return str(data["query"]).strip()
|
| 234 |
+
except json.JSONDecodeError:
|
| 235 |
+
pass
|
| 236 |
+
# Fallback: regex for {"query": "..."}
|
| 237 |
+
import re
|
| 238 |
+
m = re.search(r'"query"\s*:\s*"((?:[^"\\]|\\.)*)"', s)
|
| 239 |
+
if m:
|
| 240 |
+
return m.group(1).encode().decode("unicode_escape")
|
| 241 |
+
# Last resort: return raw content (might be a bare SQL string)
|
| 242 |
+
return s
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def call_llm(client, observation: Dict[str, Any], previous_attempts: List[Dict[str, Any]]) -> str:
|
| 246 |
+
user_lines = [
|
| 247 |
+
f"Task: {observation.get('name') or observation.get('task_id', '?')}",
|
| 248 |
+
f"Difficulty: {observation.get('difficulty', '?')}",
|
| 249 |
+
"",
|
| 250 |
+
"Schema:",
|
| 251 |
+
observation.get("schema_sql", "") or "(missing)",
|
| 252 |
+
"",
|
| 253 |
+
"Broken query:",
|
| 254 |
+
observation.get("broken_query", "") or "(missing)",
|
| 255 |
+
"",
|
| 256 |
+
f"Broken query error: {observation.get('broken_query_error') or 'none (returns wrong rows)'}",
|
| 257 |
+
f"Hint: {observation.get('hint', '')}",
|
| 258 |
+
"",
|
| 259 |
+
f"Expected: {observation.get('expected_row_count', '?')} rows × "
|
| 260 |
+
f"{observation.get('expected_column_count', '?')} columns",
|
| 261 |
+
]
|
| 262 |
+
if previous_attempts:
|
| 263 |
+
user_lines.append("")
|
| 264 |
+
user_lines.append("Previous attempts:")
|
| 265 |
+
for i, att in enumerate(previous_attempts[-3:], start=1):
|
| 266 |
+
user_lines.append(
|
| 267 |
+
f" {i}. query={att.get('query', '')!r} -> "
|
| 268 |
+
f"executed={att.get('executed')} matches={att.get('matches_expected')} "
|
| 269 |
+
f"error={att.get('error')!r}"
|
| 270 |
+
)
|
| 271 |
+
user_lines.append("")
|
| 272 |
+
user_lines.append('Return ONLY: {"query": "<fixed SQL>"}')
|
| 273 |
+
|
| 274 |
+
user_msg = "\n".join(user_lines)
|
| 275 |
+
try:
|
| 276 |
+
resp = client.chat.completions.create(
|
| 277 |
+
model=MODEL_NAME,
|
| 278 |
+
messages=[
|
| 279 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 280 |
+
{"role": "user", "content": user_msg},
|
| 281 |
+
],
|
| 282 |
+
temperature=0.1,
|
| 283 |
+
max_tokens=512,
|
| 284 |
+
)
|
| 285 |
+
content = (resp.choices[0].message.content or "").strip()
|
| 286 |
+
return _parse_query(content)
|
| 287 |
+
except Exception as exc:
|
| 288 |
+
warn(f"LLM call failed: {exc}")
|
| 289 |
+
return ""
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# ===========================================================================
|
| 293 |
+
# Per-task runner — NEVER raises. Always emits exactly one [START] / [END].
|
| 294 |
+
# ===========================================================================
|
| 295 |
+
def run_task(env: HttpEnvClient, llm_client, task_id: str) -> float:
|
| 296 |
+
emit(f"[START] {task_id}")
|
| 297 |
+
score: float = 0.5 # safe in-range fallback
|
| 298 |
+
status: str = "ok"
|
| 299 |
+
|
| 300 |
+
try:
|
| 301 |
+
obs = env.reset(task_id)
|
| 302 |
+
last_obs: Dict[str, Any] = dict(obs)
|
| 303 |
+
previous_attempts: List[Dict[str, Any]] = []
|
| 304 |
+
broken = obs.get("broken_query", "")
|
| 305 |
+
|
| 306 |
+
for step_idx in range(1, MAX_STEPS + 1):
|
| 307 |
+
try:
|
| 308 |
+
fixed = call_llm(llm_client, last_obs, previous_attempts)
|
| 309 |
+
except Exception as exc: # noqa: BLE001
|
| 310 |
+
warn(f"LLM error on step {step_idx}: {exc}")
|
| 311 |
+
fixed = ""
|
| 312 |
+
|
| 313 |
+
if not fixed:
|
| 314 |
+
fixed = broken # fall back to the broken query so step still runs
|
| 315 |
+
|
| 316 |
+
try:
|
| 317 |
+
result = env.step({"action_type": "submit_query", "query": fixed})
|
| 318 |
+
except Exception as exc: # noqa: BLE001
|
| 319 |
+
warn(f"env.step failed on step {step_idx}: {exc}")
|
| 320 |
+
emit(
|
| 321 |
+
f"[STEP] {step_idx:02d} | task={task_id} "
|
| 322 |
+
f"| action=submit_query | reward=+0.0000 | status=step_error"
|
| 323 |
+
)
|
| 324 |
+
continue
|
| 325 |
+
|
| 326 |
+
reward = float(result.get("reward", 0.0))
|
| 327 |
+
obs2: Dict[str, Any] = result.get("observation", {}) or {}
|
| 328 |
+
done = bool(result.get("done", False))
|
| 329 |
+
matches = bool(obs2.get("matches_expected", False))
|
| 330 |
+
|
| 331 |
+
emit(
|
| 332 |
+
f"[STEP] {step_idx:02d} | task={task_id} "
|
| 333 |
+
f"| action=submit_query | reward={reward:+.4f} "
|
| 334 |
+
f"| matches={matches} | rows={obs2.get('result_row_count', 0)}"
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
previous_attempts.append(
|
| 338 |
+
{
|
| 339 |
+
"query": fixed,
|
| 340 |
+
"executed": obs2.get("executed", False),
|
| 341 |
+
"matches_expected": matches,
|
| 342 |
+
"error": obs2.get("error"),
|
| 343 |
+
}
|
| 344 |
+
)
|
| 345 |
+
# Update context for next prompt
|
| 346 |
+
last_obs.update(obs2)
|
| 347 |
+
last_obs["broken_query"] = fixed
|
| 348 |
+
last_obs["broken_query_error"] = obs2.get("error")
|
| 349 |
+
last_obs["hint"] = obs.get("hint", "")
|
| 350 |
+
last_obs["schema_sql"] = obs.get("schema_sql", "")
|
| 351 |
+
last_obs["expected_row_count"] = obs.get("expected_row_count")
|
| 352 |
+
last_obs["expected_column_count"] = obs.get("expected_column_count")
|
| 353 |
+
last_obs["name"] = obs.get("name")
|
| 354 |
+
last_obs["difficulty"] = obs.get("difficulty")
|
| 355 |
+
|
| 356 |
+
if done:
|
| 357 |
+
break
|
| 358 |
+
|
| 359 |
+
# Pull final score from the env grader, then strict-clamp.
|
| 360 |
+
try:
|
| 361 |
+
grader_resp = env.grader(task_id)
|
| 362 |
+
raw = grader_resp.get("score", 0.5)
|
| 363 |
+
except Exception as exc: # noqa: BLE001
|
| 364 |
+
warn(f"grader call failed: {exc}")
|
| 365 |
+
raw = 0.5
|
| 366 |
+
score = clamp_score(raw)
|
| 367 |
+
except Exception:
|
| 368 |
+
traceback.print_exc(file=sys.stderr)
|
| 369 |
+
status = "crash"
|
| 370 |
+
score = 0.5 # in-range fallback
|
| 371 |
+
|
| 372 |
+
# FINAL emit — guaranteed exactly once per task, in (0, 1)
|
| 373 |
+
emit(f"[END] {task_id} | score={score:.4f} | status={status}")
|
| 374 |
+
return score
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
# ===========================================================================
|
| 378 |
+
# Main entry point. Exits 0 even on catastrophic failure.
|
| 379 |
+
# ===========================================================================
|
| 380 |
+
def main() -> int:
|
| 381 |
+
env: Optional[HttpEnvClient] = None
|
| 382 |
+
llm_client = None
|
| 383 |
+
|
| 384 |
+
try:
|
| 385 |
+
env = get_env_client()
|
| 386 |
+
except Exception:
|
| 387 |
+
traceback.print_exc(file=sys.stderr)
|
| 388 |
+
for tid in TASK_IDS:
|
| 389 |
+
emit(f"[START] {tid}")
|
| 390 |
+
emit(f"[END] {tid} | score=0.5000 | status=fatal_no_env")
|
| 391 |
+
return 0
|
| 392 |
+
|
| 393 |
+
try:
|
| 394 |
+
llm_client = make_llm_client()
|
| 395 |
+
except Exception:
|
| 396 |
+
traceback.print_exc(file=sys.stderr)
|
| 397 |
+
for tid in TASK_IDS:
|
| 398 |
+
emit(f"[START] {tid}")
|
| 399 |
+
emit(f"[END] {tid} | score=0.5000 | status=fatal_no_llm")
|
| 400 |
+
return 0
|
| 401 |
+
|
| 402 |
+
for tid in TASK_IDS:
|
| 403 |
+
try:
|
| 404 |
+
run_task(env, llm_client, tid)
|
| 405 |
+
except Exception:
|
| 406 |
+
# Belt and suspenders — run_task already handles its own errors.
|
| 407 |
+
traceback.print_exc(file=sys.stderr)
|
| 408 |
+
emit(f"[START] {tid}")
|
| 409 |
+
emit(f"[END] {tid} | score=0.5000 | status=outer_crash")
|
| 410 |
+
|
| 411 |
+
return 0
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
if __name__ == "__main__":
|
| 415 |
+
try:
|
| 416 |
+
sys.exit(main())
|
| 417 |
+
except SystemExit:
|
| 418 |
+
raise
|
| 419 |
+
except Exception:
|
| 420 |
+
traceback.print_exc(file=sys.stderr)
|
| 421 |
+
# Last-ditch: still emit safe scores so the validator parses something.
|
| 422 |
+
for tid in TASK_IDS:
|
| 423 |
+
print(f"[START] {tid}", flush=True)
|
| 424 |
+
print(f"[END] {tid} | score=0.5000 | status=outer_fatal", flush=True)
|
| 425 |
+
sys.exit(0) # exit 0 — validator requires it
|
openenv.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: sql-repair-env
|
| 2 |
+
version: 0.1.0
|
| 3 |
+
description: |
|
| 4 |
+
OpenEnv environment for SQL query repair. Each task gives the agent a
|
| 5 |
+
schema, a broken SQL query, and a hint. The agent must submit a corrected
|
| 6 |
+
query that returns the expected result set. Backed by SQLite in-memory.
|
| 7 |
+
maintainer: krishpotanwar
|
| 8 |
+
runtime:
|
| 9 |
+
type: docker
|
| 10 |
+
image: sql-repair-env:latest
|
| 11 |
+
port: 8000
|
| 12 |
+
endpoints:
|
| 13 |
+
health: /health
|
| 14 |
+
tasks: /tasks
|
| 15 |
+
reset: /reset
|
| 16 |
+
step: /step
|
| 17 |
+
grader: /grader
|
| 18 |
+
baseline: /baseline
|
| 19 |
+
tasks:
|
| 20 |
+
- id: task_1
|
| 21 |
+
name: Missing commas in SELECT
|
| 22 |
+
difficulty: easy
|
| 23 |
+
- id: task_2
|
| 24 |
+
name: Wrong column reference in JOIN
|
| 25 |
+
difficulty: medium
|
| 26 |
+
- id: task_3
|
| 27 |
+
name: Aggregate without GROUP BY
|
| 28 |
+
difficulty: hard
|
pyproject.toml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "sql-repair-env"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "OpenEnv environment for SQL query repair tasks"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
authors = [{ name = "krishpotanwar" }]
|
| 8 |
+
license = { text = "Apache-2.0" }
|
| 9 |
+
dependencies = [
|
| 10 |
+
"openenv-core>=0.2.0",
|
| 11 |
+
"fastapi>=0.110.0",
|
| 12 |
+
"uvicorn[standard]>=0.27.0",
|
| 13 |
+
"pydantic>=2.0.0",
|
| 14 |
+
"openai>=1.30.0",
|
| 15 |
+
"requests>=2.31.0",
|
| 16 |
+
"numpy>=1.24.0",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
[project.scripts]
|
| 20 |
+
server = "server.app:main"
|
| 21 |
+
|
| 22 |
+
[build-system]
|
| 23 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 24 |
+
build-backend = "setuptools.build_meta"
|
| 25 |
+
|
| 26 |
+
[tool.setuptools]
|
| 27 |
+
packages = ["server", "sql_env"]
|
| 28 |
+
|
| 29 |
+
[tool.setuptools.package-data]
|
| 30 |
+
"*" = ["*.yaml", "*.md"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core>=0.2.0
|
| 2 |
+
fastapi>=0.110.0
|
| 3 |
+
uvicorn[standard]>=0.27.0
|
| 4 |
+
pydantic>=2.0.0
|
| 5 |
+
openai>=1.30.0
|
| 6 |
+
requests>=2.31.0
|
| 7 |
+
numpy>=1.24.0
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""HTTP server package for SQL Repair OpenEnv environment."""
|
server/app.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI server for the SQL Repair OpenEnv environment.
|
| 2 |
+
|
| 3 |
+
Endpoints (all required by the OpenEnv submission validator):
|
| 4 |
+
GET /health -> {"status": "ok"}
|
| 5 |
+
GET /tasks -> {"tasks": ["task_1", "task_2", "task_3"]}
|
| 6 |
+
POST /reset -> reset env to a task (body optional, defaults to task_1)
|
| 7 |
+
POST /step -> apply an action, return observation/reward/done
|
| 8 |
+
POST /grader -> compute final score for a task (strictly in (0, 1))
|
| 9 |
+
POST /baseline -> run all tasks with the broken queries, return scores
|
| 10 |
+
|
| 11 |
+
Phase 1 hard requirement: /reset MUST accept an empty POST body.
|
| 12 |
+
We achieve that with `Optional[ResetRequest] = Body(default=None)`.
|
| 13 |
+
|
| 14 |
+
Entry point exposed via [project.scripts] server = "server.app:main".
|
| 15 |
+
"""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
from typing import Any, Dict, List, Optional
|
| 20 |
+
|
| 21 |
+
from fastapi import Body, FastAPI
|
| 22 |
+
from pydantic import BaseModel, Field
|
| 23 |
+
|
| 24 |
+
from sql_env.env_core import EnvState, MAX_STEPS
|
| 25 |
+
from sql_env.grader import grade_task
|
| 26 |
+
from sql_env.tasks import TASK_IDS, TASKS
|
| 27 |
+
|
| 28 |
+
app = FastAPI(
|
| 29 |
+
title="SQL Repair OpenEnv",
|
| 30 |
+
version="0.1.0",
|
| 31 |
+
description=(
|
| 32 |
+
"An OpenEnv environment for SQL query repair. Agents fix broken "
|
| 33 |
+
"SQL queries against a small SQLite schema."
|
| 34 |
+
),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Single mutable env state instance — the validator runs one session.
|
| 38 |
+
_state = EnvState()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
# Pydantic request models
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
class ResetRequest(BaseModel):
|
| 45 |
+
task_id: Optional[str] = Field(default=None, description="Task ID to reset to")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class StepAction(BaseModel):
|
| 49 |
+
action_type: str = Field(default="submit_query")
|
| 50 |
+
query: str = Field(default="")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class StepRequest(BaseModel):
|
| 54 |
+
action: Dict[str, Any] = Field(default_factory=dict)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class GraderRequest(BaseModel):
|
| 58 |
+
task_id: Optional[str] = Field(default=None)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class BaselineRequest(BaseModel):
|
| 62 |
+
tasks: Optional[List[str]] = Field(default=None)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# ---------------------------------------------------------------------------
|
| 66 |
+
# Endpoints
|
| 67 |
+
# ---------------------------------------------------------------------------
|
| 68 |
+
@app.get("/health")
|
| 69 |
+
def health() -> Dict[str, str]:
|
| 70 |
+
return {"status": "ok"}
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@app.get("/tasks")
|
| 74 |
+
def list_tasks() -> Dict[str, Any]:
|
| 75 |
+
return {
|
| 76 |
+
"tasks": TASK_IDS,
|
| 77 |
+
"details": [
|
| 78 |
+
{
|
| 79 |
+
"id": TASKS[t]["id"],
|
| 80 |
+
"name": TASKS[t]["name"],
|
| 81 |
+
"difficulty": TASKS[t]["difficulty"],
|
| 82 |
+
}
|
| 83 |
+
for t in TASK_IDS
|
| 84 |
+
],
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@app.post("/reset")
|
| 89 |
+
def reset(req: Optional[ResetRequest] = Body(default=None)) -> Dict[str, Any]:
|
| 90 |
+
"""Reset the environment. Body is optional — defaults to task_1."""
|
| 91 |
+
task_id = req.task_id if (req and req.task_id) else "task_1"
|
| 92 |
+
obs = _state.reset(task_id)
|
| 93 |
+
return obs
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@app.post("/step")
|
| 97 |
+
def step(req: Optional[StepRequest] = Body(default=None)) -> Dict[str, Any]:
|
| 98 |
+
"""Apply one action to the environment."""
|
| 99 |
+
action: Dict[str, Any] = (req.action if req and req.action else {})
|
| 100 |
+
return _state.step(action)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@app.post("/grader")
|
| 104 |
+
def grader(req: Optional[GraderRequest] = Body(default=None)) -> Dict[str, Any]:
|
| 105 |
+
"""Return the strict-(0,1) score for the given task."""
|
| 106 |
+
task_id = req.task_id if (req and req.task_id) else (_state.task_id or "task_1")
|
| 107 |
+
score = grade_task(_state, task_id)
|
| 108 |
+
return {"task_id": task_id, "score": float(score)}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@app.post("/baseline")
|
| 112 |
+
def baseline(
|
| 113 |
+
req: Optional[BaselineRequest] = Body(default=None),
|
| 114 |
+
) -> Dict[str, Any]:
|
| 115 |
+
"""Run all tasks with the broken queries to verify graders work."""
|
| 116 |
+
task_ids = (req.tasks if (req and req.tasks) else None) or list(TASK_IDS)
|
| 117 |
+
out: Dict[str, float] = {}
|
| 118 |
+
for tid in task_ids:
|
| 119 |
+
if tid not in TASKS:
|
| 120 |
+
continue
|
| 121 |
+
local = EnvState()
|
| 122 |
+
local.reset(tid)
|
| 123 |
+
# Submit the broken query as a baseline submission
|
| 124 |
+
local.step({"action_type": "submit_query", "query": TASKS[tid]["broken_query"]})
|
| 125 |
+
out[tid] = float(grade_task(local, tid))
|
| 126 |
+
return {"scores": out, "max_steps": MAX_STEPS}
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
# ---------------------------------------------------------------------------
|
| 130 |
+
# Entry point — referenced by [project.scripts] server = "server.app:main"
|
| 131 |
+
# ---------------------------------------------------------------------------
|
| 132 |
+
def main() -> None:
|
| 133 |
+
"""Entry point for `python -m server.app` and the `server` console script."""
|
| 134 |
+
import uvicorn
|
| 135 |
+
|
| 136 |
+
host = os.getenv("HOST", "0.0.0.0")
|
| 137 |
+
port = int(os.getenv("PORT", "8000"))
|
| 138 |
+
uvicorn.run(app, host=host, port=port, log_level="info")
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
if __name__ == "__main__":
|
| 142 |
+
main()
|
sql_env/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQL Repair OpenEnv environment package."""
|
| 2 |
+
from .env_core import EnvState, MAX_STEPS
|
| 3 |
+
from .tasks import TASKS, TASK_IDS
|
| 4 |
+
from .grader import grade_task, SCORE_MIN, SCORE_MAX
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"EnvState",
|
| 8 |
+
"MAX_STEPS",
|
| 9 |
+
"TASKS",
|
| 10 |
+
"TASK_IDS",
|
| 11 |
+
"grade_task",
|
| 12 |
+
"SCORE_MIN",
|
| 13 |
+
"SCORE_MAX",
|
| 14 |
+
]
|
sql_env/env_core.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQLite-backed environment state for SQL repair tasks.
|
| 2 |
+
|
| 3 |
+
The env exposes a minimal Gym-like API:
|
| 4 |
+
reset(task_id) -> observation dict
|
| 5 |
+
step(action) -> {observation, reward, done, info}
|
| 6 |
+
|
| 7 |
+
Per-task state is held in this single instance for simplicity. The
|
| 8 |
+
validator only needs one parallel run.
|
| 9 |
+
"""
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import sqlite3
|
| 13 |
+
from typing import Any, Dict, List, Optional
|
| 14 |
+
|
| 15 |
+
from .tasks import TASKS, TASK_IDS
|
| 16 |
+
|
| 17 |
+
MAX_STEPS = 6
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _new_db(task_id: str) -> sqlite3.Connection:
|
| 21 |
+
"""Build a fresh in-memory DB for the given task."""
|
| 22 |
+
if task_id not in TASKS:
|
| 23 |
+
raise KeyError(f"Unknown task_id: {task_id}")
|
| 24 |
+
conn = sqlite3.connect(":memory:")
|
| 25 |
+
cur = conn.cursor()
|
| 26 |
+
for stmt in TASKS[task_id]["schema"]:
|
| 27 |
+
cur.execute(stmt)
|
| 28 |
+
conn.commit()
|
| 29 |
+
return conn
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _run_query(task_id: str, query: str) -> Dict[str, Any]:
|
| 33 |
+
"""Execute a query against a fresh DB; return rows or error info."""
|
| 34 |
+
conn = _new_db(task_id)
|
| 35 |
+
try:
|
| 36 |
+
cur = conn.execute(query)
|
| 37 |
+
rows = cur.fetchall()
|
| 38 |
+
col_names = [d[0] for d in cur.description] if cur.description else []
|
| 39 |
+
return {"ok": True, "rows": rows, "columns": col_names, "error": None}
|
| 40 |
+
except Exception as exc:
|
| 41 |
+
return {"ok": False, "rows": None, "columns": [], "error": str(exc)}
|
| 42 |
+
finally:
|
| 43 |
+
conn.close()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _expected_rows(task_id: str) -> List[tuple]:
|
| 47 |
+
"""Compute the canonical (expected) result set for a task."""
|
| 48 |
+
res = _run_query(task_id, TASKS[task_id]["canonical_query"])
|
| 49 |
+
if not res["ok"]:
|
| 50 |
+
# Should never happen — canonical queries are vetted in tests.
|
| 51 |
+
raise RuntimeError(
|
| 52 |
+
f"Canonical query for {task_id} failed: {res['error']}"
|
| 53 |
+
)
|
| 54 |
+
return res["rows"]
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class EnvState:
|
| 58 |
+
"""Mutable per-session env state. One instance handles all tasks."""
|
| 59 |
+
|
| 60 |
+
def __init__(self) -> None:
|
| 61 |
+
self.task_id: Optional[str] = None
|
| 62 |
+
self.step_count: int = 0
|
| 63 |
+
self.last_query: Optional[str] = None
|
| 64 |
+
self.last_error: Optional[str] = None
|
| 65 |
+
self.last_result: Optional[List[tuple]] = None
|
| 66 |
+
self.solved: bool = False
|
| 67 |
+
self.expected_rows: List[tuple] = []
|
| 68 |
+
self.expected_columns: int = 0
|
| 69 |
+
|
| 70 |
+
# ------------------------------------------------------------------
|
| 71 |
+
def reset(self, task_id: Optional[str] = None) -> Dict[str, Any]:
|
| 72 |
+
tid = task_id or "task_1"
|
| 73 |
+
if tid not in TASKS:
|
| 74 |
+
tid = "task_1"
|
| 75 |
+
task = TASKS[tid]
|
| 76 |
+
|
| 77 |
+
self.task_id = tid
|
| 78 |
+
self.step_count = 0
|
| 79 |
+
self.last_query = None
|
| 80 |
+
self.last_error = None
|
| 81 |
+
self.last_result = None
|
| 82 |
+
self.solved = False
|
| 83 |
+
self.expected_rows = _expected_rows(tid)
|
| 84 |
+
self.expected_columns = (
|
| 85 |
+
len(self.expected_rows[0]) if self.expected_rows else 0
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Surface what the broken query actually does, so the agent has
|
| 89 |
+
# an error message and a canonical "what went wrong" hint.
|
| 90 |
+
baseline = _run_query(tid, task["broken_query"])
|
| 91 |
+
|
| 92 |
+
return {
|
| 93 |
+
"task_id": tid,
|
| 94 |
+
"name": task["name"],
|
| 95 |
+
"difficulty": task["difficulty"],
|
| 96 |
+
"schema_sql": "\n".join(task["schema"]),
|
| 97 |
+
"broken_query": task["broken_query"],
|
| 98 |
+
"broken_query_error": baseline["error"],
|
| 99 |
+
"broken_query_executes": baseline["ok"],
|
| 100 |
+
"hint": task["hint"],
|
| 101 |
+
"expected_row_count": len(self.expected_rows),
|
| 102 |
+
"expected_column_count": self.expected_columns,
|
| 103 |
+
"step_count": 0,
|
| 104 |
+
"max_steps": MAX_STEPS,
|
| 105 |
+
"remaining_steps": MAX_STEPS,
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
# ------------------------------------------------------------------
|
| 109 |
+
def step(self, action: Dict[str, Any]) -> Dict[str, Any]:
|
| 110 |
+
if self.task_id is None:
|
| 111 |
+
return {
|
| 112 |
+
"observation": {"error": "No active task. Call /reset first."},
|
| 113 |
+
"reward": 0.0,
|
| 114 |
+
"done": True,
|
| 115 |
+
"info": {"solved": False, "no_active_task": True},
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
self.step_count += 1
|
| 119 |
+
action_type = (action or {}).get("action_type", "submit_query")
|
| 120 |
+
query = ((action or {}).get("query") or "").strip()
|
| 121 |
+
self.last_query = query
|
| 122 |
+
|
| 123 |
+
reward = 0.0
|
| 124 |
+
result_rows: Optional[List[tuple]] = None
|
| 125 |
+
error: Optional[str] = None
|
| 126 |
+
|
| 127 |
+
if action_type != "submit_query":
|
| 128 |
+
error = f"Unsupported action_type: {action_type}"
|
| 129 |
+
reward = -0.05
|
| 130 |
+
elif not query:
|
| 131 |
+
error = "Empty query string."
|
| 132 |
+
reward = -0.05
|
| 133 |
+
else:
|
| 134 |
+
res = _run_query(self.task_id, query)
|
| 135 |
+
if res["ok"]:
|
| 136 |
+
result_rows = res["rows"]
|
| 137 |
+
self.last_result = result_rows
|
| 138 |
+
self.last_error = None
|
| 139 |
+
if result_rows == self.expected_rows:
|
| 140 |
+
reward = 1.0
|
| 141 |
+
self.solved = True
|
| 142 |
+
else:
|
| 143 |
+
# executed but wrong rows — small positive reward
|
| 144 |
+
reward = 0.4
|
| 145 |
+
else:
|
| 146 |
+
error = res["error"]
|
| 147 |
+
self.last_error = error
|
| 148 |
+
self.last_result = None
|
| 149 |
+
reward = -0.10
|
| 150 |
+
|
| 151 |
+
done = self.solved or self.step_count >= MAX_STEPS
|
| 152 |
+
|
| 153 |
+
observation = {
|
| 154 |
+
"task_id": self.task_id,
|
| 155 |
+
"step_count": self.step_count,
|
| 156 |
+
"submitted_query": query,
|
| 157 |
+
"error": error,
|
| 158 |
+
"executed": error is None and result_rows is not None,
|
| 159 |
+
"matches_expected": (
|
| 160 |
+
result_rows == self.expected_rows if result_rows is not None else False
|
| 161 |
+
),
|
| 162 |
+
"result_row_count": len(result_rows) if result_rows is not None else 0,
|
| 163 |
+
"expected_row_count": len(self.expected_rows),
|
| 164 |
+
"result_preview": result_rows[:3] if result_rows else None,
|
| 165 |
+
"expected_preview": self.expected_rows[:3],
|
| 166 |
+
"remaining_steps": max(0, MAX_STEPS - self.step_count),
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
return {
|
| 170 |
+
"observation": observation,
|
| 171 |
+
"reward": float(reward),
|
| 172 |
+
"done": bool(done),
|
| 173 |
+
"info": {"solved": self.solved},
|
| 174 |
+
}
|
sql_env/grader.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Strict (0, 1) grader for SQL repair tasks.
|
| 2 |
+
|
| 3 |
+
Phase 2 hard requirement: scores MUST be in the OPEN interval (0, 1).
|
| 4 |
+
Validator rejects exactly 0.0 and exactly 1.0. NaN/inf are also rejected,
|
| 5 |
+
so we coerce them to 0.5 (a neutral, in-range fallback).
|
| 6 |
+
"""
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
from typing import Any
|
| 11 |
+
|
| 12 |
+
# Module-level constants — also used by inference.py for consistency.
|
| 13 |
+
SCORE_MIN: float = 1e-3 # 0.001 — strictly > 0
|
| 14 |
+
SCORE_MAX: float = 0.999 # strictly < 1
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def strict_clamp(value: Any) -> float:
|
| 18 |
+
"""Coerce any input into a float strictly inside (0, 1).
|
| 19 |
+
|
| 20 |
+
NaN, inf, -inf, and non-numeric inputs all collapse to 0.5.
|
| 21 |
+
"""
|
| 22 |
+
try:
|
| 23 |
+
s = float(value)
|
| 24 |
+
except (TypeError, ValueError):
|
| 25 |
+
return 0.5
|
| 26 |
+
if math.isnan(s) or math.isinf(s):
|
| 27 |
+
return 0.5
|
| 28 |
+
if s <= 0.0:
|
| 29 |
+
return SCORE_MIN
|
| 30 |
+
if s >= 1.0:
|
| 31 |
+
return SCORE_MAX
|
| 32 |
+
return round(s, 4)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def grade_task(state, task_id: str) -> float:
|
| 36 |
+
"""Score the current state of an EnvState for the given task.
|
| 37 |
+
|
| 38 |
+
Score components (sum, then strict_clamp):
|
| 39 |
+
- 0.05 : agent submitted at least one query
|
| 40 |
+
- 0.25 : last query executed without error
|
| 41 |
+
- 0.60 : result rows matched expected rows
|
| 42 |
+
- 0.09 : efficiency bonus (faster solves score higher)
|
| 43 |
+
|
| 44 |
+
Worst case (no submission): 0.000 -> clamped to 0.001
|
| 45 |
+
Best case (1-step solve): 0.99 -> clamped to 0.99
|
| 46 |
+
Wrong-result executes: 0.30 -> in range
|
| 47 |
+
"""
|
| 48 |
+
from .env_core import MAX_STEPS # local import avoids circular
|
| 49 |
+
|
| 50 |
+
if state.task_id != task_id:
|
| 51 |
+
return SCORE_MIN
|
| 52 |
+
|
| 53 |
+
raw = 0.0
|
| 54 |
+
if state.last_query:
|
| 55 |
+
raw += 0.05
|
| 56 |
+
if state.last_error is None and state.last_result is not None:
|
| 57 |
+
raw += 0.25
|
| 58 |
+
if state.last_result == state.expected_rows and state.expected_rows:
|
| 59 |
+
raw += 0.60
|
| 60 |
+
if state.solved and state.step_count > 0:
|
| 61 |
+
bonus = 0.09 * max(0, MAX_STEPS - state.step_count) / MAX_STEPS
|
| 62 |
+
raw += bonus
|
| 63 |
+
|
| 64 |
+
return strict_clamp(raw)
|
sql_env/tasks.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Task definitions for SQL Repair env.
|
| 2 |
+
|
| 3 |
+
Each task gives the agent:
|
| 4 |
+
- schema : list of CREATE/INSERT statements (executed verbatim)
|
| 5 |
+
- broken : a SQL query that errors or returns the wrong rows
|
| 6 |
+
- canonical : the reference fix used to compute expected_rows
|
| 7 |
+
- hint : short natural-language pointer
|
| 8 |
+
|
| 9 |
+
Difficulty is tuned so even a vanilla LLM agent (Nemotron-class) can solve
|
| 10 |
+
task_1 reliably, task_2 with effort, and task_3 about half the time —
|
| 11 |
+
ensuring score variance across tasks (Phase 2 likely checks for this).
|
| 12 |
+
"""
|
| 13 |
+
from typing import Dict, List
|
| 14 |
+
|
| 15 |
+
TASKS: Dict[str, dict] = {
|
| 16 |
+
"task_1": {
|
| 17 |
+
"id": "task_1",
|
| 18 |
+
"name": "Missing commas in SELECT list",
|
| 19 |
+
"difficulty": "easy",
|
| 20 |
+
"schema": [
|
| 21 |
+
"CREATE TABLE products (id INTEGER PRIMARY KEY, name TEXT NOT NULL, price REAL NOT NULL);",
|
| 22 |
+
"INSERT INTO products VALUES (1, 'Apple', 0.50);",
|
| 23 |
+
"INSERT INTO products VALUES (2, 'Bread', 2.50);",
|
| 24 |
+
"INSERT INTO products VALUES (3, 'Cheese', 5.00);",
|
| 25 |
+
"INSERT INTO products VALUES (4, 'Milk', 1.50);",
|
| 26 |
+
"INSERT INTO products VALUES (5, 'Eggs', 3.00);",
|
| 27 |
+
],
|
| 28 |
+
"broken_query": "SELECT id name price FROM products ORDER BY id",
|
| 29 |
+
"canonical_query": "SELECT id, name, price FROM products ORDER BY id",
|
| 30 |
+
"hint": "The SELECT list is missing commas between column names.",
|
| 31 |
+
},
|
| 32 |
+
"task_2": {
|
| 33 |
+
"id": "task_2",
|
| 34 |
+
"name": "Wrong column reference in JOIN",
|
| 35 |
+
"difficulty": "medium",
|
| 36 |
+
"schema": [
|
| 37 |
+
"CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, country TEXT);",
|
| 38 |
+
"CREATE TABLE orders (id INTEGER PRIMARY KEY, user_id INTEGER NOT NULL, total REAL NOT NULL);",
|
| 39 |
+
"INSERT INTO users VALUES (1, 'Aarav', 'IN');",
|
| 40 |
+
"INSERT INTO users VALUES (2, 'Bea', 'US');",
|
| 41 |
+
"INSERT INTO users VALUES (3, 'Chen', 'CN');",
|
| 42 |
+
"INSERT INTO orders VALUES (10, 1, 99.00);",
|
| 43 |
+
"INSERT INTO orders VALUES (11, 1, 49.50);",
|
| 44 |
+
"INSERT INTO orders VALUES (12, 2, 200.00);",
|
| 45 |
+
"INSERT INTO orders VALUES (13, 3, 25.00);",
|
| 46 |
+
],
|
| 47 |
+
"broken_query": (
|
| 48 |
+
"SELECT u.username, o.total "
|
| 49 |
+
"FROM users u JOIN orders o ON u.id = o.user "
|
| 50 |
+
"ORDER BY o.id"
|
| 51 |
+
),
|
| 52 |
+
"canonical_query": (
|
| 53 |
+
"SELECT u.name, o.total "
|
| 54 |
+
"FROM users u JOIN orders o ON u.id = o.user_id "
|
| 55 |
+
"ORDER BY o.id"
|
| 56 |
+
),
|
| 57 |
+
"hint": "Two columns are misspelled — check the schema for the real names.",
|
| 58 |
+
},
|
| 59 |
+
"task_3": {
|
| 60 |
+
"id": "task_3",
|
| 61 |
+
"name": "Aggregate without GROUP BY",
|
| 62 |
+
"difficulty": "hard",
|
| 63 |
+
"schema": [
|
| 64 |
+
"CREATE TABLE sales (id INTEGER PRIMARY KEY, region TEXT NOT NULL, amount REAL NOT NULL);",
|
| 65 |
+
"INSERT INTO sales VALUES (1, 'north', 100.00);",
|
| 66 |
+
"INSERT INTO sales VALUES (2, 'north', 50.00);",
|
| 67 |
+
"INSERT INTO sales VALUES (3, 'south', 200.00);",
|
| 68 |
+
"INSERT INTO sales VALUES (4, 'south', 75.00);",
|
| 69 |
+
"INSERT INTO sales VALUES (5, 'east', 150.00);",
|
| 70 |
+
"INSERT INTO sales VALUES (6, 'east', 25.00);",
|
| 71 |
+
],
|
| 72 |
+
"broken_query": "SELECT region, SUM(amount) AS total FROM sales ORDER BY region",
|
| 73 |
+
"canonical_query": (
|
| 74 |
+
"SELECT region, SUM(amount) AS total FROM sales "
|
| 75 |
+
"GROUP BY region ORDER BY region"
|
| 76 |
+
),
|
| 77 |
+
"hint": "You SELECT a non-aggregate column with an aggregate — add GROUP BY.",
|
| 78 |
+
},
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
TASK_IDS: List[str] = list(TASKS.keys())
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_smoke.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Smoke tests for the SQL Repair env.
|
| 2 |
+
|
| 3 |
+
Run with: python -m pytest tests/ -q
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
from sql_env.env_core import EnvState, MAX_STEPS
|
| 10 |
+
from sql_env.grader import SCORE_MAX, SCORE_MIN, grade_task, strict_clamp
|
| 11 |
+
from sql_env.tasks import TASK_IDS, TASKS
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ---------------------------------------------------------------------------
|
| 15 |
+
# Strict (0, 1) clamp invariants
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
def test_strict_clamp_handles_extremes():
|
| 18 |
+
assert strict_clamp(0.0) == SCORE_MIN
|
| 19 |
+
assert strict_clamp(-1.0) == SCORE_MIN
|
| 20 |
+
assert strict_clamp(1.0) == SCORE_MAX
|
| 21 |
+
assert strict_clamp(2.0) == SCORE_MAX
|
| 22 |
+
assert strict_clamp(float("nan")) == 0.5
|
| 23 |
+
assert strict_clamp(float("inf")) == 0.5
|
| 24 |
+
assert strict_clamp(float("-inf")) == 0.5
|
| 25 |
+
assert strict_clamp("not a number") == 0.5
|
| 26 |
+
assert strict_clamp(None) == 0.5
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_strict_clamp_passes_through_in_range():
|
| 30 |
+
for v in [0.001, 0.1, 0.5, 0.7234, 0.999]:
|
| 31 |
+
out = strict_clamp(v)
|
| 32 |
+
assert SCORE_MIN <= out <= SCORE_MAX
|
| 33 |
+
assert 0.0 < out < 1.0
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
# Each canonical query reproduces the expected rows
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
def test_canonical_queries_solve_their_tasks():
|
| 40 |
+
for tid in TASK_IDS:
|
| 41 |
+
s = EnvState()
|
| 42 |
+
s.reset(tid)
|
| 43 |
+
result = s.step(
|
| 44 |
+
{"action_type": "submit_query", "query": TASKS[tid]["canonical_query"]}
|
| 45 |
+
)
|
| 46 |
+
assert result["info"]["solved"] is True, f"{tid} canonical did not solve"
|
| 47 |
+
assert result["reward"] == 1.0
|
| 48 |
+
score = grade_task(s, tid)
|
| 49 |
+
assert SCORE_MIN <= score <= SCORE_MAX
|
| 50 |
+
assert score >= 0.85, f"{tid} canonical scored too low: {score}"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
# Broken queries do not solve and grade in (0, 1)
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
+
def test_broken_queries_score_in_range_but_not_solved():
|
| 57 |
+
for tid in TASK_IDS:
|
| 58 |
+
s = EnvState()
|
| 59 |
+
s.reset(tid)
|
| 60 |
+
result = s.step(
|
| 61 |
+
{"action_type": "submit_query", "query": TASKS[tid]["broken_query"]}
|
| 62 |
+
)
|
| 63 |
+
assert result["info"]["solved"] is False
|
| 64 |
+
score = grade_task(s, tid)
|
| 65 |
+
assert SCORE_MIN <= score <= SCORE_MAX
|
| 66 |
+
assert 0.0 < score < 1.0
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ---------------------------------------------------------------------------
|
| 70 |
+
# A do-nothing run still produces an in-range score
|
| 71 |
+
# ---------------------------------------------------------------------------
|
| 72 |
+
def test_no_submission_scores_in_range():
|
| 73 |
+
for tid in TASK_IDS:
|
| 74 |
+
s = EnvState()
|
| 75 |
+
s.reset(tid)
|
| 76 |
+
score = grade_task(s, tid)
|
| 77 |
+
assert SCORE_MIN <= score <= SCORE_MAX
|
| 78 |
+
assert 0.0 < score < 1.0
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
# Step limit terminates
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
def test_step_limit_done():
|
| 85 |
+
s = EnvState()
|
| 86 |
+
s.reset("task_1")
|
| 87 |
+
for _ in range(MAX_STEPS):
|
| 88 |
+
result = s.step({"action_type": "submit_query", "query": "SELECT 1"})
|
| 89 |
+
assert result["done"] is True
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ---------------------------------------------------------------------------
|
| 93 |
+
# Reset accepts unknown task_id by falling back to task_1
|
| 94 |
+
# ---------------------------------------------------------------------------
|
| 95 |
+
def test_reset_unknown_task_falls_back():
|
| 96 |
+
s = EnvState()
|
| 97 |
+
obs = s.reset("nonexistent_task")
|
| 98 |
+
assert obs["task_id"] == "task_1"
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ---------------------------------------------------------------------------
|
| 102 |
+
# Empty action does not crash
|
| 103 |
+
# ---------------------------------------------------------------------------
|
| 104 |
+
def test_empty_action_handled():
|
| 105 |
+
s = EnvState()
|
| 106 |
+
s.reset("task_1")
|
| 107 |
+
result = s.step({})
|
| 108 |
+
assert "observation" in result
|
| 109 |
+
assert result["reward"] <= 0 # negative or zero reward
|
| 110 |
+
assert result["observation"]["error"]
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|