import logging from contextlib import asynccontextmanager from core.embeddings import get_embeddings from langchain_postgres import PGVector from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from langgraph.store.postgres import AsyncPostgresStore from psycopg.rows import dict_row from psycopg_pool import AsyncConnectionPool from core.settings import settings logger = logging.getLogger(__name__) def validate_postgres_config() -> None: """ Validate that all required PostgreSQL configuration is present. Raises ValueError if any required configuration is missing. """ required_vars = [ "POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_HOST", "POSTGRES_PORT", "POSTGRES_DB", ] missing = [var for var in required_vars if not getattr(settings, var, None)] if missing: raise ValueError( f"Missing required PostgreSQL configuration: {', '.join(missing)}. " "All individual POSTGRES_* environment variables must be set to use PostgreSQL persistence." ) if settings.POSTGRES_MIN_CONNECTIONS_PER_POOL > settings.POSTGRES_MAX_CONNECTIONS_PER_POOL: raise ValueError( f"POSTGRES_MIN_CONNECTIONS_PER_POOL ({settings.POSTGRES_MIN_CONNECTIONS_PER_POOL}) must be less than or equal to POSTGRES_MAX_CONNECTIONS_PER_POOL ({settings.POSTGRES_MAX_CONNECTIONS_PER_POOL})" ) def get_postgres_connection_string() -> str: """Build and return the PostgreSQL connection string from settings.""" if settings.POSTGRES_PASSWORD is None: raise ValueError("POSTGRES_PASSWORD is not set") return ( f"postgresql://{settings.POSTGRES_USER}:" f"{settings.POSTGRES_PASSWORD.get_secret_value()}@" f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/" f"{settings.POSTGRES_DB}?sslmode=require" ) @asynccontextmanager async def get_postgres_saver(): """Initialize and return a PostgreSQL saver instance based on a connection pool for more resilient connections.""" validate_postgres_config() application_name = settings.POSTGRES_APPLICATION_NAME + "-" + "saver" async with AsyncConnectionPool( get_postgres_connection_string(), min_size=settings.POSTGRES_MIN_CONNECTIONS_PER_POOL, max_size=settings.POSTGRES_MAX_CONNECTIONS_PER_POOL, # Langgraph requires autocommmit=true and row_factory to be set to dict_row. # Application_name is passed so you can identify the connection in your Postgres database connection manager. kwargs={"autocommit": True, "row_factory": dict_row, "application_name": application_name}, # makes sure that the connection is still valid before using it check=AsyncConnectionPool.check_connection, ) as pool: try: checkpointer = AsyncPostgresSaver(pool) await checkpointer.setup() yield checkpointer finally: await pool.close() @asynccontextmanager async def get_postgres_store(): """ Get a PostgreSQL store instance based on a connection pool for more resilient connections. Returns an AsyncPostgresStore instance that can be used with async context manager pattern. """ validate_postgres_config() application_name = settings.POSTGRES_APPLICATION_NAME + "-" + "store" async with AsyncConnectionPool( get_postgres_connection_string(), min_size=settings.POSTGRES_MIN_CONNECTIONS_PER_POOL, max_size=settings.POSTGRES_MAX_CONNECTIONS_PER_POOL, # Langgraph requires autocommmit=true and row_factory to be set to dict_row # Application_name is passed so you can identify the connection in your Postgres database connection manager. kwargs={"autocommit": True, "row_factory": dict_row, "application_name": application_name}, # makes sure that the connection is still valid before using it check=AsyncConnectionPool.check_connection, ) as pool: try: store = AsyncPostgresStore(pool) await store.setup() yield store finally: await pool.close() def get_pgvector_connection_string() -> str: """Build and return the PostgreSQL connection string for vectors from settings.""" return ( f"postgresql+psycopg://{settings.POSTGRES_USER}:" f"{settings.POSTGRES_PASSWORD.get_secret_value()}@" f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/" f"{settings.POSTGRES_DB}?sslmode=require" ) def load_pgvector_store(): """Get a PostgreSQL vectors store instance.""" validate_postgres_config() return PGVector( connection=get_pgvector_connection_string(), collection_name=settings.VECTOR_STORE_COLLECTION_NAME, embeddings=get_embeddings(settings.DEFAULT_EMBEDDING_MODEL), ) def load_pgvector_retriever(k: int = 6): store = load_pgvector_store() return store.as_retriever( search_type="mmr", search_kwargs={ "k": k, "fetch_k": 20, # candidates "lambda_mult": 0.6, # relevance vs diversity }, )