Meta-Pytorch-Openenv / server /environment.py
shreyas231219's picture
Upload folder using huggingface_hub
615101f verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
SQL/Data Cleaning Sandbox Environment Implementation.
Three tasks (easy medium hard) for AI agents:
1. Data Triage query revenue from sales data
2. Data Cleaning fix duplicates & nulls in a users table
3. Schema Migration normalize a flat table into two related tables
"""
import io
import os
import sqlite3
import sys
import tempfile
import traceback
from contextlib import redirect_stderr, redirect_stdout
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..models import SqlSandboxAction, SqlSandboxObservation
except ImportError:
from models import SqlSandboxAction, SqlSandboxObservation
# ---------------------------------------------------------------------------
# Task definitions
# ---------------------------------------------------------------------------
TASKS = {
"easy": {
"id": "easy",
"description": (
"Find the total revenue from the 'sales' table for January 2024. "
"The table has columns: id, product, amount, sale_date (YYYY-MM-DD). "
"Return the exact total as a single number by running a SQL query. "
"The expected result should be a SELECT query that returns one number."
),
"max_steps": 10,
},
"medium": {
"id": "medium",
"description": (
"The 'users' table has duplicate emails and NULL values in the 'age' column. "
"Clean the data so that: (1) all emails are lowercase, "
"(2) duplicate emails are removed (keep the row with the lowest id), "
"(3) all NULL ages are replaced with 0. "
"Use SQL or Python to fix the table in-place."
),
"max_steps": 15,
},
"hard": {
"id": "hard",
"description": (
"The 'flat_orders' table has columns: order_id, order_date, "
"customer_name, customer_email, product, quantity, price. "
"Normalize this into two tables: 'customers' (id INTEGER PRIMARY KEY, "
"name TEXT, email TEXT UNIQUE) and 'orders' (id INTEGER PRIMARY KEY, "
"customer_id INTEGER REFERENCES customers(id), order_date TEXT, "
"product TEXT, quantity INTEGER, price REAL). "
"Maintain foreign key integrity and migrate all data."
),
"max_steps": 20,
},
}
# ---------------------------------------------------------------------------
# Seed data generators
# ---------------------------------------------------------------------------
def _seed_easy(conn: sqlite3.Connection):
"""Create sales table with known data."""
conn.execute("DROP TABLE IF EXISTS sales")
conn.execute(
"CREATE TABLE sales (id INTEGER PRIMARY KEY, product TEXT, amount REAL, sale_date TEXT)"
)
rows = [
(1, "Widget A", 150.00, "2024-01-05"),
(2, "Widget B", 250.50, "2024-01-12"),
(3, "Widget C", 99.99, "2024-01-20"),
(4, "Widget A", 150.00, "2024-01-28"),
(5, "Widget D", 349.51, "2024-01-15"),
(6, "Widget A", 200.00, "2024-02-03"),
(7, "Widget B", 75.00, "2023-12-30"),
]
conn.executemany("INSERT INTO sales VALUES (?,?,?,?)", rows)
conn.commit()
def _seed_medium(conn: sqlite3.Connection):
"""Create users table with messy data."""
conn.execute("DROP TABLE IF EXISTS users")
conn.execute(
"CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, email TEXT, age INTEGER)"
)
rows = [
(1, "Alice", "Alice@Example.com", 30),
(2, "Bob", "bob@example.com", None),
(3, "Charlie", "charlie@test.com", 25),
(4, "Alice Dup", "alice@example.com", 28),
(5, "Dave", "DAVE@Test.COM", None),
(6, "Eve", "eve@example.com", 35),
(7, "Dave Dup", "dave@test.com", 40),
(8, "Frank", "frank@example.com", None),
]
conn.executemany("INSERT INTO users VALUES (?,?,?,?)", rows)
conn.commit()
def _seed_hard(conn: sqlite3.Connection):
"""Create flat_orders table."""
conn.execute("DROP TABLE IF EXISTS flat_orders")
conn.execute("DROP TABLE IF EXISTS customers")
conn.execute("DROP TABLE IF EXISTS orders")
conn.execute(
"CREATE TABLE flat_orders ("
"order_id INTEGER, order_date TEXT, customer_name TEXT, "
"customer_email TEXT, product TEXT, quantity INTEGER, price REAL)"
)
rows = [
(1, "2024-01-10", "Alice", "alice@example.com", "Laptop", 1, 999.99),
(2, "2024-01-11", "Bob", "bob@example.com", "Mouse", 2, 25.50),
(3, "2024-01-12", "Alice", "alice@example.com", "Keyboard", 1, 75.00),
(4, "2024-01-13", "Charlie", "charlie@example.com", "Monitor", 1, 300.00),
(5, "2024-01-14", "Bob", "bob@example.com", "Webcam", 1, 50.00),
(6, "2024-01-15", "Diana", "diana@example.com", "USB Hub", 3, 15.99),
]
conn.executemany("INSERT INTO flat_orders VALUES (?,?,?,?,?,?,?)", rows)
conn.commit()
SEED_FNS = {"easy": _seed_easy, "medium": _seed_medium, "hard": _seed_hard}
# ---------------------------------------------------------------------------
# Graders
# ---------------------------------------------------------------------------
EASY_EXPECTED = 1000.00 # 150 + 250.5 + 99.99 + 150 + 349.51
def grade_easy(conn: sqlite3.Connection, last_output: str) -> float:
"""Check if agent returned correct total revenue for Jan 2024."""
if not last_output:
return 0.0
# We inspect the agent's query execution result to see if 1000.0 is present.
try:
# Convert output strings to simple float checks.
import re
numbers = re.findall(r"[-+]?\d*\.\d+|\d+", last_output)
for num in numbers:
if abs(float(num) - EASY_EXPECTED) < 0.01:
return 1.0
except Exception:
pass
return 0.0
def grade_medium(conn: sqlite3.Connection, last_output: str) -> float:
"""Check cleaning quality: no duplicates, no nulls, lowercase emails."""
score = 0.0
try:
# Check table exists
cur = conn.execute("SELECT COUNT(*) FROM users")
total = cur.fetchone()[0]
if total == 0:
return 0.0
# Check lowercase emails (0.3)
cur = conn.execute("SELECT COUNT(*) FROM users WHERE email != LOWER(email)")
upper_count = cur.fetchone()[0]
if upper_count == 0:
score += 0.3
# Check no duplicate emails (0.4)
cur = conn.execute(
"SELECT COUNT(*) FROM (SELECT LOWER(email) as e FROM users GROUP BY e HAVING COUNT(*) > 1)"
)
dup_count = cur.fetchone()[0]
if dup_count == 0:
score += 0.4
# Check no NULL ages (0.3)
cur = conn.execute("SELECT COUNT(*) FROM users WHERE age IS NULL")
null_count = cur.fetchone()[0]
if null_count == 0:
score += 0.3
except Exception:
pass
return round(score, 2)
def grade_hard(conn: sqlite3.Connection, last_output: str) -> float:
"""Verify normalized schema and data integrity."""
score = 0.0
try:
# Check 'customers' table exists with correct columns (0.2)
cur = conn.execute("PRAGMA table_info(customers)")
cols = {r[1] for r in cur.fetchall()}
if {"id", "name", "email"}.issubset(cols):
score += 0.2
# Check 'orders' table exists with correct columns (0.2)
cur = conn.execute("PRAGMA table_info(orders)")
cols = {r[1] for r in cur.fetchall()}
if {"id", "customer_id", "order_date", "product", "quantity", "price"}.issubset(cols):
score += 0.2
# Check customer count = 4 unique customers (0.2)
cur = conn.execute("SELECT COUNT(*) FROM customers")
if cur.fetchone()[0] == 4:
score += 0.2
# Check orders count = 6 (0.2)
cur = conn.execute("SELECT COUNT(*) FROM orders")
if cur.fetchone()[0] == 6:
score += 0.2
# Check FK integrity: all customer_ids in orders exist in customers (0.2)
cur = conn.execute(
"SELECT COUNT(*) FROM orders WHERE customer_id NOT IN (SELECT id FROM customers)"
)
if cur.fetchone()[0] == 0:
score += 0.2
except Exception:
pass
return round(score, 2)
GRADERS = {"easy": grade_easy, "medium": grade_medium, "hard": grade_hard}
# ---------------------------------------------------------------------------
# Environment
# ---------------------------------------------------------------------------
class SqlSandboxEnvironment(Environment):
"""
SQL / Data Cleaning Sandbox a real-world OpenEnv environment.
The agent sends SQL or Python commands to clean messy databases.
Partial progress rewards are given after each step.
"""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self):
self._state = State(episode_id=str(uuid4()), step_count=0)
self._db_path = os.path.join(tempfile.gettempdir(), f"sqlsandbox_{uuid4().hex[:8]}.db")
self._conn: sqlite3.Connection | None = None
self._task_id = os.environ.get("TASK_ID", "easy")
self._task = TASKS[self._task_id]
self._max_steps = self._task["max_steps"]
self._done = False
self._last_reward = 0.0
# ---- helpers -----------------------------------------------------------
def _get_conn(self) -> sqlite3.Connection:
if self._conn is None:
self._conn = sqlite3.connect(self._db_path)
self._conn.execute("PRAGMA foreign_keys = ON")
return self._conn
def _partial_reward(self, last_output: str) -> float:
"""Run the grader to compute partial progress."""
return GRADERS[self._task_id](self._get_conn(), last_output)
def _exec_sql(self, query: str) -> tuple[str, str | None]:
try:
conn = self._get_conn()
cur = conn.execute(query)
if cur.description:
cols = [d[0] for d in cur.description]
rows = cur.fetchall()
header = " | ".join(cols)
body = "\n".join(" | ".join(str(c) for c in r) for r in rows)
output = f"{header}\n{body}" if rows else header + "\n(no rows)"
else:
output = f"OK {conn.total_changes} row(s) affected"
conn.commit()
return output, None
except Exception as e:
return "", str(e)
def _exec_python(self, code: str) -> tuple[str, str | None]:
stdout_buf, stderr_buf = io.StringIO(), io.StringIO()
try:
conn = self._get_conn()
cursor = conn.cursor()
globs = {
"__builtins__": __builtins__,
"sqlite3": sqlite3,
"DB_PATH": self._db_path,
"conn": conn,
"cursor": cursor,
}
with redirect_stdout(stdout_buf), redirect_stderr(stderr_buf):
exec(code, globs)
# Automatically commit any schema changes the LLM's python code made
conn.commit()
out = stdout_buf.getvalue()
err = stderr_buf.getvalue() or None
return out, err
except Exception:
return stdout_buf.getvalue(), traceback.format_exc()
# ---- OpenEnv interface -------------------------------------------------
def reset(self, **kwargs) -> SqlSandboxObservation:
"""Resets the environment and forces a task switch if task_id is provided."""
# 1. Close current connection to ensure file handles are released
if self._conn:
self._conn.close()
self._conn = None
# 2. Update task context from kwargs (primary) or environment (fallback)
# This is the fix for the 'Easy task persistence' bug.
self._task_id = kwargs.get("task_id", os.environ.get("TASK_ID", "easy"))
self._task = TASKS[self._task_id]
self._max_steps = self._task["max_steps"]
# 3. Re-initialize episode state
self._state = State(episode_id=str(uuid4()), step_count=0)
self._done = False
self._last_reward = 0.0
# 4. Open fresh connection and re-seed for the specific task_id
# Seed functions use 'DROP TABLE IF EXISTS' which handles cleanup.
conn = self._get_conn()
SEED_FNS[self._task_id](conn)
return SqlSandboxObservation(
output=f"Environment ready. Task: {self._task['description']}",
error=None,
current_step=0,
max_steps=self._max_steps,
task_description=self._task["description"],
done=False,
reward=0.0,
)
def step(self, action: SqlSandboxAction) -> SqlSandboxObservation: # type: ignore[override]
self._state.step_count += 1
step = self._state.step_count
if self._done:
return SqlSandboxObservation(
output="Episode already finished. Call reset().",
error=None,
current_step=step,
max_steps=self._max_steps,
task_description=self._task["description"],
done=True,
reward=self._last_reward,
)
# Execute action
if action.tool == "sql":
output, error = self._exec_sql(action.command)
else:
output, error = self._exec_python(action.command)
# Compute partial reward
reward = self._partial_reward(output)
# Check termination
done = step >= self._max_steps or reward >= 1.0
if done:
self._done = True
self._last_reward = reward
# Small penalty for errors to discourage random guessing
if error:
reward = max(0.0, reward - 0.05)
return SqlSandboxObservation(
output=output[:4000], # cap output size
error=error[:2000] if error else None,
current_step=step,
max_steps=self._max_steps,
task_description=self._task["description"],
done=done,
reward=round(reward, 4),
)
@property
def state(self) -> State:
return self._state