import logging import os import shutil import sqlite3 from server.task_specs import TaskScenarioBundle, build_task_scenario logger = logging.getLogger(__name__) WORKSPACE_ROOT = os.path.join(os.path.dirname(os.path.dirname(__file__)), "workspace") WORKSPACE_DIR = WORKSPACE_ROOT def setup_workspace( workspace_dir: str | None = None, *, scenario: TaskScenarioBundle | None = None ) -> str: """Initialise an isolated episode workspace from the seeded scenario.""" target_workspace = workspace_dir or WORKSPACE_DIR target_db_path = os.path.join(target_workspace, "mock_warehouse.db") resolved_scenario = scenario or build_task_scenario("task_1_easy_anomaly", seed=0) os.makedirs(target_workspace, exist_ok=True) _clear_workspace(target_workspace) _init_database(target_db_path, resolved_scenario) _write_seeded_files(target_workspace, resolved_scenario) logger.info( "Workspace reset complete: task=%s seed=%s db=%s", resolved_scenario.task_id, resolved_scenario.seed, target_db_path, ) return target_db_path def _clear_workspace(workspace_dir: str) -> None: for entry in os.listdir(workspace_dir): path = os.path.join(workspace_dir, entry) try: if os.path.isdir(path): shutil.rmtree(path) else: os.remove(path) except FileNotFoundError: continue def _init_database(db_path: str, scenario: TaskScenarioBundle) -> None: conn = sqlite3.connect(db_path) try: c = conn.cursor() c.execute( """ CREATE TABLE transactions ( id INTEGER PRIMARY KEY, user_id INTEGER NOT NULL, amount REAL, status TEXT NOT NULL ) """ ) c.execute( """ CREATE TABLE daily_reports ( id INTEGER PRIMARY KEY, report_date TEXT NOT NULL, department TEXT NOT NULL, revenue REAL NOT NULL, expenses REAL NOT NULL, headcount INTEGER NOT NULL ) """ ) if scenario.task_1: c.executemany( "INSERT INTO transactions VALUES (?, ?, ?, ?)", [ (row["id"], row["user_id"], row["amount"], row["status"]) for row in scenario.task_1.all_rows ], ) else: c.executemany( "INSERT INTO transactions VALUES (?, ?, ?, ?)", [(1, 9000, 100.0, "success")], ) if scenario.task_3: c.executemany( "INSERT INTO daily_reports VALUES (?, ?, ?, ?, ?, ?)", [ ( row["id"], row["report_date"], row["department"], row["revenue"], row["expenses"], row["headcount"], ) for row in scenario.task_3.all_rows ], ) conn.commit() finally: conn.close() def _write_seeded_files(workspace_dir: str, scenario: TaskScenarioBundle) -> None: if scenario.task_2: with open( os.path.join(workspace_dir, "broken_pipeline.py"), "w", encoding="utf-8", ) as f: f.write(scenario.task_2.broken_script) if scenario.task_3: with open( os.path.join(workspace_dir, "format_report.py"), "w", encoding="utf-8", ) as f: f.write(scenario.task_3.broken_script) if __name__ == "__main__": logging.basicConfig(level=logging.INFO) setup_workspace()