Spaces:
Sleeping
Sleeping
| import logging | |
| from contextlib import asynccontextmanager | |
| from core.embeddings import get_embeddings | |
| from langchain_postgres import PGVector | |
| from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver | |
| from langgraph.store.postgres import AsyncPostgresStore | |
| from psycopg.rows import dict_row | |
| from psycopg_pool import AsyncConnectionPool | |
| from core.settings import settings | |
| logger = logging.getLogger(__name__) | |
| def validate_postgres_config() -> None: | |
| """ | |
| Validate that all required PostgreSQL configuration is present. | |
| Raises ValueError if any required configuration is missing. | |
| """ | |
| required_vars = [ | |
| "POSTGRES_USER", | |
| "POSTGRES_PASSWORD", | |
| "POSTGRES_HOST", | |
| "POSTGRES_PORT", | |
| "POSTGRES_DB", | |
| ] | |
| missing = [var for var in required_vars if not getattr(settings, var, None)] | |
| if missing: | |
| raise ValueError( | |
| f"Missing required PostgreSQL configuration: {', '.join(missing)}. " | |
| "All individual POSTGRES_* environment variables must be set to use PostgreSQL persistence." | |
| ) | |
| if settings.POSTGRES_MIN_CONNECTIONS_PER_POOL > settings.POSTGRES_MAX_CONNECTIONS_PER_POOL: | |
| raise ValueError( | |
| f"POSTGRES_MIN_CONNECTIONS_PER_POOL ({settings.POSTGRES_MIN_CONNECTIONS_PER_POOL}) must be less than or equal to POSTGRES_MAX_CONNECTIONS_PER_POOL ({settings.POSTGRES_MAX_CONNECTIONS_PER_POOL})" | |
| ) | |
| def get_postgres_connection_string() -> str: | |
| """Build and return the PostgreSQL connection string from settings.""" | |
| if settings.POSTGRES_PASSWORD is None: | |
| raise ValueError("POSTGRES_PASSWORD is not set") | |
| return ( | |
| f"postgresql://{settings.POSTGRES_USER}:" | |
| f"{settings.POSTGRES_PASSWORD.get_secret_value()}@" | |
| f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/" | |
| f"{settings.POSTGRES_DB}?sslmode=require" | |
| ) | |
| async def get_postgres_saver(): | |
| """Initialize and return a PostgreSQL saver instance based on a connection pool for more resilient connections.""" | |
| validate_postgres_config() | |
| application_name = settings.POSTGRES_APPLICATION_NAME + "-" + "saver" | |
| async with AsyncConnectionPool( | |
| get_postgres_connection_string(), | |
| min_size=settings.POSTGRES_MIN_CONNECTIONS_PER_POOL, | |
| max_size=settings.POSTGRES_MAX_CONNECTIONS_PER_POOL, | |
| # Langgraph requires autocommmit=true and row_factory to be set to dict_row. | |
| # Application_name is passed so you can identify the connection in your Postgres database connection manager. | |
| kwargs={"autocommit": True, "row_factory": dict_row, "application_name": application_name}, | |
| # makes sure that the connection is still valid before using it | |
| check=AsyncConnectionPool.check_connection, | |
| ) as pool: | |
| try: | |
| checkpointer = AsyncPostgresSaver(pool) | |
| await checkpointer.setup() | |
| yield checkpointer | |
| finally: | |
| await pool.close() | |
| async def get_postgres_store(): | |
| """ | |
| Get a PostgreSQL store instance based on a connection pool for more resilient connections. | |
| Returns an AsyncPostgresStore instance that can be used with async context manager pattern. | |
| """ | |
| validate_postgres_config() | |
| application_name = settings.POSTGRES_APPLICATION_NAME + "-" + "store" | |
| async with AsyncConnectionPool( | |
| get_postgres_connection_string(), | |
| min_size=settings.POSTGRES_MIN_CONNECTIONS_PER_POOL, | |
| max_size=settings.POSTGRES_MAX_CONNECTIONS_PER_POOL, | |
| # Langgraph requires autocommmit=true and row_factory to be set to dict_row | |
| # Application_name is passed so you can identify the connection in your Postgres database connection manager. | |
| kwargs={"autocommit": True, "row_factory": dict_row, "application_name": application_name}, | |
| # makes sure that the connection is still valid before using it | |
| check=AsyncConnectionPool.check_connection, | |
| ) as pool: | |
| try: | |
| store = AsyncPostgresStore(pool) | |
| await store.setup() | |
| yield store | |
| finally: | |
| await pool.close() | |
| def get_pgvector_connection_string() -> str: | |
| """Build and return the PostgreSQL connection string for vectors from settings.""" | |
| return ( | |
| f"postgresql+psycopg://{settings.POSTGRES_USER}:" | |
| f"{settings.POSTGRES_PASSWORD.get_secret_value()}@" | |
| f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/" | |
| f"{settings.POSTGRES_DB}?sslmode=require" | |
| ) | |
| def load_pgvector_store(): | |
| """Get a PostgreSQL vectors store instance.""" | |
| validate_postgres_config() | |
| return PGVector( | |
| connection=get_pgvector_connection_string(), | |
| collection_name=settings.VECTOR_STORE_COLLECTION_NAME, | |
| embeddings=get_embeddings(settings.DEFAULT_EMBEDDING_MODEL), | |
| ) | |
| def load_pgvector_retriever(k: int = 6): | |
| store = load_pgvector_store() | |
| return store.as_retriever( | |
| search_type="mmr", | |
| search_kwargs={ | |
| "k": k, | |
| "fetch_k": 20, # candidates | |
| "lambda_mult": 0.6, # relevance vs diversity | |
| }, | |
| ) |