File size: 13,646 Bytes
f4b172d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e12532
f4b172d
 
9e12532
 
 
 
f4b172d
9e12532
 
f4b172d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e12532
 
 
f4b172d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e12532
 
 
f4b172d
 
 
 
 
 
 
 
 
 
 
 
 
9e12532
 
 
 
 
f4b172d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e12532
 
 
f4b172d
9e12532
 
 
f4b172d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e12532
 
 
f4b172d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e12532
 
f4b172d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
"""
Database Read Replica Configuration

Provides read replica support for improved read performance
and load distribution across database instances.

Features:
- Multiple read replica support
- Automatic read/write splitting
- Connection failover
- Health checks for replicas
- Latency-based routing
"""

import logging
import os
import threading
import time
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional

from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import QueuePool

logger = logging.getLogger(__name__)


class ReplicaStatus(Enum):
    """Status of a read replica."""

    HEALTHY = "healthy"
    DEGRADED = "degraded"
    UNHEALTHY = "unhealthy"
    UNKNOWN = "unknown"


@dataclass
class DatabaseReplica:
    """Configuration for a database replica."""

    name: str
    url: str
    weight: int = 1  # For weighted routing
    latency_threshold_ms: float = 100.0  # Max acceptable latency
    health_check_interval: int = 30  # seconds
    last_checked: datetime = field(default_factory=datetime.utcnow)
    status: ReplicaStatus = ReplicaStatus.UNKNOWN
    avg_latency_ms: float = 0.0
    request_count: int = 0
    error_count: int = 0


@dataclass
class DatabaseConfig:
    """Database configuration with replica support."""

    primary_url: str
    replicas: List[DatabaseReplica] = field(default_factory=list)
    pool_size: int = 20
    max_overflow: int = 10
    pool_timeout: int = 30
    pool_recycle: int = 3600
    enable_read_write_split: bool = True
    replica_selection_strategy: str = "latency"  # latency, round_robin, random


