ALM-2 / backend /core /database.py
ACA050's picture
Upload 520 files
2ed8996 verified
"""
Database configuration for AegisLM SaaS Backend.
Production-ready async SQLAlchemy setup with connection pooling,
health monitoring, and proper session management.
Includes SQLite fallback for high availability.
"""
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import text
import redis.asyncio as redis
import logging
import time
from sqlalchemy import event
from typing import Dict, Any
from core.config import settings
from core.fallback_database import (
get_db as get_db_with_fallback,
init_databases as init_fallback_databases,
close_databases as close_fallback_databases,
check_database_health,
get_current_database_type,
switch_to_primary,
switch_to_fallback
)
logger = logging.getLogger(__name__)
# Create async engine with connection pooling (legacy - kept for compatibility)
async_engine = create_async_engine(
settings.DATABASE_URL,
pool_pre_ping=True,
pool_recycle=3600,
poolclass=NullPool if settings.DEBUG else None,
echo=settings.DEBUG
)
# Create async session factory (legacy - kept for compatibility)
AsyncSessionLocal = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False
)
# Base model for declarative models
Base = declarative_base()
# Redis connection
redis_client: redis.Redis = None
# Database connection monitoring
pool_metrics = {
'connections_created': 0,
'connections_closed': 0,
'peak_connections': 0,
'total_queries': 0,
'slow_queries': 0
}
# Setup connection monitoring
@event.listens_for(async_engine.sync_engine, "connect")
def receive_connect(dbapi_connection, connection_record):
connection_record.info['connect_time'] = time.time()
pool_metrics['connections_created'] += 1
current_connections = pool_metrics['connections_created'] - pool_metrics['connections_closed']
pool_metrics['peak_connections'] = max(pool_metrics['peak_connections'], current_connections)
@event.listens_for(async_engine.sync_engine, "checkout")
def receive_checkout(dbapi_connection, connection_record, connection_proxy):
connection_record.info['checkout_time'] = time.time()
@event.listens_for(async_engine.sync_engine, "checkin")
def receive_checkin(dbapi_connection, connection_record):
checkout_time = connection_record.info.get('checkout_time')
if checkout_time:
checkout_duration = time.time() - checkout_time
if checkout_duration > 1.0: # Slow query threshold
pool_metrics['slow_queries'] += 1
pool_metrics['total_queries'] += 1
async def get_pool_metrics() -> Dict[str, Any]:
"""Get connection pool metrics."""
pool = async_engine.pool
current_size = pool.size() if hasattr(pool, 'size') else 0
checked_in = pool.checkedin() if hasattr(pool, 'checkedin') else 0
checked_out = pool.checkedout() if hasattr(pool, 'checkedout') else 0
overflow = pool.overflow() if hasattr(pool, 'overflow') else 0
return {
'pool_size': current_size,
'checked_in': checked_in,
'checked_out': checked_out,
'overflow': overflow,
'utilization': (checked_out / current_size * 100) if current_size > 0 else 0,
'metrics': pool_metrics
}
async def get_redis() -> redis.Redis:
"""Get Redis connection."""
global redis_client
if redis_client is None:
redis_client = redis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True
)
return redis_client
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""
Dependency to get async database session with automatic fallback.
Yields:
AsyncSession: Database session (primary or fallback)
"""
async for session in get_db_with_fallback():
yield session
async def get_primary_db_only() -> AsyncGenerator[AsyncSession, None]:
"""
Dependency to get primary PostgreSQL session only (no fallback).
Yields:
AsyncSession: Primary database session only
"""
async with AsyncSessionLocal() as session:
try:
yield session
except Exception:
await session.rollback()
raise
def get_sync_db() -> sessionmaker:
"""
Get sync database session for migrations/admin tasks.
Returns:
Session: Sync database session
"""
return sessionmaker(
async_engine,
autocommit=False,
autoflush=False,
)
async def init_db() -> None:
"""Initialize database with fallback support."""
try:
await init_fallback_databases()
logger.info("Database initialization completed with fallback support")
except Exception as e:
logger.error(f"Database initialization failed: {e}")
raise
async def close_db() -> None:
"""Close database connections."""
await close_fallback_databases()
if redis_client:
await redis_client.close()
async def check_db_health() -> bool:
"""
Check database health with fallback support.
Returns:
bool: True if any database is healthy
"""
try:
is_healthy, db_type = await check_database_health()
return is_healthy
except Exception:
return False
async def check_redis_health() -> bool:
"""
Check Redis health with detailed diagnostics.
Returns:
bool: True if healthy
"""
try:
redis_conn = await get_redis()
# Test basic connectivity
start_time = time.time()
await redis_conn.ping()
response_time = (time.time() - start_time) * 1000
# Test read/write capability
test_key = "health_check_test"
await redis_conn.setex(test_key, 10, "test_value")
test_value = await redis_conn.get(test_key)
await redis_conn.delete(test_key)
if test_value != "test_value":
logger.error("Redis read/write test failed")
return False
# Log performance metrics
logger.info(f"Redis health check passed - Response time: {response_time:.2f}ms")
return True
except Exception as e:
logger.error(f"Redis health check failed: {e}")
return False
async def get_redis_metrics() -> Dict[str, Any]:
"""
Get Redis performance metrics.
Returns:
Dict containing Redis metrics
"""
try:
redis_conn = await get_redis()
info = await redis_conn.info()
return {
'connected_clients': info.get('connected_clients', 0),
'used_memory': info.get('used_memory', 0),
'used_memory_human': info.get('used_memory_human', '0B'),
'total_commands_processed': info.get('total_commands_processed', 0),
'keyspace_hits': info.get('keyspace_hits', 0),
'keyspace_misses': info.get('keyspace_misses', 0),
'hit_rate': (
info.get('keyspace_hits', 0) /
(info.get('keyspace_hits', 0) + info.get('keyspace_misses', 1))
) * 100
}
except Exception as e:
logger.error(f"Failed to get Redis metrics: {e}")
return {}