Spaces:
Sleeping
Sleeping
| # coding: utf-8 | |
| # Copyright (c) 2025 inclusionAI. | |
| import copy | |
| import json | |
| import traceback | |
| from typing import Dict, Any, List, Union | |
| from examples.tools.common import Agents | |
| from aworld.core.agent.base import AgentResult | |
| from aworld.agents.llm_agent import Agent | |
| from aworld.models.llm import call_llm_model | |
| from aworld.config.conf import AgentConfig, ConfigDict | |
| from aworld.core.common import Observation, ActionModel | |
| from aworld.logs.util import logger | |
| from examples.plan_execute.prompts import * | |
| from examples.plan_execute.utils import extract_pattern | |
| class ExecuteAgent(Agent): | |
| def __init__(self, conf: Union[Dict[str, Any], ConfigDict, AgentConfig], **kwargs): | |
| super(ExecuteAgent, self).__init__(conf, **kwargs) | |
| def id(self) -> str: | |
| return Agents.EXECUTE.value | |
| def reset(self, options: Dict[str, Any]): | |
| """Execute agent reset need query task as input.""" | |
| super().reset(options) | |
| self.system_prompt = execute_system_prompt.format(task=self.task) | |
| self.step_reset = False | |
| async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> Union[ | |
| List[ActionModel], None]: | |
| await self.async_desc_transform() | |
| return self._common(observation, info) | |
| def policy(self, | |
| observation: Observation, | |
| info: Dict[str, Any] = None, | |
| **kwargs) -> List[ActionModel] | None: | |
| self.desc_transform() | |
| return self._common(observation, info) | |
| def _common(self, observation, info): | |
| self._finished = False | |
| content = observation.content | |
| llm_result = None | |
| ## build input of llm | |
| input_content = [ | |
| {'role': 'system', 'content': self.system_prompt}, | |
| ] | |
| for traj in self.trajectory: | |
| # Handle multiple messages in content | |
| if isinstance(traj[0].content, list): | |
| input_content.extend(traj[0].content) | |
| else: | |
| input_content.append(traj[0].content) | |
| if traj[-1].tool_calls is not None: | |
| input_content.append( | |
| {'role': 'assistant', 'content': '', 'tool_calls': traj[-1].tool_calls}) | |
| else: | |
| input_content.append({'role': 'assistant', 'content': traj[-1].content}) | |
| if content is None: | |
| content = observation.action_result[0].error | |
| if not self.trajectory: | |
| new_messages = [{"role": "user", "content": content}] | |
| input_content.extend(new_messages) | |
| else: | |
| # Collect existing tool_call_ids from input_content | |
| existing_tool_call_ids = { | |
| msg.get("tool_call_id") for msg in input_content | |
| if msg.get("role") == "tool" and msg.get("tool_call_id") | |
| } | |
| new_messages = [] | |
| for traj in self.trajectory: | |
| if traj[-1].tool_calls is not None: | |
| # Handle multiple tool calls | |
| for tool_call in traj[-1].tool_calls: | |
| # Only add if this tool_call_id doesn't exist in input_content | |
| if tool_call.id not in existing_tool_call_ids: | |
| new_messages.append({ | |
| "role": "tool", | |
| "content": content, | |
| "tool_call_id": tool_call.id | |
| }) | |
| if new_messages: | |
| input_content.extend(new_messages) | |
| else: | |
| input_content.append({"role": "user", "content": content}) | |
| # Validate tool_calls and tool messages pairing | |
| assistant_tool_calls = [] | |
| tool_responses = [] | |
| for msg in input_content: | |
| if msg.get("role") == "assistant" and msg.get("tool_calls"): | |
| assistant_tool_calls.extend(msg["tool_calls"]) | |
| elif msg.get("role") == "tool": | |
| tool_responses.append(msg.get("tool_call_id")) | |
| # Check if all tool_calls have corresponding responses | |
| tool_call_ids = {call.id for call in assistant_tool_calls} | |
| tool_response_ids = set(tool_responses) | |
| if tool_call_ids != tool_response_ids: | |
| missing_calls = tool_call_ids - tool_response_ids | |
| extra_responses = tool_response_ids - tool_call_ids | |
| error_msg = f"Tool calls and responses mismatch. Missing responses for tool_calls: {missing_calls}, Extra responses: {extra_responses}" | |
| logger.error(error_msg) | |
| raise ValueError(error_msg) | |
| tool_calls = [] | |
| try: | |
| llm_result = call_llm_model(self.llm, input_content, model=self.model_name, | |
| tools=self.tools, temperature=0) | |
| logger.info(f"Execute response: {llm_result.message}") | |
| res = self.response_parse(llm_result) | |
| content = res.actions[0].policy_info | |
| tool_calls = llm_result.tool_calls | |
| except Exception as e: | |
| logger.warning(traceback.format_exc()) | |
| finally: | |
| if llm_result: | |
| ob = copy.deepcopy(observation) | |
| ob.content = new_messages | |
| self.trajectory.append((ob, info, llm_result)) | |
| else: | |
| logger.warning("no result to record!") | |
| res = [] | |
| if tool_calls: | |
| for tool_call in tool_calls: | |
| tool_action_name: str = tool_call.function.name | |
| if not tool_action_name: | |
| continue | |
| names = tool_action_name.split("__") | |
| tool_name = names[0] | |
| action_name = '__'.join(names[1:]) if len(names) > 1 else '' | |
| params = json.loads(tool_call.function.arguments) | |
| res.append(ActionModel(agent_name=Agents.EXECUTE.value, | |
| tool_name=tool_name, | |
| action_name=action_name, | |
| params=params)) | |
| if res: | |
| res[0].policy_info = content | |
| self._finished = False | |
| elif content: | |
| policy_info = extract_pattern(content, "final_answer") | |
| if policy_info: | |
| res.append(ActionModel(agent_name=Agents.EXECUTE.value, | |
| policy_info=policy_info)) | |
| self._finished = True | |
| else: | |
| res.append(ActionModel(agent_name=Agents.EXECUTE.value, | |
| policy_info=content)) | |
| logger.info(f">>> execute result: {res}") | |
| result = AgentResult(actions=res, | |
| current_state=None) | |
| return result.actions | |
| class PlanAgent(Agent): | |
| def __init__(self, conf: Union[Dict[str, Any], ConfigDict, AgentConfig], **kwargs): | |
| super(PlanAgent, self).__init__(conf, **kwargs) | |
| def id(self) -> str: | |
| return Agents.PLAN.value | |
| def reset(self, options: Dict[str, Any]): | |
| """Execute agent reset need query task as input.""" | |
| super().reset(options) | |
| self.system_prompt = plan_system_prompt.format(task=self.task) | |
| self.done_prompt = plan_done_prompt.format(task=self.task) | |
| self.postfix_prompt = plan_postfix_prompt.format(task=self.task) | |
| self.first_prompt = init_prompt | |
| self.first = True | |
| self.step_reset = False | |
| async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> Union[ | |
| List[ActionModel], None]: | |
| await self.async_desc_transform() | |
| return self._common(observation, info) | |
| def policy(self, | |
| observation: Observation, | |
| info: Dict[str, Any] = None, | |
| **kwargs) -> List[ActionModel] | None: | |
| self._finished = False | |
| self.desc_transform() | |
| return self._common(observation, info) | |
| def _common(self, observation, info): | |
| llm_result = None | |
| input_content = [ | |
| {'role': 'system', 'content': self.system_prompt}, | |
| ] | |
| # build input of llm based history | |
| for traj in self.trajectory: | |
| input_content.append({'role': 'user', 'content': traj[0].content}) | |
| # plan agent no tool to call, use content | |
| input_content.append({'role': 'assistant', 'content': traj[-1].content}) | |
| message = observation.content | |
| if self.first_prompt: | |
| message = self.first_prompt | |
| self.first_prompt = None | |
| input_content.append({"role": "user", "content": message}) | |
| try: | |
| llm_result = call_llm_model(self.llm, messages=input_content, model=self.model_name) | |
| logger.info(f"Plan response: {llm_result.message}") | |
| except Exception as e: | |
| logger.warning(traceback.format_exc()) | |
| raise e | |
| finally: | |
| if llm_result: | |
| ob = copy.deepcopy(observation) | |
| ob.content = message | |
| self.trajectory.append((ob, info, llm_result)) | |
| else: | |
| logger.warning("no result to record!") | |
| res = self.response_parse(llm_result) | |
| content = res.actions[0].policy_info | |
| if "TASK_DONE" not in content: | |
| content += self.done_prompt | |
| else: | |
| # The task is done, and the assistant agent need to give the final answer about the original task | |
| content += self.postfix_prompt | |
| if not self.first: | |
| self._finished = True | |
| self.first = False | |
| logger.info(f">>> plan result: {content}") | |
| result = AgentResult(actions=[ActionModel(agent_name=Agents.PLAN.value, | |
| tool_name=Agents.EXECUTE.value, | |
| policy_info=content)], | |
| current_state=None) | |
| return result.actions | |