import os import json from typing import Callable from loguru import logger from fastapi import WebSocket from prompts import prompt_loader from .live2d_model import Live2dModel from .asr.asr_interface import ASRInterface from .tts.tts_interface import TTSInterface from .vad.vad_interface import VADInterface from .agent.agents.agent_interface import AgentInterface from .translate.translate_interface import TranslateInterface from .mcpp.server_registry import ServerRegistry from .mcpp.tool_manager import ToolManager from .mcpp.mcp_client import MCPClient from .mcpp.tool_executor import ToolExecutor from .mcpp.tool_adapter import ToolAdapter from .asr.asr_factory import ASRFactory from .tts.tts_factory import TTSFactory from .vad.vad_factory import VADFactory from .agent.agent_factory import AgentFactory from .translate.translate_factory import TranslateFactory from .config_manager import ( Config, AgentConfig, CharacterConfig, SystemConfig, ASRConfig, TTSConfig, VADConfig, TranslatorConfig, read_yaml, validate_config, ) class ServiceContext: """Initializes, stores, and updates the asr, tts, and llm instances and other configurations for a connected client.""" def __init__(self): self.config: Config = None self.system_config: SystemConfig = None self.character_config: CharacterConfig = None self.live2d_model: Live2dModel = None self.asr_engine: ASRInterface = None self.tts_engine: TTSInterface = None self.agent_engine: AgentInterface = None # translate_engine can be none if translation is disabled self.vad_engine: VADInterface | None = None self.translate_engine: TranslateInterface | None = None self.mcp_server_registery: ServerRegistry | None = None self.tool_adapter: ToolAdapter | None = None self.tool_manager: ToolManager | None = None self.mcp_client: MCPClient | None = None self.tool_executor: ToolExecutor | None = None # the system prompt is a combination of the persona prompt and live2d expression prompt self.system_prompt: str = None # Store the generated MCP prompt string (if MCP enabled) self.mcp_prompt: str = "" self.history_uid: str = "" # Add history_uid field self.send_text: Callable = None self.client_uid: str = None def __str__(self): return ( f"ServiceContext:\n" f" System Config: {'Loaded' if self.system_config else 'Not Loaded'}\n" f" Details: {json.dumps(self.system_config.model_dump(), indent=6) if self.system_config else 'None'}\n" f" Live2D Model: {self.live2d_model.model_info if self.live2d_model else 'Not Loaded'}\n" f" ASR Engine: {type(self.asr_engine).__name__ if self.asr_engine else 'Not Loaded'}\n" f" Config: {json.dumps(self.character_config.asr_config.model_dump(), indent=6) if self.character_config.asr_config else 'None'}\n" f" TTS Engine: {type(self.tts_engine).__name__ if self.tts_engine else 'Not Loaded'}\n" f" Config: {json.dumps(self.character_config.tts_config.model_dump(), indent=6) if self.character_config.tts_config else 'None'}\n" f" LLM Engine: {type(self.agent_engine).__name__ if self.agent_engine else 'Not Loaded'}\n" f" Agent Config: {json.dumps(self.character_config.agent_config.model_dump(), indent=6) if self.character_config.agent_config else 'None'}\n" f" VAD Engine: {type(self.vad_engine).__name__ if self.vad_engine else 'Not Loaded'}\n" f" Agent Config: {json.dumps(self.character_config.vad_config.model_dump(), indent=6) if self.character_config.vad_config else 'None'}\n" f" System Prompt: {self.system_prompt or 'Not Set'}\n" f" MCP Enabled: {'Yes' if self.mcp_client else 'No'}" ) # ==== Initializers async def _init_mcp_components(self, use_mcpp, enabled_servers): """Initializes MCP components based on configuration, dynamically fetching tool info.""" logger.debug( f"Initializing MCP components: use_mcpp={use_mcpp}, enabled_servers={enabled_servers}" ) # Reset MCP components first self.mcp_server_registery = None self.tool_manager = None self.mcp_client = None self.tool_executor = None self.json_detector = None self.mcp_prompt = "" if use_mcpp and enabled_servers: # 1. Initialize ServerRegistry self.mcp_server_registery = ServerRegistry() logger.info("ServerRegistry initialized or referenced.") # 2. Use ToolAdapter to get the MCP prompt and tools if not self.tool_adapter: logger.error( "ToolAdapter not initialized before calling _init_mcp_components." ) self.mcp_prompt = "[Error: ToolAdapter not initialized]" return # Exit if ToolAdapter is mandatory and not initialized try: ( mcp_prompt_string, openai_tools, claude_tools, ) = await self.tool_adapter.get_tools(enabled_servers) # Store the generated prompt string self.mcp_prompt = mcp_prompt_string logger.info( f"Dynamically generated MCP prompt string (length: {len(self.mcp_prompt)})." ) logger.info( f"Dynamically formatted tools - OpenAI: {len(openai_tools)}, Claude: {len(claude_tools)}." ) # 3. Initialize ToolManager with the fetched formatted tools _, raw_tools_dict = await self.tool_adapter.get_server_and_tool_info( enabled_servers ) self.tool_manager = ToolManager( formatted_tools_openai=openai_tools, formatted_tools_claude=claude_tools, initial_tools_dict=raw_tools_dict, ) logger.info("ToolManager initialized with dynamically fetched tools.") except Exception as e: logger.error( f"Failed during dynamic MCP tool construction: {e}", exc_info=True ) # Ensure dependent components are not created if construction fails self.tool_manager = None self.mcp_prompt = "[Error constructing MCP tools/prompt]" # 4. Initialize MCPClient if self.mcp_server_registery: self.mcp_client = MCPClient( self.mcp_server_registery, self.send_text, self.client_uid ) logger.info("MCPClient initialized for this session.") else: logger.error( "MCP enabled but ServerRegistry not available. MCPClient not created." ) self.mcp_client = None # Ensure it's None # 5. Initialize ToolExecutor if self.mcp_client and self.tool_manager: self.tool_executor = ToolExecutor(self.mcp_client, self.tool_manager) logger.info("ToolExecutor initialized for this session.") else: logger.warning( "MCPClient or ToolManager not available. ToolExecutor not created." ) self.tool_executor = None # Ensure it's None logger.info("StreamJSONDetector initialized for this session.") elif use_mcpp and not enabled_servers: logger.warning( "use_mcpp is True, but mcp_enabled_servers list is empty. MCP components not initialized." ) else: logger.debug( "MCP components not initialized (use_mcpp is False or no enabled servers)." ) async def close(self): """Clean up resources, especially the MCPClient.""" logger.info("Closing ServiceContext resources...") if self.mcp_client: logger.info(f"Closing MCPClient for context instance {id(self)}...") await self.mcp_client.aclose() self.mcp_client = None if self.agent_engine and hasattr(self.agent_engine, "close"): await self.agent_engine.close() # Ensure agent resources are also closed logger.info("ServiceContext closed.") async def load_cache( self, config: Config, system_config: SystemConfig, character_config: CharacterConfig, live2d_model: Live2dModel, asr_engine: ASRInterface, tts_engine: TTSInterface, vad_engine: VADInterface, agent_engine: AgentInterface, translate_engine: TranslateInterface | None, mcp_server_registery: ServerRegistry | None = None, tool_adapter: ToolAdapter | None = None, send_text: Callable = None, client_uid: str = None, ) -> None: """ Load the ServiceContext with the reference of the provided instances. Pass by reference so no reinitialization will be done. """ if not character_config: raise ValueError("character_config cannot be None") if not system_config: raise ValueError("system_config cannot be None") self.config = config self.system_config = system_config self.character_config = character_config self.live2d_model = live2d_model self.asr_engine = asr_engine self.tts_engine = tts_engine self.vad_engine = vad_engine self.agent_engine = agent_engine self.translate_engine = translate_engine # Load potentially shared components by reference self.mcp_server_registery = mcp_server_registery self.tool_adapter = tool_adapter self.send_text = send_text self.client_uid = client_uid # Initialize session-specific MCP components await self._init_mcp_components( self.character_config.agent_config.agent_settings.basic_memory_agent.use_mcpp, self.character_config.agent_config.agent_settings.basic_memory_agent.mcp_enabled_servers, ) logger.debug(f"Loaded service context with cache: {character_config}") async def load_from_config(self, config: Config) -> None: """ Load the ServiceContext with the config. Reinitialize the instances if the config is different. Parameters: - config (Dict): The configuration dictionary. """ if not self.config: self.config = config if not self.system_config: self.system_config = config.system_config if not self.character_config: self.character_config = config.character_config # update all sub-configs # init live2d from character config self.init_live2d(config.character_config.live2d_model_name) # init asr from character config self.init_asr(config.character_config.asr_config) # init tts from character config self.init_tts(config.character_config.tts_config) # init vad from character config self.init_vad(config.character_config.vad_config) # Initialize shared ToolAdapter if it doesn't exist yet if ( not self.tool_adapter and config.character_config.agent_config.agent_settings.basic_memory_agent.use_mcpp ): if not self.mcp_server_registery: logger.info( "Initializing shared ServerRegistry within load_from_config." ) self.mcp_server_registery = ServerRegistry() logger.info("Initializing shared ToolAdapter within load_from_config.") self.tool_adapter = ToolAdapter(server_registery=self.mcp_server_registery) # Initialize MCP Components before initializing Agent await self._init_mcp_components( config.character_config.agent_config.agent_settings.basic_memory_agent.use_mcpp, config.character_config.agent_config.agent_settings.basic_memory_agent.mcp_enabled_servers, ) # init agent from character config await self.init_agent( config.character_config.agent_config, config.character_config.persona_prompt, ) self.init_translate( config.character_config.tts_preprocessor_config.translator_config ) # store typed config references self.config = config self.system_config = config.system_config or self.system_config self.character_config = config.character_config def init_live2d(self, live2d_model_name: str) -> None: logger.info(f"Initializing Live2D: {live2d_model_name}") try: self.live2d_model = Live2dModel(live2d_model_name) self.character_config.live2d_model_name = live2d_model_name except Exception as e: logger.critical(f"Error initializing Live2D: {e}") logger.critical("Try to proceed without Live2D...") def init_asr(self, asr_config: ASRConfig) -> None: if not self.asr_engine or (self.character_config.asr_config != asr_config): logger.info(f"Initializing ASR: {asr_config.asr_model}") self.asr_engine = ASRFactory.get_asr_system( asr_config.asr_model, **getattr(asr_config, asr_config.asr_model).model_dump(), ) # saving config should be done after successful initialization self.character_config.asr_config = asr_config else: logger.info("ASR already initialized with the same config.") def init_tts(self, tts_config: TTSConfig) -> None: if not self.tts_engine or (self.character_config.tts_config != tts_config): logger.info(f"Initializing TTS: {tts_config.tts_model}") self.tts_engine = TTSFactory.get_tts_engine( tts_config.tts_model, **getattr(tts_config, tts_config.tts_model.lower()).model_dump(), ) # saving config should be done after successful initialization self.character_config.tts_config = tts_config else: logger.info("TTS already initialized with the same config.") def init_vad(self, vad_config: VADConfig) -> None: if vad_config.vad_model is None: logger.info("VAD is disabled.") self.vad_engine = None return if not self.vad_engine or (self.character_config.vad_config != vad_config): logger.info(f"Initializing VAD: {vad_config.vad_model}") self.vad_engine = VADFactory.get_vad_engine( vad_config.vad_model, **getattr(vad_config, vad_config.vad_model.lower()).model_dump(), ) # saving config should be done after successful initialization self.character_config.vad_config = vad_config else: logger.info("VAD already initialized with the same config.") async def init_agent(self, agent_config: AgentConfig, persona_prompt: str) -> None: """Initialize or update the LLM engine based on agent configuration.""" logger.info(f"Initializing Agent: {agent_config.conversation_agent_choice}") if ( self.agent_engine is not None and agent_config == self.character_config.agent_config and persona_prompt == self.character_config.persona_prompt ): logger.debug("Agent already initialized with the same config.") return system_prompt = await self.construct_system_prompt(persona_prompt) # Pass avatar to agent factory avatar = self.character_config.avatar or "" # Get avatar from config try: self.agent_engine = AgentFactory.create_agent( conversation_agent_choice=agent_config.conversation_agent_choice, agent_settings=agent_config.agent_settings.model_dump(), llm_configs=agent_config.llm_configs.model_dump(), system_prompt=system_prompt, live2d_model=self.live2d_model, tts_preprocessor_config=self.character_config.tts_preprocessor_config, character_avatar=avatar, system_config=self.system_config.model_dump(), tool_manager=self.tool_manager, tool_executor=self.tool_executor, mcp_prompt_string=self.mcp_prompt, ) logger.debug(f"Agent choice: {agent_config.conversation_agent_choice}") logger.debug(f"System prompt: {system_prompt}") # Save the current configuration self.character_config.agent_config = agent_config self.system_prompt = system_prompt except Exception as e: logger.error(f"Failed to initialize agent: {e}") raise def init_translate(self, translator_config: TranslatorConfig) -> None: """Initialize or update the translation engine based on the configuration.""" if not translator_config.translate_audio: logger.debug("Translation is disabled.") return if ( not self.translate_engine or self.character_config.tts_preprocessor_config.translator_config != translator_config ): logger.info( f"Initializing Translator: {translator_config.translate_provider}" ) self.translate_engine = TranslateFactory.get_translator( translator_config.translate_provider, getattr( translator_config, translator_config.translate_provider ).model_dump(), ) self.character_config.tts_preprocessor_config.translator_config = ( translator_config ) else: logger.info("Translation already initialized with the same config.") # ==== utils async def construct_system_prompt(self, persona_prompt: str) -> str: """ Append tool prompts to persona prompt. Parameters: - persona_prompt (str): The persona prompt. Returns: - str: The system prompt with all tool prompts appended. """ logger.debug(f"constructing persona_prompt: '''{persona_prompt}'''") for prompt_name, prompt_file in self.system_config.tool_prompts.items(): if ( prompt_name == "group_conversation_prompt" or prompt_name == "proactive_speak_prompt" ): continue prompt_content = prompt_loader.load_util(prompt_file) if prompt_name == "live2d_expression_prompt": prompt_content = prompt_content.replace( "[]", self.live2d_model.emo_str ) if prompt_name == "mcp_prompt": continue persona_prompt += prompt_content logger.debug("\n === System Prompt ===") logger.debug(persona_prompt) return persona_prompt async def handle_config_switch( self, websocket: WebSocket, config_file_name: str, ) -> None: """ Handle the configuration switch request. Change the configuration to a new config and notify the client. Parameters: - websocket (WebSocket): The WebSocket connection. - config_file_name (str): The name of the configuration file. """ try: new_character_config_data = None if config_file_name == "conf.yaml": # Load base config new_character_config_data = read_yaml("conf.yaml").get( "character_config" ) else: # Load alternative config and merge with base config characters_dir = self.system_config.config_alts_dir file_path = os.path.normpath( os.path.join(characters_dir, config_file_name) ) if not file_path.startswith(characters_dir): raise ValueError("Invalid configuration file path") alt_config_data = read_yaml(file_path).get("character_config") # Start with original config data and perform a deep merge new_character_config_data = deep_merge( self.config.character_config.model_dump(), alt_config_data ) if new_character_config_data: new_config = { "system_config": self.system_config.model_dump(), "character_config": new_character_config_data, } new_config = validate_config(new_config) await self.load_from_config(new_config) # Await the async load logger.debug(f"New config: {self}") logger.debug( f"New character config: {self.character_config.model_dump()}" ) # Send responses to client await websocket.send_text( json.dumps( { "type": "set-model-and-conf", "model_info": self.live2d_model.model_info, "conf_name": self.character_config.conf_name, "conf_uid": self.character_config.conf_uid, } ) ) await websocket.send_text( json.dumps( { "type": "config-switched", "message": f"Switched to config: {config_file_name}", } ) ) logger.info(f"Configuration switched to {config_file_name}") else: raise ValueError( f"Failed to load configuration from {config_file_name}" ) except Exception as e: logger.error(f"Error switching configuration: {e}") logger.debug(self) await websocket.send_text( json.dumps( { "type": "error", "message": f"Error switching configuration: {str(e)}", } ) ) raise e def deep_merge(dict1, dict2): """ Recursively merges dict2 into dict1, prioritizing values from dict2. """ result = dict1.copy() for key, value in dict2.items(): if key in result and isinstance(result[key], dict) and isinstance(value, dict): result[key] = deep_merge(result[key], value) else: result[key] = value return result