Spaces:
Sleeping
Sleeping
File size: 4,250 Bytes
a83c934 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
"""Database and vector store configuration
Provides async database engine, session management, and Qdrant client setup.
"""
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
create_async_engine,
async_sessionmaker,
)
from sqlalchemy.orm import declarative_base
from qdrant_client import AsyncQdrantClient
from qdrant_client.models import Distance, VectorParams
from src.config.settings import settings
from src.utils.logger import get_logger
logger = get_logger(__name__)
# SQLAlchemy Base for models
Base = declarative_base()
# Global engine instance
_engine: AsyncEngine | None = None
# Global session maker
_async_session_maker: async_sessionmaker[AsyncSession] | None = None
# Global Qdrant client
_qdrant_client: AsyncQdrantClient | None = None
def get_engine() -> AsyncEngine:
"""Get or create the async database engine
Returns:
AsyncEngine instance
"""
global _engine
if _engine is None:
# AsyncPG connection arguments for SSL
connect_args = {}
if "neon.tech" in settings.database_url or settings.is_production:
# Enable SSL for Neon and production databases
connect_args["ssl"] = "require"
_engine = create_async_engine(
settings.async_database_url,
echo=not settings.is_production, # Log SQL in development
pool_pre_ping=True, # Verify connections before using
pool_size=5,
max_overflow=10,
connect_args=connect_args,
)
logger.info("Database engine created")
return _engine
def get_session_maker() -> async_sessionmaker[AsyncSession]:
"""Get or create the async session maker
Returns:
async_sessionmaker instance
"""
global _async_session_maker
if _async_session_maker is None:
engine = get_engine()
_async_session_maker = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
logger.info("Session maker created")
return _async_session_maker
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""Dependency for getting async database sessions
Yields:
AsyncSession instance
"""
session_maker = get_session_maker()
async with session_maker() as session:
try:
yield session
finally:
await session.close()
def get_qdrant_client() -> AsyncQdrantClient:
"""Get or create the Qdrant client
Returns:
AsyncQdrantClient instance
"""
global _qdrant_client
if _qdrant_client is None:
_qdrant_client = AsyncQdrantClient(
url=settings.qdrant_url,
api_key=settings.qdrant_api_key,
timeout=30.0,
)
logger.info("Qdrant client created")
return _qdrant_client
async def init_qdrant_collection() -> None:
"""Initialize Qdrant collection if it doesn't exist
Creates the collection with appropriate vector configuration.
"""
client = get_qdrant_client()
# Check if collection exists
collections = await client.get_collections()
collection_names = [col.name for col in collections.collections]
if settings.qdrant_collection_name not in collection_names:
# Create collection with vector configuration
await client.create_collection(
collection_name=settings.qdrant_collection_name,
vectors_config=VectorParams(
size=settings.vector_size,
distance=Distance.COSINE,
),
)
logger.info(f"Created Qdrant collection: {settings.qdrant_collection_name}")
else:
logger.info(f"Qdrant collection already exists: {settings.qdrant_collection_name}")
async def close_database_connections() -> None:
"""Close all database connections gracefully"""
global _engine, _qdrant_client
if _engine is not None:
await _engine.dispose()
logger.info("Database engine disposed")
_engine = None
if _qdrant_client is not None:
await _qdrant_client.close()
logger.info("Qdrant client closed")
_qdrant_client = None
|