Spaces:
Paused
Paused
| from sqlalchemy import create_engine, MetaData | |
| from sqlalchemy.ext.declarative import declarative_base | |
| from sqlalchemy.orm import sessionmaker | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| Base = declarative_base() | |
| class DBConnector: | |
| def __init__(self, conn_string=None): | |
| """Initialize database connection""" | |
| self.conn_string = conn_string | |
| self.engine = None | |
| self.Session = None | |
| self.metadata = None | |
| def _fix_mysql_connection_string(self, conn_string): | |
| """Fix MySQL connection string to use available drivers""" | |
| if not conn_string.startswith('mysql'): | |
| return conn_string | |
| # Try different MySQL drivers in order of preference | |
| mysql_drivers = ['pymysql', 'mysqlconnector', 'mysqldb'] | |
| for driver in mysql_drivers: | |
| try: | |
| if driver == 'pymysql': | |
| import pymysql | |
| # Replace mysql:// with mysql+pymysql:// | |
| fixed_string = conn_string.replace('mysql://', 'mysql+pymysql://') | |
| logger.info("Using PyMySQL driver for MySQL connection") | |
| return fixed_string | |
| elif driver == 'mysqlconnector': | |
| import mysql.connector | |
| # Replace mysql:// with mysql+mysqlconnector:// | |
| fixed_string = conn_string.replace('mysql://', 'mysql+mysqlconnector://') | |
| logger.info("Using mysql-connector-python driver for MySQL connection") | |
| return fixed_string | |
| elif driver == 'mysqldb': | |
| import MySQLdb | |
| logger.info("Using MySQLdb driver for MySQL connection") | |
| return conn_string | |
| except ImportError: | |
| continue | |
| # If no drivers are available, provide helpful error message | |
| error_msg = """ | |
| No MySQL driver found. Please install one of the following: | |
| - PyMySQL: pip install PyMySQL | |
| - mysql-connector-python: pip install mysql-connector-python | |
| - MySQLdb: pip install mysqlclient | |
| PyMySQL is recommended for most use cases. | |
| """ | |
| raise ImportError(error_msg) | |
| def connect(self, conn_string=None): | |
| """Connect to the database with the given connection string""" | |
| if conn_string: | |
| self.conn_string = conn_string | |
| if not self.conn_string: | |
| raise ValueError("Database connection string must be provided") | |
| try: | |
| # Fix MySQL connection string if needed | |
| fixed_conn_string = self._fix_mysql_connection_string(self.conn_string) | |
| self.engine = create_engine(fixed_conn_string) | |
| self.Session = sessionmaker(bind=self.engine) | |
| self.metadata = MetaData() | |
| self.metadata.bind = self.engine | |
| logger.info(f"Connected to database: {self.conn_string.split('@')[-1]}") | |
| return True | |
| except ImportError as e: | |
| logger.error(f"Missing database driver: {str(e)}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Failed to connect to database: {str(e)}") | |
| return False | |
| def get_session(self): | |
| """Get a new session for database operations""" | |
| if not self.Session: | |
| raise ConnectionError("Database connection not established") | |
| return self.Session() | |
| def execute_query(self, query): | |
| """Execute raw SQL query and return results""" | |
| if not self.engine: | |
| raise ConnectionError("Database connection not established") | |
| try: | |
| with self.engine.connect() as connection: | |
| # Use text() to properly handle SQL queries | |
| from sqlalchemy import text | |
| result = connection.execute(text(query)) | |
| return result.fetchall() | |
| except Exception as e: | |
| logger.error(f"Query execution error: {str(e)}") | |
| raise | |
| def execute_sql_statements(self, sql_statements): | |
| """Execute multiple SQL statements in a transaction""" | |
| if not self.engine: | |
| raise ConnectionError("Database connection not established") | |
| session = self.get_session() | |
| results = [] | |
| try: | |
| for statement in sql_statements: | |
| if statement.strip(): | |
| from sqlalchemy import text | |
| result = session.execute(text(statement)) | |
| results.append(result) | |
| session.commit() | |
| return results | |
| except Exception as e: | |
| session.rollback() | |
| logger.error(f"SQL execution error: {str(e)}") | |
| raise | |
| finally: | |
| session.close() | |
| def close(self): | |
| """Close the database connection""" | |
| if self.engine: | |
| self.engine.dispose() | |
| logger.info("Database connection closed") | |