| import json | |
| from typing import Callable, Dict, List, Union | |
| from pydantic import BaseModel, Field | |
| from lagent.actions import ActionExecutor, AsyncActionExecutor, BaseAction | |
| from lagent.agents.agent import Agent, AsyncAgent | |
| from lagent.agents.aggregator import DefaultAggregator | |
| from lagent.hooks import ActionPreprocessor | |
| from lagent.llms import BaseLLM | |
| from lagent.memory import Memory | |
| from lagent.prompts.parsers.json_parser import JSONParser | |
| from lagent.prompts.prompt_template import PromptTemplate | |
| from lagent.schema import AgentMessage | |
| select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括: | |
| {action_info} | |
| {output_format} | |
| 开始!""" | |
| output_format_template = """如果使用工具请遵循以下格式回复: | |
| {function_format} | |
| 如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复 | |
| {finish_format}""" | |
| class ReAct(Agent): | |
| def __init__( | |
| self, | |
| llm: Union[BaseLLM, Dict], | |
| actions: Union[BaseAction, List[BaseAction]], | |
| template: Union[PromptTemplate, str] = None, | |
| memory: Dict = dict(type=Memory), | |
| output_format: Dict = dict(type=JSONParser), | |
| aggregator: Dict = dict(type=DefaultAggregator), | |
| hooks: List = [dict(type=ActionPreprocessor)], | |
| finish_condition: Callable[[AgentMessage], bool] = lambda m: 'conclusion' in m.content | |
| or 'conclusion' in m.formatted, | |
| max_turn: int = 5, | |
| **kwargs | |
| ): | |
| self.max_turn = max_turn | |
| self.finish_condition = finish_condition | |
| self.actions = ActionExecutor(actions=actions, hooks=hooks) | |
| self.select_agent = Agent( | |
| llm=llm, | |
| template=template.format( | |
| action_info=json.dumps(self.actions.description()), output_format=output_format.format_instruction() | |
| ), | |
| output_format=output_format, | |
| memory=memory, | |
| aggregator=aggregator, | |
| hooks=hooks, | |
| ) | |
| super().__init__(**kwargs) | |
| def forward(self, message: AgentMessage, session_id=0, **kwargs) -> AgentMessage: | |
| for _ in range(self.max_turn): | |
| message = self.select_agent(message, session_id=session_id, **kwargs) | |
| if self.finish_condition(message): | |
| return message | |
| message = self.actions(message, session_id=session_id) | |
| return message | |
| class AsyncReAct(AsyncAgent): | |
| def __init__( | |
| self, | |
| llm: Union[BaseLLM, Dict], | |
| actions: Union[BaseAction, List[BaseAction]], | |
| template: Union[PromptTemplate, str] = None, | |
| memory: Dict = dict(type=Memory), | |
| output_format: Dict = dict(type=JSONParser), | |
| aggregator: Dict = dict(type=DefaultAggregator), | |
| hooks: List = [dict(type=ActionPreprocessor)], | |
| finish_condition: Callable[[AgentMessage], bool] = lambda m: 'conclusion' in m.content | |
| or 'conclusion' in m.formatted, | |
| max_turn: int = 5, | |
| **kwargs | |
| ): | |
| self.max_turn = max_turn | |
| self.finish_condition = finish_condition | |
| self.actions = AsyncActionExecutor(actions=actions, hooks=hooks) | |
| self.select_agent = AsyncAgent( | |
| llm=llm, | |
| template=template.format( | |
| action_info=json.dumps(self.actions.description()), output_format=output_format.format_instruction() | |
| ), | |
| output_format=output_format, | |
| memory=memory, | |
| aggregator=aggregator, | |
| hooks=hooks, | |
| ) | |
| super().__init__(**kwargs) | |
| async def forward(self, message: AgentMessage, session_id=0, **kwargs) -> AgentMessage: | |
| for _ in range(self.max_turn): | |
| message = await self.select_agent(message, session_id=session_id, **kwargs) | |
| if self.finish_condition(message): | |
| return message | |
| message = await self.actions(message, session_id=session_id) | |
| return message | |
| if __name__ == '__main__': | |
| import asyncio | |
| from lagent.llms import GPTAPI, AsyncGPTAPI | |
| class ActionCall(BaseModel): | |
| name: str = Field(description='调用的函数名称') | |
| parameters: Dict = Field(description='调用函数的参数') | |
| class ActionFormat(BaseModel): | |
| thought_process: str = Field( | |
| description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。' | |
| ) | |
| action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名称和参数。') | |
| class FinishFormat(BaseModel): | |
| thought_process: str = Field( | |
| description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。' | |
| ) | |
| conclusion: str = Field(description='总结当前的搜索结果,回答问题。') | |
| prompt_template = PromptTemplate(select_action_template) | |
| output_format = JSONParser(output_format_template, function_format=ActionFormat, finish_format=FinishFormat) | |
| agent = ReAct( | |
| llm=dict( | |
| type=GPTAPI, | |
| model_type='gpt-4o-2024-05-13', | |
| max_new_tokens=4096, | |
| proxies=dict(), | |
| retry=1000, | |
| ), | |
| template=prompt_template, | |
| output_format=output_format, | |
| aggregator=dict(type='lagent.agents.aggregator.DefaultAggregator'), | |
| actions=[dict(type='lagent.actions.PythonInterpreter')], | |
| ) | |
| response = agent(AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5')) | |
| print(response) | |
| response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢')) | |
| print(response) | |
| async_agent = AsyncReAct( | |
| llm=dict( | |
| type=AsyncGPTAPI, | |
| model_type='gpt-4o-2024-05-13', | |
| max_new_tokens=4096, | |
| proxies=dict(), | |
| retry=1000, | |
| ), | |
| template=prompt_template, | |
| output_format=output_format, | |
| aggregator=dict(type='lagent.agents.aggregator.DefaultAggregator'), | |
| actions=[dict(type='lagent.actions.AsyncPythonInterpreter')], | |
| ) | |
| response = asyncio.run(async_agent(AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5'))) | |
| print(async_agent.state_dict()) | |