UtkarshSatav commited on
Commit
41fadd7
Β·
verified Β·
1 Parent(s): 1c5c280

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +262 -46
README.md CHANGED
@@ -12,45 +12,76 @@ tags:
12
 
13
  # SQLEnv β€” SQL Query Writing Environment for AI Agents
14
 
15
- An OpenEnv-compatible reinforcement learning environment where an AI agent learns to write correct SQL queries from natural language questions against a realistic e-commerce database.
16
 
17
- ## Overview
18
 
19
- The agent receives a database schema and a natural language question, submits SQL queries, and gets graded with rich partial-credit scoring.
20
 
21
- **3 difficulty levels, 5 questions each:**
 
 
 
 
 
22
 
23
- | Task | Difficulty | SQL Features |
24
- |---|---|---|
25
- | `basic_select` | Easy | WHERE, ORDER BY, LIMIT, COUNT |
26
- | `join_aggregate` | Medium | JOIN, GROUP BY, HAVING, AVG, SUM |
27
- | `advanced_analytics` | Hard | Subqueries, RANK(), LAG(), PARTITION BY |
28
 
29
- ## API Endpoints
30
 
31
- | Endpoint | Method | Description |
32
- |---|---|---|
33
- | `/health` | GET | Health check |
34
- | `/reset` | POST | Reset environment, get first question |
35
- | `/step` | POST | Submit SQL query, get graded result |
36
- | `/state` | GET | Current episode state |
37
- | `/docs` | GET | Interactive API documentation |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- ## Action Space
 
 
 
 
40
 
41
  ```json
42
- {"action": {"query": "SELECT name, age FROM customers WHERE age > 30 ORDER BY age DESC"}}
 
 
 
 
43
  ```
44
 
45
- ## Observation Space
 
 
 
 
 
 
46
 
47
  ```json
48
  {
49
  "observation": {
50
  "task_name": "basic_select",
51
- "question": "Find all customers older than 30...",
52
- "schema_description": "=== DATABASE SCHEMA === ...",
53
- "query_result": "name | age ...",
54
  "error": "",
55
  "steps_remaining": 2,
56
  "question_index": 1,
@@ -61,47 +92,232 @@ The agent receives a database schema and a natural language question, submits SQ
61
  }
62
  ```
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  ## Reward Function
65
 
66
- Multi-component partial credit scoring (0.0 to 1.0):
67
 
68
- | Component | Weight | What It Measures |
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  |---|---|---|
70
- | Syntax | 0.1 | Query executes without error |
71
- | Columns | 0.2 | Expected columns present |
72
- | Rows | 0.3 | Expected rows match |
73
- | Exact | 0.4 | Full result set matches ground truth |
 
 
 
 
 
 
74
 
75
  ## Database
76
 
77
- Realistic e-commerce database with 5 tables:
78
- - **customers** (20 rows) - name, email, age, city, signup_date
79
- - **products** (15 rows) - name, category, price, stock
80
- - **orders** (30 rows) - customer_id, order_date, status, total_amount
81
- - **order_items** (46 rows) - order_id, product_id, quantity, unit_price
82
- - **reviews** (25 rows) - product_id, customer_id, rating, review_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  ## Baseline Scores
85
 
86
- Tested with Llama 3.3 70B (via Groq):
87
 
88
- | Task | Score |
89
- |---|---|
90
- | basic_select | 1.000 |
91
- | join_aggregate | 1.000 |
92
- | advanced_analytics | 0.969 |
 
 
93
 
94
- ## Setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  ```bash
97
- pip install openenv-core
98
- cd sql_env
 
99
  uvicorn server.app:app --host 0.0.0.0 --port 7860
