Spaces:
Sleeping
Sleeping
| import hashlib | |
| import time | |
| from dataclasses import dataclass | |
| from typing import List, Union | |
| from uuid import uuid1 | |
| # Preserved roles | |
| SYSTEM_NAME = "System" | |
| MODERATOR_NAME = "Moderator" | |
| def _hash(input: str): | |
| """ | |
| Helper function that generates a SHA256 hash of a given input string. | |
| Parameters: | |
| input (str): The input string to be hashed. | |
| Returns: | |
| str: The SHA256 hash of the input string. | |
| """ | |
| hex_dig = hashlib.sha256(input.encode()).hexdigest() | |
| return hex_dig | |
| class Message: | |
| """ | |
| Represents a message in the chatArena environment. | |
| Attributes: | |
| agent_name (str): Name of the agent who sent the message. | |
| content (str): Content of the message. | |
| turn (int): The turn at which the message was sent. | |
| timestamp (int): Wall time at which the message was sent. Defaults to current time in nanoseconds. | |
| visible_to (Union[str, List[str]]): The receivers of the message. Can be a single agent, multiple agents, or 'all'. Defaults to 'all'. | |
| msg_type (str): Type of the message, e.g., 'text'. Defaults to 'text'. | |
| logged (bool): Whether the message is logged in the database. Defaults to False. | |
| """ | |
| agent_name: str | |
| content: str | |
| turn: int | |
| timestamp: int = time.time_ns() | |
| visible_to: Union[str, List[str]] = "all" | |
| msg_type: str = "text" | |
| logged: bool = False # Whether the message is logged in the database | |
| def msg_hash(self): | |
| # Generate a unique message id given the content, timestamp and role | |
| return _hash( | |
| f"agent: {self.agent_name}\ncontent: {self.content}\ntimestamp: {str(self.timestamp)}\nturn: {self.turn}\nmsg_type: {self.msg_type}" | |
| ) | |
| class MessagePool: | |
| """ | |
| A pool to manage the messages in the chatArena environment. | |
| The pool is essentially a list of messages, and it allows a unified treatment of the visibility of the messages. | |
| It supports two configurations for step definition: multiple players can act in the same turn (like in rock-paper-scissors). | |
| Agents can only see the messages that 1) were sent before the current turn, and 2) are visible to the current role. | |
| """ | |
| def __init__(self): | |
| """Initialize the MessagePool with a unique conversation ID.""" | |
| self.conversation_id = str(uuid1()) | |
| self._messages: List[ | |
| Message | |
| ] = [] | |
| self._last_message_idx = 0 | |
| def reset(self): | |
| """Clear the message pool.""" | |
| self._messages = [] | |
| def append_message(self, message: Message): | |
| """ | |
| Append a message to the pool. | |
| Parameters: | |
| message (Message): The message to be added to the pool. | |
| """ | |
| self._messages.append(message) | |
| def print(self): | |
| """Print all the messages in the pool.""" | |
| for message in self._messages: | |
| print(f"[{message.agent_name}->{message.visible_to}]: {message.content}") | |
| def last_turn(self): | |
| """ | |
| Get the turn of the last message in the pool. | |
| Returns: | |
| int: The turn of the last message. | |
| """ | |
| if len(self._messages) == 0: | |
| return 0 | |
| else: | |
| return self._messages[-1].turn | |
| def last_message(self): | |
| """ | |
| Get the last message in the pool. | |
| Returns: | |
| Message: The last message. | |
| """ | |
| if len(self._messages) == 0: | |
| return None | |
| else: | |
| return self._messages[-1] | |
| def get_all_messages(self) -> List[Message]: | |
| """ | |
| Get all the messages in the pool. | |
| Returns: | |
| List[Message]: A list of all messages. | |
| """ | |
| return self._messages | |
| def get_visible_messages(self, agent_name, turn: int) -> List[Message]: | |
| """ | |
| Get all the messages that are visible to a given agent before a specified turn. | |
| Parameters: | |
| agent_name (str): The name of the agent. | |
| turn (int): The specified turn. | |
| Returns: | |
| List[Message]: A list of visible messages. | |
| """ | |
| # Get the messages before the current turn | |
| prev_messages = [message for message in self._messages if message.turn < turn] | |
| visible_messages = [] | |
| for message in prev_messages: | |
| if ( | |
| message.visible_to == "all" | |
| or agent_name in message.visible_to | |
| or agent_name == "Moderator" | |
| ): | |
| visible_messages.append(message) | |
| return visible_messages | |