anujjoshi3105's picture
fix: postgres db url
6dd2bd0
import logging
from contextlib import asynccontextmanager
from core.embeddings import get_embeddings
from langchain_postgres import PGVector
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.store.postgres import AsyncPostgresStore
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool
from core.settings import settings
logger = logging.getLogger(__name__)
def validate_postgres_config() -> None:
"""
Validate that all required PostgreSQL configuration is present.
Raises ValueError if any required configuration is missing.
"""
required_vars = [
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"POSTGRES_HOST",
"POSTGRES_PORT",
"POSTGRES_DB",
]
missing = [var for var in required_vars if not getattr(settings, var, None)]
if missing:
raise ValueError(
f"Missing required PostgreSQL configuration: {', '.join(missing)}. "
"All individual POSTGRES_* environment variables must be set to use PostgreSQL persistence."
)
if settings.POSTGRES_MIN_CONNECTIONS_PER_POOL > settings.POSTGRES_MAX_CONNECTIONS_PER_POOL:
raise ValueError(
f"POSTGRES_MIN_CONNECTIONS_PER_POOL ({settings.POSTGRES_MIN_CONNECTIONS_PER_POOL}) must be less than or equal to POSTGRES_MAX_CONNECTIONS_PER_POOL ({settings.POSTGRES_MAX_CONNECTIONS_PER_POOL})"
)
def get_postgres_connection_string() -> str:
"""Build and return the PostgreSQL connection string from settings."""
if settings.POSTGRES_PASSWORD is None:
raise ValueError("POSTGRES_PASSWORD is not set")
return (
f"postgresql://{settings.POSTGRES_USER}:"
f"{settings.POSTGRES_PASSWORD.get_secret_value()}@"
f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/"
f"{settings.POSTGRES_DB}?sslmode=require"
)
@asynccontextmanager
async def get_postgres_saver():
"""Initialize and return a PostgreSQL saver instance based on a connection pool for more resilient connections."""
validate_postgres_config()
application_name = settings.POSTGRES_APPLICATION_NAME + "-" + "saver"
async with AsyncConnectionPool(
get_postgres_connection_string(),
min_size=settings.POSTGRES_MIN_CONNECTIONS_PER_POOL,
max_size=settings.POSTGRES_MAX_CONNECTIONS_PER_POOL,
# Langgraph requires autocommmit=true and row_factory to be set to dict_row.
# Application_name is passed so you can identify the connection in your Postgres database connection manager.
kwargs={"autocommit": True, "row_factory": dict_row, "application_name": application_name},
# makes sure that the connection is still valid before using it
check=AsyncConnectionPool.check_connection,
) as pool:
try:
checkpointer = AsyncPostgresSaver(pool)
await checkpointer.setup()
yield checkpointer
finally:
await pool.close()
@asynccontextmanager
async def get_postgres_store():
"""
Get a PostgreSQL store instance based on a connection pool for more resilient connections.
Returns an AsyncPostgresStore instance that can be used with async context manager pattern.
"""
validate_postgres_config()
application_name = settings.POSTGRES_APPLICATION_NAME + "-" + "store"
async with AsyncConnectionPool(
get_postgres_connection_string(),
min_size=settings.POSTGRES_MIN_CONNECTIONS_PER_POOL,
max_size=settings.POSTGRES_MAX_CONNECTIONS_PER_POOL,
# Langgraph requires autocommmit=true and row_factory to be set to dict_row
# Application_name is passed so you can identify the connection in your Postgres database connection manager.
kwargs={"autocommit": True, "row_factory": dict_row, "application_name": application_name},
# makes sure that the connection is still valid before using it
check=AsyncConnectionPool.check_connection,
) as pool:
try:
store = AsyncPostgresStore(pool)
await store.setup()
yield store
finally:
await pool.close()
def get_pgvector_connection_string() -> str:
"""Build and return the PostgreSQL connection string for vectors from settings."""
return (
f"postgresql+psycopg://{settings.POSTGRES_USER}:"
f"{settings.POSTGRES_PASSWORD.get_secret_value()}@"
f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/"
f"{settings.POSTGRES_DB}?sslmode=require"
)
def load_pgvector_store():
"""Get a PostgreSQL vectors store instance."""
validate_postgres_config()
return PGVector(
connection=get_pgvector_connection_string(),
collection_name=settings.VECTOR_STORE_COLLECTION_NAME,
embeddings=get_embeddings(settings.DEFAULT_EMBEDDING_MODEL),
)
def load_pgvector_retriever(k: int = 6):
store = load_pgvector_store()
return store.as_retriever(
search_type="mmr",
search_kwargs={
"k": k,
"fetch_k": 20, # candidates
"lambda_mult": 0.6, # relevance vs diversity
},
)