shreyas231219 commited on
Commit
615101f
·
verified ·
1 Parent(s): 873275a

Upload folder using huggingface_hub

Browse files
Files changed (14) hide show
  1. Dockerfile +61 -0
  2. README.md +104 -12
  3. __init__.py +16 -0
  4. baseline_inference_groq.py +193 -0
  5. client.py +55 -0
  6. inference.py +200 -0
  7. models.py +38 -0
  8. openenv.yaml +50 -0
  9. pyproject.toml +32 -0
  10. server/__init__.py +11 -0
  11. server/app.py +76 -0
  12. server/environment.py +397 -0
  13. server/requirements.txt +2 -0
  14. uv.lock +0 -0
Dockerfile ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SQL/Data Cleaning Sandbox Dockerfile for Hugging Face Spaces
2
+
3
+
4
+
5
+ # Use official Python 3.11 slim image
6
+
7
+ FROM python:3.11-slim
8
+
9
+
10
+
11
+ # Set working directory
12
+
13
+ WORKDIR /app
14
+
15
+
16
+
17
+ # Install required system packages
18
+
19
+ RUN apt-get update && \
20
+
21
+ apt-get install -y --no-install-recommends git curl build-essential && \
22
+
23
+ rm -rf /var/lib/apt/lists/*
24
+
25
+
26
+
27
+ # Copy project files
28
+
29
+ COPY . /app/
30
+
31
+
32
+
33
+ # Install python dependencies directly bypassing complex managers to ensure maximum Hugging Face compatibility
34
+
35
+ RUN pip install --no-cache-dir --upgrade pip && \
36
+
37
+ pip install --no-cache-dir uvicorn openenv-core[core]>=0.2.2 requests>=2.31.0 openai>=1.0.0 groq>=0.4.0 python-dotenv
38
+
39
+
40
+
41
+ # OpenEnv needs the workspace in PYTHONPATH
42
+
43
+ ENV PYTHONPATH="/app"
44
+
45
+ # Default fallback task
46
+
47
+ ENV TASK_ID="easy"
48
+
49
+
50
+
51
+ # Hugging Face Spaces exposes port 7860
52
+
53
+ EXPOSE 7860
54
+
55
+
56
+
57
+ # Command to run the OpenEnv Server directly
58
+
59
+ ENV ENABLE_WEB_INTERFACE=true
60
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
61
+
README.md CHANGED
@@ -1,12 +1,104 @@
1
- ---
2
- title: Meta Pytorch Openenv
3
- emoji: 🐠
4
- colorFrom: blue
5
- colorTo: purple
6
- sdk: docker
7
- pinned: false
8
- license: apache-2.0
9
- short_description: Openenv from scratch for testing
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Meta-Pytorch-Openenv
3
+ emoji: 🦀
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: docker
7
+ app_port: 7860
8
+ base_path: /web
9
+ ---
10
+ # SQL / Data Cleaning Sandbox
11
+
12
+ An **OpenEnv**-compliant environment where AI agents clean messy SQLite databases
13
+ using SQL queries and Python code.
14
+
15
+ ## Overview
16
+
17
+ | Feature | Details |
18
+ |---|---|
19
+ | **Interface** | `step()` / `reset()` / `state()` |
20
+ | **Action space** | `{ tool: "sql" \| "python", command: "..." }` |
21
+ | **Observation** | `{ output, error, current_step, max_steps, task_description }` |
22
+ | **Reward** | 0.0 - 1.0 with **partial progress signals** |
23
+ | **Tasks** | 3 (easy, medium, hard) |
24
+
25
+ ## Tasks
26
+
27
+ ### Easy - Data Triage
28
+ > Find the total revenue from the `sales` table for January 2024.
29
+
30
+ **Grader**: Checks if the computed total matches the expected float value (1000.00).
31
+
32
+ ### Medium - Data Cleaning
33
+ > Fix duplicate emails, NULL ages, and uppercase emails in the `users` table.
34
+
35
+ **Grader**: Partial scoring:
36
+ - 0.3 for all emails lowercase
37
+ - 0.4 for no duplicate emails
38
+ - 0.3 for no NULL ages
39
+
40
+ ### Hard - Schema Migration
41
+ > Normalize `flat_orders` into `customers` + `orders` tables with foreign keys.
42
+
43
+ **Grader**: Partial scoring:
44
+ - 0.2 for correct `customers` schema
45
+ - 0.2 for correct `orders` schema
46
+ - 0.2 for 4 unique customers
47
+ - 0.2 for 6 orders migrated
48
+ - 0.2 for valid FK integrity
49
+
50
+ ## Quick Start
51
+
52
+ ### Local Development
53
+
54
+ ```bash
55
+ # Install dependencies
56
+ pip install openenv-core
57
+
58
+ # Run the server (defaults to the 'easy' task)
59
+ cd sql_sandbox
60
+ TASK_ID=easy python -m server.app
61
+
62
+ # Switch tasks via env var
63
+ TASK_ID=medium python -m server.app
64
+ TASK_ID=hard python -m server.app
65
+ ```
66
+
67
+ ### Docker (Hugging Face Spaces Ready)
68
+
69
+ ```bash
70
+ # Build
71
+ docker build -t sql-sandbox:latest .
72
+
73
+ # Run on HF Spaces default port 7860
74
+ docker run -p 7860:7860 sql-sandbox:latest
75
+ ```
76
+
77
+ ## Baseline Inference
78
+
79
+ Runs GPT-4o on all three tasks and prints reproducible scores:
80
+
81
+ ```bash
82
+ export HF_TOKEN=sk-...
83
+ export MODEL_NAME=gpt-4o
84
+ python inference.py --url http://localhost:7860
85
+ ```
86
+
87
+ ## Project Structure
88
+
89
+ ```
90
+ sql_sandbox/
91
+ ├── init.py # Package exports
92
+ ├── models.py # Action & Observation Pydantic models
93
+ ├── client.py # EnvClient subclass
94
+ ├── openenv.yaml # OpenEnv manifest
95
+ ├── pyproject.toml # Dependencies
96
+ ├── inference.py # GPT-4o baseline script
97
+ ├── README.md # This file
98
+ └── server/
99
+ ├── init.py
100
+ ├── app.py # FastAPI application
101
+ ├── environment.py # Core environment logic + graders
102
+ ├── requirements.txt
103
+ └── Dockerfile
104
+ ```
__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """SQL/Data Cleaning Sandbox Environment."""
8
+
9
+ from .client import SqlSandboxEnv
10
+ from .models import SqlSandboxAction, SqlSandboxObservation
11
+
12
+ __all__ = [
13
+ "SqlSandboxAction",
14
+ "SqlSandboxObservation",
15
+ "SqlSandboxEnv",
16
+ ]
baseline_inference_groq.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Baseline Inference Script for SQL/Data Cleaning Sandbox -- Groq Edition.
3
+
4
+ Uses Groq (llama-3.3-70b-versatile) to solve all three tasks and prints
5
+ reproducible scores via the OpenEnv WebSocket client.
6
+
7
+ Usage:
8
+ set GROQ_API_KEY=gsk-... # Windows
9
+ export GROQ_API_KEY=gsk-... # Linux/macOS
10
+ python baseline_inference_groq.py # local server
11
+ python baseline_inference_groq.py --url https://... # remote server
12
+ """
13
+
14
+ import argparse
15
+ import json
16
+ import os
17
+ import sys
18
+
19
+ from dotenv import load_dotenv
20
+ load_dotenv()
21
+
22
+ from groq import Groq
23
+
24
+ from client import SqlSandboxEnv
25
+ from models import SqlSandboxAction
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # System prompt shared across all tasks
30
+ # ---------------------------------------------------------------------------
31
+ SYSTEM_PROMPT = """\
32
+ You are a data engineering assistant working inside a SQLite sandbox.
33
+
34
+ You can execute two types of actions:
35
+ 1. {"tool": "sql", "command": "<SQL query>"}
36
+ 2. {"tool": "python", "command": "<Python code>"}
37
+
38
+ Rules:
39
+ - Respond with EXACTLY ONE JSON object per turn -- no markdown, no explanation.
40
+ - In Python code, the variables `conn` (sqlite3.Connection) and `cursor`
41
+ (sqlite3.Cursor) are already available. Do NOT call sqlite3.connect().
42
+ - SQLite STRFTIME months are zero-padded: use '01' not '1', or use LIKE '2024-01-%'.
43
+ - When you believe the task is fully complete, send:
44
+ {"tool": "sql", "command": "SELECT 'DONE'"}
45
+ """
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # Core agent loop -- one task, one WebSocket session
50
+ # ---------------------------------------------------------------------------
51
+ def _run_task_agent(base_url: str, task_id: str, max_turns: int = 15) -> float:
52
+ """
53
+ Open a fresh WebSocket session, reset the environment to the given task,
54
+ then run an LLM agent loop until done or max_turns is reached.
55
+ Returns the final reward (0.0 - 1.0).
56
+ """
57
+ client_llm = Groq(api_key=os.environ["GROQ_API_KEY"])
58
+ final_reward = 0.0
59
+
60
+ # Each task gets its own WebSocket session to avoid state leakage
61
+ with SqlSandboxEnv(base_url=base_url).sync() as env:
62
+ # reset() with task_id seeds the correct DB table for this task
63
+ reset_resp = env.reset(task_id=task_id)
64
+ task_desc = reset_resp.observation.task_description
65
+
66
+ messages = [
67
+ {"role": "system", "content": SYSTEM_PROMPT},
68
+ {"role": "user", "content": f"Task: {task_desc}\n\nBegin."},
69
+ ]
70
+
71
+ print(f"\n --- Session: {task_id} ---")
72
+
73
+ for turn in range(max_turns):
74
+ # 1. Ask the LLM
75
+ response = client_llm.chat.completions.create(
76
+ model="llama-3.3-70b-versatile",
77
+ messages=messages,
78
+ temperature=0.0,
79
+ max_tokens=512,
80
+ )
81
+ assistant_msg = response.choices[0].message.content.strip()
82
+
83
+ # 2. Parse action JSON (handle optional markdown fences)
84
+ try:
85
+ raw = assistant_msg
86
+ if raw.startswith("```"):
87
+ raw = raw.split("```")[1]
88
+ if raw.startswith("json"):
89
+ raw = raw[4:]
90
+ action_data = json.loads(raw)
91
+ tool = action_data["tool"]
92
+ command = action_data["command"]
93
+ except (json.JSONDecodeError, KeyError):
94
+ # Feed parse error back to LLM, do NOT count as a step
95
+ messages.append({"role": "assistant", "content": assistant_msg})
96
+ messages.append({
97
+ "role": "user",
98
+ "content": (
99
+ 'Invalid JSON. Reply with exactly one JSON object:\n'
100
+ '{"tool": "sql" | "python", "command": "..."}'
101
+ ),
102
+ })
103
+ continue
104
+
105
+ # 3. Execute the action via OpenEnv step()
106
+ step_resp = env.step(SqlSandboxAction(tool=tool, command=command))
107
+
108
+ reward = step_resp.reward or 0.0
109
+ done = step_resp.done
110
+ output = step_resp.observation.output or ""
111
+ error = step_resp.observation.error or ""
112
+
113
+ final_reward = reward
114
+ print(f" [Turn {turn+1:02d}] tool={tool:<6} | reward={reward:.4f} | done={done}")
115
+
116
+ if done:
117
+ break
118
+
119
+ # 4. Feed result back to LLM for the next turn
120
+ messages.append({"role": "assistant", "content": assistant_msg})
121
+ feedback = f"Output:\n{output[:1500]}"
122
+ if error:
123
+ feedback += f"\nError:\n{error[:500]}"
124
+ feedback += f"\nReward so far: {reward:.4f}"
125
+ messages.append({"role": "user", "content": feedback})
126
+
127
+ return final_reward
128
+
129
+
130
+ # ---------------------------------------------------------------------------
131
+ # Per-difficulty entry points (called by main, importable for custom use)
132
+ # ---------------------------------------------------------------------------
133
+ def easy_run(base_url: str, max_turns: int = 15) -> float:
134
+ print(f"\n{'='*50}\nRunning task: easy\n{'='*50}")
135
+ score = _run_task_agent(base_url, "easy", max_turns)
136
+ print(f" Final score: {score:.4f}")
137
+ return score
138
+
139
+
140
+ def med_run(base_url: str, max_turns: int = 15) -> float:
141
+ print(f"\n{'='*50}\nRunning task: medium\n{'='*50}")
142
+ score = _run_task_agent(base_url, "medium", max_turns)
143
+ print(f" Final score: {score:.4f}")
144
+ return score
145
+
146
+
147
+ def hard_run(base_url: str, max_turns: int = 15) -> float:
148
+ print(f"\n{'='*50}\nRunning task: hard\n{'='*50}")
149
+ score = _run_task_agent(base_url, "hard", max_turns)
150
+ print(f" Final score: {score:.4f}")
151
+ return score
152
+
153
+
154
+ # ---------------------------------------------------------------------------
155
+ # CLI entry point
156
+ # ---------------------------------------------------------------------------
157
+ def main():
158
+ parser = argparse.ArgumentParser(
159
+ description="Groq baseline inference for the SQL/Data Cleaning Sandbox"
160
+ )
161
+ parser.add_argument(
162
+ "--url",
163
+ default="http://localhost:8000",
164
+ help="Base URL of the running environment server (default: http://localhost:8000)",
165
+ )
166
+ parser.add_argument(
167
+ "--max-turns",
168
+ type=int,
169
+ default=15,
170
+ help="Maximum agent turns per task (default: 15)",
171
+ )
172
+ args = parser.parse_args()
173
+
174
+ if "GROQ_API_KEY" not in os.environ:
175
+ print("ERROR: GROQ_API_KEY environment variable is not set.")
176
+ sys.exit(1)
177
+
178
+ results: dict[str, float] = {}
179
+ results["easy"] = easy_run(args.url, args.max_turns)
180
+ results["medium"] = med_run(args.url, args.max_turns)
181
+ results["hard"] = hard_run(args.url, args.max_turns)
182
+
183
+ avg = sum(results.values()) / len(results)
184
+ print(f"\n{'='*50}")
185
+ print("RESULTS SUMMARY")
186
+ print(f"{'='*50}")
187
+ for task_id, score in results.items():
188
+ print(f" {task_id:<10}: {score:.4f}")
189
+ print(f" {'average':<10}: {avg:.4f}")
190
+
191
+
192
+ if __name__ == "__main__":
193
+ main()
client.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """SQL Sandbox Environment Client."""
8
+
9
+ from typing import Dict
10
+
11
+ from openenv.core import EnvClient
12
+ from openenv.core.client_types import StepResult
13
+ from openenv.core.env_server.types import State
14
+
15
+ from models import SqlSandboxAction, SqlSandboxObservation
16
+
17
+
18
+ class SqlSandboxEnv(EnvClient[SqlSandboxAction, SqlSandboxObservation, State]):
19
+ """
20
+ Client for the SQL/Data Cleaning Sandbox.
21
+
22
+ Example:
23
+ >>> with SqlSandboxEnv(base_url="http://localhost:8000") as client:
24
+ ... result = client.reset()
25
+ ... print(result.observation.task_description)
26
+ ... result = client.step(SqlSandboxAction(tool="sql", command="SELECT * FROM sales"))
27
+ ... print(result.observation.output)
28
+ """
29
+
30
+ def _step_payload(self, action: SqlSandboxAction) -> Dict:
31
+ return {"tool": action.tool, "command": action.command}
32
+
33
+ def _parse_result(self, payload: Dict) -> StepResult[SqlSandboxObservation]:
34
+ obs_data = payload.get("observation", {})
35
+ observation = SqlSandboxObservation(
36
+ output=obs_data.get("output", ""),
37
+ error=obs_data.get("error"),
38
+ current_step=obs_data.get("current_step", 0),
39
+ max_steps=obs_data.get("max_steps", 20),
40
+ task_description=obs_data.get("task_description", ""),
41
+ done=payload.get("done", False),
42
+ reward=payload.get("reward"),
43
+ metadata=obs_data.get("metadata", {}),
44
+ )
45
+ return StepResult(
46
+ observation=observation,
47
+ reward=payload.get("reward"),
48
+ done=payload.get("done", False),
49
+ )
50
+
51
+ def _parse_state(self, payload: Dict) -> State:
52
+ return State(
53
+ episode_id=payload.get("episode_id"),
54
+ step_count=payload.get("step_count", 0),
55
+ )
inference.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Baseline Inference Script for SQL/Data Cleaning Sandbox OpenAI Edition.
3
+
4
+ Uses OpenAI (gpt-4o) to solve all three tasks and prints reproducible
5
+ scores via the OpenEnv WebSocket client.
6
+
7
+ Usage:
8
+ set HF_TOKEN=sk-... # Windows
9
+ export HF_TOKEN=sk-... # Linux/macOS
10
+ python inference.py # local server
11
+ python inference.py --url https://... # remote server
12
+ """
13
+
14
+ import argparse
15
+ import json
16
+ import os
17
+ import sys
18
+
19
+ from dotenv import load_dotenv
20
+ load_dotenv()
21
+
22
+ from openai import OpenAI
23
+
24
+ from client import SqlSandboxEnv
25
+ from models import SqlSandboxAction
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # System prompt shared across all tasks
30
+ # ---------------------------------------------------------------------------
31
+ SYSTEM_PROMPT = """\
32
+ You are a data engineering assistant working inside a SQLite sandbox.
33
+
34
+ You can execute two types of actions:
35
+ 1. {"tool": "sql", "command": "<SQL query>"}
36
+ 2. {"tool": "python", "command": "<Python code>"}
37
+
38
+ Rules:
39
+ - Respond with EXACTLY ONE JSON object per turn no markdown, no explanation.
40
+ - In Python code, the variables `conn` (sqlite3.Connection) and `cursor`
41
+ (sqlite3.Cursor) are already available. Do NOT call sqlite3.connect().
42
+ - SQLite STRFTIME months are zero-padded: use '01' not '1', or use LIKE '2024-01-%'.
43
+ - When you believe the task is fully complete, send:
44
+ {"tool": "sql", "command": "SELECT 'DONE'"}
45
+ """
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # Core agent loop one task, one WebSocket session
50
+ # ---------------------------------------------------------------------------
51
+ def _run_task_agent(base_url: str, task_id: str, max_turns: int = 15) -> float:
52
+ """
53
+ Open a fresh WebSocket session, reset the environment to the given task,
54
+ then run an LLM agent loop until done or max_turns is reached.
55
+ Returns the final reward (0.0 1.0).
56
+ """
57
+ api_key = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY")
58
+ api_base_url = os.environ.get("API_BASE_URL")
59
+ model_name = os.environ.get("MODEL_NAME", "gpt-4o")
60
+
61
+ client_llm = OpenAI(
62
+ api_key=api_key,
63
+ base_url=api_base_url,
64
+ )
65
+ final_reward = 0.0
66
+
67
+ # Each task gets its own WebSocket session to avoid state leakage
68
+ with SqlSandboxEnv(base_url=base_url).sync() as env:
69
+ # reset() with task_id seeds the correct DB table for this task
70
+ reset_resp = env.reset(task_id=task_id)
71
+ task_desc = reset_resp.observation.task_description
72
+
73
+ messages = [
74
+ {"role": "system", "content": SYSTEM_PROMPT},
75
+ {"role": "user", "content": f"Task: {task_desc}\n\nBegin."},
76
+ ]
77
+
78
+ print(f"\n --- Session: {task_id} ---")
79
+
80
+ for turn in range(max_turns):
81
+ # 1. Ask the LLM
82
+ response = client_llm.chat.completions.create(
83
+ model=model_name,
84
+ messages=messages,
85
+ temperature=0.0,
86
+ max_tokens=512,
87
+ )
88
+ assistant_msg = response.choices[0].message.content.strip()
89
+
90
+ # 2. Parse action JSON (handle optional markdown fences)
91
+ try:
92
+ raw = assistant_msg
93
+ if raw.startswith("```"):
94
+ raw = raw.split("```")[1]
95
+ if raw.startswith("json"):
96
+ raw = raw[4:]
97
+ action_data = json.loads(raw)
98
+ tool = action_data["tool"]
99
+ command = action_data["command"]
100
+ except (json.JSONDecodeError, KeyError):
101
+ # Feed parse error back to LLM, do NOT count as a step
102
+ messages.append({"role": "assistant", "content": assistant_msg})
103
+ messages.append({
104
+ "role": "user",
105
+ "content": (
106
+ 'Invalid JSON. Reply with exactly one JSON object:\n'
107
+ '{"tool": "sql" | "python", "command": "..."}'
108
+ ),
109
+ })
110
+ continue
111
+
112
+ # 3. Execute the action via OpenEnv step()
113
+ step_resp = env.step(SqlSandboxAction(tool=tool, command=command))
114
+
115
+ reward = step_resp.reward or 0.0
116
+ done = step_resp.done
117
+ output = step_resp.observation.output or ""
118
+ error = step_resp.observation.error or ""
119
+
120
+ final_reward = reward
121
+ print(f" [Turn {turn+1:02d}] tool={tool:<6} | reward={reward:.4f} | done={done}")
122
+
123
+ if done:
124
+ break
125
+
126
+ # 4. Feed result back to LLM for the next turn
127
+ messages.append({"role": "assistant", "content": assistant_msg})
128
+ feedback = f"Output:\n{output[:1500]}"
129
+ if error:
130
+ feedback += f"\nError:\n{error[:500]}"
131
+ feedback += f"\nReward so far: {reward:.4f}"
132
+ messages.append({"role": "user", "content": feedback})
133
+
134
+ return final_reward
135
+
136
+
137
+ # ---------------------------------------------------------------------------
138
+ # Per-difficulty entry points (called by main, importable for custom use)
139
+ # ---------------------------------------------------------------------------
140
+ def easy_run(base_url: str, max_turns: int = 15) -> float:
141
+ print(f"\n{'='*50}\nRunning task: easy\n{'='*50}")
142
+ score = _run_task_agent(base_url, "easy", max_turns)
143
+ print(f" Final score: {score:.4f}")
144
+ return score
145
+
146
+
147
+ def med_run(base_url: str, max_turns: int = 15) -> float:
148
+ print(f"\n{'='*50}\nRunning task: medium\n{'='*50}")
149
+ score = _run_task_agent(base_url, "medium", max_turns)
150
+ print(f" Final score: {score:.4f}")
151
+ return score
152
+
153
+
154
+ def hard_run(base_url: str, max_turns: int = 15) -> float:
155
+ print(f"\n{'='*50}\nRunning task: hard\n{'='*50}")
156
+ score = _run_task_agent(base_url, "hard", max_turns)
157
+ print(f" Final score: {score:.4f}")
158
+ return score
159
+
160
+
161
+ # ---------------------------------------------------------------------------
162
+ # CLI entry point
163
+ # ---------------------------------------------------------------------------
164
+ def main():
165
+ parser = argparse.ArgumentParser(
166
+ description="OpenAI baseline inference for the SQL/Data Cleaning Sandbox"
167
+ )
168
+ parser.add_argument(
169
+ "--url",
170
+ default="http://localhost:8000",
171
+ help="Base URL of the running environment server (default: http://localhost:8000)",
172
+ )
173
+ parser.add_argument(
174
+ "--max-turns",
175
+ type=int,
176
+ default=15,
177
+ help="Maximum agent turns per task (default: 15)",
178
+ )
179
+ args = parser.parse_args()
180
+
181
+ if not os.environ.get("HF_TOKEN") and not os.environ.get("OPENAI_API_KEY"):
182
+ print("ERROR: HF_TOKEN (or OPENAI_API_KEY) environment variable is not set per checklist.")
183
+ sys.exit(1)
184
+
185
+ results: dict[str, float] = {}
186
+ results["easy"] = easy_run(args.url, args.max_turns)
187
+ results["medium"] = med_run(args.url, args.max_turns)
188
+ results["hard"] = hard_run(args.url, args.max_turns)
189
+
190
+ avg = sum(results.values()) / len(results)
191
+ print(f"\n{'='*50}")
192
+ print("RESULTS SUMMARY")
193
+ print(f"{'='*50}")
194
+ for task_id, score in results.items():
195
+ print(f" {task_id:<10}: {score:.4f}")
196
+ print(f" {'average':<10}: {avg:.4f}")
197
+
198
+
199
+ if __name__ == "__main__":
200
+ main()
models.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Data models for the SQL/Data Cleaning Sandbox Environment.
9
+
10
+ Agents interact by sending SQL queries or Python snippets to clean
11
+ messy databases and generate reports.
12
+ """
13
+
14
+ from typing import Literal, Optional
15
+
16
+ from openenv.core.env_server.types import Action, Observation
17
+ from pydantic import Field
18
+
19
+
20
+ class SqlSandboxAction(Action):
21
+ """Action for the SQL Sandbox run a SQL query or Python snippet."""
22
+
23
+ tool: Literal["sql", "python"] = Field(
24
+ ..., description="Tool to use: 'sql' for SQLite queries, 'python' for Python scripts"
25
+ )
26
+ command: str = Field(
27
+ ..., description="The SQL query or Python code to execute"
28
+ )
29
+
30
+
31
+ class SqlSandboxObservation(Observation):
32
+ """Observation returned after each step."""
33
+
34
+ output: str = Field(default="", description="stdout / query result")
35
+ error: Optional[str] = Field(default=None, description="stderr or error message")
36
+ current_step: int = Field(default=0, description="Current step number")
37
+ max_steps: int = Field(default=20, description="Maximum allowed steps")
38
+ task_description: str = Field(default="", description="Current task description")
openenv.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: sql_sandbox
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
7
+
8
+ description: >
9
+ SQL/Data Cleaning Sandbox - a real-world OpenEnv environment where AI agents
10
+ clean messy databases via SQL and Python. Three tasks from easy to hard with
11
+ partial-progress grading (0.0-1.0).
12
+
13
+ reward_range: [0.0, 1.0]
14
+
15
+ tasks:
16
+ - id: easy
17
+ name: Data Triage
18
+ difficulty: easy
19
+ description: Find the total revenue from sales for January 2024.
20
+ - id: medium
21
+ name: Data Cleaning
22
+ difficulty: medium
23
+ description: Fix duplicate emails, null ages, and case inconsistencies in the users table.
24
+ - id: hard
25
+ name: Schema Migration
26
+ difficulty: hard
27
+ description: Normalize a flat orders table into customers + orders with foreign keys.
28
+
29
+ action_space:
30
+ type: object
31
+ properties:
32
+ tool:
33
+ type: string
34
+ enum: [sql, python]
35
+ command:
36
+ type: string
37
+
38
+ observation_space:
39
+ type: object
40
+ properties:
41
+ output:
42
+ type: string
43
+ error:
44
+ type: string
45
+ current_step:
46
+ type: integer
47
+ max_steps:
48
+ type: integer
49
+ task_description:
50
+ type: string
pyproject.toml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "openenv-sql-sandbox"
7
+ version = "0.1.0"
8
+ description = "SQL/Data Cleaning Sandbox - An OpenEnv environment for agentic data engineering evaluation"
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ "openenv-core[core]>=0.2.2",
12
+ "requests>=2.31.0",
13
+ ]
14
+
15
+ [project.optional-dependencies]
16
+ dev = [
17
+ "pytest>=8.0.0",
18
+ "pytest-cov>=4.0.0",
19
+ ]
20
+ inference = [
21
+ "openai>=1.0.0",
22
+ "requests>=2.31.0",
23
+ "groq>=0.4.0",
24
+ ]
25
+
26
+ [project.scripts]
27
+ server = "sql_sandbox.server.app:main"
28
+
29
+ [tool.setuptools]
30
+ include-package-data = true
31
+ packages = ["sql_sandbox", "sql_sandbox.server"]
32
+ package-dir = { "sql_sandbox" = ".", "sql_sandbox.server" = "server" }
server/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Sql Sandbox environment server components."""
8
+
9
+ from .environment import SqlSandboxEnvironment
10
+
11
+ __all__ = ["SqlSandboxEnvironment"]
server/app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ FastAPI application for the Sql Sandbox Environment.
9
+
10
+ This module creates an HTTP server that exposes the SqlSandboxEnvironment
11
+ over HTTP and WebSocket endpoints, compatible with EnvClient.
12
+
13
+ Endpoints:
14
+ - POST /reset: Reset the environment
15
+ - POST /step: Execute an action
16
+ - GET /state: Get current environment state
17
+ - GET /schema: Get action/observation schemas
18
+ - WS /ws: WebSocket endpoint for persistent sessions
19
+
20
+ Usage:
21
+ # Development (with auto-reload):
22
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
23
+
24
+ # Production:
25
+ uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
26
+
27
+ # Or run directly:
28
+ python -m server.app
29
+ """
30
+
31
+ try:
32
+ from openenv.core.env_server.http_server import create_app
33
+ except Exception as e: # pragma: no cover
34
+ raise ImportError(
35
+ "openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
36
+ ) from e
37
+
38
+ try:
39
+ from ..models import SqlSandboxAction, SqlSandboxObservation
40
+ from .environment import SqlSandboxEnvironment
41
+ except (ImportError, ModuleNotFoundError):
42
+ from models import SqlSandboxAction, SqlSandboxObservation
43
+ from server.environment import SqlSandboxEnvironment
44
+
45
+
46
+ # Create the app with web interface and README integration
47
+ app = create_app(
48
+ SqlSandboxEnvironment,
49
+ SqlSandboxAction,
50
+ SqlSandboxObservation,
51
+ env_name="sql_sandbox",
52
+ max_concurrent_envs=10, # increase this number to allow more concurrent WebSocket sessions
53
+ )
54
+
55
+ import os
56
+ @app.post("/set_task/{task_id}")
57
+ def set_task(task_id: str):
58
+ os.environ["TASK_ID"] = task_id
59
+ return {"status": "ok", "task_id": task_id}
60
+
61
+
62
+ def main():
63
+ """
64
+ Entry point for direct execution via uv run or python -m.
65
+
66
+ This function enables running the server without Docker:
67
+ uv run --project . server
68
+ python -m sql_sandbox.server.app
69
+ """
70
+ import uvicorn
71
+
72
+ uvicorn.run(app, host="0.0.0.0", port=8000)
73
+
74
+
75
+ if __name__ == "__main__":
76
+ main()
server/environment.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ SQL/Data Cleaning Sandbox Environment Implementation.
9
+
10
+ Three tasks (easy medium hard) for AI agents:
11
+ 1. Data Triage query revenue from sales data
12
+ 2. Data Cleaning fix duplicates & nulls in a users table
13
+ 3. Schema Migration normalize a flat table into two related tables
14
+ """
15
+
16
+ import io
17
+ import os
18
+ import sqlite3
19
+ import sys
20
+ import tempfile
21
+ import traceback
22
+ from contextlib import redirect_stderr, redirect_stdout
23
+ from uuid import uuid4
24
+
25
+ from openenv.core.env_server.interfaces import Environment
26
+ from openenv.core.env_server.types import State
27
+
28
+ try:
29
+ from ..models import SqlSandboxAction, SqlSandboxObservation
30
+ except ImportError:
31
+ from models import SqlSandboxAction, SqlSandboxObservation
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Task definitions
35
+ # ---------------------------------------------------------------------------
36
+ TASKS = {
37
+ "easy": {
38
+ "id": "easy",
39
+ "description": (
40
+ "Find the total revenue from the 'sales' table for January 2024. "
41
+ "The table has columns: id, product, amount, sale_date (YYYY-MM-DD). "
42
+ "Return the exact total as a single number by running a SQL query. "
43
+ "The expected result should be a SELECT query that returns one number."
44
+ ),
45
+ "max_steps": 10,
46
+ },
47
+ "medium": {
48
+ "id": "medium",
49
+ "description": (
50
+ "The 'users' table has duplicate emails and NULL values in the 'age' column. "
51
+ "Clean the data so that: (1) all emails are lowercase, "
52
+ "(2) duplicate emails are removed (keep the row with the lowest id), "
53
+ "(3) all NULL ages are replaced with 0. "
54
+ "Use SQL or Python to fix the table in-place."
55
+ ),
56
+ "max_steps": 15,
57
+ },
58
+ "hard": {
59
+ "id": "hard",
60
+ "description": (
61
+ "The 'flat_orders' table has columns: order_id, order_date, "
62
+ "customer_name, customer_email, product, quantity, price. "
63
+ "Normalize this into two tables: 'customers' (id INTEGER PRIMARY KEY, "
64
+ "name TEXT, email TEXT UNIQUE) and 'orders' (id INTEGER PRIMARY KEY, "
65
+ "customer_id INTEGER REFERENCES customers(id), order_date TEXT, "
66
+ "product TEXT, quantity INTEGER, price REAL). "
67
+ "Maintain foreign key integrity and migrate all data."
68
+ ),
69
+ "max_steps": 20,
70
+ },
71
+ }
72
+
73
+ # ---------------------------------------------------------------------------
74
+ # Seed data generators
75
+ # ---------------------------------------------------------------------------
76
+
77
+ def _seed_easy(conn: sqlite3.Connection):
78
+ """Create sales table with known data."""
79
+ conn.execute("DROP TABLE IF EXISTS sales")
80
+ conn.execute(
81
+ "CREATE TABLE sales (id INTEGER PRIMARY KEY, product TEXT, amount REAL, sale_date TEXT)"
82
+ )
83
+ rows = [
84
+ (1, "Widget A", 150.00, "2024-01-05"),
85
+ (2, "Widget B", 250.50, "2024-01-12"),
86
+ (3, "Widget C", 99.99, "2024-01-20"),
87
+ (4, "Widget A", 150.00, "2024-01-28"),
88
+ (5, "Widget D", 349.51, "2024-01-15"),
89
+ (6, "Widget A", 200.00, "2024-02-03"),
90
+ (7, "Widget B", 75.00, "2023-12-30"),
91
+ ]
92
+ conn.executemany("INSERT INTO sales VALUES (?,?,?,?)", rows)
93
+ conn.commit()
94
+
95
+
96
+ def _seed_medium(conn: sqlite3.Connection):
97
+ """Create users table with messy data."""
98
+ conn.execute("DROP TABLE IF EXISTS users")
99
+ conn.execute(
100
+ "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, email TEXT, age INTEGER)"
101
+ )
102
+ rows = [
103
+ (1, "Alice", "Alice@Example.com", 30),
104
+ (2, "Bob", "bob@example.com", None),
105
+ (3, "Charlie", "charlie@test.com", 25),
106
+ (4, "Alice Dup", "alice@example.com", 28),
107
+ (5, "Dave", "DAVE@Test.COM", None),
108
+ (6, "Eve", "eve@example.com", 35),
109
+ (7, "Dave Dup", "dave@test.com", 40),
110
+ (8, "Frank", "frank@example.com", None),
111
+ ]
112
+ conn.executemany("INSERT INTO users VALUES (?,?,?,?)", rows)
113
+ conn.commit()
114
+
115
+
116
+ def _seed_hard(conn: sqlite3.Connection):
117
+ """Create flat_orders table."""
118
+ conn.execute("DROP TABLE IF EXISTS flat_orders")
119
+ conn.execute("DROP TABLE IF EXISTS customers")
120
+ conn.execute("DROP TABLE IF EXISTS orders")
121
+ conn.execute(
122
+ "CREATE TABLE flat_orders ("
123
+ "order_id INTEGER, order_date TEXT, customer_name TEXT, "
124
+ "customer_email TEXT, product TEXT, quantity INTEGER, price REAL)"
125
+ )
126
+ rows = [
127
+ (1, "2024-01-10", "Alice", "alice@example.com", "Laptop", 1, 999.99),
128
+ (2, "2024-01-11", "Bob", "bob@example.com", "Mouse", 2, 25.50),
129
+ (3, "2024-01-12", "Alice", "alice@example.com", "Keyboard", 1, 75.00),
130
+ (4, "2024-01-13", "Charlie", "charlie@example.com", "Monitor", 1, 300.00),
131
+ (5, "2024-01-14", "Bob", "bob@example.com", "Webcam", 1, 50.00),
132
+ (6, "2024-01-15", "Diana", "diana@example.com", "USB Hub", 3, 15.99),
133
+ ]
134
+ conn.executemany("INSERT INTO flat_orders VALUES (?,?,?,?,?,?,?)", rows)
135
+ conn.commit()
136
+
137
+
138
+ SEED_FNS = {"easy": _seed_easy, "medium": _seed_medium, "hard": _seed_hard}
139
+
140
+ # ---------------------------------------------------------------------------
141
+ # Graders
142
+ # ---------------------------------------------------------------------------
143
+
144
+ EASY_EXPECTED = 1000.00 # 150 + 250.5 + 99.99 + 150 + 349.51
145
+
146
+
147
+ def grade_easy(conn: sqlite3.Connection, last_output: str) -> float:
148
+ """Check if agent returned correct total revenue for Jan 2024."""
149
+ if not last_output:
150
+ return 0.0
151
+
152
+ # We inspect the agent's query execution result to see if 1000.0 is present.
153
+ try:
154
+ # Convert output strings to simple float checks.
155
+ import re
156
+ numbers = re.findall(r"[-+]?\d*\.\d+|\d+", last_output)
157
+ for num in numbers:
158
+ if abs(float(num) - EASY_EXPECTED) < 0.01:
159
+ return 1.0
160
+ except Exception:
161
+ pass
162
+ return 0.0
163
+
164
+
165
+ def grade_medium(conn: sqlite3.Connection, last_output: str) -> float:
166
+ """Check cleaning quality: no duplicates, no nulls, lowercase emails."""
167
+ score = 0.0
168
+ try:
169
+ # Check table exists
170
+ cur = conn.execute("SELECT COUNT(*) FROM users")
171
+ total = cur.fetchone()[0]
172
+ if total == 0:
173
+ return 0.0
174
+
175
+ # Check lowercase emails (0.3)
176
+ cur = conn.execute("SELECT COUNT(*) FROM users WHERE email != LOWER(email)")
177
+ upper_count = cur.fetchone()[0]
178
+ if upper_count == 0:
179
+ score += 0.3
180
+
181
+ # Check no duplicate emails (0.4)
182
+ cur = conn.execute(
183
+ "SELECT COUNT(*) FROM (SELECT LOWER(email) as e FROM users GROUP BY e HAVING COUNT(*) > 1)"
184
+ )
185
+ dup_count = cur.fetchone()[0]
186
+ if dup_count == 0:
187
+ score += 0.4
188
+
189
+ # Check no NULL ages (0.3)
190
+ cur = conn.execute("SELECT COUNT(*) FROM users WHERE age IS NULL")
191
+ null_count = cur.fetchone()[0]
192
+ if null_count == 0:
193
+ score += 0.3
194
+ except Exception:
195
+ pass
196
+ return round(score, 2)
197
+
198
+
199
+ def grade_hard(conn: sqlite3.Connection, last_output: str) -> float:
200
+ """Verify normalized schema and data integrity."""
201
+ score = 0.0
202
+ try:
203
+ # Check 'customers' table exists with correct columns (0.2)
204
+ cur = conn.execute("PRAGMA table_info(customers)")
205
+ cols = {r[1] for r in cur.fetchall()}
206
+ if {"id", "name", "email"}.issubset(cols):
207
+ score += 0.2
208
+
209
+ # Check 'orders' table exists with correct columns (0.2)
210
+ cur = conn.execute("PRAGMA table_info(orders)")
211
+ cols = {r[1] for r in cur.fetchall()}
212
+ if {"id", "customer_id", "order_date", "product", "quantity", "price"}.issubset(cols):
213
+ score += 0.2
214
+
215
+ # Check customer count = 4 unique customers (0.2)
216
+ cur = conn.execute("SELECT COUNT(*) FROM customers")
217
+ if cur.fetchone()[0] == 4:
218
+ score += 0.2
219
+
220
+ # Check orders count = 6 (0.2)
221
+ cur = conn.execute("SELECT COUNT(*) FROM orders")
222
+ if cur.fetchone()[0] == 6:
223
+ score += 0.2
224
+
225
+ # Check FK integrity: all customer_ids in orders exist in customers (0.2)
226
+ cur = conn.execute(
227
+ "SELECT COUNT(*) FROM orders WHERE customer_id NOT IN (SELECT id FROM customers)"
228
+ )
229
+ if cur.fetchone()[0] == 0:
230
+ score += 0.2
231
+ except Exception:
232
+ pass
233
+ return round(score, 2)
234
+
235
+
236
+ GRADERS = {"easy": grade_easy, "medium": grade_medium, "hard": grade_hard}
237
+
238
+ # ---------------------------------------------------------------------------
239
+ # Environment
240
+ # ---------------------------------------------------------------------------
241
+
242
+ class SqlSandboxEnvironment(Environment):
243
+ """
244
+ SQL / Data Cleaning Sandbox a real-world OpenEnv environment.
245
+
246
+ The agent sends SQL or Python commands to clean messy databases.
247
+ Partial progress rewards are given after each step.
248
+ """
249
+
250
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
251
+
252
+ def __init__(self):
253
+ self._state = State(episode_id=str(uuid4()), step_count=0)
254
+ self._db_path = os.path.join(tempfile.gettempdir(), f"sqlsandbox_{uuid4().hex[:8]}.db")
255
+ self._conn: sqlite3.Connection | None = None
256
+ self._task_id = os.environ.get("TASK_ID", "easy")
257
+ self._task = TASKS[self._task_id]
258
+ self._max_steps = self._task["max_steps"]
259
+ self._done = False
260
+ self._last_reward = 0.0
261
+
262
+ # ---- helpers -----------------------------------------------------------
263
+
264
+ def _get_conn(self) -> sqlite3.Connection:
265
+ if self._conn is None:
266
+ self._conn = sqlite3.connect(self._db_path)
267
+ self._conn.execute("PRAGMA foreign_keys = ON")
268
+ return self._conn
269
+
270
+ def _partial_reward(self, last_output: str) -> float:
271
+ """Run the grader to compute partial progress."""
272
+ return GRADERS[self._task_id](self._get_conn(), last_output)
273
+
274
+ def _exec_sql(self, query: str) -> tuple[str, str | None]:
275
+ try:
276
+ conn = self._get_conn()
277
+ cur = conn.execute(query)
278
+ if cur.description:
279
+ cols = [d[0] for d in cur.description]
280
+ rows = cur.fetchall()
281
+ header = " | ".join(cols)
282
+ body = "\n".join(" | ".join(str(c) for c in r) for r in rows)
283
+ output = f"{header}\n{body}" if rows else header + "\n(no rows)"
284
+ else:
285
+ output = f"OK {conn.total_changes} row(s) affected"
286
+ conn.commit()
287
+ return output, None
288
+ except Exception as e:
289
+ return "", str(e)
290
+
291
+ def _exec_python(self, code: str) -> tuple[str, str | None]:
292
+ stdout_buf, stderr_buf = io.StringIO(), io.StringIO()
293
+ try:
294
+ conn = self._get_conn()
295
+ cursor = conn.cursor()
296
+ globs = {
297
+ "__builtins__": __builtins__,
298
+ "sqlite3": sqlite3,
299
+ "DB_PATH": self._db_path,
300
+ "conn": conn,
301
+ "cursor": cursor,
302
+ }
303
+ with redirect_stdout(stdout_buf), redirect_stderr(stderr_buf):
304
+ exec(code, globs)
305
+
306
+ # Automatically commit any schema changes the LLM's python code made
307
+ conn.commit()
308
+
309
+ out = stdout_buf.getvalue()
310
+ err = stderr_buf.getvalue() or None
311
+ return out, err
312
+ except Exception:
313
+ return stdout_buf.getvalue(), traceback.format_exc()
314
+
315
+ # ---- OpenEnv interface -------------------------------------------------
316
+ def reset(self, **kwargs) -> SqlSandboxObservation:
317
+ """Resets the environment and forces a task switch if task_id is provided."""
318
+
319
+ # 1. Close current connection to ensure file handles are released
320
+ if self._conn:
321
+ self._conn.close()
322
+ self._conn = None
323
+
324
+ # 2. Update task context from kwargs (primary) or environment (fallback)
325
+ # This is the fix for the 'Easy task persistence' bug.
326
+ self._task_id = kwargs.get("task_id", os.environ.get("TASK_ID", "easy"))
327
+ self._task = TASKS[self._task_id]
328
+ self._max_steps = self._task["max_steps"]
329
+
330
+ # 3. Re-initialize episode state
331
+ self._state = State(episode_id=str(uuid4()), step_count=0)
332
+ self._done = False
333
+ self._last_reward = 0.0
334
+
335
+ # 4. Open fresh connection and re-seed for the specific task_id
336
+ # Seed functions use 'DROP TABLE IF EXISTS' which handles cleanup.
337
+ conn = self._get_conn()
338
+ SEED_FNS[self._task_id](conn)
339
+
340
+ return SqlSandboxObservation(
341
+ output=f"Environment ready. Task: {self._task['description']}",
342
+ error=None,
343
+ current_step=0,
344
+ max_steps=self._max_steps,
345
+ task_description=self._task["description"],
346
+ done=False,
347
+ reward=0.0,
348
+ )
349
+
350
+ def step(self, action: SqlSandboxAction) -> SqlSandboxObservation: # type: ignore[override]
351
+ self._state.step_count += 1
352
+ step = self._state.step_count
353
+
354
+ if self._done:
355
+ return SqlSandboxObservation(
356
+ output="Episode already finished. Call reset().",
357
+ error=None,
358
+ current_step=step,
359
+ max_steps=self._max_steps,
360
+ task_description=self._task["description"],
361
+ done=True,
362
+ reward=self._last_reward,
363
+ )
364
+
365
+ # Execute action
366
+ if action.tool == "sql":
367
+ output, error = self._exec_sql(action.command)
368
+ else:
369
+ output, error = self._exec_python(action.command)
370
+
371
+ # Compute partial reward
372
+ reward = self._partial_reward(output)
373
+
374
+ # Check termination
375
+ done = step >= self._max_steps or reward >= 1.0
376
+ if done:
377
+ self._done = True
378
+
379
+ self._last_reward = reward
380
+
381
+ # Small penalty for errors to discourage random guessing
382
+ if error:
383
+ reward = max(0.0, reward - 0.05)
384
+
385
+ return SqlSandboxObservation(
386
+ output=output[:4000], # cap output size
387
+ error=error[:2000] if error else None,
388
+ current_step=step,
389
+ max_steps=self._max_steps,
390
+ task_description=self._task["description"],
391
+ done=done,
392
+ reward=round(reward, 4),
393
+ )
394
+
395
+ @property
396
+ def state(self) -> State:
397
+ return self._state
server/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ openenv-core[core]>=0.2.2
2
+ requests>=2.31.0
uv.lock ADDED
The diff for this file is too large to render. See raw diff