nl2sql-bench / README.md
ritvik360's picture
Upload folder using huggingface_hub
fcc5471 verified
---
title: NL2SQL Bench
emoji: πŸ“Š
colorFrom: blue
colorTo: indigo
sdk: docker
pinned: false
---
# NL2SQL-Bench
**Natural Language to SQL Analytics Environment for RL Training**
[![openenv](https://img.shields.io/badge/openenv-compatible-blue)](https://github.com/meta-pytorch/OpenEnv)
[![Python 3.10+](https://img.shields.io/badge/python-3.10+-green)](https://www.python.org)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow)](LICENSE)
---
## What is NL2SQL-Bench?
NL2SQL-Bench is an OpenEnv-compliant RL training environment where an AI agent must iteratively write and refine **SQLite queries** to answer natural-language business questions against a synthetic e-commerce database.
This fills a genuine gap in the OpenEnv ecosystem β€” no SQL query environment currently exists. Every data-driven company employs analysts who translate business questions into SQL. Training agents to do this well (and to recover from errors) is immediately valuable.
**Why it's a great RL domain:**
- Rewards are **100% deterministic** β€” no LLM-as-judge, no subjectivity
- Multi-turn episodes create **dense reward signal** across the trajectory
- The error β†’ fix β†’ retry loop is a novel mechanic not present in existing environments
- Three clearly graduated difficulty levels challenge models across the full skill range
---
## Environment Description
The agent interacts with a **synthetic e-commerce SQLite database** containing ~150 customers, 64 products across 8 categories, ~600 orders, ~1000 order items, and ~400 reviews. The database is seeded deterministically (seed=42) so results are reproducible across any machine.
The agent receives a natural-language question and iteratively submits SQL queries. Each query is executed, graded against the ground truth, and the reward + error/result is fed back as the next observation.
---
## Database Schema
```
categories(id, name)
products(id, name, category_id, price, stock_quantity)
customers(id, name, email, country, tier∈{bronze|silver|gold}, created_at)
orders(id, customer_id, status∈{pending|processing|shipped|delivered|cancelled},
created_at, total_amount)
order_items(id, order_id, product_id, quantity, unit_price)
reviews(id, product_id, customer_id, rating∈1-5, created_at)
```
All dates are ISO-8601 strings sortable by text comparison. SQLite window functions and CTEs are fully supported.
---
## Action & Observation Space
### Action
```python
@dataclass
class NL2SQLAction(Action):
query: str # A SQLite SELECT query string
```
### Observation
```python
@dataclass
class NL2SQLObservation(Observation):
question: str # The NL question to answer
schema_context: str # Compact schema description
task_name: str # Active task identifier
last_query: str # SQL submitted on previous step
last_result: List[Dict] # Up to 10 result rows
last_error: Optional[str] # SQLite error string or None
result_columns: List[str] # Column names of last_result
step: int # Current step (0 after reset)
max_steps: int # Maximum steps per episode
done: bool # Episode ended?
reward: Optional[float] # Step reward [0.0, 1.0]
score: float # Cumulative normalised score
```
---
## Tasks & Expected Difficulty
### Task 1 β€” `simple-filter` (easy)
Single-table SELECT queries with WHERE, ORDER BY, LIMIT. Tests basic SQL fluency. Example questions:
- "List all gold-tier customers ordered by name alphabetically."
- "Return the top 5 most expensive products."
**Expected solve rate (frontier model, 5 steps):** ~80%
### Task 2 β€” `join-aggregation` (medium)
Multi-table JOINs with GROUP BY, HAVING, and aggregation functions. Example questions:
- "How many orders has each customer placed? Include customers with zero orders."
- "Which customers have spent more than $500 total on delivered orders?"
**Expected solve rate (frontier model, 5 steps):** ~55%
### Task 3 β€” `analytics-window` (hard)
CTEs, window functions (DENSE_RANK, ROW_NUMBER, running SUM), and nested subqueries. Example questions:
- "Rank customers by total spending using DENSE_RANK."
- "Show monthly revenue and running total for delivered orders in 2024."
**Expected solve rate (frontier model, 5 steps):** ~30%
---
## Reward Function
Rewards are computed by deterministic comparison of the agent's result set against the ground truth:
| Component | Score | Description |
|---|---|---|
| `syntax_ok` | +0.10 | Query runs without SQLite error |
| `columns_match` | +0.20 | Returned column names match ground truth |
| `row_count_match` | +0.20 | Number of rows matches |
| `exact_match` | +0.50 | Full result set equals ground truth |
| `step_penalty` | βˆ’0.05/step | Deducted per step beyond the first |
Final reward is clamped to `[0.0, 1.0]`. Order sensitivity matches the ground-truth query: ORDER BY queries require correct row ordering; others are order-agnostic.
---
## Baseline Scores
Run by the `inference.py` script using `Qwen/Qwen2.5-72B-Instruct` via HuggingFace router:
| Task | Expected Score |
|---|---|
| `simple-filter` | ~0.70 |
| `join-aggregation` | ~0.45 |
| `analytics-window` | ~0.25 |
---
## Setup & Usage
### Prerequisites
- Python 3.10+
- Docker (for containerised deployment)
- A HuggingFace account + token
### Local Development (no Docker)
```bash
# Clone the repository
git clone https://huggingface.co/spaces/your-username/nl2sql-bench
cd nl2sql-bench
# Quick start
chmod +x scripts/run_local.sh
./scripts/run_local.sh
# Or manually:
python3 -m venv .venv && source .venv/bin/activate
pip install openenv-core fastapi "uvicorn[standard]" openai pydantic
export PYTHONPATH=".:server"
cd server && uvicorn app:app --reload --port 8000
```
### Test the Running Server
```bash
# Run smoke tests
chmod +x scripts/smoke_test.sh
./scripts/smoke_test.sh http://localhost:8000
# Run full test suite
pip install pytest pytest-asyncio
PYTHONPATH=".:server" pytest tests/ -v
```
### Docker
```bash
# Build
docker build -t nl2sql-bench:latest .
# Run
docker run -p 7860:7860 nl2sql-bench:latest
# Test
./scripts/smoke_test.sh http://localhost:7860
```
### Pre-submission Validation
```bash
# Run the official validator (replace with your HF Space URL)
chmod +x pre_validation_script.sh
./pre_validation_script.sh https://your-username-nl2sql-bench.hf.space .
```
### Running the Baseline Inference
```bash
# Set mandatory variables
export API_BASE_URL="https://router.huggingface.co/v1"
export MODEL_NAME="Qwen/Qwen2.5-72B-Instruct"
export HF_TOKEN="hf_your_token_here"
export SPACE_URL="https://your-username-nl2sql-bench.hf.space"
python inference.py
```
### Using the Client Programmatically
```python
import asyncio
from client import NL2SQLEnv
from models import NL2SQLAction
async def main():
async with NL2SQLEnv(base_url="http://localhost:8000") as env:
result = await env.reset()
print(result.observation.question)
result = await env.step(NL2SQLAction(
query="SELECT id, name FROM customers WHERE tier='gold' ORDER BY name"
))
print(f"Reward: {result.reward:.2f}")
print(f"Done: {result.done}")
print(f"Error: {result.observation.last_error}")
asyncio.run(main())
```
---
## Project Structure
```
nl2sql-bench/
β”œβ”€β”€ models.py # NL2SQLAction, NL2SQLObservation, NL2SQLState
β”œβ”€β”€ client.py # NL2SQLEnv(HTTPEnvClient)
β”œβ”€β”€ inference.py # Baseline inference script (mandatory name)
β”œβ”€β”€ openenv.yaml # OpenEnv manifest
β”œβ”€β”€ pyproject.toml
β”œβ”€β”€ Dockerfile # HF Spaces compatible (port 7860)
β”œβ”€β”€ .env.example
β”œβ”€β”€ server/
β”‚ β”œβ”€β”€ app.py # FastAPI entry point
β”‚ β”œβ”€β”€ environment.py # Core RL environment logic
β”‚ β”œβ”€β”€ grader.py # Deterministic reward computation
β”‚ β”œβ”€β”€ requirements.txt
β”‚ β”œβ”€β”€ db/
β”‚ β”‚ β”œβ”€β”€ schema.sql # 6-table e-commerce schema
β”‚ β”‚ └── seed.py # Deterministic data generator (seed=42)
β”‚ └── tasks/
β”‚ β”œβ”€β”€ base.py # BaseTask + registry
β”‚ β”œβ”€β”€ easy.py # simple-filter (5 examples)
β”‚ β”œβ”€β”€ medium.py # join-aggregation (5 examples)
β”‚ └── hard.py # analytics-window (5 examples)
β”œβ”€β”€ tests/
β”‚ β”œβ”€β”€ conftest.py
β”‚ └── test_all.py # 30+ pytest tests
└── scripts/
β”œβ”€β”€ run_local.sh # Local dev server
└── smoke_test.sh # Endpoint smoke tests
```
---
## Design Decisions
**Why SQLite in-memory?** Zero runtime dependency, deterministic, and it runs comfortably within the 2 vCPU / 8 GB constraint. The database loads in ~50ms.
**Why multi-turn (up to 5 steps)?** A single-shot SQL environment gives binary rewards. Multi-turn with error feedback gives the agent β€” and the GRPO trainer β€” a rich signal: the model learns not just to write SQL, but to debug and refine its queries.
**Why step penalty?** Without it, an agent that accidentally gets the right answer on step 5 scores the same as one that gets it on step 1. The penalty creates pressure to solve efficiently, which is realistic.
**Why order-sensitive comparison for ORDER BY queries?** Business questions that say "rank by spending" expect a ranked output. Order-agnostic comparison would give spurious credit.
---
## License
MIT β€” see [LICENSE](LICENSE)