""" Database configuration for AegisLM SaaS Backend. Production-ready async SQLAlchemy setup with connection pooling, health monitoring, and proper session management. Includes SQLite fallback for high availability. """ from typing import AsyncGenerator from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import NullPool from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import text import redis.asyncio as redis import logging import time from sqlalchemy import event from typing import Dict, Any from core.config import settings from core.fallback_database import ( get_db as get_db_with_fallback, init_databases as init_fallback_databases, close_databases as close_fallback_databases, check_database_health, get_current_database_type, switch_to_primary, switch_to_fallback ) logger = logging.getLogger(__name__) # Create async engine with connection pooling (legacy - kept for compatibility) async_engine = create_async_engine( settings.DATABASE_URL, pool_pre_ping=True, pool_recycle=3600, poolclass=NullPool if settings.DEBUG else None, echo=settings.DEBUG ) # Create async session factory (legacy - kept for compatibility) AsyncSessionLocal = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False ) # Base model for declarative models Base = declarative_base() # Redis connection redis_client: redis.Redis = None # Database connection monitoring pool_metrics = { 'connections_created': 0, 'connections_closed': 0, 'peak_connections': 0, 'total_queries': 0, 'slow_queries': 0 } # Setup connection monitoring @event.listens_for(async_engine.sync_engine, "connect") def receive_connect(dbapi_connection, connection_record): connection_record.info['connect_time'] = time.time() pool_metrics['connections_created'] += 1 current_connections = pool_metrics['connections_created'] - pool_metrics['connections_closed'] pool_metrics['peak_connections'] = max(pool_metrics['peak_connections'], current_connections) @event.listens_for(async_engine.sync_engine, "checkout") def receive_checkout(dbapi_connection, connection_record, connection_proxy): connection_record.info['checkout_time'] = time.time() @event.listens_for(async_engine.sync_engine, "checkin") def receive_checkin(dbapi_connection, connection_record): checkout_time = connection_record.info.get('checkout_time') if checkout_time: checkout_duration = time.time() - checkout_time if checkout_duration > 1.0: # Slow query threshold pool_metrics['slow_queries'] += 1 pool_metrics['total_queries'] += 1 async def get_pool_metrics() -> Dict[str, Any]: """Get connection pool metrics.""" pool = async_engine.pool current_size = pool.size() if hasattr(pool, 'size') else 0 checked_in = pool.checkedin() if hasattr(pool, 'checkedin') else 0 checked_out = pool.checkedout() if hasattr(pool, 'checkedout') else 0 overflow = pool.overflow() if hasattr(pool, 'overflow') else 0 return { 'pool_size': current_size, 'checked_in': checked_in, 'checked_out': checked_out, 'overflow': overflow, 'utilization': (checked_out / current_size * 100) if current_size > 0 else 0, 'metrics': pool_metrics } async def get_redis() -> redis.Redis: """Get Redis connection.""" global redis_client if redis_client is None: redis_client = redis.from_url( settings.REDIS_URL, encoding="utf-8", decode_responses=True ) return redis_client async def get_db() -> AsyncGenerator[AsyncSession, None]: """ Dependency to get async database session with automatic fallback. Yields: AsyncSession: Database session (primary or fallback) """ async for session in get_db_with_fallback(): yield session async def get_primary_db_only() -> AsyncGenerator[AsyncSession, None]: """ Dependency to get primary PostgreSQL session only (no fallback). Yields: AsyncSession: Primary database session only """ async with AsyncSessionLocal() as session: try: yield session except Exception: await session.rollback() raise def get_sync_db() -> sessionmaker: """ Get sync database session for migrations/admin tasks. Returns: Session: Sync database session """ return sessionmaker( async_engine, autocommit=False, autoflush=False, ) async def init_db() -> None: """Initialize database with fallback support.""" try: await init_fallback_databases() logger.info("Database initialization completed with fallback support") except Exception as e: logger.error(f"Database initialization failed: {e}") raise async def close_db() -> None: """Close database connections.""" await close_fallback_databases() if redis_client: await redis_client.close() async def check_db_health() -> bool: """ Check database health with fallback support. Returns: bool: True if any database is healthy """ try: is_healthy, db_type = await check_database_health() return is_healthy except Exception: return False async def check_redis_health() -> bool: """ Check Redis health with detailed diagnostics. Returns: bool: True if healthy """ try: redis_conn = await get_redis() # Test basic connectivity start_time = time.time() await redis_conn.ping() response_time = (time.time() - start_time) * 1000 # Test read/write capability test_key = "health_check_test" await redis_conn.setex(test_key, 10, "test_value") test_value = await redis_conn.get(test_key) await redis_conn.delete(test_key) if test_value != "test_value": logger.error("Redis read/write test failed") return False # Log performance metrics logger.info(f"Redis health check passed - Response time: {response_time:.2f}ms") return True except Exception as e: logger.error(f"Redis health check failed: {e}") return False async def get_redis_metrics() -> Dict[str, Any]: """ Get Redis performance metrics. Returns: Dict containing Redis metrics """ try: redis_conn = await get_redis() info = await redis_conn.info() return { 'connected_clients': info.get('connected_clients', 0), 'used_memory': info.get('used_memory', 0), 'used_memory_human': info.get('used_memory_human', '0B'), 'total_commands_processed': info.get('total_commands_processed', 0), 'keyspace_hits': info.get('keyspace_hits', 0), 'keyspace_misses': info.get('keyspace_misses', 0), 'hit_rate': ( info.get('keyspace_hits', 0) / (info.get('keyspace_hits', 0) + info.get('keyspace_misses', 1)) ) * 100 } except Exception as e: logger.error(f"Failed to get Redis metrics: {e}") return {}