zenith-backend / core /database_replicas.py
teoat's picture
Upload core/database_replicas.py with huggingface_hub
9e12532 verified
"""
Database Read Replica Configuration
Provides read replica support for improved read performance
and load distribution across database instances.
Features:
- Multiple read replica support
- Automatic read/write splitting
- Connection failover
- Health checks for replicas
- Latency-based routing
"""
import logging
import os
import threading
import time
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import QueuePool
logger = logging.getLogger(__name__)
class ReplicaStatus(Enum):
"""Status of a read replica."""
HEALTHY = "healthy"
DEGRADED = "degraded"
UNHEALTHY = "unhealthy"
UNKNOWN = "unknown"
@dataclass
class DatabaseReplica:
"""Configuration for a database replica."""
name: str
url: str
weight: int = 1 # For weighted routing
latency_threshold_ms: float = 100.0 # Max acceptable latency
health_check_interval: int = 30 # seconds
last_checked: datetime = field(default_factory=datetime.utcnow)
status: ReplicaStatus = ReplicaStatus.UNKNOWN
avg_latency_ms: float = 0.0
request_count: int = 0
error_count: int = 0
@dataclass
class DatabaseConfig:
"""Database configuration with replica support."""
primary_url: str
replicas: List[DatabaseReplica] = field(default_factory=list)
pool_size: int = 20
max_overflow: int = 10
pool_timeout: int = 30
pool_recycle: int = 3600
enable_read_write_split: bool = True
replica_selection_strategy: str = "latency" # latency, round_robin, random
class ReadReplicaManager:
"""
Manages database read replicas with automatic health checks
and intelligent routing.
"""
def __init__(self, config: DatabaseConfig):
self.config = config
self._primary_engine = None
self._replica_engines: Dict[str, Any] = {}
self._replica_sessions: Dict[str, sessionmaker] = {}
self._lock = threading.Lock()
self._monitor_thread = None
self._running = False
# Initialize engines
self._initialize_engines()
def _initialize_engines(self):
"""Initialize SQLAlchemy engines for primary and replicas."""
# Primary (writer) engine
self._primary_engine = create_engine(
self.config.primary_url,
poolclass=QueuePool,
pool_size=self.config.pool_size,
max_overflow=self.config.max_overflow,
pool_timeout=self.config.pool_timeout,
pool_recycle=self.config.pool_recycle,
pool_pre_ping=True,
echo=False,
)
# Replica (reader) engines
for replica in self.config.replicas:
engine = create_engine(
replica.url,
poolclass=QueuePool,
pool_size=self.config.pool_size // 2, # Smaller pool for replicas
max_overflow=self.config.max_overflow // 2,
pool_timeout=self.config.pool_timeout,
pool_recycle=self.config.pool_recycle,
pool_pre_ping=True,
echo=False,
)
self._replica_engines[replica.name] = engine
self._replica_sessions[replica.name] = sessionmaker(
autocommit=False, autoflush=False, bind=engine
)
logger.info(f"Initialized replica engine: {replica.name}")
def _start_monitoring(self):
"""Start background health check monitoring."""
if self._running:
return
self._running = True
def monitor():
while self._running:
try:
for replica in self.config.replicas:
self._check_replica_health(replica)
except Exception as e:
logger.error(f"Replica monitoring error: {e}")
time.sleep(1) # Check every second
self._monitor_thread = threading.Thread(target=monitor, daemon=True)
self._monitor_thread.start()
logger.info("Replica health monitoring started")
def _check_replica_health(self, replica: DatabaseReplica):
"""Check health and latency of a replica."""
start_time = time.time()
try:
engine = self._replica_engines.get(replica.name)
if not engine:
replica.status = ReplicaStatus.UNHEALTHY
return
# Simple health check query
with engine.connect() as conn:
conn.execute("SELECT 1")
# Calculate latency
latency_ms = (time.time() - start_time) * 1000
replica.last_checked = datetime.utcnow()
replica.avg_latency_ms = (
replica.avg_latency_ms * 0.9 + latency_ms * 0.1
) # Exponential moving average
if latency_ms > replica.latency_threshold_ms:
replica.status = ReplicaStatus.DEGRADED
else:
replica.status = ReplicaStatus.HEALTHY
except Exception as e:
replica.status = ReplicaStatus.UNHEALTHY
replica.error_count += 1
logger.warning(f"Replica {replica.name} health check failed: {e}")
def get_healthy_replica(self) -> Optional[DatabaseReplica]:
"""Get the healthiest available replica based on strategy."""
healthy = [
r
for r in self.config.replicas
if r.status in (ReplicaStatus.HEALTHY, ReplicaStatus.DEGRADED)
]
if not healthy:
return None
if self.config.replica_selection_strategy == "latency":
# Return replica with lowest latency
return min(healthy, key=lambda r: r.avg_latency_ms)
elif self.config.replica_selection_strategy == "round_robin":
# Simple round-robin (in production, track index atomically)
return healthy[0] # Simplified
elif self.config.replica_selection_strategy == "random":
import random
return random.choice(healthy)
return healthy[0]
def get_primary_session(self) -> Session:
"""Get a session for the primary (writer) database."""
return sessionmaker(
autocommit=False, autoflush=False, bind=self._primary_engine
)()
def get_replica_session(
self, replica_name: Optional[str] = None
) -> Optional[Session]:
"""
Get a session for a read replica.
Args:
replica_name: Specific replica name, or None for auto-selection
Returns:
Session object or None if no replica available
"""
if not self.config.enable_read_write_split:
return None
if replica_name:
# Specific replica requested
session_factory = self._replica_sessions.get(replica_name)
if session_factory:
return session_factory()
return None
# Auto-select replica
replica = self.get_healthy_replica()
if replica:
session_factory = self._replica_sessions.get(replica.name)
if session_factory:
replica.request_count += 1
return session_factory()
return None
def get_all_replica_status(self) -> List[Dict[str, Any]]:
"""Get status of all replicas."""
return [
{
"name": r.name,
"status": r.status.value,
"latency_ms": round(r.avg_latency_ms, 2),
"requests": r.request_count,
"errors": r.error_count,
"last_checked": r.last_checked.isoformat() if r.last_checked else None,
}
for r in self.config.replicas
]
@property
def primary_engine(self):
"""Get the primary database engine."""
return self._primary_engine
@property
def replica_engines(self):
"""Get all replica engines."""
return self._replica_engines
class ReadWriteSessionManager:
"""
Context manager for automatic read/write session routing.
Usage:
with read_write_session(replica_manager) as read_session, write_session:
# read_session for SELECT queries
# write_session for INSERT/UPDATE/DELETE
"""
def __init__(self, replica_manager: ReadReplicaManager):
self.replica_manager = replica_manager
self._read_session: Optional[Session] = None
self._write_session: Optional[Session] = None
def __enter__(self):
self._write_session = self.replica_manager.get_primary_session()
if self.replica_manager.config.enable_read_write_split:
self._read_session = self.replica_manager.get_replica_session()
# Fallback to primary if no replica available
if not self._read_session:
self._read_session = self._write_session
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self._read_session:
self._read_session.close()
if self._write_session:
self._write_session.close()
@property
def read_session(self) -> Session:
"""Get the read session (replica or primary)."""
return self._read_session
@property
def write_session(self) -> Session:
"""Get the write session (primary only)."""
return self._write_session
# Global replica manager instance
_replica_manager: Optional[ReadReplicaManager] = None
def get_replica_manager() -> ReadReplicaManager:
"""Get or create the global replica manager."""
global _replica_manager
if _replica_manager is None:
config = load_database_config()
_replica_manager = ReadReplicaManager(config)
_replica_manager._start_monitoring()
return _replica_manager
def load_database_config() -> DatabaseConfig:
"""Load database configuration from environment."""
primary_url = os.getenv(
"DATABASE_URL", "postgresql://postgres:postgres@localhost:5432/zenith"
)
# Parse replica URLs from environment
replicas = []
replica_urls = os.getenv("REPLICA_URLS", "")
if replica_urls:
for i, url in enumerate(replica_urls.split(",")):
url = url.strip()
if url:
replicas.append(
DatabaseReplica(
name=f"replica_{i + 1}",
url=url,
weight=1,
latency_threshold_ms=100.0,
)
)
return DatabaseConfig(
primary_url=primary_url,
replicas=replicas,
pool_size=int(os.getenv("DB_POOL_SIZE", "20")),
max_overflow=int(os.getenv("DB_MAX_OVERFLOW", "10")),
pool_timeout=int(os.getenv("DB_POOL_TIMEOUT", "30")),
pool_recycle=int(os.getenv("DB_POOL_RECYCLE", "3600")),
enable_read_write_split=os.getenv("ENABLE_READ_WRITE_SPLIT", "true").lower()
== "true",
replica_selection_strategy=os.getenv("REPLICA_SELECTION_STRATEGY", "latency"),
)
# Query routing decorator for automatic read/write splitting
def route_query(read_only: bool = True):
"""
Decorator to route queries to appropriate database.
Usage:
@route_query(read_only=True)
def get_users():
return session.query(User).all()
"""
def decorator(func):
def wrapper(*args, **kwargs):
manager = get_replica_manager()
with ReadWriteSessionManager(manager) as sessions:
# Inject sessions into function arguments
kwargs["_read_session"] = sessions.read_session
kwargs["_write_session"] = sessions.write_session
result = func(*args, **kwargs)
return result
return wrapper
return decorator
# Example configuration for docker-compose
EXAMPLE_REPLICA_CONFIG = """
# Add to docker-compose.yml for read replicas
services:
backend:
environment:
- DATABASE_URL=postgresql://postgres:postgres@primary:5432/zenith
- REPLICA_URLS=postgresql://postgres:postgres@replica1:5432/zenith,postgresql://postgres:postgres@replica2:5432/zenith
- ENABLE_READ_WRITE_SPLIT=true
- REPLICA_SELECTION_STRATEGY=latency
primary:
image: postgres:15
environment:
POSTGRES_DB: zenith
volumes:
- primary_data:/var/lib/postgresql/data
replica1:
image: postgres:15
environment:
POSTGRES_DB: zenith
POSTGRES_HOST_AUTH_METHOD: trust
command: |
bash -c "postgres &
sleep 5 &&
pg_basebackup -h primary -D /var/lib/postgresql/data -U replication -Fp -Xs -R"
depends_on:
- primary
volumes:
- replica1_data:/var/lib/postgresql/data
replica2:
image: postgres:15
environment:
POSTGRES_DB: zenith
POSTGRES_HOST_AUTH_METHOD: trust
command: |
bash -c "postgres &
sleep 5 &&
pg_basebackup -h primary -D /var/lib/postgresql/data -U replication -Fp -Xs -R"
depends_on:
- primary
volumes:
- replica2_data:/var/lib/postgresql/data
volumes:
primary_data:
replica1_data:
replica2_data:
"""
__all__ = [
"DatabaseConfig",
"DatabaseReplica",
"ReadReplicaManager",
"ReadWriteSessionManager",
"get_replica_manager",
"load_database_config",
"route_query",
"ReplicaStatus",
]