|
|
from __future__ import annotations |
|
|
|
|
|
import os |
|
|
from datetime import datetime, timezone |
|
|
from typing import TYPE_CHECKING |
|
|
|
|
|
from loguru import logger |
|
|
from sqlmodel import Session, select |
|
|
|
|
|
from langflow.services.auth import utils as auth_utils |
|
|
from langflow.services.base import Service |
|
|
from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableUpdate |
|
|
from langflow.services.variable.base import VariableService |
|
|
from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from collections.abc import Sequence |
|
|
from uuid import UUID |
|
|
|
|
|
from sqlmodel.ext.asyncio.session import AsyncSession |
|
|
|
|
|
from langflow.services.settings.service import SettingsService |
|
|
|
|
|
|
|
|
class DatabaseVariableService(VariableService, Service): |
|
|
def __init__(self, settings_service: SettingsService): |
|
|
self.settings_service = settings_service |
|
|
|
|
|
async def initialize_user_variables(self, user_id: UUID | str, session: AsyncSession) -> None: |
|
|
if not self.settings_service.settings.store_environment_variables: |
|
|
logger.info("Skipping environment variable storage.") |
|
|
return |
|
|
|
|
|
logger.info("Storing environment variables in the database.") |
|
|
for var_name in self.settings_service.settings.variables_to_get_from_environment: |
|
|
if var_name in os.environ and os.environ[var_name].strip(): |
|
|
value = os.environ[var_name].strip() |
|
|
query = select(Variable).where(Variable.user_id == user_id, Variable.name == var_name) |
|
|
existing = (await session.exec(query)).first() |
|
|
try: |
|
|
if existing: |
|
|
await self.update_variable(user_id, var_name, value, session) |
|
|
else: |
|
|
await self.create_variable( |
|
|
user_id=user_id, |
|
|
name=var_name, |
|
|
value=value, |
|
|
default_fields=[], |
|
|
type_=CREDENTIAL_TYPE, |
|
|
session=session, |
|
|
) |
|
|
logger.info(f"Processed {var_name} variable from environment.") |
|
|
except Exception as e: |
|
|
logger.exception(f"Error processing {var_name} variable: {e!s}") |
|
|
|
|
|
def get_variable( |
|
|
self, |
|
|
user_id: UUID | str, |
|
|
name: str, |
|
|
field: str, |
|
|
session: Session, |
|
|
) -> str: |
|
|
|
|
|
|
|
|
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first() |
|
|
|
|
|
if not variable or not variable.value: |
|
|
msg = f"{name} variable not found." |
|
|
raise ValueError(msg) |
|
|
|
|
|
if variable.type == CREDENTIAL_TYPE and field == "session_id": |
|
|
msg = ( |
|
|
f"variable {name} of type 'Credential' cannot be used in a Session ID field " |
|
|
"because its purpose is to prevent the exposure of values." |
|
|
) |
|
|
raise TypeError(msg) |
|
|
|
|
|
|
|
|
return auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service) |
|
|
|
|
|
async def get_all(self, user_id: UUID | str, session: AsyncSession) -> list[Variable | None]: |
|
|
stmt = select(Variable).where(Variable.user_id == user_id) |
|
|
return list((await session.exec(stmt)).all()) |
|
|
|
|
|
def list_variables_sync(self, user_id: UUID | str, session: Session) -> list[str | None]: |
|
|
variables = session.exec(select(Variable).where(Variable.user_id == user_id)).all() |
|
|
return [variable.name for variable in variables if variable] |
|
|
|
|
|
async def list_variables(self, user_id: UUID | str, session: AsyncSession) -> list[str | None]: |
|
|
variables = await self.get_all(user_id=user_id, session=session) |
|
|
return [variable.name for variable in variables if variable] |
|
|
|
|
|
async def update_variable( |
|
|
self, |
|
|
user_id: UUID | str, |
|
|
name: str, |
|
|
value: str, |
|
|
session: AsyncSession, |
|
|
): |
|
|
stmt = select(Variable).where(Variable.user_id == user_id, Variable.name == name) |
|
|
variable = (await session.exec(stmt)).first() |
|
|
if not variable: |
|
|
msg = f"{name} variable not found." |
|
|
raise ValueError(msg) |
|
|
encrypted = auth_utils.encrypt_api_key(value, settings_service=self.settings_service) |
|
|
variable.value = encrypted |
|
|
session.add(variable) |
|
|
await session.commit() |
|
|
await session.refresh(variable) |
|
|
return variable |
|
|
|
|
|
async def update_variable_fields( |
|
|
self, |
|
|
user_id: UUID | str, |
|
|
variable_id: UUID | str, |
|
|
variable: VariableUpdate, |
|
|
session: AsyncSession, |
|
|
): |
|
|
query = select(Variable).where(Variable.id == variable_id, Variable.user_id == user_id) |
|
|
db_variable = (await session.exec(query)).one() |
|
|
db_variable.updated_at = datetime.now(timezone.utc) |
|
|
|
|
|
variable.value = variable.value or "" |
|
|
encrypted = auth_utils.encrypt_api_key(variable.value, settings_service=self.settings_service) |
|
|
variable.value = encrypted |
|
|
|
|
|
variable_data = variable.model_dump(exclude_unset=True) |
|
|
for key, value in variable_data.items(): |
|
|
setattr(db_variable, key, value) |
|
|
|
|
|
session.add(db_variable) |
|
|
await session.commit() |
|
|
await session.refresh(db_variable) |
|
|
return db_variable |
|
|
|
|
|
async def delete_variable( |
|
|
self, |
|
|
user_id: UUID | str, |
|
|
name: str, |
|
|
session: AsyncSession, |
|
|
) -> None: |
|
|
stmt = select(Variable).where(Variable.user_id == user_id).where(Variable.name == name) |
|
|
variable = (await session.exec(stmt)).first() |
|
|
if not variable: |
|
|
msg = f"{name} variable not found." |
|
|
raise ValueError(msg) |
|
|
await session.delete(variable) |
|
|
await session.commit() |
|
|
|
|
|
async def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: AsyncSession) -> None: |
|
|
stmt = select(Variable).where(Variable.user_id == user_id, Variable.id == variable_id) |
|
|
variable = (await session.exec(stmt)).first() |
|
|
if not variable: |
|
|
msg = f"{variable_id} variable not found." |
|
|
raise ValueError(msg) |
|
|
await session.delete(variable) |
|
|
await session.commit() |
|
|
|
|
|
async def create_variable( |
|
|
self, |
|
|
user_id: UUID | str, |
|
|
name: str, |
|
|
value: str, |
|
|
*, |
|
|
default_fields: Sequence[str] = (), |
|
|
type_: str = GENERIC_TYPE, |
|
|
session: AsyncSession, |
|
|
): |
|
|
variable_base = VariableCreate( |
|
|
name=name, |
|
|
type=type_, |
|
|
value=auth_utils.encrypt_api_key(value, settings_service=self.settings_service), |
|
|
default_fields=list(default_fields), |
|
|
) |
|
|
variable = Variable.model_validate(variable_base, from_attributes=True, update={"user_id": user_id}) |
|
|
session.add(variable) |
|
|
await session.commit() |
|
|
await session.refresh(variable) |
|
|
return variable |
|
|
|