selfevolveagent / evoagentx /tools /database_postgresql.py
iLOVE2D's picture
Upload 2846 files
5374a2d verified
import time
import json
from typing import Dict, Any, List, Union, Optional
from pathlib import Path
import psycopg2
import psycopg2.extras
import re
from .database_base import DatabaseBase, DatabaseType, QueryType, DatabaseConnection
from .tool import Tool, Toolkit
from ..core.logging import logger
class PostgreSQLConnection(DatabaseConnection):
"""PostgreSQL-specific connection management"""
def __init__(self, connection_string: str, **kwargs):
super().__init__(connection_string, **kwargs)
self.conn = None
def connect(self) -> bool:
try:
self.conn = psycopg2.connect(self.connection_string, **self.connection_params)
self._is_connected = True
logger.info("Successfully connected to PostgreSQL")
return True
except Exception as e:
logger.error(f"Failed to connect to PostgreSQL: {str(e)}")
self._is_connected = False
return False
def disconnect(self) -> bool:
try:
if self.conn:
self.conn.close()
self.conn = None
self._is_connected = False
logger.info("Disconnected from PostgreSQL")
return True
except Exception as e:
logger.error(f"Error disconnecting from PostgreSQL: {str(e)}")
return False
def test_connection(self) -> bool:
try:
if self.conn:
with self.conn.cursor() as cur:
cur.execute("SELECT 1;")
return True
return False
except Exception:
return False
class PostgreSQLDatabase(DatabaseBase):
"""
PostgreSQL database implementation with automatic initialization.
Handles remote connections, existing local databases, and new local database creation.
"""
def __init__(self,
connection_string: str = None,
database_name: str = None,
local_path: str = None,
auto_save: bool = True,
**kwargs):
init_params = {
'connection_string': connection_string,
'database_name': database_name
}
super().__init__(**init_params, **kwargs)
self.local_path = Path(local_path) if local_path else None
self.auto_save = auto_save
self.connection_params = kwargs
self.is_local_database = False
self.conn = None
self.cursor = None
self.file_based_mode = False
self.tables = {} # For file-based mode
if self._is_remote_connection():
self._init_remote_database()
elif self._is_existing_local_database():
self._init_existing_local_database()
else:
self._init_new_local_database()
def _is_remote_connection(self) -> bool:
return self.connection_string and ("@" in self.connection_string or "postgresql://" in self.connection_string)
def _is_existing_local_database(self) -> bool:
if not self.local_path:
return False
if not self.local_path.exists():
return False
db_info_file = self.local_path / "db_info.json"
return db_info_file.exists()
def _init_remote_database(self):
"""Initialize remote PostgreSQL connection"""
try:
# Add connection timeout to prevent hanging
connection_params = self.connection_params.copy()
connection_params.update({
'connect_timeout': 5, # 5 second timeout
'options': '-c statement_timeout=5000' # 5 second statement timeout
})
self.conn = psycopg2.connect(self.connection_string, **connection_params)
self.cursor = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
if self.database_name:
self.conn.set_isolation_level(0)
self.cursor.execute("SELECT 1 FROM pg_database WHERE datname = %s", (self.database_name,))
self._is_initialized = True
self.is_local_database = False
self.file_based_mode = False
logger.info(f"Connected to remote PostgreSQL: {self.database_name}")
except Exception as e:
logger.error(f"Failed to connect to remote PostgreSQL: {str(e)}")
self._is_initialized = False
# Don't raise, just log the error and continue with local mode
logger.info("Falling back to local database mode")
def _init_existing_local_database(self):
"""Initialize existing local file-based database"""
try:
if not self.database_name:
self.database_name = self.local_path.name
# Load existing tables from JSON files
self._load_tables_from_files()
self._is_initialized = True
self.is_local_database = True
self.file_based_mode = True
logger.info(f"Loaded existing local file-based database from: {self.local_path}")
except Exception as e:
logger.error(f"Failed to load existing local database: {str(e)}")
self._is_initialized = False
logger.info("Falling back to new local database mode")
self._init_new_local_database()
def _init_new_local_database(self):
"""Initialize new local file-based database"""
try:
if not self.local_path:
self.local_path = Path("./workplace/postgresql_local")
self.local_path.mkdir(parents=True, exist_ok=True)
if not self.database_name:
self.database_name = self.local_path.name
self._create_db_info_file()
self._is_initialized = True
self.is_local_database = True
self.file_based_mode = True
logger.info(f"Created new local file-based database at: {self.local_path}")
except Exception as e:
logger.error(f"Failed to create new local database: {str(e)}")
self._is_initialized = False
logger.info("Database initialization failed, but toolkit is still usable")
def _create_db_info_file(self):
"""Create database info file"""
try:
db_info = {
"database_name": self.database_name,
"created_at": time.time(),
"local_path": str(self.local_path.absolute()),
"auto_save": self.auto_save,
"version": "1.0",
"mode": "file_based"
}
info_file = self.local_path / "db_info.json"
with open(info_file, 'w', encoding='utf-8') as f:
json.dump(db_info, f, indent=2, ensure_ascii=False)
except Exception as e:
logger.warning(f"Failed to create db info file: {str(e)}")
def _load_tables_from_files(self):
"""Load tables from JSON files"""
try:
for json_file in self.local_path.glob("*.json"):
if json_file.name == "db_info.json":
continue
table_name = json_file.stem
with open(json_file, 'r', encoding='utf-8') as f:
loaded_data = json.load(f)
# Ensure loaded data is a list
if not isinstance(loaded_data, list):
logger.warning(f"Table {table_name} file contains non-list data: {type(loaded_data)}, converting to empty list")
self.tables[table_name] = []
else:
self.tables[table_name] = loaded_data
except Exception as e:
logger.warning(f"Error loading tables from files: {str(e)}")
def _save_table_to_file(self, table_name: str):
"""Save table data to JSON file"""
try:
if table_name in self.tables:
table_file = self.local_path / f"{table_name}.json"
with open(table_file, 'w', encoding='utf-8') as f:
json.dump(self.tables[table_name], f, indent=2, ensure_ascii=False)
except Exception as e:
logger.error(f"Error saving table {table_name}: {str(e)}")
def _parse_sql_query(self, sql: str) -> Dict[str, Any]:
"""Enhanced SQL parser for file-based mode - now supports JOINs and complex queries"""
sql = sql.strip()
upper_sql = sql.upper()
# CREATE TABLE
if upper_sql.startswith("CREATE TABLE"):
match = re.search(r"CREATE TABLE (?:IF NOT EXISTS )?(\w+) *\((.*?)\)", sql, re.IGNORECASE | re.DOTALL)
if match:
table = match.group(1).lower()
columns = match.group(2)
col_defs = [c.strip() for c in columns.split(',') if c.strip()]
col_names = [c.split()[0] for c in col_defs]
return {"type": "CREATE", "table": table, "columns": col_names}
# INSERT
elif upper_sql.startswith("INSERT"):
match = re.search(r"INSERT INTO (\w+) *\((.*?)\) *VALUES", sql, re.IGNORECASE | re.DOTALL)
if match:
table = match.group(1).lower()
columns = [c.strip() for c in match.group(2).split(',')]
values_match = re.search(r"VALUES\s*(.*)", sql, re.IGNORECASE | re.DOTALL)
if values_match:
values_str = values_match.group(1)
value_groups = re.findall(r'\(([^)]+)\)', values_str)
all_values = []
for group in value_groups:
values = [v.strip().strip("'\"") for v in group.split(',')]
all_values.append(values)
return {"type": "INSERT", "table": table, "columns": columns, "values": all_values}
# SELECT - Enhanced to support JOINs
elif upper_sql.startswith("SELECT"):
# Complex SELECT with JOINs
if "JOIN" in upper_sql:
# Parse JOIN queries
match = re.search(r"SELECT (.*?) FROM (\w+)(?:\s+(\w+))?\s+(?:(\w+)\s+)?JOIN\s+(\w+)(?:\s+(\w+))?\s+ON\s+(.*?)(?: WHERE (.*?))?(?: ORDER BY (.*?))?(?: LIMIT (\d+))?", sql, re.IGNORECASE | re.DOTALL)
if match:
columns = [c.strip() for c in match.group(1).split(',')]
table1 = match.group(2).lower()
alias1 = match.group(3)
join_type = match.group(4) or "INNER"
table2 = match.group(5).lower()
alias2 = match.group(6)
join_condition = match.group(7)
where = match.group(8)
order_by = match.group(9)
limit = match.group(10)
return {
"type": "SELECT_JOIN",
"columns": columns,
"table1": table1,
"alias1": alias1,
"join_type": join_type,
"table2": table2,
"alias2": alias2,
"join_condition": join_condition,
"where": where,
"order_by": order_by,
"limit": limit
}
# CROSS JOIN support
elif "CROSS JOIN" in upper_sql:
match = re.search(r"SELECT (.*?) FROM (\w+)(?:\s+(\w+))?\s+CROSS\s+JOIN\s+(\w+)(?:\s+(\w+))?(?: WHERE (.*?))?(?: ORDER BY (.*?))?(?: LIMIT (\d+))?", sql, re.IGNORECASE | re.DOTALL)
if match:
columns = [c.strip() for c in match.group(1).split(',')]
table1 = match.group(2).lower()
alias1 = match.group(3)
table2 = match.group(4).lower()
alias2 = match.group(5)
where = match.group(6)
order_by = match.group(7)
limit = match.group(8)
return {
"type": "SELECT_CROSS_JOIN",
"columns": columns,
"table1": table1,
"alias1": alias1,
"table2": table2,
"alias2": alias2,
"where": where,
"order_by": order_by,
"limit": limit
}
# Simple SELECT (existing logic)
else:
match = re.search(r"SELECT (.*?) FROM (\w+)(?: WHERE (.*?))?(?: GROUP BY (.*?))?(?: ORDER BY (.*?))?(?: LIMIT (\d+))?", sql, re.IGNORECASE | re.DOTALL)
if match:
columns = [c.strip() for c in match.group(1).split(',')]
table = match.group(2).lower()
where = match.group(3)
group_by = match.group(4)
order_by = match.group(5)
limit = match.group(6)
return {"type": "SELECT", "table": table, "columns": columns, "where": where, "group_by": group_by, "order_by": order_by, "limit": limit}
# UPDATE
elif upper_sql.startswith("UPDATE"):
match = re.search(r"UPDATE (\w+) SET (.*?)(?: WHERE (.*?))?$", sql, re.IGNORECASE | re.DOTALL)
if match:
table = match.group(1).lower()
set_clause = match.group(2)
where = match.group(3)
return {"type": "UPDATE", "table": table, "set": set_clause, "where": where}
# DELETE
elif upper_sql.startswith("DELETE"):
match = re.search(r"DELETE FROM (\w+)(?: WHERE (.*?))?", sql, re.IGNORECASE | re.DOTALL)
if match:
table = match.group(1).lower()
where = match.group(2)
return {"type": "DELETE", "table": table, "where": where}
return {"type": "UNKNOWN"}
def _apply_where_filter(self, rows: List[Dict], where: str) -> List[Dict]:
"""Apply WHERE filter to rows"""
if not where:
return rows
# Ensure rows is a list of dictionaries
if not isinstance(rows, list):
logger.warning(f"_apply_where_filter: rows is not a list: {type(rows)}")
return []
# Filter out any non-dictionary items
valid_rows = [r for r in rows if isinstance(r, dict)]
if len(valid_rows) != len(rows):
logger.warning(f"_apply_where_filter: filtered out {len(rows) - len(valid_rows)} non-dict rows")
# Handle simple conditions: col = 'val', col > val, etc.
m = re.match(r"(\w+) *([=><]+) *'?([\w@.\- ]+)'?", where)
if m:
col, op, val = m.group(1), m.group(2), m.group(3)
if op == "=":
return [r for r in valid_rows if str(r.get(col, "")) == val]
elif op == ">":
try:
val_num = int(val)
return [r for r in valid_rows if int(r.get(col, 0)) > val_num]
except ValueError:
pass
elif op == "<":
try:
val_num = int(val)
return [r for r in valid_rows if int(r.get(col, 0)) < val_num]
except ValueError:
pass
return valid_rows
def _apply_column_selection(self, rows: List[Dict], columns: List[str]) -> List[Dict]:
"""Apply column selection to rows"""
if columns == ['*']:
return rows
# Ensure rows is a list of dictionaries
if not isinstance(rows, list):
logger.warning(f"_apply_column_selection: rows is not a list: {type(rows)}")
return []
# Filter out any non-dictionary items
valid_rows = [r for r in rows if isinstance(r, dict)]
if len(valid_rows) != len(rows):
logger.warning(f"_apply_column_selection: filtered out {len(rows) - len(valid_rows)} non-dict rows")
filtered_rows = []
for row in valid_rows:
filtered_row = {}
for col in columns:
if col in row:
filtered_row[col] = row[col]
filtered_rows.append(filtered_row)
return filtered_rows
def _apply_group_by(self, rows: List[Dict], group_by: str) -> List[Dict]:
"""Apply GROUP BY aggregation to rows"""
if not group_by:
return rows
# Ensure rows is a list of dictionaries
if not isinstance(rows, list):
logger.warning(f"_apply_group_by: rows is not a list: {type(rows)}")
return []
# Filter out any non-dictionary items
valid_rows = [r for r in rows if isinstance(r, dict)]
if len(valid_rows) != len(rows):
logger.warning(f"_apply_group_by: filtered out {len(rows) - len(valid_rows)} non-dict rows")
group_col = group_by.strip()
groups = {}
for row in valid_rows:
group_val = row.get(group_col, "Unknown")
if group_val not in groups:
groups[group_val] = []
groups[group_val].append(row)
result = []
for group_val, group_rows in groups.items():
group_result = {group_col: group_val}
# Always include all aggregation keys
group_result["employee_count"] = len(group_rows)
salaries = [float(r.get("salary", 0)) for r in group_rows if r.get("salary") is not None]
group_result["avg_salary"] = sum(salaries) / len(salaries) if salaries else 0
group_result["max_salary"] = max(salaries) if salaries else 0
result.append(group_result)
return result
def _execute_join_query(self, parsed: Dict) -> Dict[str, Any]:
"""Execute JOIN query in file-based mode"""
try:
table1 = parsed["table1"]
table2 = parsed["table2"]
columns = parsed["columns"]
join_condition = parsed["join_condition"]
where = parsed.get("where")
# Get table data
rows1 = self.tables.get(table1, [])
rows2 = self.tables.get(table2, [])
# Ensure rows are lists
if not isinstance(rows1, list):
logger.warning(f"Table {table1} contains non-list data: {type(rows1)}")
rows1 = []
if not isinstance(rows2, list):
logger.warning(f"Table {table2} contains non-list data: {type(rows2)}")
rows2 = []
# Parse join condition: table1.col = table2.col
join_match = re.match(r"(\w+)\.(\w+)\s*=\s*(\w+)\.(\w+)", join_condition)
if not join_match:
return {"error": "Invalid join condition format"}
col1, col2 = join_match.group(2), join_match.group(4)
# Perform JOIN
result_rows = []
for row1 in rows1:
# Ensure row1 is a dictionary
if not isinstance(row1, dict):
logger.warning(f"Skipping non-dict row1 in JOIN: {type(row1)}")
continue
for row2 in rows2:
# Ensure row2 is a dictionary
if not isinstance(row2, dict):
logger.warning(f"Skipping non-dict row2 in JOIN: {type(row2)}")
continue
if str(row1.get(col1, "")) == str(row2.get(col2, "")):
# Combine rows
combined_row = {}
for col in columns:
if '.' in col:
# Handle aliased columns: table.col
table_alias, col_name = col.split('.', 1)
if table_alias == parsed.get("alias1") or table_alias == table1:
combined_row[col] = row1.get(col_name, "")
elif table_alias == parsed.get("alias2") or table_alias == table2:
combined_row[col] = row2.get(col_name, "")
else:
# Handle simple columns
if col in row1:
combined_row[col] = row1[col]
elif col in row2:
combined_row[col] = row2[col]
result_rows.append(combined_row)
# Apply WHERE filter if specified
if where:
result_rows = self._apply_where_filter(result_rows, where)
return result_rows
except Exception as e:
logger.error(f"Error executing JOIN query: {str(e)}")
return {"error": str(e)}
def _execute_cross_join_query(self, parsed: Dict) -> Dict[str, Any]:
"""Execute CROSS JOIN query in file-based mode"""
try:
table1 = parsed["table1"]
table2 = parsed["table2"]
columns = parsed["columns"]
where = parsed.get("where")
# Get table data
rows1 = self.tables.get(table1, [])
rows2 = self.tables.get(table2, [])
# Ensure rows are lists
if not isinstance(rows1, list):
logger.warning(f"Table {table1} contains non-list data: {type(rows1)}")
rows1 = []
if not isinstance(rows2, list):
logger.warning(f"Table {table2} contains non-list data: {type(rows2)}")
rows2 = []
# Perform CROSS JOIN
result_rows = []
for row1 in rows1:
# Ensure row1 is a dictionary
if not isinstance(row1, dict):
logger.warning(f"Skipping non-dict row1 in CROSS JOIN: {type(row1)}")
continue
for row2 in rows2:
# Ensure row2 is a dictionary
if not isinstance(row2, dict):
logger.warning(f"Skipping non-dict row2 in CROSS JOIN: {type(row2)}")
continue
# Combine rows
combined_row = {}
for col in columns:
if '.' in col:
# Handle aliased columns: table.col
table_alias, col_name = col.split('.', 1)
if table_alias == parsed.get("alias1") or table_alias == table1:
combined_row[col] = row1.get(col_name, "")
elif table_alias == parsed.get("alias2") or table_alias == table2:
combined_row[col] = row2.get(col_name, "")
else:
# Handle simple columns
if col in row1:
combined_row[col] = row1[col]
elif col in row2:
combined_row[col] = row2[col]
result_rows.append(combined_row)
# Apply WHERE filter if specified
if where:
result_rows = self._apply_where_filter(result_rows, where)
return result_rows
except Exception as e:
logger.error(f"Error executing CROSS JOIN query: {str(e)}")
return {"error": str(e)}
def _get_database_type(self) -> DatabaseType:
return DatabaseType.POSTGRESQL
def connect(self) -> bool:
return self._is_initialized
def disconnect(self) -> bool:
try:
if self.conn:
self.conn.close()
self.conn = None
self.cursor = None
self._is_initialized = False
logger.info("Disconnected from PostgreSQL")
return True
except Exception as e:
logger.error(f"Error disconnecting: {str(e)}")
return False
def test_connection(self) -> bool:
if self.file_based_mode:
return self._is_initialized
try:
if self.conn:
with self.conn.cursor() as cur:
cur.execute("SELECT 1;")
return True
return False
except Exception:
return False
def execute_query(self, query: Union[str, Dict, List], query_type: QueryType = None, **kwargs) -> Dict[str, Any]:
if not self._is_initialized:
return self.format_error_result("Database not initialized")
# For file-based mode, keep existing logic
if self.file_based_mode:
return self._execute_file_based_query(query, query_type)
# For remote PostgreSQL, use direct psycopg2 execution
if self.conn is None:
return self.format_error_result("PostgreSQL server not available")
start_time = time.time()
try:
with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
# Handle different query formats
if isinstance(query, str):
# Direct SQL string - execute as-is
cur.execute(query)
elif isinstance(query, dict):
# Dict with SQL and params - use parameterized query
sql = query.get("sql")
params = query.get("params", None)
if params:
cur.execute(sql, params)
else:
cur.execute(sql)
elif isinstance(query, list):
# List of queries - execute each one
for q in query:
if isinstance(q, str):
cur.execute(q)
elif isinstance(q, dict):
sql = q.get("sql")
params = q.get("params", None)
if params:
cur.execute(sql, params)
else:
cur.execute(sql)
else:
return self.format_error_result("Unsupported query format", query_type)
# Handle results
if cur.description:
result = cur.fetchall()
else:
result = {"rowcount": cur.rowcount}
self.conn.commit()
execution_time = time.time() - start_time
return self.format_query_result(result, query_type or QueryType.SELECT, execution_time=execution_time)
except Exception as e:
execution_time = time.time() - start_time
logger.error(f"Error executing PostgreSQL query: {str(e)}")
# Rollback on error
try:
if self.conn:
self.conn.rollback()
except Exception as rollback_error:
logger.warning(f"Error during rollback: {str(rollback_error)}")
return self.format_error_result(str(e), query_type, execution_time=execution_time)
def _execute_file_based_query(self, query: Union[str, Dict, List], query_type: QueryType = None) -> Dict[str, Any]:
"""Execute query in file-based mode"""
start_time = time.time()
try:
if isinstance(query, str):
parsed = self._parse_sql_query(query)
query_type = query_type or QueryType.SELECT
# Ensure parsed is a dictionary with a type key
if not isinstance(parsed, dict) or "type" not in parsed:
logger.error(f"_execute_file_based_query: parsed is not a valid dict: {parsed}")
return self.format_error_result(f"Failed to parse SQL query: {query}", query_type)
logger.debug(f"Executing {parsed['type']} query: {parsed}")
if parsed["type"] == "CREATE":
table_name = parsed["table"]
columns = parsed.get("columns", ["id"])
if table_name not in self.tables:
self.tables[table_name] = []
# Ensure table is always a list
if not isinstance(self.tables[table_name], list):
logger.warning(f"Reinitializing table {table_name} as list (was {type(self.tables[table_name])})")
self.tables[table_name] = []
# Store schema as a hidden key
self.tables[f"__schema__{table_name}"] = columns
if self.auto_save:
self._save_table_to_file(table_name)
result = {"rowcount": 0}
elif parsed["type"] == "INSERT":
table_name = parsed["table"]
columns = parsed["columns"]
all_values = parsed["values"]
if table_name not in self.tables:
self.tables[table_name] = []
# Ensure table is always a list
if not isinstance(self.tables[table_name], list):
logger.warning(f"Reinitializing table {table_name} as list (was {type(self.tables[table_name])})")
self.tables[table_name] = []
valid_rows = 0
# Insert all rows
for values in all_values:
# Skip invalid rows (should have same number of values as columns)
if len(values) != len(columns):
logger.warning(f"Skipping invalid row: {values} (expected {len(columns)} values, got {len(values)})")
continue
# Ensure values is a list
if not isinstance(values, list):
logger.warning(f"Skipping non-list values: {type(values)}")
continue
row = {col: val for col, val in zip(columns, values)}
row["id"] = len(self.tables[table_name]) + 1
self.tables[table_name].append(row)
valid_rows += 1
if self.auto_save:
self._save_table_to_file(table_name)
result = {"rowcount": valid_rows}
elif parsed["type"] == "SELECT":
table_name = parsed["table"]
columns = parsed["columns"]
where = parsed.get("where")
group_by = parsed.get("group_by")
rows = self.tables.get(table_name, [])
# Ensure rows is always a list
if not isinstance(rows, list):
logger.warning(f"Table {table_name} contains non-list data: {type(rows)}")
rows = []
# Debug logging
logger.debug(f"SELECT query: table={table_name}, columns={columns}, where={where}, group_by={group_by}")
logger.debug(f"Rows from table: {type(rows)}, length={len(rows) if isinstance(rows, list) else 'N/A'}")
if isinstance(rows, list) and rows:
logger.debug(f"First row type: {type(rows[0])}, content: {rows[0]}")
# Apply WHERE filter
if where:
rows = self._apply_where_filter(rows, where)
# Handle basic aggregation
if group_by:
result = self._apply_group_by(rows, group_by)
else:
# Apply column selection
result = {"data": self._apply_column_selection(rows, columns)}
elif parsed["type"] == "SELECT_JOIN":
# Handle JOIN queries
logger.debug(f"Executing JOIN query: {parsed}")
join_result = self._execute_join_query(parsed)
if isinstance(join_result, dict) and "error" in join_result:
result = {"error": join_result["error"]}
else:
result = {"data": join_result}
elif parsed["type"] == "SELECT_CROSS_JOIN":
# Handle CROSS JOIN queries
logger.debug(f"Executing CROSS JOIN query: {parsed}")
cross_join_result = self._execute_cross_join_query(parsed)
if isinstance(cross_join_result, dict) and "error" in cross_join_result:
result = {"error": cross_join_result["error"]}
else:
result = {"data": cross_join_result}
elif parsed["type"] == "UPDATE":
table_name = parsed["table"]
set_clause = parsed["set"]
where = parsed.get("where")
rows = self.tables.get(table_name, [])
# Ensure rows is always a list
if not isinstance(rows, list):
logger.warning(f"Table {table_name} contains non-list data: {type(rows)}")
rows = []
# Parse set_clause: col1 = 'val1', col2 = 'val2'
updates = dict(re.findall(r"(\w+) *= *'?([\w@.\- ]+)'?", set_clause))
count = 0
for r in rows:
# Ensure r is a dictionary
if not isinstance(r, dict):
logger.warning(f"Skipping non-dict row in UPDATE: {type(r)}")
continue
match = True
if where:
m = re.match(r"(\w+) *([=><]+) *'?([\w@.\- ]+)'?", where)
if m:
col, op, val = m.group(1), m.group(2), m.group(3)
if op == "=" and str(r.get(col, "")) != val:
match = False
elif op == ">" and int(r.get(col, 0)) <= int(val):
match = False
elif op == "<" and int(r.get(col, 0)) >= int(val):
match = False
if match:
r.update(updates)
count += 1
if self.auto_save:
self._save_table_to_file(table_name)
result = {"rowcount": count}
elif parsed["type"] == "DELETE":
table_name = parsed["table"]
where = parsed.get("where")
rows = self.tables.get(table_name, [])
# Ensure rows is always a list
if not isinstance(rows, list):
logger.warning(f"Table {table_name} contains non-list data: {type(rows)}")
rows = []
if where:
m = re.match(r"(\w+) *([=><]+) *'?([\w@.\- ]+)'?", where)
if m:
col, op, val = m.group(1), m.group(2), m.group(3)
if op == "=":
new_rows = [r for r in rows if isinstance(r, dict) and str(r.get(col, "")) != val]
elif op == ">":
try:
val_num = int(val)
new_rows = [r for r in rows if isinstance(r, dict) and int(r.get(col, 0)) <= val_num]
except ValueError:
new_rows = rows
else:
new_rows = rows
deleted_count = len(rows) - len(new_rows)
self.tables[table_name] = new_rows
else:
deleted_count = 0
else:
deleted_count = len(rows)
self.tables[table_name] = []
if self.auto_save:
self._save_table_to_file(table_name)
result = {"rowcount": deleted_count}
else:
return self.format_error_result("Unsupported query type in file-based mode", query_type)
execution_time = time.time() - start_time
return self.format_query_result(result, query_type, execution_time=execution_time)
else:
return self.format_error_result("Unsupported query format in file-based mode", query_type)
except Exception as e:
execution_time = time.time() - start_time
logger.error(f"Error executing file-based query: {str(e)}")
logger.error(f"Query that caused error: {query}")
logger.error(f"Query type: {query_type}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
return self.format_error_result(str(e), query_type, execution_time=execution_time)
def get_database_info(self) -> Dict[str, Any]:
try:
if not self._is_initialized:
return self.format_error_result("Database not initialized")
if self.file_based_mode:
info = {
"database": self.database_name,
"user": "file_based",
"table_count": len(self.tables),
"connection_string": "file_based",
"is_connected": True,
"mode": "file_based"
}
else:
if self.conn is None:
return self.format_error_result("PostgreSQL server not available")
with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute("SELECT current_database() as database, current_user as user")
db_info = cur.fetchone()
cur.execute("SELECT COUNT(*) as table_count FROM information_schema.tables WHERE table_schema = 'public'")
table_count = cur.fetchone()["table_count"]
info = {
"database": db_info["database"],
"user": db_info["user"],
"table_count": table_count,
"connection_string": self.connection_string,
"is_connected": self._is_initialized
}
return self.format_query_result(info, QueryType.SELECT)
except Exception as e:
return self.format_error_result(str(e))
def list_collections(self) -> List[str]:
try:
if self.file_based_mode:
return list(self.tables.keys())
if not self._is_initialized or self.conn is None:
return []
with self.conn.cursor() as cur:
cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
tables = [row[0] for row in cur.fetchall()]
return tables
except Exception as e:
logger.error(f"Error listing tables: {str(e)}")
return []
def get_collection_info(self, collection_name: str) -> Dict[str, Any]:
try:
if not self._is_initialized:
return self.format_error_result("Database not initialized")
if self.file_based_mode:
if collection_name in self.tables:
row_count = len(self.tables[collection_name])
info = {
"table_name": collection_name,
"row_count": row_count,
"columns": ["id"] # Simple column structure
}
else:
return self.format_error_result(f"Table {collection_name} not found")
else:
if self.conn is None:
return self.format_error_result("PostgreSQL server not available")
with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
cur.execute(f"SELECT COUNT(*) as row_count FROM {collection_name}")
row_count = cur.fetchone()["row_count"]
cur.execute("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s", (collection_name,))
columns = cur.fetchall()
info = {
"table_name": collection_name,
"row_count": row_count,
"columns": columns
}
return self.format_query_result(info, QueryType.SELECT)
except Exception as e:
return self.format_error_result(str(e))
def get_schema(self, collection_name: str = None) -> Dict[str, Any]:
try:
if not self._is_initialized:
return self.format_error_result("Database not initialized")
if self.file_based_mode:
if collection_name:
if collection_name in self.tables:
schema = {"id": "integer"}
return self.format_query_result({"table_name": collection_name, "schema": schema}, QueryType.SELECT)
else:
return self.format_error_result(f"Table {collection_name} not found")
else:
schemas = {}
for table_name in self.tables:
schemas[table_name] = {"id": "integer"}
return self.format_query_result({"database_name": self.database_name, "schemas": schemas}, QueryType.SELECT)
else:
if self.conn is None:
return self.format_error_result("PostgreSQL server not available")
with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur:
if collection_name:
cur.execute("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s", (collection_name,))
columns = cur.fetchall()
schema = {col["column_name"]: col["data_type"] for col in columns}
return self.format_query_result({"table_name": collection_name, "schema": schema}, QueryType.SELECT)
else:
cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
tables = [row[0] for row in cur.fetchall()]
schemas = {}
for table in tables:
cur.execute("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s", (table,))
columns = cur.fetchall()
schemas[table] = {col["column_name"]: col["data_type"] for col in columns}
return self.format_query_result({"database_name": self.database_name, "schemas": schemas}, QueryType.SELECT)
except Exception as e:
return self.format_error_result(str(e))
def get_supported_query_types(self) -> List[QueryType]:
return [
QueryType.SELECT,
QueryType.INSERT,
QueryType.UPDATE,
QueryType.DELETE,
QueryType.CREATE,
QueryType.DROP,
QueryType.ALTER,
QueryType.INDEX
]
def get_capabilities(self) -> Dict[str, Any]:
base_capabilities = super().get_capabilities()
base_capabilities.update({
"supports_sql": True,
"supports_transactions": not self.file_based_mode,
"supports_indexing": not self.file_based_mode,
"schema_flexible": self.file_based_mode,
"file_based_mode": self.file_based_mode
})
return base_capabilities
# Tool classes
class PostgreSQLExecuteTool(Tool):
name: str = "postgresql_execute"
description: str = "Execute arbitrary SQL queries on PostgreSQL."
inputs: Dict[str, Dict[str, str]] = {
"query": {"type": "string", "description": "SQL query to execute (can be SELECT, INSERT, UPDATE, DELETE, etc.)"},
"query_type": {"type": "string", "description": "Type of query (select, insert, update, delete, create, drop, alter, index) - auto-detected if not provided"}
}
required: Optional[List[str]] = ["query"]
def __init__(self, database: PostgreSQLDatabase = None):
super().__init__()
self.database = database
def __call__(self, query: str, query_type: str = None) -> Dict[str, Any]:
try:
if not self.database:
return {"success": False, "error": "PostgreSQL database not initialized", "data": None}
# Simply pass the SQL query directly to the database
# No more complex parsing - let psycopg2 handle it
query_type_enum = None
if query_type:
try:
query_type_enum = QueryType(query_type.lower())
except ValueError:
return {"success": False, "error": f"Invalid query type: {query_type}", "data": None}
result = self.database.execute_query(query=query, query_type=query_type_enum)
return result
except Exception as e:
logger.error(f"Error in postgresql_execute tool: {str(e)}")
return {"success": False, "error": str(e), "data": None}
class PostgreSQLFindTool(Tool):
name: str = "postgresql_find"
description: str = "Find (SELECT) rows from a PostgreSQL table."
inputs: Dict[str, Dict[str, str]] = {
"table_name": {"type": "string", "description": "Table name to query"},
"where": {"type": "string", "description": "WHERE clause (optional, e.g., 'age > 18')"},
"columns": {"type": "string", "description": "Comma-separated columns to select (default '*')"},
"limit": {"type": "integer", "description": "Maximum number of rows to return (optional)"},
"offset": {"type": "integer", "description": "Number of rows to skip (optional)"},
"sort": {"type": "string", "description": "ORDER BY clause (optional, e.g., 'age ASC')"}
}
required: Optional[List[str]] = ["table_name"]
def __init__(self, database: PostgreSQLDatabase = None):
super().__init__()
self.database = database
def __call__(self, table_name: str, where: str = None, columns: str = "*", limit: int = None, offset: int = None, sort: str = None) -> Dict[str, Any]:
try:
if not self.database:
return {"success": False, "error": "PostgreSQL database not initialized", "data": None}
sql = f"SELECT {columns} FROM {table_name}"
if where:
sql += f" WHERE {where}"
if sort:
sql += f" ORDER BY {sort}"
if limit is not None:
sql += f" LIMIT {limit}"
if offset is not None:
sql += f" OFFSET {offset}"
result = self.database.execute_query(sql, QueryType.SELECT)
return result
except Exception as e:
logger.error(f"Error in postgresql_find tool: {str(e)}")
return {"success": False, "error": str(e), "data": None}
class PostgreSQLUpdateTool(Tool):
name: str = "postgresql_update"
description: str = "Update rows in a PostgreSQL table."
inputs: Dict[str, Dict[str, str]] = {
"table_name": {"type": "string", "description": "Table name to update"},
"set": {"type": "string", "description": "SET clause (e.g., 'status = \'active\'')"},
"where": {"type": "string", "description": "WHERE clause (optional)"}
}
required: Optional[List[str]] = ["table_name", "set"]
def __init__(self, database: PostgreSQLDatabase = None):
super().__init__()
self.database = database
def __call__(self, table_name: str, set: str, where: str = None) -> Dict[str, Any]:
try:
if not self.database:
return {"success": False, "error": "PostgreSQL database not initialized", "data": None}
sql = f"UPDATE {table_name} SET {set}"
if where:
sql += f" WHERE {where}"
result = self.database.execute_query(sql, QueryType.UPDATE)
return result
except Exception as e:
logger.error(f"Error in postgresql_update tool: {str(e)}")
return {"success": False, "error": str(e), "data": None}
class PostgreSQLCreateTool(Tool):
name: str = "postgresql_create"
description: str = "Create a table or other object in PostgreSQL."
inputs: Dict[str, Dict[str, str]] = {
"query": {"type": "string", "description": "CREATE statement (e.g., CREATE TABLE ...)"}
}
required: Optional[List[str]] = ["query"]
def __init__(self, database: PostgreSQLDatabase = None):
super().__init__()
self.database = database
def __call__(self, query: str) -> Dict[str, Any]:
try:
if not self.database:
return {"success": False, "error": "PostgreSQL database not initialized", "data": None}
result = self.database.execute_query(query, QueryType.CREATE)
return result
except Exception as e:
logger.error(f"Error in postgresql_create tool: {str(e)}")
return {"success": False, "error": str(e), "data": None}
class PostgreSQLDeleteTool(Tool):
name: str = "postgresql_delete"
description: str = "Delete rows from a PostgreSQL table."
inputs: Dict[str, Dict[str, str]] = {
"table_name": {"type": "string", "description": "Table name to delete from"},
"where": {"type": "string", "description": "WHERE clause (optional)"}
}
required: Optional[List[str]] = ["table_name"]
def __init__(self, database: PostgreSQLDatabase = None):
super().__init__()
self.database = database
def __call__(self, table_name: str, where: str = None) -> Dict[str, Any]:
try:
if not self.database:
return {"success": False, "error": "PostgreSQL database not initialized", "data": None}
sql = f"DELETE FROM {table_name}"
if where:
sql += f" WHERE {where}"
result = self.database.execute_query(sql, QueryType.DELETE)
return result
except Exception as e:
logger.error(f"Error in postgresql_delete tool: {str(e)}")
return {"success": False, "error": str(e), "data": None}
class PostgreSQLInfoTool(Tool):
name: str = "postgresql_info"
description: str = "Get PostgreSQL database and table information."
inputs: Dict[str, Dict[str, str]] = {
"info_type": {"type": "string", "description": "Type of information (database, tables, table, schema, capabilities)"},
"table_name": {"type": "string", "description": "Table name for table-specific info (optional)"}
}
required: Optional[List[str]] = []
def __init__(self, database: PostgreSQLDatabase = None):
super().__init__()
self.database = database
def __call__(self, info_type: str = "database", table_name: str = None) -> Dict[str, Any]:
try:
if not self.database:
return {"success": False, "error": "PostgreSQL database not initialized", "data": None}
info_type = info_type.lower()
if info_type == "database":
result = self.database.get_database_info()
elif info_type == "tables":
tables = self.database.list_collections()
result = {"success": True, "data": tables, "table_count": len(tables)}
elif info_type == "table" and table_name:
result = self.database.get_collection_info(table_name)
elif info_type == "schema":
result = self.database.get_schema(table_name)
elif info_type == "capabilities":
result = {"success": True, "data": self.database.get_capabilities()}
else:
return {"success": False, "error": f"Invalid info type: {info_type}", "data": None}
return result
except Exception as e:
logger.error(f"Error in postgresql_info tool: {str(e)}")
return {"success": False, "error": str(e), "data": None}
class PostgreSQLToolkit(Toolkit):
def __init__(self,
name: str = "PostgreSQLToolkit",
connection_string: str = None,
database_name: str = None,
local_path: str = None,
auto_save: bool = True,
**kwargs):
database = PostgreSQLDatabase(
connection_string=connection_string,
database_name=database_name,
local_path=local_path,
auto_save=auto_save,
**kwargs
)
tools = [
PostgreSQLExecuteTool(database=database),
PostgreSQLFindTool(database=database),
PostgreSQLUpdateTool(database=database),
PostgreSQLCreateTool(database=database),
PostgreSQLDeleteTool(database=database),
PostgreSQLInfoTool(database=database)
]
super().__init__(name=name, tools=tools)
self.database = database
self.connection_string = connection_string
self.database_name = database_name
self.local_path = local_path
self.auto_save = auto_save
import atexit
atexit.register(self._cleanup)
def _cleanup(self):
try:
if self.database:
self.database.disconnect()
logger.info("Disconnected from PostgreSQL database")
except Exception as e:
logger.warning(f"Error during cleanup: {str(e)}")
def get_capabilities(self) -> Dict[str, Any]:
if self.database:
capabilities = self.database.get_capabilities()
capabilities.update({
"is_local_database": self.database.is_local_database,
"local_path": str(self.database.local_path) if self.database.local_path else None,
"auto_save": self.database.auto_save
})
return capabilities
return {"error": "PostgreSQL database not initialized"}
def connect(self) -> bool:
return self.database.connect() if self.database else False
def disconnect(self) -> bool:
return self.database.disconnect() if self.database else False
def test_connection(self) -> bool:
return self.database.test_connection() if self.database else False
def get_database(self) -> PostgreSQLDatabase:
return self.database
def get_local_info(self) -> Dict[str, Any]:
return {
"is_local_database": self.database.is_local_database,
"local_path": str(self.database.local_path) if self.database.local_path else None,
"auto_save": self.database.auto_save,
"database_name": self.database_name,
"connection_string": self.connection_string
} if self.database else {"error": "Database not initialized"}