|
|
import json |
|
|
from collections.abc import Sequence |
|
|
from uuid import UUID |
|
|
|
|
|
from langchain_core.chat_history import BaseChatMessageHistory |
|
|
from langchain_core.messages import BaseMessage |
|
|
from loguru import logger |
|
|
from sqlalchemy import delete |
|
|
from sqlmodel import Session, col, select |
|
|
from sqlmodel.ext.asyncio.session import AsyncSession |
|
|
|
|
|
from langflow.schema.message import Message |
|
|
from langflow.services.database.models.message.model import MessageRead, MessageTable |
|
|
from langflow.services.deps import async_session_scope, session_scope |
|
|
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER |
|
|
|
|
|
|
|
|
def _get_variable_query( |
|
|
sender: str | None = None, |
|
|
sender_name: str | None = None, |
|
|
session_id: str | None = None, |
|
|
order_by: str | None = "timestamp", |
|
|
order: str | None = "DESC", |
|
|
flow_id: UUID | None = None, |
|
|
limit: int | None = None, |
|
|
): |
|
|
stmt = select(MessageTable).where(MessageTable.error == False) |
|
|
if sender: |
|
|
stmt = stmt.where(MessageTable.sender == sender) |
|
|
if sender_name: |
|
|
stmt = stmt.where(MessageTable.sender_name == sender_name) |
|
|
if session_id: |
|
|
stmt = stmt.where(MessageTable.session_id == session_id) |
|
|
if flow_id: |
|
|
stmt = stmt.where(MessageTable.flow_id == flow_id) |
|
|
if order_by: |
|
|
col = getattr(MessageTable, order_by).desc() if order == "DESC" else getattr(MessageTable, order_by).asc() |
|
|
stmt = stmt.order_by(col) |
|
|
if limit: |
|
|
stmt = stmt.limit(limit) |
|
|
return stmt |
|
|
|
|
|
|
|
|
def get_messages( |
|
|
sender: str | None = None, |
|
|
sender_name: str | None = None, |
|
|
session_id: str | None = None, |
|
|
order_by: str | None = "timestamp", |
|
|
order: str | None = "DESC", |
|
|
flow_id: UUID | None = None, |
|
|
limit: int | None = None, |
|
|
) -> list[Message]: |
|
|
"""Retrieves messages from the monitor service based on the provided filters. |
|
|
|
|
|
Args: |
|
|
sender (Optional[str]): The sender of the messages (e.g., "Machine" or "User") |
|
|
sender_name (Optional[str]): The name of the sender. |
|
|
session_id (Optional[str]): The session ID associated with the messages. |
|
|
order_by (Optional[str]): The field to order the messages by. Defaults to "timestamp". |
|
|
order (Optional[str]): The order in which to retrieve the messages. Defaults to "DESC". |
|
|
flow_id (Optional[UUID]): The flow ID associated with the messages. |
|
|
limit (Optional[int]): The maximum number of messages to retrieve. |
|
|
|
|
|
Returns: |
|
|
List[Data]: A list of Data objects representing the retrieved messages. |
|
|
""" |
|
|
with session_scope() as session: |
|
|
stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit) |
|
|
messages = session.exec(stmt) |
|
|
return [Message(**d.model_dump()) for d in messages] |
|
|
|
|
|
|
|
|
async def aget_messages( |
|
|
sender: str | None = None, |
|
|
sender_name: str | None = None, |
|
|
session_id: str | None = None, |
|
|
order_by: str | None = "timestamp", |
|
|
order: str | None = "DESC", |
|
|
flow_id: UUID | None = None, |
|
|
limit: int | None = None, |
|
|
) -> list[Message]: |
|
|
"""Retrieves messages from the monitor service based on the provided filters. |
|
|
|
|
|
Args: |
|
|
sender (Optional[str]): The sender of the messages (e.g., "Machine" or "User") |
|
|
sender_name (Optional[str]): The name of the sender. |
|
|
session_id (Optional[str]): The session ID associated with the messages. |
|
|
order_by (Optional[str]): The field to order the messages by. Defaults to "timestamp". |
|
|
order (Optional[str]): The order in which to retrieve the messages. Defaults to "DESC". |
|
|
flow_id (Optional[UUID]): The flow ID associated with the messages. |
|
|
limit (Optional[int]): The maximum number of messages to retrieve. |
|
|
|
|
|
Returns: |
|
|
List[Data]: A list of Data objects representing the retrieved messages. |
|
|
""" |
|
|
async with async_session_scope() as session: |
|
|
stmt = _get_variable_query(sender, sender_name, session_id, order_by, order, flow_id, limit) |
|
|
messages = await session.exec(stmt) |
|
|
return [await Message.create(**d.model_dump()) for d in messages] |
|
|
|
|
|
|
|
|
def add_messages(messages: Message | list[Message], flow_id: str | None = None): |
|
|
"""Add a message to the monitor service.""" |
|
|
if not isinstance(messages, list): |
|
|
messages = [messages] |
|
|
|
|
|
if not all(isinstance(message, Message) for message in messages): |
|
|
types = ", ".join([str(type(message)) for message in messages]) |
|
|
msg = f"The messages must be instances of Message. Found: {types}" |
|
|
raise ValueError(msg) |
|
|
|
|
|
try: |
|
|
messages_models = [MessageTable.from_message(msg, flow_id=flow_id) for msg in messages] |
|
|
with session_scope() as session: |
|
|
messages_models = add_messagetables(messages_models, session) |
|
|
return [Message(**message.model_dump()) for message in messages_models] |
|
|
except Exception as e: |
|
|
logger.exception(e) |
|
|
raise |
|
|
|
|
|
|
|
|
async def aadd_messages(messages: Message | list[Message], flow_id: str | None = None): |
|
|
"""Add a message to the monitor service.""" |
|
|
if not isinstance(messages, list): |
|
|
messages = [messages] |
|
|
|
|
|
if not all(isinstance(message, Message) for message in messages): |
|
|
types = ", ".join([str(type(message)) for message in messages]) |
|
|
msg = f"The messages must be instances of Message. Found: {types}" |
|
|
raise ValueError(msg) |
|
|
|
|
|
try: |
|
|
messages_models = [MessageTable.from_message(msg, flow_id=flow_id) for msg in messages] |
|
|
async with async_session_scope() as session: |
|
|
messages_models = await aadd_messagetables(messages_models, session) |
|
|
return [await Message.create(**message.model_dump()) for message in messages_models] |
|
|
except Exception as e: |
|
|
logger.exception(e) |
|
|
raise |
|
|
|
|
|
|
|
|
def update_messages(messages: Message | list[Message]) -> list[Message]: |
|
|
if not isinstance(messages, list): |
|
|
messages = [messages] |
|
|
|
|
|
with session_scope() as session: |
|
|
updated_messages: list[MessageTable] = [] |
|
|
for message in messages: |
|
|
msg = session.get(MessageTable, message.id) |
|
|
if msg: |
|
|
msg.sqlmodel_update(message.model_dump(exclude_unset=True, exclude_none=True)) |
|
|
session.add(msg) |
|
|
session.commit() |
|
|
session.refresh(msg) |
|
|
updated_messages.append(msg) |
|
|
else: |
|
|
logger.warning(f"Message with id {message.id} not found") |
|
|
return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages] |
|
|
|
|
|
|
|
|
async def aupdate_messages(messages: Message | list[Message]) -> list[Message]: |
|
|
if not isinstance(messages, list): |
|
|
messages = [messages] |
|
|
|
|
|
async with async_session_scope() as session: |
|
|
updated_messages: list[MessageTable] = [] |
|
|
for message in messages: |
|
|
msg = await session.get(MessageTable, message.id) |
|
|
if msg: |
|
|
msg.sqlmodel_update(message.model_dump(exclude_unset=True, exclude_none=True)) |
|
|
session.add(msg) |
|
|
await session.commit() |
|
|
await session.refresh(msg) |
|
|
updated_messages.append(msg) |
|
|
else: |
|
|
logger.warning(f"Message with id {message.id} not found") |
|
|
return [MessageRead.model_validate(message, from_attributes=True) for message in updated_messages] |
|
|
|
|
|
|
|
|
def add_messagetables(messages: list[MessageTable], session: Session): |
|
|
for message in messages: |
|
|
try: |
|
|
session.add(message) |
|
|
session.commit() |
|
|
session.refresh(message) |
|
|
except Exception as e: |
|
|
logger.exception(e) |
|
|
raise |
|
|
|
|
|
new_messages = [] |
|
|
for msg in messages: |
|
|
msg.properties = json.loads(msg.properties) if isinstance(msg.properties, str) else msg.properties |
|
|
msg.content_blocks = [json.loads(j) if isinstance(j, str) else j for j in msg.content_blocks] |
|
|
msg.category = msg.category or "" |
|
|
new_messages.append(msg) |
|
|
|
|
|
return [MessageRead.model_validate(message, from_attributes=True) for message in new_messages] |
|
|
|
|
|
|
|
|
async def aadd_messagetables(messages: list[MessageTable], session: AsyncSession): |
|
|
try: |
|
|
for message in messages: |
|
|
session.add(message) |
|
|
await session.commit() |
|
|
for message in messages: |
|
|
await session.refresh(message) |
|
|
except Exception as e: |
|
|
logger.exception(e) |
|
|
raise |
|
|
|
|
|
new_messages = [] |
|
|
for msg in messages: |
|
|
msg.properties = json.loads(msg.properties) if isinstance(msg.properties, str) else msg.properties |
|
|
msg.content_blocks = [json.loads(j) if isinstance(j, str) else j for j in msg.content_blocks] |
|
|
msg.category = msg.category or "" |
|
|
new_messages.append(msg) |
|
|
|
|
|
return [MessageRead.model_validate(message, from_attributes=True) for message in new_messages] |
|
|
|
|
|
|
|
|
def delete_messages(session_id: str) -> None: |
|
|
"""Delete messages from the monitor service based on the provided session ID. |
|
|
|
|
|
Args: |
|
|
session_id (str): The session ID associated with the messages to delete. |
|
|
""" |
|
|
with session_scope() as session: |
|
|
session.exec( |
|
|
delete(MessageTable) |
|
|
.where(col(MessageTable.session_id) == session_id) |
|
|
.execution_options(synchronize_session="fetch") |
|
|
) |
|
|
|
|
|
|
|
|
async def adelete_messages(session_id: str) -> None: |
|
|
"""Delete messages from the monitor service based on the provided session ID. |
|
|
|
|
|
Args: |
|
|
session_id (str): The session ID associated with the messages to delete. |
|
|
""" |
|
|
async with async_session_scope() as session: |
|
|
stmt = ( |
|
|
delete(MessageTable) |
|
|
.where(col(MessageTable.session_id) == session_id) |
|
|
.execution_options(synchronize_session="fetch") |
|
|
) |
|
|
await session.exec(stmt) |
|
|
|
|
|
|
|
|
async def delete_message(id_: str) -> None: |
|
|
"""Delete a message from the monitor service based on the provided ID. |
|
|
|
|
|
Args: |
|
|
id_ (str): The ID of the message to delete. |
|
|
""" |
|
|
async with async_session_scope() as session: |
|
|
message = await session.get(MessageTable, id_) |
|
|
if message: |
|
|
await session.delete(message) |
|
|
await session.commit() |
|
|
|
|
|
|
|
|
def store_message( |
|
|
message: Message, |
|
|
flow_id: str | None = None, |
|
|
) -> list[Message]: |
|
|
"""Stores a message in the memory. |
|
|
|
|
|
Args: |
|
|
message (Message): The message to store. |
|
|
flow_id (Optional[str]): The flow ID associated with the message. |
|
|
When running from the CustomComponent you can access this using `self.graph.flow_id`. |
|
|
|
|
|
Returns: |
|
|
List[Message]: A list of data containing the stored message. |
|
|
|
|
|
Raises: |
|
|
ValueError: If any of the required parameters (session_id, sender, sender_name) is not provided. |
|
|
""" |
|
|
if not message: |
|
|
logger.warning("No message provided.") |
|
|
return [] |
|
|
|
|
|
required_fields = ["session_id", "sender", "sender_name"] |
|
|
missing_fields = [field for field in required_fields if not getattr(message, field)] |
|
|
if missing_fields: |
|
|
missing_descriptions = { |
|
|
"session_id": "session_id (unique conversation identifier)", |
|
|
"sender": f"sender (e.g., '{MESSAGE_SENDER_USER}' or '{MESSAGE_SENDER_AI}')", |
|
|
"sender_name": "sender_name (display name, e.g., 'User' or 'Assistant')", |
|
|
} |
|
|
missing = ", ".join(missing_descriptions[field] for field in missing_fields) |
|
|
msg = ( |
|
|
f"It looks like we're missing some important information: {missing}. " |
|
|
"Please ensure that your message includes all the required fields." |
|
|
) |
|
|
raise ValueError(msg) |
|
|
if hasattr(message, "id") and message.id: |
|
|
return update_messages([message]) |
|
|
return add_messages([message], flow_id=flow_id) |
|
|
|
|
|
|
|
|
async def astore_message( |
|
|
message: Message, |
|
|
flow_id: str | None = None, |
|
|
) -> list[Message]: |
|
|
"""Stores a message in the memory. |
|
|
|
|
|
Args: |
|
|
message (Message): The message to store. |
|
|
flow_id (Optional[str]): The flow ID associated with the message. |
|
|
When running from the CustomComponent you can access this using `self.graph.flow_id`. |
|
|
|
|
|
Returns: |
|
|
List[Message]: A list of data containing the stored message. |
|
|
|
|
|
Raises: |
|
|
ValueError: If any of the required parameters (session_id, sender, sender_name) is not provided. |
|
|
""" |
|
|
if not message: |
|
|
logger.warning("No message provided.") |
|
|
return [] |
|
|
|
|
|
if not message.session_id or not message.sender or not message.sender_name: |
|
|
msg = "All of session_id, sender, and sender_name must be provided." |
|
|
raise ValueError(msg) |
|
|
if hasattr(message, "id") and message.id: |
|
|
return await aupdate_messages([message]) |
|
|
return await aadd_messages([message], flow_id=flow_id) |
|
|
|
|
|
|
|
|
class LCBuiltinChatMemory(BaseChatMessageHistory): |
|
|
def __init__( |
|
|
self, |
|
|
flow_id: str, |
|
|
session_id: str, |
|
|
) -> None: |
|
|
self.flow_id = flow_id |
|
|
self.session_id = session_id |
|
|
|
|
|
@property |
|
|
def messages(self) -> list[BaseMessage]: |
|
|
messages = get_messages( |
|
|
session_id=self.session_id, |
|
|
) |
|
|
return [m.to_lc_message() for m in messages if not m.error] |
|
|
|
|
|
async def aget_messages(self) -> list[BaseMessage]: |
|
|
messages = await aget_messages( |
|
|
session_id=self.session_id, |
|
|
) |
|
|
return [m.to_lc_message() for m in messages if not m.error] |
|
|
|
|
|
def add_messages(self, messages: Sequence[BaseMessage]) -> None: |
|
|
for lc_message in messages: |
|
|
message = Message.from_lc_message(lc_message) |
|
|
message.session_id = self.session_id |
|
|
store_message(message, flow_id=self.flow_id) |
|
|
|
|
|
async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None: |
|
|
for lc_message in messages: |
|
|
message = Message.from_lc_message(lc_message) |
|
|
message.session_id = self.session_id |
|
|
await astore_message(message, flow_id=self.flow_id) |
|
|
|
|
|
def clear(self) -> None: |
|
|
delete_messages(self.session_id) |
|
|
|
|
|
async def aclear(self) -> None: |
|
|
await adelete_messages(self.session_id) |
|
|
|