| import contextlib |
| from collections.abc import AsyncIterator |
|
|
| from sqlalchemy.ext.asyncio import ( |
| AsyncConnection, |
| AsyncSession, |
| async_sessionmaker, |
| create_async_engine, |
| ) |
| from sqlalchemy.orm import declarative_base |
|
|
| from app.core.config import settings |
|
|
| Base = declarative_base() |
|
|
| |
|
|
|
|
| class DatabaseSessionManager: |
| def __init__(self): |
| self._engine = create_async_engine(settings.SQLALCHEMY_DATABASE_URI) |
| self._sessionmaker = async_sessionmaker(autocommit=False, bind=self._engine) |
|
|
| async def close(self): |
| if self._engine is None: |
| raise Exception("DatabaseSessionManager is not initialized") |
| await self._engine.dispose() |
|
|
| self._engine = None |
| self._sessionmaker = None |
|
|
| @contextlib.asynccontextmanager |
| async def connect(self) -> AsyncIterator[AsyncConnection]: |
| if self._engine is None: |
| raise Exception("DatabaseSessionManager is not initialized") |
|
|
| async with self._engine.begin() as connection: |
| try: |
| yield connection |
| except Exception: |
| await connection.rollback() |
| raise |
|
|
| @contextlib.asynccontextmanager |
| async def session(self) -> AsyncIterator[AsyncSession]: |
| if self._sessionmaker is None: |
| raise Exception("DatabaseSessionManager is not initialized") |
|
|
| session = self._sessionmaker() |
| try: |
| yield session |
| except Exception: |
| await session.rollback() |
| raise |
| finally: |
| await session.close() |
|
|
|
|
| sessionmanager = DatabaseSessionManager() |
|
|
|
|
| async def get_db_session(): |
| async with sessionmanager.session() as session: |
| yield session |
|
|