| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | from dataclasses import dataclass |
| | from typing import Any, Dict, List, Optional |
| |
|
| | from tenacity import retry |
| | from tenacity.stop import stop_after_attempt |
| | from tenacity.wait import wait_exponential |
| |
|
| | from camel.agents import BaseAgent |
| | from camel.configs import ChatGPTConfig |
| | from camel.messages import ChatMessage, MessageType, SystemMessage |
| | from camel.model_backend import ModelBackend, ModelFactory |
| | from camel.typing import ModelType, RoleType |
| | from camel.utils import ( |
| | get_model_token_limit, |
| | num_tokens_from_messages, |
| | openai_api_key_required, |
| | ) |
| |
|
| |
|
| | @dataclass(frozen=True) |
| | class ChatAgentResponse: |
| | r"""Response of a ChatAgent. |
| | |
| | Attributes: |
| | msgs (List[ChatMessage]): A list of zero, one or several messages. |
| | If the list is empty, there is some error in message generation. |
| | If the list has one message, this is normal mode. |
| | If the list has several messages, this is the critic mode. |
| | terminated (bool): A boolean indicating whether the agent decided |
| | to terminate the chat session. |
| | info (Dict[str, Any]): Extra information about the chat message. |
| | """ |
| | msgs: List[ChatMessage] |
| | terminated: bool |
| | info: Dict[str, Any] |
| |
|
| | @property |
| | def msg(self): |
| | if self.terminated: |
| | raise RuntimeError("error in ChatAgentResponse, info:{}".format(str(self.info))) |
| | if len(self.msgs) > 1: |
| | raise RuntimeError("Property msg is only available for a single message in msgs") |
| | elif len(self.msgs) == 0: |
| | if len(self.info) > 0: |
| | raise RuntimeError("Empty msgs in ChatAgentResponse, info:{}".format(str(self.info))) |
| | else: |
| | |
| | return None |
| | return self.msgs[0] |
| |
|
| |
|
| | class ChatAgent(BaseAgent): |
| | r"""Class for managing conversations of CAMEL Chat Agents. |
| | |
| | Args: |
| | system_message (SystemMessage): The system message for the chat agent. |
| | model (ModelType, optional): The LLM model to use for generating |
| | responses. (default :obj:`ModelType.GPT_3_5_TURBO`) |
| | model_config (Any, optional): Configuration options for the LLM model. |
| | (default: :obj:`None`) |
| | message_window_size (int, optional): The maximum number of previous |
| | messages to include in the context window. If `None`, no windowing |
| | is performed. (default: :obj:`None`) |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | system_message: SystemMessage, |
| | model: Optional[ModelType] = None, |
| | model_config: Optional[Any] = None, |
| | message_window_size: Optional[int] = None, |
| | ) -> None: |
| |
|
| | self.system_message: SystemMessage = system_message |
| | self.role_name: str = system_message.role_name |
| | self.role_type: RoleType = system_message.role_type |
| | self.model: ModelType = (model if model is not None else ModelType.GPT_3_5_TURBO) |
| | self.model_config: ChatGPTConfig = model_config or ChatGPTConfig() |
| | self.model_token_limit: int = get_model_token_limit(self.model) |
| | self.message_window_size: Optional[int] = message_window_size |
| | self.model_backend: ModelBackend = ModelFactory.create(self.model, self.model_config.__dict__) |
| | self.terminated: bool = False |
| | self.info: bool = False |
| | self.init_messages() |
| |
|
| | def reset(self) -> List[MessageType]: |
| | r"""Resets the :obj:`ChatAgent` to its initial state and returns the |
| | stored messages. |
| | |
| | Returns: |
| | List[MessageType]: The stored messages. |
| | """ |
| | self.terminated = False |
| | self.init_messages() |
| | return self.stored_messages |
| |
|
| | def get_info( |
| | self, |
| | id: Optional[str], |
| | usage: Optional[Dict[str, int]], |
| | termination_reasons: List[str], |
| | num_tokens: int, |
| | ) -> Dict[str, Any]: |
| | r"""Returns a dictionary containing information about the chat session. |
| | |
| | Args: |
| | id (str, optional): The ID of the chat session. |
| | usage (Dict[str, int], optional): Information about the usage of |
| | the LLM model. |
| | termination_reasons (List[str]): The reasons for the termination of |
| | the chat session. |
| | num_tokens (int): The number of tokens used in the chat session. |
| | |
| | Returns: |
| | Dict[str, Any]: The chat session information. |
| | """ |
| | return { |
| | "id": id, |
| | "usage": usage, |
| | "termination_reasons": termination_reasons, |
| | "num_tokens": num_tokens, |
| | } |
| |
|
| | def init_messages(self) -> None: |
| | r"""Initializes the stored messages list with the initial system |
| | message. |
| | """ |
| | self.stored_messages: List[MessageType] = [self.system_message] |
| |
|
| | def update_messages(self, message: ChatMessage) -> List[MessageType]: |
| | r"""Updates the stored messages list with a new message. |
| | |
| | Args: |
| | message (ChatMessage): The new message to add to the stored |
| | messages. |
| | |
| | Returns: |
| | List[ChatMessage]: The updated stored messages. |
| | """ |
| | self.stored_messages.append(message) |
| | return self.stored_messages |
| |
|
| | @retry(wait=wait_exponential(min=5, max=60), stop=stop_after_attempt(5)) |
| | @openai_api_key_required |
| | def step( |
| | self, |
| | input_message: ChatMessage, |
| | ) -> ChatAgentResponse: |
| | r"""Performs a single step in the chat session by generating a response |
| | to the input message. |
| | |
| | Args: |
| | input_message (ChatMessage): The input message to the agent. |
| | |
| | Returns: |
| | ChatAgentResponse: A struct |
| | containing the output messages, a boolean indicating whether |
| | the chat session has terminated, and information about the chat |
| | session. |
| | """ |
| | messages = self.update_messages(input_message) |
| | if self.message_window_size is not None and len( |
| | messages) > self.message_window_size: |
| | messages = [self.system_message |
| | ] + messages[-self.message_window_size:] |
| | openai_messages = [message.to_openai_message() for message in messages] |
| | num_tokens = num_tokens_from_messages(openai_messages, self.model) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | output_messages: Optional[List[ChatMessage]] |
| | info: Dict[str, Any] |
| |
|
| | if num_tokens < self.model_token_limit: |
| | response = self.model_backend.run(messages=openai_messages) |
| | if not isinstance(response, dict): |
| | raise RuntimeError("OpenAI returned unexpected struct") |
| | output_messages = [ |
| | ChatMessage(role_name=self.role_name, role_type=self.role_type, |
| | meta_dict=dict(), **dict(choice["message"])) |
| | for choice in response["choices"] |
| | ] |
| | info = self.get_info( |
| | response["id"], |
| | response["usage"], |
| | [str(choice["finish_reason"]) for choice in response["choices"]], |
| | num_tokens, |
| | ) |
| |
|
| | |
| | |
| | if output_messages[0].content.split("\n")[-1].startswith("<INFO>"): |
| | self.info = True |
| | else: |
| | self.terminated = True |
| | output_messages = [] |
| |
|
| | info = self.get_info( |
| | None, |
| | None, |
| | ["max_tokens_exceeded_by_camel"], |
| | num_tokens, |
| | ) |
| |
|
| | return ChatAgentResponse(output_messages, self.terminated, info) |
| |
|
| | def __repr__(self) -> str: |
| | r"""Returns a string representation of the :obj:`ChatAgent`. |
| | |
| | Returns: |
| | str: The string representation of the :obj:`ChatAgent`. |
| | """ |
| | return f"ChatAgent({self.role_name}, {self.role_type}, {self.model})" |
| |
|