Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| SQL/Data Cleaning Sandbox Environment Implementation. | |
| Three tasks (easy medium hard) for AI agents: | |
| 1. Data Triage query revenue from sales data | |
| 2. Data Cleaning fix duplicates & nulls in a users table | |
| 3. Schema Migration normalize a flat table into two related tables | |
| """ | |
| import io | |
| import os | |
| import sqlite3 | |
| import sys | |
| import tempfile | |
| import traceback | |
| from contextlib import redirect_stderr, redirect_stdout | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| try: | |
| from ..models import SqlSandboxAction, SqlSandboxObservation | |
| except ImportError: | |
| from models import SqlSandboxAction, SqlSandboxObservation | |
| # --------------------------------------------------------------------------- | |
| # Task definitions | |
| # --------------------------------------------------------------------------- | |
| TASKS = { | |
| "easy": { | |
| "id": "easy", | |
| "description": ( | |
| "Find the total revenue from the 'sales' table for January 2024. " | |
| "The table has columns: id, product, amount, sale_date (YYYY-MM-DD). " | |
| "Return the exact total as a single number by running a SQL query. " | |
| "The expected result should be a SELECT query that returns one number." | |
| ), | |
| "max_steps": 10, | |
| }, | |
| "medium": { | |
| "id": "medium", | |
| "description": ( | |
| "The 'users' table has duplicate emails and NULL values in the 'age' column. " | |
| "Clean the data so that: (1) all emails are lowercase, " | |
| "(2) duplicate emails are removed (keep the row with the lowest id), " | |
| "(3) all NULL ages are replaced with 0. " | |
| "Use SQL or Python to fix the table in-place." | |
| ), | |
| "max_steps": 15, | |
| }, | |
| "hard": { | |
| "id": "hard", | |
| "description": ( | |
| "The 'flat_orders' table has columns: order_id, order_date, " | |
| "customer_name, customer_email, product, quantity, price. " | |
| "Normalize this into two tables: 'customers' (id INTEGER PRIMARY KEY, " | |
| "name TEXT, email TEXT UNIQUE) and 'orders' (id INTEGER PRIMARY KEY, " | |
| "customer_id INTEGER REFERENCES customers(id), order_date TEXT, " | |
| "product TEXT, quantity INTEGER, price REAL). " | |
| "Maintain foreign key integrity and migrate all data." | |
| ), | |
| "max_steps": 20, | |
| }, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Seed data generators | |
| # --------------------------------------------------------------------------- | |
| def _seed_easy(conn: sqlite3.Connection): | |
| """Create sales table with known data.""" | |
| conn.execute("DROP TABLE IF EXISTS sales") | |
| conn.execute( | |
| "CREATE TABLE sales (id INTEGER PRIMARY KEY, product TEXT, amount REAL, sale_date TEXT)" | |
| ) | |
| rows = [ | |
| (1, "Widget A", 150.00, "2024-01-05"), | |
| (2, "Widget B", 250.50, "2024-01-12"), | |
| (3, "Widget C", 99.99, "2024-01-20"), | |
| (4, "Widget A", 150.00, "2024-01-28"), | |
| (5, "Widget D", 349.51, "2024-01-15"), | |
| (6, "Widget A", 200.00, "2024-02-03"), | |
| (7, "Widget B", 75.00, "2023-12-30"), | |
| ] | |
| conn.executemany("INSERT INTO sales VALUES (?,?,?,?)", rows) | |
| conn.commit() | |
| def _seed_medium(conn: sqlite3.Connection): | |
| """Create users table with messy data.""" | |
| conn.execute("DROP TABLE IF EXISTS users") | |
| conn.execute( | |
| "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, email TEXT, age INTEGER)" | |
| ) | |
| rows = [ | |
| (1, "Alice", "Alice@Example.com", 30), | |
| (2, "Bob", "bob@example.com", None), | |
| (3, "Charlie", "charlie@test.com", 25), | |
| (4, "Alice Dup", "alice@example.com", 28), | |
| (5, "Dave", "DAVE@Test.COM", None), | |
| (6, "Eve", "eve@example.com", 35), | |
| (7, "Dave Dup", "dave@test.com", 40), | |
| (8, "Frank", "frank@example.com", None), | |
| ] | |
| conn.executemany("INSERT INTO users VALUES (?,?,?,?)", rows) | |
| conn.commit() | |
| def _seed_hard(conn: sqlite3.Connection): | |
| """Create flat_orders table.""" | |
| conn.execute("DROP TABLE IF EXISTS flat_orders") | |
| conn.execute("DROP TABLE IF EXISTS customers") | |
| conn.execute("DROP TABLE IF EXISTS orders") | |
| conn.execute( | |
| "CREATE TABLE flat_orders (" | |
| "order_id INTEGER, order_date TEXT, customer_name TEXT, " | |
| "customer_email TEXT, product TEXT, quantity INTEGER, price REAL)" | |
| ) | |
| rows = [ | |
| (1, "2024-01-10", "Alice", "alice@example.com", "Laptop", 1, 999.99), | |
| (2, "2024-01-11", "Bob", "bob@example.com", "Mouse", 2, 25.50), | |
| (3, "2024-01-12", "Alice", "alice@example.com", "Keyboard", 1, 75.00), | |
| (4, "2024-01-13", "Charlie", "charlie@example.com", "Monitor", 1, 300.00), | |
| (5, "2024-01-14", "Bob", "bob@example.com", "Webcam", 1, 50.00), | |
| (6, "2024-01-15", "Diana", "diana@example.com", "USB Hub", 3, 15.99), | |
| ] | |
| conn.executemany("INSERT INTO flat_orders VALUES (?,?,?,?,?,?,?)", rows) | |
| conn.commit() | |
| SEED_FNS = {"easy": _seed_easy, "medium": _seed_medium, "hard": _seed_hard} | |
| # --------------------------------------------------------------------------- | |
| # Graders | |
| # --------------------------------------------------------------------------- | |
| EASY_EXPECTED = 1000.00 # 150 + 250.5 + 99.99 + 150 + 349.51 | |
| def grade_easy(conn: sqlite3.Connection, last_output: str) -> float: | |
| """Check if agent returned correct total revenue for Jan 2024.""" | |
| if not last_output: | |
| return 0.0 | |
| # We inspect the agent's query execution result to see if 1000.0 is present. | |
| try: | |
| # Convert output strings to simple float checks. | |
| import re | |
| numbers = re.findall(r"[-+]?\d*\.\d+|\d+", last_output) | |
| for num in numbers: | |
| if abs(float(num) - EASY_EXPECTED) < 0.01: | |
| return 1.0 | |
| except Exception: | |
| pass | |
| return 0.0 | |
| def grade_medium(conn: sqlite3.Connection, last_output: str) -> float: | |
| """Check cleaning quality: no duplicates, no nulls, lowercase emails.""" | |
| score = 0.0 | |
| try: | |
| # Check table exists | |
| cur = conn.execute("SELECT COUNT(*) FROM users") | |
| total = cur.fetchone()[0] | |
| if total == 0: | |
| return 0.0 | |
| # Check lowercase emails (0.3) | |
| cur = conn.execute("SELECT COUNT(*) FROM users WHERE email != LOWER(email)") | |
| upper_count = cur.fetchone()[0] | |
| if upper_count == 0: | |
| score += 0.3 | |
| # Check no duplicate emails (0.4) | |
| cur = conn.execute( | |
| "SELECT COUNT(*) FROM (SELECT LOWER(email) as e FROM users GROUP BY e HAVING COUNT(*) > 1)" | |
| ) | |
| dup_count = cur.fetchone()[0] | |
| if dup_count == 0: | |
| score += 0.4 | |
| # Check no NULL ages (0.3) | |
| cur = conn.execute("SELECT COUNT(*) FROM users WHERE age IS NULL") | |
| null_count = cur.fetchone()[0] | |
| if null_count == 0: | |
| score += 0.3 | |
| except Exception: | |
| pass | |
| return round(score, 2) | |
| def grade_hard(conn: sqlite3.Connection, last_output: str) -> float: | |
| """Verify normalized schema and data integrity.""" | |
| score = 0.0 | |
| try: | |
| # Check 'customers' table exists with correct columns (0.2) | |
| cur = conn.execute("PRAGMA table_info(customers)") | |
| cols = {r[1] for r in cur.fetchall()} | |
| if {"id", "name", "email"}.issubset(cols): | |
| score += 0.2 | |
| # Check 'orders' table exists with correct columns (0.2) | |
| cur = conn.execute("PRAGMA table_info(orders)") | |
| cols = {r[1] for r in cur.fetchall()} | |
| if {"id", "customer_id", "order_date", "product", "quantity", "price"}.issubset(cols): | |
| score += 0.2 | |
| # Check customer count = 4 unique customers (0.2) | |
| cur = conn.execute("SELECT COUNT(*) FROM customers") | |
| if cur.fetchone()[0] == 4: | |
| score += 0.2 | |
| # Check orders count = 6 (0.2) | |
| cur = conn.execute("SELECT COUNT(*) FROM orders") | |
| if cur.fetchone()[0] == 6: | |
| score += 0.2 | |
| # Check FK integrity: all customer_ids in orders exist in customers (0.2) | |
| cur = conn.execute( | |
| "SELECT COUNT(*) FROM orders WHERE customer_id NOT IN (SELECT id FROM customers)" | |
| ) | |
| if cur.fetchone()[0] == 0: | |
| score += 0.2 | |
| except Exception: | |
| pass | |
| return round(score, 2) | |
| GRADERS = {"easy": grade_easy, "medium": grade_medium, "hard": grade_hard} | |
| # --------------------------------------------------------------------------- | |
| # Environment | |
| # --------------------------------------------------------------------------- | |
| class SqlSandboxEnvironment(Environment): | |
| """ | |
| SQL / Data Cleaning Sandbox a real-world OpenEnv environment. | |
| The agent sends SQL or Python commands to clean messy databases. | |
| Partial progress rewards are given after each step. | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self): | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._db_path = os.path.join(tempfile.gettempdir(), f"sqlsandbox_{uuid4().hex[:8]}.db") | |
| self._conn: sqlite3.Connection | None = None | |
| self._task_id = os.environ.get("TASK_ID", "easy") | |
| self._task = TASKS[self._task_id] | |
| self._max_steps = self._task["max_steps"] | |
| self._done = False | |
| self._last_reward = 0.0 | |
| # ---- helpers ----------------------------------------------------------- | |
| def _get_conn(self) -> sqlite3.Connection: | |
| if self._conn is None: | |
| self._conn = sqlite3.connect(self._db_path) | |
| self._conn.execute("PRAGMA foreign_keys = ON") | |
| return self._conn | |
| def _partial_reward(self, last_output: str) -> float: | |
| """Run the grader to compute partial progress.""" | |
| return GRADERS[self._task_id](self._get_conn(), last_output) | |
| def _exec_sql(self, query: str) -> tuple[str, str | None]: | |
| try: | |
| conn = self._get_conn() | |
| cur = conn.execute(query) | |
| if cur.description: | |
| cols = [d[0] for d in cur.description] | |
| rows = cur.fetchall() | |
| header = " | ".join(cols) | |
| body = "\n".join(" | ".join(str(c) for c in r) for r in rows) | |
| output = f"{header}\n{body}" if rows else header + "\n(no rows)" | |
| else: | |
| output = f"OK {conn.total_changes} row(s) affected" | |
| conn.commit() | |
| return output, None | |
| except Exception as e: | |
| return "", str(e) | |
| def _exec_python(self, code: str) -> tuple[str, str | None]: | |
| stdout_buf, stderr_buf = io.StringIO(), io.StringIO() | |
| try: | |
| conn = self._get_conn() | |
| cursor = conn.cursor() | |
| globs = { | |
| "__builtins__": __builtins__, | |
| "sqlite3": sqlite3, | |
| "DB_PATH": self._db_path, | |
| "conn": conn, | |
| "cursor": cursor, | |
| } | |
| with redirect_stdout(stdout_buf), redirect_stderr(stderr_buf): | |
| exec(code, globs) | |
| # Automatically commit any schema changes the LLM's python code made | |
| conn.commit() | |
| out = stdout_buf.getvalue() | |
| err = stderr_buf.getvalue() or None | |
| return out, err | |
| except Exception: | |
| return stdout_buf.getvalue(), traceback.format_exc() | |
| # ---- OpenEnv interface ------------------------------------------------- | |
| def reset(self, **kwargs) -> SqlSandboxObservation: | |
| """Resets the environment and forces a task switch if task_id is provided.""" | |
| # 1. Close current connection to ensure file handles are released | |
| if self._conn: | |
| self._conn.close() | |
| self._conn = None | |
| # 2. Update task context from kwargs (primary) or environment (fallback) | |
| # This is the fix for the 'Easy task persistence' bug. | |
| self._task_id = kwargs.get("task_id", os.environ.get("TASK_ID", "easy")) | |
| self._task = TASKS[self._task_id] | |
| self._max_steps = self._task["max_steps"] | |
| # 3. Re-initialize episode state | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._done = False | |
| self._last_reward = 0.0 | |
| # 4. Open fresh connection and re-seed for the specific task_id | |
| # Seed functions use 'DROP TABLE IF EXISTS' which handles cleanup. | |
| conn = self._get_conn() | |
| SEED_FNS[self._task_id](conn) | |
| return SqlSandboxObservation( | |
| output=f"Environment ready. Task: {self._task['description']}", | |
| error=None, | |
| current_step=0, | |
| max_steps=self._max_steps, | |
| task_description=self._task["description"], | |
| done=False, | |
| reward=0.0, | |
| ) | |
| def step(self, action: SqlSandboxAction) -> SqlSandboxObservation: # type: ignore[override] | |
| self._state.step_count += 1 | |
| step = self._state.step_count | |
| if self._done: | |
| return SqlSandboxObservation( | |
| output="Episode already finished. Call reset().", | |
| error=None, | |
| current_step=step, | |
| max_steps=self._max_steps, | |
| task_description=self._task["description"], | |
| done=True, | |
| reward=self._last_reward, | |
| ) | |
| # Execute action | |
| if action.tool == "sql": | |
| output, error = self._exec_sql(action.command) | |
| else: | |
| output, error = self._exec_python(action.command) | |
| # Compute partial reward | |
| reward = self._partial_reward(output) | |
| # Check termination | |
| done = step >= self._max_steps or reward >= 1.0 | |
| if done: | |
| self._done = True | |
| self._last_reward = reward | |
| # Small penalty for errors to discourage random guessing | |
| if error: | |
| reward = max(0.0, reward - 0.05) | |
| return SqlSandboxObservation( | |
| output=output[:4000], # cap output size | |
| error=error[:2000] if error else None, | |
| current_step=step, | |
| max_steps=self._max_steps, | |
| task_description=self._task["description"], | |
| done=done, | |
| reward=round(reward, 4), | |
| ) | |
| def state(self) -> State: | |
| return self._state | |