|
|
from __future__ import annotations |
|
|
|
|
|
from contextlib import asynccontextmanager, contextmanager |
|
|
from dataclasses import dataclass |
|
|
from typing import TYPE_CHECKING |
|
|
|
|
|
from alembic.util.exc import CommandError |
|
|
from loguru import logger |
|
|
from sqlmodel import Session, text |
|
|
from sqlmodel.ext.asyncio.session import AsyncSession |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from langflow.services.database.service import DatabaseService |
|
|
|
|
|
|
|
|
def initialize_database(*, fix_migration: bool = False) -> None: |
|
|
logger.debug("Initializing database") |
|
|
from langflow.services.deps import get_db_service |
|
|
|
|
|
database_service: DatabaseService = get_db_service() |
|
|
try: |
|
|
database_service.create_db_and_tables() |
|
|
except Exception as exc: |
|
|
|
|
|
|
|
|
if "already exists" not in str(exc): |
|
|
msg = "Error creating DB and tables" |
|
|
logger.exception(msg) |
|
|
raise RuntimeError(msg) from exc |
|
|
try: |
|
|
database_service.check_schema_health() |
|
|
except Exception as exc: |
|
|
msg = "Error checking schema health" |
|
|
logger.exception(msg) |
|
|
raise RuntimeError(msg) from exc |
|
|
try: |
|
|
database_service.run_migrations(fix=fix_migration) |
|
|
except CommandError as exc: |
|
|
|
|
|
|
|
|
if "overlaps with other requested revisions" not in str( |
|
|
exc |
|
|
) and "Can't locate revision identified by" not in str(exc): |
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
logger.warning("Wrong revision in DB, deleting alembic_version table and running migrations again") |
|
|
with session_getter(database_service) as session: |
|
|
session.exec(text("DROP TABLE alembic_version")) |
|
|
database_service.run_migrations(fix=fix_migration) |
|
|
except Exception as exc: |
|
|
|
|
|
|
|
|
if "already exists" not in str(exc): |
|
|
logger.exception(exc) |
|
|
raise |
|
|
logger.debug("Database initialized") |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def session_getter(db_service: DatabaseService): |
|
|
try: |
|
|
session = Session(db_service.engine) |
|
|
yield session |
|
|
except Exception: |
|
|
logger.exception("Session rollback because of exception") |
|
|
session.rollback() |
|
|
raise |
|
|
finally: |
|
|
session.close() |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def async_session_getter(db_service: DatabaseService): |
|
|
try: |
|
|
session = AsyncSession(db_service.async_engine, expire_on_commit=False) |
|
|
yield session |
|
|
except Exception: |
|
|
logger.exception("Session rollback because of exception") |
|
|
await session.rollback() |
|
|
raise |
|
|
finally: |
|
|
await session.close() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Result: |
|
|
name: str |
|
|
type: str |
|
|
success: bool |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TableResults: |
|
|
table_name: str |
|
|
results: list[Result] |
|
|
|