|
|
import time |
|
|
import json |
|
|
from typing import Dict, Any, List, Union, Optional |
|
|
from pathlib import Path |
|
|
import psycopg2 |
|
|
import psycopg2.extras |
|
|
import re |
|
|
|
|
|
from .database_base import DatabaseBase, DatabaseType, QueryType, DatabaseConnection |
|
|
from .tool import Tool, Toolkit |
|
|
from ..core.logging import logger |
|
|
|
|
|
class PostgreSQLConnection(DatabaseConnection): |
|
|
"""PostgreSQL-specific connection management""" |
|
|
def __init__(self, connection_string: str, **kwargs): |
|
|
super().__init__(connection_string, **kwargs) |
|
|
self.conn = None |
|
|
|
|
|
def connect(self) -> bool: |
|
|
try: |
|
|
self.conn = psycopg2.connect(self.connection_string, **self.connection_params) |
|
|
self._is_connected = True |
|
|
logger.info("Successfully connected to PostgreSQL") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to connect to PostgreSQL: {str(e)}") |
|
|
self._is_connected = False |
|
|
return False |
|
|
|
|
|
def disconnect(self) -> bool: |
|
|
try: |
|
|
if self.conn: |
|
|
self.conn.close() |
|
|
self.conn = None |
|
|
self._is_connected = False |
|
|
logger.info("Disconnected from PostgreSQL") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Error disconnecting from PostgreSQL: {str(e)}") |
|
|
return False |
|
|
|
|
|
def test_connection(self) -> bool: |
|
|
try: |
|
|
if self.conn: |
|
|
with self.conn.cursor() as cur: |
|
|
cur.execute("SELECT 1;") |
|
|
return True |
|
|
return False |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
class PostgreSQLDatabase(DatabaseBase): |
|
|
""" |
|
|
PostgreSQL database implementation with automatic initialization. |
|
|
Handles remote connections, existing local databases, and new local database creation. |
|
|
""" |
|
|
def __init__(self, |
|
|
connection_string: str = None, |
|
|
database_name: str = None, |
|
|
local_path: str = None, |
|
|
auto_save: bool = True, |
|
|
**kwargs): |
|
|
init_params = { |
|
|
'connection_string': connection_string, |
|
|
'database_name': database_name |
|
|
} |
|
|
super().__init__(**init_params, **kwargs) |
|
|
self.local_path = Path(local_path) if local_path else None |
|
|
self.auto_save = auto_save |
|
|
self.connection_params = kwargs |
|
|
self.is_local_database = False |
|
|
self.conn = None |
|
|
self.cursor = None |
|
|
self.file_based_mode = False |
|
|
self.tables = {} |
|
|
|
|
|
if self._is_remote_connection(): |
|
|
self._init_remote_database() |
|
|
elif self._is_existing_local_database(): |
|
|
self._init_existing_local_database() |
|
|
else: |
|
|
self._init_new_local_database() |
|
|
|
|
|
def _is_remote_connection(self) -> bool: |
|
|
return self.connection_string and ("@" in self.connection_string or "postgresql://" in self.connection_string) |
|
|
|
|
|
def _is_existing_local_database(self) -> bool: |
|
|
if not self.local_path: |
|
|
return False |
|
|
if not self.local_path.exists(): |
|
|
return False |
|
|
db_info_file = self.local_path / "db_info.json" |
|
|
return db_info_file.exists() |
|
|
|
|
|
def _init_remote_database(self): |
|
|
"""Initialize remote PostgreSQL connection""" |
|
|
try: |
|
|
|
|
|
connection_params = self.connection_params.copy() |
|
|
connection_params.update({ |
|
|
'connect_timeout': 5, |
|
|
'options': '-c statement_timeout=5000' |
|
|
}) |
|
|
|
|
|
self.conn = psycopg2.connect(self.connection_string, **connection_params) |
|
|
self.cursor = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) |
|
|
if self.database_name: |
|
|
self.conn.set_isolation_level(0) |
|
|
self.cursor.execute("SELECT 1 FROM pg_database WHERE datname = %s", (self.database_name,)) |
|
|
self._is_initialized = True |
|
|
self.is_local_database = False |
|
|
self.file_based_mode = False |
|
|
logger.info(f"Connected to remote PostgreSQL: {self.database_name}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to connect to remote PostgreSQL: {str(e)}") |
|
|
self._is_initialized = False |
|
|
|
|
|
logger.info("Falling back to local database mode") |
|
|
|
|
|
def _init_existing_local_database(self): |
|
|
"""Initialize existing local file-based database""" |
|
|
try: |
|
|
if not self.database_name: |
|
|
self.database_name = self.local_path.name |
|
|
|
|
|
|
|
|
self._load_tables_from_files() |
|
|
|
|
|
self._is_initialized = True |
|
|
self.is_local_database = True |
|
|
self.file_based_mode = True |
|
|
logger.info(f"Loaded existing local file-based database from: {self.local_path}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load existing local database: {str(e)}") |
|
|
self._is_initialized = False |
|
|
logger.info("Falling back to new local database mode") |
|
|
self._init_new_local_database() |
|
|
|
|
|
def _init_new_local_database(self): |
|
|
"""Initialize new local file-based database""" |
|
|
try: |
|
|
if not self.local_path: |
|
|
self.local_path = Path("./workplace/postgresql_local") |
|
|
self.local_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if not self.database_name: |
|
|
self.database_name = self.local_path.name |
|
|
|
|
|
self._create_db_info_file() |
|
|
self._is_initialized = True |
|
|
self.is_local_database = True |
|
|
self.file_based_mode = True |
|
|
logger.info(f"Created new local file-based database at: {self.local_path}") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to create new local database: {str(e)}") |
|
|
self._is_initialized = False |
|
|
logger.info("Database initialization failed, but toolkit is still usable") |
|
|
|
|
|
def _create_db_info_file(self): |
|
|
"""Create database info file""" |
|
|
try: |
|
|
db_info = { |
|
|
"database_name": self.database_name, |
|
|
"created_at": time.time(), |
|
|
"local_path": str(self.local_path.absolute()), |
|
|
"auto_save": self.auto_save, |
|
|
"version": "1.0", |
|
|
"mode": "file_based" |
|
|
} |
|
|
info_file = self.local_path / "db_info.json" |
|
|
with open(info_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(db_info, f, indent=2, ensure_ascii=False) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to create db info file: {str(e)}") |
|
|
|
|
|
def _load_tables_from_files(self): |
|
|
"""Load tables from JSON files""" |
|
|
try: |
|
|
for json_file in self.local_path.glob("*.json"): |
|
|
if json_file.name == "db_info.json": |
|
|
continue |
|
|
table_name = json_file.stem |
|
|
with open(json_file, 'r', encoding='utf-8') as f: |
|
|
loaded_data = json.load(f) |
|
|
|
|
|
if not isinstance(loaded_data, list): |
|
|
logger.warning(f"Table {table_name} file contains non-list data: {type(loaded_data)}, converting to empty list") |
|
|
self.tables[table_name] = [] |
|
|
else: |
|
|
self.tables[table_name] = loaded_data |
|
|
except Exception as e: |
|
|
logger.warning(f"Error loading tables from files: {str(e)}") |
|
|
|
|
|
def _save_table_to_file(self, table_name: str): |
|
|
"""Save table data to JSON file""" |
|
|
try: |
|
|
if table_name in self.tables: |
|
|
table_file = self.local_path / f"{table_name}.json" |
|
|
with open(table_file, 'w', encoding='utf-8') as f: |
|
|
json.dump(self.tables[table_name], f, indent=2, ensure_ascii=False) |
|
|
except Exception as e: |
|
|
logger.error(f"Error saving table {table_name}: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
def _parse_sql_query(self, sql: str) -> Dict[str, Any]: |
|
|
"""Enhanced SQL parser for file-based mode - now supports JOINs and complex queries""" |
|
|
sql = sql.strip() |
|
|
upper_sql = sql.upper() |
|
|
|
|
|
|
|
|
if upper_sql.startswith("CREATE TABLE"): |
|
|
match = re.search(r"CREATE TABLE (?:IF NOT EXISTS )?(\w+) *\((.*?)\)", sql, re.IGNORECASE | re.DOTALL) |
|
|
if match: |
|
|
table = match.group(1).lower() |
|
|
columns = match.group(2) |
|
|
col_defs = [c.strip() for c in columns.split(',') if c.strip()] |
|
|
col_names = [c.split()[0] for c in col_defs] |
|
|
return {"type": "CREATE", "table": table, "columns": col_names} |
|
|
|
|
|
|
|
|
elif upper_sql.startswith("INSERT"): |
|
|
match = re.search(r"INSERT INTO (\w+) *\((.*?)\) *VALUES", sql, re.IGNORECASE | re.DOTALL) |
|
|
if match: |
|
|
table = match.group(1).lower() |
|
|
columns = [c.strip() for c in match.group(2).split(',')] |
|
|
values_match = re.search(r"VALUES\s*(.*)", sql, re.IGNORECASE | re.DOTALL) |
|
|
if values_match: |
|
|
values_str = values_match.group(1) |
|
|
value_groups = re.findall(r'\(([^)]+)\)', values_str) |
|
|
all_values = [] |
|
|
for group in value_groups: |
|
|
values = [v.strip().strip("'\"") for v in group.split(',')] |
|
|
all_values.append(values) |
|
|
return {"type": "INSERT", "table": table, "columns": columns, "values": all_values} |
|
|
|
|
|
|
|
|
elif upper_sql.startswith("SELECT"): |
|
|
|
|
|
if "JOIN" in upper_sql: |
|
|
|
|
|
match = re.search(r"SELECT (.*?) FROM (\w+)(?:\s+(\w+))?\s+(?:(\w+)\s+)?JOIN\s+(\w+)(?:\s+(\w+))?\s+ON\s+(.*?)(?: WHERE (.*?))?(?: ORDER BY (.*?))?(?: LIMIT (\d+))?", sql, re.IGNORECASE | re.DOTALL) |
|
|
if match: |
|
|
columns = [c.strip() for c in match.group(1).split(',')] |
|
|
table1 = match.group(2).lower() |
|
|
alias1 = match.group(3) |
|
|
join_type = match.group(4) or "INNER" |
|
|
table2 = match.group(5).lower() |
|
|
alias2 = match.group(6) |
|
|
join_condition = match.group(7) |
|
|
where = match.group(8) |
|
|
order_by = match.group(9) |
|
|
limit = match.group(10) |
|
|
|
|
|
return { |
|
|
"type": "SELECT_JOIN", |
|
|
"columns": columns, |
|
|
"table1": table1, |
|
|
"alias1": alias1, |
|
|
"join_type": join_type, |
|
|
"table2": table2, |
|
|
"alias2": alias2, |
|
|
"join_condition": join_condition, |
|
|
"where": where, |
|
|
"order_by": order_by, |
|
|
"limit": limit |
|
|
} |
|
|
|
|
|
|
|
|
elif "CROSS JOIN" in upper_sql: |
|
|
match = re.search(r"SELECT (.*?) FROM (\w+)(?:\s+(\w+))?\s+CROSS\s+JOIN\s+(\w+)(?:\s+(\w+))?(?: WHERE (.*?))?(?: ORDER BY (.*?))?(?: LIMIT (\d+))?", sql, re.IGNORECASE | re.DOTALL) |
|
|
if match: |
|
|
columns = [c.strip() for c in match.group(1).split(',')] |
|
|
table1 = match.group(2).lower() |
|
|
alias1 = match.group(3) |
|
|
table2 = match.group(4).lower() |
|
|
alias2 = match.group(5) |
|
|
where = match.group(6) |
|
|
order_by = match.group(7) |
|
|
limit = match.group(8) |
|
|
|
|
|
return { |
|
|
"type": "SELECT_CROSS_JOIN", |
|
|
"columns": columns, |
|
|
"table1": table1, |
|
|
"alias1": alias1, |
|
|
"table2": table2, |
|
|
"alias2": alias2, |
|
|
"where": where, |
|
|
"order_by": order_by, |
|
|
"limit": limit |
|
|
} |
|
|
|
|
|
|
|
|
else: |
|
|
match = re.search(r"SELECT (.*?) FROM (\w+)(?: WHERE (.*?))?(?: GROUP BY (.*?))?(?: ORDER BY (.*?))?(?: LIMIT (\d+))?", sql, re.IGNORECASE | re.DOTALL) |
|
|
if match: |
|
|
columns = [c.strip() for c in match.group(1).split(',')] |
|
|
table = match.group(2).lower() |
|
|
where = match.group(3) |
|
|
group_by = match.group(4) |
|
|
order_by = match.group(5) |
|
|
limit = match.group(6) |
|
|
return {"type": "SELECT", "table": table, "columns": columns, "where": where, "group_by": group_by, "order_by": order_by, "limit": limit} |
|
|
|
|
|
|
|
|
elif upper_sql.startswith("UPDATE"): |
|
|
match = re.search(r"UPDATE (\w+) SET (.*?)(?: WHERE (.*?))?$", sql, re.IGNORECASE | re.DOTALL) |
|
|
if match: |
|
|
table = match.group(1).lower() |
|
|
set_clause = match.group(2) |
|
|
where = match.group(3) |
|
|
return {"type": "UPDATE", "table": table, "set": set_clause, "where": where} |
|
|
|
|
|
|
|
|
elif upper_sql.startswith("DELETE"): |
|
|
match = re.search(r"DELETE FROM (\w+)(?: WHERE (.*?))?", sql, re.IGNORECASE | re.DOTALL) |
|
|
if match: |
|
|
table = match.group(1).lower() |
|
|
where = match.group(2) |
|
|
return {"type": "DELETE", "table": table, "where": where} |
|
|
|
|
|
return {"type": "UNKNOWN"} |
|
|
|
|
|
def _apply_where_filter(self, rows: List[Dict], where: str) -> List[Dict]: |
|
|
"""Apply WHERE filter to rows""" |
|
|
if not where: |
|
|
return rows |
|
|
|
|
|
|
|
|
if not isinstance(rows, list): |
|
|
logger.warning(f"_apply_where_filter: rows is not a list: {type(rows)}") |
|
|
return [] |
|
|
|
|
|
|
|
|
valid_rows = [r for r in rows if isinstance(r, dict)] |
|
|
if len(valid_rows) != len(rows): |
|
|
logger.warning(f"_apply_where_filter: filtered out {len(rows) - len(valid_rows)} non-dict rows") |
|
|
|
|
|
|
|
|
m = re.match(r"(\w+) *([=><]+) *'?([\w@.\- ]+)'?", where) |
|
|
if m: |
|
|
col, op, val = m.group(1), m.group(2), m.group(3) |
|
|
if op == "=": |
|
|
return [r for r in valid_rows if str(r.get(col, "")) == val] |
|
|
elif op == ">": |
|
|
try: |
|
|
val_num = int(val) |
|
|
return [r for r in valid_rows if int(r.get(col, 0)) > val_num] |
|
|
except ValueError: |
|
|
pass |
|
|
elif op == "<": |
|
|
try: |
|
|
val_num = int(val) |
|
|
return [r for r in valid_rows if int(r.get(col, 0)) < val_num] |
|
|
except ValueError: |
|
|
pass |
|
|
return valid_rows |
|
|
|
|
|
def _apply_column_selection(self, rows: List[Dict], columns: List[str]) -> List[Dict]: |
|
|
"""Apply column selection to rows""" |
|
|
if columns == ['*']: |
|
|
return rows |
|
|
|
|
|
|
|
|
if not isinstance(rows, list): |
|
|
logger.warning(f"_apply_column_selection: rows is not a list: {type(rows)}") |
|
|
return [] |
|
|
|
|
|
|
|
|
valid_rows = [r for r in rows if isinstance(r, dict)] |
|
|
if len(valid_rows) != len(rows): |
|
|
logger.warning(f"_apply_column_selection: filtered out {len(rows) - len(valid_rows)} non-dict rows") |
|
|
|
|
|
filtered_rows = [] |
|
|
for row in valid_rows: |
|
|
filtered_row = {} |
|
|
for col in columns: |
|
|
if col in row: |
|
|
filtered_row[col] = row[col] |
|
|
filtered_rows.append(filtered_row) |
|
|
return filtered_rows |
|
|
|
|
|
def _apply_group_by(self, rows: List[Dict], group_by: str) -> List[Dict]: |
|
|
"""Apply GROUP BY aggregation to rows""" |
|
|
if not group_by: |
|
|
return rows |
|
|
|
|
|
|
|
|
if not isinstance(rows, list): |
|
|
logger.warning(f"_apply_group_by: rows is not a list: {type(rows)}") |
|
|
return [] |
|
|
|
|
|
|
|
|
valid_rows = [r for r in rows if isinstance(r, dict)] |
|
|
if len(valid_rows) != len(rows): |
|
|
logger.warning(f"_apply_group_by: filtered out {len(rows) - len(valid_rows)} non-dict rows") |
|
|
|
|
|
group_col = group_by.strip() |
|
|
groups = {} |
|
|
for row in valid_rows: |
|
|
group_val = row.get(group_col, "Unknown") |
|
|
if group_val not in groups: |
|
|
groups[group_val] = [] |
|
|
groups[group_val].append(row) |
|
|
|
|
|
result = [] |
|
|
for group_val, group_rows in groups.items(): |
|
|
group_result = {group_col: group_val} |
|
|
|
|
|
group_result["employee_count"] = len(group_rows) |
|
|
salaries = [float(r.get("salary", 0)) for r in group_rows if r.get("salary") is not None] |
|
|
group_result["avg_salary"] = sum(salaries) / len(salaries) if salaries else 0 |
|
|
group_result["max_salary"] = max(salaries) if salaries else 0 |
|
|
result.append(group_result) |
|
|
|
|
|
return result |
|
|
|
|
|
def _execute_join_query(self, parsed: Dict) -> Dict[str, Any]: |
|
|
"""Execute JOIN query in file-based mode""" |
|
|
try: |
|
|
table1 = parsed["table1"] |
|
|
table2 = parsed["table2"] |
|
|
columns = parsed["columns"] |
|
|
join_condition = parsed["join_condition"] |
|
|
where = parsed.get("where") |
|
|
|
|
|
|
|
|
rows1 = self.tables.get(table1, []) |
|
|
rows2 = self.tables.get(table2, []) |
|
|
|
|
|
|
|
|
if not isinstance(rows1, list): |
|
|
logger.warning(f"Table {table1} contains non-list data: {type(rows1)}") |
|
|
rows1 = [] |
|
|
if not isinstance(rows2, list): |
|
|
logger.warning(f"Table {table2} contains non-list data: {type(rows2)}") |
|
|
rows2 = [] |
|
|
|
|
|
|
|
|
join_match = re.match(r"(\w+)\.(\w+)\s*=\s*(\w+)\.(\w+)", join_condition) |
|
|
if not join_match: |
|
|
return {"error": "Invalid join condition format"} |
|
|
|
|
|
col1, col2 = join_match.group(2), join_match.group(4) |
|
|
|
|
|
|
|
|
result_rows = [] |
|
|
for row1 in rows1: |
|
|
|
|
|
if not isinstance(row1, dict): |
|
|
logger.warning(f"Skipping non-dict row1 in JOIN: {type(row1)}") |
|
|
continue |
|
|
for row2 in rows2: |
|
|
|
|
|
if not isinstance(row2, dict): |
|
|
logger.warning(f"Skipping non-dict row2 in JOIN: {type(row2)}") |
|
|
continue |
|
|
if str(row1.get(col1, "")) == str(row2.get(col2, "")): |
|
|
|
|
|
combined_row = {} |
|
|
for col in columns: |
|
|
if '.' in col: |
|
|
|
|
|
table_alias, col_name = col.split('.', 1) |
|
|
if table_alias == parsed.get("alias1") or table_alias == table1: |
|
|
combined_row[col] = row1.get(col_name, "") |
|
|
elif table_alias == parsed.get("alias2") or table_alias == table2: |
|
|
combined_row[col] = row2.get(col_name, "") |
|
|
else: |
|
|
|
|
|
if col in row1: |
|
|
combined_row[col] = row1[col] |
|
|
elif col in row2: |
|
|
combined_row[col] = row2[col] |
|
|
result_rows.append(combined_row) |
|
|
|
|
|
|
|
|
if where: |
|
|
result_rows = self._apply_where_filter(result_rows, where) |
|
|
|
|
|
return result_rows |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error executing JOIN query: {str(e)}") |
|
|
return {"error": str(e)} |
|
|
|
|
|
def _execute_cross_join_query(self, parsed: Dict) -> Dict[str, Any]: |
|
|
"""Execute CROSS JOIN query in file-based mode""" |
|
|
try: |
|
|
table1 = parsed["table1"] |
|
|
table2 = parsed["table2"] |
|
|
columns = parsed["columns"] |
|
|
where = parsed.get("where") |
|
|
|
|
|
|
|
|
rows1 = self.tables.get(table1, []) |
|
|
rows2 = self.tables.get(table2, []) |
|
|
|
|
|
|
|
|
if not isinstance(rows1, list): |
|
|
logger.warning(f"Table {table1} contains non-list data: {type(rows1)}") |
|
|
rows1 = [] |
|
|
if not isinstance(rows2, list): |
|
|
logger.warning(f"Table {table2} contains non-list data: {type(rows2)}") |
|
|
rows2 = [] |
|
|
|
|
|
|
|
|
result_rows = [] |
|
|
for row1 in rows1: |
|
|
|
|
|
if not isinstance(row1, dict): |
|
|
logger.warning(f"Skipping non-dict row1 in CROSS JOIN: {type(row1)}") |
|
|
continue |
|
|
for row2 in rows2: |
|
|
|
|
|
if not isinstance(row2, dict): |
|
|
logger.warning(f"Skipping non-dict row2 in CROSS JOIN: {type(row2)}") |
|
|
continue |
|
|
|
|
|
combined_row = {} |
|
|
for col in columns: |
|
|
if '.' in col: |
|
|
|
|
|
table_alias, col_name = col.split('.', 1) |
|
|
if table_alias == parsed.get("alias1") or table_alias == table1: |
|
|
combined_row[col] = row1.get(col_name, "") |
|
|
elif table_alias == parsed.get("alias2") or table_alias == table2: |
|
|
combined_row[col] = row2.get(col_name, "") |
|
|
else: |
|
|
|
|
|
if col in row1: |
|
|
combined_row[col] = row1[col] |
|
|
elif col in row2: |
|
|
combined_row[col] = row2[col] |
|
|
result_rows.append(combined_row) |
|
|
|
|
|
|
|
|
if where: |
|
|
result_rows = self._apply_where_filter(result_rows, where) |
|
|
|
|
|
return result_rows |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error executing CROSS JOIN query: {str(e)}") |
|
|
return {"error": str(e)} |
|
|
|
|
|
def _get_database_type(self) -> DatabaseType: |
|
|
return DatabaseType.POSTGRESQL |
|
|
|
|
|
def connect(self) -> bool: |
|
|
return self._is_initialized |
|
|
|
|
|
def disconnect(self) -> bool: |
|
|
try: |
|
|
if self.conn: |
|
|
self.conn.close() |
|
|
self.conn = None |
|
|
self.cursor = None |
|
|
self._is_initialized = False |
|
|
logger.info("Disconnected from PostgreSQL") |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Error disconnecting: {str(e)}") |
|
|
return False |
|
|
|
|
|
def test_connection(self) -> bool: |
|
|
if self.file_based_mode: |
|
|
return self._is_initialized |
|
|
try: |
|
|
if self.conn: |
|
|
with self.conn.cursor() as cur: |
|
|
cur.execute("SELECT 1;") |
|
|
return True |
|
|
return False |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
def execute_query(self, query: Union[str, Dict, List], query_type: QueryType = None, **kwargs) -> Dict[str, Any]: |
|
|
if not self._is_initialized: |
|
|
return self.format_error_result("Database not initialized") |
|
|
|
|
|
|
|
|
if self.file_based_mode: |
|
|
return self._execute_file_based_query(query, query_type) |
|
|
|
|
|
|
|
|
if self.conn is None: |
|
|
return self.format_error_result("PostgreSQL server not available") |
|
|
|
|
|
start_time = time.time() |
|
|
try: |
|
|
with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: |
|
|
|
|
|
if isinstance(query, str): |
|
|
|
|
|
cur.execute(query) |
|
|
elif isinstance(query, dict): |
|
|
|
|
|
sql = query.get("sql") |
|
|
params = query.get("params", None) |
|
|
if params: |
|
|
cur.execute(sql, params) |
|
|
else: |
|
|
cur.execute(sql) |
|
|
elif isinstance(query, list): |
|
|
|
|
|
for q in query: |
|
|
if isinstance(q, str): |
|
|
cur.execute(q) |
|
|
elif isinstance(q, dict): |
|
|
sql = q.get("sql") |
|
|
params = q.get("params", None) |
|
|
if params: |
|
|
cur.execute(sql, params) |
|
|
else: |
|
|
cur.execute(sql) |
|
|
else: |
|
|
return self.format_error_result("Unsupported query format", query_type) |
|
|
|
|
|
|
|
|
if cur.description: |
|
|
result = cur.fetchall() |
|
|
else: |
|
|
result = {"rowcount": cur.rowcount} |
|
|
|
|
|
self.conn.commit() |
|
|
|
|
|
execution_time = time.time() - start_time |
|
|
return self.format_query_result(result, query_type or QueryType.SELECT, execution_time=execution_time) |
|
|
|
|
|
except Exception as e: |
|
|
execution_time = time.time() - start_time |
|
|
logger.error(f"Error executing PostgreSQL query: {str(e)}") |
|
|
|
|
|
try: |
|
|
if self.conn: |
|
|
self.conn.rollback() |
|
|
except Exception as rollback_error: |
|
|
logger.warning(f"Error during rollback: {str(rollback_error)}") |
|
|
return self.format_error_result(str(e), query_type, execution_time=execution_time) |
|
|
|
|
|
def _execute_file_based_query(self, query: Union[str, Dict, List], query_type: QueryType = None) -> Dict[str, Any]: |
|
|
"""Execute query in file-based mode""" |
|
|
start_time = time.time() |
|
|
try: |
|
|
if isinstance(query, str): |
|
|
parsed = self._parse_sql_query(query) |
|
|
query_type = query_type or QueryType.SELECT |
|
|
|
|
|
|
|
|
if not isinstance(parsed, dict) or "type" not in parsed: |
|
|
logger.error(f"_execute_file_based_query: parsed is not a valid dict: {parsed}") |
|
|
return self.format_error_result(f"Failed to parse SQL query: {query}", query_type) |
|
|
|
|
|
logger.debug(f"Executing {parsed['type']} query: {parsed}") |
|
|
|
|
|
if parsed["type"] == "CREATE": |
|
|
table_name = parsed["table"] |
|
|
columns = parsed.get("columns", ["id"]) |
|
|
if table_name not in self.tables: |
|
|
self.tables[table_name] = [] |
|
|
|
|
|
if not isinstance(self.tables[table_name], list): |
|
|
logger.warning(f"Reinitializing table {table_name} as list (was {type(self.tables[table_name])})") |
|
|
self.tables[table_name] = [] |
|
|
|
|
|
self.tables[f"__schema__{table_name}"] = columns |
|
|
if self.auto_save: |
|
|
self._save_table_to_file(table_name) |
|
|
result = {"rowcount": 0} |
|
|
elif parsed["type"] == "INSERT": |
|
|
table_name = parsed["table"] |
|
|
columns = parsed["columns"] |
|
|
all_values = parsed["values"] |
|
|
if table_name not in self.tables: |
|
|
self.tables[table_name] = [] |
|
|
|
|
|
if not isinstance(self.tables[table_name], list): |
|
|
logger.warning(f"Reinitializing table {table_name} as list (was {type(self.tables[table_name])})") |
|
|
self.tables[table_name] = [] |
|
|
|
|
|
valid_rows = 0 |
|
|
|
|
|
for values in all_values: |
|
|
|
|
|
if len(values) != len(columns): |
|
|
logger.warning(f"Skipping invalid row: {values} (expected {len(columns)} values, got {len(values)})") |
|
|
continue |
|
|
|
|
|
if not isinstance(values, list): |
|
|
logger.warning(f"Skipping non-list values: {type(values)}") |
|
|
continue |
|
|
row = {col: val for col, val in zip(columns, values)} |
|
|
row["id"] = len(self.tables[table_name]) + 1 |
|
|
self.tables[table_name].append(row) |
|
|
valid_rows += 1 |
|
|
|
|
|
if self.auto_save: |
|
|
self._save_table_to_file(table_name) |
|
|
result = {"rowcount": valid_rows} |
|
|
elif parsed["type"] == "SELECT": |
|
|
table_name = parsed["table"] |
|
|
columns = parsed["columns"] |
|
|
where = parsed.get("where") |
|
|
group_by = parsed.get("group_by") |
|
|
rows = self.tables.get(table_name, []) |
|
|
|
|
|
if not isinstance(rows, list): |
|
|
logger.warning(f"Table {table_name} contains non-list data: {type(rows)}") |
|
|
rows = [] |
|
|
|
|
|
|
|
|
logger.debug(f"SELECT query: table={table_name}, columns={columns}, where={where}, group_by={group_by}") |
|
|
logger.debug(f"Rows from table: {type(rows)}, length={len(rows) if isinstance(rows, list) else 'N/A'}") |
|
|
if isinstance(rows, list) and rows: |
|
|
logger.debug(f"First row type: {type(rows[0])}, content: {rows[0]}") |
|
|
|
|
|
|
|
|
if where: |
|
|
rows = self._apply_where_filter(rows, where) |
|
|
|
|
|
|
|
|
if group_by: |
|
|
result = self._apply_group_by(rows, group_by) |
|
|
else: |
|
|
|
|
|
result = {"data": self._apply_column_selection(rows, columns)} |
|
|
|
|
|
elif parsed["type"] == "SELECT_JOIN": |
|
|
|
|
|
logger.debug(f"Executing JOIN query: {parsed}") |
|
|
join_result = self._execute_join_query(parsed) |
|
|
if isinstance(join_result, dict) and "error" in join_result: |
|
|
result = {"error": join_result["error"]} |
|
|
else: |
|
|
result = {"data": join_result} |
|
|
|
|
|
elif parsed["type"] == "SELECT_CROSS_JOIN": |
|
|
|
|
|
logger.debug(f"Executing CROSS JOIN query: {parsed}") |
|
|
cross_join_result = self._execute_cross_join_query(parsed) |
|
|
if isinstance(cross_join_result, dict) and "error" in cross_join_result: |
|
|
result = {"error": cross_join_result["error"]} |
|
|
else: |
|
|
result = {"data": cross_join_result} |
|
|
elif parsed["type"] == "UPDATE": |
|
|
table_name = parsed["table"] |
|
|
set_clause = parsed["set"] |
|
|
where = parsed.get("where") |
|
|
rows = self.tables.get(table_name, []) |
|
|
|
|
|
if not isinstance(rows, list): |
|
|
logger.warning(f"Table {table_name} contains non-list data: {type(rows)}") |
|
|
rows = [] |
|
|
|
|
|
updates = dict(re.findall(r"(\w+) *= *'?([\w@.\- ]+)'?", set_clause)) |
|
|
count = 0 |
|
|
for r in rows: |
|
|
|
|
|
if not isinstance(r, dict): |
|
|
logger.warning(f"Skipping non-dict row in UPDATE: {type(r)}") |
|
|
continue |
|
|
match = True |
|
|
if where: |
|
|
m = re.match(r"(\w+) *([=><]+) *'?([\w@.\- ]+)'?", where) |
|
|
if m: |
|
|
col, op, val = m.group(1), m.group(2), m.group(3) |
|
|
if op == "=" and str(r.get(col, "")) != val: |
|
|
match = False |
|
|
elif op == ">" and int(r.get(col, 0)) <= int(val): |
|
|
match = False |
|
|
elif op == "<" and int(r.get(col, 0)) >= int(val): |
|
|
match = False |
|
|
if match: |
|
|
r.update(updates) |
|
|
count += 1 |
|
|
if self.auto_save: |
|
|
self._save_table_to_file(table_name) |
|
|
result = {"rowcount": count} |
|
|
elif parsed["type"] == "DELETE": |
|
|
table_name = parsed["table"] |
|
|
where = parsed.get("where") |
|
|
rows = self.tables.get(table_name, []) |
|
|
|
|
|
if not isinstance(rows, list): |
|
|
logger.warning(f"Table {table_name} contains non-list data: {type(rows)}") |
|
|
rows = [] |
|
|
if where: |
|
|
m = re.match(r"(\w+) *([=><]+) *'?([\w@.\- ]+)'?", where) |
|
|
if m: |
|
|
col, op, val = m.group(1), m.group(2), m.group(3) |
|
|
if op == "=": |
|
|
new_rows = [r for r in rows if isinstance(r, dict) and str(r.get(col, "")) != val] |
|
|
elif op == ">": |
|
|
try: |
|
|
val_num = int(val) |
|
|
new_rows = [r for r in rows if isinstance(r, dict) and int(r.get(col, 0)) <= val_num] |
|
|
except ValueError: |
|
|
new_rows = rows |
|
|
else: |
|
|
new_rows = rows |
|
|
deleted_count = len(rows) - len(new_rows) |
|
|
self.tables[table_name] = new_rows |
|
|
else: |
|
|
deleted_count = 0 |
|
|
else: |
|
|
deleted_count = len(rows) |
|
|
self.tables[table_name] = [] |
|
|
if self.auto_save: |
|
|
self._save_table_to_file(table_name) |
|
|
result = {"rowcount": deleted_count} |
|
|
else: |
|
|
return self.format_error_result("Unsupported query type in file-based mode", query_type) |
|
|
execution_time = time.time() - start_time |
|
|
return self.format_query_result(result, query_type, execution_time=execution_time) |
|
|
else: |
|
|
return self.format_error_result("Unsupported query format in file-based mode", query_type) |
|
|
except Exception as e: |
|
|
execution_time = time.time() - start_time |
|
|
logger.error(f"Error executing file-based query: {str(e)}") |
|
|
logger.error(f"Query that caused error: {query}") |
|
|
logger.error(f"Query type: {query_type}") |
|
|
import traceback |
|
|
logger.error(f"Traceback: {traceback.format_exc()}") |
|
|
return self.format_error_result(str(e), query_type, execution_time=execution_time) |
|
|
|
|
|
def get_database_info(self) -> Dict[str, Any]: |
|
|
try: |
|
|
if not self._is_initialized: |
|
|
return self.format_error_result("Database not initialized") |
|
|
|
|
|
if self.file_based_mode: |
|
|
info = { |
|
|
"database": self.database_name, |
|
|
"user": "file_based", |
|
|
"table_count": len(self.tables), |
|
|
"connection_string": "file_based", |
|
|
"is_connected": True, |
|
|
"mode": "file_based" |
|
|
} |
|
|
else: |
|
|
if self.conn is None: |
|
|
return self.format_error_result("PostgreSQL server not available") |
|
|
|
|
|
with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: |
|
|
cur.execute("SELECT current_database() as database, current_user as user") |
|
|
db_info = cur.fetchone() |
|
|
cur.execute("SELECT COUNT(*) as table_count FROM information_schema.tables WHERE table_schema = 'public'") |
|
|
table_count = cur.fetchone()["table_count"] |
|
|
info = { |
|
|
"database": db_info["database"], |
|
|
"user": db_info["user"], |
|
|
"table_count": table_count, |
|
|
"connection_string": self.connection_string, |
|
|
"is_connected": self._is_initialized |
|
|
} |
|
|
return self.format_query_result(info, QueryType.SELECT) |
|
|
except Exception as e: |
|
|
return self.format_error_result(str(e)) |
|
|
|
|
|
def list_collections(self) -> List[str]: |
|
|
try: |
|
|
if self.file_based_mode: |
|
|
return list(self.tables.keys()) |
|
|
if not self._is_initialized or self.conn is None: |
|
|
return [] |
|
|
with self.conn.cursor() as cur: |
|
|
cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'") |
|
|
tables = [row[0] for row in cur.fetchall()] |
|
|
return tables |
|
|
except Exception as e: |
|
|
logger.error(f"Error listing tables: {str(e)}") |
|
|
return [] |
|
|
|
|
|
def get_collection_info(self, collection_name: str) -> Dict[str, Any]: |
|
|
try: |
|
|
if not self._is_initialized: |
|
|
return self.format_error_result("Database not initialized") |
|
|
|
|
|
if self.file_based_mode: |
|
|
if collection_name in self.tables: |
|
|
row_count = len(self.tables[collection_name]) |
|
|
info = { |
|
|
"table_name": collection_name, |
|
|
"row_count": row_count, |
|
|
"columns": ["id"] |
|
|
} |
|
|
else: |
|
|
return self.format_error_result(f"Table {collection_name} not found") |
|
|
else: |
|
|
if self.conn is None: |
|
|
return self.format_error_result("PostgreSQL server not available") |
|
|
|
|
|
with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: |
|
|
cur.execute(f"SELECT COUNT(*) as row_count FROM {collection_name}") |
|
|
row_count = cur.fetchone()["row_count"] |
|
|
cur.execute("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s", (collection_name,)) |
|
|
columns = cur.fetchall() |
|
|
info = { |
|
|
"table_name": collection_name, |
|
|
"row_count": row_count, |
|
|
"columns": columns |
|
|
} |
|
|
return self.format_query_result(info, QueryType.SELECT) |
|
|
except Exception as e: |
|
|
return self.format_error_result(str(e)) |
|
|
|
|
|
def get_schema(self, collection_name: str = None) -> Dict[str, Any]: |
|
|
try: |
|
|
if not self._is_initialized: |
|
|
return self.format_error_result("Database not initialized") |
|
|
|
|
|
if self.file_based_mode: |
|
|
if collection_name: |
|
|
if collection_name in self.tables: |
|
|
schema = {"id": "integer"} |
|
|
return self.format_query_result({"table_name": collection_name, "schema": schema}, QueryType.SELECT) |
|
|
else: |
|
|
return self.format_error_result(f"Table {collection_name} not found") |
|
|
else: |
|
|
schemas = {} |
|
|
for table_name in self.tables: |
|
|
schemas[table_name] = {"id": "integer"} |
|
|
return self.format_query_result({"database_name": self.database_name, "schemas": schemas}, QueryType.SELECT) |
|
|
else: |
|
|
if self.conn is None: |
|
|
return self.format_error_result("PostgreSQL server not available") |
|
|
|
|
|
with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: |
|
|
if collection_name: |
|
|
cur.execute("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s", (collection_name,)) |
|
|
columns = cur.fetchall() |
|
|
schema = {col["column_name"]: col["data_type"] for col in columns} |
|
|
return self.format_query_result({"table_name": collection_name, "schema": schema}, QueryType.SELECT) |
|
|
else: |
|
|
cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'") |
|
|
tables = [row[0] for row in cur.fetchall()] |
|
|
schemas = {} |
|
|
for table in tables: |
|
|
cur.execute("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s", (table,)) |
|
|
columns = cur.fetchall() |
|
|
schemas[table] = {col["column_name"]: col["data_type"] for col in columns} |
|
|
return self.format_query_result({"database_name": self.database_name, "schemas": schemas}, QueryType.SELECT) |
|
|
except Exception as e: |
|
|
return self.format_error_result(str(e)) |
|
|
|
|
|
def get_supported_query_types(self) -> List[QueryType]: |
|
|
return [ |
|
|
QueryType.SELECT, |
|
|
QueryType.INSERT, |
|
|
QueryType.UPDATE, |
|
|
QueryType.DELETE, |
|
|
QueryType.CREATE, |
|
|
QueryType.DROP, |
|
|
QueryType.ALTER, |
|
|
QueryType.INDEX |
|
|
] |
|
|
|
|
|
def get_capabilities(self) -> Dict[str, Any]: |
|
|
base_capabilities = super().get_capabilities() |
|
|
base_capabilities.update({ |
|
|
"supports_sql": True, |
|
|
"supports_transactions": not self.file_based_mode, |
|
|
"supports_indexing": not self.file_based_mode, |
|
|
"schema_flexible": self.file_based_mode, |
|
|
"file_based_mode": self.file_based_mode |
|
|
}) |
|
|
return base_capabilities |
|
|
|
|
|
|
|
|
class PostgreSQLExecuteTool(Tool): |
|
|
name: str = "postgresql_execute" |
|
|
description: str = "Execute arbitrary SQL queries on PostgreSQL." |
|
|
inputs: Dict[str, Dict[str, str]] = { |
|
|
"query": {"type": "string", "description": "SQL query to execute (can be SELECT, INSERT, UPDATE, DELETE, etc.)"}, |
|
|
"query_type": {"type": "string", "description": "Type of query (select, insert, update, delete, create, drop, alter, index) - auto-detected if not provided"} |
|
|
} |
|
|
required: Optional[List[str]] = ["query"] |
|
|
def __init__(self, database: PostgreSQLDatabase = None): |
|
|
super().__init__() |
|
|
self.database = database |
|
|
def __call__(self, query: str, query_type: str = None) -> Dict[str, Any]: |
|
|
try: |
|
|
if not self.database: |
|
|
return {"success": False, "error": "PostgreSQL database not initialized", "data": None} |
|
|
|
|
|
|
|
|
|
|
|
query_type_enum = None |
|
|
if query_type: |
|
|
try: |
|
|
query_type_enum = QueryType(query_type.lower()) |
|
|
except ValueError: |
|
|
return {"success": False, "error": f"Invalid query type: {query_type}", "data": None} |
|
|
|
|
|
result = self.database.execute_query(query=query, query_type=query_type_enum) |
|
|
return result |
|
|
except Exception as e: |
|
|
logger.error(f"Error in postgresql_execute tool: {str(e)}") |
|
|
return {"success": False, "error": str(e), "data": None} |
|
|
|
|
|
class PostgreSQLFindTool(Tool): |
|
|
name: str = "postgresql_find" |
|
|
description: str = "Find (SELECT) rows from a PostgreSQL table." |
|
|
inputs: Dict[str, Dict[str, str]] = { |
|
|
"table_name": {"type": "string", "description": "Table name to query"}, |
|
|
"where": {"type": "string", "description": "WHERE clause (optional, e.g., 'age > 18')"}, |
|
|
"columns": {"type": "string", "description": "Comma-separated columns to select (default '*')"}, |
|
|
"limit": {"type": "integer", "description": "Maximum number of rows to return (optional)"}, |
|
|
"offset": {"type": "integer", "description": "Number of rows to skip (optional)"}, |
|
|
"sort": {"type": "string", "description": "ORDER BY clause (optional, e.g., 'age ASC')"} |
|
|
} |
|
|
required: Optional[List[str]] = ["table_name"] |
|
|
def __init__(self, database: PostgreSQLDatabase = None): |
|
|
super().__init__() |
|
|
self.database = database |
|
|
def __call__(self, table_name: str, where: str = None, columns: str = "*", limit: int = None, offset: int = None, sort: str = None) -> Dict[str, Any]: |
|
|
try: |
|
|
if not self.database: |
|
|
return {"success": False, "error": "PostgreSQL database not initialized", "data": None} |
|
|
sql = f"SELECT {columns} FROM {table_name}" |
|
|
if where: |
|
|
sql += f" WHERE {where}" |
|
|
if sort: |
|
|
sql += f" ORDER BY {sort}" |
|
|
if limit is not None: |
|
|
sql += f" LIMIT {limit}" |
|
|
if offset is not None: |
|
|
sql += f" OFFSET {offset}" |
|
|
result = self.database.execute_query(sql, QueryType.SELECT) |
|
|
return result |
|
|
except Exception as e: |
|
|
logger.error(f"Error in postgresql_find tool: {str(e)}") |
|
|
return {"success": False, "error": str(e), "data": None} |
|
|
|
|
|
class PostgreSQLUpdateTool(Tool): |
|
|
name: str = "postgresql_update" |
|
|
description: str = "Update rows in a PostgreSQL table." |
|
|
inputs: Dict[str, Dict[str, str]] = { |
|
|
"table_name": {"type": "string", "description": "Table name to update"}, |
|
|
"set": {"type": "string", "description": "SET clause (e.g., 'status = \'active\'')"}, |
|
|
"where": {"type": "string", "description": "WHERE clause (optional)"} |
|
|
} |
|
|
required: Optional[List[str]] = ["table_name", "set"] |
|
|
def __init__(self, database: PostgreSQLDatabase = None): |
|
|
super().__init__() |
|
|
self.database = database |
|
|
def __call__(self, table_name: str, set: str, where: str = None) -> Dict[str, Any]: |
|
|
try: |
|
|
if not self.database: |
|
|
return {"success": False, "error": "PostgreSQL database not initialized", "data": None} |
|
|
sql = f"UPDATE {table_name} SET {set}" |
|
|
if where: |
|
|
sql += f" WHERE {where}" |
|
|
result = self.database.execute_query(sql, QueryType.UPDATE) |
|
|
return result |
|
|
except Exception as e: |
|
|
logger.error(f"Error in postgresql_update tool: {str(e)}") |
|
|
return {"success": False, "error": str(e), "data": None} |
|
|
|
|
|
class PostgreSQLCreateTool(Tool): |
|
|
name: str = "postgresql_create" |
|
|
description: str = "Create a table or other object in PostgreSQL." |
|
|
inputs: Dict[str, Dict[str, str]] = { |
|
|
"query": {"type": "string", "description": "CREATE statement (e.g., CREATE TABLE ...)"} |
|
|
} |
|
|
required: Optional[List[str]] = ["query"] |
|
|
def __init__(self, database: PostgreSQLDatabase = None): |
|
|
super().__init__() |
|
|
self.database = database |
|
|
def __call__(self, query: str) -> Dict[str, Any]: |
|
|
try: |
|
|
if not self.database: |
|
|
return {"success": False, "error": "PostgreSQL database not initialized", "data": None} |
|
|
result = self.database.execute_query(query, QueryType.CREATE) |
|
|
return result |
|
|
except Exception as e: |
|
|
logger.error(f"Error in postgresql_create tool: {str(e)}") |
|
|
return {"success": False, "error": str(e), "data": None} |
|
|
|
|
|
class PostgreSQLDeleteTool(Tool): |
|
|
name: str = "postgresql_delete" |
|
|
description: str = "Delete rows from a PostgreSQL table." |
|
|
inputs: Dict[str, Dict[str, str]] = { |
|
|
"table_name": {"type": "string", "description": "Table name to delete from"}, |
|
|
"where": {"type": "string", "description": "WHERE clause (optional)"} |
|
|
} |
|
|
required: Optional[List[str]] = ["table_name"] |
|
|
def __init__(self, database: PostgreSQLDatabase = None): |
|
|
super().__init__() |
|
|
self.database = database |
|
|
def __call__(self, table_name: str, where: str = None) -> Dict[str, Any]: |
|
|
try: |
|
|
if not self.database: |
|
|
return {"success": False, "error": "PostgreSQL database not initialized", "data": None} |
|
|
sql = f"DELETE FROM {table_name}" |
|
|
if where: |
|
|
sql += f" WHERE {where}" |
|
|
result = self.database.execute_query(sql, QueryType.DELETE) |
|
|
return result |
|
|
except Exception as e: |
|
|
logger.error(f"Error in postgresql_delete tool: {str(e)}") |
|
|
return {"success": False, "error": str(e), "data": None} |
|
|
|
|
|
class PostgreSQLInfoTool(Tool): |
|
|
name: str = "postgresql_info" |
|
|
description: str = "Get PostgreSQL database and table information." |
|
|
inputs: Dict[str, Dict[str, str]] = { |
|
|
"info_type": {"type": "string", "description": "Type of information (database, tables, table, schema, capabilities)"}, |
|
|
"table_name": {"type": "string", "description": "Table name for table-specific info (optional)"} |
|
|
} |
|
|
required: Optional[List[str]] = [] |
|
|
def __init__(self, database: PostgreSQLDatabase = None): |
|
|
super().__init__() |
|
|
self.database = database |
|
|
def __call__(self, info_type: str = "database", table_name: str = None) -> Dict[str, Any]: |
|
|
try: |
|
|
if not self.database: |
|
|
return {"success": False, "error": "PostgreSQL database not initialized", "data": None} |
|
|
info_type = info_type.lower() |
|
|
if info_type == "database": |
|
|
result = self.database.get_database_info() |
|
|
elif info_type == "tables": |
|
|
tables = self.database.list_collections() |
|
|
result = {"success": True, "data": tables, "table_count": len(tables)} |
|
|
elif info_type == "table" and table_name: |
|
|
result = self.database.get_collection_info(table_name) |
|
|
elif info_type == "schema": |
|
|
result = self.database.get_schema(table_name) |
|
|
elif info_type == "capabilities": |
|
|
result = {"success": True, "data": self.database.get_capabilities()} |
|
|
else: |
|
|
return {"success": False, "error": f"Invalid info type: {info_type}", "data": None} |
|
|
return result |
|
|
except Exception as e: |
|
|
logger.error(f"Error in postgresql_info tool: {str(e)}") |
|
|
return {"success": False, "error": str(e), "data": None} |
|
|
|
|
|
class PostgreSQLToolkit(Toolkit): |
|
|
def __init__(self, |
|
|
name: str = "PostgreSQLToolkit", |
|
|
connection_string: str = None, |
|
|
database_name: str = None, |
|
|
local_path: str = None, |
|
|
auto_save: bool = True, |
|
|
**kwargs): |
|
|
database = PostgreSQLDatabase( |
|
|
connection_string=connection_string, |
|
|
database_name=database_name, |
|
|
local_path=local_path, |
|
|
auto_save=auto_save, |
|
|
**kwargs |
|
|
) |
|
|
tools = [ |
|
|
PostgreSQLExecuteTool(database=database), |
|
|
PostgreSQLFindTool(database=database), |
|
|
PostgreSQLUpdateTool(database=database), |
|
|
PostgreSQLCreateTool(database=database), |
|
|
PostgreSQLDeleteTool(database=database), |
|
|
PostgreSQLInfoTool(database=database) |
|
|
] |
|
|
super().__init__(name=name, tools=tools) |
|
|
self.database = database |
|
|
self.connection_string = connection_string |
|
|
self.database_name = database_name |
|
|
self.local_path = local_path |
|
|
self.auto_save = auto_save |
|
|
import atexit |
|
|
atexit.register(self._cleanup) |
|
|
def _cleanup(self): |
|
|
try: |
|
|
if self.database: |
|
|
self.database.disconnect() |
|
|
logger.info("Disconnected from PostgreSQL database") |
|
|
except Exception as e: |
|
|
logger.warning(f"Error during cleanup: {str(e)}") |
|
|
def get_capabilities(self) -> Dict[str, Any]: |
|
|
if self.database: |
|
|
capabilities = self.database.get_capabilities() |
|
|
capabilities.update({ |
|
|
"is_local_database": self.database.is_local_database, |
|
|
"local_path": str(self.database.local_path) if self.database.local_path else None, |
|
|
"auto_save": self.database.auto_save |
|
|
}) |
|
|
return capabilities |
|
|
return {"error": "PostgreSQL database not initialized"} |
|
|
def connect(self) -> bool: |
|
|
return self.database.connect() if self.database else False |
|
|
def disconnect(self) -> bool: |
|
|
return self.database.disconnect() if self.database else False |
|
|
def test_connection(self) -> bool: |
|
|
return self.database.test_connection() if self.database else False |
|
|
def get_database(self) -> PostgreSQLDatabase: |
|
|
return self.database |
|
|
def get_local_info(self) -> Dict[str, Any]: |
|
|
return { |
|
|
"is_local_database": self.database.is_local_database, |
|
|
"local_path": str(self.database.local_path) if self.database.local_path else None, |
|
|
"auto_save": self.database.auto_save, |
|
|
"database_name": self.database_name, |
|
|
"connection_string": self.connection_string |
|
|
} if self.database else {"error": "Database not initialized"} |