class ReadReplicaManager:
    """
    Manages database read replicas with automatic health checks
    and intelligent routing.
    """

    def __init__(self, config: DatabaseConfig):
        self.config = config
        self._primary_engine = None
        self._replica_engines: Dict[str, Any] = {}
        self._replica_sessions: Dict[str, sessionmaker] = {}
        self._lock = threading.Lock()
        self._monitor_thread = None
        self._running = False

        # Initialize engines
        self._initialize_engines()

    def _initialize_engines(self):
        """Initialize SQLAlchemy engines for primary and replicas."""
        # Primary (writer) engine
        self._primary_engine = create_engine(
            self.config.primary_url,
            poolclass=QueuePool,
            pool_size=self.config.pool_size,
            max_overflow=self.config.max_overflow,
            pool_timeout=self.config.pool_timeout,
            pool_recycle=self.config.pool_recycle,
            pool_pre_ping=True,
            echo=False,
        )

        # Replica (reader) engines
        for replica in self.config.replicas:
            engine = create_engine(
                replica.url,
                poolclass=QueuePool,
                pool_size=self.config.pool_size // 2,  # Smaller pool for replicas
                max_overflow=self.config.max_overflow // 2,
                pool_timeout=self.config.pool_timeout,
                pool_recycle=self.config.pool_recycle,
                pool_pre_ping=True,
                echo=False,
            )
            self._replica_engines[replica.name] = engine
            self._replica_sessions[replica.name] = sessionmaker(
                autocommit=False, autoflush=False, bind=engine
            )

            logger.info(f"Initialized replica engine: {replica.name}")

    def _start_monitoring(self):
        """Start background health check monitoring."""
        if self._running:
            return

        self._running = True

        def monitor():
            while self._running:
                try:
                    for replica in self.config.replicas:
                        self._check_replica_health(replica)
                except Exception as e:
                    logger.error(f"Replica monitoring error: {e}")
                time.sleep(1)  # Check every second

        self._monitor_thread = threading.Thread(target=monitor, daemon=True)
        self._monitor_thread.start()
        logger.info("Replica health monitoring started")

    def _check_replica_health(self, replica: DatabaseReplica):
        """Check health and latency of a replica."""
        start_time = time.time()

        try:
            engine = self._replica_engines.get(replica.name)
            if not engine:
                replica.status = ReplicaStatus.UNHEALTHY
                return

            # Simple health check query
            with engine.connect() as conn:
                conn.execute("SELECT 1")

            # Calculate latency
            latency_ms = (time.time() - start_time) * 1000

            replica.last_checked = datetime.utcnow()
            replica.avg_latency_ms = (
                replica.avg_latency_ms * 0.9 + latency_ms * 0.1
            )  # Exponential moving average

            if latency_ms > replica.latency_threshold_ms:
                replica.status = ReplicaStatus.DEGRADED
            else:
                replica.status = ReplicaStatus.HEALTHY

        except Exception as e:
            replica.status = ReplicaStatus.UNHEALTHY
            replica.error_count += 1
            logger.warning(f"Replica {replica.name} health check failed: {e}")

    def get_healthy_replica(self) -> Optional[DatabaseReplica]:
        """Get the healthiest available replica based on strategy."""
        healthy = [
            r
            for r in self.config.replicas
            if r.status in (ReplicaStatus.HEALTHY, ReplicaStatus.DEGRADED)
        ]

        if not healthy:
            return None

        if self.config.replica_selection_strategy == "latency":
            # Return replica with lowest latency
            return min(healthy, key=lambda r: r.avg_latency_ms)

        elif self.config.replica_selection_strategy == "round_robin":
            # Simple round-robin (in production, track index atomically)
            return healthy[0]  # Simplified

        elif self.config.replica_selection_strategy == "random":
            import random

            return random.choice(healthy)

        return healthy[0]

    def get_primary_session(self) -> Session:
        """Get a session for the primary (writer) database."""
        return sessionmaker(
            autocommit=False, autoflush=False, bind=self._primary_engine
        )()

    def get_replica_session(
        self, replica_name: Optional[str] = None
    ) -> Optional[Session]:
        """
        Get a session for a read replica.

        Args:
            replica_name: Specific replica name, or None for auto-selection

        Returns:
            Session object or None if no replica available
        """
        if not self.config.enable_read_write_split:
            return None

        if replica_name:
            # Specific replica requested
            session_factory = self._replica_sessions.get(replica_name)
            if session_factory:
                return session_factory()
            return None

        # Auto-select replica
        replica = self.get_healthy_replica()
        if replica:
            session_factory = self._replica_sessions.get(replica.name)
            if session_factory:
                replica.request_count += 1
                return session_factory()

        return None

    def get_all_replica_status(self) -> List[Dict[str, Any]]:
        """Get status of all replicas."""
        return [
            {
                "name": r.name,
                "status": r.status.value,
                "latency_ms": round(r.avg_latency_ms, 2),
                "requests": r.request_count,
                "errors": r.error_count,
                "last_checked": r.last_checked.isoformat() if r.last_checked else None,
            }
            for r in self.config.replicas
        ]

    @property
    def primary_engine(self):
        """Get the primary database engine."""
        return self._primary_engine

    @property
    def replica_engines(self):
        """Get all replica engines."""
        return self._replica_engines


class ReadWriteSessionManager:
    """
    Context manager for automatic read/write session routing.

    Usage:
        with read_write_session(replica_manager) as read_session, write_session:
            # read_session for SELECT queries
            # write_session for INSERT/UPDATE/DELETE
    """

    def __init__(self, replica_manager: ReadReplicaManager):
        self.replica_manager = replica_manager
        self._read_session: Optional[Session] = None
        self._write_session: Optional[Session] = None

    def __enter__(self):
        self._write_session = self.replica_manager.get_primary_session()

        if self.replica_manager.config.enable_read_write_split:
            self._read_session = self.replica_manager.get_replica_session()

        # Fallback to primary if no replica available
        if not self._read_session:
            self._read_session = self._write_session

        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self._read_session:
            self._read_session.close()
        if self._write_session:
            self._write_session.close()

    @property
    def read_session(self) -> Session:
        """Get the read session (replica or primary)."""
        return self._read_session

    @property
    def write_session(self) -> Session:
        """Get the write session (primary only)."""
        return self._write_session


