iLOVE2D's picture
Upload 2846 files
5374a2d verified
from enum import Enum
from pydantic import Field, model_validator
from datetime import datetime
from typing import Optional, Callable, Any, List, Union
from .module import BaseModule
from .module_utils import generate_id, get_timestamp
class MessageType(Enum):
REQUEST = "request"
RESPONSE = "response"
COMMAND = "command"
ERROR = "error"
UNKNOWN = "unknown"
INPUT = "input"
class Message(BaseModule):
"""
the base class for message.
Attributes:
content (Any): the content of the message, need to implement str() function.
agent (str): the sender of the message, normally set as the agent name.
action (str): the trigger of the message, normally set as the action name.
prompt (str): the prompt used to obtain the generated text.
next_actions (List[str]): the following actions.
msg_type (str): the type of the message, such as "request", "response", "command" etc.
wf_goal (str): the goal of the whole workflow.
wf_task (str): the name of a task in the workflow, i.e., the ``name`` of a WorkFlowNode instance.
wf_task_desc (str): the description of a task in the workflow, i.e., the ``description`` of a WorkFlowNode instance.
message_id (str): the unique identifier of the message.
timestamp (str): the timestame of the message.
"""
content: Any
agent: Optional[str] = None
# receivers: Optional[Union[str, List[str]]] = None
action: Optional[str] = None
prompt: Optional[Union[str, List[dict]]] = None
next_actions: Optional[List[str]] = None
msg_type: Optional[MessageType] = MessageType.UNKNOWN
wf_goal: Optional[str] = None
wf_task: Optional[str] = None
wf_task_desc: Optional[str] = None
message_id: Optional[str] = Field(default_factory=generate_id)
timestamp: Optional[str] = Field(default_factory=get_timestamp)
conversation_id: Optional[str] = Field(default_factory=generate_id)
def __str__(self) -> str:
return self.to_str()
def __eq__(self, other: "Message"):
return self.message_id == other.message_id
def __hash__(self):
return self.message_id
def to_str(self) -> str:
msg_part = []
if self.timestamp:
msg_part.append(f"[{self.timestamp}]")
if self.agent:
msg_part.append(f"Agent: {self.agent}")
if self.msg_type and self.msg_type != MessageType.UNKNOWN:
msg_part.append(f"Type: {self.msg_type}")
if self.action:
msg_part.append(f"Action: {self.action}")
if self.wf_goal:
msg_part.append(f"Goal: {self.wf_goal}")
if self.wf_task:
msg_part.append(f"Task: {self.wf_task} ({self.wf_task_desc or 'No description'})")
if self.content:
msg_part.append(f"Content: {str(self.content)}")
msg = "\n".join(msg_part)
return msg
def to_dict(self, exclude_none: bool = True, ignore: List[str] = [], **kwargs) -> dict:
"""
Convert the Message to a dictionary for saving.
"""
data = super().to_dict(exclude_none=exclude_none, ignore=ignore, **kwargs)
if self.msg_type:
data["msg_type"] = self.msg_type.value
return data
@model_validator(mode="before")
@classmethod
def validate_data(cls, data: Any) -> Any:
if "msg_type" in data and data["msg_type"] and isinstance(data["msg_type"], str):
data["msg_type"] = MessageType(data["msg_type"])
return data
@classmethod
def sort_by_timestamp(cls, messages: List['Message'], reverse: bool = False) -> List['Message']:
"""
sort the messages based on the timestamp.
Args:
messages (List[Message]): the messages to be sorted.
reverse (bool): If True, sort the messages in descending order. Otherwise, sort the messages in ascending order.
"""
messages.sort(key=lambda msg: datetime.strptime(msg.timestamp, "%Y-%m-%d %H:%M:%S"), reverse=reverse)
return messages
@classmethod
def sort(cls, messages: List['Message'], key: Optional[Callable[['Message'], Any]] = None, reverse: bool = False) -> List['Message']:
"""
sort the messages using key or timestamp (by default).
Args:
messages (List[Message]): the messages to be sorted.
key (Optional[Callable[['Message'], Any]]): the function used to sort messages.
reverse (bool): If True, sort the messages in descending order. Otherwise, sort the messages in ascending order.
"""
if key is None:
return cls.sort_by_timestamp(messages, reverse=reverse)
messages.sort(key=key, reverse=reverse)
return messages
@classmethod
def merge(cls, messages: List[List['Message']], sort: bool=False, key: Optional[Callable[['Message'], Any]] = None, reverse: bool=False) -> List['Message']:
"""
merge different message list.
Args:
messages (List[List[Message]]): the message lists to be merged.
sort (bool): whether to sort the merged messages.
key (Optional[Callable[['Message'], Any]]): the function used to sort messages.
reverse (bool): If True, sort the messages in descending order. Otherwise, sort the messages in ascending order.
"""
merged_messages = sum(messages, [])
if sort:
merged_messages = cls.sort(merged_messages, key=key, reverse=reverse)
return merged_messages