File size: 3,932 Bytes
d57737f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
SQLite database management for AWM environments.

Creates databases from schema + sample data, manages snapshots for
verifier comparison (initial vs final state).
"""

import logging
import os
import shutil
import sqlite3
from typing import Any

logger = logging.getLogger(__name__)


def create_database(db_path: str, db_schema: dict, sample_data: Any) -> str:
    """Create a SQLite database from schema and populate with sample data.

    Args:
        db_path: Path where the .db file will be created.
        db_schema: Schema dict from gen_db.jsonl (contains "tables" list with "ddl" and "indexes").
        sample_data: Sample data from gen_sample.jsonl (list of SQL INSERT dicts or raw statements) .

    Returns:
        The db_path on success.
    """
    os.makedirs(os.path.dirname(db_path), exist_ok=True)

    if os.path.exists(db_path):
        os.remove(db_path)

    conn = sqlite3.connect(db_path)
    try:
        _create_schema(conn, db_schema)
        _insert_sample_data(conn, db_path, sample_data)
        conn.commit()
    except Exception:
        conn.rollback()
        raise
    finally:
        conn.close()

    logger.info(f"Created database: {db_path}")
    return db_path


def _create_schema(conn: sqlite3.Connection, db_schema: dict) -> None:
    cursor = conn.cursor()
    tables = db_schema.get("tables", [])
    for table in tables:
        ddl = table.get("ddl", "").strip()
        if ddl:
            try:
                cursor.execute(ddl)
            except sqlite3.Error as e:
                logger.warning(f"Failed to execute DDL: {e}\n  DDL: {ddl}")

        indexes = table.get("indexes", [])
        for idx in indexes:
            idx_stmt = str(idx).strip()
            if idx_stmt:
                try:
                    cursor.execute(idx_stmt)
                except sqlite3.Error as e:
                    logger.warning(f"Failed to create index: {e}\n  Index: {idx_stmt}")


def _insert_sample_data(
    conn: sqlite3.Connection, db_path: str, sample_data: Any
) -> None:
    """Insert sample data. Handles the AWM sample_data format which is a list
    of dicts with 'table_name' and 'insert_statements' keys."""
    if not sample_data:
        return

    cursor = conn.cursor()

    # AWM format wraps the list inside {"tables": [...]};  unwrap if needed.
    if isinstance(sample_data, dict) and "tables" in sample_data:
        sample_data = sample_data["tables"]

    if isinstance(sample_data, list):
        for item in sample_data:
            if isinstance(item, dict):
                statements = item.get("insert_statements", [])
                table_name = item.get("table_name", "unknown")
                for stmt in statements:
                    stmt = str(stmt).strip()
                    if stmt:
                        try:
                            cursor.execute(stmt)
                        except sqlite3.Error as e:
                            logger.warning(
                                f"Failed to insert into {table_name}: {e}\n  SQL: {stmt}"
                            )
            elif isinstance(item, str):
                item = item.strip()
                if item:
                    try:
                        cursor.execute(item)
                    except sqlite3.Error as e:
                        logger.warning(f"Failed to execute SQL: {e}\n  SQL: {item}")


def save_snapshot(db_path: str, snapshot_path: str) -> str:
    """Copy the database file as a snapshot for later verifier comparison."""
    os.makedirs(os.path.dirname(snapshot_path), exist_ok=True)
    shutil.copy2(db_path, snapshot_path)
    return snapshot_path


def cleanup_session_dir(session_dir: str) -> None:
    """Remove the session temp directory and all contents."""
    if session_dir and os.path.isdir(session_dir):
        shutil.rmtree(session_dir, ignore_errors=True)
        logger.debug(f"Cleaned up session dir: {session_dir}")