|
|
"""WebSocket endpoint for real-time conversation streaming. |
|
|
|
|
|
This module handles WebSocket connections for AI-to-AI conversations, |
|
|
providing real-time message streaming between the frontend and backend. |
|
|
|
|
|
Features: |
|
|
- Real-time bidirectional messaging |
|
|
- Connection management with heartbeat |
|
|
- Message validation and formatting |
|
|
- Error handling and graceful disconnections |
|
|
- Conversation-specific channels |
|
|
|
|
|
Usage: |
|
|
Connect to: /ws/conversation/{conversation_id} |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
from typing import Dict, Set |
|
|
from datetime import datetime |
|
|
from uuid import uuid4 |
|
|
|
|
|
from fastapi import WebSocket, WebSocketDisconnect |
|
|
from fastapi.websockets import WebSocketState |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ConnectionManager: |
|
|
"""Manages active WebSocket connections for conversations.""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize connection manager.""" |
|
|
self.active_connections: Dict[str, Set[WebSocket]] = {} |
|
|
|
|
|
async def connect(self, websocket: WebSocket, conversation_id: str): |
|
|
"""Accept and register a new WebSocket connection. |
|
|
|
|
|
Args: |
|
|
websocket: The WebSocket connection |
|
|
conversation_id: ID of the conversation to join |
|
|
""" |
|
|
await websocket.accept() |
|
|
|
|
|
if conversation_id not in self.active_connections: |
|
|
self.active_connections[conversation_id] = set() |
|
|
|
|
|
self.active_connections[conversation_id].add(websocket) |
|
|
logger.info(f"WebSocket connected to conversation {conversation_id}") |
|
|
|
|
|
|
|
|
await websocket.send_json({ |
|
|
"type": "connection_status", |
|
|
"status": "connected", |
|
|
"conversation_id": conversation_id, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"message": "Connected to conversation" |
|
|
}) |
|
|
|
|
|
def disconnect(self, websocket: WebSocket, conversation_id: str): |
|
|
"""Remove a WebSocket connection. |
|
|
|
|
|
Args: |
|
|
websocket: The WebSocket connection to remove |
|
|
conversation_id: ID of the conversation |
|
|
""" |
|
|
if conversation_id in self.active_connections: |
|
|
self.active_connections[conversation_id].discard(websocket) |
|
|
|
|
|
|
|
|
if not self.active_connections[conversation_id]: |
|
|
del self.active_connections[conversation_id] |
|
|
|
|
|
logger.info(f"WebSocket disconnected from conversation {conversation_id}") |
|
|
|
|
|
async def send_to_conversation(self, conversation_id: str, message: dict): |
|
|
"""Send message to all connections in a conversation. |
|
|
|
|
|
Args: |
|
|
conversation_id: Target conversation ID |
|
|
message: Message dict to send |
|
|
""" |
|
|
if conversation_id in self.active_connections: |
|
|
connections_copy = list(self.active_connections[conversation_id]) |
|
|
logger.info(f"Sending message to {len(connections_copy)} connections for conversation {conversation_id}") |
|
|
|
|
|
disconnected = [] |
|
|
|
|
|
for websocket in connections_copy: |
|
|
try: |
|
|
if websocket.client_state == WebSocketState.CONNECTED: |
|
|
await websocket.send_json(message) |
|
|
else: |
|
|
disconnected.append(websocket) |
|
|
except Exception as e: |
|
|
logger.error(f"Error sending message to WebSocket: {e}") |
|
|
disconnected.append(websocket) |
|
|
|
|
|
|
|
|
for websocket in disconnected: |
|
|
self.active_connections[conversation_id].discard(websocket) |
|
|
|
|
|
logger.info(f"Message sent successfully to {len(connections_copy) - len(disconnected)} connections, {len(disconnected)} disconnected") |
|
|
else: |
|
|
logger.warning(f"No active connections found for conversation {conversation_id}") |
|
|
|
|
|
async def broadcast_to_all(self, message: dict): |
|
|
"""Broadcast message to all active connections. |
|
|
|
|
|
Args: |
|
|
message: Message dict to broadcast |
|
|
""" |
|
|
for conversation_id in self.active_connections: |
|
|
await self.send_to_conversation(conversation_id, message) |
|
|
|
|
|
|
|
|
|
|
|
manager = ConnectionManager() |
|
|
|
|
|
|
|
|
async def websocket_endpoint(websocket: WebSocket, conversation_id: str): |
|
|
"""WebSocket endpoint for conversation streaming. |
|
|
|
|
|
Args: |
|
|
websocket: The WebSocket connection |
|
|
conversation_id: Unique identifier for the conversation |
|
|
""" |
|
|
await manager.connect(websocket, conversation_id) |
|
|
|
|
|
try: |
|
|
while True: |
|
|
|
|
|
data = await websocket.receive_json() |
|
|
|
|
|
|
|
|
if not validate_message(data): |
|
|
await websocket.send_json({ |
|
|
"type": "error", |
|
|
"error": "Invalid message format", |
|
|
"timestamp": datetime.now().isoformat() |
|
|
}) |
|
|
continue |
|
|
|
|
|
|
|
|
data["timestamp"] = datetime.now().isoformat() |
|
|
data["message_id"] = str(uuid4()) |
|
|
|
|
|
|
|
|
logger.info(f"Received message in conversation {conversation_id}: {data['type']}") |
|
|
|
|
|
|
|
|
await handle_message(data, conversation_id) |
|
|
|
|
|
except WebSocketDisconnect: |
|
|
manager.disconnect(websocket, conversation_id) |
|
|
except Exception as e: |
|
|
logger.error(f"WebSocket error in conversation {conversation_id}: {e}") |
|
|
manager.disconnect(websocket, conversation_id) |
|
|
|
|
|
|
|
|
def validate_message(data: dict) -> bool: |
|
|
"""Validate incoming WebSocket message format. |
|
|
|
|
|
Args: |
|
|
data: Message data to validate |
|
|
|
|
|
Returns: |
|
|
True if message is valid |
|
|
""" |
|
|
required_fields = ["type", "content"] |
|
|
|
|
|
if not isinstance(data, dict): |
|
|
return False |
|
|
|
|
|
for field in required_fields: |
|
|
if field not in data: |
|
|
logger.warning(f"Missing required field: {field}") |
|
|
return False |
|
|
|
|
|
|
|
|
valid_types = [ |
|
|
"conversation_message", |
|
|
"typing_indicator", |
|
|
"conversation_control", |
|
|
"heartbeat", |
|
|
"start_conversation" |
|
|
] |
|
|
|
|
|
if data["type"] not in valid_types: |
|
|
logger.warning(f"Invalid message type: {data['type']}") |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
async def handle_message(data: dict, conversation_id: str): |
|
|
"""Handle different types of WebSocket messages. |
|
|
|
|
|
Args: |
|
|
data: Validated message data |
|
|
conversation_id: Target conversation ID |
|
|
""" |
|
|
message_type = data["type"] |
|
|
|
|
|
if message_type == "conversation_message": |
|
|
|
|
|
await manager.send_to_conversation(conversation_id, data) |
|
|
|
|
|
elif message_type == "typing_indicator": |
|
|
|
|
|
await manager.send_to_conversation(conversation_id, data) |
|
|
|
|
|
elif message_type == "conversation_control": |
|
|
|
|
|
await handle_conversation_control(data, conversation_id) |
|
|
|
|
|
elif message_type == "start_conversation": |
|
|
|
|
|
await handle_start_conversation(data, conversation_id) |
|
|
|
|
|
elif message_type == "heartbeat": |
|
|
|
|
|
await manager.send_to_conversation(conversation_id, { |
|
|
"type": "heartbeat_response", |
|
|
"timestamp": datetime.now().isoformat() |
|
|
}) |
|
|
|
|
|
|
|
|
async def handle_conversation_control(data: dict, conversation_id: str): |
|
|
"""Handle conversation control messages. |
|
|
|
|
|
Args: |
|
|
data: Control message data |
|
|
conversation_id: Target conversation ID |
|
|
""" |
|
|
control_action = data.get("action") |
|
|
|
|
|
try: |
|
|
|
|
|
from .conversation_service import get_conversation_service |
|
|
service = get_conversation_service() |
|
|
|
|
|
if control_action == "stop": |
|
|
success = await service.stop_conversation(conversation_id) |
|
|
if success: |
|
|
await manager.send_to_conversation(conversation_id, { |
|
|
"type": "conversation_control", |
|
|
"action": "stop", |
|
|
"conversation_id": conversation_id, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"message": "Conversation stopped" |
|
|
}) |
|
|
else: |
|
|
await manager.send_to_conversation(conversation_id, { |
|
|
"type": "error", |
|
|
"error": "Failed to stop conversation", |
|
|
"timestamp": datetime.now().isoformat() |
|
|
}) |
|
|
|
|
|
elif control_action in ["pause", "resume"]: |
|
|
|
|
|
await manager.send_to_conversation(conversation_id, { |
|
|
"type": "conversation_control", |
|
|
"action": control_action, |
|
|
"conversation_id": conversation_id, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"message": f"Conversation {control_action}d" |
|
|
}) |
|
|
logger.info(f"Conversation {conversation_id} {control_action}d") |
|
|
|
|
|
else: |
|
|
logger.warning(f"Unknown control action: {control_action}") |
|
|
await manager.send_to_conversation(conversation_id, { |
|
|
"type": "error", |
|
|
"error": f"Unknown control action: {control_action}", |
|
|
"timestamp": datetime.now().isoformat() |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error handling conversation control: {e}") |
|
|
await manager.send_to_conversation(conversation_id, { |
|
|
"type": "error", |
|
|
"error": f"Control error: {str(e)}", |
|
|
"timestamp": datetime.now().isoformat() |
|
|
}) |
|
|
|
|
|
|
|
|
async def handle_start_conversation(data: dict, conversation_id: str): |
|
|
"""Handle starting a new conversation via WebSocket. |
|
|
|
|
|
Args: |
|
|
data: Start conversation message data |
|
|
conversation_id: Target conversation ID |
|
|
""" |
|
|
try: |
|
|
|
|
|
from .conversation_service import get_conversation_service |
|
|
service = get_conversation_service() |
|
|
|
|
|
|
|
|
surveyor_persona_id = data.get("surveyor_persona_id") |
|
|
patient_persona_id = data.get("patient_persona_id") |
|
|
host = data.get("host") |
|
|
model = data.get("model") |
|
|
surveyor_prompt_addition = data.get("surveyor_prompt_addition") |
|
|
patient_prompt_addition = data.get("patient_prompt_addition") |
|
|
|
|
|
if not surveyor_persona_id or not patient_persona_id: |
|
|
await manager.send_to_conversation(conversation_id, { |
|
|
"type": "error", |
|
|
"error": "Missing required persona IDs", |
|
|
"timestamp": datetime.now().isoformat() |
|
|
}) |
|
|
return |
|
|
|
|
|
|
|
|
success = await service.start_conversation( |
|
|
conversation_id=conversation_id, |
|
|
surveyor_persona_id=surveyor_persona_id, |
|
|
patient_persona_id=patient_persona_id, |
|
|
host=host, |
|
|
model=model, |
|
|
surveyor_prompt_addition=surveyor_prompt_addition, |
|
|
patient_prompt_addition=patient_prompt_addition, |
|
|
) |
|
|
|
|
|
if success: |
|
|
logger.info(f"Started conversation {conversation_id} via WebSocket") |
|
|
else: |
|
|
await manager.send_to_conversation(conversation_id, { |
|
|
"type": "error", |
|
|
"error": "Failed to start conversation", |
|
|
"timestamp": datetime.now().isoformat() |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error starting conversation via WebSocket: {e}") |
|
|
await manager.send_to_conversation(conversation_id, { |
|
|
"type": "error", |
|
|
"error": f"Start error: {str(e)}", |
|
|
"timestamp": datetime.now().isoformat() |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["websocket_endpoint", "manager"] |
|
|
|