InfiAgent / src /infiagent /agent /base_agent.py
g3eIL's picture
Upload 80 files
77320e4 verified
import asyncio
from abc import ABC, abstractmethod
from typing import Dict, Callable, Union, AsyncGenerator
from ..exceptions.exceptions import InputErrorException
from ..prompt import PromptTemplate
from ..schemas import AgentOutput, AgentType, AgentResponse
from ..llm.base_llm import BaseLLM
from ..tools import BaseTool
from ..utils import Config, get_logger
import os
from importlib import import_module
logger = get_logger()
LLM_CONF_OVERRIDE_KEY = ['psm', 'dc', 'temperature', 'top_p', 'top_k', 'max_tokens']
class BaseAgent(ABC):
"""Base Agent class defining the essential attributes and methods for an ALM Agent.
"""
def __init__(self, **kwargs):
"""
Initializes an instance of the Agent class.
"""
# Set default values
default_config = {
'name': 'agent',
'type': AgentType.react,
'version': '',
'description': '',
'prompt_template': None,
'auth': {}
}
# Update default values with provided config
default_config.update(kwargs)
# Access configuration data with a known default value
auth = default_config['auth']
self._set_auth_env(auth)
self._name: str = default_config['name']
self._type: AgentType = default_config['type']
self._version: str = default_config['version']
self._description: str = default_config['description']
self.__prompt_template: Union[PromptTemplate, None] = \
self._get_prompt_template(default_config['prompt_template'])
self.__llm: Union[BaseLLM, None] = None
self.__plugins_map: Dict = {}
self.__plugin_tool_function = {}
self.__plugin_tool_async_function = {}
self.__plugin_tool_description = None
@property
def name(self) -> str:
return self._name
@property
def type(self) -> AgentType:
return self._type
@property
def version(self) -> str:
return self._version
@property
def description(self) -> str:
return self._description
@property
def prompt_template(self) -> PromptTemplate:
return self.__prompt_template
@property
def llm(self) -> Union[BaseLLM, None]:
return self.__llm
@llm.setter
def llm(self, llm_client: BaseLLM):
if llm_client is None or not isinstance(llm_client, BaseLLM):
raise InputErrorException("Invalid llm client {}".format(type(llm_client)))
self.__llm = llm_client
@property
def plugins_map(self) -> Dict:
return self.__plugins_map.copy() # Return a copy to prevent external modification
def add_plugin(self, tool_name: str, tool):
if not tool_name or not tool:
raise InputErrorException("Adding invalid tool name: {}, type {}".format(tool_name, type(tool)))
self.__plugins_map[tool_name] = tool
def _set_auth_env(self, obj):
"""This method sets environment variables for authentication.
"""
for key in obj:
os.environ[key] = obj.get(key)
def _get_prompt_template(self, obj):
"""This method returns a prompt template instance based on the provided configuration.
"""
assert isinstance(obj, dict) or isinstance(obj, PromptTemplate)
if isinstance(obj, dict):
return {
key: self._parse_prompt_template(obj[key]) for key in obj
}
elif isinstance(obj, PromptTemplate):
ans = self._parse_prompt_template(obj)
return ans
else:
raise InputErrorException("Invalid PromptTemplate, it should be a dict or PromptTemplate. But get {}"
.format(type(obj)))
def _parse_prompt_template(self, obj: Union[dict, PromptTemplate]):
"""This method parses the prompt template configuration and returns a prompt template instance.
"""
assert isinstance(obj, dict) or isinstance(obj, PromptTemplate)
if isinstance(obj, PromptTemplate):
return obj
return PromptTemplate(input_variables=obj['input_variables'],
template=obj['template'],
validate_template=bool(obj.get('validate_template', True)))
@classmethod
def _get_basic_instance_from_config(cls, config_data):
agent_module_name = config_data.get("module_name", None)
agent_class_name = config_data.get("class_name", None)
if not agent_module_name or not agent_class_name:
raise InputErrorException("Agent module_name and class_name required, please check your config")
module = import_module(agent_module_name)
clazz = getattr(module, agent_class_name)
agent_instance = clazz(**config_data)
return agent_instance
@classmethod
def from_config_path_and_kwargs(cls, config_path, **kwargs):
config_data = Config.load(config_path)
logger.info(f"Use config from path {config_path} to init agent : {config_data}")
agent_instance = cls._get_basic_instance_from_config(config_data)
if 'llm' in config_data and 'params' in config_data['llm']:
for param in LLM_CONF_OVERRIDE_KEY:
if param in kwargs and kwargs[param]:
logger.info(f"Overwrite with new {param} {kwargs[param]}")
config_data['llm']['params'][param] = kwargs[param]
assert isinstance(agent_instance, BaseAgent)
agent_instance._init_llm(config_data.get("llm", {}))
agent_instance._init_plugins(config_data.get('plugins', []))
return agent_instance
def _init_llm(self, obj):
"""
This method parses the Language Model Manager (LLM) configuration and returns an LLM instance.
:param obj: A configuration dictionary or string.
:type obj: dict or str
:raises ValueError: If the specified LLM is not supported.
:return: An LLM instance.
:rtype: BaseLLM
"""
if isinstance(obj, str):
name = obj
model_params = dict()
else:
name = obj.get('model_name', None)
model_params = obj.get('params', dict())
module_name = obj['module_name']
class_name = obj['class_name']
module = import_module(module_name)
clazz = getattr(module, class_name)
llm = clazz(model_name=name, params=model_params)
self.llm = llm
def _init_plugins(self, configs):
"""
This method parses the plugin configuration and add each plugin into the plugins_map.
"""
assert isinstance(configs, list)
for plugin_config in configs:
if plugin_config.get('type', "") == 'agent':
# Agent as plugin
agent = BaseAgent.from_config_path_and_kwargs(plugin_config['config'])
self.plugins_map[plugin_config['name']] = agent
else:
# Tools as plugin
params = plugin_config.get('params', dict())
tool = BaseTool.from_config(config_input=plugin_config['config'], **params)
self.plugins_map[tool.name] = tool
@classmethod
async def async_from_config_path_and_kwargs(cls, config_path, **kwargs):
config_data = Config.load(config_path)
logger.info(f"Use config from path {config_path} to init agent : {config_data}")
agent_instance = cls._get_basic_instance_from_config(config_data)
# override default config with user input
if 'llm' in config_data and 'params' in config_data['llm']:
for param in LLM_CONF_OVERRIDE_KEY:
if param in kwargs and kwargs[param]:
logger.info(f"Overwrite with new {param} {kwargs[param]}")
config_data['llm']['params'][param] = kwargs[param]
# Create tasks for llm and each individual plugin
llm_config = config_data.get("llm", {})
plugin_configs = config_data.get('plugins', [])
# Create tasks for llm and each individual plugin
llm_task = asyncio.create_task(cls._async_init_llm(llm_config))
plugin_tasks = [asyncio.create_task(cls._async_init_plugin(plugin_config)) for
plugin_config in plugin_configs]
# Gather results
llm, *plugins = await asyncio.gather(llm_task, *plugin_tasks)
agent_instance.llm = llm
for plugin in plugins:
plugin_name, plugin_instance = plugin
agent_instance.add_plugin(plugin_name, plugin_instance)
return agent_instance
@classmethod
async def _async_init_llm(cls, llm_config):
llm_model_name = llm_config.get("module_name", None)
llm_class_name = llm_config.get("class_name", None)
if not llm_model_name or not llm_class_name:
raise InputErrorException("Agent LLM module_name and class_name required, please check your config")
module = import_module(llm_model_name)
clazz = getattr(module, llm_class_name)
assert issubclass(clazz, BaseLLM), f"{clazz} is not a subclass of BaseLLM"
llm_instance = await clazz.create(config_data=llm_config)
return llm_instance
@classmethod
async def _async_init_plugin(cls, plugin_config):
if plugin_config.get('type', "") == 'agent':
# Agent as plugin
agent = await BaseAgent.async_from_config_path_and_kwargs(plugin_config['config'])
return plugin_config['name'], agent
else:
# Tool as plugin
params = plugin_config.get('params', dict())
name = plugin_config.get('name', None)
config = plugin_config['config']
tool = await BaseTool.async_from_config(config_input=config, **params)
if name is None:
name = tool.name
logger.info("Init tool with name [{}], and description [{}]".format(name, tool.description))
return name, tool
@abstractmethod
def run(self, *args, **kwargs) -> [AgentResponse, None]:
"""Abstract method to be overridden by child classes for running the agent.
:return: The output of the agent.
:rtype: AgentOutput
"""
pass
async def async_run(self, *args, **kwargs) -> AsyncGenerator[AgentResponse, None]:
"""Abstract method to be overridden by child classes for running the agent.
:return: The output of the agent.
"""
yield self.run(*args, **kwargs)
def _get_plugin_function_map(self, method_name: str) -> Dict[str, Callable]:
if method_name == "run" and self.__plugin_tool_function:
return self.__plugin_tool_function
elif method_name == "async_run" and self.__plugin_tool_async_function:
return self.__plugin_tool_async_function
function_map = {}
for name, plugin_tool in self.plugins_map.items():
if isinstance(plugin_tool, (BaseTool, BaseAgent)):
function_map[name] = getattr(plugin_tool, method_name)
else:
logger.warning(f"No support for plugin name {name} of type {type(plugin_tool)}")
if method_name == "run":
self.__plugin_tool_function = function_map
elif method_name == "async_run":
self.__plugin_tool_async_function = function_map
return function_map
def get_plugin_tool_function(self) -> Dict[str, Callable]:
"""Format the function map for the function API.
:return: The function map.
:rtype: Dict[str, Callable]
"""
return self._get_plugin_function_map("run")
def get_plugin_tool_async_function(self) -> Dict[str, Callable]:
"""Format the function map for the function API.
:return: The function map.
:rtype: Dict[str, Callable]
"""
return self._get_plugin_function_map("async_run")
def _get_plugin_description(self):
if self.__plugin_tool_description:
return self.__plugin_tool_description
descriptions = ""
try:
for plugin_name, plugin in self.plugins_map.items():
descriptions += f"{plugin_name}[input]: {plugin.description}\n"
except Exception as e:
err_msg = "Failed to get plugin tool name and description. error: {}".format(str(e))
raise InputErrorException(err_msg) from e
self.__plugin_tool_description = descriptions
return descriptions
def clear(self):
"""
Clear and reset the agent.
"""
pass