| | import inspect |
| | from collections import OrderedDict |
| | from typing import Callable, Dict, List, Union |
| |
|
| | from lagent.actions.base_action import BaseAction |
| | from lagent.actions.builtin_actions import FinishAction, InvalidAction, NoAction |
| | from lagent.hooks import Hook, RemovableHandle |
| | from lagent.schema import ActionReturn, ActionValidCode, AgentMessage, FunctionCall |
| | from lagent.utils import create_object |
| |
|
| |
|
| | class ActionExecutor: |
| | """The action executor class. |
| | |
| | Args: |
| | actions (Union[BaseAction, List[BaseAction]]): The action or actions. |
| | invalid_action (BaseAction, optional): The invalid action. Defaults to |
| | InvalidAction(). |
| | no_action (BaseAction, optional): The no action. |
| | Defaults to NoAction(). |
| | finish_action (BaseAction, optional): The finish action. Defaults to |
| | FinishAction(). |
| | finish_in_action (bool, optional): Whether the finish action is in the |
| | action list. Defaults to False. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | actions: Union[BaseAction, List[BaseAction], Dict, List[Dict]], |
| | invalid_action: BaseAction = dict(type=InvalidAction), |
| | no_action: BaseAction = dict(type=NoAction), |
| | finish_action: BaseAction = dict(type=FinishAction), |
| | finish_in_action: bool = False, |
| | hooks: List[Dict] = None, |
| | ): |
| |
|
| | if not isinstance(actions, list): |
| | actions = [actions] |
| | finish_action = create_object(finish_action) |
| | if finish_in_action: |
| | actions.append(finish_action) |
| | for i, action in enumerate(actions): |
| | actions[i] = create_object(action) |
| | self.actions = {action.name: action for action in actions} |
| |
|
| | self.invalid_action = create_object(invalid_action) |
| | self.no_action = create_object(no_action) |
| | self.finish_action = finish_action |
| | self._hooks: Dict[int, Hook] = OrderedDict() |
| | if hooks: |
| | for hook in hooks: |
| | hook = create_object(hook) |
| | self.register_hook(hook) |
| |
|
| | def description(self) -> List[Dict]: |
| | actions = [] |
| | for action_name, action in self.actions.items(): |
| | if action.is_toolkit: |
| | for api in action.description['api_list']: |
| | api_desc = api.copy() |
| | api_desc['name'] = f"{action_name}.{api_desc['name']}" |
| | actions.append(api_desc) |
| | else: |
| | action_desc = action.description.copy() |
| | actions.append(action_desc) |
| | return actions |
| |
|
| | def __contains__(self, name: str): |
| | return name in self.actions |
| |
|
| | def keys(self): |
| | return list(self.actions.keys()) |
| |
|
| | def __setitem__(self, name: str, action: Union[BaseAction, Dict]): |
| | action = create_object(action) |
| | self.actions[action.name] = action |
| |
|
| | def __delitem__(self, name: str): |
| | del self.actions[name] |
| |
|
| | def forward(self, name, parameters, **kwargs) -> ActionReturn: |
| | action_name, api_name = ( |
| | name.split('.') if '.' in name else (name, 'run')) |
| | action_return: ActionReturn = ActionReturn() |
| | if action_name not in self: |
| | if name == self.no_action.name: |
| | action_return = self.no_action(parameters) |
| | elif name == self.finish_action.name: |
| | action_return = self.finish_action(parameters) |
| | else: |
| | action_return = self.invalid_action(parameters) |
| | else: |
| | action_return = self.actions[action_name](parameters, api_name) |
| | action_return.valid = ActionValidCode.OPEN |
| | return action_return |
| |
|
| | def __call__(self, |
| | message: AgentMessage, |
| | session_id=0, |
| | **kwargs) -> AgentMessage: |
| | |
| | for hook in self._hooks.values(): |
| | result = hook.before_action(self, message, session_id) |
| | if result: |
| | message = result |
| |
|
| | assert isinstance(message.content, FunctionCall) or ( |
| | isinstance(message.content, dict) and 'name' in message.content |
| | and 'parameters' in message.content) |
| | if isinstance(message.content, dict): |
| | name = message.content.get('name') |
| | parameters = message.content.get('parameters') |
| | else: |
| | name = message.content.name |
| | parameters = message.content.parameters |
| |
|
| | response_message = self.forward( |
| | name=name, parameters=parameters, **kwargs) |
| | if not isinstance(response_message, AgentMessage): |
| | response_message = AgentMessage( |
| | sender=self.__class__.__name__, |
| | content=response_message, |
| | ) |
| |
|
| | for hook in self._hooks.values(): |
| | result = hook.after_action(self, response_message, session_id) |
| | if result: |
| | response_message = result |
| | return response_message |
| |
|
| | def register_hook(self, hook: Callable): |
| | handle = RemovableHandle(self._hooks) |
| | self._hooks[handle.id] = hook |
| | return handle |
| |
|
| |
|
| | class AsyncActionExecutor(ActionExecutor): |
| |
|
| | async def forward(self, name, parameters, **kwargs) -> ActionReturn: |
| | action_name, api_name = ( |
| | name.split('.') if '.' in name else (name, 'run')) |
| | action_return: ActionReturn = ActionReturn() |
| | if action_name not in self: |
| | if name == self.no_action.name: |
| | action_return = self.no_action(parameters) |
| | elif name == self.finish_action.name: |
| | action_return = self.finish_action(parameters) |
| | else: |
| | action_return = self.invalid_action(parameters) |
| | else: |
| | action = self.actions[action_name] |
| | if inspect.iscoroutinefunction(action.__call__): |
| | action_return = await action(parameters, api_name) |
| | else: |
| | action_return = action(parameters, api_name) |
| | action_return.valid = ActionValidCode.OPEN |
| | return action_return |
| |
|
| | async def __call__(self, |
| | message: AgentMessage, |
| | session_id=0, |
| | **kwargs) -> AgentMessage: |
| | |
| | for hook in self._hooks.values(): |
| | if inspect.iscoroutinefunction(hook.before_action): |
| | result = await hook.before_action(self, message, session_id) |
| | else: |
| | result = hook.before_action(self, message, session_id) |
| | if result: |
| | message = result |
| |
|
| | assert isinstance(message.content, FunctionCall) or ( |
| | isinstance(message.content, dict) and 'name' in message.content |
| | and 'parameters' in message.content) |
| | if isinstance(message.content, dict): |
| | name = message.content.get('name') |
| | parameters = message.content.get('parameters') |
| | else: |
| | name = message.content.name |
| | parameters = message.content.parameters |
| |
|
| | response_message = await self.forward( |
| | name=name, parameters=parameters, **kwargs) |
| | if not isinstance(response_message, AgentMessage): |
| | response_message = AgentMessage( |
| | sender=self.__class__.__name__, |
| | content=response_message, |
| | ) |
| |
|
| | for hook in self._hooks.values(): |
| | if inspect.iscoroutinefunction(hook.after_action): |
| | result = await hook.after_action(self, response_message, |
| | session_id) |
| | else: |
| | result = hook.after_action(self, response_message, session_id) |
| | if result: |
| | response_message = result |
| | return response_message |
| |
|