100
  ```
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  ## Environment Variables
103
 
104
  | Variable | Default | Description |
105
  |---|---|---|
106
- | SQL_ENV_TASK | basic_select | Task to load |
107
- | SQL_ENV_MAX_STEPS | 15 | Max steps per episode |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # SQLEnv β€” SQL Query Writing Environment for AI Agents
14
 
15
+ An OpenEnv-compatible reinforcement learning environment where AI agents learn to write correct SQL queries from natural language questions. Built on a realistic e-commerce database with deterministic, partial-credit grading across 3 difficulty levels.
16
 
17
+ ## Why This Environment Matters
18
 
19
+ **Text-to-SQL is a real-world task** performed by millions of data analysts, engineers, and business users daily. Natural language interfaces to databases are a $2B+ market (Tableau Ask Data, ThoughtSpot, etc.), yet there is no standardized RL benchmark for training and evaluating text-to-SQL agents in the OpenEnv ecosystem.
20
 
21
+ SQLEnv fills this gap by providing:
22
+ - A **realistic domain** (e-commerce) that any evaluator can understand
23
+ - **Deterministic grading** β€” same query always produces the same score
24
+ - **Rich reward signal** β€” not just pass/fail, but 4-component partial credit
25
+ - **Meaningful difficulty progression** β€” from simple WHERE clauses to window functions
26
+ - **Immediate practical value** β€” can be used to train/evaluate text-to-SQL copilots
27
 
28
+ **Who benefits:** AI/ML researchers training agents, data teams evaluating LLM SQL capability, educators teaching SQL with automated grading, and enterprises building natural language BI tools.
 
 
 
 
29
 
30
+ ## Interactive Demo
31
 
32
+ Visit the **[Live Space](https://huggingface.co/spaces/UtkarshSatav/sql-env)** to try the environment directly in your browser:
33
+ - Select a difficulty level (easy / medium / hard)
34
+ - Read the natural language question
35
+ - Write SQL queries and get instant graded feedback
36
+ - See color-coded rewards and visual progress tracking
37
+
38
+ ## Environment Design
39
+
40
+ ### Episode Flow
41
+
42
+ ```
43
+ Agent Environment
44
+ | |
45
+ |-------- reset() ------------->| Initialize DB, load task, return Q1
46
+ |<------- observation ----------| Schema + question + 0 reward
47
+ | |
48
+ |--- step(SQLAction(query)) --->| Execute SQL, grade against ground truth
49
+ |<------- observation ----------| Result + reward + feedback
50
+ | |
51
+ | (retry or next question) | Move on after perfect score or max attempts
52
+ | |
53
+ |<------- done=true ------------| All 5 questions answered
54
+ ```
55
 
56
+ Each episode consists of 5 questions. The agent gets multiple attempts per question (3 for easy, 4 for medium, 5 for hard). A small step penalty (-0.02 per retry) encourages efficiency.
57
+
58
+ ### Action Space
59
+
60
+ The agent submits a single SQL SELECT query:
61
 
62
  ```json
63
+ {
64
+ "action": {
65
+ "query": "SELECT name, age FROM customers WHERE age > 30 ORDER BY age DESC"
66
+ }
67
+ }
68
  ```
69
 
70
+ - Only SELECT statements are allowed (INSERT/UPDATE/DELETE/DROP are blocked)
71
+ - Queries execute against an in-memory SQLite database
72
+ - The agent sees the full database schema in every observation
73
+
74
+ ### Observation Space
75
+
76
+ After each step, the agent receives:
77
 
78
  ```json
