Spaces:
Sleeping
Sleeping
File size: 3,871 Bytes
f89b1ac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import logging
import os
import shutil
import sqlite3
from server.task_specs import TaskScenarioBundle, build_task_scenario
logger = logging.getLogger(__name__)
WORKSPACE_ROOT = os.path.join(os.path.dirname(os.path.dirname(__file__)), "workspace")
WORKSPACE_DIR = WORKSPACE_ROOT
def setup_workspace(
workspace_dir: str | None = None, *, scenario: TaskScenarioBundle | None = None
) -> str:
"""Initialise an isolated episode workspace from the seeded scenario."""
target_workspace = workspace_dir or WORKSPACE_DIR
target_db_path = os.path.join(target_workspace, "mock_warehouse.db")
resolved_scenario = scenario or build_task_scenario("task_1_easy_anomaly", seed=0)
os.makedirs(target_workspace, exist_ok=True)
_clear_workspace(target_workspace)
_init_database(target_db_path, resolved_scenario)
_write_seeded_files(target_workspace, resolved_scenario)
logger.info(
"Workspace reset complete: task=%s seed=%s db=%s",
resolved_scenario.task_id,
resolved_scenario.seed,
target_db_path,
)
return target_db_path
def _clear_workspace(workspace_dir: str) -> None:
for entry in os.listdir(workspace_dir):
path = os.path.join(workspace_dir, entry)
try:
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)
except FileNotFoundError:
continue
def _init_database(db_path: str, scenario: TaskScenarioBundle) -> None:
conn = sqlite3.connect(db_path)
try:
c = conn.cursor()
c.execute(
"""
CREATE TABLE transactions (
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
amount REAL,
status TEXT NOT NULL
)
"""
)
c.execute(
"""
CREATE TABLE daily_reports (
id INTEGER PRIMARY KEY,
report_date TEXT NOT NULL,
department TEXT NOT NULL,
revenue REAL NOT NULL,
expenses REAL NOT NULL,
headcount INTEGER NOT NULL
)
"""
)
if scenario.task_1:
c.executemany(
"INSERT INTO transactions VALUES (?, ?, ?, ?)",
[
(row["id"], row["user_id"], row["amount"], row["status"])
for row in scenario.task_1.all_rows
],
)
else:
c.executemany(
"INSERT INTO transactions VALUES (?, ?, ?, ?)",
[(1, 9000, 100.0, "success")],
)
if scenario.task_3:
c.executemany(
"INSERT INTO daily_reports VALUES (?, ?, ?, ?, ?, ?)",
[
(
row["id"],
row["report_date"],
row["department"],
row["revenue"],
row["expenses"],
row["headcount"],
)
for row in scenario.task_3.all_rows
],
)
conn.commit()
finally:
conn.close()
def _write_seeded_files(workspace_dir: str, scenario: TaskScenarioBundle) -> None:
if scenario.task_2:
with open(
os.path.join(workspace_dir, "broken_pipeline.py"),
"w",
encoding="utf-8",
) as f:
f.write(scenario.task_2.broken_script)
if scenario.task_3:
with open(
os.path.join(workspace_dir, "format_report.py"),
"w",
encoding="utf-8",
) as f:
f.write(scenario.task_3.broken_script)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
setup_workspace()
|