backend / src /memory /postgres.py
anujjoshi3105's picture
initial
22dcdfd
import logging
from contextlib import asynccontextmanager
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.store.postgres import AsyncPostgresStore
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool
from core.settings import settings
logger = logging.getLogger(__name__)
def validate_postgres_config() -> None:
"""
Validate that all required PostgreSQL configuration is present.
Raises ValueError if any required configuration is missing.
"""
if settings.POSTGRES_URL:
return
required_vars = [
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"POSTGRES_HOST",
"POSTGRES_PORT",
"POSTGRES_DB",
]
missing = [var for var in required_vars if not getattr(settings, var, None)]
if missing:
raise ValueError(
f"Missing required PostgreSQL configuration: {', '.join(missing)}. "
"Either POSTGRES_URL or all individual POSTGRES_* environment variables must be set to use PostgreSQL persistence."
)
if settings.POSTGRES_MIN_CONNECTIONS_PER_POOL > settings.POSTGRES_MAX_CONNECTIONS_PER_POOL:
raise ValueError(
f"POSTGRES_MIN_CONNECTIONS_PER_POOL ({settings.POSTGRES_MIN_CONNECTIONS_PER_POOL}) must be less than or equal to POSTGRES_MAX_CONNECTIONS_PER_POOL ({settings.POSTGRES_MAX_CONNECTIONS_PER_POOL})"
)
def get_postgres_connection_string() -> str:
"""Build and return the PostgreSQL connection string from settings."""
if settings.POSTGRES_URL:
return settings.POSTGRES_URL.get_secret_value()
if settings.POSTGRES_PASSWORD is None:
raise ValueError("POSTGRES_PASSWORD is not set")
return (
f"postgresql://{settings.POSTGRES_USER}:"
f"{settings.POSTGRES_PASSWORD.get_secret_value()}@"
f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/"
f"{settings.POSTGRES_DB}"
)
@asynccontextmanager
async def get_postgres_saver():
"""Initialize and return a PostgreSQL saver instance based on a connection pool for more resilent connections."""
validate_postgres_config()
application_name = settings.POSTGRES_APPLICATION_NAME + "-" + "saver"
async with AsyncConnectionPool(
get_postgres_connection_string(),
min_size=settings.POSTGRES_MIN_CONNECTIONS_PER_POOL,
max_size=settings.POSTGRES_MAX_CONNECTIONS_PER_POOL,
# Langgraph requires autocommmit=true and row_factory to be set to dict_row.
# Application_name is passed so you can identify the connection in your Postgres database connection manager.
kwargs={"autocommit": True, "row_factory": dict_row, "application_name": application_name},
# makes sure that the connection is still valid before using it
check=AsyncConnectionPool.check_connection,
) as pool:
try:
checkpointer = AsyncPostgresSaver(pool)
await checkpointer.setup()
yield checkpointer
finally:
await pool.close()
@asynccontextmanager
async def get_postgres_store():
"""
Get a PostgreSQL store instance based on a connection pool for more resilent connections.
Returns an AsyncPostgresStore instance that can be used with async context manager pattern.
"""
validate_postgres_config()
application_name = settings.POSTGRES_APPLICATION_NAME + "-" + "store"
async with AsyncConnectionPool(
get_postgres_connection_string(),
min_size=settings.POSTGRES_MIN_CONNECTIONS_PER_POOL,
max_size=settings.POSTGRES_MAX_CONNECTIONS_PER_POOL,
# Langgraph requires autocommmit=true and row_factory to be set to dict_row
# Application_name is passed so you can identify the connection in your Postgres database connection manager.
kwargs={"autocommit": True, "row_factory": dict_row, "application_name": application_name},
# makes sure that the connection is still valid before using it
check=AsyncConnectionPool.check_connection,
) as pool:
try:
store = AsyncPostgresStore(pool)
await store.setup()
yield store
finally:
await pool.close()