|
|
from __future__ import annotations |
|
|
|
|
|
import asyncio |
|
|
import re |
|
|
import sqlite3 |
|
|
import time |
|
|
from contextlib import asynccontextmanager, contextmanager |
|
|
from datetime import datetime, timezone |
|
|
from pathlib import Path |
|
|
from typing import TYPE_CHECKING |
|
|
|
|
|
import sqlalchemy as sa |
|
|
from alembic import command, util |
|
|
from alembic.config import Config |
|
|
from loguru import logger |
|
|
from sqlalchemy import event, inspect |
|
|
from sqlalchemy.dialects import sqlite as dialect_sqlite |
|
|
from sqlalchemy.engine import Engine |
|
|
from sqlalchemy.exc import OperationalError |
|
|
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine |
|
|
from sqlmodel import Session, SQLModel, create_engine, select, text |
|
|
from sqlmodel.ext.asyncio.session import AsyncSession |
|
|
|
|
|
from langflow.initial_setup.constants import STARTER_FOLDER_NAME |
|
|
from langflow.services.base import Service |
|
|
from langflow.services.database import models |
|
|
from langflow.services.database.models.user.crud import get_user_by_username |
|
|
from langflow.services.database.utils import Result, TableResults |
|
|
from langflow.services.deps import get_settings_service |
|
|
from langflow.services.utils import teardown_superuser |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from langflow.services.settings.service import SettingsService |
|
|
|
|
|
|
|
|
class DatabaseService(Service): |
|
|
name = "database_service" |
|
|
|
|
|
def __init__(self, settings_service: SettingsService): |
|
|
self.settings_service = settings_service |
|
|
if settings_service.settings.database_url is None: |
|
|
msg = "No database URL provided" |
|
|
raise ValueError(msg) |
|
|
self.database_url: str = settings_service.settings.database_url |
|
|
self._sanitize_database_url() |
|
|
|
|
|
|
|
|
langflow_dir = Path(__file__).parent.parent.parent |
|
|
self.script_location = langflow_dir / "alembic" |
|
|
self.alembic_cfg_path = langflow_dir / "alembic.ini" |
|
|
|
|
|
|
|
|
event.listen(Engine, "connect", self.on_connection) |
|
|
self.engine = self._create_engine() |
|
|
self.async_engine = self._create_async_engine() |
|
|
alembic_log_file = self.settings_service.settings.alembic_log_file |
|
|
|
|
|
|
|
|
if Path(alembic_log_file).is_absolute(): |
|
|
|
|
|
self.alembic_log_path = Path(alembic_log_file) |
|
|
else: |
|
|
|
|
|
self.alembic_log_path = Path(langflow_dir) / alembic_log_file |
|
|
|
|
|
def reload_engine(self) -> None: |
|
|
self._sanitize_database_url() |
|
|
self.engine = self._create_engine() |
|
|
self.async_engine = self._create_async_engine() |
|
|
|
|
|
def _sanitize_database_url(self): |
|
|
if self.database_url.startswith("postgres://"): |
|
|
self.database_url = self.database_url.replace("postgres://", "postgresql://") |
|
|
logger.warning( |
|
|
"Fixed postgres dialect in database URL. Replacing postgres:// with postgresql://. " |
|
|
"To avoid this warning, update the database URL." |
|
|
) |
|
|
|
|
|
def _create_engine(self) -> Engine: |
|
|
"""Create the engine for the database.""" |
|
|
return create_engine( |
|
|
self.database_url, |
|
|
connect_args=self._get_connect_args(), |
|
|
pool_size=self.settings_service.settings.pool_size, |
|
|
max_overflow=self.settings_service.settings.max_overflow, |
|
|
) |
|
|
|
|
|
def _create_async_engine(self) -> AsyncEngine: |
|
|
"""Create the engine for the database.""" |
|
|
url_components = self.database_url.split("://", maxsplit=1) |
|
|
if url_components[0].startswith("sqlite"): |
|
|
database_url = "sqlite+aiosqlite://" |
|
|
kwargs = {} |
|
|
else: |
|
|
kwargs = { |
|
|
"pool_size": self.settings_service.settings.pool_size, |
|
|
"max_overflow": self.settings_service.settings.max_overflow, |
|
|
} |
|
|
database_url = "postgresql+psycopg://" if url_components[0].startswith("postgresql") else url_components[0] |
|
|
database_url += url_components[1] |
|
|
return create_async_engine( |
|
|
database_url, |
|
|
connect_args=self._get_connect_args(), |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
def _get_connect_args(self): |
|
|
if self.settings_service.settings.database_url and self.settings_service.settings.database_url.startswith( |
|
|
"sqlite" |
|
|
): |
|
|
connect_args = { |
|
|
"check_same_thread": False, |
|
|
"timeout": self.settings_service.settings.db_connect_timeout, |
|
|
} |
|
|
else: |
|
|
connect_args = {} |
|
|
return connect_args |
|
|
|
|
|
def on_connection(self, dbapi_connection, _connection_record) -> None: |
|
|
if isinstance(dbapi_connection, sqlite3.Connection | dialect_sqlite.aiosqlite.AsyncAdapt_aiosqlite_connection): |
|
|
pragmas: dict = self.settings_service.settings.sqlite_pragmas or {} |
|
|
pragmas_list = [] |
|
|
for key, val in pragmas.items(): |
|
|
pragmas_list.append(f"PRAGMA {key} = {val}") |
|
|
logger.debug(f"sqlite connection, setting pragmas: {pragmas_list}") |
|
|
if pragmas_list: |
|
|
cursor = dbapi_connection.cursor() |
|
|
try: |
|
|
for pragma in pragmas_list: |
|
|
try: |
|
|
cursor.execute(pragma) |
|
|
except OperationalError: |
|
|
logger.exception(f"Failed to set PRAGMA {pragma}") |
|
|
finally: |
|
|
cursor.close() |
|
|
|
|
|
@contextmanager |
|
|
def with_session(self): |
|
|
with Session(self.engine) as session: |
|
|
yield session |
|
|
|
|
|
@asynccontextmanager |
|
|
async def with_async_session(self): |
|
|
async with AsyncSession(self.async_engine, expire_on_commit=False) as session: |
|
|
yield session |
|
|
|
|
|
async def assign_orphaned_flows_to_superuser(self) -> None: |
|
|
"""Assign orphaned flows to the default superuser when auto login is enabled.""" |
|
|
settings_service = get_settings_service() |
|
|
|
|
|
if not settings_service.auth_settings.AUTO_LOGIN: |
|
|
return |
|
|
|
|
|
async with self.with_async_session() as session: |
|
|
|
|
|
stmt = ( |
|
|
select(models.Flow) |
|
|
.join(models.Folder) |
|
|
.where( |
|
|
models.Flow.user_id == None, |
|
|
models.Folder.name != STARTER_FOLDER_NAME, |
|
|
) |
|
|
) |
|
|
orphaned_flows = (await session.exec(stmt)).all() |
|
|
|
|
|
if not orphaned_flows: |
|
|
return |
|
|
|
|
|
logger.debug("Assigning orphaned flows to the default superuser") |
|
|
|
|
|
|
|
|
superuser_username = settings_service.auth_settings.SUPERUSER |
|
|
superuser = await get_user_by_username(session, superuser_username) |
|
|
|
|
|
if not superuser: |
|
|
error_message = "Default superuser not found" |
|
|
logger.error(error_message) |
|
|
raise RuntimeError(error_message) |
|
|
|
|
|
|
|
|
existing_names: set[str] = set( |
|
|
(await session.exec(select(models.Flow.name).where(models.Flow.user_id == superuser.id))).all() |
|
|
) |
|
|
|
|
|
|
|
|
for flow in orphaned_flows: |
|
|
flow.user_id = superuser.id |
|
|
flow.name = self._generate_unique_flow_name(flow.name, existing_names) |
|
|
existing_names.add(flow.name) |
|
|
session.add(flow) |
|
|
|
|
|
|
|
|
await session.commit() |
|
|
logger.debug("Successfully assigned orphaned flows to the default superuser") |
|
|
|
|
|
def _generate_unique_flow_name(self, original_name: str, existing_names: set[str]) -> str: |
|
|
"""Generate a unique flow name by adding or incrementing a suffix.""" |
|
|
if original_name not in existing_names: |
|
|
return original_name |
|
|
|
|
|
match = re.search(r"^(.*) \((\d+)\)$", original_name) |
|
|
if match: |
|
|
base_name, current_number = match.groups() |
|
|
new_name = f"{base_name} ({int(current_number) + 1})" |
|
|
else: |
|
|
new_name = f"{original_name} (1)" |
|
|
|
|
|
|
|
|
while new_name in existing_names: |
|
|
match = re.match(r"^(.*) \((\d+)\)$", new_name) |
|
|
if match is not None: |
|
|
base_name, current_number = match.groups() |
|
|
else: |
|
|
error_message = "Invalid format: match is None" |
|
|
raise ValueError(error_message) |
|
|
|
|
|
new_name = f"{base_name} ({int(current_number) + 1})" |
|
|
|
|
|
return new_name |
|
|
|
|
|
def check_schema_health(self) -> bool: |
|
|
inspector = inspect(self.engine) |
|
|
|
|
|
model_mapping: dict[str, type[SQLModel]] = { |
|
|
"flow": models.Flow, |
|
|
"user": models.User, |
|
|
"apikey": models.ApiKey, |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
legacy_tables = ["flowstyle"] |
|
|
|
|
|
for table, model in model_mapping.items(): |
|
|
expected_columns = list(model.model_fields.keys()) |
|
|
|
|
|
try: |
|
|
available_columns = [col["name"] for col in inspector.get_columns(table)] |
|
|
except sa.exc.NoSuchTableError: |
|
|
logger.debug(f"Missing table: {table}") |
|
|
return False |
|
|
|
|
|
for column in expected_columns: |
|
|
if column not in available_columns: |
|
|
logger.debug(f"Missing column: {column} in table {table}") |
|
|
return False |
|
|
|
|
|
for table in legacy_tables: |
|
|
if table in inspector.get_table_names(): |
|
|
logger.warning(f"Legacy table exists: {table}") |
|
|
|
|
|
return True |
|
|
|
|
|
def init_alembic(self, alembic_cfg) -> None: |
|
|
logger.info("Initializing alembic") |
|
|
command.ensure_version(alembic_cfg) |
|
|
|
|
|
command.upgrade(alembic_cfg, "head") |
|
|
logger.info("Alembic initialized") |
|
|
|
|
|
def run_migrations(self, *, fix=False) -> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with self.alembic_log_path.open("w", encoding="utf-8") as buffer: |
|
|
alembic_cfg = Config(stdout=buffer) |
|
|
|
|
|
alembic_cfg.set_main_option("script_location", str(self.script_location)) |
|
|
alembic_cfg.set_main_option("sqlalchemy.url", self.database_url.replace("%", "%%")) |
|
|
|
|
|
should_initialize_alembic = False |
|
|
with self.with_session() as session: |
|
|
|
|
|
|
|
|
try: |
|
|
session.exec(text("SELECT * FROM alembic_version")) |
|
|
except Exception: |
|
|
logger.debug("Alembic not initialized") |
|
|
should_initialize_alembic = True |
|
|
|
|
|
if should_initialize_alembic: |
|
|
try: |
|
|
self.init_alembic(alembic_cfg) |
|
|
except Exception as exc: |
|
|
msg = "Error initializing alembic" |
|
|
logger.exception(msg) |
|
|
raise RuntimeError(msg) from exc |
|
|
else: |
|
|
logger.info("Alembic already initialized") |
|
|
|
|
|
logger.info(f"Running DB migrations in {self.script_location}") |
|
|
|
|
|
try: |
|
|
buffer.write(f"{datetime.now(tz=timezone.utc).astimezone().isoformat()}: Checking migrations\n") |
|
|
command.check(alembic_cfg) |
|
|
except Exception as exc: |
|
|
logger.debug(f"Error checking migrations: {exc}") |
|
|
if isinstance(exc, util.exc.CommandError | util.exc.AutogenerateDiffsDetected): |
|
|
command.upgrade(alembic_cfg, "head") |
|
|
time.sleep(3) |
|
|
|
|
|
try: |
|
|
buffer.write(f"{datetime.now(tz=timezone.utc).astimezone()}: Checking migrations\n") |
|
|
command.check(alembic_cfg) |
|
|
except util.exc.AutogenerateDiffsDetected as exc: |
|
|
logger.exception("Error checking migrations") |
|
|
if not fix: |
|
|
msg = f"There's a mismatch between the models and the database.\n{exc}" |
|
|
raise RuntimeError(msg) from exc |
|
|
|
|
|
if fix: |
|
|
self.try_downgrade_upgrade_until_success(alembic_cfg) |
|
|
|
|
|
def try_downgrade_upgrade_until_success(self, alembic_cfg, retries=5) -> None: |
|
|
|
|
|
|
|
|
for i in range(1, retries + 1): |
|
|
try: |
|
|
command.check(alembic_cfg) |
|
|
break |
|
|
except util.exc.AutogenerateDiffsDetected: |
|
|
|
|
|
logger.warning("AutogenerateDiffsDetected") |
|
|
command.downgrade(alembic_cfg, f"-{i}") |
|
|
|
|
|
time.sleep(3) |
|
|
command.upgrade(alembic_cfg, "head") |
|
|
|
|
|
def run_migrations_test(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sql_models = [ |
|
|
model for model in models.__dict__.values() if isinstance(model, type) and issubclass(model, SQLModel) |
|
|
] |
|
|
return [TableResults(sql_model.__tablename__, self.check_table(sql_model)) for sql_model in sql_models] |
|
|
|
|
|
def check_table(self, model): |
|
|
results = [] |
|
|
inspector = inspect(self.engine) |
|
|
table_name = model.__tablename__ |
|
|
expected_columns = list(model.__fields__.keys()) |
|
|
available_columns = [] |
|
|
try: |
|
|
available_columns = [col["name"] for col in inspector.get_columns(table_name)] |
|
|
results.append(Result(name=table_name, type="table", success=True)) |
|
|
except sa.exc.NoSuchTableError: |
|
|
logger.exception(f"Missing table: {table_name}") |
|
|
results.append(Result(name=table_name, type="table", success=False)) |
|
|
|
|
|
for column in expected_columns: |
|
|
if column not in available_columns: |
|
|
logger.error(f"Missing column: {column} in table {table_name}") |
|
|
results.append(Result(name=column, type="column", success=False)) |
|
|
else: |
|
|
results.append(Result(name=column, type="column", success=True)) |
|
|
return results |
|
|
|
|
|
def create_db_and_tables(self) -> None: |
|
|
from sqlalchemy import inspect |
|
|
|
|
|
inspector = inspect(self.engine) |
|
|
table_names = inspector.get_table_names() |
|
|
current_tables = ["flow", "user", "apikey", "folder", "message", "variable", "transaction", "vertex_build"] |
|
|
|
|
|
if table_names and all(table in table_names for table in current_tables): |
|
|
logger.debug("Database and tables already exist") |
|
|
return |
|
|
|
|
|
logger.debug("Creating database and tables") |
|
|
|
|
|
for table in SQLModel.metadata.sorted_tables: |
|
|
try: |
|
|
table.create(self.engine, checkfirst=True) |
|
|
except OperationalError as oe: |
|
|
logger.warning(f"Table {table} already exists, skipping. Exception: {oe}") |
|
|
except Exception as exc: |
|
|
msg = f"Error creating table {table}" |
|
|
logger.exception(msg) |
|
|
raise RuntimeError(msg) from exc |
|
|
|
|
|
|
|
|
inspector = inspect(self.engine) |
|
|
table_names = inspector.get_table_names() |
|
|
for table in current_tables: |
|
|
if table not in table_names: |
|
|
logger.error("Something went wrong creating the database and tables.") |
|
|
logger.error("Please check your database settings.") |
|
|
msg = "Something went wrong creating the database and tables." |
|
|
raise RuntimeError(msg) |
|
|
|
|
|
logger.debug("Database and tables created successfully") |
|
|
|
|
|
async def teardown(self) -> None: |
|
|
logger.debug("Tearing down database") |
|
|
try: |
|
|
settings_service = get_settings_service() |
|
|
|
|
|
|
|
|
async with self.with_async_session() as session: |
|
|
await teardown_superuser(settings_service, session) |
|
|
except Exception: |
|
|
logger.exception("Error tearing down database") |
|
|
await self.async_engine.dispose() |
|
|
await asyncio.to_thread(self.engine.dispose) |
|
|
|