agent_world_model_env / server /db_manager.py
ChilleD's picture
Upload folder using huggingface_hub
d57737f verified
"""
SQLite database management for AWM environments.
Creates databases from schema + sample data, manages snapshots for
verifier comparison (initial vs final state).
"""
import logging
import os
import shutil
import sqlite3
from typing import Any
logger = logging.getLogger(__name__)
def create_database(db_path: str, db_schema: dict, sample_data: Any) -> str:
"""Create a SQLite database from schema and populate with sample data.
Args:
db_path: Path where the .db file will be created.
db_schema: Schema dict from gen_db.jsonl (contains "tables" list with "ddl" and "indexes").
sample_data: Sample data from gen_sample.jsonl (list of SQL INSERT dicts or raw statements) .
Returns:
The db_path on success.
"""
os.makedirs(os.path.dirname(db_path), exist_ok=True)
if os.path.exists(db_path):
os.remove(db_path)
conn = sqlite3.connect(db_path)
try:
_create_schema(conn, db_schema)
_insert_sample_data(conn, db_path, sample_data)
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
logger.info(f"Created database: {db_path}")
return db_path
def _create_schema(conn: sqlite3.Connection, db_schema: dict) -> None:
cursor = conn.cursor()
tables = db_schema.get("tables", [])
for table in tables:
ddl = table.get("ddl", "").strip()
if ddl:
try:
cursor.execute(ddl)
except sqlite3.Error as e:
logger.warning(f"Failed to execute DDL: {e}\n DDL: {ddl}")
indexes = table.get("indexes", [])
for idx in indexes:
idx_stmt = str(idx).strip()
if idx_stmt:
try:
cursor.execute(idx_stmt)
except sqlite3.Error as e:
logger.warning(f"Failed to create index: {e}\n Index: {idx_stmt}")
def _insert_sample_data(
conn: sqlite3.Connection, db_path: str, sample_data: Any
) -> None:
"""Insert sample data. Handles the AWM sample_data format which is a list
of dicts with 'table_name' and 'insert_statements' keys."""
if not sample_data:
return
cursor = conn.cursor()
# AWM format wraps the list inside {"tables": [...]}; unwrap if needed.
if isinstance(sample_data, dict) and "tables" in sample_data:
sample_data = sample_data["tables"]
if isinstance(sample_data, list):
for item in sample_data:
if isinstance(item, dict):
statements = item.get("insert_statements", [])
table_name = item.get("table_name", "unknown")
for stmt in statements:
stmt = str(stmt).strip()
if stmt:
try:
cursor.execute(stmt)
except sqlite3.Error as e:
logger.warning(
f"Failed to insert into {table_name}: {e}\n SQL: {stmt}"
)
elif isinstance(item, str):
item = item.strip()
if item:
try:
cursor.execute(item)
except sqlite3.Error as e:
logger.warning(f"Failed to execute SQL: {e}\n SQL: {item}")
def save_snapshot(db_path: str, snapshot_path: str) -> str:
"""Copy the database file as a snapshot for later verifier comparison."""
os.makedirs(os.path.dirname(snapshot_path), exist_ok=True)
shutil.copy2(db_path, snapshot_path)
return snapshot_path
def cleanup_session_dir(session_dir: str) -> None:
"""Remove the session temp directory and all contents."""
if session_dir and os.path.isdir(session_dir):
shutil.rmtree(session_dir, ignore_errors=True)
logger.debug(f"Cleaned up session dir: {session_dir}")