File size: 6,014 Bytes
a2ec7b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""
PostgreSQL Data Loading Utilities for MCPMark Tasks
===================================================

Common utilities for loading data into PostgreSQL databases from CSV files
and setting up schemas in prepare_environment.py scripts.
"""

import csv
import os
import psycopg2
from pathlib import Path
from typing import Dict, List, Any, Optional
import logging

logger = logging.getLogger(__name__)


def get_connection_params() -> dict:
    """Get database connection parameters from environment variables."""
    return {
        "host": os.getenv("POSTGRES_HOST", "localhost"),
        "port": int(os.getenv("POSTGRES_PORT", 5432)),
        "database": os.getenv("POSTGRES_DATABASE"),
        "user": os.getenv("POSTGRES_USERNAME"),
        "password": os.getenv("POSTGRES_PASSWORD"),
    }


def execute_schema_sql(conn, schema_sql: str):
    """Execute schema SQL with proper error handling."""
    with conn.cursor() as cur:
        cur.execute(schema_sql)
        conn.commit()
        logger.info("βœ… Database schema created successfully")


def load_csv_to_table(
    conn, 
    csv_file_path: Path, 
    table_name: str, 
    columns: Optional[List[str]] = None,
    skip_header: bool = True
):
    """
    Load CSV data into a PostgreSQL table.
    
    Args:
        conn: Database connection
        csv_file_path: Path to CSV file
        table_name: Target table name
        columns: List of column names (if None, uses all columns)
        skip_header: Whether to skip the first row
    """
    if not csv_file_path.exists():
        raise FileNotFoundError(f"CSV file not found: {csv_file_path}")
    
    with conn.cursor() as cur:
        with open(csv_file_path, 'r', encoding='utf-8') as f:
            csv_reader = csv.reader(f)
            
            # Skip header if needed
            if skip_header:
                next(csv_reader)
            
            # Build COPY command
            if columns:
                copy_sql = f"COPY {table_name} ({', '.join(columns)}) FROM STDIN WITH CSV"
            else:
                copy_sql = f"COPY {table_name} FROM STDIN WITH CSV"
            
            # Reset file pointer and copy data
            f.seek(0)
            if skip_header:
                next(csv.reader(f))  # Skip header again
            
            cur.copy_expert(copy_sql, f)
            
        conn.commit()
        logger.info(f"βœ… Loaded data from {csv_file_path.name} into {table_name}")


def insert_data_from_dict(conn, table_name: str, data: List[Dict[str, Any]]):
    """
    Insert data from a list of dictionaries into a table.
    
    Args:
        conn: Database connection
        table_name: Target table name
        data: List of dictionaries with column_name: value pairs
    """
    if not data:
        return
    
    # Get column names from first record
    columns = list(data[0].keys())
    placeholders = ', '.join(['%s'] * len(columns))
    columns_str = ', '.join(columns)
    
    insert_sql = f"INSERT INTO {table_name} ({columns_str}) VALUES ({placeholders}) ON CONFLICT DO NOTHING"
    
    with conn.cursor() as cur:
        for row in data:
            values = [row[col] for col in columns]
            cur.execute(insert_sql, values)
        
        conn.commit()
        logger.info(f"βœ… Inserted {len(data)} rows into {table_name}")


def create_table_with_data(
    conn, 
    table_name: str, 
    schema_sql: str, 
    data: Optional[List[Dict[str, Any]]] = None,
    data_from_csv: Optional[Path] = None
):
    """
    Create a table and optionally load data.
    
    Args:
        conn: Database connection
        table_name: Table name
        schema_sql: CREATE TABLE SQL statement
        data: Optional list of dictionaries to insert
        data_from_csv: Optional CSV file to load
    """
    with conn.cursor() as cur:
        # Create table
        cur.execute(schema_sql)
        logger.info(f"βœ… Created table {table_name}")
        
        # Load data if provided
        if data:
            insert_data_from_dict(conn, table_name, data)
        elif data_from_csv:
            load_csv_to_table(conn, data_from_csv, table_name)


def setup_database_with_config(setup_config: Dict[str, Any]):
    """
    Set up database using a configuration dictionary.
    
    Args:
        setup_config: Dictionary with 'tables' key containing table configurations
        
    Example config:
    {
        "tables": {
            "artists": {
                "schema": "CREATE TABLE artists (id SERIAL PRIMARY KEY, name VARCHAR(120))",
                "data": [{"id": 1, "name": "Iron Maiden"}],
                "data_from_csv": "data/artists.csv"  # alternative to data
            }
        }
    }
    """
    conn_params = get_connection_params()
    if not conn_params["database"]:
        raise ValueError("❌ No database specified in POSTGRES_DATABASE environment variable")
    
    try:
        conn = psycopg2.connect(**conn_params)
        
        for table_name, config in setup_config["tables"].items():
            schema_sql = config["schema"]
            data = config.get("data")
            csv_file_path = None
            
            # Handle CSV file path
            if "data_from_csv" in config:
                csv_file_path = Path(config["data_from_csv"])
                if not csv_file_path.is_absolute():
                    # Assume relative to current working directory (task directory)
                    csv_file_path = Path.cwd() / csv_file_path
            
            create_table_with_data(
                conn, 
                table_name, 
                schema_sql, 
                data=data, 
                data_from_csv=csv_file_path
            )
        
        conn.close()
        logger.info("πŸŽ‰ Database setup completed successfully")
        
    except psycopg2.Error as e:
        logger.error(f"❌ Database error during setup: {e}")
        raise
    except Exception as e:
        logger.error(f"❌ Setup error: {e}")
        raise