|
|
""" |
|
|
Enterprise-Grade Database Engine with Connection Pooling and Async Support |
|
|
""" |
|
|
import os |
|
|
import logging |
|
|
from typing import Optional, AsyncGenerator |
|
|
from contextlib import asynccontextmanager |
|
|
from sqlalchemy.ext.asyncio import ( |
|
|
create_async_engine, |
|
|
AsyncSession, |
|
|
AsyncEngine, |
|
|
async_sessionmaker |
|
|
) |
|
|
from sqlalchemy.pool import NullPool, QueuePool |
|
|
from sqlalchemy import event, text |
|
|
|
|
|
from .models import Base |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class DatabaseConfig: |
|
|
"""Database configuration with environment variable support""" |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
self.database_url = os.getenv( |
|
|
"DATABASE_URL", |
|
|
"sqlite+aiosqlite:///./data/cx_agent.db" |
|
|
) |
|
|
|
|
|
|
|
|
if self.database_url.startswith("postgres://"): |
|
|
self.database_url = self.database_url.replace( |
|
|
"postgres://", "postgresql+asyncpg://", 1 |
|
|
) |
|
|
|
|
|
|
|
|
self.pool_size = int(os.getenv("DB_POOL_SIZE", "20")) |
|
|
self.max_overflow = int(os.getenv("DB_MAX_OVERFLOW", "10")) |
|
|
self.pool_timeout = int(os.getenv("DB_POOL_TIMEOUT", "30")) |
|
|
self.pool_recycle = int(os.getenv("DB_POOL_RECYCLE", "3600")) |
|
|
self.pool_pre_ping = os.getenv("DB_POOL_PRE_PING", "true").lower() == "true" |
|
|
|
|
|
|
|
|
self.echo = os.getenv("DB_ECHO", "false").lower() == "true" |
|
|
|
|
|
|
|
|
self.enable_wal = os.getenv("SQLITE_WAL", "true").lower() == "true" |
|
|
|
|
|
def is_sqlite(self) -> bool: |
|
|
"""Check if using SQLite""" |
|
|
return "sqlite" in self.database_url |
|
|
|
|
|
def is_postgres(self) -> bool: |
|
|
"""Check if using PostgreSQL""" |
|
|
return "postgresql" in self.database_url |
|
|
|
|
|
|
|
|
class DatabaseManager: |
|
|
"""Singleton database manager with connection pooling""" |
|
|
|
|
|
_instance: Optional["DatabaseManager"] = None |
|
|
_engine: Optional[AsyncEngine] = None |
|
|
_session_factory: Optional[async_sessionmaker[AsyncSession]] = None |
|
|
|
|
|
def __new__(cls): |
|
|
if cls._instance is None: |
|
|
cls._instance = super().__new__(cls) |
|
|
return cls._instance |
|
|
|
|
|
def __init__(self): |
|
|
if self._engine is None: |
|
|
self._initialize() |
|
|
|
|
|
def _initialize(self): |
|
|
"""Initialize database engine and session factory""" |
|
|
config = DatabaseConfig() |
|
|
|
|
|
|
|
|
engine_kwargs = { |
|
|
"echo": config.echo, |
|
|
"future": True, |
|
|
} |
|
|
|
|
|
|
|
|
if config.is_sqlite(): |
|
|
|
|
|
logger.info(f"Initializing SQLite database: {config.database_url}") |
|
|
engine_kwargs.update({ |
|
|
"poolclass": NullPool, |
|
|
"connect_args": { |
|
|
"check_same_thread": False, |
|
|
"timeout": 30, |
|
|
} |
|
|
}) |
|
|
|
|
|
|
|
|
if config.enable_wal: |
|
|
engine_kwargs["connect_args"]["pragmas"] = { |
|
|
"journal_mode": "WAL", |
|
|
"synchronous": "NORMAL", |
|
|
"cache_size": -64000, |
|
|
"foreign_keys": 1, |
|
|
"busy_timeout": 5000, |
|
|
} |
|
|
|
|
|
else: |
|
|
|
|
|
logger.info(f"Initializing database: {config.database_url}") |
|
|
engine_kwargs.update({ |
|
|
"poolclass": QueuePool, |
|
|
"pool_size": config.pool_size, |
|
|
"max_overflow": config.max_overflow, |
|
|
"pool_timeout": config.pool_timeout, |
|
|
"pool_recycle": config.pool_recycle, |
|
|
"pool_pre_ping": config.pool_pre_ping, |
|
|
}) |
|
|
|
|
|
|
|
|
self._engine = create_async_engine( |
|
|
config.database_url, |
|
|
**engine_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
self._session_factory = async_sessionmaker( |
|
|
self._engine, |
|
|
class_=AsyncSession, |
|
|
expire_on_commit=False, |
|
|
autocommit=False, |
|
|
autoflush=False |
|
|
) |
|
|
|
|
|
|
|
|
self._register_event_listeners() |
|
|
|
|
|
logger.info("Database engine initialized successfully") |
|
|
|
|
|
def _register_event_listeners(self): |
|
|
"""Register SQLAlchemy event listeners""" |
|
|
|
|
|
@event.listens_for(self._engine.sync_engine, "connect") |
|
|
def receive_connect(dbapi_conn, connection_record): |
|
|
"""Event listener for new connections""" |
|
|
logger.debug("New database connection established") |
|
|
|
|
|
@event.listens_for(self._engine.sync_engine, "close") |
|
|
def receive_close(dbapi_conn, connection_record): |
|
|
"""Event listener for closed connections""" |
|
|
logger.debug("Database connection closed") |
|
|
|
|
|
@property |
|
|
def engine(self) -> AsyncEngine: |
|
|
"""Get the database engine""" |
|
|
if self._engine is None: |
|
|
raise RuntimeError("Database engine not initialized") |
|
|
return self._engine |
|
|
|
|
|
@property |
|
|
def session_factory(self) -> async_sessionmaker[AsyncSession]: |
|
|
"""Get the session factory""" |
|
|
if self._session_factory is None: |
|
|
raise RuntimeError("Session factory not initialized") |
|
|
return self._session_factory |
|
|
|
|
|
async def create_tables(self): |
|
|
"""Create all database tables""" |
|
|
logger.info("Creating database tables...") |
|
|
async with self._engine.begin() as conn: |
|
|
await conn.run_sync(Base.metadata.create_all) |
|
|
logger.info("Database tables created successfully") |
|
|
|
|
|
async def drop_tables(self): |
|
|
"""Drop all database tables (use with caution!)""" |
|
|
logger.warning("Dropping all database tables...") |
|
|
async with self._engine.begin() as conn: |
|
|
await conn.run_sync(Base.metadata.drop_all) |
|
|
logger.info("Database tables dropped") |
|
|
|
|
|
async def health_check(self) -> bool: |
|
|
"""Check database health""" |
|
|
try: |
|
|
async with self.get_session() as session: |
|
|
await session.execute(text("SELECT 1")) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Database health check failed: {e}") |
|
|
return False |
|
|
|
|
|
@asynccontextmanager |
|
|
async def get_session(self) -> AsyncGenerator[AsyncSession, None]: |
|
|
"""Get a database session with automatic cleanup""" |
|
|
session = self.session_factory() |
|
|
try: |
|
|
yield session |
|
|
await session.commit() |
|
|
except Exception as e: |
|
|
await session.rollback() |
|
|
logger.error(f"Database session error: {e}") |
|
|
raise |
|
|
finally: |
|
|
await session.close() |
|
|
|
|
|
async def close(self): |
|
|
"""Close database engine and connections""" |
|
|
if self._engine is not None: |
|
|
await self._engine.dispose() |
|
|
logger.info("Database engine closed") |
|
|
|
|
|
|
|
|
|
|
|
_db_manager: Optional[DatabaseManager] = None |
|
|
|
|
|
|
|
|
def get_db_manager() -> DatabaseManager: |
|
|
"""Get or create the global database manager instance""" |
|
|
global _db_manager |
|
|
if _db_manager is None: |
|
|
_db_manager = DatabaseManager() |
|
|
return _db_manager |
|
|
|
|
|
|
|
|
async def get_session() -> AsyncGenerator[AsyncSession, None]: |
|
|
"""Convenience function to get a database session""" |
|
|
db_manager = get_db_manager() |
|
|
async with db_manager.get_session() as session: |
|
|
yield session |
|
|
|
|
|
|
|
|
async def init_database(): |
|
|
"""Initialize database (create tables if needed)""" |
|
|
db_manager = get_db_manager() |
|
|
await db_manager.create_tables() |
|
|
logger.info("Database initialized") |
|
|
|
|
|
|
|
|
async def close_database(): |
|
|
"""Close database connections""" |
|
|
db_manager = get_db_manager() |
|
|
await db_manager.close() |
|
|
logger.info("Database closed") |
|
|
|