import asyncio import json from typing import Dict, Optional, Callable import numpy as np from fastapi import WebSocket from loguru import logger from ..chat_group import ChatGroupManager from ..chat_history_manager import store_message from ..service_context import ServiceContext from .group_conversation import process_group_conversation from .single_conversation import process_single_conversation from .conversation_utils import EMOJI_LIST from .types import GroupConversationState from prompts import prompt_loader async def handle_conversation_trigger( msg_type: str, data: dict, client_uid: str, context: ServiceContext, websocket: WebSocket, client_contexts: Dict[str, ServiceContext], client_connections: Dict[str, WebSocket], chat_group_manager: ChatGroupManager, received_data_buffers: Dict[str, np.ndarray], current_conversation_tasks: Dict[str, Optional[asyncio.Task]], broadcast_to_group: Callable, ) -> None: """Handle triggers that start a conversation""" metadata = None if msg_type == "ai-speak-signal": try: # Get proactive speak prompt from config prompt_name = "proactive_speak_prompt" prompt_file = context.system_config.tool_prompts.get(prompt_name) if prompt_file: user_input = prompt_loader.load_util(prompt_file) else: logger.warning("Proactive speak prompt not configured, using default") user_input = "Please say something." except Exception as e: logger.error(f"Error loading proactive speak prompt: {e}") user_input = "Please say something." # Add metadata to indicate this is a proactive speak request # that should be skipped in both memory and history metadata = { "proactive_speak": True, "skip_memory": True, # Skip storing in AI's internal memory "skip_history": True, # Skip storing in local conversation history } await websocket.send_text( json.dumps( { "type": "full-text", "text": "AI wants to speak something...", } ) ) elif msg_type == "text-input": user_input = data.get("text", "") else: # mic-audio-end user_input = received_data_buffers[client_uid] received_data_buffers[client_uid] = np.array([]) images = data.get("images") session_emoji = np.random.choice(EMOJI_LIST) group = chat_group_manager.get_client_group(client_uid) if group and len(group.members) > 1: # Use group_id as task key for group conversations task_key = group.group_id if ( task_key not in current_conversation_tasks or current_conversation_tasks[task_key].done() ): logger.info(f"Starting new group conversation for {task_key}") current_conversation_tasks[task_key] = asyncio.create_task( process_group_conversation( client_contexts=client_contexts, client_connections=client_connections, broadcast_func=broadcast_to_group, group_members=group.members, initiator_client_uid=client_uid, user_input=user_input, images=images, session_emoji=session_emoji, metadata=metadata, ) ) else: # Use client_uid as task key for individual conversations current_conversation_tasks[client_uid] = asyncio.create_task( process_single_conversation( context=context, websocket_send=websocket.send_text, client_uid=client_uid, user_input=user_input, images=images, session_emoji=session_emoji, metadata=metadata, ) ) async def handle_individual_interrupt( client_uid: str, current_conversation_tasks: Dict[str, Optional[asyncio.Task]], context: ServiceContext, heard_response: str, ): if client_uid in current_conversation_tasks: task = current_conversation_tasks[client_uid] if task and not task.done(): task.cancel() logger.info("🛑 Conversation task was successfully interrupted") try: context.agent_engine.handle_interrupt(heard_response) except Exception as e: logger.error(f"Error handling interrupt: {e}") if context.history_uid: store_message( conf_uid=context.character_config.conf_uid, history_uid=context.history_uid, role="ai", content=heard_response, name=context.character_config.character_name, avatar=context.character_config.avatar, ) store_message( conf_uid=context.character_config.conf_uid, history_uid=context.history_uid, role="system", content="[Interrupted by user]", ) async def handle_group_interrupt( group_id: str, heard_response: str, current_conversation_tasks: Dict[str, Optional[asyncio.Task]], chat_group_manager: ChatGroupManager, client_contexts: Dict[str, ServiceContext], broadcast_to_group: Callable, ) -> None: """Handles interruption for a group conversation""" task = current_conversation_tasks.get(group_id) if not task or task.done(): return # Get state and speaker info before cancellation state = GroupConversationState.get_state(group_id) current_speaker_uid = state.current_speaker_uid if state else None # Get context from current speaker context = None group = chat_group_manager.get_group_by_id(group_id) if current_speaker_uid: context = client_contexts.get(current_speaker_uid) logger.info(f"Found current speaker context for {current_speaker_uid}") if not context and group and group.members: logger.warning(f"No context found for group {group_id}, using first member") context = client_contexts.get(next(iter(group.members))) # Now cancel the task task.cancel() try: await task except asyncio.CancelledError: logger.info(f"🛑 Group conversation {group_id} cancelled successfully.") current_conversation_tasks.pop(group_id, None) GroupConversationState.remove_state(group_id) # Clean up state after we've used it # Store messages with speaker info if context and group: for member_uid in group.members: if member_uid in client_contexts: try: member_ctx = client_contexts[member_uid] member_ctx.agent_engine.handle_interrupt(heard_response) store_message( conf_uid=member_ctx.character_config.conf_uid, history_uid=member_ctx.history_uid, role="ai", content=heard_response, name=context.character_config.character_name, avatar=context.character_config.avatar, ) store_message( conf_uid=member_ctx.character_config.conf_uid, history_uid=member_ctx.history_uid, role="system", content="[Interrupted by user]", ) except Exception as e: logger.error(f"Error handling interrupt for {member_uid}: {e}") await broadcast_to_group( list(group.members), { "type": "interrupt-signal", "text": "conversation-interrupted", }, )