Spaces:
Sleeping
Sleeping
| # coding: utf-8 | |
| # Copyright (c) 2025 inclusionAI. | |
| import abc | |
| import json | |
| import time | |
| import traceback | |
| import uuid | |
| from collections import OrderedDict | |
| from typing import AsyncGenerator, Dict, Any, List, Union, Callable | |
| import aworld.trace as trace | |
| from aworld.config import ToolConfig | |
| from aworld.config.conf import AgentConfig, ConfigDict, ContextRuleConfig, ModelConfig, OptimizationConfig, \ | |
| LlmCompressionConfig | |
| from aworld.core.agent.agent_desc import get_agent_desc | |
| from aworld.core.agent.base import BaseAgent, AgentResult, is_agent_by_name, is_agent | |
| from aworld.core.common import Observation, ActionModel | |
| from aworld.core.context.base import AgentContext | |
| from aworld.core.context.base import Context | |
| from aworld.core.context.processor.prompt_processor import PromptProcessor | |
| from aworld.core.event import eventbus | |
| from aworld.core.event.base import Message, ToolMessage, Constants, AgentMessage | |
| from aworld.core.tool.base import ToolFactory, AsyncTool, Tool | |
| from aworld.core.memory import MemoryItem, MemoryConfig | |
| from aworld.core.tool.tool_desc import get_tool_desc | |
| from aworld.logs.util import logger, color_log, Color, trace_logger | |
| from aworld.mcp_client.utils import sandbox_mcp_tool_desc_transform | |
| from aworld.memory.main import MemoryFactory | |
| from aworld.models.llm import get_llm_model, call_llm_model, acall_llm_model, acall_llm_model_stream | |
| from aworld.models.model_response import ModelResponse, ToolCall | |
| from aworld.models.utils import tool_desc_transform, agent_desc_transform | |
| from aworld.output import Outputs | |
| from aworld.output.base import StepOutput, MessageOutput | |
| from aworld.runners.hook.hook_factory import HookFactory | |
| from aworld.runners.hook.hooks import HookPoint | |
| from aworld.utils.common import sync_exec, nest_dict_counter | |
| class Agent(BaseAgent[Observation, List[ActionModel]]): | |
| """Basic agent for unified protocol within the framework.""" | |
| def __init__(self, | |
| conf: Union[Dict[str, Any], ConfigDict, AgentConfig], | |
| resp_parse_func: Callable[..., Any] = None, | |
| **kwargs): | |
| """A api class implementation of agent, using the `Observation` and `List[ActionModel]` protocols. | |
| Args: | |
| conf: Agent config, supported AgentConfig, ConfigDict or dict. | |
| resp_parse_func: Response parse function for the agent standard output, transform llm response. | |
| """ | |
| super(Agent, self).__init__(conf, **kwargs) | |
| conf = self.conf | |
| self.model_name = conf.llm_config.llm_model_name if conf.llm_config.llm_model_name else conf.llm_model_name | |
| self._llm = None | |
| self.memory = MemoryFactory.from_config(MemoryConfig(provider="inmemory")) | |
| self.system_prompt: str = kwargs.pop("system_prompt") if kwargs.get("system_prompt") else conf.system_prompt | |
| self.agent_prompt: str = kwargs.get("agent_prompt") if kwargs.get("agent_prompt") else conf.agent_prompt | |
| self.event_driven = kwargs.pop('event_driven', conf.get('event_driven', False)) | |
| self.handler: Callable[..., Any] = kwargs.get('handler') | |
| self.need_reset = kwargs.get('need_reset') if kwargs.get('need_reset') else conf.need_reset | |
| # whether to keep contextual information, False means keep, True means reset in every step by the agent call | |
| self.step_reset = kwargs.get('step_reset') if kwargs.get('step_reset') else True | |
| # tool_name: [tool_action1, tool_action2, ...] | |
| self.black_tool_actions: Dict[str, List[str]] = kwargs.get("black_tool_actions") if kwargs.get( | |
| "black_tool_actions") else conf.get('black_tool_actions', {}) | |
| self.resp_parse_func = resp_parse_func if resp_parse_func else self.response_parse | |
| self.history_messages = kwargs.get("history_messages") if kwargs.get("history_messages") else 100 | |
| self.use_tools_in_prompt = kwargs.get('use_tools_in_prompt', conf.use_tools_in_prompt) | |
| self.context_rule = kwargs.get("context_rule") if kwargs.get("context_rule") else conf.context_rule | |
| self.tools_instances = {} | |
| self.tools_conf = {} | |
| def reset(self, options: Dict[str, Any]): | |
| super().reset(options) | |
| self.memory = MemoryFactory.from_config( | |
| MemoryConfig(provider=options.pop("memory_store") if options.get("memory_store") else "inmemory")) | |
| def set_tools_instances(self, tools, tools_conf): | |
| self.tools_instances = tools | |
| self.tools_conf = tools_conf | |
| def llm(self): | |
| # lazy | |
| if self._llm is None: | |
| llm_config = self.conf.llm_config or None | |
| conf = llm_config if llm_config and ( | |
| llm_config.llm_provider or llm_config.llm_base_url or llm_config.llm_api_key or llm_config.llm_model_name) else self.conf | |
| self._llm = get_llm_model(conf) | |
| return self._llm | |
| def _env_tool(self): | |
| """Description of agent as tool.""" | |
| return tool_desc_transform(get_tool_desc(), | |
| tools=self.tool_names if self.tool_names else [], | |
| black_tool_actions=self.black_tool_actions) | |
| def _handoffs_agent_as_tool(self): | |
| """Description of agent as tool.""" | |
| return agent_desc_transform(get_agent_desc(), | |
| agents=self.handoffs if self.handoffs else []) | |
| def _mcp_is_tool(self): | |
| """Description of mcp servers are tools.""" | |
| try: | |
| return sync_exec(sandbox_mcp_tool_desc_transform, self.mcp_servers, self.mcp_config) | |
| except Exception as e: | |
| logger.error(f"mcp_is_tool error: {traceback.format_exc()}") | |
| return [] | |
| def desc_transform(self): | |
| """Transform of descriptions of supported tools, agents, and MCP servers in the framework to support function calls of LLM.""" | |
| # Stateless tool | |
| self.tools = self._env_tool() | |
| # Agents as tool | |
| self.tools.extend(self._handoffs_agent_as_tool()) | |
| # MCP servers are tools | |
| self.tools.extend(self._mcp_is_tool()) | |
| # load to context | |
| self.agent_context.set_tools(self.tools) | |
| return self.tools | |
| async def async_desc_transform(self): | |
| """Transform of descriptions of supported tools, agents, and MCP servers in the framework to support function calls of LLM.""" | |
| # Stateless tool | |
| self.tools = self._env_tool() | |
| # Agents as tool | |
| self.tools.extend(self._handoffs_agent_as_tool()) | |
| # MCP servers are tools | |
| # todo sandbox | |
| if self.sandbox: | |
| sand_box = self.sandbox | |
| mcp_tools = await sand_box.mcpservers.list_tools() | |
| self.tools.extend(mcp_tools) | |
| else: | |
| self.tools.extend(await sandbox_mcp_tool_desc_transform(self.mcp_servers, self.mcp_config)) | |
| # load to agent context | |
| self.agent_context.set_tools(self.tools) | |
| def _messages_transform( | |
| self, | |
| observation: Observation, | |
| ): | |
| agent_prompt = self.agent_context.agent_prompt | |
| sys_prompt = self.agent_context.sys_prompt | |
| messages = [] | |
| if sys_prompt: | |
| messages.append( | |
| {'role': 'system', 'content': sys_prompt if not self.use_tools_in_prompt else sys_prompt.format( | |
| tool_list=self.tools)}) | |
| content = observation.content | |
| if agent_prompt and '{task}' in agent_prompt: | |
| content = agent_prompt.format(task=observation.content) | |
| cur_msg = {'role': 'user', 'content': content} | |
| # query from memory, | |
| # histories = self.memory.get_last_n(self.history_messages, filter={"session_id": self.context.session_id}) | |
| histories = self.memory.get_last_n(self.history_messages) | |
| messages.extend(histories) | |
| action_results = observation.action_result | |
| if action_results: | |
| for action_result in action_results: | |
| cur_msg['role'] = 'tool' | |
| cur_msg['tool_call_id'] = action_result.tool_id | |
| agent_info = self.context.context_info.get(self.id()) | |
| if (self.use_tools_in_prompt and "is_use_tool_prompt" in agent_info and "tool_calls" | |
| in agent_info and agent_prompt): | |
| cur_msg['content'] = agent_prompt.format(action_list=agent_info["tool_calls"], | |
| result=content) | |
| if observation.images: | |
| urls = [{'type': 'text', 'text': content}] | |
| for image_url in observation.images: | |
| urls.append({'type': 'image_url', 'image_url': {"url": image_url}}) | |
| cur_msg['content'] = urls | |
| messages.append(cur_msg) | |
| # truncate and other process | |
| try: | |
| messages = self._process_messages(messages=messages, agent_context=self.agent_context, context=self.context) | |
| except Exception as e: | |
| logger.warning(f"Failed to process messages in _messages_transform: {e}") | |
| logger.debug(f"Process messages error details: {traceback.format_exc()}") | |
| self.agent_context.update_messages(messages) | |
| return messages | |
| def messages_transform(self, | |
| content: str, | |
| image_urls: List[str] = None, | |
| **kwargs): | |
| """Transform the original content to LLM messages of native format. | |
| Args: | |
| content: User content. | |
| image_urls: List of images encoded using base64. | |
| sys_prompt: Agent system prompt. | |
| max_step: The maximum list length obtained from memory. | |
| Returns: | |
| Message list for LLM. | |
| """ | |
| sys_prompt = self.agent_context.system_prompt | |
| agent_prompt = self.agent_context.agent_prompt | |
| messages = [] | |
| if sys_prompt: | |
| messages.append( | |
| {'role': 'system', 'content': sys_prompt if not self.use_tools_in_prompt else sys_prompt.format( | |
| tool_list=self.tools)}) | |
| histories = self.memory.get_last_n(self.history_messages) | |
| user_content = content | |
| if not histories and agent_prompt and '{task}' in agent_prompt: | |
| user_content = agent_prompt.format(task=content) | |
| cur_msg = {'role': 'user', 'content': user_content} | |
| # query from memory, | |
| # histories = self.memory.get_last_n(self.history_messages, filter={"session_id": self.context.session_id}) | |
| if histories: | |
| # default use the first tool call | |
| for history in histories: | |
| if not self.use_tools_in_prompt and "tool_calls" in history.metadata and history.metadata['tool_calls']: | |
| messages.append({'role': history.metadata['role'], 'content': history.content, | |
| 'tool_calls': [history.metadata["tool_calls"][0]]}) | |
| else: | |
| messages.append({'role': history.metadata['role'], 'content': history.content, | |
| "tool_call_id": history.metadata.get("tool_call_id")}) | |
| if not self.use_tools_in_prompt and "tool_calls" in histories[-1].metadata and histories[-1].metadata[ | |
| 'tool_calls']: | |
| tool_id = histories[-1].metadata["tool_calls"][0].id | |
| if tool_id: | |
| cur_msg['role'] = 'tool' | |
| cur_msg['tool_call_id'] = tool_id | |
| if self.use_tools_in_prompt and "is_use_tool_prompt" in histories[-1].metadata and "tool_calls" in \ | |
| histories[-1].metadata and agent_prompt: | |
| cur_msg['content'] = agent_prompt.format(action_list=histories[-1].metadata["tool_calls"], | |
| result=content) | |
| if image_urls: | |
| urls = [{'type': 'text', 'text': content}] | |
| for image_url in image_urls: | |
| urls.append({'type': 'image_url', 'image_url': {"url": image_url}}) | |
| cur_msg['content'] = urls | |
| messages.append(cur_msg) | |
| # truncate and other process | |
| try: | |
| messages = self._process_messages(messages=messages, agent_context=self.agent_context, context=self.context) | |
| except Exception as e: | |
| logger.warning(f"Failed to process messages in messages_transform: {e}") | |
| logger.debug(f"Process messages error details: {traceback.format_exc()}") | |
| self.agent_context.set_messages(messages) | |
| return messages | |
| def use_tool_list(self, resp: ModelResponse) -> List[Dict[str, Any]]: | |
| tool_list = [] | |
| try: | |
| if resp and hasattr(resp, 'content') and resp.content: | |
| content = resp.content.strip() | |
| else: | |
| return tool_list | |
| content = content.replace('\n', '').replace('\r', '') | |
| response_json = json.loads(content) | |
| if "use_tool_list" in response_json: | |
| use_tool_list = response_json["use_tool_list"] | |
| if use_tool_list: | |
| for use_tool in use_tool_list: | |
| tool_name = use_tool["tool"] | |
| arguments = use_tool["arguments"] | |
| if tool_name and arguments: | |
| tool_list.append(use_tool) | |
| return tool_list | |
| except Exception as e: | |
| logger.debug(f"tool_parse error, content: {resp.content}, \nerror msg: {traceback.format_exc()}") | |
| return tool_list | |
| def response_parse(self, resp: ModelResponse) -> AgentResult: | |
| """Default parse response by LLM.""" | |
| results = [] | |
| if not resp: | |
| logger.warning("LLM no valid response!") | |
| return AgentResult(actions=[], current_state=None) | |
| use_tool_list = self.use_tool_list(resp) | |
| is_call_tool = False | |
| content = '' if resp.content is None else resp.content | |
| if resp.tool_calls: | |
| is_call_tool = True | |
| for tool_call in resp.tool_calls: | |
| full_name: str = tool_call.function.name | |
| if not full_name: | |
| logger.warning("tool call response no tool name.") | |
| continue | |
| try: | |
| params = json.loads(tool_call.function.arguments) | |
| except: | |
| logger.warning(f"{tool_call.function.arguments} parse to json fail.") | |
| params = {} | |
| # format in framework | |
| names = full_name.split("__") | |
| tool_name = names[0] | |
| if is_agent_by_name(tool_name): | |
| param_info = params.get('content', "") + ' ' + params.get('info', '') | |
| results.append(ActionModel(tool_name=tool_name, | |
| tool_id=tool_call.id, | |
| agent_name=self.id(), | |
| params=params, | |
| policy_info=content + param_info)) | |
| else: | |
| action_name = '__'.join(names[1:]) if len(names) > 1 else '' | |
| results.append(ActionModel(tool_name=tool_name, | |
| tool_id=tool_call.id, | |
| action_name=action_name, | |
| agent_name=self.id(), | |
| params=params, | |
| policy_info=content)) | |
| elif use_tool_list and len(use_tool_list) > 0: | |
| is_call_tool = True | |
| for use_tool in use_tool_list: | |
| full_name = use_tool["tool"] | |
| if not full_name: | |
| logger.warning("tool call response no tool name.") | |
| continue | |
| params = use_tool["arguments"] | |
| if not params: | |
| logger.warning("tool call response no tool params.") | |
| continue | |
| names = full_name.split("__") | |
| tool_name = names[0] | |
| if is_agent_by_name(tool_name): | |
| param_info = params.get('content', "") + ' ' + params.get('info', '') | |
| results.append(ActionModel(tool_name=tool_name, | |
| tool_id=use_tool.get('id'), | |
| agent_name=self.id(), | |
| params=params, | |
| policy_info=content + param_info)) | |
| else: | |
| action_name = '__'.join(names[1:]) if len(names) > 1 else '' | |
| results.append(ActionModel(tool_name=tool_name, | |
| tool_id=use_tool.get('id'), | |
| action_name=action_name, | |
| agent_name=self.id(), | |
| params=params, | |
| policy_info=content)) | |
| else: | |
| if content: | |
| content = content.replace("```json", "").replace("```", "") | |
| # no tool call, agent name is itself. | |
| results.append(ActionModel(agent_name=self.id(), policy_info=content)) | |
| return AgentResult(actions=results, current_state=None, is_call_tool=is_call_tool) | |
| def _log_messages(self, messages: List[Dict[str, Any]]) -> None: | |
| """Log the sequence of messages for debugging purposes""" | |
| logger.info(f"[agent] Invoking LLM with {len(messages)} messages:") | |
| for i, msg in enumerate(messages): | |
| prefix = msg.get('role') | |
| logger.info(f"[agent] Message {i + 1}: {prefix} ===================================") | |
| if isinstance(msg['content'], list): | |
| for item in msg['content']: | |
| if item.get('type') == 'text': | |
| logger.info(f"[agent] Text content: {item.get('text')}") | |
| elif item.get('type') == 'image_url': | |
| image_url = item.get('image_url', {}).get('url', '') | |
| if image_url.startswith('data:image'): | |
| logger.info(f"[agent] Image: [Base64 image data]") | |
| else: | |
| logger.info(f"[agent] Image URL: {image_url[:30]}...") | |
| else: | |
| content = str(msg['content']) | |
| chunk_size = 500 | |
| for j in range(0, len(content), chunk_size): | |
| chunk = content[j:j + chunk_size] | |
| if j == 0: | |
| logger.info(f"[agent] Content: {chunk}") | |
| else: | |
| logger.info(f"[agent] Content (continued): {chunk}") | |
| if 'tool_calls' in msg and msg['tool_calls']: | |
| for tool_call in msg.get('tool_calls'): | |
| if isinstance(tool_call, dict): | |
| logger.info(f"[agent] Tool call: {tool_call.get('name')} - ID: {tool_call.get('id')}") | |
| args = str(tool_call.get('args', {}))[:1000] | |
| logger.info(f"[agent] Tool args: {args}...") | |
| elif isinstance(tool_call, ToolCall): | |
| logger.info(f"[agent] Tool call: {tool_call.function.name} - ID: {tool_call.id}") | |
| args = str(tool_call.function.arguments)[:1000] | |
| logger.info(f"[agent] Tool args: {args}...") | |
| def _agent_result(self, actions: List[ActionModel], caller: str): | |
| if not actions: | |
| raise Exception(f'{self.id()} no action decision has been made.') | |
| tools = OrderedDict() | |
| agents = [] | |
| for action in actions: | |
| if is_agent(action): | |
| agents.append(action) | |
| else: | |
| if action.tool_name not in tools: | |
| tools[action.tool_name] = [] | |
| tools[action.tool_name].append(action) | |
| _group_name = None | |
| # agents and tools exist simultaneously, more than one agent/tool name | |
| if (agents and tools) or len(agents) > 1 or len(tools) > 1: | |
| _group_name = f"{self.id()}_{uuid.uuid1().hex}" | |
| # complex processing | |
| if _group_name: | |
| logger.warning(f"more than one agent an tool causing confusion, will choose the first one. {agents}") | |
| agents = [agents[0]] if agents else [] | |
| for _, v in tools.items(): | |
| actions = v | |
| break | |
| if agents: | |
| return AgentMessage(payload=actions, | |
| caller=caller, | |
| sender=self.id(), | |
| receiver=actions[0].tool_name, | |
| session_id=self.context.session_id if self.context else "", | |
| headers={"context": self.context}) | |
| else: | |
| return ToolMessage(payload=actions, | |
| caller=caller, | |
| sender=self.id(), | |
| receiver=actions[0].tool_name, | |
| session_id=self.context.session_id if self.context else "", | |
| headers={"context": self.context}) | |
| def post_run(self, policy_result: List[ActionModel], policy_input: Observation) -> Message: | |
| return self._agent_result( | |
| policy_result, | |
| policy_input.from_agent_name if policy_input.from_agent_name else policy_input.observer | |
| ) | |
| async def async_post_run(self, policy_result: List[ActionModel], policy_input: Observation) -> Message: | |
| return self._agent_result( | |
| policy_result, | |
| policy_input.from_agent_name if policy_input.from_agent_name else policy_input.observer | |
| ) | |
| def policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> List[ActionModel]: | |
| """The strategy of an agent can be to decide which tools to use in the environment, or to delegate tasks to other agents. | |
| Args: | |
| observation: The state observed from tools in the environment. | |
| info: Extended information is used to assist the agent to decide a policy. | |
| Returns: | |
| ActionModel sequence from agent policy | |
| """ | |
| output = None | |
| if kwargs.get("output") and isinstance(kwargs.get("output"), StepOutput): | |
| output = kwargs["output"] | |
| # Get current step information for trace recording | |
| step = kwargs.get("step", 0) | |
| exp_id = kwargs.get("exp_id", None) | |
| source_span = trace.get_current_span() | |
| if hasattr(observation, 'context') and observation.context: | |
| self.task_histories = observation.context | |
| try: | |
| self._run_hooks_sync(self.context, HookPoint.PRE_LLM_CALL) | |
| except Exception as e: | |
| logger.warn(traceback.format_exc()) | |
| self._finished = False | |
| self.desc_transform() | |
| images = observation.images if self.conf.use_vision else None | |
| if self.conf.use_vision and not images and observation.image: | |
| images = [observation.image] | |
| observation.images = images | |
| messages = self.messages_transform(content=observation.content, | |
| image_urls=observation.images) | |
| self._log_messages(messages) | |
| self.memory.add(MemoryItem( | |
| content=messages[-1]['content'], | |
| metadata={ | |
| "role": messages[-1]['role'], | |
| "agent_name": self.id(), | |
| "tool_call_id": messages[-1].get("tool_call_id") | |
| } | |
| )) | |
| llm_response = None | |
| span_name = f"llm_call_{exp_id}" | |
| serializable_messages = self._to_serializable(messages) | |
| with trace.span(span_name) as llm_span: | |
| llm_span.set_attributes({ | |
| "exp_id": exp_id, | |
| "step": step, | |
| "messages": json.dumps(serializable_messages, ensure_ascii=False) | |
| }) | |
| if source_span: | |
| source_span.set_attribute("messages", json.dumps(serializable_messages, ensure_ascii=False)) | |
| try: | |
| llm_response = call_llm_model( | |
| self.llm, | |
| messages=messages, | |
| model=self.model_name, | |
| temperature=self.conf.llm_config.llm_temperature, | |
| tools=self.tools if not self.use_tools_in_prompt and self.tools else None | |
| ) | |
| logger.info(f"Execute response: {llm_response.message}") | |
| except Exception as e: | |
| logger.warn(traceback.format_exc()) | |
| raise e | |
| finally: | |
| if llm_response: | |
| # update usage | |
| self.update_context_usage(used_context_length=llm_response.usage['total_tokens']) | |
| # update current step output | |
| self.update_llm_output(llm_response) | |
| use_tools = self.use_tool_list(llm_response) | |
| is_use_tool_prompt = len(use_tools) > 0 | |
| if llm_response.error: | |
| logger.info(f"llm result error: {llm_response.error}") | |
| else: | |
| info = { | |
| "role": "assistant", | |
| "agent_name": self.id(), | |
| "tool_calls": llm_response.tool_calls if not self.use_tools_in_prompt else use_tools, | |
| "is_use_tool_prompt": is_use_tool_prompt if not self.use_tools_in_prompt else False | |
| } | |
| self.memory.add(MemoryItem( | |
| content=llm_response.content, | |
| metadata=info | |
| )) | |
| # rewrite | |
| self.context.context_info[self.id()] = info | |
| else: | |
| logger.error(f"{self.id()} failed to get LLM response") | |
| raise RuntimeError(f"{self.id()} failed to get LLM response") | |
| try: | |
| self._run_hooks_sync(self.context, HookPoint.POST_LLM_CALL) | |
| except Exception as e: | |
| logger.warn(traceback.format_exc()) | |
| agent_result = sync_exec(self.resp_parse_func, llm_response) | |
| if not agent_result.is_call_tool: | |
| self._finished = True | |
| if output: | |
| output.add_part(MessageOutput(source=llm_response, json_parse=False)) | |
| output.mark_finished() | |
| return agent_result.actions | |
| async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> List[ActionModel]: | |
| """The strategy of an agent can be to decide which tools to use in the environment, or to delegate tasks to other agents. | |
| Args: | |
| observation: The state observed from tools in the environment. | |
| info: Extended information is used to assist the agent to decide a policy. | |
| Returns: | |
| ActionModel sequence from agent policy | |
| """ | |
| outputs = None | |
| if kwargs.get("outputs") and isinstance(kwargs.get("outputs"), Outputs): | |
| outputs = kwargs.get("outputs") | |
| # Get current step information for trace recording | |
| source_span = trace.get_current_span() | |
| if hasattr(observation, 'context') and observation.context: | |
| self.task_histories = observation.context | |
| try: | |
| events = [] | |
| async for event in self.run_hooks(self.context, HookPoint.PRE_LLM_CALL): | |
| events.append(event) | |
| except Exception as e: | |
| logger.warn(traceback.format_exc()) | |
| self._finished = False | |
| messages = await self._prepare_llm_input(observation, info, **kwargs) | |
| serializable_messages = self._to_serializable(messages) | |
| llm_response = None | |
| if source_span: | |
| source_span.set_attribute("messages", json.dumps(serializable_messages, ensure_ascii=False)) | |
| try: | |
| llm_response = await self._call_llm_model(observation, messages, info, **kwargs) | |
| except Exception as e: | |
| logger.warn(traceback.format_exc()) | |
| raise e | |
| finally: | |
| if llm_response: | |
| # update usage | |
| self.update_context_usage(used_context_length=llm_response.usage['total_tokens']) | |
| # update current step output | |
| self.update_llm_output(llm_response) | |
| use_tools = self.use_tool_list(llm_response) | |
| is_use_tool_prompt = len(use_tools) > 0 | |
| if llm_response.error: | |
| logger.info(f"llm result error: {llm_response.error}") | |
| else: | |
| self.memory.add(MemoryItem( | |
| content=llm_response.content, | |
| metadata={ | |
| "role": "assistant", | |
| "agent_name": self.id(), | |
| "tool_calls": llm_response.tool_calls if not self.use_tools_in_prompt else use_tools, | |
| "is_use_tool_prompt": is_use_tool_prompt if not self.use_tools_in_prompt else False | |
| } | |
| )) | |
| else: | |
| logger.error(f"{self.id()} failed to get LLM response") | |
| raise RuntimeError(f"{self.id()} failed to get LLM response") | |
| try: | |
| events = [] | |
| async for event in self.run_hooks(self.context, HookPoint.POST_LLM_CALL): | |
| events.append(event) | |
| except Exception as e: | |
| logger.warn(traceback.format_exc()) | |
| agent_result = sync_exec(self.resp_parse_func, llm_response) | |
| if not agent_result.is_call_tool: | |
| self._finished = True | |
| return agent_result.actions | |
| def _to_serializable(self, obj): | |
| if isinstance(obj, dict): | |
| return {k: self._to_serializable(v) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [self._to_serializable(i) for i in obj] | |
| elif hasattr(obj, "to_dict"): | |
| return obj.to_dict() | |
| elif hasattr(obj, "model_dump"): | |
| return obj.model_dump() | |
| elif hasattr(obj, "dict"): | |
| return obj.dict() | |
| else: | |
| return obj | |
| async def llm_and_tool_execution(self, observation: Observation, messages: List[Dict[str, str]] = [], | |
| info: Dict[str, Any] = {}, **kwargs) -> List[ActionModel]: | |
| """Perform combined LLM call and tool execution operations. | |
| Args: | |
| observation: The state observed from the environment | |
| info: Extended information to assist the agent in decision-making | |
| **kwargs: Other parameters | |
| Returns: | |
| ActionModel sequence. If a tool is executed, includes the tool execution result. | |
| """ | |
| # Get current step information for trace recording | |
| llm_response = await self._call_llm_model(observation, messages, info, **kwargs) | |
| if llm_response: | |
| use_tools = self.use_tool_list(llm_response) | |
| is_use_tool_prompt = len(use_tools) > 0 | |
| if llm_response.error: | |
| logger.info(f"llm result error: {llm_response.error}") | |
| else: | |
| self.memory.add(MemoryItem( | |
| content=llm_response.content, | |
| metadata={ | |
| "role": "assistant", | |
| "agent_name": self.id(), | |
| "tool_calls": llm_response.tool_calls if not self.use_tools_in_prompt else use_tools, | |
| "is_use_tool_prompt": is_use_tool_prompt if not self.use_tools_in_prompt else False | |
| } | |
| )) | |
| else: | |
| logger.error(f"{self.id()} failed to get LLM response") | |
| raise RuntimeError(f"{self.id()} failed to get LLM response") | |
| agent_result = sync_exec(self.resp_parse_func, llm_response) | |
| if not agent_result.is_call_tool: | |
| self._finished = True | |
| return agent_result.actions | |
| else: | |
| result = await self._execute_tool(agent_result.actions) | |
| return result | |
| async def _prepare_llm_input(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs): | |
| """Prepare LLM input | |
| Args: | |
| observation: The state observed from the environment | |
| info: Extended information to assist the agent in decision-making | |
| **kwargs: Other parameters | |
| """ | |
| await self.async_desc_transform() | |
| images = observation.images if self.conf.use_vision else None | |
| if self.conf.use_vision and not images and observation.image: | |
| images = [observation.image] | |
| messages = self.messages_transform(content=observation.content, | |
| image_urls=images) | |
| self._log_messages(messages) | |
| self.memory.add(MemoryItem( | |
| content=messages[-1]['content'], | |
| metadata={ | |
| "role": messages[-1]['role'], | |
| "agent_name": self.id(), | |
| "tool_call_id": messages[-1].get("tool_call_id") | |
| } | |
| )) | |
| return messages | |
| def _process_messages(self, messages: List[Dict[str, Any]], agent_context: AgentContext = None, | |
| context: Context = None) -> Message: | |
| origin_messages = messages | |
| st = time.time() | |
| with trace.span(f"llm_context_process", attributes={ | |
| "start_time": st | |
| }) as compress_span: | |
| if agent_context.context_rule is None: | |
| logger.debug('debug|skip process_messages context_rule is None') | |
| return messages | |
| origin_len = compressed_len = len(str(messages)) | |
| origin_messages_count = truncated_messages_count = len(messages) | |
| try: | |
| prompt_processor = PromptProcessor(agent_context) | |
| result = prompt_processor.process_messages(messages, context) | |
| messages = result.processed_messages | |
| compressed_len = len(str(messages)) | |
| truncated_messages_count = len(messages) | |
| logger.debug( | |
| f'debug|llm_context_process|{origin_len}|{compressed_len}|{origin_messages_count}|{truncated_messages_count}|\n|{origin_messages}\n|{messages}') | |
| return messages | |
| finally: | |
| compress_span.set_attributes({ | |
| "end_time": time.time(), | |
| "duration": time.time() - st, | |
| # messages length | |
| "origin_messages_count": origin_messages_count, | |
| "truncated_messages_count": truncated_messages_count, | |
| "truncated_ratio": round(truncated_messages_count / origin_messages_count, 2), | |
| # token length | |
| "origin_len": origin_len, | |
| "compressed_len": compressed_len, | |
| "compress_ratio": round(compressed_len / origin_len, 2) | |
| }) | |
| async def _call_llm_model(self, observation: Observation, messages: List[Dict[str, str]] = [], | |
| info: Dict[str, Any] = {}, **kwargs) -> ModelResponse: | |
| """Perform LLM call | |
| Args: | |
| observation: The state observed from the environment | |
| info: Extended information to assist the agent in decision-making | |
| **kwargs: Other parameters | |
| Returns: | |
| LLM response | |
| """ | |
| outputs = None | |
| if kwargs.get("outputs") and isinstance(kwargs.get("outputs"), Outputs): | |
| outputs = kwargs.get("outputs") | |
| if not messages: | |
| messages = await self._prepare_llm_input(observation, self.agent_context, **kwargs) | |
| llm_response = None | |
| source_span = trace.get_current_span() | |
| serializable_messages = self._to_serializable(messages) | |
| if source_span: | |
| source_span.set_attribute("messages", json.dumps(serializable_messages, ensure_ascii=False)) | |
| try: | |
| stream_mode = kwargs.get("stream", False) | |
| if stream_mode: | |
| llm_response = ModelResponse(id="", model="", content="", tool_calls=[]) | |
| resp_stream = acall_llm_model_stream( | |
| self.llm, | |
| messages=messages, | |
| model=self.model_name, | |
| temperature=self.conf.llm_config.llm_temperature, | |
| tools=self.tools if not self.use_tools_in_prompt and self.tools else None, | |
| stream=True | |
| ) | |
| async def async_call_llm(resp_stream, json_parse=False): | |
| llm_resp = ModelResponse(id="", model="", content="", tool_calls=[]) | |
| # Async streaming with acall_llm_model | |
| async def async_generator(): | |
| async for chunk in resp_stream: | |
| if chunk.content: | |
| llm_resp.content += chunk.content | |
| yield chunk.content | |
| if chunk.tool_calls: | |
| llm_resp.tool_calls.extend(chunk.tool_calls) | |
| if chunk.error: | |
| llm_resp.error = chunk.error | |
| llm_resp.id = chunk.id | |
| llm_resp.model = chunk.model | |
| llm_resp.usage = nest_dict_counter(llm_resp.usage, chunk.usage) | |
| return MessageOutput(source=async_generator(), json_parse=json_parse), llm_resp | |
| output, response = await async_call_llm(resp_stream) | |
| llm_response = response | |
| if eventbus is not None and resp_stream: | |
| output_message = Message( | |
| category=Constants.OUTPUT, | |
| payload=output, | |
| sender=self.id(), | |
| session_id=self.context.session_id if self.context else "", | |
| headers={"context": self.context} | |
| ) | |
| await eventbus.publish(output_message) | |
| elif not self.event_driven and outputs: | |
| outputs.add_output(output) | |
| else: | |
| llm_response = await acall_llm_model( | |
| self.llm, | |
| messages=messages, | |
| model=self.model_name, | |
| temperature=self.conf.llm_config.llm_temperature, | |
| tools=self.tools if not self.use_tools_in_prompt and self.tools else None, | |
| stream=kwargs.get("stream", False) | |
| ) | |
| if eventbus is None: | |
| logger.warn("=============== eventbus is none ============") | |
| if eventbus is not None and llm_response: | |
| await eventbus.publish(Message( | |
| category=Constants.OUTPUT, | |
| payload=llm_response, | |
| sender=self.id(), | |
| session_id=self.context.session_id if self.context else "", | |
| headers={"context": self.context} | |
| )) | |
| elif not self.event_driven and outputs: | |
| outputs.add_output(MessageOutput(source=llm_response, json_parse=False)) | |
| logger.info(f"Execute response: {json.dumps(llm_response.to_dict(), ensure_ascii=False)}") | |
| except Exception as e: | |
| logger.warn(traceback.format_exc()) | |
| raise e | |
| finally: | |
| return llm_response | |
| async def _execute_tool(self, actions: List[ActionModel]) -> Any: | |
| """Execute tool calls | |
| Args: | |
| action: The action(s) to execute | |
| Returns: | |
| The result of tool execution | |
| """ | |
| tool_actions = [] | |
| for act in actions: | |
| if is_agent(act): | |
| continue | |
| else: | |
| tool_actions.append(act) | |
| msg = None | |
| terminated = False | |
| # group action by tool name | |
| tool_mapping = dict() | |
| reward = 0.0 | |
| # Directly use or use tools after creation. | |
| for act in tool_actions: | |
| if not self.tools_instances or (self.tools_instances and act.tool_name not in self.tools): | |
| # Dynamically only use default config in module. | |
| conf = self.tools_conf.get(act.tool_name) | |
| if not conf: | |
| conf = ToolConfig(exit_on_failure=self.task.conf.get('exit_on_failure')) | |
| tool = ToolFactory(act.tool_name, conf=conf, asyn=conf.use_async if conf else False) | |
| if isinstance(tool, Tool): | |
| tool.reset() | |
| elif isinstance(tool, AsyncTool): | |
| await tool.reset() | |
| tool_mapping[act.tool_name] = [] | |
| self.tools_instances[act.tool_name] = tool | |
| if act.tool_name not in tool_mapping: | |
| tool_mapping[act.tool_name] = [] | |
| tool_mapping[act.tool_name].append(act) | |
| observation = None | |
| for tool_name, action in tool_mapping.items(): | |
| # Execute action using browser tool and unpack all return values | |
| if isinstance(self.tools_instances[tool_name], Tool): | |
| message = self.tools_instances[tool_name].step(action) | |
| elif isinstance(self.tools_instances[tool_name], AsyncTool): | |
| # todo sandbox | |
| message = await self.tools_instances[tool_name].step(action, agent=self) | |
| else: | |
| logger.warning(f"Unsupported tool type: {self.tools_instances[tool_name]}") | |
| continue | |
| observation, reward, terminated, _, info = message.payload | |
| # Check if there's an exception in info | |
| if info.get("exception"): | |
| color_log(f"Agent {self.id()} _execute_tool failed with exception: {info['exception']}", | |
| color=Color.red) | |
| msg = f"Agent {self.id()} _execute_tool failed with exception: {info['exception']}" | |
| logger.info(f"Agent {self.id()} _execute_tool finished by tool action: {action}.") | |
| log_ob = Observation(content='' if observation.content is None else observation.content, | |
| action_result=observation.action_result) | |
| trace_logger.info(f"{tool_name} observation: {log_ob}", color=Color.green) | |
| self.memory.add(MemoryItem( | |
| content=observation.content, | |
| metadata={ | |
| "role": "tool", | |
| "agent_name": self.id(), | |
| "tool_call_id": action[0].tool_id | |
| } | |
| )) | |
| return [ActionModel(agent_name=self.id(), policy_info=observation.content)] | |
| def _init_context(self, context: Context): | |
| super()._init_context(context) | |
| # Generate default configuration when context_rule is empty | |
| llm_config = self.conf.llm_config | |
| context_rule = self.context_rule | |
| if context_rule is None: | |
| context_rule = ContextRuleConfig( | |
| optimization_config=OptimizationConfig( | |
| enabled=True, | |
| max_token_budget_ratio=1.0 | |
| ), | |
| llm_compression_config=LlmCompressionConfig( | |
| enabled=False # Compression disabled by default | |
| ) | |
| ) | |
| self.agent_context.set_model_config(llm_config) | |
| self.agent_context.context_rule = context_rule | |
| self.agent_context.system_prompt = self.system_prompt | |
| self.agent_context.agent_prompt = self.agent_prompt | |
| logger.debug(f'init_context llm_agent {self.name()} {self.agent_context} {self.conf} {self.context_rule}') | |
| def update_system_prompt(self, system_prompt: str): | |
| self.system_prompt = system_prompt | |
| self.agent_context.system_prompt = system_prompt | |
| logger.info(f"Agent {self.name()} system_prompt updated") | |
| def update_agent_prompt(self, agent_prompt: str): | |
| self.agent_prompt = agent_prompt | |
| self.agent_context.agent_prompt = agent_prompt | |
| logger.info(f"Agent {self.name()} agent_prompt updated") | |
| def update_context_rule(self, context_rule: ContextRuleConfig): | |
| self.agent_context.context_rule = context_rule | |
| logger.info(f"Agent {self.name()} context_rule updated") | |
| def update_context_usage(self, used_context_length: int = None, total_context_length: int = None): | |
| self.agent_context.update_context_usage(used_context_length, total_context_length) | |
| logger.debug(f"Agent {self.name()} context usage updated: {self.agent_context.context_usage}") | |
| def update_llm_output(self, llm_response: ModelResponse): | |
| self.agent_context.set_llm_output(llm_response) | |
| logger.debug(f"Agent {self.name()} llm output updated: {self.agent_context.llm_output}") | |
| async def run_hooks(self, context: Context, hook_point: str): | |
| """Execute hooks asynchronously""" | |
| from aworld.runners.hook.hook_factory import HookFactory | |
| from aworld.core.event.base import Message | |
| # Get all hooks for the specified hook point | |
| all_hooks = HookFactory.hooks(hook_point) | |
| hooks = all_hooks.get(hook_point, []) | |
| for hook in hooks: | |
| try: | |
| # Create a temporary Message object to pass to the hook | |
| message = Message( | |
| category="agent_hook", | |
| payload=None, | |
| sender=self.id(), | |
| session_id=context.session_id if hasattr(context, 'session_id') else None, | |
| headers={"context": self.context} | |
| ) | |
| # Execute hook | |
| msg = await hook.exec(message, context) | |
| if msg: | |
| logger.debug(f"Hook {hook.point()} executed successfully") | |
| yield msg | |
| except Exception as e: | |
| logger.warning(f"Hook {hook.point()} execution failed: {traceback.format_exc()}") | |
| def _run_hooks_sync(self, context: Context, hook_point: str): | |
| """Execute hooks synchronously""" | |
| # Use sync_exec to execute asynchronous hook logic | |
| try: | |
| sync_exec(self.run_hooks, context, hook_point) | |
| except Exception as e: | |
| logger.warn(f"Failed to execute hooks for {hook_point}: {traceback.format_exc()}") | |