| """ |
| Database configuration for AegisLM SaaS Backend. |
| |
| Production-ready async SQLAlchemy setup with connection pooling, |
| health monitoring, and proper session management. |
| Includes SQLite fallback for high availability. |
| """ |
|
|
| from typing import AsyncGenerator |
| from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker |
| from sqlalchemy.orm import sessionmaker |
| from sqlalchemy.pool import NullPool |
| from sqlalchemy.ext.declarative import declarative_base |
| from sqlalchemy import text |
| import redis.asyncio as redis |
| import logging |
| import time |
| from sqlalchemy import event |
| from typing import Dict, Any |
|
|
| from core.config import settings |
| from core.fallback_database import ( |
| get_db as get_db_with_fallback, |
| init_databases as init_fallback_databases, |
| close_databases as close_fallback_databases, |
| check_database_health, |
| get_current_database_type, |
| switch_to_primary, |
| switch_to_fallback |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| async_engine = create_async_engine( |
| settings.DATABASE_URL, |
| pool_pre_ping=True, |
| pool_recycle=3600, |
| poolclass=NullPool if settings.DEBUG else None, |
| echo=settings.DEBUG |
| ) |
|
|
| |
| AsyncSessionLocal = async_sessionmaker( |
| async_engine, |
| class_=AsyncSession, |
| expire_on_commit=False |
| ) |
|
|
| |
| Base = declarative_base() |
|
|
| |
| redis_client: redis.Redis = None |
|
|
| |
| pool_metrics = { |
| 'connections_created': 0, |
| 'connections_closed': 0, |
| 'peak_connections': 0, |
| 'total_queries': 0, |
| 'slow_queries': 0 |
| } |
|
|
| |
| @event.listens_for(async_engine.sync_engine, "connect") |
| def receive_connect(dbapi_connection, connection_record): |
| connection_record.info['connect_time'] = time.time() |
| pool_metrics['connections_created'] += 1 |
| current_connections = pool_metrics['connections_created'] - pool_metrics['connections_closed'] |
| pool_metrics['peak_connections'] = max(pool_metrics['peak_connections'], current_connections) |
|
|
| @event.listens_for(async_engine.sync_engine, "checkout") |
| def receive_checkout(dbapi_connection, connection_record, connection_proxy): |
| connection_record.info['checkout_time'] = time.time() |
|
|
| @event.listens_for(async_engine.sync_engine, "checkin") |
| def receive_checkin(dbapi_connection, connection_record): |
| checkout_time = connection_record.info.get('checkout_time') |
| if checkout_time: |
| checkout_duration = time.time() - checkout_time |
| if checkout_duration > 1.0: |
| pool_metrics['slow_queries'] += 1 |
| pool_metrics['total_queries'] += 1 |
|
|
| async def get_pool_metrics() -> Dict[str, Any]: |
| """Get connection pool metrics.""" |
| pool = async_engine.pool |
| current_size = pool.size() if hasattr(pool, 'size') else 0 |
| checked_in = pool.checkedin() if hasattr(pool, 'checkedin') else 0 |
| checked_out = pool.checkedout() if hasattr(pool, 'checkedout') else 0 |
| overflow = pool.overflow() if hasattr(pool, 'overflow') else 0 |
| |
| return { |
| 'pool_size': current_size, |
| 'checked_in': checked_in, |
| 'checked_out': checked_out, |
| 'overflow': overflow, |
| 'utilization': (checked_out / current_size * 100) if current_size > 0 else 0, |
| 'metrics': pool_metrics |
| } |
|
|
| async def get_redis() -> redis.Redis: |
| """Get Redis connection.""" |
| global redis_client |
| if redis_client is None: |
| redis_client = redis.from_url( |
| settings.REDIS_URL, |
| encoding="utf-8", |
| decode_responses=True |
| ) |
| return redis_client |
|
|
| async def get_db() -> AsyncGenerator[AsyncSession, None]: |
| """ |
| Dependency to get async database session with automatic fallback. |
| |
| Yields: |
| AsyncSession: Database session (primary or fallback) |
| """ |
| async for session in get_db_with_fallback(): |
| yield session |
|
|
|
|
| async def get_primary_db_only() -> AsyncGenerator[AsyncSession, None]: |
| """ |
| Dependency to get primary PostgreSQL session only (no fallback). |
| |
| Yields: |
| AsyncSession: Primary database session only |
| """ |
| async with AsyncSessionLocal() as session: |
| try: |
| yield session |
| except Exception: |
| await session.rollback() |
| raise |
|
|
|
|
| def get_sync_db() -> sessionmaker: |
| """ |
| Get sync database session for migrations/admin tasks. |
| |
| Returns: |
| Session: Sync database session |
| """ |
| return sessionmaker( |
| async_engine, |
| autocommit=False, |
| autoflush=False, |
| ) |
|
|
|
|
| async def init_db() -> None: |
| """Initialize database with fallback support.""" |
| try: |
| await init_fallback_databases() |
| logger.info("Database initialization completed with fallback support") |
| except Exception as e: |
| logger.error(f"Database initialization failed: {e}") |
| raise |
|
|
|
|
| async def close_db() -> None: |
| """Close database connections.""" |
| await close_fallback_databases() |
| |
| if redis_client: |
| await redis_client.close() |
|
|
|
|
| async def check_db_health() -> bool: |
| """ |
| Check database health with fallback support. |
| |
| Returns: |
| bool: True if any database is healthy |
| """ |
| try: |
| is_healthy, db_type = await check_database_health() |
| return is_healthy |
| except Exception: |
| return False |
|
|
|
|
| async def check_redis_health() -> bool: |
| """ |
| Check Redis health with detailed diagnostics. |
| |
| Returns: |
| bool: True if healthy |
| """ |
| try: |
| redis_conn = await get_redis() |
| |
| |
| start_time = time.time() |
| await redis_conn.ping() |
| response_time = (time.time() - start_time) * 1000 |
| |
| |
| test_key = "health_check_test" |
| await redis_conn.setex(test_key, 10, "test_value") |
| test_value = await redis_conn.get(test_key) |
| await redis_conn.delete(test_key) |
| |
| if test_value != "test_value": |
| logger.error("Redis read/write test failed") |
| return False |
| |
| |
| logger.info(f"Redis health check passed - Response time: {response_time:.2f}ms") |
| |
| return True |
| |
| except Exception as e: |
| logger.error(f"Redis health check failed: {e}") |
| return False |
|
|
| async def get_redis_metrics() -> Dict[str, Any]: |
| """ |
| Get Redis performance metrics. |
| |
| Returns: |
| Dict containing Redis metrics |
| """ |
| try: |
| redis_conn = await get_redis() |
| info = await redis_conn.info() |
| |
| return { |
| 'connected_clients': info.get('connected_clients', 0), |
| 'used_memory': info.get('used_memory', 0), |
| 'used_memory_human': info.get('used_memory_human', '0B'), |
| 'total_commands_processed': info.get('total_commands_processed', 0), |
| 'keyspace_hits': info.get('keyspace_hits', 0), |
| 'keyspace_misses': info.get('keyspace_misses', 0), |
| 'hit_rate': ( |
| info.get('keyspace_hits', 0) / |
| (info.get('keyspace_hits', 0) + info.get('keyspace_misses', 1)) |
| ) * 100 |
| } |
| |
| except Exception as e: |
| logger.error(f"Failed to get Redis metrics: {e}") |
| return {} |
|
|