File size: 4,231 Bytes
fbdfc24
531a2b2
 
 
 
fbdfc24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# database/postgres_checkpointer.py - CORRECT VERSION
# Add project root to Python path
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from psycopg_pool import AsyncConnectionPool
from psycopg.rows import dict_row
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver  # βœ… Correct import
from langgraph.checkpoint.memory import MemorySaver
import logging
from typing import Optional

logger = logging.getLogger(__name__)

class PostgresCheckpointer:
    def __init__(self, database_url: str, max_connections: int = 10, min_connections: int = 2):
        self.database_url = database_url
        self.max_connections = max_connections
        self.min_connections = min_connections
        self.pool: Optional[AsyncConnectionPool] = None
        self.checkpointer: Optional[AsyncPostgresSaver] = None  # βœ… Correct type
        self._is_initialized = False

    async def initialize(self) -> bool:
        """Initialize PostgreSQL connection pool and checkpointer"""
        try:
            # Create async connection pool
            self.pool = AsyncConnectionPool(
                conninfo=self.database_url,
                max_size=self.max_connections,
                min_size=self.min_connections,
                kwargs={"row_factory": dict_row, "autocommit": True},
                open=False,
            )
            
            await self.pool.open()
            
            # βœ… CORRECT: Use AsyncPostgresSaver with AsyncConnectionPool
            self.checkpointer = AsyncPostgresSaver(self.pool)
            await self.checkpointer.setup()  # βœ… Async setup method
            
            self._is_initialized = True
            logger.info("βœ… PostgreSQL checkpointer initialized successfully with AsyncPostgresSaver")
            return True
            
        except Exception as e:
            logger.error(f"❌ PostgreSQL initialization failed: {e}")
            
            # Fallback to in-memory
            try:
                from langgraph.checkpoint.memory_aio import AsyncMemorySaver  # βœ… Async memory saver
                self.checkpointer = AsyncMemorySaver()
                logger.warning("πŸ”„ Falling back to async in-memory checkpointer")
                self._is_initialized = True
                return True
            except ImportError:
                # Fallback to sync MemorySaver if async not available
                self.checkpointer = MemorySaver()
                logger.warning("πŸ”„ Falling back to sync in-memory checkpointer")
                self._is_initialized = True
                return True
            except Exception as fallback_error:
                logger.error(f"❌ Even fallback failed: {fallback_error}")
                return False

    async def close(self):
        """Close connections with proper cleanup"""
        if self.pool:
            await self.pool.close()
            logger.info("βœ… PostgreSQL connection pool closed")
        
        self._is_initialized = False

    async def health_check(self) -> dict:
        """Check the health of the PostgreSQL connection"""
        if not self._is_initialized or not self.pool:
            return {"status": "uninitialized", "healthy": False}
        
        try:
            async with self.pool.connection() as conn:
                async with conn.cursor() as cur:
                    await cur.execute("SELECT 1")
                    result = await cur.fetchone()
            
            return {
                "status": "healthy", 
                "healthy": True,
                "connection_count": self.pool.size if hasattr(self.pool, 'size') else "unknown"
            }
        except Exception as e:
            return {"status": f"unhealthy: {str(e)}", "healthy": False}

    def is_initialized(self) -> bool:
        """Check if checkpointer is properly initialized"""
        return self._is_initialized and self.checkpointer is not None

    def get_checkpointer(self) -> AsyncPostgresSaver:
        """Get the underlying checkpointer instance"""
        if not self.is_initialized():
            raise RuntimeError("Checkpointer not initialized. Call initialize() first.")
        return self.checkpointer