|
|
import logging |
|
|
from contextlib import asynccontextmanager |
|
|
|
|
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver |
|
|
from langgraph.store.postgres import AsyncPostgresStore |
|
|
from psycopg.rows import dict_row |
|
|
from psycopg_pool import AsyncConnectionPool |
|
|
|
|
|
from core.settings import settings |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def validate_postgres_config() -> None: |
|
|
""" |
|
|
Validate that all required PostgreSQL configuration is present. |
|
|
Raises ValueError if any required configuration is missing. |
|
|
""" |
|
|
if settings.POSTGRES_URL: |
|
|
return |
|
|
|
|
|
required_vars = [ |
|
|
"POSTGRES_USER", |
|
|
"POSTGRES_PASSWORD", |
|
|
"POSTGRES_HOST", |
|
|
"POSTGRES_PORT", |
|
|
"POSTGRES_DB", |
|
|
] |
|
|
|
|
|
missing = [var for var in required_vars if not getattr(settings, var, None)] |
|
|
if missing: |
|
|
raise ValueError( |
|
|
f"Missing required PostgreSQL configuration: {', '.join(missing)}. " |
|
|
"Either POSTGRES_URL or all individual POSTGRES_* environment variables must be set to use PostgreSQL persistence." |
|
|
) |
|
|
|
|
|
if settings.POSTGRES_MIN_CONNECTIONS_PER_POOL > settings.POSTGRES_MAX_CONNECTIONS_PER_POOL: |
|
|
raise ValueError( |
|
|
f"POSTGRES_MIN_CONNECTIONS_PER_POOL ({settings.POSTGRES_MIN_CONNECTIONS_PER_POOL}) must be less than or equal to POSTGRES_MAX_CONNECTIONS_PER_POOL ({settings.POSTGRES_MAX_CONNECTIONS_PER_POOL})" |
|
|
) |
|
|
|
|
|
|
|
|
def get_postgres_connection_string() -> str: |
|
|
"""Build and return the PostgreSQL connection string from settings.""" |
|
|
if settings.POSTGRES_URL: |
|
|
return settings.POSTGRES_URL.get_secret_value() |
|
|
|
|
|
if settings.POSTGRES_PASSWORD is None: |
|
|
raise ValueError("POSTGRES_PASSWORD is not set") |
|
|
return ( |
|
|
f"postgresql://{settings.POSTGRES_USER}:" |
|
|
f"{settings.POSTGRES_PASSWORD.get_secret_value()}@" |
|
|
f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/" |
|
|
f"{settings.POSTGRES_DB}" |
|
|
) |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def get_postgres_saver(): |
|
|
"""Initialize and return a PostgreSQL saver instance based on a connection pool for more resilent connections.""" |
|
|
validate_postgres_config() |
|
|
application_name = settings.POSTGRES_APPLICATION_NAME + "-" + "saver" |
|
|
|
|
|
async with AsyncConnectionPool( |
|
|
get_postgres_connection_string(), |
|
|
min_size=settings.POSTGRES_MIN_CONNECTIONS_PER_POOL, |
|
|
max_size=settings.POSTGRES_MAX_CONNECTIONS_PER_POOL, |
|
|
|
|
|
|
|
|
kwargs={"autocommit": True, "row_factory": dict_row, "application_name": application_name}, |
|
|
|
|
|
check=AsyncConnectionPool.check_connection, |
|
|
) as pool: |
|
|
try: |
|
|
checkpointer = AsyncPostgresSaver(pool) |
|
|
await checkpointer.setup() |
|
|
yield checkpointer |
|
|
finally: |
|
|
await pool.close() |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def get_postgres_store(): |
|
|
""" |
|
|
Get a PostgreSQL store instance based on a connection pool for more resilent connections. |
|
|
|
|
|
Returns an AsyncPostgresStore instance that can be used with async context manager pattern. |
|
|
|
|
|
""" |
|
|
validate_postgres_config() |
|
|
application_name = settings.POSTGRES_APPLICATION_NAME + "-" + "store" |
|
|
|
|
|
async with AsyncConnectionPool( |
|
|
get_postgres_connection_string(), |
|
|
min_size=settings.POSTGRES_MIN_CONNECTIONS_PER_POOL, |
|
|
max_size=settings.POSTGRES_MAX_CONNECTIONS_PER_POOL, |
|
|
|
|
|
|
|
|
kwargs={"autocommit": True, "row_factory": dict_row, "application_name": application_name}, |
|
|
|
|
|
check=AsyncConnectionPool.check_connection, |
|
|
) as pool: |
|
|
try: |
|
|
store = AsyncPostgresStore(pool) |
|
|
await store.setup() |
|
|
yield store |
|
|
finally: |
|
|
await pool.close() |
|
|
|