|
|
import asyncio |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Dict, Callable, Union, AsyncGenerator |
|
|
|
|
|
from ..exceptions.exceptions import InputErrorException |
|
|
from ..prompt import PromptTemplate |
|
|
from ..schemas import AgentOutput, AgentType, AgentResponse |
|
|
|
|
|
from ..llm.base_llm import BaseLLM |
|
|
|
|
|
from ..tools import BaseTool |
|
|
from ..utils import Config, get_logger |
|
|
|
|
|
import os |
|
|
from importlib import import_module |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
LLM_CONF_OVERRIDE_KEY = ['psm', 'dc', 'temperature', 'top_p', 'top_k', 'max_tokens'] |
|
|
|
|
|
|
|
|
class BaseAgent(ABC): |
|
|
"""Base Agent class defining the essential attributes and methods for an ALM Agent. |
|
|
""" |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
""" |
|
|
Initializes an instance of the Agent class. |
|
|
""" |
|
|
|
|
|
default_config = { |
|
|
'name': 'agent', |
|
|
'type': AgentType.react, |
|
|
'version': '', |
|
|
'description': '', |
|
|
'prompt_template': None, |
|
|
'auth': {} |
|
|
} |
|
|
|
|
|
default_config.update(kwargs) |
|
|
|
|
|
|
|
|
auth = default_config['auth'] |
|
|
self._set_auth_env(auth) |
|
|
|
|
|
self._name: str = default_config['name'] |
|
|
self._type: AgentType = default_config['type'] |
|
|
self._version: str = default_config['version'] |
|
|
self._description: str = default_config['description'] |
|
|
self.__prompt_template: Union[PromptTemplate, None] = \ |
|
|
self._get_prompt_template(default_config['prompt_template']) |
|
|
self.__llm: Union[BaseLLM, None] = None |
|
|
self.__plugins_map: Dict = {} |
|
|
self.__plugin_tool_function = {} |
|
|
self.__plugin_tool_async_function = {} |
|
|
self.__plugin_tool_description = None |
|
|
|
|
|
@property |
|
|
def name(self) -> str: |
|
|
return self._name |
|
|
|
|
|
@property |
|
|
def type(self) -> AgentType: |
|
|
return self._type |
|
|
|
|
|
@property |
|
|
def version(self) -> str: |
|
|
return self._version |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return self._description |
|
|
|
|
|
@property |
|
|
def prompt_template(self) -> PromptTemplate: |
|
|
return self.__prompt_template |
|
|
|
|
|
@property |
|
|
def llm(self) -> Union[BaseLLM, None]: |
|
|
return self.__llm |
|
|
|
|
|
@llm.setter |
|
|
def llm(self, llm_client: BaseLLM): |
|
|
if llm_client is None or not isinstance(llm_client, BaseLLM): |
|
|
raise InputErrorException("Invalid llm client {}".format(type(llm_client))) |
|
|
self.__llm = llm_client |
|
|
|
|
|
@property |
|
|
def plugins_map(self) -> Dict: |
|
|
return self.__plugins_map.copy() |
|
|
|
|
|
def add_plugin(self, tool_name: str, tool): |
|
|
if not tool_name or not tool: |
|
|
raise InputErrorException("Adding invalid tool name: {}, type {}".format(tool_name, type(tool))) |
|
|
self.__plugins_map[tool_name] = tool |
|
|
|
|
|
def _set_auth_env(self, obj): |
|
|
"""This method sets environment variables for authentication. |
|
|
""" |
|
|
for key in obj: |
|
|
os.environ[key] = obj.get(key) |
|
|
|
|
|
def _get_prompt_template(self, obj): |
|
|
"""This method returns a prompt template instance based on the provided configuration. |
|
|
""" |
|
|
assert isinstance(obj, dict) or isinstance(obj, PromptTemplate) |
|
|
if isinstance(obj, dict): |
|
|
return { |
|
|
key: self._parse_prompt_template(obj[key]) for key in obj |
|
|
} |
|
|
elif isinstance(obj, PromptTemplate): |
|
|
ans = self._parse_prompt_template(obj) |
|
|
return ans |
|
|
else: |
|
|
raise InputErrorException("Invalid PromptTemplate, it should be a dict or PromptTemplate. But get {}" |
|
|
.format(type(obj))) |
|
|
|
|
|
def _parse_prompt_template(self, obj: Union[dict, PromptTemplate]): |
|
|
"""This method parses the prompt template configuration and returns a prompt template instance. |
|
|
""" |
|
|
assert isinstance(obj, dict) or isinstance(obj, PromptTemplate) |
|
|
if isinstance(obj, PromptTemplate): |
|
|
return obj |
|
|
return PromptTemplate(input_variables=obj['input_variables'], |
|
|
template=obj['template'], |
|
|
validate_template=bool(obj.get('validate_template', True))) |
|
|
|
|
|
@classmethod |
|
|
def _get_basic_instance_from_config(cls, config_data): |
|
|
agent_module_name = config_data.get("module_name", None) |
|
|
agent_class_name = config_data.get("class_name", None) |
|
|
if not agent_module_name or not agent_class_name: |
|
|
raise InputErrorException("Agent module_name and class_name required, please check your config") |
|
|
|
|
|
module = import_module(agent_module_name) |
|
|
clazz = getattr(module, agent_class_name) |
|
|
agent_instance = clazz(**config_data) |
|
|
return agent_instance |
|
|
|
|
|
@classmethod |
|
|
def from_config_path_and_kwargs(cls, config_path, **kwargs): |
|
|
config_data = Config.load(config_path) |
|
|
logger.info(f"Use config from path {config_path} to init agent : {config_data}") |
|
|
agent_instance = cls._get_basic_instance_from_config(config_data) |
|
|
|
|
|
if 'llm' in config_data and 'params' in config_data['llm']: |
|
|
for param in LLM_CONF_OVERRIDE_KEY: |
|
|
if param in kwargs and kwargs[param]: |
|
|
logger.info(f"Overwrite with new {param} {kwargs[param]}") |
|
|
config_data['llm']['params'][param] = kwargs[param] |
|
|
|
|
|
assert isinstance(agent_instance, BaseAgent) |
|
|
agent_instance._init_llm(config_data.get("llm", {})) |
|
|
agent_instance._init_plugins(config_data.get('plugins', [])) |
|
|
return agent_instance |
|
|
|
|
|
def _init_llm(self, obj): |
|
|
""" |
|
|
This method parses the Language Model Manager (LLM) configuration and returns an LLM instance. |
|
|
|
|
|
:param obj: A configuration dictionary or string. |
|
|
:type obj: dict or str |
|
|
:raises ValueError: If the specified LLM is not supported. |
|
|
:return: An LLM instance. |
|
|
:rtype: BaseLLM |
|
|
""" |
|
|
if isinstance(obj, str): |
|
|
name = obj |
|
|
model_params = dict() |
|
|
else: |
|
|
name = obj.get('model_name', None) |
|
|
model_params = obj.get('params', dict()) |
|
|
|
|
|
module_name = obj['module_name'] |
|
|
class_name = obj['class_name'] |
|
|
|
|
|
module = import_module(module_name) |
|
|
clazz = getattr(module, class_name) |
|
|
|
|
|
llm = clazz(model_name=name, params=model_params) |
|
|
self.llm = llm |
|
|
|
|
|
def _init_plugins(self, configs): |
|
|
""" |
|
|
This method parses the plugin configuration and add each plugin into the plugins_map. |
|
|
""" |
|
|
assert isinstance(configs, list) |
|
|
for plugin_config in configs: |
|
|
if plugin_config.get('type', "") == 'agent': |
|
|
|
|
|
agent = BaseAgent.from_config_path_and_kwargs(plugin_config['config']) |
|
|
self.plugins_map[plugin_config['name']] = agent |
|
|
else: |
|
|
|
|
|
params = plugin_config.get('params', dict()) |
|
|
tool = BaseTool.from_config(config_input=plugin_config['config'], **params) |
|
|
self.plugins_map[tool.name] = tool |
|
|
|
|
|
@classmethod |
|
|
async def async_from_config_path_and_kwargs(cls, config_path, **kwargs): |
|
|
config_data = Config.load(config_path) |
|
|
logger.info(f"Use config from path {config_path} to init agent : {config_data}") |
|
|
agent_instance = cls._get_basic_instance_from_config(config_data) |
|
|
|
|
|
|
|
|
if 'llm' in config_data and 'params' in config_data['llm']: |
|
|
for param in LLM_CONF_OVERRIDE_KEY: |
|
|
if param in kwargs and kwargs[param]: |
|
|
logger.info(f"Overwrite with new {param} {kwargs[param]}") |
|
|
config_data['llm']['params'][param] = kwargs[param] |
|
|
|
|
|
|
|
|
llm_config = config_data.get("llm", {}) |
|
|
plugin_configs = config_data.get('plugins', []) |
|
|
|
|
|
|
|
|
|
|
|
llm_task = asyncio.create_task(cls._async_init_llm(llm_config)) |
|
|
plugin_tasks = [asyncio.create_task(cls._async_init_plugin(plugin_config)) for |
|
|
plugin_config in plugin_configs] |
|
|
|
|
|
|
|
|
|
|
|
llm, *plugins = await asyncio.gather(llm_task, *plugin_tasks) |
|
|
|
|
|
agent_instance.llm = llm |
|
|
for plugin in plugins: |
|
|
plugin_name, plugin_instance = plugin |
|
|
agent_instance.add_plugin(plugin_name, plugin_instance) |
|
|
return agent_instance |
|
|
|
|
|
@classmethod |
|
|
async def _async_init_llm(cls, llm_config): |
|
|
llm_model_name = llm_config.get("module_name", None) |
|
|
llm_class_name = llm_config.get("class_name", None) |
|
|
if not llm_model_name or not llm_class_name: |
|
|
raise InputErrorException("Agent LLM module_name and class_name required, please check your config") |
|
|
module = import_module(llm_model_name) |
|
|
clazz = getattr(module, llm_class_name) |
|
|
assert issubclass(clazz, BaseLLM), f"{clazz} is not a subclass of BaseLLM" |
|
|
llm_instance = await clazz.create(config_data=llm_config) |
|
|
return llm_instance |
|
|
|
|
|
@classmethod |
|
|
async def _async_init_plugin(cls, plugin_config): |
|
|
|
|
|
if plugin_config.get('type', "") == 'agent': |
|
|
|
|
|
agent = await BaseAgent.async_from_config_path_and_kwargs(plugin_config['config']) |
|
|
return plugin_config['name'], agent |
|
|
else: |
|
|
|
|
|
params = plugin_config.get('params', dict()) |
|
|
name = plugin_config.get('name', None) |
|
|
config = plugin_config['config'] |
|
|
|
|
|
tool = await BaseTool.async_from_config(config_input=config, **params) |
|
|
|
|
|
if name is None: |
|
|
name = tool.name |
|
|
logger.info("Init tool with name [{}], and description [{}]".format(name, tool.description)) |
|
|
return name, tool |
|
|
|
|
|
@abstractmethod |
|
|
def run(self, *args, **kwargs) -> [AgentResponse, None]: |
|
|
"""Abstract method to be overridden by child classes for running the agent. |
|
|
|
|
|
:return: The output of the agent. |
|
|
:rtype: AgentOutput |
|
|
""" |
|
|
pass |
|
|
|
|
|
async def async_run(self, *args, **kwargs) -> AsyncGenerator[AgentResponse, None]: |
|
|
"""Abstract method to be overridden by child classes for running the agent. |
|
|
|
|
|
:return: The output of the agent. |
|
|
""" |
|
|
yield self.run(*args, **kwargs) |
|
|
|
|
|
def _get_plugin_function_map(self, method_name: str) -> Dict[str, Callable]: |
|
|
if method_name == "run" and self.__plugin_tool_function: |
|
|
return self.__plugin_tool_function |
|
|
elif method_name == "async_run" and self.__plugin_tool_async_function: |
|
|
return self.__plugin_tool_async_function |
|
|
|
|
|
function_map = {} |
|
|
|
|
|
for name, plugin_tool in self.plugins_map.items(): |
|
|
if isinstance(plugin_tool, (BaseTool, BaseAgent)): |
|
|
function_map[name] = getattr(plugin_tool, method_name) |
|
|
else: |
|
|
logger.warning(f"No support for plugin name {name} of type {type(plugin_tool)}") |
|
|
|
|
|
if method_name == "run": |
|
|
self.__plugin_tool_function = function_map |
|
|
elif method_name == "async_run": |
|
|
self.__plugin_tool_async_function = function_map |
|
|
|
|
|
return function_map |
|
|
|
|
|
def get_plugin_tool_function(self) -> Dict[str, Callable]: |
|
|
"""Format the function map for the function API. |
|
|
|
|
|
:return: The function map. |
|
|
:rtype: Dict[str, Callable] |
|
|
""" |
|
|
return self._get_plugin_function_map("run") |
|
|
|
|
|
def get_plugin_tool_async_function(self) -> Dict[str, Callable]: |
|
|
"""Format the function map for the function API. |
|
|
|
|
|
:return: The function map. |
|
|
:rtype: Dict[str, Callable] |
|
|
""" |
|
|
return self._get_plugin_function_map("async_run") |
|
|
|
|
|
def _get_plugin_description(self): |
|
|
if self.__plugin_tool_description: |
|
|
return self.__plugin_tool_description |
|
|
|
|
|
descriptions = "" |
|
|
try: |
|
|
for plugin_name, plugin in self.plugins_map.items(): |
|
|
descriptions += f"{plugin_name}[input]: {plugin.description}\n" |
|
|
except Exception as e: |
|
|
err_msg = "Failed to get plugin tool name and description. error: {}".format(str(e)) |
|
|
raise InputErrorException(err_msg) from e |
|
|
|
|
|
self.__plugin_tool_description = descriptions |
|
|
return descriptions |
|
|
|
|
|
def clear(self): |
|
|
""" |
|
|
Clear and reset the agent. |
|
|
""" |
|
|
pass |
|
|
|