dataclerk-openenv / tests /test_env.py
vedhanth66's picture
Initial commit: DataClerk OpenEnv
2a352e6
"""
DataClerk OpenEnv β€” Test suite (stdlib-only, no pydantic mock needed since pydantic
is installed in the Docker/Space environment; this file works stand-alone too).
Run with: python tests/test_env.py
Or: python -m pytest tests/ -v
"""
from __future__ import annotations
import os
import sys
import unittest
# ── Path ──────────────────────────────────────────────────────────────────────
_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, _ROOT)
_TEST_DB = "/tmp/dataclerk_test_suite.db"
os.environ["DB_PATH"] = _TEST_DB
def _ensure_db():
from app.database import seed_database
seed_database(_TEST_DB)
_ensure_db()
import app.database as _db_mod
import app.tasks as _tasks_mod
_db_mod.DB_PATH = _TEST_DB
_tasks_mod.DB_PATH = _TEST_DB
_tasks_mod._CACHE = None # reset so graders read the test DB
# ─────────────────────────────────────────────
# Database tests
# ─────────────────────────────────────────────
class TestDatabase(unittest.TestCase):
def test_row_counts(self):
import sqlite3
conn = sqlite3.connect(_TEST_DB)
self.assertEqual(conn.execute("SELECT COUNT(*) FROM customers").fetchone()[0], 200)
self.assertEqual(conn.execute("SELECT COUNT(*) FROM products").fetchone()[0], 37)
self.assertEqual(conn.execute("SELECT COUNT(*) FROM orders").fetchone()[0], 1800)
self.assertEqual(conn.execute("SELECT COUNT(*) FROM support_tickets").fetchone()[0], 600)
conn.close()
def test_deterministic_seeding(self):
import sqlite3
path2 = _TEST_DB + ".dup"
if os.path.exists(path2):
os.remove(path2)
from app.database import seed_database
seed_database(path2)
c1 = sqlite3.connect(_TEST_DB)
c2 = sqlite3.connect(path2)
r1 = c1.execute("SELECT total_amount FROM orders ORDER BY id LIMIT 10").fetchall()
r2 = c2.execute("SELECT total_amount FROM orders ORDER BY id LIMIT 10").fetchall()
self.assertEqual(r1, r2)
c1.close()
c2.close()
os.remove(path2)
def test_schema_summary(self):
schema = _db_mod.get_schema_summary(_TEST_DB)
self.assertIn("customers", schema)
self.assertIn("orders", schema)
self.assertTrue(any("email" in c for c in schema["customers"]))
def test_idempotent_seed(self):
from app.database import seed_database
seed_database(_TEST_DB) # second call β€” must not duplicate
import sqlite3
conn = sqlite3.connect(_TEST_DB)
self.assertEqual(conn.execute("SELECT COUNT(*) FROM customers").fetchone()[0], 200)
conn.close()
# ─────────────────────────────────────────────
# Grader tests
# ─────────────────────────────────────────────
class TestGraders(unittest.TestCase):
def test_task1_perfect_answer(self):
answer = (
"1. Electronics: $330812.35\n"
"2. Clothing: $84556.61\n"
"3. Home & Garden: $69583.10"
)
queries = [
"SELECT p.category, SUM(oi.quantity*oi.unit_price) FROM orders o "
"JOIN order_items oi ON oi.order_id=o.id "
"JOIN products p ON p.id=oi.product_id "
"WHERE o.status='completed' GROUP BY p.category ORDER BY 2 DESC LIMIT 3"
]
score, _ = _tasks_mod._grade_task1(answer, queries)
self.assertGreaterEqual(score, 0.70)
def test_task1_empty_returns_zero(self):
score, _ = _tasks_mod._grade_task1("", [])
self.assertEqual(score, 0.0)
def test_task1_partial_credit(self):
answer = "Electronics was top with around 330000 in revenue"
score, _ = _tasks_mod._grade_task1(answer, [])
self.assertGreater(score, 0.05)
self.assertLess(score, 0.60)
def test_task2_correct_answer(self):
answer = "26 at-risk customers, average lifetime value $3023.28"
score, _ = _tasks_mod._grade_task2(answer, ["WITH x AS (SELECT 1 AS n)"])
self.assertGreaterEqual(score, 0.55)
def test_task2_empty_returns_zero(self):
score, _ = _tasks_mod._grade_task2("", [])
self.assertEqual(score, 0.0)
def test_task3_complete_report(self):
answer = (
"low priority is slowest at 14.25 days. "
"urgent is fastest at 1.96 days. "
"Sports has the highest refund rate at 12.21%. "
"standard: 111, premium: 33, enterprise: 15. Total: 159."
)
queries = [
"SELECT priority, AVG(julianday(resolved_at)-julianday(created_at)) "
"FROM support_tickets WHERE status IN ('resolved','closed') GROUP BY priority",
"SELECT p.category, 100.0*SUM(CASE WHEN o.status='refunded' THEN 1 ELSE 0 END)/COUNT(*) "
"FROM orders o JOIN order_items oi ON oi.order_id=o.id "
"JOIN products p ON p.id=oi.product_id GROUP BY p.category ORDER BY 2 DESC",
"SELECT c.tier, COUNT(*) FROM customers c "
"WHERE c.id IN (SELECT customer_id FROM orders WHERE status='completed' "
"GROUP BY customer_id HAVING COUNT(*)>=3) "
"AND c.id IN (SELECT customer_id FROM support_tickets "
"GROUP BY customer_id HAVING COUNT(*)>=2) GROUP BY c.tier",
"SELECT * FROM customers LIMIT 3",
"SELECT * FROM support_tickets LIMIT 3",
]
score, _ = _tasks_mod._grade_task3(answer, queries)
self.assertGreaterEqual(score, 0.60)
def test_scores_always_in_bounds(self):
for ans in ["", "Electronics 999999", "a" * 5000]:
for fn in [
_tasks_mod._grade_task1,
_tasks_mod._grade_task2,
_tasks_mod._grade_task3,
]:
s, _ = fn(ans, [])
self.assertGreaterEqual(s, 0.0)
self.assertLessEqual(s, 1.0)
def test_three_tasks_registered(self):
self.assertIn("revenue_analysis", _tasks_mod.TASKS)
self.assertIn("customer_risk_analysis", _tasks_mod.TASKS)
self.assertIn("business_health_report", _tasks_mod.TASKS)
def test_difficulty_progression(self):
self.assertEqual(_tasks_mod.TASKS["revenue_analysis"]["difficulty"], "easy")
self.assertEqual(_tasks_mod.TASKS["customer_risk_analysis"]["difficulty"], "medium")
self.assertEqual(_tasks_mod.TASKS["business_health_report"]["difficulty"], "hard")
# ─────────────────────────────────────────────
# Environment tests
# ─────────────────────────────────────────────
class TestEnvironment(unittest.TestCase):
@classmethod
def setUpClass(cls):
import app.environment as env_mod
env_mod.DB_PATH = _TEST_DB
env_mod.get_schema_summary = lambda db=_TEST_DB: _db_mod.get_schema_summary(_TEST_DB)
cls.env_mod = env_mod
from app.models import SQLAction
cls.Action = SQLAction
def _env(self, task="revenue_analysis"):
env = self.env_mod.DataClerkEnvironment(task)
env.reset()
return env
def _act(self, **kw):
return self.Action(**kw)
# ── reset ──────────────────────────────────────────────────────────────────
def test_reset_clean_state(self):
env = self._env()
self.assertEqual(env.step, 0)
self.assertEqual(env.query_count, 0)
self.assertFalse(env.done)
self.assertEqual(env.total_reward, 0.0)
def test_reset_returns_observation(self):
env = self.env_mod.DataClerkEnvironment("revenue_analysis")
obs, info = env.reset()
self.assertEqual(obs.task_id, "revenue_analysis")
self.assertIn("customers", obs.available_tables)
self.assertIn("orders", obs.schema_summary)
# ── action types ───────────────────────────────────────────────────────────
def test_list_tables(self):
env = self._env()
_, r, done, info = env.step_env(self._act(action_type="list_tables"))
self.assertFalse(done)
self.assertGreater(r, 0)
self.assertIn("orders", info["tables"])
def test_describe_table(self):
env = self._env()
_, r, done, info = env.step_env(
self._act(action_type="describe_table", table_name="orders")
)
self.assertFalse(done)
self.assertGreater(r, 0)
self.assertIn("columns", info)
def test_describe_unknown_table_penalised(self):
env = self._env()
_, r, _, _ = env.step_env(
self._act(action_type="describe_table", table_name="nonexistent")
)
self.assertLess(r, 0)
def test_execute_sql_count(self):
env = self._env()
obs, r, done, _ = env.step_env(
self._act(action_type="execute_sql", sql_query="SELECT COUNT(*) FROM orders")
)
self.assertFalse(done)
self.assertGreater(r, 0)
self.assertIsNotNone(obs.last_query_result)
self.assertEqual(obs.last_query_result.rows[0][0], 1800)
def test_execute_cte_allowed(self):
env = self._env()
cte = ("WITH stats AS (SELECT customer_id, COUNT(*) AS n FROM orders GROUP BY customer_id) "
"SELECT AVG(n) FROM stats")
obs, r, _, _ = env.step_env(self._act(action_type="execute_sql", sql_query=cte))
self.assertGreaterEqual(r, 0)
self.assertIsNotNone(obs.last_query_result)
def test_sql_error_penalised(self):
env = self._env()
_, r, _, _ = env.step_env(
self._act(action_type="execute_sql", sql_query="SELECT * FROM no_table")
)
self.assertLess(r, 0)
def test_dangerous_sql_blocked(self):
env = self._env()
_, r, _, _ = env.step_env(
self._act(action_type="execute_sql", sql_query="DROP TABLE customers")
)
self.assertLess(r, 0)
import sqlite3
conn = sqlite3.connect(_TEST_DB)
count = conn.execute("SELECT COUNT(*) FROM customers").fetchone()[0]
conn.close()
self.assertEqual(count, 200, "DROP was executed β€” data was modified!")
def test_duplicate_query_penalised(self):
env = self._env()
sql = "SELECT id FROM customers LIMIT 1"
_, r1, _, _ = env.step_env(self._act(action_type="execute_sql", sql_query=sql))
_, r2, _, _ = env.step_env(self._act(action_type="execute_sql", sql_query=sql))
self.assertGreater(r1, 0)
self.assertLess(r2, 0)
def test_submit_ends_episode(self):
env = self._env()
_, r, done, info = env.step_env(
self._act(action_type="submit_answer",
answer="Electronics 330812, Clothing 84556, Home & Garden 69583")
)
self.assertTrue(done)
self.assertIn("final_score", info)
self.assertGreaterEqual(info["final_score"], 0.0)
self.assertLessEqual(info["final_score"], 1.0)
def test_submit_good_answer_scores_high(self):
env = self._env()
# Run a good query first
env.step_env(self._act(
action_type="execute_sql",
sql_query=(
"SELECT p.category, ROUND(SUM(oi.quantity*oi.unit_price),2) "
"FROM orders o JOIN order_items oi ON oi.order_id=o.id "
"JOIN products p ON p.id=oi.product_id "
"WHERE o.status='completed' AND o.created_at>=date('2025-06-15','-180 days') "
"GROUP BY p.category ORDER BY 2 DESC LIMIT 3"
)
))
_, r, _, info = env.step_env(self._act(
action_type="submit_answer",
answer="1. Electronics $330812.35 2. Clothing $84556.61 3. Home & Garden $69583.10"
))
self.assertGreaterEqual(info["final_score"], 0.50)
# ── episode lifecycle ──────────────────────────────────────────────────────
def test_timeout_at_max_steps(self):
env = self.env_mod.DataClerkEnvironment("revenue_analysis") # max_steps=8
env.reset()
done = False
for i in range(10):
_, _, done, _ = env.step_env(
self._act(action_type="execute_sql", sql_query=f"SELECT {i+1}")
)
if done:
self.assertLessEqual(i + 1, 8)
break
self.assertTrue(done)
def test_step_after_done_is_noop(self):
env = self._env()
env.step_env(self._act(action_type="submit_answer", answer="done"))
_, r2, done2, _ = env.step_env(self._act(action_type="list_tables"))
self.assertTrue(done2)
self.assertEqual(r2, 0.0)
def test_state_tracks_progress(self):
env = self._env()
env.step_env(self._act(action_type="list_tables"))
env.step_env(self._act(action_type="execute_sql", sql_query="SELECT * FROM products LIMIT 5"))
state = env.state()
self.assertEqual(state.step, 2)
self.assertEqual(state.query_count, 1)
self.assertFalse(state.done)
def test_all_tasks_work(self):
for task_id in ["revenue_analysis", "customer_risk_analysis", "business_health_report"]:
env = self.env_mod.DataClerkEnvironment(task_id)
obs, _ = env.reset()
self.assertEqual(obs.task_id, task_id)
self.assertGreater(obs.max_steps, 0)
def test_unknown_action_penalised(self):
env = self._env()
_, r, _, _ = env.step_env(self._act(action_type="fly_to_moon"))
self.assertLess(r, 0)
if __name__ == "__main__":
unittest.main(verbosity=2)