| from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession |
| from sqlalchemy.ext.declarative import declarative_base |
| from sqlalchemy.orm import sessionmaker |
| from sqlalchemy import MetaData, select |
|
|
| from backend.config.settings import settings |
|
|
| |
| if "sqlite" in settings.DATABASE_URL: |
| |
| SQLALCHEMY_DATABASE_URL = settings.DATABASE_URL.replace("sqlite:///", "sqlite+aiosqlite:///") |
| else: |
| |
| SQLALCHEMY_DATABASE_URL = settings.DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://") |
|
|
| |
| if "sqlite" in SQLALCHEMY_DATABASE_URL: |
| |
| engine = create_async_engine( |
| SQLALCHEMY_DATABASE_URL, |
| echo=True, |
| connect_args={"check_same_thread": False} |
| ) |
| else: |
| |
| engine = create_async_engine(SQLALCHEMY_DATABASE_URL, echo=True) |
|
|
| |
| async_session = sessionmaker( |
| engine, class_=AsyncSession, expire_on_commit=False |
| ) |
|
|
| |
| metadata = MetaData() |
| Base = declarative_base(metadata=metadata) |
|
|
| |
| async def get_db(): |
| async with async_session() as session: |
| try: |
| yield session |
| await session.commit() |
| except Exception: |
| await session.rollback() |
| raise |
| finally: |
| await session.close() |
|
|
| |
| async def init_db(): |
| async with engine.begin() as conn: |
| |
| await conn.run_sync(Base.metadata.create_all) |
| |
| |
| async with async_session() as session: |
| from backend.core.security import get_password_hash |
| from backend.db.models.user import User |
| from backend.db.models.role import Role |
| |
| |
| result = await session.execute(select(Role).filter(Role.name == "admin")) |
| admin_role = result.scalars().first() |
| |
| if not admin_role: |
| admin_role = Role(name="admin", description="Administrator role") |
| session.add(admin_role) |
| await session.commit() |
| await session.refresh(admin_role) |
| |
| |
| result = await session.execute(select(User).filter(User.email == "admin@example.com")) |
| admin_user = result.scalars().first() |
| |
| if not admin_user: |
| admin_user = User( |
| email="admin@example.com", |
| username="admin", |
| hashed_password=get_password_hash("admin123"), |
| role_id=admin_role.id, |
| is_active=True |
| ) |
| session.add(admin_user) |
| await session.commit() |
| |
| |
| result = await session.execute(select(Role).filter(Role.name == "user")) |
| user_role = result.scalars().first() |
| |
| if not user_role: |
| user_role = Role(name="user", description="Regular user role") |
| session.add(user_role) |
| await session.commit() |
| await session.refresh(user_role) |
| |
| |
| result = await session.execute(select(User).filter(User.email == "user@example.com")) |
| regular_user = result.scalars().first() |
| |
| if not regular_user: |
| regular_user = User( |
| email="user@example.com", |
| username="user", |
| hashed_password=get_password_hash("user123"), |
| role_id=user_role.id, |
| is_active=True |
| ) |
| session.add(regular_user) |
| await session.commit() |