Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| import sys | |
| from typing import Dict, List, Optional, Union | |
| import logging | |
| from autogen import Agent, ConversableAgent | |
| logger = logging.getLogger(__name__) | |
| class GroupChat: | |
| """A group chat class that contains the following data fields: | |
| - agents: a list of participating agents. | |
| - messages: a list of messages in the group chat. | |
| - max_round: the maximum number of rounds. | |
| - admin_name: the name of the admin agent if there is one. Default is "Admin". | |
| KeyBoardInterrupt will make the admin agent take over. | |
| - func_call_filter: whether to enforce function call filter. Default is True. | |
| When set to True and when a message is a function call suggestion, | |
| the next speaker will be chosen from an agent which contains the corresponding function name | |
| in its `function_map`. | |
| """ | |
| agents: List[Agent] | |
| messages: List[Dict] | |
| max_round: int = 10 | |
| admin_name: str = "Admin" | |
| func_call_filter: bool = True | |
| def agent_names(self) -> List[str]: | |
| """Return the names of the agents in the group chat.""" | |
| return [agent.name for agent in self.agents] | |
| def reset(self): | |
| """Reset the group chat.""" | |
| self.messages.clear() | |
| def agent_by_name(self, name: str) -> Agent: | |
| """Find the next speaker based on the message.""" | |
| return self.agents[self.agent_names.index(name)] | |
| def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent: | |
| """Return the next agent in the list.""" | |
| if agents == self.agents: | |
| return agents[(self.agent_names.index(agent.name) + 1) % len(agents)] | |
| else: | |
| offset = self.agent_names.index(agent.name) + 1 | |
| for i in range(len(self.agents)): | |
| if self.agents[(offset + i) % len(self.agents)] in agents: | |
| return self.agents[(offset + i) % len(self.agents)] | |
| def select_speaker_msg(self, agents: List[Agent]): | |
| """Return the message for selecting the next speaker.""" | |
| return f"""You are in a role play game. The following roles are available: | |
| {self._participant_roles()}. | |
| Ignoring the order in which the above roles appear. | |
| Think about the dependency relationships between different roles. | |
| Read the following conversation. | |
| Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.""" | |
| def select_speaker(self, last_speaker: Agent, selector: ConversableAgent): | |
| """Select the next speaker.""" | |
| if self.func_call_filter and self.messages and "function_call" in self.messages[-1]: | |
| # find agents with the right function_map which contains the function name | |
| agents = [ | |
| agent for agent in self.agents if agent.can_execute_function(self.messages[-1]["function_call"]["name"]) | |
| ] | |
| if len(agents) == 1: | |
| # only one agent can execute the function | |
| return agents[0] | |
| elif not agents: | |
| # find all the agents with function_map | |
| agents = [agent for agent in self.agents if agent.function_map] | |
| if len(agents) == 1: | |
| return agents[0] | |
| elif not agents: | |
| raise ValueError( | |
| f"No agent can execute the function {self.messages[-1]['name']}. " | |
| "Please check the function_map of the agents." | |
| ) | |
| else: | |
| agents = self.agents | |
| # Warn if GroupChat is underpopulated | |
| n_agents = len(agents) | |
| if n_agents < 3: | |
| logger.warning( | |
| f"GroupChat is underpopulated with {n_agents} agents. Direct communication would be more efficient." | |
| ) | |
| selector.update_system_message(self.select_speaker_msg(agents)) | |
| prompt = self.messages[-5:] + [{ | |
| "role": "system", | |
| "content": f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.", | |
| }] | |
| print(prompt) | |
| final, name = selector.generate_oai_reply( | |
| # 根据前五次对话选择下一个发言人 | |
| prompt | |
| ) | |
| if not final: | |
| # i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id | |
| return self.next_agent(last_speaker, agents) | |
| try: | |
| return self.agent_by_name(name) | |
| except ValueError: | |
| return self.next_agent(last_speaker, agents) | |
| def _participant_roles(self): | |
| return "\n".join([f"{agent.name}: {agent.system_message}" for agent in self.agents]) | |
| class GroupChatManager(ConversableAgent): | |
| """(In preview) A chat manager agent that can manage a group chat of multiple agents.""" | |
| def __init__( | |
| self, | |
| groupchat: GroupChat, | |
| name: Optional[str] = "chat_manager", | |
| max_consecutive_auto_reply: Optional[int] = sys.maxsize, | |
| human_input_mode: Optional[str] = "NEVER", | |
| system_message: Optional[str] = "Group chat manager.", | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| name=name, | |
| max_consecutive_auto_reply=max_consecutive_auto_reply, | |
| human_input_mode=human_input_mode, | |
| system_message=system_message, | |
| **kwargs, | |
| ) | |
| self.groupchat = groupchat | |
| self.update_system_message(self.groupchat.select_speaker_msg(self.groupchat.agents)) | |
| def broadcast( | |
| self, | |
| message: Optional[str] = None, | |
| sender: Optional[Agent] = None, | |
| ) -> Union[str, Dict, None]: | |
| for agent in self.groupchat.agents: | |
| if agent != sender: | |
| self.send(message, agent, request_reply=False) | |