"""WebSocket endpoint for real-time conversation streaming. This module handles WebSocket connections for AI-to-AI conversations, providing real-time message streaming between the frontend and backend. Features: - Real-time bidirectional messaging - Connection management with heartbeat - Message validation and formatting - Error handling and graceful disconnections - Conversation-specific channels Usage: Connect to: /ws/conversation/{conversation_id} """ import json import logging from typing import Dict, Set from datetime import datetime from uuid import uuid4 from fastapi import WebSocket, WebSocketDisconnect from fastapi.websockets import WebSocketState # Setup logging logger = logging.getLogger(__name__) # Connection manager for active WebSocket connections class ConnectionManager: """Manages active WebSocket connections for conversations.""" def __init__(self): """Initialize connection manager.""" self.active_connections: Dict[str, Set[WebSocket]] = {} async def connect(self, websocket: WebSocket, conversation_id: str): """Accept and register a new WebSocket connection. Args: websocket: The WebSocket connection conversation_id: ID of the conversation to join """ await websocket.accept() if conversation_id not in self.active_connections: self.active_connections[conversation_id] = set() self.active_connections[conversation_id].add(websocket) logger.info(f"WebSocket connected to conversation {conversation_id}") # Send connection confirmation await websocket.send_json({ "type": "connection_status", "status": "connected", "conversation_id": conversation_id, "timestamp": datetime.now().isoformat(), "message": "Connected to conversation" }) def disconnect(self, websocket: WebSocket, conversation_id: str): """Remove a WebSocket connection. Args: websocket: The WebSocket connection to remove conversation_id: ID of the conversation """ if conversation_id in self.active_connections: self.active_connections[conversation_id].discard(websocket) # Clean up empty conversation rooms if not self.active_connections[conversation_id]: del self.active_connections[conversation_id] logger.info(f"WebSocket disconnected from conversation {conversation_id}") async def send_to_conversation(self, conversation_id: str, message: dict): """Send message to all connections in a conversation. Args: conversation_id: Target conversation ID message: Message dict to send """ if conversation_id in self.active_connections: connections_copy = list(self.active_connections[conversation_id]) logger.info(f"Sending message to {len(connections_copy)} connections for conversation {conversation_id}") disconnected = [] for websocket in connections_copy: try: if websocket.client_state == WebSocketState.CONNECTED: await websocket.send_json(message) else: disconnected.append(websocket) except Exception as e: logger.error(f"Error sending message to WebSocket: {e}") disconnected.append(websocket) # Clean up disconnected sockets for websocket in disconnected: self.active_connections[conversation_id].discard(websocket) logger.info(f"Message sent successfully to {len(connections_copy) - len(disconnected)} connections, {len(disconnected)} disconnected") else: logger.warning(f"No active connections found for conversation {conversation_id}") async def broadcast_to_all(self, message: dict): """Broadcast message to all active connections. Args: message: Message dict to broadcast """ for conversation_id in self.active_connections: await self.send_to_conversation(conversation_id, message) # Global connection manager instance manager = ConnectionManager() async def websocket_endpoint(websocket: WebSocket, conversation_id: str): """WebSocket endpoint for conversation streaming. Args: websocket: The WebSocket connection conversation_id: Unique identifier for the conversation """ await manager.connect(websocket, conversation_id) try: while True: # Receive message from client data = await websocket.receive_json() # Validate message format if not validate_message(data): await websocket.send_json({ "type": "error", "error": "Invalid message format", "timestamp": datetime.now().isoformat() }) continue # Add server-side metadata data["timestamp"] = datetime.now().isoformat() data["message_id"] = str(uuid4()) # Log the message logger.info(f"Received message in conversation {conversation_id}: {data['type']}") # Handle different message types await handle_message(data, conversation_id) except WebSocketDisconnect: manager.disconnect(websocket, conversation_id) except Exception as e: logger.error(f"WebSocket error in conversation {conversation_id}: {e}") manager.disconnect(websocket, conversation_id) def validate_message(data: dict) -> bool: """Validate incoming WebSocket message format. Args: data: Message data to validate Returns: True if message is valid """ required_fields = ["type", "content"] if not isinstance(data, dict): return False for field in required_fields: if field not in data: logger.warning(f"Missing required field: {field}") return False # Validate message types valid_types = [ "conversation_message", "typing_indicator", "conversation_control", "heartbeat", "start_conversation" # New message type ] if data["type"] not in valid_types: logger.warning(f"Invalid message type: {data['type']}") return False return True async def handle_message(data: dict, conversation_id: str): """Handle different types of WebSocket messages. Args: data: Validated message data conversation_id: Target conversation ID """ message_type = data["type"] if message_type == "conversation_message": # Forward conversation messages to all participants await manager.send_to_conversation(conversation_id, data) elif message_type == "typing_indicator": # Forward typing indicators await manager.send_to_conversation(conversation_id, data) elif message_type == "conversation_control": # Handle conversation control (start, pause, stop) await handle_conversation_control(data, conversation_id) elif message_type == "start_conversation": # Handle starting a new conversation await handle_start_conversation(data, conversation_id) elif message_type == "heartbeat": # Respond to heartbeat await manager.send_to_conversation(conversation_id, { "type": "heartbeat_response", "timestamp": datetime.now().isoformat() }) async def handle_conversation_control(data: dict, conversation_id: str): """Handle conversation control messages. Args: data: Control message data conversation_id: Target conversation ID """ control_action = data.get("action") try: # Import here to avoid circular imports from .conversation_service import get_conversation_service service = get_conversation_service() if control_action == "stop": success = await service.stop_conversation(conversation_id) if success: await manager.send_to_conversation(conversation_id, { "type": "conversation_control", "action": "stop", "conversation_id": conversation_id, "timestamp": datetime.now().isoformat(), "message": "Conversation stopped" }) else: await manager.send_to_conversation(conversation_id, { "type": "error", "error": "Failed to stop conversation", "timestamp": datetime.now().isoformat() }) elif control_action in ["pause", "resume"]: # For now, just broadcast the action (pause/resume not fully implemented) await manager.send_to_conversation(conversation_id, { "type": "conversation_control", "action": control_action, "conversation_id": conversation_id, "timestamp": datetime.now().isoformat(), "message": f"Conversation {control_action}d" }) logger.info(f"Conversation {conversation_id} {control_action}d") else: logger.warning(f"Unknown control action: {control_action}") await manager.send_to_conversation(conversation_id, { "type": "error", "error": f"Unknown control action: {control_action}", "timestamp": datetime.now().isoformat() }) except Exception as e: logger.error(f"Error handling conversation control: {e}") await manager.send_to_conversation(conversation_id, { "type": "error", "error": f"Control error: {str(e)}", "timestamp": datetime.now().isoformat() }) async def handle_start_conversation(data: dict, conversation_id: str): """Handle starting a new conversation via WebSocket. Args: data: Start conversation message data conversation_id: Target conversation ID """ try: # Import here to avoid circular imports from .conversation_service import get_conversation_service service = get_conversation_service() # Extract required fields surveyor_persona_id = data.get("surveyor_persona_id") patient_persona_id = data.get("patient_persona_id") host = data.get("host") model = data.get("model") surveyor_prompt_addition = data.get("surveyor_prompt_addition") patient_prompt_addition = data.get("patient_prompt_addition") if not surveyor_persona_id or not patient_persona_id: await manager.send_to_conversation(conversation_id, { "type": "error", "error": "Missing required persona IDs", "timestamp": datetime.now().isoformat() }) return # Start the conversation success = await service.start_conversation( conversation_id=conversation_id, surveyor_persona_id=surveyor_persona_id, patient_persona_id=patient_persona_id, host=host, model=model, surveyor_prompt_addition=surveyor_prompt_addition, patient_prompt_addition=patient_prompt_addition, ) if success: logger.info(f"Started conversation {conversation_id} via WebSocket") else: await manager.send_to_conversation(conversation_id, { "type": "error", "error": "Failed to start conversation", "timestamp": datetime.now().isoformat() }) except Exception as e: logger.error(f"Error starting conversation via WebSocket: {e}") await manager.send_to_conversation(conversation_id, { "type": "error", "error": f"Start error: {str(e)}", "timestamp": datetime.now().isoformat() }) # Export the manager for use in other modules __all__ = ["websocket_endpoint", "manager"]