File size: 7,443 Bytes
2ed8996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
"""
Database configuration for AegisLM SaaS Backend.

Production-ready async SQLAlchemy setup with connection pooling,
health monitoring, and proper session management.
Includes SQLite fallback for high availability.
"""

from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import text
import redis.asyncio as redis
import logging
import time
from sqlalchemy import event
from typing import Dict, Any

from core.config import settings
from core.fallback_database import (
    get_db as get_db_with_fallback,
    init_databases as init_fallback_databases,
    close_databases as close_fallback_databases,
    check_database_health,
    get_current_database_type,
    switch_to_primary,
    switch_to_fallback
)

logger = logging.getLogger(__name__)


# Create async engine with connection pooling (legacy - kept for compatibility)
async_engine = create_async_engine(
    settings.DATABASE_URL,
    pool_pre_ping=True,
    pool_recycle=3600,
    poolclass=NullPool if settings.DEBUG else None,
    echo=settings.DEBUG
)

# Create async session factory (legacy - kept for compatibility)
AsyncSessionLocal = async_sessionmaker(
    async_engine,
    class_=AsyncSession,
    expire_on_commit=False
)

# Base model for declarative models
Base = declarative_base()

# Redis connection
redis_client: redis.Redis = None

# Database connection monitoring
pool_metrics = {
    'connections_created': 0,
    'connections_closed': 0,
    'peak_connections': 0,
    'total_queries': 0,
    'slow_queries': 0
}

# Setup connection monitoring
@event.listens_for(async_engine.sync_engine, "connect")
def receive_connect(dbapi_connection, connection_record):
    connection_record.info['connect_time'] = time.time()
    pool_metrics['connections_created'] += 1
    current_connections = pool_metrics['connections_created'] - pool_metrics['connections_closed']
    pool_metrics['peak_connections'] = max(pool_metrics['peak_connections'], current_connections)

@event.listens_for(async_engine.sync_engine, "checkout")
def receive_checkout(dbapi_connection, connection_record, connection_proxy):
    connection_record.info['checkout_time'] = time.time()

@event.listens_for(async_engine.sync_engine, "checkin")
def receive_checkin(dbapi_connection, connection_record):
    checkout_time = connection_record.info.get('checkout_time')
    if checkout_time:
        checkout_duration = time.time() - checkout_time
        if checkout_duration > 1.0:  # Slow query threshold
            pool_metrics['slow_queries'] += 1
        pool_metrics['total_queries'] += 1

async def get_pool_metrics() -> Dict[str, Any]:
    """Get connection pool metrics."""
    pool = async_engine.pool
    current_size = pool.size() if hasattr(pool, 'size') else 0
    checked_in = pool.checkedin() if hasattr(pool, 'checkedin') else 0
    checked_out = pool.checkedout() if hasattr(pool, 'checkedout') else 0
    overflow = pool.overflow() if hasattr(pool, 'overflow') else 0
    
    return {
        'pool_size': current_size,
        'checked_in': checked_in,
        'checked_out': checked_out,
        'overflow': overflow,
        'utilization': (checked_out / current_size * 100) if current_size > 0 else 0,
        'metrics': pool_metrics
    }

async def get_redis() -> redis.Redis:
    """Get Redis connection."""
    global redis_client
    if redis_client is None:
        redis_client = redis.from_url(
            settings.REDIS_URL,
            encoding="utf-8",
            decode_responses=True
        )
    return redis_client

async def get_db() -> AsyncGenerator[AsyncSession, None]:
    """
    Dependency to get async database session with automatic fallback.
    
    Yields:
        AsyncSession: Database session (primary or fallback)
    """
    async for session in get_db_with_fallback():
        yield session


async def get_primary_db_only() -> AsyncGenerator[AsyncSession, None]:
    """
    Dependency to get primary PostgreSQL session only (no fallback).
    
    Yields:
        AsyncSession: Primary database session only
    """
    async with AsyncSessionLocal() as session:
        try:
            yield session
        except Exception:
            await session.rollback()
            raise


def get_sync_db() -> sessionmaker:
    """
    Get sync database session for migrations/admin tasks.
    
    Returns:
        Session: Sync database session
    """
    return sessionmaker(
        async_engine,
        autocommit=False,
        autoflush=False,
    )


async def init_db() -> None:
    """Initialize database with fallback support."""
    try:
        await init_fallback_databases()
        logger.info("Database initialization completed with fallback support")
    except Exception as e:
        logger.error(f"Database initialization failed: {e}")
        raise


async def close_db() -> None:
    """Close database connections."""
    await close_fallback_databases()
    
    if redis_client:
        await redis_client.close()


async def check_db_health() -> bool:
    """
    Check database health with fallback support.
    
    Returns:
        bool: True if any database is healthy
    """
    try:
        is_healthy, db_type = await check_database_health()
        return is_healthy
    except Exception:
        return False


async def check_redis_health() -> bool:
    """
    Check Redis health with detailed diagnostics.
    
    Returns:
        bool: True if healthy
    """
    try:
        redis_conn = await get_redis()
        
        # Test basic connectivity
        start_time = time.time()
        await redis_conn.ping()
        response_time = (time.time() - start_time) * 1000
        
        # Test read/write capability
        test_key = "health_check_test"
        await redis_conn.setex(test_key, 10, "test_value")
        test_value = await redis_conn.get(test_key)
        await redis_conn.delete(test_key)
        
        if test_value != "test_value":
            logger.error("Redis read/write test failed")
            return False
        
        # Log performance metrics
        logger.info(f"Redis health check passed - Response time: {response_time:.2f}ms")
        
        return True
        
    except Exception as e:
        logger.error(f"Redis health check failed: {e}")
        return False

async def get_redis_metrics() -> Dict[str, Any]:
    """
    Get Redis performance metrics.
    
    Returns:
        Dict containing Redis metrics
    """
    try:
        redis_conn = await get_redis()
        info = await redis_conn.info()
        
        return {
            'connected_clients': info.get('connected_clients', 0),
            'used_memory': info.get('used_memory', 0),
            'used_memory_human': info.get('used_memory_human', '0B'),
            'total_commands_processed': info.get('total_commands_processed', 0),
            'keyspace_hits': info.get('keyspace_hits', 0),
            'keyspace_misses': info.get('keyspace_misses', 0),
            'hit_rate': (
                info.get('keyspace_hits', 0) / 
                (info.get('keyspace_hits', 0) + info.get('keyspace_misses', 1))
            ) * 100
        }
        
    except Exception as e:
        logger.error(f"Failed to get Redis metrics: {e}")
        return {}