| import inspect |
| from collections import OrderedDict |
| from typing import Callable, Dict, List, Union |
|
|
| from lagent.actions.base_action import BaseAction |
| from lagent.actions.builtin_actions import FinishAction, InvalidAction, NoAction |
| from lagent.hooks import Hook, RemovableHandle |
| from lagent.schema import ActionReturn, ActionValidCode, AgentMessage, FunctionCall |
| from lagent.utils import create_object |
|
|
|
|
| class ActionExecutor: |
| """The action executor class. |
| |
| Args: |
| actions (Union[BaseAction, List[BaseAction]]): The action or actions. |
| invalid_action (BaseAction, optional): The invalid action. Defaults to |
| InvalidAction(). |
| no_action (BaseAction, optional): The no action. |
| Defaults to NoAction(). |
| finish_action (BaseAction, optional): The finish action. Defaults to |
| FinishAction(). |
| finish_in_action (bool, optional): Whether the finish action is in the |
| action list. Defaults to False. |
| """ |
|
|
| def __init__( |
| self, |
| actions: Union[BaseAction, List[BaseAction], Dict, List[Dict]], |
| invalid_action: BaseAction = dict(type=InvalidAction), |
| no_action: BaseAction = dict(type=NoAction), |
| finish_action: BaseAction = dict(type=FinishAction), |
| finish_in_action: bool = False, |
| hooks: List[Dict] = None, |
| ): |
|
|
| if not isinstance(actions, list): |
| actions = [actions] |
| finish_action = create_object(finish_action) |
| if finish_in_action: |
| actions.append(finish_action) |
| for i, action in enumerate(actions): |
| actions[i] = create_object(action) |
| self.actions = {action.name: action for action in actions} |
|
|
| self.invalid_action = create_object(invalid_action) |
| self.no_action = create_object(no_action) |
| self.finish_action = finish_action |
| self._hooks: Dict[int, Hook] = OrderedDict() |
| if hooks: |
| for hook in hooks: |
| hook = create_object(hook) |
| self.register_hook(hook) |
|
|
| def description(self) -> List[Dict]: |
| actions = [] |
| for action_name, action in self.actions.items(): |
| if action.is_toolkit: |
| for api in action.description['api_list']: |
| api_desc = api.copy() |
| api_desc['name'] = f"{action_name}.{api_desc['name']}" |
| actions.append(api_desc) |
| else: |
| action_desc = action.description.copy() |
| actions.append(action_desc) |
| return actions |
|
|
| def __contains__(self, name: str): |
| return name in self.actions |
|
|
| def keys(self): |
| return list(self.actions.keys()) |
|
|
| def __setitem__(self, name: str, action: Union[BaseAction, Dict]): |
| action = create_object(action) |
| self.actions[action.name] = action |
|
|
| def __delitem__(self, name: str): |
| del self.actions[name] |
|
|
| def forward(self, name, parameters, **kwargs) -> ActionReturn: |
| action_name, api_name = ( |
| name.split('.') if '.' in name else (name, 'run')) |
| action_return: ActionReturn = ActionReturn() |
| if action_name not in self: |
| if name == self.no_action.name: |
| action_return = self.no_action(parameters) |
| elif name == self.finish_action.name: |
| action_return = self.finish_action(parameters) |
| else: |
| action_return = self.invalid_action(parameters) |
| else: |
| action_return = self.actions[action_name](parameters, api_name) |
| action_return.valid = ActionValidCode.OPEN |
| return action_return |
|
|
| def __call__(self, |
| message: AgentMessage, |
| session_id=0, |
| **kwargs) -> AgentMessage: |
| |
| for hook in self._hooks.values(): |
| result = hook.before_action(self, message, session_id) |
| if result: |
| message = result |
|
|
| assert isinstance(message.content, FunctionCall) or ( |
| isinstance(message.content, dict) and 'name' in message.content |
| and 'parameters' in message.content) |
| if isinstance(message.content, dict): |
| name = message.content.get('name') |
| parameters = message.content.get('parameters') |
| else: |
| name = message.content.name |
| parameters = message.content.parameters |
|
|
| response_message = self.forward( |
| name=name, parameters=parameters, **kwargs) |
| if not isinstance(response_message, AgentMessage): |
| response_message = AgentMessage( |
| sender=self.__class__.__name__, |
| content=response_message, |
| ) |
|
|
| for hook in self._hooks.values(): |
| result = hook.after_action(self, response_message, session_id) |
| if result: |
| response_message = result |
| return response_message |
|
|
| def register_hook(self, hook: Callable): |
| handle = RemovableHandle(self._hooks) |
| self._hooks[handle.id] = hook |
| return handle |
|
|
|
|
| class AsyncActionExecutor(ActionExecutor): |
|
|
| async def forward(self, name, parameters, **kwargs) -> ActionReturn: |
| action_name, api_name = ( |
| name.split('.') if '.' in name else (name, 'run')) |
| action_return: ActionReturn = ActionReturn() |
| if action_name not in self: |
| if name == self.no_action.name: |
| action_return = self.no_action(parameters) |
| elif name == self.finish_action.name: |
| action_return = self.finish_action(parameters) |
| else: |
| action_return = self.invalid_action(parameters) |
| else: |
| action = self.actions[action_name] |
| if inspect.iscoroutinefunction(action.__call__): |
| action_return = await action(parameters, api_name) |
| else: |
| action_return = action(parameters, api_name) |
| action_return.valid = ActionValidCode.OPEN |
| return action_return |
|
|
| async def __call__(self, |
| message: AgentMessage, |
| session_id=0, |
| **kwargs) -> AgentMessage: |
| |
| for hook in self._hooks.values(): |
| if inspect.iscoroutinefunction(hook.before_action): |
| result = await hook.before_action(self, message, session_id) |
| else: |
| result = hook.before_action(self, message, session_id) |
| if result: |
| message = result |
|
|
| assert isinstance(message.content, FunctionCall) or ( |
| isinstance(message.content, dict) and 'name' in message.content |
| and 'parameters' in message.content) |
| if isinstance(message.content, dict): |
| name = message.content.get('name') |
| parameters = message.content.get('parameters') |
| else: |
| name = message.content.name |
| parameters = message.content.parameters |
|
|
| response_message = await self.forward( |
| name=name, parameters=parameters, **kwargs) |
| if not isinstance(response_message, AgentMessage): |
| response_message = AgentMessage( |
| sender=self.__class__.__name__, |
| content=response_message, |
| ) |
|
|
| for hook in self._hooks.values(): |
| if inspect.iscoroutinefunction(hook.after_action): |
| result = await hook.after_action(self, response_message, |
| session_id) |
| else: |
| result = hook.after_action(self, response_message, session_id) |
| if result: |
| response_message = result |
| return response_message |
|
|