Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitignore +13 -0
- Dockerfile +25 -0
- README.md +250 -5
- __init__.py +8 -0
- client.py +54 -0
- data/schema.sql +49 -0
- data/seed.sql +200 -0
- data/tasks/advanced_analytics.json +101 -0
- data/tasks/basic_select.json +75 -0
- data/tasks/join_aggregate.json +93 -0
- inference.py +265 -0
- models.py +28 -0
- openenv.yaml +6 -0
- pyproject.toml +45 -0
- server/Dockerfile +80 -0
- server/__init__.py +1 -0
- server/app.py +50 -0
- server/database.py +214 -0
- server/graders.py +214 -0
- server/requirements.txt +4 -0
- server/sql_env_environment.py +217 -0
.gitignore
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.egg-info/
|
| 5 |
+
dist/
|
| 6 |
+
build/
|
| 7 |
+
.venv/
|
| 8 |
+
venv/
|
| 9 |
+
*.db
|
| 10 |
+
*.sqlite
|
| 11 |
+
.env
|
| 12 |
+
.DS_Store
|
| 13 |
+
uv.lock
|
Dockerfile
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install system dependencies
|
| 6 |
+
RUN apt-get update && \
|
| 7 |
+
apt-get install -y --no-install-recommends curl && \
|
| 8 |
+
rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
# Copy requirements first for better caching
|
| 11 |
+
COPY server/requirements.txt .
|
| 12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
+
|
| 14 |
+
# Copy project files
|
| 15 |
+
COPY . .
|
| 16 |
+
|
| 17 |
+
# Expose HF Spaces default port
|
| 18 |
+
EXPOSE 7860
|
| 19 |
+
|
| 20 |
+
# Health check
|
| 21 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \
|
| 22 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 23 |
+
|
| 24 |
+
# Run the FastAPI server on port 7860 (HF Spaces default)
|
| 25 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,10 +1,255 @@
|
|
| 1 |
---
|
| 2 |
-
title: Sql Env
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Sql Env Environment Server
|
| 3 |
+
emoji: 🗜️
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: gray
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
app_port: 8000
|
| 9 |
+
base_path: /web
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# Sql Env Environment
|
| 15 |
+
|
| 16 |
+
A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns.
|
| 17 |
+
|
| 18 |
+
## Quick Start
|
| 19 |
+
|
| 20 |
+
The simplest way to use the Sql Env environment is through the `SqlEnv` class:
|
| 21 |
+
|
| 22 |
+
```python
|
| 23 |
+
from sql_env import SqlAction, SqlEnv
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
# Create environment from Docker image
|
| 27 |
+
sql_envenv = SqlEnv.from_docker_image("sql_env-env:latest")
|
| 28 |
+
|
| 29 |
+
# Reset
|
| 30 |
+
result = sql_envenv.reset()
|
| 31 |
+
print(f"Reset: {result.observation.echoed_message}")
|
| 32 |
+
|
| 33 |
+
# Send multiple messages
|
| 34 |
+
messages = ["Hello, World!", "Testing echo", "Final message"]
|
| 35 |
+
|
| 36 |
+
for msg in messages:
|
| 37 |
+
result = sql_envenv.step(SqlAction(message=msg))
|
| 38 |
+
print(f"Sent: '{msg}'")
|
| 39 |
+
print(f" → Echoed: '{result.observation.echoed_message}'")
|
| 40 |
+
print(f" → Length: {result.observation.message_length}")
|
| 41 |
+
print(f" → Reward: {result.reward}")
|
| 42 |
+
|
| 43 |
+
finally:
|
| 44 |
+
# Always clean up
|
| 45 |
+
sql_envenv.close()
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
That's it! The `SqlEnv.from_docker_image()` method handles:
|
| 49 |
+
- Starting the Docker container
|
| 50 |
+
- Waiting for the server to be ready
|
| 51 |
+
- Connecting to the environment
|
| 52 |
+
- Container cleanup when you call `close()`
|
| 53 |
+
|
| 54 |
+
## Building the Docker Image
|
| 55 |
+
|
| 56 |
+
Before using the environment, you need to build the Docker image:
|
| 57 |
+
|
| 58 |
+
```bash
|
| 59 |
+
# From project root
|
| 60 |
+
docker build -t sql_env-env:latest -f server/Dockerfile .
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
## Deploying to Hugging Face Spaces
|
| 64 |
+
|
| 65 |
+
You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
# From the environment directory (where openenv.yaml is located)
|
| 69 |
+
openenv push
|
| 70 |
+
|
| 71 |
+
# Or specify options
|
| 72 |
+
openenv push --namespace my-org --private
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
The `openenv push` command will:
|
| 76 |
+
1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
|
| 77 |
+
2. Prepare a custom build for Hugging Face Docker space (enables web interface)
|
| 78 |
+
3. Upload to Hugging Face (ensuring you're logged in)
|
| 79 |
+
|
| 80 |
+
### Prerequisites
|
| 81 |
+
|
| 82 |
+
- Authenticate with Hugging Face: The command will prompt for login if not already authenticated
|
| 83 |
+
|
| 84 |
+
### Options
|
| 85 |
+
|
| 86 |
+
- `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
|
| 87 |
+
- `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
|
| 88 |
+
- `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
|
| 89 |
+
- `--private`: Deploy the space as private (default: public)
|
| 90 |
+
|
| 91 |
+
### Examples
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
# Push to your personal namespace (defaults to username/env-name from openenv.yaml)
|
| 95 |
+
openenv push
|
| 96 |
+
|
| 97 |
+
# Push to a specific repository
|
| 98 |
+
openenv push --repo-id my-org/my-env
|
| 99 |
+
|
| 100 |
+
# Push with a custom base image
|
| 101 |
+
openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
|
| 102 |
+
|
| 103 |
+
# Push as a private space
|
| 104 |
+
openenv push --private
|
| 105 |
+
|
| 106 |
+
# Combine options
|
| 107 |
+
openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
After deployment, your space will be available at:
|
| 111 |
+
`https://huggingface.co/spaces/<repo-id>`
|
| 112 |
+
|
| 113 |
+
The deployed space includes:
|
| 114 |
+
- **Web Interface** at `/web` - Interactive UI for exploring the environment
|
| 115 |
+
- **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
|
| 116 |
+
- **Health Check** at `/health` - Container health monitoring
|
| 117 |
+
- **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions
|
| 118 |
+
|
| 119 |
+
## Environment Details
|
| 120 |
+
|
| 121 |
+
### Action
|
| 122 |
+
**SqlAction**: Contains a single field
|
| 123 |
+
- `message` (str) - The message to echo back
|
| 124 |
+
|
| 125 |
+
### Observation
|
| 126 |
+
**SqlObservation**: Contains the echo response and metadata
|
| 127 |
+
- `echoed_message` (str) - The message echoed back
|
| 128 |
+
- `message_length` (int) - Length of the message
|
| 129 |
+
- `reward` (float) - Reward based on message length (length × 0.1)
|
| 130 |
+
- `done` (bool) - Always False for echo environment
|
| 131 |
+
- `metadata` (dict) - Additional info like step count
|
| 132 |
+
|
| 133 |
+
### Reward
|
| 134 |
+
The reward is calculated as: `message_length × 0.1`
|
| 135 |
+
- "Hi" → reward: 0.2
|
| 136 |
+
- "Hello, World!" → reward: 1.3
|
| 137 |
+
- Empty message → reward: 0.0
|
| 138 |
+
|
| 139 |
+
## Advanced Usage
|
| 140 |
+
|
| 141 |
+
### Connecting to an Existing Server
|
| 142 |
+
|
| 143 |
+
If you already have a Sql Env environment server running, you can connect directly:
|
| 144 |
+
|
| 145 |
+
```python
|
| 146 |
+
from sql_env import SqlEnv
|
| 147 |
+
|
| 148 |
+
# Connect to existing server
|
| 149 |
+
sql_envenv = SqlEnv(base_url="<ENV_HTTP_URL_HERE>")
|
| 150 |
+
|
| 151 |
+
# Use as normal
|
| 152 |
+
result = sql_envenv.reset()
|
| 153 |
+
result = sql_envenv.step(SqlAction(message="Hello!"))
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
Note: When connecting to an existing server, `sql_envenv.close()` will NOT stop the server.
|
| 157 |
+
|
| 158 |
+
### Using the Context Manager
|
| 159 |
+
|
| 160 |
+
The client supports context manager usage for automatic connection management:
|
| 161 |
+
|
| 162 |
+
```python
|
| 163 |
+
from sql_env import SqlAction, SqlEnv
|
| 164 |
+
|
| 165 |
+
# Connect with context manager (auto-connects and closes)
|
| 166 |
+
with SqlEnv(base_url="http://localhost:8000") as env:
|
| 167 |
+
result = env.reset()
|
| 168 |
+
print(f"Reset: {result.observation.echoed_message}")
|
| 169 |
+
# Multiple steps with low latency
|
| 170 |
+
for msg in ["Hello", "World", "!"]:
|
| 171 |
+
result = env.step(SqlAction(message=msg))
|
| 172 |
+
print(f"Echoed: {result.observation.echoed_message}")
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
The client uses WebSocket connections for:
|
| 176 |
+
- **Lower latency**: No HTTP connection overhead per request
|
| 177 |
+
- **Persistent session**: Server maintains your environment state
|
| 178 |
+
- **Efficient for episodes**: Better for many sequential steps
|
| 179 |
+
|
| 180 |
+
### Concurrent WebSocket Sessions
|
| 181 |
+
|
| 182 |
+
The server supports multiple concurrent WebSocket connections. To enable this,
|
| 183 |
+
modify `server/app.py` to use factory mode:
|
| 184 |
+
|
| 185 |
+
```python
|
| 186 |
+
# In server/app.py - use factory mode for concurrent sessions
|
| 187 |
+
app = create_app(
|
| 188 |
+
SqlEnvironment, # Pass class, not instance
|
| 189 |
+
SqlAction,
|
| 190 |
+
SqlObservation,
|
| 191 |
+
max_concurrent_envs=4, # Allow 4 concurrent sessions
|
| 192 |
+
)
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
Then multiple clients can connect simultaneously:
|
| 196 |
+
|
| 197 |
+
```python
|
| 198 |
+
from sql_env import SqlAction, SqlEnv
|
| 199 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 200 |
+
|
| 201 |
+
def run_episode(client_id: int):
|
| 202 |
+
with SqlEnv(base_url="http://localhost:8000") as env:
|
| 203 |
+
result = env.reset()
|
| 204 |
+
for i in range(10):
|
| 205 |
+
result = env.step(SqlAction(message=f"Client {client_id}, step {i}"))
|
| 206 |
+
return client_id, result.observation.message_length
|
| 207 |
+
|
| 208 |
+
# Run 4 episodes concurrently
|
| 209 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
| 210 |
+
results = list(executor.map(run_episode, range(4)))
|
| 211 |
+
```
|
| 212 |
+
|
| 213 |
+
## Development & Testing
|
| 214 |
+
|
| 215 |
+
### Direct Environment Testing
|
| 216 |
+
|
| 217 |
+
Test the environment logic directly without starting the HTTP server:
|
| 218 |
+
|
| 219 |
+
```bash
|
| 220 |
+
# From the server directory
|
| 221 |
+
python3 server/sql_env_environment.py
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
This verifies that:
|
| 225 |
+
- Environment resets correctly
|
| 226 |
+
- Step executes actions properly
|
| 227 |
+
- State tracking works
|
| 228 |
+
- Rewards are calculated correctly
|
| 229 |
+
|
| 230 |
+
### Running Locally
|
| 231 |
+
|
| 232 |
+
Run the server locally for development:
|
| 233 |
+
|
| 234 |
+
```bash
|
| 235 |
+
uvicorn server.app:app --reload
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
## Project Structure
|
| 239 |
+
|
| 240 |
+
```
|
| 241 |
+
sql_env/
|
| 242 |
+
├── .dockerignore # Docker build exclusions
|
| 243 |
+
├── __init__.py # Module exports
|
| 244 |
+
├── README.md # This file
|
| 245 |
+
├── openenv.yaml # OpenEnv manifest
|
| 246 |
+
├── pyproject.toml # Project metadata and dependencies
|
| 247 |
+
├── uv.lock # Locked dependencies (generated)
|
| 248 |
+
├── client.py # SqlEnv client
|
| 249 |
+
├── models.py # Action and Observation models
|
| 250 |
+
└── server/
|
| 251 |
+
├── __init__.py # Server module exports
|
| 252 |
+
├── sql_env_environment.py # Core environment logic
|
| 253 |
+
├── app.py # FastAPI application (HTTP + WebSocket endpoints)
|
| 254 |
+
└── Dockerfile # Container image definition
|
| 255 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQL Query Writing Environment."""
|
| 2 |
+
|
| 3 |
+
from .models import SQLAction, SQLObservation
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"SQLAction",
|
| 7 |
+
"SQLObservation",
|
| 8 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQL Query Writing Environment Client."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
from openenv.core import EnvClient
|
| 6 |
+
from openenv.core.client_types import StepResult
|
| 7 |
+
from openenv.core.env_server.types import State
|
| 8 |
+
|
| 9 |
+
from .models import SQLAction, SQLObservation
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SQLEnvClient(
|
| 13 |
+
EnvClient[SQLAction, SQLObservation, State]
|
| 14 |
+
):
|
| 15 |
+
"""
|
| 16 |
+
Client for the SQL Query Writing Environment.
|
| 17 |
+
|
| 18 |
+
Example:
|
| 19 |
+
>>> with SQLEnvClient(base_url="http://localhost:8000") as client:
|
| 20 |
+
... result = client.reset()
|
| 21 |
+
... print(result.observation.question)
|
| 22 |
+
... result = client.step(SQLAction(query="SELECT * FROM customers"))
|
| 23 |
+
... print(result.observation.query_result)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def _step_payload(self, action: SQLAction) -> Dict:
|
| 27 |
+
return {"query": action.query}
|
| 28 |
+
|
| 29 |
+
def _parse_result(self, payload: Dict) -> StepResult[SQLObservation]:
|
| 30 |
+
obs_data = payload.get("observation", {})
|
| 31 |
+
observation = SQLObservation(
|
| 32 |
+
task_name=obs_data.get("task_name", ""),
|
| 33 |
+
question=obs_data.get("question", ""),
|
| 34 |
+
schema_description=obs_data.get("schema_description", ""),
|
| 35 |
+
query_result=obs_data.get("query_result", ""),
|
| 36 |
+
error=obs_data.get("error", ""),
|
| 37 |
+
steps_remaining=obs_data.get("steps_remaining", 0),
|
| 38 |
+
question_index=obs_data.get("question_index", 0),
|
| 39 |
+
total_questions=obs_data.get("total_questions", 0),
|
| 40 |
+
done=payload.get("done", False),
|
| 41 |
+
reward=payload.get("reward"),
|
| 42 |
+
metadata=obs_data.get("metadata", {}),
|
| 43 |
+
)
|
| 44 |
+
return StepResult(
|
| 45 |
+
observation=observation,
|
| 46 |
+
reward=payload.get("reward"),
|
| 47 |
+
done=payload.get("done", False),
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def _parse_state(self, payload: Dict) -> State:
|
| 51 |
+
return State(
|
| 52 |
+
episode_id=payload.get("episode_id"),
|
| 53 |
+
step_count=payload.get("step_count", 0),
|
| 54 |
+
)
|
data/schema.sql
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- SQLEnv E-Commerce Database Schema
|
| 2 |
+
-- Designed for text-to-SQL agent training with 3 difficulty levels
|
| 3 |
+
|
| 4 |
+
CREATE TABLE IF NOT EXISTS customers (
|
| 5 |
+
id INTEGER PRIMARY KEY,
|
| 6 |
+
name TEXT NOT NULL,
|
| 7 |
+
email TEXT NOT NULL UNIQUE,
|
| 8 |
+
age INTEGER NOT NULL,
|
| 9 |
+
city TEXT NOT NULL,
|
| 10 |
+
signup_date TEXT NOT NULL -- ISO format: YYYY-MM-DD
|
| 11 |
+
);
|
| 12 |
+
|
| 13 |
+
CREATE TABLE IF NOT EXISTS products (
|
| 14 |
+
id INTEGER PRIMARY KEY,
|
| 15 |
+
name TEXT NOT NULL,
|
| 16 |
+
category TEXT NOT NULL, -- Electronics, Clothing, Books, Home
|
| 17 |
+
price REAL NOT NULL,
|
| 18 |
+
stock INTEGER NOT NULL DEFAULT 0
|
| 19 |
+
);
|
| 20 |
+
|
| 21 |
+
CREATE TABLE IF NOT EXISTS orders (
|
| 22 |
+
id INTEGER PRIMARY KEY,
|
| 23 |
+
customer_id INTEGER NOT NULL,
|
| 24 |
+
order_date TEXT NOT NULL, -- ISO format: YYYY-MM-DD
|
| 25 |
+
status TEXT NOT NULL, -- pending, shipped, delivered, cancelled
|
| 26 |
+
total_amount REAL NOT NULL,
|
| 27 |
+
FOREIGN KEY (customer_id) REFERENCES customers(id)
|
| 28 |
+
);
|
| 29 |
+
|
| 30 |
+
CREATE TABLE IF NOT EXISTS order_items (
|
| 31 |
+
id INTEGER PRIMARY KEY,
|
| 32 |
+
order_id INTEGER NOT NULL,
|
| 33 |
+
product_id INTEGER NOT NULL,
|
| 34 |
+
quantity INTEGER NOT NULL,
|
| 35 |
+
unit_price REAL NOT NULL,
|
| 36 |
+
FOREIGN KEY (order_id) REFERENCES orders(id),
|
| 37 |
+
FOREIGN KEY (product_id) REFERENCES products(id)
|
| 38 |
+
);
|
| 39 |
+
|
| 40 |
+
CREATE TABLE IF NOT EXISTS reviews (
|
| 41 |
+
id INTEGER PRIMARY KEY,
|
| 42 |
+
product_id INTEGER NOT NULL,
|
| 43 |
+
customer_id INTEGER NOT NULL,
|
| 44 |
+
rating INTEGER NOT NULL CHECK (rating >= 1 AND rating <= 5),
|
| 45 |
+
review_text TEXT NOT NULL,
|
| 46 |
+
review_date TEXT NOT NULL, -- ISO format: YYYY-MM-DD
|
| 47 |
+
FOREIGN KEY (product_id) REFERENCES products(id),
|
| 48 |
+
FOREIGN KEY (customer_id) REFERENCES customers(id)
|
| 49 |
+
);
|
data/seed.sql
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- SQLEnv Deterministic Seed Data
|
| 2 |
+
-- Carefully crafted to support easy/medium/hard SQL tasks
|
| 3 |
+
|
| 4 |
+
-- ============================================================
|
| 5 |
+
-- CUSTOMERS (20 rows)
|
| 6 |
+
-- ============================================================
|
| 7 |
+
INSERT INTO customers (id, name, email, age, city, signup_date) VALUES
|
| 8 |
+
(1, 'Aarav Sharma', 'aarav@example.com', 28, 'Mumbai', '2023-03-15'),
|
| 9 |
+
(2, 'Priya Patel', 'priya@example.com', 35, 'Delhi', '2023-05-20'),
|
| 10 |
+
(3, 'Rahul Kumar', 'rahul@example.com', 42, 'Bangalore', '2023-01-10'),
|
| 11 |
+
(4, 'Sneha Gupta', 'sneha@example.com', 24, 'Mumbai', '2024-02-14'),
|
| 12 |
+
(5, 'Vikram Singh', 'vikram@example.com', 31, 'Chennai', '2023-07-01'),
|
| 13 |
+
(6, 'Ananya Reddy', 'ananya@example.com', 29, 'Hyderabad', '2023-09-05'),
|
| 14 |
+
(7, 'Arjun Nair', 'arjun@example.com', 38, 'Bangalore', '2023-04-18'),
|
| 15 |
+
(8, 'Kavita Joshi', 'kavita@example.com', 45, 'Pune', '2023-02-28'),
|
| 16 |
+
(9, 'Deepak Verma', 'deepak@example.com', 33, 'Delhi', '2024-01-05'),
|
| 17 |
+
(10, 'Meera Iyer', 'meera@example.com', 27, 'Chennai', '2023-11-12'),
|
| 18 |
+
(11, 'Rohan Das', 'rohan@example.com', 36, 'Kolkata', '2023-06-22'),
|
| 19 |
+
(12, 'Pooja Mishra', 'pooja@example.com', 41, 'Pune', '2023-08-15'),
|
| 20 |
+
(13, 'Suresh Menon', 'suresh@example.com', 50, 'Mumbai', '2023-03-30'),
|
| 21 |
+
(14, 'Nisha Agarwal', 'nisha@example.com', 23, 'Delhi', '2024-03-01'),
|
| 22 |
+
(15, 'Amit Pandey', 'amit@example.com', 34, 'Bangalore', '2023-10-10'),
|
| 23 |
+
(16, 'Ritu Chopra', 'ritu@example.com', 30, 'Hyderabad', '2023-12-25'),
|
| 24 |
+
(17, 'Karan Malhotra', 'karan@example.com', 26, 'Mumbai', '2024-01-20'),
|
| 25 |
+
(18, 'Divya Saxena', 'divya@example.com', 39, 'Chennai', '2023-05-05'),
|
| 26 |
+
(19, 'Nikhil Bhat', 'nikhil@example.com', 32, 'Kolkata', '2023-07-14'),
|
| 27 |
+
(20, 'Swati Tiwari', 'swati@example.com', 44, 'Pune', '2023-04-02');
|
| 28 |
+
|
| 29 |
+
-- ============================================================
|
| 30 |
+
-- PRODUCTS (15 rows, 4 categories)
|
| 31 |
+
-- ============================================================
|
| 32 |
+
INSERT INTO products (id, name, category, price, stock) VALUES
|
| 33 |
+
(1, 'Wireless Headphones', 'Electronics', 2499.00, 50),
|
| 34 |
+
(2, 'Smartphone Case', 'Electronics', 499.00, 200),
|
| 35 |
+
(3, 'Bluetooth Speaker', 'Electronics', 3999.00, 30),
|
| 36 |
+
(4, 'USB-C Cable', 'Electronics', 199.00, 500),
|
| 37 |
+
(5, 'Cotton T-Shirt', 'Clothing', 599.00, 150),
|
| 38 |
+
(6, 'Denim Jeans', 'Clothing', 1499.00, 80),
|
| 39 |
+
(7, 'Running Shoes', 'Clothing', 2999.00, 40),
|
| 40 |
+
(8, 'Winter Jacket', 'Clothing', 3499.00, 25),
|
| 41 |
+
(9, 'Python Programming', 'Books', 449.00, 100),
|
| 42 |
+
(10, 'Data Science Handbook', 'Books', 699.00, 60),
|
| 43 |
+
(11, 'Mystery Novel', 'Books', 299.00, 120),
|
| 44 |
+
(12, 'Cooking Recipes', 'Books', 399.00, 90),
|
| 45 |
+
(13, 'Ceramic Mug Set', 'Home', 799.00, 70),
|
| 46 |
+
(14, 'Desk Lamp', 'Home', 1299.00, 45),
|
| 47 |
+
(15, 'Plant Pot', 'Home', 349.00, 110);
|
| 48 |
+
|
| 49 |
+
-- ============================================================
|
| 50 |
+
-- ORDERS (30 rows)
|
| 51 |
+
-- ============================================================
|
| 52 |
+
INSERT INTO orders (id, customer_id, order_date, status, total_amount) VALUES
|
| 53 |
+
(1, 1, '2024-01-15', 'delivered', 2998.00),
|
| 54 |
+
(2, 1, '2024-03-20', 'delivered', 499.00),
|
| 55 |
+
(3, 2, '2024-02-10', 'delivered', 4498.00),
|
| 56 |
+
(4, 3, '2024-01-05', 'delivered', 1798.00),
|
| 57 |
+
(5, 3, '2024-04-12', 'shipped', 3999.00),
|
| 58 |
+
(6, 2, '2024-03-01', 'delivered', 599.00),
|
| 59 |
+
(7, 5, '2024-02-28', 'delivered', 5997.00),
|
| 60 |
+
(8, 5, '2024-05-15', 'shipped', 1299.00),
|
| 61 |
+
(9, 6, '2024-01-20', 'delivered', 898.00),
|
| 62 |
+
(10, 7, '2024-03-10', 'delivered', 2499.00),
|
| 63 |
+
(11, 7, '2024-06-01', 'pending', 699.00),
|
| 64 |
+
(12, 8, '2024-02-14', 'delivered', 4998.00),
|
| 65 |
+
(13, 8, '2024-04-25', 'cancelled', 1499.00),
|
| 66 |
+
(14, 9, '2024-03-05', 'delivered', 848.00),
|
| 67 |
+
(15, 10, '2024-01-30', 'delivered', 2999.00),
|
| 68 |
+
(16, 10, '2024-05-20', 'shipped', 449.00),
|
| 69 |
+
(17, 11, '2024-02-18', 'delivered', 1598.00),
|
| 70 |
+
(18, 12, '2024-03-22', 'delivered', 3499.00),
|
| 71 |
+
(19, 13, '2024-04-08', 'shipped', 798.00),
|
| 72 |
+
(20, 13, '2024-01-12', 'delivered', 2499.00),
|
| 73 |
+
(21, 11, '2024-05-01', 'pending', 599.00),
|
| 74 |
+
(22, 15, '2024-02-05', 'delivered', 1748.00),
|
| 75 |
+
(23, 15, '2024-06-10', 'pending', 299.00),
|
| 76 |
+
(24, 16, '2024-03-15', 'delivered', 999.00),
|
| 77 |
+
(25, 17, '2024-04-20', 'shipped', 2499.00),
|
| 78 |
+
(26, 18, '2024-01-25', 'delivered', 4498.00),
|
| 79 |
+
(27, 18, '2024-05-30', 'delivered', 699.00),
|
| 80 |
+
(28, 19, '2024-02-22', 'delivered', 1299.00),
|
| 81 |
+
(29, 19, '2024-06-05', 'cancelled', 399.00),
|
| 82 |
+
(30, 20, '2024-03-28', 'delivered', 3798.00);
|
| 83 |
+
|
| 84 |
+
-- ============================================================
|
| 85 |
+
-- ORDER_ITEMS (60 rows)
|
| 86 |
+
-- ============================================================
|
| 87 |
+
INSERT INTO order_items (id, order_id, product_id, quantity, unit_price) VALUES
|
| 88 |
+
-- Order 1: customer 1 bought headphones + smartphone case
|
| 89 |
+
(1, 1, 1, 1, 2499.00),
|
| 90 |
+
(2, 1, 2, 1, 499.00),
|
| 91 |
+
-- Order 2: customer 1 bought smartphone case
|
| 92 |
+
(3, 2, 2, 1, 499.00),
|
| 93 |
+
-- Order 3: customer 2 bought headphones + jeans + smartphone case
|
| 94 |
+
(4, 3, 1, 1, 2499.00),
|
| 95 |
+
(5, 3, 6, 1, 1499.00),
|
| 96 |
+
(6, 3, 2, 1, 499.00),
|
| 97 |
+
-- Order 4: customer 3 bought desk lamp + smartphone case
|
| 98 |
+
(7, 4, 14, 1, 1299.00),
|
| 99 |
+
(8, 4, 2, 1, 499.00),
|
| 100 |
+
-- Order 5: customer 3 bought bluetooth speaker
|
| 101 |
+
(9, 5, 3, 1, 3999.00),
|
| 102 |
+
-- Order 6: customer 2 bought t-shirt
|
| 103 |
+
(10, 6, 5, 1, 599.00),
|
| 104 |
+
-- Order 7: customer 5 bought running shoes x2
|
| 105 |
+
(11, 7, 7, 2, 2999.00),
|
| 106 |
+
-- Order 8: customer 5 bought desk lamp
|
| 107 |
+
(12, 8, 14, 1, 1299.00),
|
| 108 |
+
-- Order 9: customer 6 bought mug set + usb cable
|
| 109 |
+
(13, 9, 13, 1, 799.00),
|
| 110 |
+
(14, 9, 4, 1, 199.00), -- total should be 998, but we set 898 for slight difference
|
| 111 |
+
-- Order 10: customer 7 bought headphones
|
| 112 |
+
(15, 10, 1, 1, 2499.00),
|
| 113 |
+
-- Order 11: customer 7 bought data science book
|
| 114 |
+
(16, 11, 10, 1, 699.00),
|
| 115 |
+
-- Order 12: customer 8 bought headphones x2
|
| 116 |
+
(17, 12, 1, 2, 2499.00),
|
| 117 |
+
-- Order 13: customer 8 bought jeans (cancelled)
|
| 118 |
+
(18, 13, 6, 1, 1499.00),
|
| 119 |
+
-- Order 14: customer 9 bought python book + plant pot
|
| 120 |
+
(19, 14, 9, 1, 449.00),
|
| 121 |
+
(20, 14, 15, 1, 349.00),
|
| 122 |
+
-- Order 15: customer 10 bought running shoes
|
| 123 |
+
(21, 15, 7, 1, 2999.00),
|
| 124 |
+
-- Order 16: customer 10 bought python book
|
| 125 |
+
(22, 16, 9, 1, 449.00),
|
| 126 |
+
-- Order 17: customer 11 bought mug set x2
|
| 127 |
+
(23, 17, 13, 2, 799.00),
|
| 128 |
+
-- Order 18: customer 12 bought winter jacket
|
| 129 |
+
(24, 18, 8, 1, 3499.00),
|
| 130 |
+
-- Order 19: customer 13 bought mug set (shipped)
|
| 131 |
+
(25, 19, 13, 1, 799.00),
|
| 132 |
+
-- Order 20: customer 13 bought headphones
|
| 133 |
+
(26, 20, 1, 1, 2499.00),
|
| 134 |
+
-- Order 21: customer 11 bought t-shirt (pending)
|
| 135 |
+
(27, 21, 5, 1, 599.00),
|
| 136 |
+
-- Order 22: customer 15 bought python book + desk lamp
|
| 137 |
+
(28, 22, 9, 1, 449.00),
|
| 138 |
+
(29, 22, 14, 1, 1299.00),
|
| 139 |
+
-- Order 23: customer 15 bought mystery novel (pending)
|
| 140 |
+
(30, 23, 11, 1, 299.00),
|
| 141 |
+
-- Order 24: customer 16 bought t-shirt + cooking book
|
| 142 |
+
(31, 24, 5, 1, 599.00),
|
| 143 |
+
(32, 24, 12, 1, 399.00),
|
| 144 |
+
-- Order 25: customer 17 bought headphones (shipped)
|
| 145 |
+
(33, 25, 1, 1, 2499.00),
|
| 146 |
+
-- Order 26: customer 18 bought bluetooth speaker + smartphone case
|
| 147 |
+
(34, 26, 3, 1, 3999.00),
|
| 148 |
+
(35, 26, 2, 1, 499.00),
|
| 149 |
+
-- Order 27: customer 18 bought data science book
|
| 150 |
+
(36, 27, 10, 1, 699.00),
|
| 151 |
+
-- Order 28: customer 19 bought desk lamp
|
| 152 |
+
(37, 28, 14, 1, 1299.00),
|
| 153 |
+
-- Order 29: customer 19 bought cooking book (cancelled)
|
| 154 |
+
(38, 29, 12, 1, 399.00),
|
| 155 |
+
-- Order 30: customer 20 bought running shoes + mug set
|
| 156 |
+
(39, 30, 7, 1, 2999.00),
|
| 157 |
+
(40, 30, 13, 1, 799.00),
|
| 158 |
+
-- Extra items to diversify category coverage
|
| 159 |
+
-- Order 3 (customer 2): add a book → Electronics + Clothing + Books
|
| 160 |
+
(41, 3, 9, 1, 449.00),
|
| 161 |
+
-- Order 4 (customer 3): add a t-shirt → Home + Electronics + Clothing
|
| 162 |
+
(42, 4, 5, 1, 599.00),
|
| 163 |
+
-- Order 4 (customer 3): add a book → Home + Electronics + Clothing + Books = 4 categories
|
| 164 |
+
(43, 4, 11, 1, 299.00),
|
| 165 |
+
-- Order 24 (customer 16): add a mug set → Clothing + Books + Home
|
| 166 |
+
(44, 24, 13, 1, 799.00),
|
| 167 |
+
-- Order 9 (customer 6): add a book → Home + Electronics + Books
|
| 168 |
+
(45, 9, 12, 1, 399.00),
|
| 169 |
+
-- Order 17 (customer 11): add headphones → Home + Clothing + Electronics
|
| 170 |
+
(46, 17, 4, 1, 199.00);
|
| 171 |
+
|
| 172 |
+
-- ============================================================
|
| 173 |
+
-- REVIEWS (25 rows)
|
| 174 |
+
-- ============================================================
|
| 175 |
+
INSERT INTO reviews (id, product_id, customer_id, rating, review_text, review_date) VALUES
|
| 176 |
+
(1, 1, 1, 5, 'Amazing sound quality, worth every penny!', '2024-02-01'),
|
| 177 |
+
(2, 1, 7, 4, 'Good headphones, battery could be better.', '2024-03-25'),
|
| 178 |
+
(3, 1, 13, 5, 'Best wireless headphones I have owned.', '2024-02-15'),
|
| 179 |
+
(4, 2, 1, 3, 'Decent case but feels a bit cheap.', '2024-04-01'),
|
| 180 |
+
(5, 2, 3, 4, 'Good fit, protects the phone well.', '2024-02-10'),
|
| 181 |
+
(6, 3, 18, 5, 'Incredible bass for the price!', '2024-02-20'),
|
| 182 |
+
(7, 3, 5, 4, 'Great speaker, slightly heavy though.', '2024-06-01'),
|
| 183 |
+
(8, 5, 4, 4, 'Soft cotton, very comfortable.', '2024-03-15'),
|
| 184 |
+
(9, 5, 16, 3, 'Okay quality, shrunk after washing.', '2024-04-05'),
|
| 185 |
+
(10, 6, 2, 5, 'Perfect fit, love the style!', '2024-03-10'),
|
| 186 |
+
(11, 7, 5, 5, 'Super comfortable for running.', '2024-03-20'),
|
| 187 |
+
(12, 7, 10, 4, 'Good shoes but sizing runs a bit large.', '2024-02-15'),
|
| 188 |
+
(13, 7, 20, 5, 'Best running shoes ever!', '2024-04-10'),
|
| 189 |
+
(14, 8, 12, 4, 'Warm and stylish, great for winter.', '2024-04-01'),
|
| 190 |
+
(15, 9, 9, 5, 'Excellent book for learning Python.', '2024-04-10'),
|
| 191 |
+
(16, 9, 15, 4, 'Good content, some chapters feel rushed.', '2024-03-01'),
|
| 192 |
+
(17, 10, 7, 3, 'Covers basics well, lacks depth on ML topics.', '2024-06-15'),
|
| 193 |
+
(18, 11, 15, 4, 'Gripping plot, could not put it down!', '2024-06-20'),
|
| 194 |
+
(19, 13, 6, 5, 'Beautiful mugs, great as a gift set.', '2024-02-10'),
|
| 195 |
+
(20, 13, 11, 4, 'Nice mugs, one had a small chip.', '2024-03-05'),
|
| 196 |
+
(21, 13, 20, 5, 'Excellent quality ceramic.', '2024-04-15'),
|
| 197 |
+
(22, 14, 3, 4, 'Bright and adjustable, good for desk work.', '2024-02-20'),
|
| 198 |
+
(23, 14, 19, 5, 'Perfect desk lamp, love the design.', '2024-03-10'),
|
| 199 |
+
(24, 15, 9, 3, 'Looks nice but drainage hole is too small.', '2024-04-15'),
|
| 200 |
+
(25, 12, 16, 4, 'Great recipes, easy to follow instructions.', '2024-04-20');
|
data/tasks/advanced_analytics.json
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"task_name": "advanced_analytics",
|
| 3 |
+
"difficulty": "hard",
|
| 4 |
+
"description": "Subqueries, CTEs, window functions, and complex multi-table analytics",
|
| 5 |
+
"max_steps_per_question": 5,
|
| 6 |
+
"questions": [
|
| 7 |
+
{
|
| 8 |
+
"id": "hard_1",
|
| 9 |
+
"question": "Find all customers whose total spending across all orders exceeds the average total spending per customer. Show customer name and total spent, sorted by total spent from highest to lowest.",
|
| 10 |
+
"ground_truth_sql": "SELECT c.name, SUM(o.total_amount) as total_spent FROM customers c JOIN orders o ON c.id = o.customer_id GROUP BY c.id HAVING total_spent > (SELECT AVG(total_spent) FROM (SELECT SUM(total_amount) as total_spent FROM orders GROUP BY customer_id)) ORDER BY total_spent DESC",
|
| 11 |
+
"expected_columns": ["name", "total_spent"],
|
| 12 |
+
"expected_row_count": 9,
|
| 13 |
+
"expected_rows": [
|
| 14 |
+
["Vikram Singh", 7296.0],
|
| 15 |
+
["Kavita Joshi", 6497.0],
|
| 16 |
+
["Rahul Kumar", 5797.0],
|
| 17 |
+
["Divya Saxena", 5197.0],
|
| 18 |
+
["Priya Patel", 5097.0],
|
| 19 |
+
["Swati Tiwari", 3798.0],
|
| 20 |
+
["Pooja Mishra", 3499.0],
|
| 21 |
+
["Aarav Sharma", 3497.0],
|
| 22 |
+
["Meera Iyer", 3448.0]
|
| 23 |
+
],
|
| 24 |
+
"order_matters": true
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"id": "hard_2",
|
| 28 |
+
"question": "Rank all products by their total revenue (quantity * unit_price from order_items) within each product category. Show category, product name, revenue, and the rank within the category. Sort by category alphabetically, then by rank.",
|
| 29 |
+
"ground_truth_sql": "SELECT p.category, p.name, SUM(oi.quantity * oi.unit_price) as revenue, RANK() OVER (PARTITION BY p.category ORDER BY SUM(oi.quantity * oi.unit_price) DESC) as category_rank FROM products p JOIN order_items oi ON p.id = oi.product_id GROUP BY p.id ORDER BY p.category, category_rank",
|
| 30 |
+
"expected_columns": ["category", "name", "revenue", "category_rank"],
|
| 31 |
+
"expected_row_count": 15,
|
| 32 |
+
"expected_rows": [
|
| 33 |
+
["Books", "Python Programming", 1796.0, 1],
|
| 34 |
+
["Books", "Data Science Handbook", 1398.0, 2],
|
| 35 |
+
["Books", "Cooking Recipes", 1197.0, 3],
|
| 36 |
+
["Books", "Mystery Novel", 598.0, 4],
|
| 37 |
+
["Clothing", "Running Shoes", 11996.0, 1],
|
| 38 |
+
["Clothing", "Winter Jacket", 3499.0, 2],
|
| 39 |
+
["Clothing", "Denim Jeans", 2998.0, 3],
|
| 40 |
+
["Clothing", "Cotton T-Shirt", 2396.0, 4],
|
| 41 |
+
["Electronics", "Wireless Headphones", 17493.0, 1],
|
| 42 |
+
["Electronics", "Bluetooth Speaker", 7998.0, 2],
|
| 43 |
+
["Electronics", "Smartphone Case", 2495.0, 3],
|
| 44 |
+
["Electronics", "USB-C Cable", 398.0, 4],
|
| 45 |
+
["Home", "Desk Lamp", 5196.0, 1],
|
| 46 |
+
["Home", "Ceramic Mug Set", 4794.0, 2],
|
| 47 |
+
["Home", "Plant Pot", 349.0, 3]
|
| 48 |
+
],
|
| 49 |
+
"order_matters": true
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"id": "hard_3",
|
| 53 |
+
"question": "Calculate the month-over-month growth in order count for 2024. Show the month (as YYYY-MM), the number of orders that month, and the change from the previous month (NULL for the first month). Sort by month.",
|
| 54 |
+
"ground_truth_sql": "SELECT strftime('%Y-%m', order_date) as month, COUNT(*) as order_count, COUNT(*) - LAG(COUNT(*)) OVER (ORDER BY strftime('%Y-%m', order_date)) as growth FROM orders GROUP BY month ORDER BY month",
|
| 55 |
+
"expected_columns": ["month", "order_count", "growth"],
|
| 56 |
+
"expected_row_count": 6,
|
| 57 |
+
"expected_rows": [
|
| 58 |
+
["2024-01", 6, null],
|
| 59 |
+
["2024-02", 6, 0],
|
| 60 |
+
["2024-03", 7, 1],
|
| 61 |
+
["2024-04", 4, -3],
|
| 62 |
+
["2024-05", 4, 0],
|
| 63 |
+
["2024-06", 3, -1]
|
| 64 |
+
],
|
| 65 |
+
"order_matters": true
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"id": "hard_4",
|
| 69 |
+
"question": "Find all customers who have purchased products from at least 3 different product categories. Show the customer name and the number of distinct categories they bought from, sorted by category count descending then name ascending.",
|
| 70 |
+
"ground_truth_sql": "SELECT c.name, COUNT(DISTINCT p.category) as category_count FROM customers c JOIN orders o ON c.id = o.customer_id JOIN order_items oi ON o.id = oi.order_id JOIN products p ON oi.product_id = p.id GROUP BY c.id HAVING category_count >= 3 ORDER BY category_count DESC, c.name ASC",
|
| 71 |
+
"expected_columns": ["name", "category_count"],
|
| 72 |
+
"expected_row_count": 5,
|
| 73 |
+
"expected_rows": [
|
| 74 |
+
["Rahul Kumar", 4],
|
| 75 |
+
["Ananya Reddy", 3],
|
| 76 |
+
["Priya Patel", 3],
|
| 77 |
+
["Ritu Chopra", 3],
|
| 78 |
+
["Rohan Das", 3]
|
| 79 |
+
],
|
| 80 |
+
"order_matters": true
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"id": "hard_5",
|
| 84 |
+
"question": "For each product category, find the product with the highest average review rating. Show the category, product name, and average rating (rounded to 2 decimal places). Only include products that have at least 2 reviews. Sort by category alphabetically, then by average rating descending.",
|
| 85 |
+
"ground_truth_sql": "SELECT p.category, p.name, ROUND(AVG(r.rating), 2) as avg_rating FROM products p JOIN reviews r ON p.id = r.product_id GROUP BY p.id HAVING COUNT(r.id) >= 2 ORDER BY p.category, avg_rating DESC",
|
| 86 |
+
"expected_columns": ["category", "name", "avg_rating"],
|
| 87 |
+
"expected_row_count": 8,
|
| 88 |
+
"expected_rows": [
|
| 89 |
+
["Books", "Python Programming", 4.5],
|
| 90 |
+
["Clothing", "Running Shoes", 4.67],
|
| 91 |
+
["Clothing", "Cotton T-Shirt", 3.5],
|
| 92 |
+
["Electronics", "Wireless Headphones", 4.67],
|
| 93 |
+
["Electronics", "Bluetooth Speaker", 4.5],
|
| 94 |
+
["Electronics", "Smartphone Case", 3.5],
|
| 95 |
+
["Home", "Ceramic Mug Set", 4.67],
|
| 96 |
+
["Home", "Desk Lamp", 4.5]
|
| 97 |
+
],
|
| 98 |
+
"order_matters": true
|
| 99 |
+
}
|
| 100 |
+
]
|
| 101 |
+
}
|
data/tasks/basic_select.json
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"task_name": "basic_select",
|
| 3 |
+
"difficulty": "easy",
|
| 4 |
+
"description": "Simple SELECT queries with WHERE, ORDER BY, LIMIT",
|
| 5 |
+
"max_steps_per_question": 3,
|
| 6 |
+
"questions": [
|
| 7 |
+
{
|
| 8 |
+
"id": "easy_1",
|
| 9 |
+
"question": "Find the names and ages of all customers older than 30, sorted by age from highest to lowest.",
|
| 10 |
+
"ground_truth_sql": "SELECT name, age FROM customers WHERE age > 30 ORDER BY age DESC",
|
| 11 |
+
"expected_columns": ["name", "age"],
|
| 12 |
+
"expected_row_count": 13,
|
| 13 |
+
"expected_rows": [
|
| 14 |
+
["Suresh Menon", 50],
|
| 15 |
+
["Kavita Joshi", 45],
|
| 16 |
+
["Swati Tiwari", 44],
|
| 17 |
+
["Rahul Kumar", 42],
|
| 18 |
+
["Pooja Mishra", 41],
|
| 19 |
+
["Divya Saxena", 39],
|
| 20 |
+
["Arjun Nair", 38],
|
| 21 |
+
["Rohan Das", 36],
|
| 22 |
+
["Priya Patel", 35],
|
| 23 |
+
["Amit Pandey", 34],
|
| 24 |
+
["Deepak Verma", 33],
|
| 25 |
+
["Nikhil Bhat", 32],
|
| 26 |
+
["Vikram Singh", 31]
|
| 27 |
+
],
|
| 28 |
+
"order_matters": true
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"id": "easy_2",
|
| 32 |
+
"question": "List all products in the 'Electronics' category, showing name and price, sorted by price from highest to lowest.",
|
| 33 |
+
"ground_truth_sql": "SELECT name, price FROM products WHERE category = 'Electronics' ORDER BY price DESC",
|
| 34 |
+
"expected_columns": ["name", "price"],
|
| 35 |
+
"expected_row_count": 4,
|
| 36 |
+
"expected_rows": [
|
| 37 |
+
["Bluetooth Speaker", 3999.0],
|
| 38 |
+
["Wireless Headphones", 2499.0],
|
| 39 |
+
["Smartphone Case", 499.0],
|
| 40 |
+
["USB-C Cable", 199.0]
|
| 41 |
+
],
|
| 42 |
+
"order_matters": true
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"id": "easy_3",
|
| 46 |
+
"question": "How many orders have the status 'shipped'?",
|
| 47 |
+
"ground_truth_sql": "SELECT COUNT(*) as shipped_count FROM orders WHERE status = 'shipped'",
|
| 48 |
+
"expected_columns": ["shipped_count"],
|
| 49 |
+
"expected_row_count": 1,
|
| 50 |
+
"expected_rows": [[5]],
|
| 51 |
+
"order_matters": false
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"id": "easy_4",
|
| 55 |
+
"question": "What is the most expensive product? Show its name and price.",
|
| 56 |
+
"ground_truth_sql": "SELECT name, price FROM products ORDER BY price DESC LIMIT 1",
|
| 57 |
+
"expected_columns": ["name", "price"],
|
| 58 |
+
"expected_row_count": 1,
|
| 59 |
+
"expected_rows": [["Bluetooth Speaker", 3999.0]],
|
| 60 |
+
"order_matters": false
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"id": "easy_5",
|
| 64 |
+
"question": "Find all customers from Mumbai who signed up after January 1, 2024. Show their name and signup date, sorted by signup date.",
|
| 65 |
+
"ground_truth_sql": "SELECT name, signup_date FROM customers WHERE city = 'Mumbai' AND signup_date > '2024-01-01' ORDER BY signup_date",
|
| 66 |
+
"expected_columns": ["name", "signup_date"],
|
| 67 |
+
"expected_row_count": 2,
|
| 68 |
+
"expected_rows": [
|
| 69 |
+
["Karan Malhotra", "2024-01-20"],
|
| 70 |
+
["Sneha Gupta", "2024-02-14"]
|
| 71 |
+
],
|
| 72 |
+
"order_matters": true
|
| 73 |
+
}
|
| 74 |
+
]
|
| 75 |
+
}
|
data/tasks/join_aggregate.json
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"task_name": "join_aggregate",
|
| 3 |
+
"difficulty": "medium",
|
| 4 |
+
"description": "JOIN queries with GROUP BY, HAVING, and aggregate functions",
|
| 5 |
+
"max_steps_per_question": 4,
|
| 6 |
+
"questions": [
|
| 7 |
+
{
|
| 8 |
+
"id": "med_1",
|
| 9 |
+
"question": "What is the average order total for each customer? Show customer name and average total (rounded to 2 decimal places), sorted by average total from highest to lowest.",
|
| 10 |
+
"ground_truth_sql": "SELECT c.name, ROUND(AVG(o.total_amount), 2) as avg_total FROM customers c JOIN orders o ON c.id = o.customer_id GROUP BY c.id ORDER BY avg_total DESC",
|
| 11 |
+
"expected_columns": ["name", "avg_total"],
|
| 12 |
+
"expected_row_count": 18,
|
| 13 |
+
"expected_rows": [
|
| 14 |
+
["Swati Tiwari", 3798.0],
|
| 15 |
+
["Vikram Singh", 3648.0],
|
| 16 |
+
["Pooja Mishra", 3499.0],
|
| 17 |
+
["Kavita Joshi", 3248.5],
|
| 18 |
+
["Rahul Kumar", 2898.5],
|
| 19 |
+
["Divya Saxena", 2598.5],
|
| 20 |
+
["Priya Patel", 2548.5],
|
| 21 |
+
["Karan Malhotra", 2499.0],
|
| 22 |
+
["Aarav Sharma", 1748.5],
|
| 23 |
+
["Meera Iyer", 1724.0],
|
| 24 |
+
["Suresh Menon", 1648.5],
|
| 25 |
+
["Arjun Nair", 1599.0],
|
| 26 |
+
["Rohan Das", 1098.5],
|
| 27 |
+
["Amit Pandey", 1023.5],
|
| 28 |
+
["Ritu Chopra", 999.0],
|
| 29 |
+
["Ananya Reddy", 898.0],
|
| 30 |
+
["Nikhil Bhat", 849.0],
|
| 31 |
+
["Deepak Verma", 848.0]
|
| 32 |
+
],
|
| 33 |
+
"order_matters": true
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"id": "med_2",
|
| 37 |
+
"question": "Which products have been ordered more than 2 times in total quantity? Show the product name and total quantity ordered, sorted by total quantity from highest to lowest.",
|
| 38 |
+
"ground_truth_sql": "SELECT p.name, SUM(oi.quantity) as total_qty FROM products p JOIN order_items oi ON p.id = oi.product_id GROUP BY p.id HAVING total_qty > 2 ORDER BY total_qty DESC",
|
| 39 |
+
"expected_columns": ["name", "total_qty"],
|
| 40 |
+
"expected_row_count": 8,
|
| 41 |
+
"expected_rows": [
|
| 42 |
+
["Wireless Headphones", 7],
|
| 43 |
+
["Ceramic Mug Set", 6],
|
| 44 |
+
["Smartphone Case", 5],
|
| 45 |
+
["Desk Lamp", 4],
|
| 46 |
+
["Python Programming", 4],
|
| 47 |
+
["Running Shoes", 4],
|
| 48 |
+
["Cotton T-Shirt", 4],
|
| 49 |
+
["Cooking Recipes", 3]
|
| 50 |
+
],
|
| 51 |
+
"order_matters": true
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"id": "med_3",
|
| 55 |
+
"question": "List all customers who have never placed an order. Show their name and email, sorted by name.",
|
| 56 |
+
"ground_truth_sql": "SELECT name, email FROM customers WHERE id NOT IN (SELECT DISTINCT customer_id FROM orders) ORDER BY name",
|
| 57 |
+
"expected_columns": ["name", "email"],
|
| 58 |
+
"expected_row_count": 2,
|
| 59 |
+
"expected_rows": [
|
| 60 |
+
["Nisha Agarwal", "nisha@example.com"],
|
| 61 |
+
["Sneha Gupta", "sneha@example.com"]
|
| 62 |
+
],
|
| 63 |
+
"order_matters": true
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"id": "med_4",
|
| 67 |
+
"question": "What is the total revenue per product category? Calculate revenue as quantity times unit_price from order_items. Show category and revenue (rounded to 2 decimal places), sorted by revenue from highest to lowest.",
|
| 68 |
+
"ground_truth_sql": "SELECT p.category, ROUND(SUM(oi.quantity * oi.unit_price), 2) as revenue FROM products p JOIN order_items oi ON p.id = oi.product_id GROUP BY p.category ORDER BY revenue DESC",
|
| 69 |
+
"expected_columns": ["category", "revenue"],
|
| 70 |
+
"expected_row_count": 4,
|
| 71 |
+
"expected_rows": [
|
| 72 |
+
["Electronics", 28384.0],
|
| 73 |
+
["Clothing", 20889.0],
|
| 74 |
+
["Home", 10339.0],
|
| 75 |
+
["Books", 4989.0]
|
| 76 |
+
],
|
| 77 |
+
"order_matters": true
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"id": "med_5",
|
| 81 |
+
"question": "Who are the top 3 customers by total spending? Show customer name and total amount spent across all orders, sorted by total spent from highest to lowest.",
|
| 82 |
+
"ground_truth_sql": "SELECT c.name, SUM(o.total_amount) as total_spent FROM customers c JOIN orders o ON c.id = o.customer_id GROUP BY c.id ORDER BY total_spent DESC LIMIT 3",
|
| 83 |
+
"expected_columns": ["name", "total_spent"],
|
| 84 |
+
"expected_row_count": 3,
|
| 85 |
+
"expected_rows": [
|
| 86 |
+
["Vikram Singh", 7296.0],
|
| 87 |
+
["Kavita Joshi", 6497.0],
|
| 88 |
+
["Rahul Kumar", 5797.0]
|
| 89 |
+
],
|
| 90 |
+
"order_matters": true
|
| 91 |
+
}
|
| 92 |
+
]
|
| 93 |
+
}
|
inference.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference Script for SQL Query Writing Environment
|
| 3 |
+
===================================================
|
| 4 |
+
MANDATORY — this file must be named `inference.py` and placed in the project root.
|
| 5 |
+
|
| 6 |
+
Uses OpenAI Client for all LLM calls. Reads credentials from environment variables:
|
| 7 |
+
API_BASE_URL The API endpoint for the LLM.
|
| 8 |
+
MODEL_NAME The model identifier to use for inference.
|
| 9 |
+
HF_TOKEN Your Hugging Face / API key.
|
| 10 |
+
|
| 11 |
+
STDOUT FORMAT:
|
| 12 |
+
[START] task=<task_name> env=<benchmark> model=<model_name>
|
| 13 |
+
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
|
| 14 |
+
[END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...,rn>
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
import textwrap
|
| 20 |
+
from typing import List, Optional
|
| 21 |
+
|
| 22 |
+
from openai import OpenAI
|
| 23 |
+
|
| 24 |
+
# ---------------------------------------------------------------------------
|
| 25 |
+
# Environment — runs locally (no Docker needed for inference)
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 28 |
+
os.environ.setdefault("SQL_ENV_TASK", "basic_select")
|
| 29 |
+
|
| 30 |
+
from server.sql_env_environment import SQLEnvironment
|
| 31 |
+
from models import SQLAction
|
| 32 |
+
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
# Configuration
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 37 |
+
if not API_KEY:
|
| 38 |
+
try:
|
| 39 |
+
from huggingface_hub import get_token
|
| 40 |
+
API_KEY = get_token()
|
| 41 |
+
except Exception:
|
| 42 |
+
pass
|
| 43 |
+
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
|
| 44 |
+
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
| 45 |
+
|
| 46 |
+
BENCHMARK = "sql_env"
|
| 47 |
+
TASKS = ["basic_select", "join_aggregate", "advanced_analytics"]
|
| 48 |
+
MAX_STEPS = 8
|
| 49 |
+
TEMPERATURE = 0.3
|
| 50 |
+
MAX_TOKENS = 512
|
| 51 |
+
SUCCESS_SCORE_THRESHOLD = 0.1
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ---------------------------------------------------------------------------
|
| 55 |
+
# Logging helpers (MANDATORY stdout format)
|
| 56 |
+
# ---------------------------------------------------------------------------
|
| 57 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 58 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 62 |
+
# Sanitize action: remove newlines, truncate for readability
|
| 63 |
+
action_clean = action.replace("\n", " ").replace("\r", "").strip()
|
| 64 |
+
if len(action_clean) > 200:
|
| 65 |
+
action_clean = action_clean[:200] + "..."
|
| 66 |
+
error_val = error if error else "null"
|
| 67 |
+
done_val = str(done).lower()
|
| 68 |
+
print(
|
| 69 |
+
f"[STEP] step={step} action={action_clean} reward={reward:.2f} done={done_val} error={error_val}",
|
| 70 |
+
flush=True,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 75 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 76 |
+
print(
|
| 77 |
+
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
|
| 78 |
+
flush=True,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# ---------------------------------------------------------------------------
|
| 83 |
+
# System prompt for the LLM
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
SYSTEM_PROMPT = textwrap.dedent("""
|
| 86 |
+
You are an expert SQL query writer. You are given a database schema and a
|
| 87 |
+
natural language question. Write a single SQL SELECT query that answers
|
| 88 |
+
the question exactly.
|
| 89 |
+
|
| 90 |
+
Rules:
|
| 91 |
+
- Write ONLY the SQL query, nothing else. No explanations, no markdown.
|
| 92 |
+
- Use only SELECT statements (no INSERT, UPDATE, DELETE, etc.)
|
| 93 |
+
- Match the requested column names and sorting exactly.
|
| 94 |
+
- Use standard SQL compatible with SQLite.
|
| 95 |
+
- If the question asks for rounding, use ROUND().
|
| 96 |
+
- If the question asks for sorting, include ORDER BY.
|
| 97 |
+
- Pay attention to whether results should be sorted ascending or descending.
|
| 98 |
+
""").strip()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def build_user_prompt(
|
| 102 |
+
question: str,
|
| 103 |
+
schema: str,
|
| 104 |
+
last_result: str,
|
| 105 |
+
last_error: str,
|
| 106 |
+
feedback: str,
|
| 107 |
+
attempt: int,
|
| 108 |
+
) -> str:
|
| 109 |
+
"""Build the prompt for the LLM."""
|
| 110 |
+
parts = [
|
| 111 |
+
f"DATABASE SCHEMA:\n{schema}\n",
|
| 112 |
+
f"QUESTION: {question}\n",
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
if attempt > 1:
|
| 116 |
+
parts.append(f"PREVIOUS ATTEMPT RESULT:\n{last_result}\n")
|
| 117 |
+
if last_error:
|
| 118 |
+
parts.append(f"ERROR: {last_error}\n")
|
| 119 |
+
if feedback:
|
| 120 |
+
parts.append(f"FEEDBACK: {feedback}\n")
|
| 121 |
+
parts.append(
|
| 122 |
+
f"This is attempt {attempt}. Fix the query based on the feedback above.\n"
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
parts.append("Write the SQL query:")
|
| 126 |
+
return "\n".join(parts)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def get_sql_from_model(
|
| 130 |
+
client: OpenAI,
|
| 131 |
+
question: str,
|
| 132 |
+
schema: str,
|
| 133 |
+
last_result: str,
|
| 134 |
+
last_error: str,
|
| 135 |
+
feedback: str,
|
| 136 |
+
attempt: int,
|
| 137 |
+
) -> str:
|
| 138 |
+
"""Call the LLM to generate a SQL query."""
|
| 139 |
+
user_prompt = build_user_prompt(
|
| 140 |
+
question, schema, last_result, last_error, feedback, attempt
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
try:
|
| 144 |
+
completion = client.chat.completions.create(
|
| 145 |
+
model=MODEL_NAME,
|
| 146 |
+
messages=[
|
| 147 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 148 |
+
{"role": "user", "content": user_prompt},
|
| 149 |
+
],
|
| 150 |
+
temperature=TEMPERATURE,
|
| 151 |
+
max_tokens=MAX_TOKENS,
|
| 152 |
+
stream=False,
|
| 153 |
+
)
|
| 154 |
+
text = (completion.choices[0].message.content or "").strip()
|
| 155 |
+
|
| 156 |
+
# Clean up: remove markdown code blocks if present
|
| 157 |
+
if text.startswith("```"):
|
| 158 |
+
lines = text.split("\n")
|
| 159 |
+
# Remove first and last lines (```sql and ```)
|
| 160 |
+
lines = [l for l in lines if not l.strip().startswith("```")]
|
| 161 |
+
text = "\n".join(lines).strip()
|
| 162 |
+
|
| 163 |
+
return text if text else "SELECT 1"
|
| 164 |
+
except Exception as exc:
|
| 165 |
+
print(f"[DEBUG] Model request failed: {exc}", flush=True)
|
| 166 |
+
return "SELECT 1"
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def run_task(client: OpenAI, task_name: str) -> None:
|
| 170 |
+
"""Run a single task and emit [START]/[STEP]/[END] logs."""
|
| 171 |
+
os.environ["SQL_ENV_TASK"] = task_name
|
| 172 |
+
|
| 173 |
+
env = SQLEnvironment()
|
| 174 |
+
obs = env.reset()
|
| 175 |
+
|
| 176 |
+
rewards: List[float] = []
|
| 177 |
+
steps_taken = 0
|
| 178 |
+
score = 0.0
|
| 179 |
+
success = False
|
| 180 |
+
attempt_on_q = 0
|
| 181 |
+
|
| 182 |
+
log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
last_result = ""
|
| 186 |
+
last_error = ""
|
| 187 |
+
feedback = ""
|
| 188 |
+
attempt_on_q = 1
|
| 189 |
+
|
| 190 |
+
for step in range(1, MAX_STEPS + 1):
|
| 191 |
+
if obs.done:
|
| 192 |
+
break
|
| 193 |
+
|
| 194 |
+
# Get SQL from model
|
| 195 |
+
sql_query = get_sql_from_model(
|
| 196 |
+
client=client,
|
| 197 |
+
question=obs.question,
|
| 198 |
+
schema=obs.schema_description,
|
| 199 |
+
last_result=last_result,
|
| 200 |
+
last_error=last_error,
|
| 201 |
+
feedback=feedback,
|
| 202 |
+
attempt=attempt_on_q,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Step the environment
|
| 206 |
+
obs = env.step(SQLAction(query=sql_query))
|
| 207 |
+
|
| 208 |
+
reward = obs.reward
|
| 209 |
+
done = obs.done
|
| 210 |
+
error = obs.error if obs.error else None
|
| 211 |
+
|
| 212 |
+
rewards.append(reward)
|
| 213 |
+
steps_taken = step
|
| 214 |
+
|
| 215 |
+
log_step(
|
| 216 |
+
step=step,
|
| 217 |
+
action=sql_query,
|
| 218 |
+
reward=reward,
|
| 219 |
+
done=done,
|
| 220 |
+
error=error,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Track state for retry prompting
|
| 224 |
+
last_result = obs.query_result
|
| 225 |
+
last_error = obs.error
|
| 226 |
+
feedback = obs.metadata.get("feedback", "")
|
| 227 |
+
|
| 228 |
+
# Track attempt number for current question
|
| 229 |
+
if reward >= 0.98: # near-perfect, moved to next question
|
| 230 |
+
attempt_on_q = 1
|
| 231 |
+
else:
|
| 232 |
+
attempt_on_q += 1
|
| 233 |
+
|
| 234 |
+
if done:
|
| 235 |
+
break
|
| 236 |
+
|
| 237 |
+
# Calculate normalized score
|
| 238 |
+
max_possible = obs.total_questions # 5 questions, max 1.0 each
|
| 239 |
+
if max_possible > 0:
|
| 240 |
+
score = sum(rewards) / max_possible
|
| 241 |
+
score = min(max(score, 0.0), 1.0)
|
| 242 |
+
success = score >= SUCCESS_SCORE_THRESHOLD
|
| 243 |
+
|
| 244 |
+
except Exception as exc:
|
| 245 |
+
print(f"[DEBUG] Task {task_name} error: {exc}", flush=True)
|
| 246 |
+
|
| 247 |
+
finally:
|
| 248 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def main() -> None:
|
| 252 |
+
"""Run inference on all 3 tasks."""
|
| 253 |
+
if not API_KEY:
|
| 254 |
+
print("[ERROR] HF_TOKEN or API_KEY environment variable is required.", flush=True)
|
| 255 |
+
sys.exit(1)
|
| 256 |
+
|
| 257 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 258 |
+
|
| 259 |
+
for task_name in TASKS:
|
| 260 |
+
run_task(client, task_name)
|
| 261 |
+
print("", flush=True) # blank line between tasks
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data models for the SQL Query Writing Environment.
|
| 3 |
+
|
| 4 |
+
Defines the Action (agent submits SQL query) and Observation
|
| 5 |
+
(agent sees schema, question, query results, reward) types.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from openenv.core.env_server.types import Action, Observation
|
| 9 |
+
from pydantic import Field
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SQLAction(Action):
|
| 13 |
+
"""Action for the SQL environment - a SQL query to execute."""
|
| 14 |
+
|
| 15 |
+
query: str = Field(..., description="SQL SELECT query to execute against the database")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SQLObservation(Observation):
|
| 19 |
+
"""Observation from the SQL environment."""
|
| 20 |
+
|
| 21 |
+
task_name: str = Field(default="", description="Current task identifier (e.g., basic_select)")
|
| 22 |
+
question: str = Field(default="", description="Natural language question to answer with SQL")
|
| 23 |
+
schema_description: str = Field(default="", description="Human-readable database schema")
|
| 24 |
+
query_result: str = Field(default="", description="Result of the last executed query (or error)")
|
| 25 |
+
error: str = Field(default="", description="SQL error message if query failed, empty otherwise")
|
| 26 |
+
steps_remaining: int = Field(default=0, description="Number of attempts remaining for current question")
|
| 27 |
+
question_index: int = Field(default=0, description="Current question number (1-indexed)")
|
| 28 |
+
total_questions: int = Field(default=0, description="Total questions in this task")
|
openenv.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: sql_env
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 7860
|
pyproject.toml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
[build-system]
|
| 8 |
+
requires = ["setuptools>=45", "wheel"]
|
| 9 |
+
build-backend = "setuptools.build_meta"
|
| 10 |
+
|
| 11 |
+
[project]
|
| 12 |
+
name = "openenv-sql_env"
|
| 13 |
+
version = "0.1.0"
|
| 14 |
+
description = "Sql Env environment for OpenEnv"
|
| 15 |
+
requires-python = ">=3.10"
|
| 16 |
+
dependencies = [
|
| 17 |
+
# Core OpenEnv runtime (provides FastAPI server + HTTP client types)
|
| 18 |
+
# install from github
|
| 19 |
+
# "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
|
| 20 |
+
"openenv-core[core]>=0.2.2",
|
| 21 |
+
# Environment-specific dependencies
|
| 22 |
+
# Add all dependencies needed for your environment here
|
| 23 |
+
# Examples:
|
| 24 |
+
# "numpy>=1.19.0",
|
| 25 |
+
# "torch>=2.0.0",
|
| 26 |
+
# "gymnasium>=0.29.0",
|
| 27 |
+
# "openspiel>=1.0.0",
|
| 28 |
+
# "smolagents>=1.22.0,<2",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
dev = [
|
| 33 |
+
"pytest>=8.0.0",
|
| 34 |
+
"pytest-cov>=4.0.0",
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
[project.scripts]
|
| 38 |
+
# Server entry point - enables running via: uv run --project . server
|
| 39 |
+
# or: python -m sql_env.server.app
|
| 40 |
+
server = "sql_env.server.app:main"
|
| 41 |
+
|
| 42 |
+
[tool.setuptools]
|
| 43 |
+
include-package-data = true
|
| 44 |
+
packages = ["sql_env", "sql_env.server"]
|
| 45 |
+
package-dir = { "sql_env" = ".", "sql_env.server" = "server" }
|
server/Dockerfile
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Multi-stage build using openenv-base
|
| 8 |
+
# This Dockerfile is flexible and works for both:
|
| 9 |
+
# - In-repo environments (with local OpenEnv sources)
|
| 10 |
+
# - Standalone environments (with openenv from PyPI/Git)
|
| 11 |
+
# The build script (openenv build) handles context detection and sets appropriate build args.
|
| 12 |
+
|
| 13 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 14 |
+
FROM ${BASE_IMAGE} AS builder
|
| 15 |
+
|
| 16 |
+
WORKDIR /app
|
| 17 |
+
|
| 18 |
+
# Ensure git is available (required for installing dependencies from VCS)
|
| 19 |
+
RUN apt-get update && \
|
| 20 |
+
apt-get install -y --no-install-recommends git && \
|
| 21 |
+
rm -rf /var/lib/apt/lists/*
|
| 22 |
+
|
| 23 |
+
# Build argument to control whether we're building standalone or in-repo
|
| 24 |
+
ARG BUILD_MODE=in-repo
|
| 25 |
+
ARG ENV_NAME=sql_env
|
| 26 |
+
|
| 27 |
+
# Copy environment code (always at root of build context)
|
| 28 |
+
COPY . /app/env
|
| 29 |
+
|
| 30 |
+
# For in-repo builds, openenv is already vendored in the build context
|
| 31 |
+
# For standalone builds, openenv will be installed via pyproject.toml
|
| 32 |
+
WORKDIR /app/env
|
| 33 |
+
|
| 34 |
+
# Ensure uv is available (for local builds where base image lacks it)
|
| 35 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 36 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 37 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 38 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
# Install dependencies using uv sync
|
| 42 |
+
# If uv.lock exists, use it; otherwise resolve on the fly
|
| 43 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 44 |
+
if [ -f uv.lock ]; then \
|
| 45 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 46 |
+
else \
|
| 47 |
+
uv sync --no-install-project --no-editable; \
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 51 |
+
if [ -f uv.lock ]; then \
|
| 52 |
+
uv sync --frozen --no-editable; \
|
| 53 |
+
else \
|
| 54 |
+
uv sync --no-editable; \
|
| 55 |
+
fi
|
| 56 |
+
|
| 57 |
+
# Final runtime stage
|
| 58 |
+
FROM ${BASE_IMAGE}
|
| 59 |
+
|
| 60 |
+
WORKDIR /app
|
| 61 |
+
|
| 62 |
+
# Copy the virtual environment from builder
|
| 63 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 64 |
+
|
| 65 |
+
# Copy the environment code
|
| 66 |
+
COPY --from=builder /app/env /app/env
|
| 67 |
+
|
| 68 |
+
# Set PATH to use the virtual environment
|
| 69 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 70 |
+
|
| 71 |
+
# Set PYTHONPATH so imports work correctly
|
| 72 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 73 |
+
|
| 74 |
+
# Health check
|
| 75 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 76 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 77 |
+
|
| 78 |
+
# Run the FastAPI server
|
| 79 |
+
# The module path is constructed to work with the /app/env structure
|
| 80 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""SQL environment server components."""
|
server/app.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI application for the SQL Query Writing Environment.
|
| 3 |
+
|
| 4 |
+
Endpoints:
|
| 5 |
+
- POST /reset: Reset the environment
|
| 6 |
+
- POST /step: Execute an action (SQL query)
|
| 7 |
+
- GET /state: Get current environment state
|
| 8 |
+
- GET /health: Health check
|
| 9 |
+
- WS /ws: WebSocket endpoint for persistent sessions
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from openenv.core.env_server.http_server import create_app
|
| 17 |
+
except Exception as e:
|
| 18 |
+
raise ImportError(
|
| 19 |
+
"openenv is required. Install with: pip install openenv-core"
|
| 20 |
+
) from e
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from ..models import SQLAction, SQLObservation
|
| 24 |
+
from .sql_env_environment import SQLEnvironment
|
| 25 |
+
except (ImportError, ModuleNotFoundError):
|
| 26 |
+
from models import SQLAction, SQLObservation
|
| 27 |
+
from server.sql_env_environment import SQLEnvironment
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
app = create_app(
|
| 31 |
+
SQLEnvironment,
|
| 32 |
+
SQLAction,
|
| 33 |
+
SQLObservation,
|
| 34 |
+
env_name="sql_env",
|
| 35 |
+
max_concurrent_envs=3,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def main(host: str = "0.0.0.0", port: int = 8000):
|
| 40 |
+
"""Entry point for direct execution."""
|
| 41 |
+
import uvicorn
|
| 42 |
+
uvicorn.run(app, host=host, port=port)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
import argparse
|
| 47 |
+
parser = argparse.ArgumentParser()
|
| 48 |
+
parser.add_argument("--port", type=int, default=8000)
|
| 49 |
+
args = parser.parse_args()
|
| 50 |
+
main(port=args.port)
|
server/database.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQLite database management for SQLEnv.
|
| 3 |
+
|
| 4 |
+
Handles database initialization, query execution, and schema introspection.
|
| 5 |
+
All operations use an in-memory SQLite database that is re-created on each
|
| 6 |
+
environment reset, ensuring deterministic, isolated episodes.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sqlite3
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import List, Optional, Tuple
|
| 14 |
+
|
| 15 |
+
DATA_DIR = Path(__file__).resolve().parent.parent / "data"
|
| 16 |
+
SCHEMA_PATH = DATA_DIR / "schema.sql"
|
| 17 |
+
SEED_PATH = DATA_DIR / "seed.sql"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class QueryResult:
|
| 22 |
+
"""Result of executing a SQL query."""
|
| 23 |
+
|
| 24 |
+
columns: List[str] = field(default_factory=list)
|
| 25 |
+
rows: List[Tuple] = field(default_factory=list)
|
| 26 |
+
error: Optional[str] = None
|
| 27 |
+
row_count: int = 0
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def success(self) -> bool:
|
| 31 |
+
return self.error is None
|
| 32 |
+
|
| 33 |
+
def to_display_string(self, max_rows: int = 20) -> str:
|
| 34 |
+
"""Format result as a readable table string."""
|
| 35 |
+
if self.error:
|
| 36 |
+
return f"ERROR: {self.error}"
|
| 37 |
+
if not self.columns:
|
| 38 |
+
return "(no results)"
|
| 39 |
+
|
| 40 |
+
# Calculate column widths
|
| 41 |
+
col_widths = [len(str(c)) for c in self.columns]
|
| 42 |
+
display_rows = self.rows[:max_rows]
|
| 43 |
+
for row in display_rows:
|
| 44 |
+
for i, val in enumerate(row):
|
| 45 |
+
col_widths[i] = max(col_widths[i], len(str(val)))
|
| 46 |
+
|
| 47 |
+
# Build table
|
| 48 |
+
header = " | ".join(
|
| 49 |
+
str(c).ljust(col_widths[i]) for i, c in enumerate(self.columns)
|
| 50 |
+
)
|
| 51 |
+
separator = "-+-".join("-" * w for w in col_widths)
|
| 52 |
+
lines = [header, separator]
|
| 53 |
+
|
| 54 |
+
for row in display_rows:
|
| 55 |
+
line = " | ".join(
|
| 56 |
+
str(val).ljust(col_widths[i]) for i, val in enumerate(row)
|
| 57 |
+
)
|
| 58 |
+
lines.append(line)
|
| 59 |
+
|
| 60 |
+
if len(self.rows) > max_rows:
|
| 61 |
+
lines.append(f"... ({len(self.rows) - max_rows} more rows)")
|
| 62 |
+
|
| 63 |
+
lines.append(f"\n({self.row_count} row{'s' if self.row_count != 1 else ''})")
|
| 64 |
+
return "\n".join(lines)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class Database:
|
| 68 |
+
"""
|
| 69 |
+
Manages an in-memory SQLite database for one episode.
|
| 70 |
+
|
| 71 |
+
Each call to `initialize()` creates a fresh database with the schema
|
| 72 |
+
and seed data, ensuring deterministic state across episodes.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self):
|
| 76 |
+
self._conn: Optional[sqlite3.Connection] = None
|
| 77 |
+
|
| 78 |
+
def initialize(self) -> None:
|
| 79 |
+
"""Create a fresh in-memory database with schema and seed data."""
|
| 80 |
+
self.close()
|
| 81 |
+
self._conn = sqlite3.connect(":memory:")
|
| 82 |
+
self._conn.execute("PRAGMA foreign_keys = ON")
|
| 83 |
+
|
| 84 |
+
schema_sql = SCHEMA_PATH.read_text()
|
| 85 |
+
self._conn.executescript(schema_sql)
|
| 86 |
+
|
| 87 |
+
seed_sql = SEED_PATH.read_text()
|
| 88 |
+
self._conn.executescript(seed_sql)
|
| 89 |
+
|
| 90 |
+
self._conn.commit()
|
| 91 |
+
|
| 92 |
+
def execute_query(self, sql: str, timeout_seconds: float = 5.0) -> QueryResult:
|
| 93 |
+
"""
|
| 94 |
+
Execute a SQL query and return the result.
|
| 95 |
+
|
| 96 |
+
Only SELECT statements are allowed. Modification statements
|
| 97 |
+
(INSERT, UPDATE, DELETE, DROP, ALTER, CREATE) are rejected.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
sql: The SQL query string to execute.
|
| 101 |
+
timeout_seconds: Max execution time (unused for SQLite in-memory).
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
QueryResult with columns, rows, and potential error.
|
| 105 |
+
"""
|
| 106 |
+
if self._conn is None:
|
| 107 |
+
return QueryResult(error="Database not initialized. Call reset() first.")
|
| 108 |
+
|
| 109 |
+
# Strip and normalize
|
| 110 |
+
stripped = sql.strip().rstrip(";").strip()
|
| 111 |
+
if not stripped:
|
| 112 |
+
return QueryResult(error="Empty query.")
|
| 113 |
+
|
| 114 |
+
# Block modification statements
|
| 115 |
+
first_word = stripped.split()[0].upper()
|
| 116 |
+
blocked = {"INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "TRUNCATE", "REPLACE"}
|
| 117 |
+
if first_word in blocked:
|
| 118 |
+
return QueryResult(
|
| 119 |
+
error=f"Only SELECT queries are allowed. Got: {first_word}"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
cursor = self._conn.execute(stripped)
|
| 124 |
+
if cursor.description is None:
|
| 125 |
+
return QueryResult(error="Query did not return results.")
|
| 126 |
+
|
| 127 |
+
columns = [desc[0] for desc in cursor.description]
|
| 128 |
+
rows = cursor.fetchall()
|
| 129 |
+
return QueryResult(
|
| 130 |
+
columns=columns,
|
| 131 |
+
rows=rows,
|
| 132 |
+
row_count=len(rows),
|
| 133 |
+
)
|
| 134 |
+
except sqlite3.Error as e:
|
| 135 |
+
return QueryResult(error=str(e))
|
| 136 |
+
|
| 137 |
+
def get_schema_description(self) -> str:
|
| 138 |
+
"""
|
| 139 |
+
Return a human-readable description of the database schema
|
| 140 |
+
including table structures and sample data.
|
| 141 |
+
"""
|
| 142 |
+
schema_text = []
|
| 143 |
+
schema_text.append("=== DATABASE SCHEMA ===\n")
|
| 144 |
+
|
| 145 |
+
tables = [
|
| 146 |
+
("customers", "Customer information"),
|
| 147 |
+
("products", "Product catalog"),
|
| 148 |
+
("orders", "Customer orders"),
|
| 149 |
+
("order_items", "Items within each order"),
|
| 150 |
+
("reviews", "Product reviews by customers"),
|
| 151 |
+
]
|
| 152 |
+
|
| 153 |
+
if self._conn is None:
|
| 154 |
+
return "Database not initialized."
|
| 155 |
+
|
| 156 |
+
for table_name, description in tables:
|
| 157 |
+
schema_text.append(f"TABLE: {table_name} -- {description}")
|
| 158 |
+
|
| 159 |
+
# Get column info
|
| 160 |
+
cursor = self._conn.execute(f"PRAGMA table_info({table_name})")
|
| 161 |
+
columns = cursor.fetchall()
|
| 162 |
+
for col in columns:
|
| 163 |
+
# col: (cid, name, type, notnull, default_value, pk)
|
| 164 |
+
col_name = col[1]
|
| 165 |
+
col_type = col[2]
|
| 166 |
+
is_pk = " PRIMARY KEY" if col[5] else ""
|
| 167 |
+
is_nn = " NOT NULL" if col[3] else ""
|
| 168 |
+
schema_text.append(f" {col_name} {col_type}{is_pk}{is_nn}")
|
| 169 |
+
|
| 170 |
+
# Get foreign keys
|
| 171 |
+
cursor = self._conn.execute(f"PRAGMA foreign_key_list({table_name})")
|
| 172 |
+
fks = cursor.fetchall()
|
| 173 |
+
for fk in fks:
|
| 174 |
+
schema_text.append(f" FOREIGN KEY ({fk[3]}) REFERENCES {fk[2]}({fk[4]})")
|
| 175 |
+
|
| 176 |
+
# Show sample data (first 3 rows)
|
| 177 |
+
result = self.execute_query(f"SELECT * FROM {table_name} LIMIT 3")
|
| 178 |
+
if result.success and result.rows:
|
| 179 |
+
schema_text.append(f" Sample data ({result.row_count} rows shown):")
|
| 180 |
+
for row in result.rows:
|
| 181 |
+
schema_text.append(f" {row}")
|
| 182 |
+
|
| 183 |
+
# Show total count
|
| 184 |
+
count_result = self.execute_query(
|
| 185 |
+
f"SELECT COUNT(*) FROM {table_name}"
|
| 186 |
+
)
|
| 187 |
+
if count_result.success and count_result.rows:
|
| 188 |
+
total = count_result.rows[0][0]
|
| 189 |
+
schema_text.append(f" Total rows: {total}")
|
| 190 |
+
|
| 191 |
+
schema_text.append("")
|
| 192 |
+
|
| 193 |
+
# Add relationship summary
|
| 194 |
+
schema_text.append("=== RELATIONSHIPS ===")
|
| 195 |
+
schema_text.append("orders.customer_id -> customers.id")
|
| 196 |
+
schema_text.append("order_items.order_id -> orders.id")
|
| 197 |
+
schema_text.append("order_items.product_id -> products.id")
|
| 198 |
+
schema_text.append("reviews.product_id -> products.id")
|
| 199 |
+
schema_text.append("reviews.customer_id -> customers.id")
|
| 200 |
+
schema_text.append("")
|
| 201 |
+
schema_text.append("=== NOTES ===")
|
| 202 |
+
schema_text.append("- All dates are in ISO format (YYYY-MM-DD)")
|
| 203 |
+
schema_text.append("- Prices are in INR (Indian Rupees)")
|
| 204 |
+
schema_text.append("- Order status: pending, shipped, delivered, cancelled")
|
| 205 |
+
schema_text.append("- Product categories: Electronics, Clothing, Books, Home")
|
| 206 |
+
schema_text.append("- Ratings are integers from 1 to 5")
|
| 207 |
+
|
| 208 |
+
return "\n".join(schema_text)
|
| 209 |
+
|
| 210 |
+
def close(self) -> None:
|
| 211 |
+
"""Close the database connection."""
|
| 212 |
+
if self._conn is not None:
|
| 213 |
+
self._conn.close()
|
| 214 |
+
self._conn = None
|
server/graders.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-component grading system for SQL query evaluation.
|
| 3 |
+
|
| 4 |
+
Scores agent queries against ground truth with partial credit:
|
| 5 |
+
- syntax_score (0.1): Query parses and executes without error
|
| 6 |
+
- column_score (0.2): Fraction of expected columns present
|
| 7 |
+
- row_score (0.3): Fraction of expected rows matching
|
| 8 |
+
- exact_score (0.4): Full result set matches ground truth exactly
|
| 9 |
+
|
| 10 |
+
Total reward per question is in [0.0, 1.0].
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from typing import Any, List, Optional, Tuple
|
| 14 |
+
|
| 15 |
+
from .database import Database, QueryResult
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _normalize_value(val: Any) -> Any:
|
| 19 |
+
"""Normalize a value for comparison (handle float/int equivalence, None)."""
|
| 20 |
+
if val is None:
|
| 21 |
+
return None
|
| 22 |
+
if isinstance(val, float):
|
| 23 |
+
if val == int(val):
|
| 24 |
+
return int(val)
|
| 25 |
+
return round(val, 2)
|
| 26 |
+
if isinstance(val, str):
|
| 27 |
+
return val.strip()
|
| 28 |
+
return val
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _normalize_row(row: Tuple) -> Tuple:
|
| 32 |
+
"""Normalize all values in a row."""
|
| 33 |
+
return tuple(_normalize_value(v) for v in row)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _normalize_column_name(col: str) -> str:
|
| 37 |
+
"""Normalize column name for comparison (lowercase, strip)."""
|
| 38 |
+
return col.strip().lower()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def grade_query(
|
| 42 |
+
db: Database,
|
| 43 |
+
agent_sql: str,
|
| 44 |
+
expected_columns: List[str],
|
| 45 |
+
expected_rows: List[List],
|
| 46 |
+
order_matters: bool = True,
|
| 47 |
+
) -> dict:
|
| 48 |
+
"""
|
| 49 |
+
Grade an agent's SQL query against expected results.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
db: Active Database instance.
|
| 53 |
+
agent_sql: The SQL query submitted by the agent.
|
| 54 |
+
expected_columns: List of expected column names.
|
| 55 |
+
expected_rows: List of expected row values (list of lists).
|
| 56 |
+
order_matters: Whether row order affects scoring.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Dictionary with:
|
| 60 |
+
- reward: float in [0.0, 1.0]
|
| 61 |
+
- syntax_score: float
|
| 62 |
+
- column_score: float
|
| 63 |
+
- row_score: float
|
| 64 |
+
- exact_score: float
|
| 65 |
+
- query_result: QueryResult object
|
| 66 |
+
- feedback: str describing what was right/wrong
|
| 67 |
+
"""
|
| 68 |
+
result = db.execute_query(agent_sql)
|
| 69 |
+
|
| 70 |
+
# Component weights
|
| 71 |
+
W_SYNTAX = 0.1
|
| 72 |
+
W_COLUMN = 0.2
|
| 73 |
+
W_ROW = 0.3
|
| 74 |
+
W_EXACT = 0.4
|
| 75 |
+
|
| 76 |
+
# --- Syntax Score ---
|
| 77 |
+
if not result.success:
|
| 78 |
+
return {
|
| 79 |
+
"reward": 0.0,
|
| 80 |
+
"syntax_score": 0.0,
|
| 81 |
+
"column_score": 0.0,
|
| 82 |
+
"row_score": 0.0,
|
| 83 |
+
"exact_score": 0.0,
|
| 84 |
+
"query_result": result,
|
| 85 |
+
"feedback": f"SQL error: {result.error}",
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
syntax_score = 1.0
|
| 89 |
+
|
| 90 |
+
# --- Column Score ---
|
| 91 |
+
expected_cols_normalized = [_normalize_column_name(c) for c in expected_columns]
|
| 92 |
+
actual_cols_normalized = [_normalize_column_name(c) for c in result.columns]
|
| 93 |
+
|
| 94 |
+
if not expected_cols_normalized:
|
| 95 |
+
column_score = 1.0 if not actual_cols_normalized else 0.0
|
| 96 |
+
else:
|
| 97 |
+
matched_cols = sum(
|
| 98 |
+
1 for c in expected_cols_normalized if c in actual_cols_normalized
|
| 99 |
+
)
|
| 100 |
+
column_score = matched_cols / len(expected_cols_normalized)
|
| 101 |
+
|
| 102 |
+
# --- Row Score ---
|
| 103 |
+
expected_rows_normalized = [
|
| 104 |
+
_normalize_row(tuple(row)) for row in expected_rows
|
| 105 |
+
]
|
| 106 |
+
actual_rows_normalized = [_normalize_row(row) for row in result.rows]
|
| 107 |
+
|
| 108 |
+
if not expected_rows_normalized:
|
| 109 |
+
row_score = 1.0 if not actual_rows_normalized else 0.0
|
| 110 |
+
else:
|
| 111 |
+
if order_matters:
|
| 112 |
+
# For ordered results, match position-by-position
|
| 113 |
+
matched_rows = 0
|
| 114 |
+
for i, expected_row in enumerate(expected_rows_normalized):
|
| 115 |
+
if i < len(actual_rows_normalized):
|
| 116 |
+
if _rows_match(expected_row, actual_rows_normalized[i], expected_cols_normalized, actual_cols_normalized):
|
| 117 |
+
matched_rows += 1
|
| 118 |
+
row_score = matched_rows / len(expected_rows_normalized)
|
| 119 |
+
else:
|
| 120 |
+
# For unordered results, check set membership
|
| 121 |
+
matched_rows = 0
|
| 122 |
+
remaining_actual = list(actual_rows_normalized)
|
| 123 |
+
for expected_row in expected_rows_normalized:
|
| 124 |
+
for j, actual_row in enumerate(remaining_actual):
|
| 125 |
+
if _rows_match(expected_row, actual_row, expected_cols_normalized, actual_cols_normalized):
|
| 126 |
+
matched_rows += 1
|
| 127 |
+
remaining_actual.pop(j)
|
| 128 |
+
break
|
| 129 |
+
row_score = matched_rows / len(expected_rows_normalized)
|
| 130 |
+
|
| 131 |
+
# --- Exact Score ---
|
| 132 |
+
exact_score = 0.0
|
| 133 |
+
if column_score == 1.0 and row_score == 1.0:
|
| 134 |
+
# Check exact match: same number of rows and all matched
|
| 135 |
+
if len(actual_rows_normalized) == len(expected_rows_normalized):
|
| 136 |
+
exact_score = 1.0
|
| 137 |
+
else:
|
| 138 |
+
# Extra rows returned — partial exact credit
|
| 139 |
+
exact_score = 0.5
|
| 140 |
+
|
| 141 |
+
# --- Total Reward ---
|
| 142 |
+
reward = (
|
| 143 |
+
W_SYNTAX * syntax_score
|
| 144 |
+
+ W_COLUMN * column_score
|
| 145 |
+
+ W_ROW * row_score
|
| 146 |
+
+ W_EXACT * exact_score
|
| 147 |
+
)
|
| 148 |
+
reward = round(min(max(reward, 0.0), 1.0), 4)
|
| 149 |
+
|
| 150 |
+
# --- Feedback ---
|
| 151 |
+
feedback_parts = []
|
| 152 |
+
if syntax_score == 1.0:
|
| 153 |
+
feedback_parts.append("Query executed successfully.")
|
| 154 |
+
if column_score < 1.0:
|
| 155 |
+
missing = [c for c in expected_cols_normalized if c not in actual_cols_normalized]
|
| 156 |
+
feedback_parts.append(f"Missing columns: {missing}. Expected: {expected_cols_normalized}, Got: {actual_cols_normalized}")
|
| 157 |
+
if row_score < 1.0:
|
| 158 |
+
feedback_parts.append(
|
| 159 |
+
f"Row match: {row_score:.0%} ({int(row_score * len(expected_rows_normalized))}/{len(expected_rows_normalized)} rows correct). "
|
| 160 |
+
f"Expected {len(expected_rows_normalized)} rows, got {len(actual_rows_normalized)}."
|
| 161 |
+
)
|
| 162 |
+
if exact_score == 1.0:
|
| 163 |
+
feedback_parts.append("Perfect match!")
|
| 164 |
+
elif exact_score == 0.5:
|
| 165 |
+
feedback_parts.append(f"All expected rows found but got {len(actual_rows_normalized)} rows instead of {len(expected_rows_normalized)} (extra rows).")
|
| 166 |
+
|
| 167 |
+
return {
|
| 168 |
+
"reward": reward,
|
| 169 |
+
"syntax_score": syntax_score,
|
| 170 |
+
"column_score": column_score,
|
| 171 |
+
"row_score": row_score,
|
| 172 |
+
"exact_score": exact_score,
|
| 173 |
+
"query_result": result,
|
| 174 |
+
"feedback": " ".join(feedback_parts),
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _rows_match(
|
| 179 |
+
expected_row: Tuple,
|
| 180 |
+
actual_row: Tuple,
|
| 181 |
+
expected_cols: List[str],
|
| 182 |
+
actual_cols: List[str],
|
| 183 |
+
) -> bool:
|
| 184 |
+
"""
|
| 185 |
+
Check if an actual row matches an expected row.
|
| 186 |
+
|
| 187 |
+
Handles column reordering: maps expected columns to actual column positions.
|
| 188 |
+
"""
|
| 189 |
+
if len(expected_cols) != len(expected_row):
|
| 190 |
+
return False
|
| 191 |
+
|
| 192 |
+
# Build a mapping from expected column index to actual column index
|
| 193 |
+
col_map = {}
|
| 194 |
+
for i, ec in enumerate(expected_cols):
|
| 195 |
+
if ec in actual_cols:
|
| 196 |
+
col_map[i] = actual_cols.index(ec)
|
| 197 |
+
else:
|
| 198 |
+
return False # Missing column
|
| 199 |
+
|
| 200 |
+
for exp_idx, act_idx in col_map.items():
|
| 201 |
+
if act_idx >= len(actual_row):
|
| 202 |
+
return False
|
| 203 |
+
exp_val = expected_row[exp_idx]
|
| 204 |
+
act_val = _normalize_value(actual_row[act_idx])
|
| 205 |
+
if exp_val != act_val:
|
| 206 |
+
# Try numeric comparison with tolerance
|
| 207 |
+
try:
|
| 208 |
+
if abs(float(exp_val) - float(act_val)) < 0.01:
|
| 209 |
+
continue
|
| 210 |
+
except (TypeError, ValueError):
|
| 211 |
+
pass
|
| 212 |
+
return False
|
| 213 |
+
|
| 214 |
+
return True
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core[core]>=0.2.0
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
openai>=1.0.0
|
server/sql_env_environment.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQL Query Writing Environment.
|
| 3 |
+
|
| 4 |
+
An AI agent receives a database schema and natural language question,
|
| 5 |
+
then writes SQL queries to answer the question. The environment grades
|
| 6 |
+
each query with partial-credit scoring and provides feedback.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from uuid import uuid4
|
| 13 |
+
|
| 14 |
+
from openenv.core.env_server.interfaces import Environment
|
| 15 |
+
from openenv.core.env_server.types import State
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from ..models import SQLAction, SQLObservation
|
| 19 |
+
except ImportError:
|
| 20 |
+
from models import SQLAction, SQLObservation
|
| 21 |
+
|
| 22 |
+
from .database import Database
|
| 23 |
+
from .graders import grade_query
|
| 24 |
+
|
| 25 |
+
TASKS_DIR = Path(__file__).resolve().parent.parent / "data" / "tasks"
|
| 26 |
+
|
| 27 |
+
# Default task can be overridden via environment variable
|
| 28 |
+
DEFAULT_TASK = os.getenv("SQL_ENV_TASK", "basic_select")
|
| 29 |
+
MAX_TOTAL_STEPS = int(os.getenv("SQL_ENV_MAX_STEPS", "15"))
|
| 30 |
+
STEP_PENALTY = float(os.getenv("SQL_ENV_STEP_PENALTY", "0.02"))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _load_task(task_name: str) -> dict:
|
| 34 |
+
"""Load a task definition from JSON file."""
|
| 35 |
+
task_path = TASKS_DIR / f"{task_name}.json"
|
| 36 |
+
if not task_path.exists():
|
| 37 |
+
available = [f.stem for f in TASKS_DIR.glob("*.json")]
|
| 38 |
+
raise ValueError(
|
| 39 |
+
f"Task '{task_name}' not found. Available: {available}"
|
| 40 |
+
)
|
| 41 |
+
with open(task_path) as f:
|
| 42 |
+
return json.load(f)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class SQLEnvironment(Environment):
|
| 46 |
+
"""
|
| 47 |
+
SQL Query Writing Environment.
|
| 48 |
+
|
| 49 |
+
The agent interacts with an e-commerce SQLite database by submitting
|
| 50 |
+
SQL queries to answer natural language questions. Each query is graded
|
| 51 |
+
with a multi-component reward function providing partial credit.
|
| 52 |
+
|
| 53 |
+
Episode flow:
|
| 54 |
+
1. reset() → loads task, initializes DB, returns first question
|
| 55 |
+
2. step(SQLAction) → executes query, grades it, returns observation
|
| 56 |
+
3. Episode ends when all questions answered or max steps reached
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 60 |
+
|
| 61 |
+
def __init__(self):
|
| 62 |
+
self._db = Database()
|
| 63 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 64 |
+
self._task: dict = {}
|
| 65 |
+
self._questions: list = []
|
| 66 |
+
self._current_q_index: int = 0
|
| 67 |
+
self._q_steps_used: int = 0
|
| 68 |
+
self._max_steps_per_q: int = 3
|
| 69 |
+
self._total_steps: int = 0
|
| 70 |
+
self._rewards: list = []
|
| 71 |
+
self._schema_cache: str = ""
|
| 72 |
+
self._done: bool = False
|
| 73 |
+
self._last_feedback: str = ""
|
| 74 |
+
|
| 75 |
+
def reset(self) -> SQLObservation:
|
| 76 |
+
"""
|
| 77 |
+
Reset the environment: initialize DB, load task, return first question.
|
| 78 |
+
"""
|
| 79 |
+
self._db.initialize()
|
| 80 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 81 |
+
|
| 82 |
+
task_name = os.getenv("SQL_ENV_TASK", DEFAULT_TASK)
|
| 83 |
+
self._task = _load_task(task_name)
|
| 84 |
+
self._questions = self._task["questions"]
|
| 85 |
+
self._max_steps_per_q = self._task.get("max_steps_per_question", 3)
|
| 86 |
+
self._current_q_index = 0
|
| 87 |
+
self._q_steps_used = 0
|
| 88 |
+
self._total_steps = 0
|
| 89 |
+
self._rewards = []
|
| 90 |
+
self._done = False
|
| 91 |
+
self._last_feedback = ""
|
| 92 |
+
self._schema_cache = self._db.get_schema_description()
|
| 93 |
+
|
| 94 |
+
return self._make_observation(
|
| 95 |
+
reward=0.0,
|
| 96 |
+
query_result="",
|
| 97 |
+
error="",
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def step(self, action: SQLAction) -> SQLObservation: # type: ignore[override]
|
| 101 |
+
"""
|
| 102 |
+
Execute the agent's SQL query, grade it, and return observation.
|
| 103 |
+
"""
|
| 104 |
+
# Auto-reset if step called before reset (HTTP stateless mode)
|
| 105 |
+
if not self._questions:
|
| 106 |
+
self.reset()
|
| 107 |
+
|
| 108 |
+
if self._done or self._current_q_index >= len(self._questions):
|
| 109 |
+
self._done = True
|
| 110 |
+
return self._make_observation(
|
| 111 |
+
reward=0.0,
|
| 112 |
+
query_result="Episode is over. Call reset() to start a new episode.",
|
| 113 |
+
error="",
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
self._state.step_count += 1
|
| 117 |
+
self._total_steps += 1
|
| 118 |
+
self._q_steps_used += 1
|
| 119 |
+
|
| 120 |
+
# Get current question
|
| 121 |
+
question = self._questions[self._current_q_index]
|
| 122 |
+
|
| 123 |
+
# Grade the query
|
| 124 |
+
grade_result = grade_query(
|
| 125 |
+
db=self._db,
|
| 126 |
+
agent_sql=action.query,
|
| 127 |
+
expected_columns=question["expected_columns"],
|
| 128 |
+
expected_rows=question["expected_rows"],
|
| 129 |
+
order_matters=question.get("order_matters", True),
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
raw_reward = grade_result["reward"]
|
| 133 |
+
|
| 134 |
+
# Apply step penalty (not on first attempt)
|
| 135 |
+
penalty = STEP_PENALTY * (self._q_steps_used - 1)
|
| 136 |
+
reward = max(raw_reward - penalty, 0.0)
|
| 137 |
+
reward = round(reward, 4)
|
| 138 |
+
|
| 139 |
+
self._rewards.append(reward)
|
| 140 |
+
self._last_feedback = grade_result["feedback"]
|
| 141 |
+
|
| 142 |
+
# Format query result for observation
|
| 143 |
+
query_result_str = grade_result["query_result"].to_display_string()
|
| 144 |
+
error_str = grade_result["query_result"].error or ""
|
| 145 |
+
|
| 146 |
+
# Check if we should move to next question
|
| 147 |
+
perfect = grade_result["exact_score"] == 1.0
|
| 148 |
+
out_of_attempts = self._q_steps_used >= self._max_steps_per_q
|
| 149 |
+
move_on = perfect or out_of_attempts
|
| 150 |
+
|
| 151 |
+
if move_on:
|
| 152 |
+
self._current_q_index += 1
|
| 153 |
+
self._q_steps_used = 0
|
| 154 |
+
|
| 155 |
+
# Check if episode is done
|
| 156 |
+
if self._current_q_index >= len(self._questions):
|
| 157 |
+
self._done = True
|
| 158 |
+
if self._total_steps >= MAX_TOTAL_STEPS:
|
| 159 |
+
self._done = True
|
| 160 |
+
|
| 161 |
+
return self._make_observation(
|
| 162 |
+
reward=reward,
|
| 163 |
+
query_result=query_result_str,
|
| 164 |
+
error=error_str,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
@property
|
| 168 |
+
def state(self) -> State:
|
| 169 |
+
return self._state
|
| 170 |
+
|
| 171 |
+
def _make_observation(
|
| 172 |
+
self,
|
| 173 |
+
reward: float,
|
| 174 |
+
query_result: str,
|
| 175 |
+
error: str,
|
| 176 |
+
) -> SQLObservation:
|
| 177 |
+
"""Build an SQLObservation for the current state."""
|
| 178 |
+
if self._done or not self._questions or self._current_q_index >= len(self._questions):
|
| 179 |
+
# Episode finished or not started
|
| 180 |
+
return SQLObservation(
|
| 181 |
+
task_name=self._task.get("task_name", ""),
|
| 182 |
+
question="Episode complete. All questions answered.",
|
| 183 |
+
schema_description="",
|
| 184 |
+
query_result=query_result,
|
| 185 |
+
error=error,
|
| 186 |
+
steps_remaining=0,
|
| 187 |
+
question_index=len(self._questions),
|
| 188 |
+
total_questions=len(self._questions),
|
| 189 |
+
done=True,
|
| 190 |
+
reward=reward,
|
| 191 |
+
metadata={
|
| 192 |
+
"feedback": self._last_feedback,
|
| 193 |
+
"total_reward": round(sum(self._rewards), 4),
|
| 194 |
+
"rewards": [round(r, 4) for r in self._rewards],
|
| 195 |
+
},
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
question = self._questions[self._current_q_index]
|
| 199 |
+
steps_remaining = self._max_steps_per_q - self._q_steps_used
|
| 200 |
+
|
| 201 |
+
return SQLObservation(
|
| 202 |
+
task_name=self._task.get("task_name", ""),
|
| 203 |
+
question=question["question"],
|
| 204 |
+
schema_description=self._schema_cache,
|
| 205 |
+
query_result=query_result,
|
| 206 |
+
error=error,
|
| 207 |
+
steps_remaining=steps_remaining,
|
| 208 |
+
question_index=self._current_q_index + 1,
|
| 209 |
+
total_questions=len(self._questions),
|
| 210 |
+
done=False,
|
| 211 |
+
reward=reward,
|
| 212 |
+
metadata={
|
| 213 |
+
"feedback": self._last_feedback,
|
| 214 |
+
"question_id": question["id"],
|
| 215 |
+
"difficulty": self._task.get("difficulty", ""),
|
| 216 |
+
},
|
| 217 |
+
)
|