File size: 5,525 Bytes
5669b22 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | from typing import Type, Literal
from loguru import logger
from .agents.agent_interface import AgentInterface
from .agents.basic_memory_agent import BasicMemoryAgent
from .stateless_llm_factory import LLMFactory as StatelessLLMFactory
from .agents.hume_ai import HumeAIAgent
from .agents.letta_agent import LettaAgent
from ..mcpp.tool_manager import ToolManager
from ..mcpp.tool_executor import ToolExecutor
from typing import Optional
class AgentFactory:
@staticmethod
def create_agent(
conversation_agent_choice: str,
agent_settings: dict,
llm_configs: dict,
system_prompt: str,
live2d_model=None,
tts_preprocessor_config=None,
**kwargs,
) -> Type[AgentInterface]:
"""Create an agent based on the configuration.
Args:
conversation_agent_choice: The type of agent to create
agent_settings: Settings for different types of agents
llm_configs: Pool of LLM configurations
system_prompt: The system prompt to use
live2d_model: Live2D model instance for expression extraction
tts_preprocessor_config: Configuration for TTS preprocessing
**kwargs: Additional arguments
"""
logger.info(f"Initializing agent: {conversation_agent_choice}")
if conversation_agent_choice == "basic_memory_agent":
# Get the LLM provider choice from agent settings
basic_memory_settings: dict = agent_settings.get("basic_memory_agent", {})
llm_provider: str = basic_memory_settings.get("llm_provider")
if not llm_provider:
raise ValueError("LLM provider not specified for basic memory agent")
# Get the LLM config for this provider
llm_config: dict = llm_configs.get(llm_provider)
interrupt_method: Literal["system", "user"] = llm_config.pop(
"interrupt_method", "user"
)
if not llm_config:
raise ValueError(
f"Configuration not found for LLM provider: {llm_provider}"
)
# Create the stateless LLM
llm = StatelessLLMFactory.create_llm(
llm_provider=llm_provider, system_prompt=system_prompt, **llm_config
)
tool_prompts = kwargs.get("system_config", {}).get("tool_prompts", {})
# Extract MCP components/data needed by BasicMemoryAgent from kwargs
tool_manager: Optional[ToolManager] = kwargs.get("tool_manager")
tool_executor: Optional[ToolExecutor] = kwargs.get("tool_executor")
mcp_prompt_string: str = kwargs.get("mcp_prompt_string", "")
# Create the agent with the LLM and live2d_model
return BasicMemoryAgent(
llm=llm,
system=system_prompt,
live2d_model=live2d_model,
tts_preprocessor_config=tts_preprocessor_config,
faster_first_response=basic_memory_settings.get(
"faster_first_response", True
),
segment_method=basic_memory_settings.get("segment_method", "pysbd"),
use_mcpp=basic_memory_settings.get("use_mcpp", False),
interrupt_method=interrupt_method,
tool_prompts=tool_prompts,
tool_manager=tool_manager,
tool_executor=tool_executor,
mcp_prompt_string=mcp_prompt_string,
)
elif conversation_agent_choice == "mem0_agent":
from .agents.mem0_llm import LLM as Mem0LLM
mem0_settings = agent_settings.get("mem0_agent", {})
if not mem0_settings:
raise ValueError("Mem0 agent settings not found")
# Validate required settings
required_fields = ["base_url", "model", "mem0_config"]
for field in required_fields:
if field not in mem0_settings:
raise ValueError(
f"Missing required field '{field}' in mem0_agent settings"
)
return Mem0LLM(
user_id=kwargs.get("user_id", "default"),
system=system_prompt,
live2d_model=live2d_model,
**mem0_settings,
)
elif conversation_agent_choice == "hume_ai_agent":
settings = agent_settings.get("hume_ai_agent", {})
return HumeAIAgent(
api_key=settings.get("api_key"),
host=settings.get("host", "api.hume.ai"),
config_id=settings.get("config_id"),
idle_timeout=settings.get("idle_timeout", 15),
)
elif conversation_agent_choice == "letta_agent":
settings = agent_settings.get("letta_agent", {})
return LettaAgent(
live2d_model=live2d_model,
id=settings.get("id"),
tts_preprocessor_config=tts_preprocessor_config,
faster_first_response=settings.get("faster_first_response"),
segment_method=settings.get("segment_method"),
host=settings.get("host"),
port=settings.get("port"),
)
else:
raise ValueError(f"Unsupported agent type: {conversation_agent_choice}")
|