File size: 3,871 Bytes
f89b1ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import logging
import os
import shutil
import sqlite3

from server.task_specs import TaskScenarioBundle, build_task_scenario

logger = logging.getLogger(__name__)

WORKSPACE_ROOT = os.path.join(os.path.dirname(os.path.dirname(__file__)), "workspace")
WORKSPACE_DIR = WORKSPACE_ROOT


def setup_workspace(
    workspace_dir: str | None = None, *, scenario: TaskScenarioBundle | None = None
) -> str:
    """Initialise an isolated episode workspace from the seeded scenario."""
    target_workspace = workspace_dir or WORKSPACE_DIR
    target_db_path = os.path.join(target_workspace, "mock_warehouse.db")
    resolved_scenario = scenario or build_task_scenario("task_1_easy_anomaly", seed=0)
    os.makedirs(target_workspace, exist_ok=True)

    _clear_workspace(target_workspace)
    _init_database(target_db_path, resolved_scenario)
    _write_seeded_files(target_workspace, resolved_scenario)

    logger.info(
        "Workspace reset complete: task=%s seed=%s db=%s",
        resolved_scenario.task_id,
        resolved_scenario.seed,
        target_db_path,
    )
    return target_db_path


def _clear_workspace(workspace_dir: str) -> None:
    for entry in os.listdir(workspace_dir):
        path = os.path.join(workspace_dir, entry)
        try:
            if os.path.isdir(path):
                shutil.rmtree(path)
            else:
                os.remove(path)
        except FileNotFoundError:
            continue


def _init_database(db_path: str, scenario: TaskScenarioBundle) -> None:
    conn = sqlite3.connect(db_path)
    try:
        c = conn.cursor()
        c.execute(
            """
            CREATE TABLE transactions (
                id INTEGER PRIMARY KEY,
                user_id INTEGER NOT NULL,
                amount REAL,
                status TEXT NOT NULL
            )
            """
        )
        c.execute(
            """
            CREATE TABLE daily_reports (
                id INTEGER PRIMARY KEY,
                report_date TEXT NOT NULL,
                department TEXT NOT NULL,
                revenue REAL NOT NULL,
                expenses REAL NOT NULL,
                headcount INTEGER NOT NULL
            )
            """
        )

        if scenario.task_1:
            c.executemany(
                "INSERT INTO transactions VALUES (?, ?, ?, ?)",
                [
                    (row["id"], row["user_id"], row["amount"], row["status"])
                    for row in scenario.task_1.all_rows
                ],
            )
        else:
            c.executemany(
                "INSERT INTO transactions VALUES (?, ?, ?, ?)",
                [(1, 9000, 100.0, "success")],
            )

        if scenario.task_3:
            c.executemany(
                "INSERT INTO daily_reports VALUES (?, ?, ?, ?, ?, ?)",
                [
                    (
                        row["id"],
                        row["report_date"],
                        row["department"],
                        row["revenue"],
                        row["expenses"],
                        row["headcount"],
                    )
                    for row in scenario.task_3.all_rows
                ],
            )

        conn.commit()
    finally:
        conn.close()


def _write_seeded_files(workspace_dir: str, scenario: TaskScenarioBundle) -> None:
    if scenario.task_2:
        with open(
            os.path.join(workspace_dir, "broken_pipeline.py"),
            "w",
            encoding="utf-8",
        ) as f:
            f.write(scenario.task_2.broken_script)

    if scenario.task_3:
        with open(
            os.path.join(workspace_dir, "format_report.py"),
            "w",
            encoding="utf-8",
        ) as f:
            f.write(scenario.task_3.broken_script)


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    setup_workspace()