|
|
import asyncio |
|
|
import json |
|
|
from pydantic import create_model, Field |
|
|
from typing import Optional, Callable, Type, List, Any |
|
|
|
|
|
from .agent import Agent |
|
|
from ..core.logging import logger |
|
|
from ..core.registry import MODULE_REGISTRY, ACTION_FUNCTION_REGISTRY |
|
|
from ..models.model_configs import LLMConfig |
|
|
from ..actions.action import Action, ActionOutput, ActionInput |
|
|
from ..utils.utils import generate_dynamic_class_name, make_parent_folder |
|
|
from ..core.message import Message, MessageType |
|
|
|
|
|
|
|
|
class ActionAgent(Agent): |
|
|
""" |
|
|
ActionAgent is a specialized agent that executes a provided function directly without LLM. |
|
|
It creates an action that uses the provided function as the execution backbone. |
|
|
|
|
|
Attributes: |
|
|
name (str): The name of the agent. |
|
|
description (str): A description of the agent's purpose and capabilities. |
|
|
inputs (List[dict]): List of input specifications, where each dict contains: |
|
|
- name (str): Name of the input parameter |
|
|
- type (str): Type of the input |
|
|
- description (str): Description of what the input represents |
|
|
- required (bool, optional): Whether this input is required (default: True) |
|
|
outputs (List[dict]): List of output specifications, where each dict contains: |
|
|
- name (str): Name of the output field |
|
|
- type (str): Type of the output |
|
|
- description (str): Description of what the output represents |
|
|
- required (bool, optional): Whether this output is required (default: True) |
|
|
execute_func (Callable): The function to execute the agent. |
|
|
async_execute_func (Callable, Optional): Async version of the function. If not provided, |
|
|
an async wrapper will be automatically created around execute_func. |
|
|
llm_config (LLMConfig, optional): Configuration for the language model (minimal usage). |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
name: str, |
|
|
description: str, |
|
|
inputs: List[dict], |
|
|
outputs: List[dict], |
|
|
execute_func: Callable, |
|
|
async_execute_func: Optional[Callable] = None, |
|
|
llm_config: Optional[LLMConfig] = None, |
|
|
**kwargs |
|
|
): |
|
|
|
|
|
if not callable(execute_func): |
|
|
raise ValueError("execute_func must be callable") |
|
|
|
|
|
if async_execute_func is not None and not callable(async_execute_func): |
|
|
raise ValueError("async_execute_func must be callable") |
|
|
|
|
|
|
|
|
self._validate_inputs_outputs(inputs, outputs) |
|
|
|
|
|
|
|
|
is_human = llm_config is None |
|
|
|
|
|
|
|
|
super().__init__( |
|
|
name=name, |
|
|
description=description, |
|
|
llm_config=llm_config, |
|
|
is_human=is_human, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
self.execute_func = execute_func |
|
|
self.async_execute_func = async_execute_func |
|
|
self.inputs = inputs |
|
|
self.outputs = outputs |
|
|
|
|
|
|
|
|
action = self._create_function_action_with_params( |
|
|
name, execute_func, async_execute_func, inputs, outputs |
|
|
) |
|
|
self.add_action(action) |
|
|
|
|
|
def init_llm(self): |
|
|
pass |
|
|
|
|
|
def _validate_inputs_outputs(self, inputs: List[dict], outputs: List[dict]): |
|
|
"""Validate the structure of inputs and outputs.""" |
|
|
|
|
|
if inputs is None: |
|
|
inputs = [] |
|
|
|
|
|
if outputs is None: |
|
|
outputs = [] |
|
|
|
|
|
|
|
|
for i, input_field in enumerate(inputs): |
|
|
if not isinstance(input_field, dict): |
|
|
raise ValueError(f"Input field {i} must be a dictionary, got {type(input_field)}") |
|
|
|
|
|
required_keys = ["name", "type", "description"] |
|
|
for key in required_keys: |
|
|
if key not in input_field: |
|
|
raise ValueError(f"Input field {i} missing required key '{key}'") |
|
|
|
|
|
if not isinstance(input_field["name"], str): |
|
|
raise ValueError(f"Input field {i} 'name' must be a string, got {type(input_field['name'])}") |
|
|
|
|
|
if not isinstance(input_field["type"], str): |
|
|
raise ValueError(f"Input field {i} 'type' must be a string, got {type(input_field['type'])}") |
|
|
|
|
|
if not isinstance(input_field["description"], str): |
|
|
raise ValueError(f"Input field {i} 'description' must be a string, got {type(input_field['description'])}") |
|
|
|
|
|
|
|
|
input_names = [field["name"] for field in inputs] |
|
|
if len(input_names) != len(set(input_names)): |
|
|
raise ValueError(f"Duplicate input names found: {[name for name in input_names if input_names.count(name) > 1]}") |
|
|
|
|
|
|
|
|
for i, output_field in enumerate(outputs): |
|
|
if not isinstance(output_field, dict): |
|
|
raise ValueError(f"Output field {i} must be a dictionary, got {type(output_field)}") |
|
|
|
|
|
required_keys = ["name", "type", "description"] |
|
|
for key in required_keys: |
|
|
if key not in output_field: |
|
|
raise ValueError(f"Output field {i} missing required key '{key}'") |
|
|
|
|
|
if not isinstance(output_field["name"], str): |
|
|
raise ValueError(f"Output field {i} 'name' must be a string, got {type(output_field['name'])}") |
|
|
|
|
|
if not isinstance(output_field["type"], str): |
|
|
raise ValueError(f"Output field {i} 'type' must be a string, got {type(output_field['type'])}") |
|
|
|
|
|
if not isinstance(output_field["description"], str): |
|
|
raise ValueError(f"Output field {i} 'description' must be a string, got {type(output_field['description'])}") |
|
|
|
|
|
|
|
|
output_names = [field["name"] for field in outputs] |
|
|
if len(output_names) != len(set(output_names)): |
|
|
raise ValueError(f"Duplicate output names found: {[name for name in output_names if output_names.count(name) > 1]}") |
|
|
|
|
|
def _create_function_action_input_type(self, name: str, inputs: List[dict]) -> Type[ActionInput]: |
|
|
"""Create ActionInput type from input specifications.""" |
|
|
action_input_fields = {} |
|
|
for field in inputs: |
|
|
required = field.get("required", True) |
|
|
if required: |
|
|
action_input_fields[field["name"]] = (str, Field(description=field["description"])) |
|
|
else: |
|
|
action_input_fields[field["name"]] = (Optional[str], Field(default=None, description=field["description"])) |
|
|
|
|
|
action_input_type = create_model( |
|
|
self._get_unique_class_name( |
|
|
generate_dynamic_class_name(f"{name} action_input") |
|
|
), |
|
|
**action_input_fields, |
|
|
__base__=ActionInput |
|
|
) |
|
|
return action_input_type |
|
|
|
|
|
def _create_function_action_output_type(self, name: str, outputs: List[dict]) -> Type[ActionOutput]: |
|
|
"""Create ActionOutput type from output specifications.""" |
|
|
action_output_fields = {} |
|
|
for field in outputs: |
|
|
required = field.get("required", True) |
|
|
if required: |
|
|
action_output_fields[field["name"]] = (Any, Field(description=field["description"])) |
|
|
else: |
|
|
action_output_fields[field["name"]] = (Optional[Any], Field(default=None, description=field["description"])) |
|
|
|
|
|
action_output_type = create_model( |
|
|
self._get_unique_class_name( |
|
|
generate_dynamic_class_name(f"{name} action_output") |
|
|
), |
|
|
**action_output_fields, |
|
|
__base__=ActionOutput |
|
|
) |
|
|
return action_output_type |
|
|
|
|
|
def _create_execute_method(self, execute_func: Callable): |
|
|
"""Create the execute method for the action.""" |
|
|
def execute_method(action_self, llm=None, inputs=None, sys_msg=None, return_prompt=False, **kwargs): |
|
|
|
|
|
if inputs is None: |
|
|
inputs = {} |
|
|
|
|
|
|
|
|
required_inputs = action_self.inputs_format.get_required_input_names() |
|
|
missing_inputs = [input_name for input_name in required_inputs if input_name not in inputs] |
|
|
if missing_inputs: |
|
|
raise ValueError(f"Missing required inputs: {missing_inputs}") |
|
|
|
|
|
|
|
|
filtered_inputs = {} |
|
|
for input_name, input_value in inputs.items(): |
|
|
if input_name in [field["name"] for field in self.inputs]: |
|
|
filtered_inputs[input_name] = input_value |
|
|
else: |
|
|
logger.warning(f"Unexpected input '{input_name}' provided") |
|
|
|
|
|
|
|
|
try: |
|
|
result = execute_func(**filtered_inputs) |
|
|
except Exception as e: |
|
|
|
|
|
try: |
|
|
|
|
|
output_fields = action_self.outputs_format.get_attrs() |
|
|
if "error" in output_fields: |
|
|
error_output = action_self.outputs_format( |
|
|
error=f"Function execution failed: {str(e)}" |
|
|
) |
|
|
elif len(output_fields) > 0: |
|
|
|
|
|
first_field = output_fields[0] |
|
|
error_output = action_self.outputs_format(**{first_field: f"Error: {str(e)}"}) |
|
|
else: |
|
|
|
|
|
error_output = action_self.outputs_format() |
|
|
except Exception as create_error: |
|
|
|
|
|
logger.error(f"Failed to create error output: {create_error}") |
|
|
error_output = action_self.outputs_format() |
|
|
return error_output, "Function execution" |
|
|
|
|
|
|
|
|
if isinstance(result, dict): |
|
|
|
|
|
output = action_self.outputs_format(**result) |
|
|
else: |
|
|
|
|
|
output_fields = action_self.outputs_format.get_attrs() |
|
|
if len(output_fields) > 0: |
|
|
first_field = output_fields[0] |
|
|
output = action_self.outputs_format(**{first_field: result}) |
|
|
else: |
|
|
|
|
|
output = action_self.outputs_format() |
|
|
|
|
|
return output, "Function execution" |
|
|
|
|
|
return execute_method |
|
|
|
|
|
def _create_async_execute_method(self, async_execute_func: Callable, execute_func: Callable): |
|
|
"""Create the async execute method for the action.""" |
|
|
async def async_execute_method(action_self, llm=None, inputs=None, sys_msg=None, return_prompt=False, **kwargs): |
|
|
|
|
|
if inputs is None: |
|
|
inputs = {} |
|
|
|
|
|
|
|
|
required_inputs = action_self.inputs_format.get_required_input_names() |
|
|
missing_inputs = [input_name for input_name in required_inputs if input_name not in inputs] |
|
|
if missing_inputs: |
|
|
raise ValueError(f"Missing required inputs: {missing_inputs}") |
|
|
|
|
|
|
|
|
filtered_inputs = {} |
|
|
for input_name, input_value in inputs.items(): |
|
|
if input_name in [field["name"] for field in self.inputs]: |
|
|
filtered_inputs[input_name] = input_value |
|
|
else: |
|
|
logger.warning(f"Unexpected input '{input_name}' provided") |
|
|
|
|
|
|
|
|
try: |
|
|
if async_execute_func is not None: |
|
|
result = await async_execute_func(**filtered_inputs) |
|
|
else: |
|
|
|
|
|
loop = asyncio.get_event_loop() |
|
|
result = await loop.run_in_executor(None, lambda: execute_func(**filtered_inputs)) |
|
|
except Exception as e: |
|
|
|
|
|
try: |
|
|
|
|
|
output_fields = action_self.outputs_format.get_attrs() |
|
|
if "error" in output_fields: |
|
|
error_output = action_self.outputs_format( |
|
|
error=f"Async function execution failed: {str(e)}" |
|
|
) |
|
|
elif len(output_fields) > 0: |
|
|
|
|
|
first_field = list(output_fields.keys())[0] |
|
|
error_output = action_self.outputs_format(**{first_field: f"Error: {str(e)}"}) |
|
|
else: |
|
|
|
|
|
error_output = action_self.outputs_format() |
|
|
except Exception as create_error: |
|
|
|
|
|
logger.error(f"Failed to create error output: {create_error}") |
|
|
error_output = action_self.outputs_format() |
|
|
return error_output, "Async function execution" |
|
|
|
|
|
|
|
|
if isinstance(result, dict): |
|
|
|
|
|
output = action_self.outputs_format(**result) |
|
|
else: |
|
|
|
|
|
output_fields = action_self.outputs_format.get_attrs() |
|
|
if len(output_fields) > 0: |
|
|
first_field = output_fields[0] |
|
|
output = action_self.outputs_format(**{first_field: result}) |
|
|
else: |
|
|
|
|
|
output = action_self.outputs_format() |
|
|
|
|
|
return output, "Async function execution" |
|
|
|
|
|
return async_execute_method |
|
|
|
|
|
def _create_function_action_with_params(self, name: str, execute_func: Callable, async_execute_func: Callable, inputs: List[dict], outputs: List[dict]) -> Action: |
|
|
"""Create an action that executes the provided function with given parameters.""" |
|
|
|
|
|
|
|
|
action_input_type = self._create_function_action_input_type(name, inputs) |
|
|
action_output_type = self._create_function_action_output_type(name, outputs) |
|
|
|
|
|
|
|
|
action_cls_name = self._get_unique_class_name( |
|
|
generate_dynamic_class_name(f"{name} function action") |
|
|
) |
|
|
|
|
|
|
|
|
function_action_cls = create_model( |
|
|
action_cls_name, |
|
|
__base__=Action |
|
|
) |
|
|
|
|
|
|
|
|
function_action = function_action_cls( |
|
|
name=action_cls_name, |
|
|
description=f"Executes {execute_func.__name__} function", |
|
|
inputs_format=action_input_type, |
|
|
outputs_format=action_output_type |
|
|
) |
|
|
|
|
|
|
|
|
execute_method = self._create_execute_method(execute_func) |
|
|
async_execute_method = self._create_async_execute_method(async_execute_func, execute_func) |
|
|
|
|
|
|
|
|
function_action.execute = execute_method.__get__(function_action, type(function_action)) |
|
|
function_action.async_execute = async_execute_method.__get__(function_action, type(function_action)) |
|
|
|
|
|
return function_action |
|
|
|
|
|
def _create_function_action(self, name: str, execute_func: Callable, async_execute_func: Callable, inputs: List[dict], outputs: List[dict]) -> Action: |
|
|
"""Create an action that executes the provided function.""" |
|
|
return self._create_function_action_with_params( |
|
|
name, |
|
|
execute_func, |
|
|
async_execute_func, |
|
|
inputs, |
|
|
outputs |
|
|
) |
|
|
|
|
|
def get_config(self) -> dict: |
|
|
"""Get configuration for the ActionAgent.""" |
|
|
|
|
|
config = super().get_config() |
|
|
|
|
|
|
|
|
config.update({ |
|
|
"class_name": "ActionAgent", |
|
|
"execute_func_name": self.execute_func.__name__ if self.execute_func else None, |
|
|
"async_execute_func_name": self.async_execute_func.__name__ if self.async_execute_func else None, |
|
|
"inputs": self.inputs, |
|
|
"outputs": self.outputs |
|
|
}) |
|
|
return config |
|
|
|
|
|
def save_module(self, path: str, ignore: List[str] = [], **kwargs) -> str: |
|
|
"""Save the ActionAgent configuration to a JSON file. |
|
|
|
|
|
Args: |
|
|
path: File path where the configuration should be saved |
|
|
ignore: List of keys to exclude from the saved configuration |
|
|
**kwargs (Any): Additional parameters for the save operation |
|
|
|
|
|
Returns: |
|
|
The path where the configuration was saved |
|
|
""" |
|
|
config = self.get_config() |
|
|
|
|
|
|
|
|
config.update({ |
|
|
"class_name": "ActionAgent", |
|
|
"execute_func_name": self.execute_func.__name__ if self.execute_func else None, |
|
|
"async_execute_func_name": self.async_execute_func.__name__ if self.async_execute_func else None, |
|
|
"inputs": self.inputs, |
|
|
"outputs": self.outputs |
|
|
}) |
|
|
|
|
|
|
|
|
for ignore_key in ignore: |
|
|
config.pop(ignore_key, None) |
|
|
|
|
|
|
|
|
make_parent_folder(path) |
|
|
with open(path, 'w', encoding='utf-8') as f: |
|
|
json.dump(config, f, indent=4, ensure_ascii=False) |
|
|
|
|
|
return path |
|
|
|
|
|
@classmethod |
|
|
def load_module(cls, path: str, llm_config: LLMConfig = None, **kwargs) -> "ActionAgent": |
|
|
"""Load the ActionAgent from a JSON file. |
|
|
|
|
|
Args: |
|
|
path: The path of the file |
|
|
llm_config: The LLMConfig instance (optional) |
|
|
**kwargs: Additional keyword arguments |
|
|
|
|
|
Returns: |
|
|
ActionAgent: The loaded agent instance |
|
|
|
|
|
Raises: |
|
|
KeyError: If required functions are not found in the registry |
|
|
""" |
|
|
|
|
|
with open(path, 'r', encoding='utf-8') as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
execute_func_name = config.get("execute_func_name") |
|
|
async_execute_func_name = config.get("async_execute_func_name") |
|
|
|
|
|
|
|
|
execute_func = None |
|
|
async_execute_func = None |
|
|
|
|
|
if execute_func_name: |
|
|
if not ACTION_FUNCTION_REGISTRY.has_function(execute_func_name): |
|
|
raise KeyError(f"Function '{execute_func_name}' not found in registry. Please register it first.") |
|
|
execute_func = ACTION_FUNCTION_REGISTRY.get_function(execute_func_name) |
|
|
|
|
|
if async_execute_func_name: |
|
|
if not ACTION_FUNCTION_REGISTRY.has_function(async_execute_func_name): |
|
|
raise KeyError(f"Function '{async_execute_func_name}' not found in registry. Please register it first.") |
|
|
async_execute_func = ACTION_FUNCTION_REGISTRY.get_function(async_execute_func_name) |
|
|
|
|
|
|
|
|
agent = cls( |
|
|
name=config["name"], |
|
|
description=config["description"], |
|
|
inputs=config["inputs"], |
|
|
outputs=config["outputs"], |
|
|
execute_func=execute_func, |
|
|
async_execute_func=async_execute_func, |
|
|
llm_config=llm_config, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
return agent |
|
|
|
|
|
def __call__(self, inputs: dict = None, return_msg_type: MessageType = MessageType.UNKNOWN, **kwargs) -> Message: |
|
|
""" |
|
|
Call the main function action. |
|
|
|
|
|
Args: |
|
|
inputs (dict): The inputs to the function action. |
|
|
return_msg_type (MessageType): The type of message to return. |
|
|
**kwargs (Any): Additional keyword arguments. |
|
|
|
|
|
Returns: |
|
|
Message: The output of the function action. |
|
|
""" |
|
|
inputs = inputs or {} |
|
|
return super().__call__(action_name=self.main_action_name, action_input_data=inputs, return_msg_type=return_msg_type, **kwargs) |
|
|
|
|
|
@property |
|
|
def main_action_name(self) -> str: |
|
|
""" |
|
|
Get the name of the main function action for this agent. |
|
|
|
|
|
Returns: |
|
|
The name of the main function action |
|
|
""" |
|
|
for action in self.actions: |
|
|
if action.name != self.cext_action_name: |
|
|
return action.name |
|
|
raise ValueError("Couldn't find the main action name!") |
|
|
|
|
|
def _get_unique_class_name(self, candidate_name: str) -> str: |
|
|
""" |
|
|
Get a unique class name by checking if it already exists in the registry. |
|
|
If it does, append "Vx" to make it unique. |
|
|
""" |
|
|
if not MODULE_REGISTRY.has_module(candidate_name): |
|
|
return candidate_name |
|
|
|
|
|
counter = 1 |
|
|
while True: |
|
|
new_name = f"{candidate_name}V{counter}" |
|
|
if not MODULE_REGISTRY.has_module(new_name): |
|
|
return new_name |
|
|
counter += 1 |