|
|
import os |
|
|
from typing import Dict, List, Any, Optional |
|
|
from dotenv import load_dotenv |
|
|
import psycopg2 |
|
|
from psycopg2.extras import RealDictCursor |
|
|
import pandas as pd |
|
|
from mcp.server.fastmcp import FastMCP |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
LIST_SCHEMA=os.getenv('LIST_SCHEMA') |
|
|
LIST_DATABASE_INFOS=os.getenv('LIST_DATABASE_INFOS') |
|
|
TABLE_IN_SCHEMA=os.getenv('TABLE_IN_SCHEMA') |
|
|
COLUMN_IN_TABLE=os.getenv('COLUMN_IN_TABLE') |
|
|
|
|
|
|
|
|
class DatabaseInterface: |
|
|
def __init__(self, db_config: Optional[Dict[str, Any]] = None): |
|
|
|
|
|
self.mcp = FastMCP("ecommerce-mcp-server") |
|
|
|
|
|
if db_config: |
|
|
self.db_config = db_config |
|
|
else: |
|
|
|
|
|
self.db_config = { |
|
|
'host': os.getenv('DB_HOST'), |
|
|
'port': int(os.getenv('DB_PORT', 5432)), |
|
|
'database': os.getenv('DB_NAME'), |
|
|
'user': os.getenv('DB_USER'), |
|
|
'password': os.getenv('DB_PASSWORD') |
|
|
} |
|
|
|
|
|
|
|
|
required_fields = ['host', 'database', 'user', 'password'] |
|
|
missing_fields = [field for field in required_fields if not self.db_config.get(field)] |
|
|
if missing_fields: |
|
|
raise ValueError(f"Missing required database configuration: {missing_fields}") |
|
|
|
|
|
def get_db_connection(self): |
|
|
"""Create database connection with error handling""" |
|
|
try: |
|
|
return psycopg2.connect(**self.db_config) |
|
|
except psycopg2.Error as e: |
|
|
raise ConnectionError(f"Failed to connect to database: {str(e)}") |
|
|
|
|
|
def list_schemas(self): |
|
|
print("=======>", LIST_SCHEMA) |
|
|
sql_path = Path(LIST_SCHEMA) |
|
|
with sql_path.open("r", encoding="utf-8") as f: |
|
|
query = f.read() |
|
|
|
|
|
conn = self.get_db_connection() |
|
|
try: |
|
|
with conn.cursor() as cur: |
|
|
cur.execute(query) |
|
|
result = cur.fetchone()[0] |
|
|
return result |
|
|
finally: |
|
|
conn.close() |
|
|
|
|
|
def list_database_info(self): |
|
|
sql_path = Path(LIST_DATABASE_INFOS) |
|
|
with sql_path.open("r", encoding="utf-8") as f: |
|
|
query = f.read() |
|
|
|
|
|
conn = self.get_db_connection() |
|
|
try: |
|
|
with conn.cursor() as cur: |
|
|
cur.execute(query) |
|
|
result = cur.fetchone()[0] |
|
|
return result |
|
|
finally: |
|
|
conn.close() |
|
|
|
|
|
def list_tables_in_schema(self, schema_name: str): |
|
|
sql_path = Path(TABLE_IN_SCHEMA) |
|
|
with sql_path.open("r", encoding="utf-8") as f: |
|
|
query = f.read() |
|
|
|
|
|
conn = self.get_db_connection() |
|
|
try: |
|
|
with conn.cursor() as cur: |
|
|
cur.execute(query, {'schema_name': schema_name}) |
|
|
result = cur.fetchone()[0] |
|
|
return result |
|
|
finally: |
|
|
conn.close() |
|
|
|
|
|
def list_columns_in_table(self, schema_name: str, table_name: str): |
|
|
sql_path = Path(COLUMN_IN_TABLE) |
|
|
with sql_path.open("r", encoding="utf-8") as f: |
|
|
query = f.read() |
|
|
|
|
|
conn = self.get_db_connection() |
|
|
try: |
|
|
with conn.cursor() as cur: |
|
|
cur.execute(query, { |
|
|
'schema_name': schema_name, |
|
|
'table_name': table_name |
|
|
}) |
|
|
result = cur.fetchone()[0] |
|
|
return result |
|
|
finally: |
|
|
conn.close() |
|
|
|
|
|
def create_view(self, view_name: str, query: str, validate_only: bool = False): |
|
|
"""Create or replace a SQL view with proper error handling and validation""" |
|
|
|
|
|
if not view_name or not view_name.strip(): |
|
|
return "❌ View name cannot be empty" |
|
|
if not query or not query.strip(): |
|
|
return "❌ Query cannot be empty" |
|
|
|
|
|
if not view_name.replace('_', '').replace('-', '').isalnum(): |
|
|
return "❌ View name contains invalid characters. Use only letters, numbers, underscores, and hyphens" |
|
|
if validate_only: |
|
|
|
|
|
try: |
|
|
conn = self.get_db_connection() |
|
|
with conn.cursor() as cur: |
|
|
cur.execute(f"EXPLAIN {query}") |
|
|
return f"✅ Query is valid for view '{view_name}'" |
|
|
except Exception as e: |
|
|
return f"❌ Invalid query: {str(e)}" |
|
|
finally: |
|
|
conn.close() |
|
|
try: |
|
|
conn = self.get_db_connection() |
|
|
try: |
|
|
with conn.cursor() as cur: |
|
|
|
|
|
drop_view_query = f"DROP VIEW IF EXISTS {view_name} CASCADE" |
|
|
cur.execute(drop_view_query) |
|
|
|
|
|
|
|
|
create_view_query = f"CREATE VIEW {view_name} AS {query}" |
|
|
cur.execute(create_view_query) |
|
|
|
|
|
|
|
|
conn.commit() |
|
|
return f"✅ View '{view_name}' created successfully" |
|
|
|
|
|
except psycopg2.Error as db_error: |
|
|
|
|
|
conn.rollback() |
|
|
return f"❌ Database error creating view '{view_name}': {str(db_error)}" |
|
|
except Exception as e: |
|
|
|
|
|
conn.rollback() |
|
|
return f"❌ Error creating view '{view_name}': {str(e)}" |
|
|
finally: |
|
|
conn.close() |
|
|
|
|
|
except Exception as connection_error: |
|
|
return f"❌ Connection error: {str(connection_error)}" |
|
|
|
|
|
def execute_sql_file(self, file_path: str): |
|
|
"""Execute SQL statements from a file""" |
|
|
sql_path = Path(file_path) |
|
|
if not sql_path.exists(): |
|
|
return f"❌ SQL file not found: {file_path}" |
|
|
|
|
|
try: |
|
|
with sql_path.open("r", encoding="utf-8") as f: |
|
|
sql_content = f.read() |
|
|
conn = self.get_db_connection() |
|
|
try: |
|
|
with conn.cursor() as cur: |
|
|
|
|
|
sql_commands = [cmd.strip() for cmd in sql_content.split(';') if cmd.strip()] |
|
|
executed_commands = [] |
|
|
for command in sql_commands: |
|
|
if command: |
|
|
cur.execute(command) |
|
|
executed_commands.append(command.split()[0:3]) |
|
|
|
|
|
conn.commit() |
|
|
return f"✅ Successfully executed {len(executed_commands)} SQL commands" |
|
|
|
|
|
except Exception as e: |
|
|
conn.rollback() |
|
|
return f"❌ Error executing SQL file: {str(e)}" |
|
|
finally: |
|
|
conn.close() |
|
|
|
|
|
except Exception as e: |
|
|
return f"❌ Error reading SQL file: {str(e)}" |
|
|
|
|
|
def read_only_query(self, query): |
|
|
try: |
|
|
conn = self.get_db_connection() |
|
|
with conn.cursor() as cur: |
|
|
cur.execute("SET TRANSACTION READ ONLY") |
|
|
cur.execute(query) |
|
|
result = cur.fetchall() |
|
|
return result |
|
|
finally: |
|
|
conn.close() |
|
|
|
|
|
def create_table_from_query(self, table_name: str, source_query: str, drop_if_exists: bool = True) -> str: |
|
|
"""Create permanent table from any SELECT query""" |
|
|
try: |
|
|
conn = self.get_db_connection() |
|
|
try: |
|
|
with conn.cursor() as cur: |
|
|
|
|
|
if drop_if_exists and (table_name != "transactions" and table_name != "customers" and table_name != "articles"): |
|
|
drop_query = f"DROP TABLE IF EXISTS {table_name}" |
|
|
cur.execute(drop_query) |
|
|
|
|
|
|
|
|
create_query = f"CREATE TABLE {table_name} AS {source_query}" |
|
|
cur.execute(create_query) |
|
|
conn.commit() |
|
|
|
|
|
|
|
|
cur.execute(f"SELECT COUNT(*) FROM {table_name}") |
|
|
count = cur.fetchone()[0] |
|
|
|
|
|
print(f"✅ Table '{table_name}' created successfully with {count} rows") |
|
|
return f"✅ Table '{table_name}' created successfully with {count} rows" |
|
|
|
|
|
except Exception as e: |
|
|
conn.rollback() |
|
|
return f"❌ Error creating table: {str(e)}" |
|
|
finally: |
|
|
conn.close() |
|
|
except Exception as e: |
|
|
return f"❌ Connection error: {str(e)}" |
|
|
|
|
|
def drop_table(self, table_name: str, cascade: bool = False) -> str: |
|
|
"""Drop a table""" |
|
|
try: |
|
|
conn = self.get_db_connection() |
|
|
try: |
|
|
with conn.cursor() as cur: |
|
|
cascade_clause = " CASCADE" if cascade else "" |
|
|
if table_name != "transactions" and table_name != "customers" and table_name != "articles": |
|
|
drop_query = f"DROP TABLE IF EXISTS {table_name}{cascade_clause}" |
|
|
cur.execute(drop_query) |
|
|
conn.commit() |
|
|
return f"✅ Table '{table_name}' dropped successfully" |
|
|
else: |
|
|
return f"❌ Table '{table_name}' is a system table and cannot be dropped" |
|
|
except Exception as e: |
|
|
conn.rollback() |
|
|
return f"❌ Error dropping table: {str(e)}" |
|
|
finally: |
|
|
conn.close() |
|
|
except Exception as e: |
|
|
return f"❌ Connection error: {str(e)}" |
|
|
|