UtkarshSatav commited on
Commit
08b82d0
·
verified ·
1 Parent(s): 6b8948a

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.egg-info/
5
+ dist/
6
+ build/
7
+ .venv/
8
+ venv/
9
+ *.db
10
+ *.sqlite
11
+ .env
12
+ .DS_Store
13
+ uv.lock
Dockerfile ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && \
7
+ apt-get install -y --no-install-recommends curl && \
8
+ rm -rf /var/lib/apt/lists/*
9
+
10
+ # Copy requirements first for better caching
11
+ COPY server/requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy project files
15
+ COPY . .
16
+
17
+ # Expose HF Spaces default port
18
+ EXPOSE 7860
19
+
20
+ # Health check
21
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \
22
+ CMD curl -f http://localhost:7860/health || exit 1
23
+
24
+ # Run the FastAPI server on port 7860 (HF Spaces default)
25
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,10 +1,255 @@
1
  ---
2
- title: Sql Env
3
- emoji: 📚
4
- colorFrom: gray
5
- colorTo: blue
6
  sdk: docker
7
  pinned: false
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Sql Env Environment Server
3
+ emoji: 🗜️
4
+ colorFrom: blue
5
+ colorTo: gray
6
  sdk: docker
7
  pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
  ---
13
 
14
+ # Sql Env Environment
15
+
16
+ A simple test environment that echoes back messages. Perfect for testing the env APIs as well as demonstrating environment usage patterns.
17
+
18
+ ## Quick Start
19
+
20
+ The simplest way to use the Sql Env environment is through the `SqlEnv` class:
21
+
22
+ ```python
23
+ from sql_env import SqlAction, SqlEnv
24
+
25
+ try:
26
+ # Create environment from Docker image
27
+ sql_envenv = SqlEnv.from_docker_image("sql_env-env:latest")
28
+
29
+ # Reset
30
+ result = sql_envenv.reset()
31
+ print(f"Reset: {result.observation.echoed_message}")
32
+
33
+ # Send multiple messages
34
+ messages = ["Hello, World!", "Testing echo", "Final message"]
35
+
36
+ for msg in messages:
37
+ result = sql_envenv.step(SqlAction(message=msg))
38
+ print(f"Sent: '{msg}'")
39
+ print(f" → Echoed: '{result.observation.echoed_message}'")
40
+ print(f" → Length: {result.observation.message_length}")
41
+ print(f" → Reward: {result.reward}")
42
+
43
+ finally:
44
+ # Always clean up
45
+ sql_envenv.close()
46
+ ```
47
+
48
+ That's it! The `SqlEnv.from_docker_image()` method handles:
49
+ - Starting the Docker container
50
+ - Waiting for the server to be ready
51
+ - Connecting to the environment
52
+ - Container cleanup when you call `close()`
53
+
54
+ ## Building the Docker Image
55
+
56
+ Before using the environment, you need to build the Docker image:
57
+
58
+ ```bash
59
+ # From project root
60
+ docker build -t sql_env-env:latest -f server/Dockerfile .
61
+ ```
62
+
63
+ ## Deploying to Hugging Face Spaces
64
+
65
+ You can easily deploy your OpenEnv environment to Hugging Face Spaces using the `openenv push` command:
66
+
67
+ ```bash
68
+ # From the environment directory (where openenv.yaml is located)
69
+ openenv push
70
+
71
+ # Or specify options
72
+ openenv push --namespace my-org --private
73
+ ```
74
+
75
+ The `openenv push` command will:
76
+ 1. Validate that the directory is an OpenEnv environment (checks for `openenv.yaml`)
77
+ 2. Prepare a custom build for Hugging Face Docker space (enables web interface)
78
+ 3. Upload to Hugging Face (ensuring you're logged in)
79
+
80
+ ### Prerequisites
81
+
82
+ - Authenticate with Hugging Face: The command will prompt for login if not already authenticated
83
+
84
+ ### Options
85
+
86
+ - `--directory`, `-d`: Directory containing the OpenEnv environment (defaults to current directory)
87
+ - `--repo-id`, `-r`: Repository ID in format 'username/repo-name' (defaults to 'username/env-name' from openenv.yaml)
88
+ - `--base-image`, `-b`: Base Docker image to use (overrides Dockerfile FROM)
89
+ - `--private`: Deploy the space as private (default: public)
90
+
91
+ ### Examples
92
+
93
+ ```bash
94
+ # Push to your personal namespace (defaults to username/env-name from openenv.yaml)
95
+ openenv push
96
+
97
+ # Push to a specific repository
98
+ openenv push --repo-id my-org/my-env
99
+
100
+ # Push with a custom base image
101
+ openenv push --base-image ghcr.io/meta-pytorch/openenv-base:latest
102
+
103
+ # Push as a private space
104
+ openenv push --private
105
+
106
+ # Combine options
107
+ openenv push --repo-id my-org/my-env --base-image custom-base:latest --private
108
+ ```
109
+
110
+ After deployment, your space will be available at:
111
+ `https://huggingface.co/spaces/<repo-id>`
112
+
113
+ The deployed space includes:
114
+ - **Web Interface** at `/web` - Interactive UI for exploring the environment
115
+ - **API Documentation** at `/docs` - Full OpenAPI/Swagger interface
116
+ - **Health Check** at `/health` - Container health monitoring
117
+ - **WebSocket** at `/ws` - Persistent session endpoint for low-latency interactions
118
+
119
+ ## Environment Details
120
+
121
+ ### Action
122
+ **SqlAction**: Contains a single field
123
+ - `message` (str) - The message to echo back
124
+
125
+ ### Observation
126
+ **SqlObservation**: Contains the echo response and metadata
127
+ - `echoed_message` (str) - The message echoed back
128
+ - `message_length` (int) - Length of the message
129
+ - `reward` (float) - Reward based on message length (length × 0.1)
130
+ - `done` (bool) - Always False for echo environment
131
+ - `metadata` (dict) - Additional info like step count
132
+
133
+ ### Reward
134
+ The reward is calculated as: `message_length × 0.1`
135
+ - "Hi" → reward: 0.2
136
+ - "Hello, World!" → reward: 1.3
137
+ - Empty message → reward: 0.0
138
+
139
+ ## Advanced Usage
140
+
141
+ ### Connecting to an Existing Server
142
+
143
+ If you already have a Sql Env environment server running, you can connect directly:
144
+
145
+ ```python
146
+ from sql_env import SqlEnv
147
+
148
+ # Connect to existing server
149
+ sql_envenv = SqlEnv(base_url="<ENV_HTTP_URL_HERE>")
150
+
151
+ # Use as normal
152
+ result = sql_envenv.reset()
153
+ result = sql_envenv.step(SqlAction(message="Hello!"))
154
+ ```
155
+
156
+ Note: When connecting to an existing server, `sql_envenv.close()` will NOT stop the server.
157
+
158
+ ### Using the Context Manager
159
+
160
+ The client supports context manager usage for automatic connection management:
161
+
162
+ ```python
163
+ from sql_env import SqlAction, SqlEnv
164
+
165
+ # Connect with context manager (auto-connects and closes)
166
+ with SqlEnv(base_url="http://localhost:8000") as env:
167
+ result = env.reset()
168
+ print(f"Reset: {result.observation.echoed_message}")
169
+ # Multiple steps with low latency
170
+ for msg in ["Hello", "World", "!"]:
171
+ result = env.step(SqlAction(message=msg))
172
+ print(f"Echoed: {result.observation.echoed_message}")
173
+ ```
174
+
175
+ The client uses WebSocket connections for:
176
+ - **Lower latency**: No HTTP connection overhead per request
177
+ - **Persistent session**: Server maintains your environment state
178
+ - **Efficient for episodes**: Better for many sequential steps
179
+
180
+ ### Concurrent WebSocket Sessions
181
+
182
+ The server supports multiple concurrent WebSocket connections. To enable this,
183
+ modify `server/app.py` to use factory mode:
184
+
185
+ ```python
186
+ # In server/app.py - use factory mode for concurrent sessions
187
+ app = create_app(
188
+ SqlEnvironment, # Pass class, not instance
189
+ SqlAction,
190
+ SqlObservation,
191
+ max_concurrent_envs=4, # Allow 4 concurrent sessions
192
+ )
193
+ ```
194
+
195
+ Then multiple clients can connect simultaneously:
196
+
197
+ ```python
198
+ from sql_env import SqlAction, SqlEnv
199
+ from concurrent.futures import ThreadPoolExecutor
200
+
201
+ def run_episode(client_id: int):
202
+ with SqlEnv(base_url="http://localhost:8000") as env:
203
+ result = env.reset()
204
+ for i in range(10):
205
+ result = env.step(SqlAction(message=f"Client {client_id}, step {i}"))
206
+ return client_id, result.observation.message_length
207
+
208
+ # Run 4 episodes concurrently
209
+ with ThreadPoolExecutor(max_workers=4) as executor:
210
+ results = list(executor.map(run_episode, range(4)))
211
+ ```
212
+
213
+ ## Development & Testing
214
+
215
+ ### Direct Environment Testing
216
+
217
+ Test the environment logic directly without starting the HTTP server:
218
+
219
+ ```bash
220
+ # From the server directory
221
+ python3 server/sql_env_environment.py
222
+ ```
223
+
224
+ This verifies that:
225
+ - Environment resets correctly
226
+ - Step executes actions properly
227
+ - State tracking works
228
+ - Rewards are calculated correctly
229
+
230
+ ### Running Locally
231
+
232
+ Run the server locally for development:
233
+
234
+ ```bash
235
+ uvicorn server.app:app --reload
236
+ ```
237
+
238
+ ## Project Structure
239
+
240
+ ```
241
+ sql_env/
242
+ ├── .dockerignore # Docker build exclusions
243
+ ├── __init__.py # Module exports
244
+ ├── README.md # This file
245
+ ├── openenv.yaml # OpenEnv manifest
246
+ ├── pyproject.toml # Project metadata and dependencies
247
+ ├── uv.lock # Locked dependencies (generated)
248
+ ├── client.py # SqlEnv client
249
+ ├── models.py # Action and Observation models
250
+ └── server/
251
+ ├── __init__.py # Server module exports
252
+ ├── sql_env_environment.py # Core environment logic
253
+ ├── app.py # FastAPI application (HTTP + WebSocket endpoints)
254
+ └── Dockerfile # Container image definition
255
+ ```
__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """SQL Query Writing Environment."""
2
+
3
+ from .models import SQLAction, SQLObservation
4
+
5
+ __all__ = [
6
+ "SQLAction",
7
+ "SQLObservation",
8
+ ]
client.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SQL Query Writing Environment Client."""
2
+
3
+ from typing import Dict
4
+
5
+ from openenv.core import EnvClient
6
+ from openenv.core.client_types import StepResult
7
+ from openenv.core.env_server.types import State
8
+
9
+ from .models import SQLAction, SQLObservation
10
+
11
+
12
+ class SQLEnvClient(
13
+ EnvClient[SQLAction, SQLObservation, State]
14
+ ):
15
+ """
16
+ Client for the SQL Query Writing Environment.
17
+
18
+ Example:
19
+ >>> with SQLEnvClient(base_url="http://localhost:8000") as client:
20
+ ... result = client.reset()
21
+ ... print(result.observation.question)
22
+ ... result = client.step(SQLAction(query="SELECT * FROM customers"))
23
+ ... print(result.observation.query_result)
24
+ """
25
+
26
+ def _step_payload(self, action: SQLAction) -> Dict:
27
+ return {"query": action.query}
28
+
29
+ def _parse_result(self, payload: Dict) -> StepResult[SQLObservation]:
30
+ obs_data = payload.get("observation", {})
31
+ observation = SQLObservation(
32
+ task_name=obs_data.get("task_name", ""),
33
+ question=obs_data.get("question", ""),
34
+ schema_description=obs_data.get("schema_description", ""),
35
+ query_result=obs_data.get("query_result", ""),
36
+ error=obs_data.get("error", ""),
37
+ steps_remaining=obs_data.get("steps_remaining", 0),
38
+ question_index=obs_data.get("question_index", 0),
39
+ total_questions=obs_data.get("total_questions", 0),
40
+ done=payload.get("done", False),
41
+ reward=payload.get("reward"),
42
+ metadata=obs_data.get("metadata", {}),
43
+ )
44
+ return StepResult(
45
+ observation=observation,
46
+ reward=payload.get("reward"),
47
+ done=payload.get("done", False),
48
+ )
49
+
50
+ def _parse_state(self, payload: Dict) -> State:
51
+ return State(
52
+ episode_id=payload.get("episode_id"),
53
+ step_count=payload.get("step_count", 0),
54
+ )
data/schema.sql ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- SQLEnv E-Commerce Database Schema
2
+ -- Designed for text-to-SQL agent training with 3 difficulty levels
3
+
4
+ CREATE TABLE IF NOT EXISTS customers (
5
+ id INTEGER PRIMARY KEY,
6
+ name TEXT NOT NULL,
7
+ email TEXT NOT NULL UNIQUE,
8
+ age INTEGER NOT NULL,
9
+ city TEXT NOT NULL,
10
+ signup_date TEXT NOT NULL -- ISO format: YYYY-MM-DD
11
+ );
12
+
13
+ CREATE TABLE IF NOT EXISTS products (
14
+ id INTEGER PRIMARY KEY,
15
+ name TEXT NOT NULL,
16
+ category TEXT NOT NULL, -- Electronics, Clothing, Books, Home
17
+ price REAL NOT NULL,
18
+ stock INTEGER NOT NULL DEFAULT 0
19
+ );
20
+
21
+ CREATE TABLE IF NOT EXISTS orders (
22
+ id INTEGER PRIMARY KEY,
23
+ customer_id INTEGER NOT NULL,
24
+ order_date TEXT NOT NULL, -- ISO format: YYYY-MM-DD
25
+ status TEXT NOT NULL, -- pending, shipped, delivered, cancelled
26
+ total_amount REAL NOT NULL,
27
+ FOREIGN KEY (customer_id) REFERENCES customers(id)
28
+ );
29
+
30
+ CREATE TABLE IF NOT EXISTS order_items (
31
+ id INTEGER PRIMARY KEY,
32
+ order_id INTEGER NOT NULL,
33
+ product_id INTEGER NOT NULL,
34
+ quantity INTEGER NOT NULL,
35
+ unit_price REAL NOT NULL,
36
+ FOREIGN KEY (order_id) REFERENCES orders(id),
37
+ FOREIGN KEY (product_id) REFERENCES products(id)
38
+ );
39
+
40
+ CREATE TABLE IF NOT EXISTS reviews (
41
+ id INTEGER PRIMARY KEY,
42
+ product_id INTEGER NOT NULL,
43
+ customer_id INTEGER NOT NULL,
44
+ rating INTEGER NOT NULL CHECK (rating >= 1 AND rating <= 5),
45
+ review_text TEXT NOT NULL,
46
+ review_date TEXT NOT NULL, -- ISO format: YYYY-MM-DD
47
+ FOREIGN KEY (product_id) REFERENCES products(id),
48
+ FOREIGN KEY (customer_id) REFERENCES customers(id)
49
+ );
data/seed.sql ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- SQLEnv Deterministic Seed Data
2
+ -- Carefully crafted to support easy/medium/hard SQL tasks
3
+
4
+ -- ============================================================
5
+ -- CUSTOMERS (20 rows)
6
+ -- ============================================================
7
+ INSERT INTO customers (id, name, email, age, city, signup_date) VALUES
8
+ (1, 'Aarav Sharma', 'aarav@example.com', 28, 'Mumbai', '2023-03-15'),
9
+ (2, 'Priya Patel', 'priya@example.com', 35, 'Delhi', '2023-05-20'),
10
+ (3, 'Rahul Kumar', 'rahul@example.com', 42, 'Bangalore', '2023-01-10'),
11
+ (4, 'Sneha Gupta', 'sneha@example.com', 24, 'Mumbai', '2024-02-14'),
12
+ (5, 'Vikram Singh', 'vikram@example.com', 31, 'Chennai', '2023-07-01'),
13
+ (6, 'Ananya Reddy', 'ananya@example.com', 29, 'Hyderabad', '2023-09-05'),
14
+ (7, 'Arjun Nair', 'arjun@example.com', 38, 'Bangalore', '2023-04-18'),
15
+ (8, 'Kavita Joshi', 'kavita@example.com', 45, 'Pune', '2023-02-28'),
16
+ (9, 'Deepak Verma', 'deepak@example.com', 33, 'Delhi', '2024-01-05'),
17
+ (10, 'Meera Iyer', 'meera@example.com', 27, 'Chennai', '2023-11-12'),
18
+ (11, 'Rohan Das', 'rohan@example.com', 36, 'Kolkata', '2023-06-22'),
19
+ (12, 'Pooja Mishra', 'pooja@example.com', 41, 'Pune', '2023-08-15'),
20
+ (13, 'Suresh Menon', 'suresh@example.com', 50, 'Mumbai', '2023-03-30'),
21
+ (14, 'Nisha Agarwal', 'nisha@example.com', 23, 'Delhi', '2024-03-01'),
22
+ (15, 'Amit Pandey', 'amit@example.com', 34, 'Bangalore', '2023-10-10'),
23
+ (16, 'Ritu Chopra', 'ritu@example.com', 30, 'Hyderabad', '2023-12-25'),
24
+ (17, 'Karan Malhotra', 'karan@example.com', 26, 'Mumbai', '2024-01-20'),
25
+ (18, 'Divya Saxena', 'divya@example.com', 39, 'Chennai', '2023-05-05'),
26
+ (19, 'Nikhil Bhat', 'nikhil@example.com', 32, 'Kolkata', '2023-07-14'),
27
+ (20, 'Swati Tiwari', 'swati@example.com', 44, 'Pune', '2023-04-02');
28
+
29
+ -- ============================================================
30
+ -- PRODUCTS (15 rows, 4 categories)
31
+ -- ============================================================
32
+ INSERT INTO products (id, name, category, price, stock) VALUES
33
+ (1, 'Wireless Headphones', 'Electronics', 2499.00, 50),
34
+ (2, 'Smartphone Case', 'Electronics', 499.00, 200),
35
+ (3, 'Bluetooth Speaker', 'Electronics', 3999.00, 30),
36
+ (4, 'USB-C Cable', 'Electronics', 199.00, 500),
37
+ (5, 'Cotton T-Shirt', 'Clothing', 599.00, 150),
38
+ (6, 'Denim Jeans', 'Clothing', 1499.00, 80),
39
+ (7, 'Running Shoes', 'Clothing', 2999.00, 40),
40
+ (8, 'Winter Jacket', 'Clothing', 3499.00, 25),
41
+ (9, 'Python Programming', 'Books', 449.00, 100),
42
+ (10, 'Data Science Handbook', 'Books', 699.00, 60),
43
+ (11, 'Mystery Novel', 'Books', 299.00, 120),
44
+ (12, 'Cooking Recipes', 'Books', 399.00, 90),
45
+ (13, 'Ceramic Mug Set', 'Home', 799.00, 70),
46
+ (14, 'Desk Lamp', 'Home', 1299.00, 45),
47
+ (15, 'Plant Pot', 'Home', 349.00, 110);
48
+
49
+ -- ============================================================
50
+ -- ORDERS (30 rows)
51
+ -- ============================================================
52
+ INSERT INTO orders (id, customer_id, order_date, status, total_amount) VALUES
53
+ (1, 1, '2024-01-15', 'delivered', 2998.00),
54
+ (2, 1, '2024-03-20', 'delivered', 499.00),
55
+ (3, 2, '2024-02-10', 'delivered', 4498.00),
56
+ (4, 3, '2024-01-05', 'delivered', 1798.00),
57
+ (5, 3, '2024-04-12', 'shipped', 3999.00),
58
+ (6, 2, '2024-03-01', 'delivered', 599.00),
59
+ (7, 5, '2024-02-28', 'delivered', 5997.00),
60
+ (8, 5, '2024-05-15', 'shipped', 1299.00),
61
+ (9, 6, '2024-01-20', 'delivered', 898.00),
62
+ (10, 7, '2024-03-10', 'delivered', 2499.00),
63
+ (11, 7, '2024-06-01', 'pending', 699.00),
64
+ (12, 8, '2024-02-14', 'delivered', 4998.00),
65
+ (13, 8, '2024-04-25', 'cancelled', 1499.00),
66
+ (14, 9, '2024-03-05', 'delivered', 848.00),
67
+ (15, 10, '2024-01-30', 'delivered', 2999.00),
68
+ (16, 10, '2024-05-20', 'shipped', 449.00),
69
+ (17, 11, '2024-02-18', 'delivered', 1598.00),
70
+ (18, 12, '2024-03-22', 'delivered', 3499.00),
71
+ (19, 13, '2024-04-08', 'shipped', 798.00),
72
+ (20, 13, '2024-01-12', 'delivered', 2499.00),
73
+ (21, 11, '2024-05-01', 'pending', 599.00),
74
+ (22, 15, '2024-02-05', 'delivered', 1748.00),
75
+ (23, 15, '2024-06-10', 'pending', 299.00),
76
+ (24, 16, '2024-03-15', 'delivered', 999.00),
77
+ (25, 17, '2024-04-20', 'shipped', 2499.00),
78
+ (26, 18, '2024-01-25', 'delivered', 4498.00),
79
+ (27, 18, '2024-05-30', 'delivered', 699.00),
80
+ (28, 19, '2024-02-22', 'delivered', 1299.00),
81
+ (29, 19, '2024-06-05', 'cancelled', 399.00),
82
+ (30, 20, '2024-03-28', 'delivered', 3798.00);
83
+
84
+ -- ============================================================
85
+ -- ORDER_ITEMS (60 rows)
86
+ -- ============================================================
87
+ INSERT INTO order_items (id, order_id, product_id, quantity, unit_price) VALUES
88
+ -- Order 1: customer 1 bought headphones + smartphone case
89
+ (1, 1, 1, 1, 2499.00),
90
+ (2, 1, 2, 1, 499.00),
91
+ -- Order 2: customer 1 bought smartphone case
92
+ (3, 2, 2, 1, 499.00),
93
+ -- Order 3: customer 2 bought headphones + jeans + smartphone case
94
+ (4, 3, 1, 1, 2499.00),
95
+ (5, 3, 6, 1, 1499.00),
96
+ (6, 3, 2, 1, 499.00),
97
+ -- Order 4: customer 3 bought desk lamp + smartphone case
98
+ (7, 4, 14, 1, 1299.00),
99
+ (8, 4, 2, 1, 499.00),
100
+ -- Order 5: customer 3 bought bluetooth speaker
101
+ (9, 5, 3, 1, 3999.00),
102
+ -- Order 6: customer 2 bought t-shirt
103
+ (10, 6, 5, 1, 599.00),
104
+ -- Order 7: customer 5 bought running shoes x2
105
+ (11, 7, 7, 2, 2999.00),
106
+ -- Order 8: customer 5 bought desk lamp
107
+ (12, 8, 14, 1, 1299.00),
108
+ -- Order 9: customer 6 bought mug set + usb cable
109
+ (13, 9, 13, 1, 799.00),
110
+ (14, 9, 4, 1, 199.00), -- total should be 998, but we set 898 for slight difference
111
+ -- Order 10: customer 7 bought headphones
112
+ (15, 10, 1, 1, 2499.00),
113
+ -- Order 11: customer 7 bought data science book
114
+ (16, 11, 10, 1, 699.00),
115
+ -- Order 12: customer 8 bought headphones x2
116
+ (17, 12, 1, 2, 2499.00),
117
+ -- Order 13: customer 8 bought jeans (cancelled)
118
+ (18, 13, 6, 1, 1499.00),
119
+ -- Order 14: customer 9 bought python book + plant pot
120
+ (19, 14, 9, 1, 449.00),
121
+ (20, 14, 15, 1, 349.00),
122
+ -- Order 15: customer 10 bought running shoes
123
+ (21, 15, 7, 1, 2999.00),
124
+ -- Order 16: customer 10 bought python book
125
+ (22, 16, 9, 1, 449.00),
126
+ -- Order 17: customer 11 bought mug set x2
127
+ (23, 17, 13, 2, 799.00),
128
+ -- Order 18: customer 12 bought winter jacket
129
+ (24, 18, 8, 1, 3499.00),
130
+ -- Order 19: customer 13 bought mug set (shipped)
131
+ (25, 19, 13, 1, 799.00),
132
+ -- Order 20: customer 13 bought headphones
133
+ (26, 20, 1, 1, 2499.00),
134
+ -- Order 21: customer 11 bought t-shirt (pending)
135
+ (27, 21, 5, 1, 599.00),
136
+ -- Order 22: customer 15 bought python book + desk lamp
137
+ (28, 22, 9, 1, 449.00),
138
+ (29, 22, 14, 1, 1299.00),
139
+ -- Order 23: customer 15 bought mystery novel (pending)
140
+ (30, 23, 11, 1, 299.00),
141
+ -- Order 24: customer 16 bought t-shirt + cooking book
142
+ (31, 24, 5, 1, 599.00),
143
+ (32, 24, 12, 1, 399.00),
144
+ -- Order 25: customer 17 bought headphones (shipped)
145
+ (33, 25, 1, 1, 2499.00),
146
+ -- Order 26: customer 18 bought bluetooth speaker + smartphone case
147
+ (34, 26, 3, 1, 3999.00),
148
+ (35, 26, 2, 1, 499.00),
149
+ -- Order 27: customer 18 bought data science book
150
+ (36, 27, 10, 1, 699.00),
151
+ -- Order 28: customer 19 bought desk lamp
152
+ (37, 28, 14, 1, 1299.00),
153
+ -- Order 29: customer 19 bought cooking book (cancelled)
154
+ (38, 29, 12, 1, 399.00),
155
+ -- Order 30: customer 20 bought running shoes + mug set
156
+ (39, 30, 7, 1, 2999.00),
157
+ (40, 30, 13, 1, 799.00),
158
+ -- Extra items to diversify category coverage
159
+ -- Order 3 (customer 2): add a book → Electronics + Clothing + Books
160
+ (41, 3, 9, 1, 449.00),
161
+ -- Order 4 (customer 3): add a t-shirt → Home + Electronics + Clothing
162
+ (42, 4, 5, 1, 599.00),
163
+ -- Order 4 (customer 3): add a book → Home + Electronics + Clothing + Books = 4 categories
164
+ (43, 4, 11, 1, 299.00),
165
+ -- Order 24 (customer 16): add a mug set → Clothing + Books + Home
166
+ (44, 24, 13, 1, 799.00),
167
+ -- Order 9 (customer 6): add a book → Home + Electronics + Books
168
+ (45, 9, 12, 1, 399.00),
169
+ -- Order 17 (customer 11): add headphones → Home + Clothing + Electronics
170
+ (46, 17, 4, 1, 199.00);
171
+
172
+ -- ============================================================
173
+ -- REVIEWS (25 rows)
174
+ -- ============================================================
175
+ INSERT INTO reviews (id, product_id, customer_id, rating, review_text, review_date) VALUES
176
+ (1, 1, 1, 5, 'Amazing sound quality, worth every penny!', '2024-02-01'),
177
+ (2, 1, 7, 4, 'Good headphones, battery could be better.', '2024-03-25'),
178
+ (3, 1, 13, 5, 'Best wireless headphones I have owned.', '2024-02-15'),
179
+ (4, 2, 1, 3, 'Decent case but feels a bit cheap.', '2024-04-01'),
180
+ (5, 2, 3, 4, 'Good fit, protects the phone well.', '2024-02-10'),
181
+ (6, 3, 18, 5, 'Incredible bass for the price!', '2024-02-20'),
182
+ (7, 3, 5, 4, 'Great speaker, slightly heavy though.', '2024-06-01'),
183
+ (8, 5, 4, 4, 'Soft cotton, very comfortable.', '2024-03-15'),
184
+ (9, 5, 16, 3, 'Okay quality, shrunk after washing.', '2024-04-05'),
185
+ (10, 6, 2, 5, 'Perfect fit, love the style!', '2024-03-10'),
186
+ (11, 7, 5, 5, 'Super comfortable for running.', '2024-03-20'),
187
+ (12, 7, 10, 4, 'Good shoes but sizing runs a bit large.', '2024-02-15'),
188
+ (13, 7, 20, 5, 'Best running shoes ever!', '2024-04-10'),
189
+ (14, 8, 12, 4, 'Warm and stylish, great for winter.', '2024-04-01'),
190
+ (15, 9, 9, 5, 'Excellent book for learning Python.', '2024-04-10'),
191
+ (16, 9, 15, 4, 'Good content, some chapters feel rushed.', '2024-03-01'),
192
+ (17, 10, 7, 3, 'Covers basics well, lacks depth on ML topics.', '2024-06-15'),
193
+ (18, 11, 15, 4, 'Gripping plot, could not put it down!', '2024-06-20'),
194
+ (19, 13, 6, 5, 'Beautiful mugs, great as a gift set.', '2024-02-10'),
195
+ (20, 13, 11, 4, 'Nice mugs, one had a small chip.', '2024-03-05'),
196
+ (21, 13, 20, 5, 'Excellent quality ceramic.', '2024-04-15'),
197
+ (22, 14, 3, 4, 'Bright and adjustable, good for desk work.', '2024-02-20'),
198
+ (23, 14, 19, 5, 'Perfect desk lamp, love the design.', '2024-03-10'),
199
+ (24, 15, 9, 3, 'Looks nice but drainage hole is too small.', '2024-04-15'),
200
+ (25, 12, 16, 4, 'Great recipes, easy to follow instructions.', '2024-04-20');
data/tasks/advanced_analytics.json ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "task_name": "advanced_analytics",
3
+ "difficulty": "hard",
4
+ "description": "Subqueries, CTEs, window functions, and complex multi-table analytics",
5
+ "max_steps_per_question": 5,
6
+ "questions": [
7
+ {
8
+ "id": "hard_1",
9
+ "question": "Find all customers whose total spending across all orders exceeds the average total spending per customer. Show customer name and total spent, sorted by total spent from highest to lowest.",
10
+ "ground_truth_sql": "SELECT c.name, SUM(o.total_amount) as total_spent FROM customers c JOIN orders o ON c.id = o.customer_id GROUP BY c.id HAVING total_spent > (SELECT AVG(total_spent) FROM (SELECT SUM(total_amount) as total_spent FROM orders GROUP BY customer_id)) ORDER BY total_spent DESC",
11
+ "expected_columns": ["name", "total_spent"],
12
+ "expected_row_count": 9,
13
+ "expected_rows": [
14
+ ["Vikram Singh", 7296.0],
15
+ ["Kavita Joshi", 6497.0],
16
+ ["Rahul Kumar", 5797.0],
17
+ ["Divya Saxena", 5197.0],
18
+ ["Priya Patel", 5097.0],
19
+ ["Swati Tiwari", 3798.0],
20
+ ["Pooja Mishra", 3499.0],
21
+ ["Aarav Sharma", 3497.0],
22
+ ["Meera Iyer", 3448.0]
23
+ ],
24
+ "order_matters": true
25
+ },
26
+ {
27
+ "id": "hard_2",
28
+ "question": "Rank all products by their total revenue (quantity * unit_price from order_items) within each product category. Show category, product name, revenue, and the rank within the category. Sort by category alphabetically, then by rank.",
29
+ "ground_truth_sql": "SELECT p.category, p.name, SUM(oi.quantity * oi.unit_price) as revenue, RANK() OVER (PARTITION BY p.category ORDER BY SUM(oi.quantity * oi.unit_price) DESC) as category_rank FROM products p JOIN order_items oi ON p.id = oi.product_id GROUP BY p.id ORDER BY p.category, category_rank",
30
+ "expected_columns": ["category", "name", "revenue", "category_rank"],
31
+ "expected_row_count": 15,
32
+ "expected_rows": [
33
+ ["Books", "Python Programming", 1796.0, 1],
34
+ ["Books", "Data Science Handbook", 1398.0, 2],
35
+ ["Books", "Cooking Recipes", 1197.0, 3],
36
+ ["Books", "Mystery Novel", 598.0, 4],
37
+ ["Clothing", "Running Shoes", 11996.0, 1],
38
+ ["Clothing", "Winter Jacket", 3499.0, 2],
39
+ ["Clothing", "Denim Jeans", 2998.0, 3],
40
+ ["Clothing", "Cotton T-Shirt", 2396.0, 4],
41
+ ["Electronics", "Wireless Headphones", 17493.0, 1],
42
+ ["Electronics", "Bluetooth Speaker", 7998.0, 2],
43
+ ["Electronics", "Smartphone Case", 2495.0, 3],
44
+ ["Electronics", "USB-C Cable", 398.0, 4],
45
+ ["Home", "Desk Lamp", 5196.0, 1],
46
+ ["Home", "Ceramic Mug Set", 4794.0, 2],
47
+ ["Home", "Plant Pot", 349.0, 3]
48
+ ],
49
+ "order_matters": true
50
+ },
51
+ {
52
+ "id": "hard_3",
53
+ "question": "Calculate the month-over-month growth in order count for 2024. Show the month (as YYYY-MM), the number of orders that month, and the change from the previous month (NULL for the first month). Sort by month.",
54
+ "ground_truth_sql": "SELECT strftime('%Y-%m', order_date) as month, COUNT(*) as order_count, COUNT(*) - LAG(COUNT(*)) OVER (ORDER BY strftime('%Y-%m', order_date)) as growth FROM orders GROUP BY month ORDER BY month",
55
+ "expected_columns": ["month", "order_count", "growth"],
56
+ "expected_row_count": 6,
57
+ "expected_rows": [
58
+ ["2024-01", 6, null],
59
+ ["2024-02", 6, 0],
60
+ ["2024-03", 7, 1],
61
+ ["2024-04", 4, -3],
62
+ ["2024-05", 4, 0],
63
+ ["2024-06", 3, -1]
64
+ ],
65
+ "order_matters": true
66
+ },
67
+ {
68
+ "id": "hard_4",
69
+ "question": "Find all customers who have purchased products from at least 3 different product categories. Show the customer name and the number of distinct categories they bought from, sorted by category count descending then name ascending.",
70
+ "ground_truth_sql": "SELECT c.name, COUNT(DISTINCT p.category) as category_count FROM customers c JOIN orders o ON c.id = o.customer_id JOIN order_items oi ON o.id = oi.order_id JOIN products p ON oi.product_id = p.id GROUP BY c.id HAVING category_count >= 3 ORDER BY category_count DESC, c.name ASC",
71
+ "expected_columns": ["name", "category_count"],
72
+ "expected_row_count": 5,
73
+ "expected_rows": [
74
+ ["Rahul Kumar", 4],
75
+ ["Ananya Reddy", 3],
76
+ ["Priya Patel", 3],
77
+ ["Ritu Chopra", 3],
78
+ ["Rohan Das", 3]
79
+ ],
80
+ "order_matters": true
81
+ },
82
+ {
83
+ "id": "hard_5",
84
+ "question": "For each product category, find the product with the highest average review rating. Show the category, product name, and average rating (rounded to 2 decimal places). Only include products that have at least 2 reviews. Sort by category alphabetically, then by average rating descending.",
85
+ "ground_truth_sql": "SELECT p.category, p.name, ROUND(AVG(r.rating), 2) as avg_rating FROM products p JOIN reviews r ON p.id = r.product_id GROUP BY p.id HAVING COUNT(r.id) >= 2 ORDER BY p.category, avg_rating DESC",
86
+ "expected_columns": ["category", "name", "avg_rating"],
87
+ "expected_row_count": 8,
88
+ "expected_rows": [
89
+ ["Books", "Python Programming", 4.5],
90
+ ["Clothing", "Running Shoes", 4.67],
91
+ ["Clothing", "Cotton T-Shirt", 3.5],
92
+ ["Electronics", "Wireless Headphones", 4.67],
93
+ ["Electronics", "Bluetooth Speaker", 4.5],
94
+ ["Electronics", "Smartphone Case", 3.5],
95
+ ["Home", "Ceramic Mug Set", 4.67],
96
+ ["Home", "Desk Lamp", 4.5]
97
+ ],
98
+ "order_matters": true
99
+ }
100
+ ]
101
+ }
data/tasks/basic_select.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "task_name": "basic_select",
3
+ "difficulty": "easy",
4
+ "description": "Simple SELECT queries with WHERE, ORDER BY, LIMIT",
5
+ "max_steps_per_question": 3,
6
+ "questions": [
7
+ {
8
+ "id": "easy_1",
9
+ "question": "Find the names and ages of all customers older than 30, sorted by age from highest to lowest.",
10
+ "ground_truth_sql": "SELECT name, age FROM customers WHERE age > 30 ORDER BY age DESC",
11
+ "expected_columns": ["name", "age"],
12
+ "expected_row_count": 13,
13
+ "expected_rows": [
14
+ ["Suresh Menon", 50],
15
+ ["Kavita Joshi", 45],
16
+ ["Swati Tiwari", 44],
17
+ ["Rahul Kumar", 42],
18
+ ["Pooja Mishra", 41],
19
+ ["Divya Saxena", 39],
20
+ ["Arjun Nair", 38],
21
+ ["Rohan Das", 36],
22
+ ["Priya Patel", 35],
23
+ ["Amit Pandey", 34],
24
+ ["Deepak Verma", 33],
25
+ ["Nikhil Bhat", 32],
26
+ ["Vikram Singh", 31]
27
+ ],
28
+ "order_matters": true
29
+ },
30
+ {
31
+ "id": "easy_2",
32
+ "question": "List all products in the 'Electronics' category, showing name and price, sorted by price from highest to lowest.",
33
+ "ground_truth_sql": "SELECT name, price FROM products WHERE category = 'Electronics' ORDER BY price DESC",
34
+ "expected_columns": ["name", "price"],
35
+ "expected_row_count": 4,
36
+ "expected_rows": [
37
+ ["Bluetooth Speaker", 3999.0],
38
+ ["Wireless Headphones", 2499.0],
39
+ ["Smartphone Case", 499.0],
40
+ ["USB-C Cable", 199.0]
41
+ ],
42
+ "order_matters": true
43
+ },
44
+ {
45
+ "id": "easy_3",
46
+ "question": "How many orders have the status 'shipped'?",
47
+ "ground_truth_sql": "SELECT COUNT(*) as shipped_count FROM orders WHERE status = 'shipped'",
48
+ "expected_columns": ["shipped_count"],
49
+ "expected_row_count": 1,
50
+ "expected_rows": [[5]],
51
+ "order_matters": false
52
+ },
53
+ {
54
+ "id": "easy_4",
55
+ "question": "What is the most expensive product? Show its name and price.",
56
+ "ground_truth_sql": "SELECT name, price FROM products ORDER BY price DESC LIMIT 1",
57
+ "expected_columns": ["name", "price"],
58
+ "expected_row_count": 1,
59
+ "expected_rows": [["Bluetooth Speaker", 3999.0]],
60
+ "order_matters": false
61
+ },
62
+ {
63
+ "id": "easy_5",
64
+ "question": "Find all customers from Mumbai who signed up after January 1, 2024. Show their name and signup date, sorted by signup date.",
65
+ "ground_truth_sql": "SELECT name, signup_date FROM customers WHERE city = 'Mumbai' AND signup_date > '2024-01-01' ORDER BY signup_date",
66
+ "expected_columns": ["name", "signup_date"],
67
+ "expected_row_count": 2,
68
+ "expected_rows": [
69
+ ["Karan Malhotra", "2024-01-20"],
70
+ ["Sneha Gupta", "2024-02-14"]
71
+ ],
72
+ "order_matters": true
73
+ }
74
+ ]
75
+ }
data/tasks/join_aggregate.json ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "task_name": "join_aggregate",
3
+ "difficulty": "medium",
4
+ "description": "JOIN queries with GROUP BY, HAVING, and aggregate functions",
5
+ "max_steps_per_question": 4,
6
+ "questions": [
7
+ {
8
+ "id": "med_1",
9
+ "question": "What is the average order total for each customer? Show customer name and average total (rounded to 2 decimal places), sorted by average total from highest to lowest.",
10
+ "ground_truth_sql": "SELECT c.name, ROUND(AVG(o.total_amount), 2) as avg_total FROM customers c JOIN orders o ON c.id = o.customer_id GROUP BY c.id ORDER BY avg_total DESC",
11
+ "expected_columns": ["name", "avg_total"],
12
+ "expected_row_count": 18,
13
+ "expected_rows": [
14
+ ["Swati Tiwari", 3798.0],
15
+ ["Vikram Singh", 3648.0],
16
+ ["Pooja Mishra", 3499.0],
17
+ ["Kavita Joshi", 3248.5],
18
+ ["Rahul Kumar", 2898.5],
19
+ ["Divya Saxena", 2598.5],
20
+ ["Priya Patel", 2548.5],
21
+ ["Karan Malhotra", 2499.0],
22
+ ["Aarav Sharma", 1748.5],
23
+ ["Meera Iyer", 1724.0],
24
+ ["Suresh Menon", 1648.5],
25
+ ["Arjun Nair", 1599.0],
26
+ ["Rohan Das", 1098.5],
27
+ ["Amit Pandey", 1023.5],
28
+ ["Ritu Chopra", 999.0],
29
+ ["Ananya Reddy", 898.0],
30
+ ["Nikhil Bhat", 849.0],
31
+ ["Deepak Verma", 848.0]
32
+ ],
33
+ "order_matters": true
34
+ },
35
+ {
36
+ "id": "med_2",
37
+ "question": "Which products have been ordered more than 2 times in total quantity? Show the product name and total quantity ordered, sorted by total quantity from highest to lowest.",
38
+ "ground_truth_sql": "SELECT p.name, SUM(oi.quantity) as total_qty FROM products p JOIN order_items oi ON p.id = oi.product_id GROUP BY p.id HAVING total_qty > 2 ORDER BY total_qty DESC",
39
+ "expected_columns": ["name", "total_qty"],
40
+ "expected_row_count": 8,
41
+ "expected_rows": [
42
+ ["Wireless Headphones", 7],
43
+ ["Ceramic Mug Set", 6],
44
+ ["Smartphone Case", 5],
45
+ ["Desk Lamp", 4],
46
+ ["Python Programming", 4],
47
+ ["Running Shoes", 4],
48
+ ["Cotton T-Shirt", 4],
49
+ ["Cooking Recipes", 3]
50
+ ],
51
+ "order_matters": true
52
+ },
53
+ {
54
+ "id": "med_3",
55
+ "question": "List all customers who have never placed an order. Show their name and email, sorted by name.",
56
+ "ground_truth_sql": "SELECT name, email FROM customers WHERE id NOT IN (SELECT DISTINCT customer_id FROM orders) ORDER BY name",
57
+ "expected_columns": ["name", "email"],
58
+ "expected_row_count": 2,
59
+ "expected_rows": [
60
+ ["Nisha Agarwal", "nisha@example.com"],
61
+ ["Sneha Gupta", "sneha@example.com"]
62
+ ],
63
+ "order_matters": true
64
+ },
65
+ {
66
+ "id": "med_4",
67
+ "question": "What is the total revenue per product category? Calculate revenue as quantity times unit_price from order_items. Show category and revenue (rounded to 2 decimal places), sorted by revenue from highest to lowest.",
68
+ "ground_truth_sql": "SELECT p.category, ROUND(SUM(oi.quantity * oi.unit_price), 2) as revenue FROM products p JOIN order_items oi ON p.id = oi.product_id GROUP BY p.category ORDER BY revenue DESC",
69
+ "expected_columns": ["category", "revenue"],
70
+ "expected_row_count": 4,
71
+ "expected_rows": [
72
+ ["Electronics", 28384.0],
73
+ ["Clothing", 20889.0],
74
+ ["Home", 10339.0],
75
+ ["Books", 4989.0]
76
+ ],
77
+ "order_matters": true
78
+ },
79
+ {
80
+ "id": "med_5",
81
+ "question": "Who are the top 3 customers by total spending? Show customer name and total amount spent across all orders, sorted by total spent from highest to lowest.",
82
+ "ground_truth_sql": "SELECT c.name, SUM(o.total_amount) as total_spent FROM customers c JOIN orders o ON c.id = o.customer_id GROUP BY c.id ORDER BY total_spent DESC LIMIT 3",
83
+ "expected_columns": ["name", "total_spent"],
84
+ "expected_row_count": 3,
85
+ "expected_rows": [
86
+ ["Vikram Singh", 7296.0],
87
+ ["Kavita Joshi", 6497.0],
88
+ ["Rahul Kumar", 5797.0]
89
+ ],
90
+ "order_matters": true
91
+ }
92
+ ]
93
+ }
inference.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference Script for SQL Query Writing Environment
3
+ ===================================================
4
+ MANDATORY — this file must be named `inference.py` and placed in the project root.
5
+
6
+ Uses OpenAI Client for all LLM calls. Reads credentials from environment variables:
7
+ API_BASE_URL The API endpoint for the LLM.
8
+ MODEL_NAME The model identifier to use for inference.
9
+ HF_TOKEN Your Hugging Face / API key.
10
+
11
+ STDOUT FORMAT:
12
+ [START] task=<task_name> env=<benchmark> model=<model_name>
13
+ [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
14
+ [END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...,rn>
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ import textwrap
20
+ from typing import List, Optional
21
+
22
+ from openai import OpenAI
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Environment — runs locally (no Docker needed for inference)
26
+ # ---------------------------------------------------------------------------
27
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
28
+ os.environ.setdefault("SQL_ENV_TASK", "basic_select")
29
+
30
+ from server.sql_env_environment import SQLEnvironment
31
+ from models import SQLAction
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Configuration
35
+ # ---------------------------------------------------------------------------
36
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
37
+ if not API_KEY:
38
+ try:
39
+ from huggingface_hub import get_token
40
+ API_KEY = get_token()
41
+ except Exception:
42
+ pass
43
+ API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
44
+ MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
45
+
46
+ BENCHMARK = "sql_env"
47
+ TASKS = ["basic_select", "join_aggregate", "advanced_analytics"]
48
+ MAX_STEPS = 8
49
+ TEMPERATURE = 0.3
50
+ MAX_TOKENS = 512
51
+ SUCCESS_SCORE_THRESHOLD = 0.1
52
+
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # Logging helpers (MANDATORY stdout format)
56
+ # ---------------------------------------------------------------------------
57
+ def log_start(task: str, env: str, model: str) -> None:
58
+ print(f"[START] task={task} env={env} model={model}", flush=True)
59
+
60
+
61
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
62
+ # Sanitize action: remove newlines, truncate for readability
63
+ action_clean = action.replace("\n", " ").replace("\r", "").strip()
64
+ if len(action_clean) > 200:
65
+ action_clean = action_clean[:200] + "..."
66
+ error_val = error if error else "null"
67
+ done_val = str(done).lower()
68
+ print(
69
+ f"[STEP] step={step} action={action_clean} reward={reward:.2f} done={done_val} error={error_val}",
70
+ flush=True,
71
+ )
72
+
73
+
74
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
75
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
76
+ print(
77
+ f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
78
+ flush=True,
79
+ )
80
+
81
+
82
+ # ---------------------------------------------------------------------------
83
+ # System prompt for the LLM
84
+ # ---------------------------------------------------------------------------
85
+ SYSTEM_PROMPT = textwrap.dedent("""
86
+ You are an expert SQL query writer. You are given a database schema and a
87
+ natural language question. Write a single SQL SELECT query that answers
88
+ the question exactly.
89
+
90
+ Rules:
91
+ - Write ONLY the SQL query, nothing else. No explanations, no markdown.
92
+ - Use only SELECT statements (no INSERT, UPDATE, DELETE, etc.)
93
+ - Match the requested column names and sorting exactly.
94
+ - Use standard SQL compatible with SQLite.
95
+ - If the question asks for rounding, use ROUND().
96
+ - If the question asks for sorting, include ORDER BY.
97
+ - Pay attention to whether results should be sorted ascending or descending.
98
+ """).strip()
99
+
100
+
101
+ def build_user_prompt(
102
+ question: str,
103
+ schema: str,
104
+ last_result: str,
105
+ last_error: str,
106
+ feedback: str,
107
+ attempt: int,
108
+ ) -> str:
109
+ """Build the prompt for the LLM."""
110
+ parts = [
111
+ f"DATABASE SCHEMA:\n{schema}\n",
112
+ f"QUESTION: {question}\n",
113
+ ]
114
+
115
+ if attempt > 1:
116
+ parts.append(f"PREVIOUS ATTEMPT RESULT:\n{last_result}\n")
117
+ if last_error:
118
+ parts.append(f"ERROR: {last_error}\n")
119
+ if feedback:
120
+ parts.append(f"FEEDBACK: {feedback}\n")
121
+ parts.append(
122
+ f"This is attempt {attempt}. Fix the query based on the feedback above.\n"
123
+ )
124
+
125
+ parts.append("Write the SQL query:")
126
+ return "\n".join(parts)
127
+
128
+
129
+ def get_sql_from_model(
130
+ client: OpenAI,
131
+ question: str,
132
+ schema: str,
133
+ last_result: str,
134
+ last_error: str,
135
+ feedback: str,
136
+ attempt: int,
137
+ ) -> str:
138
+ """Call the LLM to generate a SQL query."""
139
+ user_prompt = build_user_prompt(
140
+ question, schema, last_result, last_error, feedback, attempt
141
+ )
142
+
143
+ try:
144
+ completion = client.chat.completions.create(
145
+ model=MODEL_NAME,
146
+ messages=[
147
+ {"role": "system", "content": SYSTEM_PROMPT},
148
+ {"role": "user", "content": user_prompt},
149
+ ],
150
+ temperature=TEMPERATURE,
151
+ max_tokens=MAX_TOKENS,
152
+ stream=False,
153
+ )
154
+ text = (completion.choices[0].message.content or "").strip()
155
+
156
+ # Clean up: remove markdown code blocks if present
157
+ if text.startswith("```"):
158
+ lines = text.split("\n")
159
+ # Remove first and last lines (```sql and ```)
160
+ lines = [l for l in lines if not l.strip().startswith("```")]
161
+ text = "\n".join(lines).strip()
162
+
163
+ return text if text else "SELECT 1"
164
+ except Exception as exc:
165
+ print(f"[DEBUG] Model request failed: {exc}", flush=True)
166
+ return "SELECT 1"
167
+
168
+
169
+ def run_task(client: OpenAI, task_name: str) -> None:
170
+ """Run a single task and emit [START]/[STEP]/[END] logs."""
171
+ os.environ["SQL_ENV_TASK"] = task_name
172
+
173
+ env = SQLEnvironment()
174
+ obs = env.reset()
175
+
176
+ rewards: List[float] = []
177
+ steps_taken = 0
178
+ score = 0.0
179
+ success = False
180
+ attempt_on_q = 0
181
+
182
+ log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
183
+
184
+ try:
185
+ last_result = ""
186
+ last_error = ""
187
+ feedback = ""
188
+ attempt_on_q = 1
189
+
190
+ for step in range(1, MAX_STEPS + 1):
191
+ if obs.done:
192
+ break
193
+
194
+ # Get SQL from model
195
+ sql_query = get_sql_from_model(
196
+ client=client,
197
+ question=obs.question,
198
+ schema=obs.schema_description,
199
+ last_result=last_result,
200
+ last_error=last_error,
201
+ feedback=feedback,
202
+ attempt=attempt_on_q,
203
+ )
204
+
205
+ # Step the environment
206
+ obs = env.step(SQLAction(query=sql_query))
207
+
208
+ reward = obs.reward
209
+ done = obs.done
210
+ error = obs.error if obs.error else None
211
+
212
+ rewards.append(reward)
213
+ steps_taken = step
214
+
215
+ log_step(
216
+ step=step,
217
+ action=sql_query,
218
+ reward=reward,
219
+ done=done,
220
+ error=error,
221
+ )
222
+
223
+ # Track state for retry prompting
224
+ last_result = obs.query_result
225
+ last_error = obs.error
226
+ feedback = obs.metadata.get("feedback", "")
227
+
228
+ # Track attempt number for current question
229
+ if reward >= 0.98: # near-perfect, moved to next question
230
+ attempt_on_q = 1
231
+ else:
232
+ attempt_on_q += 1
233
+
234
+ if done:
235
+ break
236
+
237
+ # Calculate normalized score
238
+ max_possible = obs.total_questions # 5 questions, max 1.0 each
239
+ if max_possible > 0:
240
+ score = sum(rewards) / max_possible
241
+ score = min(max(score, 0.0), 1.0)
242
+ success = score >= SUCCESS_SCORE_THRESHOLD
243
+
244
+ except Exception as exc:
245
+ print(f"[DEBUG] Task {task_name} error: {exc}", flush=True)
246
+
247
+ finally:
248
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
249
+
250
+
251
+ def main() -> None:
252
+ """Run inference on all 3 tasks."""
253
+ if not API_KEY:
254
+ print("[ERROR] HF_TOKEN or API_KEY environment variable is required.", flush=True)
255
+ sys.exit(1)
256
+
257
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
258
+
259
+ for task_name in TASKS:
260
+ run_task(client, task_name)
261
+ print("", flush=True) # blank line between tasks
262
+
263
+
264
+ if __name__ == "__main__":
265
+ main()
models.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data models for the SQL Query Writing Environment.
3
+
4
+ Defines the Action (agent submits SQL query) and Observation
5
+ (agent sees schema, question, query results, reward) types.
6
+ """
7
+
8
+ from openenv.core.env_server.types import Action, Observation
9
+ from pydantic import Field
10
+
11
+
12
+ class SQLAction(Action):
13
+ """Action for the SQL environment - a SQL query to execute."""
14
+
15
+ query: str = Field(..., description="SQL SELECT query to execute against the database")
16
+
17
+
18
+ class SQLObservation(Observation):
19
+ """Observation from the SQL environment."""
20
+
21
+ task_name: str = Field(default="", description="Current task identifier (e.g., basic_select)")
22
+ question: str = Field(default="", description="Natural language question to answer with SQL")
23
+ schema_description: str = Field(default="", description="Human-readable database schema")
24
+ query_result: str = Field(default="", description="Result of the last executed query (or error)")
25
+ error: str = Field(default="", description="SQL error message if query failed, empty otherwise")
26
+ steps_remaining: int = Field(default=0, description="Number of attempts remaining for current question")
27
+ question_index: int = Field(default=0, description="Current question number (1-indexed)")
28
+ total_questions: int = Field(default=0, description="Total questions in this task")
openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: sql_env
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 7860
pyproject.toml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-sql_env"
13
+ version = "0.1.0"
14
+ description = "Sql Env environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
+ # install from github
19
+ # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
+ "openenv-core[core]>=0.2.2",
21
+ # Environment-specific dependencies
22
+ # Add all dependencies needed for your environment here
23
+ # Examples:
24
+ # "numpy>=1.19.0",
25
+ # "torch>=2.0.0",
26
+ # "gymnasium>=0.29.0",
27
+ # "openspiel>=1.0.0",
28
+ # "smolagents>=1.22.0,<2",
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ dev = [
33
+ "pytest>=8.0.0",
34
+ "pytest-cov>=4.0.0",
35
+ ]
36
+
37
+ [project.scripts]
38
+ # Server entry point - enables running via: uv run --project . server
39
+ # or: python -m sql_env.server.app
40
+ server = "sql_env.server.app:main"
41
+
42
+ [tool.setuptools]
43
+ include-package-data = true
44
+ packages = ["sql_env", "sql_env.server"]
45
+ package-dir = { "sql_env" = ".", "sql_env.server" = "server" }
server/Dockerfile ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Multi-stage build using openenv-base
8
+ # This Dockerfile is flexible and works for both:
9
+ # - In-repo environments (with local OpenEnv sources)
10
+ # - Standalone environments (with openenv from PyPI/Git)
11
+ # The build script (openenv build) handles context detection and sets appropriate build args.
12
+
13
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
14
+ FROM ${BASE_IMAGE} AS builder
15
+
16
+ WORKDIR /app
17
+
18
+ # Ensure git is available (required for installing dependencies from VCS)
19
+ RUN apt-get update && \
20
+ apt-get install -y --no-install-recommends git && \
21
+ rm -rf /var/lib/apt/lists/*
22
+
23
+ # Build argument to control whether we're building standalone or in-repo
24
+ ARG BUILD_MODE=in-repo
25
+ ARG ENV_NAME=sql_env
26
+
27
+ # Copy environment code (always at root of build context)
28
+ COPY . /app/env
29
+
30
+ # For in-repo builds, openenv is already vendored in the build context
31
+ # For standalone builds, openenv will be installed via pyproject.toml
32
+ WORKDIR /app/env
33
+
34
+ # Ensure uv is available (for local builds where base image lacks it)
35
+ RUN if ! command -v uv >/dev/null 2>&1; then \
36
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
37
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
38
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
39
+ fi
40
+
41
+ # Install dependencies using uv sync
42
+ # If uv.lock exists, use it; otherwise resolve on the fly
43
+ RUN --mount=type=cache,target=/root/.cache/uv \
44
+ if [ -f uv.lock ]; then \
45
+ uv sync --frozen --no-install-project --no-editable; \
46
+ else \
47
+ uv sync --no-install-project --no-editable; \
48
+ fi
49
+
50
+ RUN --mount=type=cache,target=/root/.cache/uv \
51
+ if [ -f uv.lock ]; then \
52
+ uv sync --frozen --no-editable; \
53
+ else \
54
+ uv sync --no-editable; \
55
+ fi
56
+
57
+ # Final runtime stage
58
+ FROM ${BASE_IMAGE}
59
+
60
+ WORKDIR /app
61
+
62
+ # Copy the virtual environment from builder
63
+ COPY --from=builder /app/env/.venv /app/.venv
64
+
65
+ # Copy the environment code
66
+ COPY --from=builder /app/env /app/env
67
+
68
+ # Set PATH to use the virtual environment
69
+ ENV PATH="/app/.venv/bin:$PATH"
70
+
71
+ # Set PYTHONPATH so imports work correctly
72
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
73
+
74
+ # Health check
75
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
76
+ CMD curl -f http://localhost:8000/health || exit 1
77
+
78
+ # Run the FastAPI server
79
+ # The module path is constructed to work with the /app/env structure
80
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 8000"]
server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """SQL environment server components."""
server/app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI application for the SQL Query Writing Environment.
3
+
4
+ Endpoints:
5
+ - POST /reset: Reset the environment
6
+ - POST /step: Execute an action (SQL query)
7
+ - GET /state: Get current environment state
8
+ - GET /health: Health check
9
+ - WS /ws: WebSocket endpoint for persistent sessions
10
+
11
+ Usage:
12
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
13
+ """
14
+
15
+ try:
16
+ from openenv.core.env_server.http_server import create_app
17
+ except Exception as e:
18
+ raise ImportError(
19
+ "openenv is required. Install with: pip install openenv-core"
20
+ ) from e
21
+
22
+ try:
23
+ from ..models import SQLAction, SQLObservation
24
+ from .sql_env_environment import SQLEnvironment
25
+ except (ImportError, ModuleNotFoundError):
26
+ from models import SQLAction, SQLObservation
27
+ from server.sql_env_environment import SQLEnvironment
28
+
29
+
30
+ app = create_app(
31
+ SQLEnvironment,
32
+ SQLAction,
33
+ SQLObservation,
34
+ env_name="sql_env",
35
+ max_concurrent_envs=3,
36
+ )
37
+
38
+
39
+ def main(host: str = "0.0.0.0", port: int = 8000):
40
+ """Entry point for direct execution."""
41
+ import uvicorn
42
+ uvicorn.run(app, host=host, port=port)
43
+
44
+
45
+ if __name__ == "__main__":
46
+ import argparse
47
+ parser = argparse.ArgumentParser()
48
+ parser.add_argument("--port", type=int, default=8000)
49
+ args = parser.parse_args()
50
+ main(port=args.port)
server/database.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQLite database management for SQLEnv.
3
+
4
+ Handles database initialization, query execution, and schema introspection.
5
+ All operations use an in-memory SQLite database that is re-created on each
6
+ environment reset, ensuring deterministic, isolated episodes.
7
+ """
8
+
9
+ import os
10
+ import sqlite3
11
+ from dataclasses import dataclass, field
12
+ from pathlib import Path
13
+ from typing import List, Optional, Tuple
14
+
15
+ DATA_DIR = Path(__file__).resolve().parent.parent / "data"
16
+ SCHEMA_PATH = DATA_DIR / "schema.sql"
17
+ SEED_PATH = DATA_DIR / "seed.sql"
18
+
19
+
20
+ @dataclass
21
+ class QueryResult:
22
+ """Result of executing a SQL query."""
23
+
24
+ columns: List[str] = field(default_factory=list)
25
+ rows: List[Tuple] = field(default_factory=list)
26
+ error: Optional[str] = None
27
+ row_count: int = 0
28
+
29
+ @property
30
+ def success(self) -> bool:
31
+ return self.error is None
32
+
33
+ def to_display_string(self, max_rows: int = 20) -> str:
34
+ """Format result as a readable table string."""
35
+ if self.error:
36
+ return f"ERROR: {self.error}"
37
+ if not self.columns:
38
+ return "(no results)"
39
+
40
+ # Calculate column widths
41
+ col_widths = [len(str(c)) for c in self.columns]
42
+ display_rows = self.rows[:max_rows]
43
+ for row in display_rows:
44
+ for i, val in enumerate(row):
45
+ col_widths[i] = max(col_widths[i], len(str(val)))
46
+
47
+ # Build table
48
+ header = " | ".join(
49
+ str(c).ljust(col_widths[i]) for i, c in enumerate(self.columns)
50
+ )
51
+ separator = "-+-".join("-" * w for w in col_widths)
52
+ lines = [header, separator]
53
+
54
+ for row in display_rows:
55
+ line = " | ".join(
56
+ str(val).ljust(col_widths[i]) for i, val in enumerate(row)
57
+ )
58
+ lines.append(line)
59
+
60
+ if len(self.rows) > max_rows:
61
+ lines.append(f"... ({len(self.rows) - max_rows} more rows)")
62
+
63
+ lines.append(f"\n({self.row_count} row{'s' if self.row_count != 1 else ''})")
64
+ return "\n".join(lines)
65
+
66
+
67
+ class Database:
68
+ """
69
+ Manages an in-memory SQLite database for one episode.
70
+
71
+ Each call to `initialize()` creates a fresh database with the schema
72
+ and seed data, ensuring deterministic state across episodes.
73
+ """
74
+
75
+ def __init__(self):
76
+ self._conn: Optional[sqlite3.Connection] = None
77
+
78
+ def initialize(self) -> None:
79
+ """Create a fresh in-memory database with schema and seed data."""
80
+ self.close()
81
+ self._conn = sqlite3.connect(":memory:")
82
+ self._conn.execute("PRAGMA foreign_keys = ON")
83
+
84
+ schema_sql = SCHEMA_PATH.read_text()
85
+ self._conn.executescript(schema_sql)
86
+
87
+ seed_sql = SEED_PATH.read_text()
88
+ self._conn.executescript(seed_sql)
89
+
90
+ self._conn.commit()
91
+
92
+ def execute_query(self, sql: str, timeout_seconds: float = 5.0) -> QueryResult:
93
+ """
94
+ Execute a SQL query and return the result.
95
+
96
+ Only SELECT statements are allowed. Modification statements
97
+ (INSERT, UPDATE, DELETE, DROP, ALTER, CREATE) are rejected.
98
+
99
+ Args:
100
+ sql: The SQL query string to execute.
101
+ timeout_seconds: Max execution time (unused for SQLite in-memory).
102
+
103
+ Returns:
104
+ QueryResult with columns, rows, and potential error.
105
+ """
106
+ if self._conn is None:
107
+ return QueryResult(error="Database not initialized. Call reset() first.")
108
+
109
+ # Strip and normalize
110
+ stripped = sql.strip().rstrip(";").strip()
111
+ if not stripped:
112
+ return QueryResult(error="Empty query.")
113
+
114
+ # Block modification statements
115
+ first_word = stripped.split()[0].upper()
116
+ blocked = {"INSERT", "UPDATE", "DELETE", "DROP", "ALTER", "CREATE", "TRUNCATE", "REPLACE"}
117
+ if first_word in blocked:
118
+ return QueryResult(
119
+ error=f"Only SELECT queries are allowed. Got: {first_word}"
120
+ )
121
+
122
+ try:
123
+ cursor = self._conn.execute(stripped)
124
+ if cursor.description is None:
125
+ return QueryResult(error="Query did not return results.")
126
+
127
+ columns = [desc[0] for desc in cursor.description]
128
+ rows = cursor.fetchall()
129
+ return QueryResult(
130
+ columns=columns,
131
+ rows=rows,
132
+ row_count=len(rows),
133
+ )
134
+ except sqlite3.Error as e:
135
+ return QueryResult(error=str(e))
136
+
137
+ def get_schema_description(self) -> str:
138
+ """
139
+ Return a human-readable description of the database schema
140
+ including table structures and sample data.
141
+ """
142
+ schema_text = []
143
+ schema_text.append("=== DATABASE SCHEMA ===\n")
144
+
145
+ tables = [
146
+ ("customers", "Customer information"),
147
+ ("products", "Product catalog"),
148
+ ("orders", "Customer orders"),
149
+ ("order_items", "Items within each order"),
150
+ ("reviews", "Product reviews by customers"),
151
+ ]
152
+
153
+ if self._conn is None:
154
+ return "Database not initialized."
155
+
156
+ for table_name, description in tables:
157
+ schema_text.append(f"TABLE: {table_name} -- {description}")
158
+
159
+ # Get column info
160
+ cursor = self._conn.execute(f"PRAGMA table_info({table_name})")
161
+ columns = cursor.fetchall()
162
+ for col in columns:
163
+ # col: (cid, name, type, notnull, default_value, pk)
164
+ col_name = col[1]
165
+ col_type = col[2]
166
+ is_pk = " PRIMARY KEY" if col[5] else ""
167
+ is_nn = " NOT NULL" if col[3] else ""
168
+ schema_text.append(f" {col_name} {col_type}{is_pk}{is_nn}")
169
+
170
+ # Get foreign keys
171
+ cursor = self._conn.execute(f"PRAGMA foreign_key_list({table_name})")
172
+ fks = cursor.fetchall()
173
+ for fk in fks:
174
+ schema_text.append(f" FOREIGN KEY ({fk[3]}) REFERENCES {fk[2]}({fk[4]})")
175
+
176
+ # Show sample data (first 3 rows)
177
+ result = self.execute_query(f"SELECT * FROM {table_name} LIMIT 3")
178
+ if result.success and result.rows:
179
+ schema_text.append(f" Sample data ({result.row_count} rows shown):")
180
+ for row in result.rows:
181
+ schema_text.append(f" {row}")
182
+
183
+ # Show total count
184
+ count_result = self.execute_query(
185
+ f"SELECT COUNT(*) FROM {table_name}"
186
+ )
187
+ if count_result.success and count_result.rows:
188
+ total = count_result.rows[0][0]
189
+ schema_text.append(f" Total rows: {total}")
190
+
191
+ schema_text.append("")
192
+
193
+ # Add relationship summary
194
+ schema_text.append("=== RELATIONSHIPS ===")
195
+ schema_text.append("orders.customer_id -> customers.id")
196
+ schema_text.append("order_items.order_id -> orders.id")
197
+ schema_text.append("order_items.product_id -> products.id")
198
+ schema_text.append("reviews.product_id -> products.id")
199
+ schema_text.append("reviews.customer_id -> customers.id")
200
+ schema_text.append("")
201
+ schema_text.append("=== NOTES ===")
202
+ schema_text.append("- All dates are in ISO format (YYYY-MM-DD)")
203
+ schema_text.append("- Prices are in INR (Indian Rupees)")
204
+ schema_text.append("- Order status: pending, shipped, delivered, cancelled")
205
+ schema_text.append("- Product categories: Electronics, Clothing, Books, Home")
206
+ schema_text.append("- Ratings are integers from 1 to 5")
207
+
208
+ return "\n".join(schema_text)
209
+
210
+ def close(self) -> None:
211
+ """Close the database connection."""
212
+ if self._conn is not None:
213
+ self._conn.close()
214
+ self._conn = None
server/graders.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-component grading system for SQL query evaluation.
3
+
4
+ Scores agent queries against ground truth with partial credit:
5
+ - syntax_score (0.1): Query parses and executes without error
6
+ - column_score (0.2): Fraction of expected columns present
7
+ - row_score (0.3): Fraction of expected rows matching
8
+ - exact_score (0.4): Full result set matches ground truth exactly
9
+
10
+ Total reward per question is in [0.0, 1.0].
11
+ """
12
+
13
+ from typing import Any, List, Optional, Tuple
14
+
15
+ from .database import Database, QueryResult
16
+
17
+
18
+ def _normalize_value(val: Any) -> Any:
19
+ """Normalize a value for comparison (handle float/int equivalence, None)."""
20
+ if val is None:
21
+ return None
22
+ if isinstance(val, float):
23
+ if val == int(val):
24
+ return int(val)
25
+ return round(val, 2)
26
+ if isinstance(val, str):
27
+ return val.strip()
28
+ return val
29
+
30
+
31
+ def _normalize_row(row: Tuple) -> Tuple:
32
+ """Normalize all values in a row."""
33
+ return tuple(_normalize_value(v) for v in row)
34
+
35
+
36
+ def _normalize_column_name(col: str) -> str:
37
+ """Normalize column name for comparison (lowercase, strip)."""
38
+ return col.strip().lower()
39
+
40
+
41
+ def grade_query(
42
+ db: Database,
43
+ agent_sql: str,
44
+ expected_columns: List[str],
45
+ expected_rows: List[List],
46
+ order_matters: bool = True,
47
+ ) -> dict:
48
+ """
49
+ Grade an agent's SQL query against expected results.
50
+
51
+ Args:
52
+ db: Active Database instance.
53
+ agent_sql: The SQL query submitted by the agent.
54
+ expected_columns: List of expected column names.
55
+ expected_rows: List of expected row values (list of lists).
56
+ order_matters: Whether row order affects scoring.
57
+
58
+ Returns:
59
+ Dictionary with:
60
+ - reward: float in [0.0, 1.0]
61
+ - syntax_score: float
62
+ - column_score: float
63
+ - row_score: float
64
+ - exact_score: float
65
+ - query_result: QueryResult object
66
+ - feedback: str describing what was right/wrong
67
+ """
68
+ result = db.execute_query(agent_sql)
69
+
70
+ # Component weights
71
+ W_SYNTAX = 0.1
72
+ W_COLUMN = 0.2
73
+ W_ROW = 0.3
74
+ W_EXACT = 0.4
75
+
76
+ # --- Syntax Score ---
77
+ if not result.success:
78
+ return {
79
+ "reward": 0.0,
80
+ "syntax_score": 0.0,
81
+ "column_score": 0.0,
82
+ "row_score": 0.0,
83
+ "exact_score": 0.0,
84
+ "query_result": result,
85
+ "feedback": f"SQL error: {result.error}",
86
+ }
87
+
88
+ syntax_score = 1.0
89
+
90
+ # --- Column Score ---
91
+ expected_cols_normalized = [_normalize_column_name(c) for c in expected_columns]
92
+ actual_cols_normalized = [_normalize_column_name(c) for c in result.columns]
93
+
94
+ if not expected_cols_normalized:
95
+ column_score = 1.0 if not actual_cols_normalized else 0.0
96
+ else:
97
+ matched_cols = sum(
98
+ 1 for c in expected_cols_normalized if c in actual_cols_normalized
99
+ )
100
+ column_score = matched_cols / len(expected_cols_normalized)
101
+
102
+ # --- Row Score ---
103
+ expected_rows_normalized = [
104
+ _normalize_row(tuple(row)) for row in expected_rows
105
+ ]
106
+ actual_rows_normalized = [_normalize_row(row) for row in result.rows]
107
+
108
+ if not expected_rows_normalized:
109
+ row_score = 1.0 if not actual_rows_normalized else 0.0
110
+ else:
111
+ if order_matters:
112
+ # For ordered results, match position-by-position
113
+ matched_rows = 0
114
+ for i, expected_row in enumerate(expected_rows_normalized):
115
+ if i < len(actual_rows_normalized):
116
+ if _rows_match(expected_row, actual_rows_normalized[i], expected_cols_normalized, actual_cols_normalized):
117
+ matched_rows += 1
118
+ row_score = matched_rows / len(expected_rows_normalized)
119
+ else:
120
+ # For unordered results, check set membership
121
+ matched_rows = 0
122
+ remaining_actual = list(actual_rows_normalized)
123
+ for expected_row in expected_rows_normalized:
124
+ for j, actual_row in enumerate(remaining_actual):
125
+ if _rows_match(expected_row, actual_row, expected_cols_normalized, actual_cols_normalized):
126
+ matched_rows += 1
127
+ remaining_actual.pop(j)
128
+ break
129
+ row_score = matched_rows / len(expected_rows_normalized)
130
+
131
+ # --- Exact Score ---
132
+ exact_score = 0.0
133
+ if column_score == 1.0 and row_score == 1.0:
134
+ # Check exact match: same number of rows and all matched
135
+ if len(actual_rows_normalized) == len(expected_rows_normalized):
136
+ exact_score = 1.0
137
+ else:
138
+ # Extra rows returned — partial exact credit
139
+ exact_score = 0.5
140
+
141
+ # --- Total Reward ---
142
+ reward = (
143
+ W_SYNTAX * syntax_score
144
+ + W_COLUMN * column_score
145
+ + W_ROW * row_score
146
+ + W_EXACT * exact_score
147
+ )
148
+ reward = round(min(max(reward, 0.0), 1.0), 4)
149
+
150
+ # --- Feedback ---
151
+ feedback_parts = []
152
+ if syntax_score == 1.0:
153
+ feedback_parts.append("Query executed successfully.")
154
+ if column_score < 1.0:
155
+ missing = [c for c in expected_cols_normalized if c not in actual_cols_normalized]
156
+ feedback_parts.append(f"Missing columns: {missing}. Expected: {expected_cols_normalized}, Got: {actual_cols_normalized}")
157
+ if row_score < 1.0:
158
+ feedback_parts.append(
159
+ f"Row match: {row_score:.0%} ({int(row_score * len(expected_rows_normalized))}/{len(expected_rows_normalized)} rows correct). "
160
+ f"Expected {len(expected_rows_normalized)} rows, got {len(actual_rows_normalized)}."
161
+ )
162
+ if exact_score == 1.0:
163
+ feedback_parts.append("Perfect match!")
164
+ elif exact_score == 0.5:
165
+ feedback_parts.append(f"All expected rows found but got {len(actual_rows_normalized)} rows instead of {len(expected_rows_normalized)} (extra rows).")
166
+
167
+ return {
168
+ "reward": reward,
169
+ "syntax_score": syntax_score,
170
+ "column_score": column_score,
171
+ "row_score": row_score,
172
+ "exact_score": exact_score,
173
+ "query_result": result,
174
+ "feedback": " ".join(feedback_parts),
175
+ }
176
+
177
+
178
+ def _rows_match(
179
+ expected_row: Tuple,
180
+ actual_row: Tuple,
181
+ expected_cols: List[str],
182
+ actual_cols: List[str],
183
+ ) -> bool:
184
+ """
185
+ Check if an actual row matches an expected row.
186
+
187
+ Handles column reordering: maps expected columns to actual column positions.
188
+ """
189
+ if len(expected_cols) != len(expected_row):
190
+ return False
191
+
192
+ # Build a mapping from expected column index to actual column index
193
+ col_map = {}
194
+ for i, ec in enumerate(expected_cols):
195
+ if ec in actual_cols:
196
+ col_map[i] = actual_cols.index(ec)
197
+ else:
198
+ return False # Missing column
199
+
200
+ for exp_idx, act_idx in col_map.items():
201
+ if act_idx >= len(actual_row):
202
+ return False
203
+ exp_val = expected_row[exp_idx]
204
+ act_val = _normalize_value(actual_row[act_idx])
205
+ if exp_val != act_val:
206
+ # Try numeric comparison with tolerance
207
+ try:
208
+ if abs(float(exp_val) - float(act_val)) < 0.01:
209
+ continue
210
+ except (TypeError, ValueError):
211
+ pass
212
+ return False
213
+
214
+ return True
server/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ openenv-core[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+ openai>=1.0.0
server/sql_env_environment.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQL Query Writing Environment.
3
+
4
+ An AI agent receives a database schema and natural language question,
5
+ then writes SQL queries to answer the question. The environment grades
6
+ each query with partial-credit scoring and provides feedback.
7
+ """
8
+
9
+ import json
10
+ import os
11
+ from pathlib import Path
12
+ from uuid import uuid4
13
+
14
+ from openenv.core.env_server.interfaces import Environment
15
+ from openenv.core.env_server.types import State
16
+
17
+ try:
18
+ from ..models import SQLAction, SQLObservation
19
+ except ImportError:
20
+ from models import SQLAction, SQLObservation
21
+
22
+ from .database import Database
23
+ from .graders import grade_query
24
+
25
+ TASKS_DIR = Path(__file__).resolve().parent.parent / "data" / "tasks"
26
+
27
+ # Default task can be overridden via environment variable
28
+ DEFAULT_TASK = os.getenv("SQL_ENV_TASK", "basic_select")
29
+ MAX_TOTAL_STEPS = int(os.getenv("SQL_ENV_MAX_STEPS", "15"))
30
+ STEP_PENALTY = float(os.getenv("SQL_ENV_STEP_PENALTY", "0.02"))
31
+
32
+
33
+ def _load_task(task_name: str) -> dict:
34
+ """Load a task definition from JSON file."""
35
+ task_path = TASKS_DIR / f"{task_name}.json"
36
+ if not task_path.exists():
37
+ available = [f.stem for f in TASKS_DIR.glob("*.json")]
38
+ raise ValueError(
39
+ f"Task '{task_name}' not found. Available: {available}"
40
+ )
41
+ with open(task_path) as f:
42
+ return json.load(f)
43
+
44
+
45
+ class SQLEnvironment(Environment):
46
+ """
47
+ SQL Query Writing Environment.
48
+
49
+ The agent interacts with an e-commerce SQLite database by submitting
50
+ SQL queries to answer natural language questions. Each query is graded
51
+ with a multi-component reward function providing partial credit.
52
+
53
+ Episode flow:
54
+ 1. reset() → loads task, initializes DB, returns first question
55
+ 2. step(SQLAction) → executes query, grades it, returns observation
56
+ 3. Episode ends when all questions answered or max steps reached
57
+ """
58
+
59
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
60
+
61
+ def __init__(self):
62
+ self._db = Database()
63
+ self._state = State(episode_id=str(uuid4()), step_count=0)
64
+ self._task: dict = {}
65
+ self._questions: list = []
66
+ self._current_q_index: int = 0
67
+ self._q_steps_used: int = 0
68
+ self._max_steps_per_q: int = 3
69
+ self._total_steps: int = 0
70
+ self._rewards: list = []
71
+ self._schema_cache: str = ""
72
+ self._done: bool = False
73
+ self._last_feedback: str = ""
74
+
75
+ def reset(self) -> SQLObservation:
76
+ """
77
+ Reset the environment: initialize DB, load task, return first question.
78
+ """
79
+ self._db.initialize()
80
+ self._state = State(episode_id=str(uuid4()), step_count=0)
81
+
82
+ task_name = os.getenv("SQL_ENV_TASK", DEFAULT_TASK)
83
+ self._task = _load_task(task_name)
84
+ self._questions = self._task["questions"]
85
+ self._max_steps_per_q = self._task.get("max_steps_per_question", 3)
86
+ self._current_q_index = 0
87
+ self._q_steps_used = 0
88
+ self._total_steps = 0
89
+ self._rewards = []
90
+ self._done = False
91
+ self._last_feedback = ""
92
+ self._schema_cache = self._db.get_schema_description()
93
+
94
+ return self._make_observation(
95
+ reward=0.0,
96
+ query_result="",
97
+ error="",
98
+ )
99
+
100
+ def step(self, action: SQLAction) -> SQLObservation: # type: ignore[override]
101
+ """
102
+ Execute the agent's SQL query, grade it, and return observation.
103
+ """
104
+ # Auto-reset if step called before reset (HTTP stateless mode)
105
+ if not self._questions:
106
+ self.reset()
107
+
108
+ if self._done or self._current_q_index >= len(self._questions):
109
+ self._done = True
110
+ return self._make_observation(
111
+ reward=0.0,
112
+ query_result="Episode is over. Call reset() to start a new episode.",
113
+ error="",
114
+ )
115
+
116
+ self._state.step_count += 1
117
+ self._total_steps += 1
118
+ self._q_steps_used += 1
119
+
120
+ # Get current question
121
+ question = self._questions[self._current_q_index]
122
+
123
+ # Grade the query
124
+ grade_result = grade_query(
125
+ db=self._db,
126
+ agent_sql=action.query,
127
+ expected_columns=question["expected_columns"],
128
+ expected_rows=question["expected_rows"],
129
+ order_matters=question.get("order_matters", True),
130
+ )
131
+
132
+ raw_reward = grade_result["reward"]
133
+
134
+ # Apply step penalty (not on first attempt)
135
+ penalty = STEP_PENALTY * (self._q_steps_used - 1)
136
+ reward = max(raw_reward - penalty, 0.0)
137
+ reward = round(reward, 4)
138
+
139
+ self._rewards.append(reward)
140
+ self._last_feedback = grade_result["feedback"]
141
+
142
+ # Format query result for observation
143
+ query_result_str = grade_result["query_result"].to_display_string()
144
+ error_str = grade_result["query_result"].error or ""
145
+
146
+ # Check if we should move to next question
147
+ perfect = grade_result["exact_score"] == 1.0
148
+ out_of_attempts = self._q_steps_used >= self._max_steps_per_q
149
+ move_on = perfect or out_of_attempts
150
+
151
+ if move_on:
152
+ self._current_q_index += 1
153
+ self._q_steps_used = 0
154
+
155
+ # Check if episode is done
156
+ if self._current_q_index >= len(self._questions):
157
+ self._done = True
158
+ if self._total_steps >= MAX_TOTAL_STEPS:
159
+ self._done = True
160
+
161
+ return self._make_observation(
162
+ reward=reward,
163
+ query_result=query_result_str,
164
+ error=error_str,
165
+ )
166
+
167
+ @property
168
+ def state(self) -> State:
169
+ return self._state
170
+
171
+ def _make_observation(
172
+ self,
173
+ reward: float,
174
+ query_result: str,
175
+ error: str,
176
+ ) -> SQLObservation:
177
+ """Build an SQLObservation for the current state."""
178
+ if self._done or not self._questions or self._current_q_index >= len(self._questions):
179
+ # Episode finished or not started
180
+ return SQLObservation(
181
+ task_name=self._task.get("task_name", ""),
182
+ question="Episode complete. All questions answered.",
183
+ schema_description="",
184
+ query_result=query_result,
185
+ error=error,
186
+ steps_remaining=0,
187
+ question_index=len(self._questions),
188
+ total_questions=len(self._questions),
189
+ done=True,
190
+ reward=reward,
191
+ metadata={
192
+ "feedback": self._last_feedback,
193
+ "total_reward": round(sum(self._rewards), 4),
194
+ "rewards": [round(r, 4) for r in self._rewards],
195
+ },
196
+ )
197
+
198
+ question = self._questions[self._current_q_index]
199
+ steps_remaining = self._max_steps_per_q - self._q_steps_used
200
+
201
+ return SQLObservation(
202
+ task_name=self._task.get("task_name", ""),
203
+ question=question["question"],
204
+ schema_description=self._schema_cache,
205
+ query_result=query_result,
206
+ error=error,
207
+ steps_remaining=steps_remaining,
208
+ question_index=self._current_q_index + 1,
209
+ total_questions=len(self._questions),
210
+ done=False,
211
+ reward=reward,
212
+ metadata={
213
+ "feedback": self._last_feedback,
214
+ "question_id": question["id"],
215
+ "difficulty": self._task.get("difficulty", ""),
216
+ },
217
+ )