File size: 7,640 Bytes
f9ad313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8441ef
f9ad313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b404e8f
 
 
 
 
 
 
 
 
 
 
f9ad313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""

Database Connection Module - Multi-Database Support.



This module provides:

- SQLAlchemy engine and session management for MySQL, PostgreSQL, and SQLite

- Connection pooling (for MySQL/PostgreSQL)

- SSL/TLS support

- Connection health checking

"""

import logging
from contextlib import contextmanager
from typing import Optional, Generator
from sqlalchemy import create_engine, text, event
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import QueuePool, StaticPool
from sqlalchemy.exc import OperationalError, SQLAlchemyError

import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import DatabaseConfig, DatabaseType, config

logger = logging.getLogger(__name__)


class DatabaseConnection:
    """

    Manages database connections with connection pooling.

    

    Supports MySQL, PostgreSQL, and SQLite.

    """
    
    def __init__(self, db_config: Optional[DatabaseConfig] = None):
        """

        Initialize database connection manager.

        

        Args:

            db_config: Database configuration. Uses global config if not provided.

        """
        self.config = db_config or config.database
        self._engine: Optional[Engine] = None
        self._session_factory: Optional[sessionmaker] = None
    
    def _create_engine(self) -> Engine:
        """

        Create SQLAlchemy engine with appropriate settings for each database type.

        

        Returns:

            Configured SQLAlchemy Engine instance

        """
        connect_args = {}
        
        if self.config.db_type == DatabaseType.POSTGRESQL:
            # PostgreSQL-specific settings
            if self.config.ssl_ca:
                connect_args["sslmode"] = "verify-full"
                connect_args["sslrootcert"] = self.config.ssl_ca
            
            engine = create_engine(
                self.config.connection_string,
                poolclass=QueuePool,
                pool_size=5,
                max_overflow=10,
                pool_timeout=30,
                pool_recycle=1800,
                pool_pre_ping=True,
                connect_args=connect_args,
                echo=False
            )
            
        elif self.config.db_type == DatabaseType.SQLITE:
            # SQLite-specific settings
            # We use StaticPool for SQLite to avoid issues with multiple threads 
            # if using in-memory or a single file connection
            engine = create_engine(
                self.config.connection_string,
                poolclass=StaticPool,
                connect_args={"check_same_thread": False},
                echo=False
            )
            
        else:  # MySQL (default)
            # MySQL-specific settings (SSL for Aiven)
            if self.config.ssl_ca:
                connect_args["ssl"] = {
                    "ca": self.config.ssl_ca,
                    "check_hostname": True,
                    "verify_mode": True
                }
            
            engine = create_engine(
                self.config.connection_string,
                poolclass=QueuePool,
                pool_size=5,
                max_overflow=10,
                pool_timeout=30,
                pool_recycle=1800,
                pool_pre_ping=True,
                connect_args=connect_args,
                echo=False
            )
        
        return engine
    
    @property
    def engine(self) -> Engine:
        """Get or create the SQLAlchemy engine."""
        if self._engine is None:
            self._engine = self._create_engine()
        return self._engine
    
    @property
    def session_factory(self) -> sessionmaker:
        """Get or create the session factory."""
        if self._session_factory is None:
            self._session_factory = sessionmaker(
                bind=self.engine,
                autocommit=False,
                autoflush=False
            )
        return self._session_factory
    
    @property
    def db_type(self) -> DatabaseType:
        """Get the current database type."""
        return self.config.db_type
    
    @contextmanager
    def get_session(self) -> Generator[Session, None, None]:
        """

        Context manager for database sessions.

        

        Yields:

            SQLAlchemy Session instance

            

        Example:

            with db.get_session() as session:

                result = session.execute(text("SELECT * FROM users"))

        """
        session = self.session_factory()
        try:
            yield session
            session.commit()
        except SQLAlchemyError as e:
            session.rollback()
            logger.error(f"Database session error: {e}")
            raise
        finally:
            session.close()
    
    def execute_query(self, query: str, params: Optional[dict] = None) -> list:
        """

        Execute a read-only SQL query and return results.

        

        Args:

            query: SQL query string (must be SELECT)

            params: Optional query parameters for parameterized queries

            

        Returns:

            List of result rows as dictionaries

        """
        with self.get_session() as session:
            result = session.execute(text(query), params or {})
            # Convert rows to dictionaries for easier handling
            columns = result.keys()
            return [dict(zip(columns, row)) for row in result.fetchall()]

    def execute_write(self, query: str, params: Optional[dict] = None) -> bool:
        """

        Execute a write operation (INSERT, UPDATE, DELETE, CREATE).

        

        Args:

            query: SQL query string

            params: Optional query parameters

            

        Returns:

            bool: True if successful

        """
        with self.get_session() as session:
            session.execute(text(query), params or {})
            session.commit()
            return True
    
    def test_connection(self) -> tuple[bool, str]:
        """

        Test database connectivity.

        

        Returns:

            tuple: (success: bool, message: str)

        """
        try:
            with self.get_session() as session:
                result = session.execute(text("SELECT 1 as health_check"))
                row = result.fetchone()
                if row and row[0] == 1:
                    db_type = self.config.db_type.value.upper()
                    return True, f"{db_type} connection successful"
                return False, "Unexpected result from health check query"
        except OperationalError as e:
            logger.error(f"Database connection failed: {e}")
            return False, f"Connection failed: {str(e)}"
        except Exception as e:
            logger.error(f"Unexpected error during connection test: {e}")
            return False, f"Unexpected error: {str(e)}"
    
    def close(self):
        """Close all connections and dispose of the engine."""
        if self._engine:
            self._engine.dispose()
            self._engine = None
            self._session_factory = None
            logger.info("Database connections closed")


# Create a global database connection instance
db_connection = DatabaseConnection()


def get_db() -> DatabaseConnection:
    """Get the global database connection instance."""
    return db_connection