File size: 5,644 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from enum import Enum
from pydantic import Field, model_validator
from datetime import datetime
from typing import Optional, Callable, Any, List, Union
from .module import BaseModule
from .module_utils import generate_id, get_timestamp
class MessageType(Enum):
REQUEST = "request"
RESPONSE = "response"
COMMAND = "command"
ERROR = "error"
UNKNOWN = "unknown"
INPUT = "input"
class Message(BaseModule):
"""
the base class for message.
Attributes:
content (Any): the content of the message, need to implement str() function.
agent (str): the sender of the message, normally set as the agent name.
action (str): the trigger of the message, normally set as the action name.
prompt (str): the prompt used to obtain the generated text.
next_actions (List[str]): the following actions.
msg_type (str): the type of the message, such as "request", "response", "command" etc.
wf_goal (str): the goal of the whole workflow.
wf_task (str): the name of a task in the workflow, i.e., the ``name`` of a WorkFlowNode instance.
wf_task_desc (str): the description of a task in the workflow, i.e., the ``description`` of a WorkFlowNode instance.
message_id (str): the unique identifier of the message.
timestamp (str): the timestame of the message.
"""
content: Any
agent: Optional[str] = None
# receivers: Optional[Union[str, List[str]]] = None
action: Optional[str] = None
prompt: Optional[Union[str, List[dict]]] = None
next_actions: Optional[List[str]] = None
msg_type: Optional[MessageType] = MessageType.UNKNOWN
wf_goal: Optional[str] = None
wf_task: Optional[str] = None
wf_task_desc: Optional[str] = None
message_id: Optional[str] = Field(default_factory=generate_id)
timestamp: Optional[str] = Field(default_factory=get_timestamp)
conversation_id: Optional[str] = Field(default_factory=generate_id)
def __str__(self) -> str:
return self.to_str()
def __eq__(self, other: "Message"):
return self.message_id == other.message_id
def __hash__(self):
return self.message_id
def to_str(self) -> str:
msg_part = []
if self.timestamp:
msg_part.append(f"[{self.timestamp}]")
if self.agent:
msg_part.append(f"Agent: {self.agent}")
if self.msg_type and self.msg_type != MessageType.UNKNOWN:
msg_part.append(f"Type: {self.msg_type}")
if self.action:
msg_part.append(f"Action: {self.action}")
if self.wf_goal:
msg_part.append(f"Goal: {self.wf_goal}")
if self.wf_task:
msg_part.append(f"Task: {self.wf_task} ({self.wf_task_desc or 'No description'})")
if self.content:
msg_part.append(f"Content: {str(self.content)}")
msg = "\n".join(msg_part)
return msg
def to_dict(self, exclude_none: bool = True, ignore: List[str] = [], **kwargs) -> dict:
"""
Convert the Message to a dictionary for saving.
"""
data = super().to_dict(exclude_none=exclude_none, ignore=ignore, **kwargs)
if self.msg_type:
data["msg_type"] = self.msg_type.value
return data
@model_validator(mode="before")
@classmethod
def validate_data(cls, data: Any) -> Any:
if "msg_type" in data and data["msg_type"] and isinstance(data["msg_type"], str):
data["msg_type"] = MessageType(data["msg_type"])
return data
@classmethod
def sort_by_timestamp(cls, messages: List['Message'], reverse: bool = False) -> List['Message']:
"""
sort the messages based on the timestamp.
Args:
messages (List[Message]): the messages to be sorted.
reverse (bool): If True, sort the messages in descending order. Otherwise, sort the messages in ascending order.
"""
messages.sort(key=lambda msg: datetime.strptime(msg.timestamp, "%Y-%m-%d %H:%M:%S"), reverse=reverse)
return messages
@classmethod
def sort(cls, messages: List['Message'], key: Optional[Callable[['Message'], Any]] = None, reverse: bool = False) -> List['Message']:
"""
sort the messages using key or timestamp (by default).
Args:
messages (List[Message]): the messages to be sorted.
key (Optional[Callable[['Message'], Any]]): the function used to sort messages.
reverse (bool): If True, sort the messages in descending order. Otherwise, sort the messages in ascending order.
"""
if key is None:
return cls.sort_by_timestamp(messages, reverse=reverse)
messages.sort(key=key, reverse=reverse)
return messages
@classmethod
def merge(cls, messages: List[List['Message']], sort: bool=False, key: Optional[Callable[['Message'], Any]] = None, reverse: bool=False) -> List['Message']:
"""
merge different message list.
Args:
messages (List[List[Message]]): the message lists to be merged.
sort (bool): whether to sort the merged messages.
key (Optional[Callable[['Message'], Any]]): the function used to sort messages.
reverse (bool): If True, sort the messages in descending order. Otherwise, sort the messages in ascending order.
"""
merged_messages = sum(messages, [])
if sort:
merged_messages = cls.sort(merged_messages, key=key, reverse=reverse)
return merged_messages
|