# Global replica manager instance
_replica_manager: Optional[ReadReplicaManager] = None


def get_replica_manager() -> ReadReplicaManager:
    """Get or create the global replica manager."""
    global _replica_manager
    if _replica_manager is None:
        config = load_database_config()
        _replica_manager = ReadReplicaManager(config)
        _replica_manager._start_monitoring()
    return _replica_manager


def load_database_config() -> DatabaseConfig:
    """Load database configuration from environment."""
    primary_url = os.getenv(
        "DATABASE_URL", "postgresql://postgres:postgres@localhost:5432/zenith"
    )

    # Parse replica URLs from environment
    replicas = []
    replica_urls = os.getenv("REPLICA_URLS", "")

    if replica_urls:
        for i, url in enumerate(replica_urls.split(",")):
            url = url.strip()
            if url:
                replicas.append(
                    DatabaseReplica(
                        name=f"replica_{i + 1}",
                        url=url,
                        weight=1,
                        latency_threshold_ms=100.0,
                    )
                )

    return DatabaseConfig(
        primary_url=primary_url,
        replicas=replicas,
        pool_size=int(os.getenv("DB_POOL_SIZE", "20")),
        max_overflow=int(os.getenv("DB_MAX_OVERFLOW", "10")),
        pool_timeout=int(os.getenv("DB_POOL_TIMEOUT", "30")),
        pool_recycle=int(os.getenv("DB_POOL_RECYCLE", "3600")),
        enable_read_write_split=os.getenv("ENABLE_READ_WRITE_SPLIT", "true").lower()
        == "true",
        replica_selection_strategy=os.getenv("REPLICA_SELECTION_STRATEGY", "latency"),
    )


# Query routing decorator for automatic read/write splitting
def route_query(read_only: bool = True):
    """
    Decorator to route queries to appropriate database.

    Usage:
        @route_query(read_only=True)
        def get_users():
            return session.query(User).all()
    """

    def decorator(func):
        def wrapper(*args, **kwargs):
            manager = get_replica_manager()

            with ReadWriteSessionManager(manager) as sessions:
                # Inject sessions into function arguments
                kwargs["_read_session"] = sessions.read_session
                kwargs["_write_session"] = sessions.write_session

                result = func(*args, **kwargs)
                return result

        return wrapper

    return decorator


# Example configuration for docker-compose
EXAMPLE_REPLICA_CONFIG = """
# Add to docker-compose.yml for read replicas

services:
  backend:
    environment:
      - DATABASE_URL=postgresql://postgres:postgres@primary:5432/zenith
      - REPLICA_URLS=postgresql://postgres:postgres@replica1:5432/zenith,postgresql://postgres:postgres@replica2:5432/zenith
      - ENABLE_READ_WRITE_SPLIT=true
      - REPLICA_SELECTION_STRATEGY=latency

  primary:
    image: postgres:15
    environment:
      POSTGRES_DB: zenith
    volumes:
      - primary_data:/var/lib/postgresql/data

  replica1:
    image: postgres:15
    environment:
      POSTGRES_DB: zenith
      POSTGRES_HOST_AUTH_METHOD: trust
    command: |
      bash -c "postgres &
               sleep 5 &&
               pg_basebackup -h primary -D /var/lib/postgresql/data -U replication -Fp -Xs -R"
    depends_on:
      - primary
    volumes:
      - replica1_data:/var/lib/postgresql/data

  replica2:
    image: postgres:15
    environment:
      POSTGRES_DB: zenith
      POSTGRES_HOST_AUTH_METHOD: trust
    command: |
      bash -c "postgres &
               sleep 5 &&
               pg_basebackup -h primary -D /var/lib/postgresql/data -U replication -Fp -Xs -R"
    depends_on:
      - primary
    volumes:
      - replica2_data:/var/lib/postgresql/data

volumes:
  primary_data:
  replica1_data:
  replica2_data:
"""


__all__ = [
    "DatabaseConfig",
    "DatabaseReplica",
    "ReadReplicaManager",
    "ReadWriteSessionManager",
    "get_replica_manager",
    "load_database_config",
    "route_query",
    "ReplicaStatus",
]