Spaces:
Running
Running
| from __future__ import annotations | |
| from contextlib import asynccontextmanager, contextmanager | |
| from dataclasses import dataclass | |
| from typing import TYPE_CHECKING | |
| from alembic.util.exc import CommandError | |
| from loguru import logger | |
| from sqlmodel import Session, text | |
| from sqlmodel.ext.asyncio.session import AsyncSession | |
| if TYPE_CHECKING: | |
| from langflow.services.database.service import DatabaseService | |
| def initialize_database(*, fix_migration: bool = False) -> None: | |
| logger.debug("Initializing database") | |
| from langflow.services.deps import get_db_service | |
| database_service: DatabaseService = get_db_service() | |
| try: | |
| database_service.create_db_and_tables() | |
| except Exception as exc: | |
| # if the exception involves tables already existing | |
| # we can ignore it | |
| if "already exists" not in str(exc): | |
| msg = "Error creating DB and tables" | |
| logger.exception(msg) | |
| raise RuntimeError(msg) from exc | |
| try: | |
| database_service.check_schema_health() | |
| except Exception as exc: | |
| msg = "Error checking schema health" | |
| logger.exception(msg) | |
| raise RuntimeError(msg) from exc | |
| try: | |
| database_service.run_migrations(fix=fix_migration) | |
| except CommandError as exc: | |
| # if "overlaps with other requested revisions" or "Can't locate revision identified by" | |
| # are not in the exception, we can't handle it | |
| if "overlaps with other requested revisions" not in str( | |
| exc | |
| ) and "Can't locate revision identified by" not in str(exc): | |
| raise | |
| # This means there's wrong revision in the DB | |
| # We need to delete the alembic_version table | |
| # and run the migrations again | |
| logger.warning("Wrong revision in DB, deleting alembic_version table and running migrations again") | |
| with session_getter(database_service) as session: | |
| session.exec(text("DROP TABLE alembic_version")) | |
| database_service.run_migrations(fix=fix_migration) | |
| except Exception as exc: | |
| # if the exception involves tables already existing | |
| # we can ignore it | |
| if "already exists" not in str(exc): | |
| logger.exception(exc) | |
| raise | |
| logger.debug("Database initialized") | |
| def session_getter(db_service: DatabaseService): | |
| try: | |
| session = Session(db_service.engine) | |
| yield session | |
| except Exception: | |
| logger.exception("Session rollback because of exception") | |
| session.rollback() | |
| raise | |
| finally: | |
| session.close() | |
| async def async_session_getter(db_service: DatabaseService): | |
| try: | |
| session = AsyncSession(db_service.async_engine, expire_on_commit=False) | |
| yield session | |
| except Exception: | |
| logger.exception("Session rollback because of exception") | |
| await session.rollback() | |
| raise | |
| finally: | |
| await session.close() | |
| class Result: | |
| name: str | |
| type: str | |
| success: bool | |
| class TableResults: | |
| table_name: str | |
| results: list[Result] | |