79
  {
80
  "observation": {
81
  "task_name": "basic_select",
82
+ "question": "Find the names and ages of all customers older than 30, sorted by age from highest to lowest.",
83
+ "schema_description": "=== DATABASE SCHEMA ===\n\nTABLE: customers -- Customer information\n id INTEGER PRIMARY KEY\n name TEXT NOT NULL\n ...",
84
+ "query_result": "name | age\n--------------+----\nSuresh Menon | 50\nKavita Joshi | 45\n...",
85
  "error": "",
86
  "steps_remaining": 2,
87
  "question_index": 1,
 
92
  }
93
  ```
94
 
95
+ Key fields:
96
+ - `question` β€” Natural language question to answer with SQL
97
+ - `schema_description` β€” Full database schema with sample data and relationships
98
+ - `query_result` β€” Formatted table output of the executed query (or error message)
99
+ - `error` β€” SQL error string if query failed, empty otherwise
100
+ - `steps_remaining` β€” Attempts left for this question
101
+ - `metadata.feedback` β€” Human-readable explanation of what was right/wrong
102
+
103
+ ## Tasks (3 Difficulty Levels)
104
+
105
+ ### Task 1: `basic_select` (Easy)
106
+ Simple SELECT queries with WHERE, ORDER BY, LIMIT, COUNT.
107
+
108
+ **Example questions:**
109
+ - "Find the names and ages of all customers older than 30, sorted by age descending"
110
+ - "List all products in the Electronics category, sorted by price descending"
111
+ - "How many orders have the status shipped?"
112
+
113
+ **Max attempts per question:** 3
114
+
115
+ ### Task 2: `join_aggregate` (Medium)
116
+ JOIN queries with GROUP BY, HAVING, and aggregate functions.
117
+
118
+ **Example questions:**
119
+ - "What is the average order total for each customer?"
120
+ - "Which products have been ordered more than 2 times in total quantity?"
121
+ - "List all customers who have never placed an order"
122
+
123
+ **Max attempts per question:** 4
124
+
125
+ ### Task 3: `advanced_analytics` (Hard)
126
+ Subqueries, CTEs, window functions, and complex multi-table analytics.
127
+
128
+ **Example questions:**
129
+ - "Find all customers whose total spending exceeds the average customer spending"
130
+ - "Rank all products by revenue within each category using RANK()"
131
+ - "Calculate month-over-month growth in order count using LAG()"
132
+
133
+ **Max attempts per question:** 5
134
+
135
  ## Reward Function
136
 
137
+ The reward is **NOT binary**. It provides rich gradient signal across the full trajectory using 4 weighted components:
138
 
139
+ ```
140
+ Total Reward = 0.1 Γ— syntax + 0.2 Γ— columns + 0.3 Γ— rows + 0.4 Γ— exact
141
+ ```
142
+
143
+ | Component | Weight | Score Range | What It Measures |
144
+ |---|---|---|---|
145
+ | **Syntax** | 0.1 | 0 or 1 | Query parses and executes without SQL error |
146
+ | **Columns** | 0.2 | 0.0–1.0 | Fraction of expected column names present in result |
147
+ | **Rows** | 0.3 | 0.0–1.0 | Fraction of expected rows matching (position-aware for ordered queries) |
148
+ | **Exact** | 0.4 | 0, 0.5, or 1 | Full result set matches ground truth (0.5 if extra rows) |
149
+
150
+ **Score examples:**
151
+
152
+ | Scenario | Reward | Breakdown |
153
  |---|---|---|
154
+ | Syntax error (`SELEC * FORM`) | 0.00 | All components zero |
155
+ | Valid SQL, wrong columns | 0.10 | Syntax only |
156
+ | Right columns, partial rows | 0.40 | Syntax + columns + partial rows |
157
+ | Right columns, all rows, extra rows | 0.80 | Syntax + columns + rows + partial exact |
158
+ | Perfect match | 1.00 | All components maximum |
159
+
160
+ **Additional mechanics:**
161
+ - **Step penalty:** -0.02 per retry attempt (encourages getting it right the first time)
162
+ - **Deterministic:** Same query always produces the same score
163
+ - **Handles edge cases:** NULL values, float tolerance (Β±0.01), column reordering
164
 
165
  ## Database
166
 
167
+ Realistic e-commerce database with 5 interconnected tables:
168
+
169
+ ```
170
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
171
+ β”‚ customers β”‚ β”‚ orders β”‚ β”‚ products β”‚
172
+ β”‚ (20 rows) │───<β”‚ (30 rows) β”‚ β”‚ (15 rows) β”‚
173
+ β”‚ name, email β”‚ β”‚ order_date β”‚ β”‚ name, price β”‚
174
+ β”‚ age, city β”‚ β”‚ status β”‚ β”‚ category β”‚
175
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
176
+ β”‚
177
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
178
+ β”‚ order_items │───>β”‚ reviews β”‚
179
+ β”‚ (46 rows) β”‚ β”‚ (25 rows) β”‚
180
+ β”‚ quantity β”‚ β”‚ rating 1-5 β”‚
181
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
182
+ ```
183
+
184
+ **Key data characteristics:**
185
+ - 4 product categories: Electronics, Clothing, Books, Home
186
+ - 4 order statuses: pending, shipped, delivered, cancelled
187
+ - 7 cities across India
188
+ - 2 customers with no orders (for LEFT JOIN / NOT IN queries)
189
+ - 5 customers spanning 3+ product categories (for complex analytics)
190
+ - Prices in INR, dates in ISO format
191
+
192
+ All data is **deterministic** β€” seeded identically on every `reset()`.
193
 
194
  ## Baseline Scores
195
 
196
+ Tested with **Llama 3.3 70B** (via Groq, temperature=0.3):
197
 
198
+ | Task | Score | Steps | Notes |
199
+ |---|---|---|---|
200
+ | `basic_select` | **1.000** | 6 | Solved all 5 questions, 1 retry on Q3 |
201
+ | `join_aggregate` | **1.000** | 8 | Used all 8 steps, retried Q2 multiple times |
202
+ | `advanced_analytics` | **0.969** | 8 | Near-perfect, retries on RANK() and LAG() queries |
203
+
204
+ Also tested with **Qwen 2.5 72B** (via HuggingFace):
205
 
206
+ | Task | Score | Notes |
207
+ |---|---|---|
208
+ | `basic_select` | **1.000** | Perfect |
209
+ | `join_aggregate` | **0.967** | Near-perfect |
210
+ | `advanced_analytics` | **0.623** | Partial (API credits ran out mid-task) |
211
+
212
+ These scores demonstrate:
213
+ - Easy task is solvable by current LLMs (validates environment works)
214
+ - Hard task challenges even 70B models (meaningful benchmark)
215
+ - Scores vary across runs and models (not a constant grader)
216
+
217
+ ## API Endpoints
218
+
219
+ | Endpoint | Method | Description |
220
+ |---|---|---|
221
+ | `/` | GET | Interactive Gradio playground |
222
+ | `/health` | GET | Health check β†’ `{"status": "healthy"}` |
223
+ | `/reset` | POST | Reset environment, returns initial observation |
224
+ | `/step` | POST | Submit SQL query, returns graded observation |
225
+ | `/state` | GET | Current episode state (episode_id, step_count) |
226
+ | `/schema` | GET | Action/observation JSON schemas |
227
+ | `/docs` | GET | Interactive Swagger API documentation |
228
+ | `/ws` | WS | WebSocket for persistent sessions |
229
+
230
+ ## Setup & Usage
231
+
232
+ ### Quick Start (Local)
233
 
234
  ```bash
235
+ git clone https://github.com/UtkarshSatav/Scaler-Hackathon-SQL.env-.git
236
+ cd Scaler-Hackathon-SQL.env-
237
+ pip install openenv-core[core] fastapi uvicorn gradio openai
238
  uvicorn server.app:app --host 0.0.0.0 --port 7860
239
  ```
240
 
241
+ ### Docker
242
+
243
+ ```bash
244
+ docker build -t sql-env:latest -f Dockerfile .
245
+ docker run -p 7860:7860 sql-env:latest
246
+ ```
247
+
248
+ ### Using the Client
249
+
250
+ ```python
251
+ from client import SQLEnvClient
252
+ from models import SQLAction
253
+
254
+ with SQLEnvClient(base_url="http://localhost:7860") as env:
255
+ result = env.reset()
256
+ print(result.observation.question)
257
+
258
+ result = env.step(SQLAction(query="SELECT * FROM customers"))
259
+ print(f"Reward: {result.reward}")
260
+ ```
261
+
262
+ ### Running Inference
263
+
264
+ ```bash
265
+ export API_BASE_URL="https://api.groq.com/openai/v1"
266
+ export API_KEY="your_key"
267
+ export MODEL_NAME="llama-3.3-70b-versatile"
268
+ python inference.py
269
+ ```
270
+
271
  ## Environment Variables
272
 
273
  | Variable | Default | Description |
274
  |---|---|---|
275
+ | `SQL_ENV_TASK` | `basic_select` | Task to load: basic_select, join_aggregate, advanced_analytics |
276
+ | `SQL_ENV_MAX_STEPS` | `15` | Maximum total steps per episode |
277
+ | `SQL_ENV_STEP_PENALTY` | `0.02` | Penalty per retry attempt |
278
+ | `API_BASE_URL` | `https://router.huggingface.co/v1` | LLM API endpoint (for inference.py) |
279
+ | `MODEL_NAME` | `Qwen/Qwen2.5-72B-Instruct` | Model for inference |
280
+ | `HF_TOKEN` / `API_KEY` | (required) | API key for LLM calls |
281
+
282
+ ## Project Structure
283
+
284
+ ```
285
+ sql_env/
286
+ β”œβ”€β”€ inference.py # Mandatory baseline inference script
287
+ β”œβ”€β”€ Dockerfile # HF Spaces deployment container
288
+ β”œβ”€β”€ openenv.yaml # OpenEnv metadata
289
+ β”œβ”€β”€ README.md # This file
290
+ β”œβ”€β”€ models.py # SQLAction, SQLObservation (typed Pydantic models)
291
+ β”œβ”€β”€ client.py # SQLEnvClient for WebSocket connections
292
+ β”œβ”€β”€ data/
293
+ β”‚ β”œβ”€β”€ schema.sql # Database table definitions
294
+ β”‚ β”œβ”€β”€ seed.sql # Deterministic seed data
295
+ β”‚ └── tasks/
296
+ β”‚ β”œβ”€β”€ basic_select.json # 5 easy questions + ground truth
297
+ β”‚ β”œβ”€β”€ join_aggregate.json # 5 medium questions + ground truth
298
+ β”‚ └── advanced_analytics.json # 5 hard questions + ground truth
299
+ β”œβ”€β”€ server/
300
+ β”‚ β”œβ”€β”€ app.py # FastAPI + Gradio app
301
+ β”‚ β”œβ”€β”€ sql_env_environment.py # SQLEnvironment (reset/step/state)
302
+ β”‚ β”œβ”€β”€ database.py # SQLite management
303
+ β”‚ β”œβ”€β”€ graders.py # Multi-component reward function
304
+ β”‚ β”œβ”€β”€ gradio_ui.py # Interactive web UI
305
+ β”‚ └── Dockerfile # Alternative Dockerfile (openenv scaffold)
306
+ └── tests/
307
+ β”œβ”€β”€ test_database.py # 8 tests
308
+ β”œβ”€β”€ test_graders.py # 13 tests
309
+ β”œβ”€β”€ test_environment.py # 15 tests
310
+ β”œβ”€β”€ test_inference.py # 8 tests
311
+ └── test_server.py # 5 tests (49 total, all passing)
312
+ ```
313
+
314
+ ## Technical Decisions
315
+
316
+ | Decision | Rationale |
317
+ |---|---|
318
+ | SQLite (in-memory) | Zero external deps, deterministic, resets instantly |
319
+ | E-commerce domain | Universally understood, naturally scales in complexity |
320
+ | 4-component reward | Provides gradient signal, not just binary pass/fail |
321
+ | Multiple attempts per question | Allows agent to learn from errors within an episode |
322
+ | Fixed seed data | Reproducible results β€” same data every reset() |
323
+ | Schema in every observation | No hidden information, agent sees full context |