dataops-env / data /init_db.py
visheshrathi's picture
Upload folder using huggingface_hub
f89b1ac verified
import logging
import os
import shutil
import sqlite3
from server.task_specs import TaskScenarioBundle, build_task_scenario
logger = logging.getLogger(__name__)
WORKSPACE_ROOT = os.path.join(os.path.dirname(os.path.dirname(__file__)), "workspace")
WORKSPACE_DIR = WORKSPACE_ROOT
def setup_workspace(
workspace_dir: str | None = None, *, scenario: TaskScenarioBundle | None = None
) -> str:
"""Initialise an isolated episode workspace from the seeded scenario."""
target_workspace = workspace_dir or WORKSPACE_DIR
target_db_path = os.path.join(target_workspace, "mock_warehouse.db")
resolved_scenario = scenario or build_task_scenario("task_1_easy_anomaly", seed=0)
os.makedirs(target_workspace, exist_ok=True)
_clear_workspace(target_workspace)
_init_database(target_db_path, resolved_scenario)
_write_seeded_files(target_workspace, resolved_scenario)
logger.info(
"Workspace reset complete: task=%s seed=%s db=%s",
resolved_scenario.task_id,
resolved_scenario.seed,
target_db_path,
)
return target_db_path
def _clear_workspace(workspace_dir: str) -> None:
for entry in os.listdir(workspace_dir):
path = os.path.join(workspace_dir, entry)
try:
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)
except FileNotFoundError:
continue
def _init_database(db_path: str, scenario: TaskScenarioBundle) -> None:
conn = sqlite3.connect(db_path)
try:
c = conn.cursor()
c.execute(
"""
CREATE TABLE transactions (
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
amount REAL,
status TEXT NOT NULL
)
"""
)
c.execute(
"""
CREATE TABLE daily_reports (
id INTEGER PRIMARY KEY,
report_date TEXT NOT NULL,
department TEXT NOT NULL,
revenue REAL NOT NULL,
expenses REAL NOT NULL,
headcount INTEGER NOT NULL
)
"""
)
if scenario.task_1:
c.executemany(
"INSERT INTO transactions VALUES (?, ?, ?, ?)",
[
(row["id"], row["user_id"], row["amount"], row["status"])
for row in scenario.task_1.all_rows
],
)
else:
c.executemany(
"INSERT INTO transactions VALUES (?, ?, ?, ?)",
[(1, 9000, 100.0, "success")],
)
if scenario.task_3:
c.executemany(
"INSERT INTO daily_reports VALUES (?, ?, ?, ?, ?, ?)",
[
(
row["id"],
row["report_date"],
row["department"],
row["revenue"],
row["expenses"],
row["headcount"],
)
for row in scenario.task_3.all_rows
],
)
conn.commit()
finally:
conn.close()
def _write_seeded_files(workspace_dir: str, scenario: TaskScenarioBundle) -> None:
if scenario.task_2:
with open(
os.path.join(workspace_dir, "broken_pipeline.py"),
"w",
encoding="utf-8",
) as f:
f.write(scenario.task_2.broken_script)
if scenario.task_3:
with open(
os.path.join(workspace_dir, "format_report.py"),
"w",
encoding="utf-8",
) as f:
f.write(scenario.task_3.broken_script)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
setup_workspace()