Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| import shutil | |
| import sqlite3 | |
| from server.task_specs import TaskScenarioBundle, build_task_scenario | |
| logger = logging.getLogger(__name__) | |
| WORKSPACE_ROOT = os.path.join(os.path.dirname(os.path.dirname(__file__)), "workspace") | |
| WORKSPACE_DIR = WORKSPACE_ROOT | |
| def setup_workspace( | |
| workspace_dir: str | None = None, *, scenario: TaskScenarioBundle | None = None | |
| ) -> str: | |
| """Initialise an isolated episode workspace from the seeded scenario.""" | |
| target_workspace = workspace_dir or WORKSPACE_DIR | |
| target_db_path = os.path.join(target_workspace, "mock_warehouse.db") | |
| resolved_scenario = scenario or build_task_scenario("task_1_easy_anomaly", seed=0) | |
| os.makedirs(target_workspace, exist_ok=True) | |
| _clear_workspace(target_workspace) | |
| _init_database(target_db_path, resolved_scenario) | |
| _write_seeded_files(target_workspace, resolved_scenario) | |
| logger.info( | |
| "Workspace reset complete: task=%s seed=%s db=%s", | |
| resolved_scenario.task_id, | |
| resolved_scenario.seed, | |
| target_db_path, | |
| ) | |
| return target_db_path | |
| def _clear_workspace(workspace_dir: str) -> None: | |
| for entry in os.listdir(workspace_dir): | |
| path = os.path.join(workspace_dir, entry) | |
| try: | |
| if os.path.isdir(path): | |
| shutil.rmtree(path) | |
| else: | |
| os.remove(path) | |
| except FileNotFoundError: | |
| continue | |
| def _init_database(db_path: str, scenario: TaskScenarioBundle) -> None: | |
| conn = sqlite3.connect(db_path) | |
| try: | |
| c = conn.cursor() | |
| c.execute( | |
| """ | |
| CREATE TABLE transactions ( | |
| id INTEGER PRIMARY KEY, | |
| user_id INTEGER NOT NULL, | |
| amount REAL, | |
| status TEXT NOT NULL | |
| ) | |
| """ | |
| ) | |
| c.execute( | |
| """ | |
| CREATE TABLE daily_reports ( | |
| id INTEGER PRIMARY KEY, | |
| report_date TEXT NOT NULL, | |
| department TEXT NOT NULL, | |
| revenue REAL NOT NULL, | |
| expenses REAL NOT NULL, | |
| headcount INTEGER NOT NULL | |
| ) | |
| """ | |
| ) | |
| if scenario.task_1: | |
| c.executemany( | |
| "INSERT INTO transactions VALUES (?, ?, ?, ?)", | |
| [ | |
| (row["id"], row["user_id"], row["amount"], row["status"]) | |
| for row in scenario.task_1.all_rows | |
| ], | |
| ) | |
| else: | |
| c.executemany( | |
| "INSERT INTO transactions VALUES (?, ?, ?, ?)", | |
| [(1, 9000, 100.0, "success")], | |
| ) | |
| if scenario.task_3: | |
| c.executemany( | |
| "INSERT INTO daily_reports VALUES (?, ?, ?, ?, ?, ?)", | |
| [ | |
| ( | |
| row["id"], | |
| row["report_date"], | |
| row["department"], | |
| row["revenue"], | |
| row["expenses"], | |
| row["headcount"], | |
| ) | |
| for row in scenario.task_3.all_rows | |
| ], | |
| ) | |
| conn.commit() | |
| finally: | |
| conn.close() | |
| def _write_seeded_files(workspace_dir: str, scenario: TaskScenarioBundle) -> None: | |
| if scenario.task_2: | |
| with open( | |
| os.path.join(workspace_dir, "broken_pipeline.py"), | |
| "w", | |
| encoding="utf-8", | |
| ) as f: | |
| f.write(scenario.task_2.broken_script) | |
| if scenario.task_3: | |
| with open( | |
| os.path.join(workspace_dir, "format_report.py"), | |
| "w", | |
| encoding="utf-8", | |
| ) as f: | |
| f.write(scenario.task_3.broken_script) | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO) | |
| setup_workspace() | |