Spaces:
Sleeping
Sleeping
File size: 5,184 Bytes
361bd3e 6dd2bd0 361bd3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import logging
from contextlib import asynccontextmanager
from core.embeddings import get_embeddings
from langchain_postgres import PGVector
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.store.postgres import AsyncPostgresStore
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool
from core.settings import settings
logger = logging.getLogger(__name__)
def validate_postgres_config() -> None:
"""
Validate that all required PostgreSQL configuration is present.
Raises ValueError if any required configuration is missing.
"""
required_vars = [
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"POSTGRES_HOST",
"POSTGRES_PORT",
"POSTGRES_DB",
]
missing = [var for var in required_vars if not getattr(settings, var, None)]
if missing:
raise ValueError(
f"Missing required PostgreSQL configuration: {', '.join(missing)}. "
"All individual POSTGRES_* environment variables must be set to use PostgreSQL persistence."
)
if settings.POSTGRES_MIN_CONNECTIONS_PER_POOL > settings.POSTGRES_MAX_CONNECTIONS_PER_POOL:
raise ValueError(
f"POSTGRES_MIN_CONNECTIONS_PER_POOL ({settings.POSTGRES_MIN_CONNECTIONS_PER_POOL}) must be less than or equal to POSTGRES_MAX_CONNECTIONS_PER_POOL ({settings.POSTGRES_MAX_CONNECTIONS_PER_POOL})"
)
def get_postgres_connection_string() -> str:
"""Build and return the PostgreSQL connection string from settings."""
if settings.POSTGRES_PASSWORD is None:
raise ValueError("POSTGRES_PASSWORD is not set")
return (
f"postgresql://{settings.POSTGRES_USER}:"
f"{settings.POSTGRES_PASSWORD.get_secret_value()}@"
f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/"
f"{settings.POSTGRES_DB}?sslmode=require"
)
@asynccontextmanager
async def get_postgres_saver():
"""Initialize and return a PostgreSQL saver instance based on a connection pool for more resilient connections."""
validate_postgres_config()
application_name = settings.POSTGRES_APPLICATION_NAME + "-" + "saver"
async with AsyncConnectionPool(
get_postgres_connection_string(),
min_size=settings.POSTGRES_MIN_CONNECTIONS_PER_POOL,
max_size=settings.POSTGRES_MAX_CONNECTIONS_PER_POOL,
# Langgraph requires autocommmit=true and row_factory to be set to dict_row.
# Application_name is passed so you can identify the connection in your Postgres database connection manager.
kwargs={"autocommit": True, "row_factory": dict_row, "application_name": application_name},
# makes sure that the connection is still valid before using it
check=AsyncConnectionPool.check_connection,
) as pool:
try:
checkpointer = AsyncPostgresSaver(pool)
await checkpointer.setup()
yield checkpointer
finally:
await pool.close()
@asynccontextmanager
async def get_postgres_store():
"""
Get a PostgreSQL store instance based on a connection pool for more resilient connections.
Returns an AsyncPostgresStore instance that can be used with async context manager pattern.
"""
validate_postgres_config()
application_name = settings.POSTGRES_APPLICATION_NAME + "-" + "store"
async with AsyncConnectionPool(
get_postgres_connection_string(),
min_size=settings.POSTGRES_MIN_CONNECTIONS_PER_POOL,
max_size=settings.POSTGRES_MAX_CONNECTIONS_PER_POOL,
# Langgraph requires autocommmit=true and row_factory to be set to dict_row
# Application_name is passed so you can identify the connection in your Postgres database connection manager.
kwargs={"autocommit": True, "row_factory": dict_row, "application_name": application_name},
# makes sure that the connection is still valid before using it
check=AsyncConnectionPool.check_connection,
) as pool:
try:
store = AsyncPostgresStore(pool)
await store.setup()
yield store
finally:
await pool.close()
def get_pgvector_connection_string() -> str:
"""Build and return the PostgreSQL connection string for vectors from settings."""
return (
f"postgresql+psycopg://{settings.POSTGRES_USER}:"
f"{settings.POSTGRES_PASSWORD.get_secret_value()}@"
f"{settings.POSTGRES_HOST}:{settings.POSTGRES_PORT}/"
f"{settings.POSTGRES_DB}?sslmode=require"
)
def load_pgvector_store():
"""Get a PostgreSQL vectors store instance."""
validate_postgres_config()
return PGVector(
connection=get_pgvector_connection_string(),
collection_name=settings.VECTOR_STORE_COLLECTION_NAME,
embeddings=get_embeddings(settings.DEFAULT_EMBEDDING_MODEL),
)
def load_pgvector_retriever(k: int = 6):
store = load_pgvector_store()
return store.as_retriever(
search_type="mmr",
search_kwargs={
"k": k,
"fetch_k": 20, # candidates
"lambda_mult": 0.6, # relevance vs diversity
},
) |