ConverTA / backend /api /conversation_ws.py
MikelWL's picture
Config panel MVP: persona selection + prompt additions
cd93e5e
"""WebSocket endpoint for real-time conversation streaming.
This module handles WebSocket connections for AI-to-AI conversations,
providing real-time message streaming between the frontend and backend.
Features:
- Real-time bidirectional messaging
- Connection management with heartbeat
- Message validation and formatting
- Error handling and graceful disconnections
- Conversation-specific channels
Usage:
Connect to: /ws/conversation/{conversation_id}
"""
import json
import logging
from typing import Dict, Set
from datetime import datetime
from uuid import uuid4
from fastapi import WebSocket, WebSocketDisconnect
from fastapi.websockets import WebSocketState
# Setup logging
logger = logging.getLogger(__name__)
# Connection manager for active WebSocket connections
class ConnectionManager:
"""Manages active WebSocket connections for conversations."""
def __init__(self):
"""Initialize connection manager."""
self.active_connections: Dict[str, Set[WebSocket]] = {}
async def connect(self, websocket: WebSocket, conversation_id: str):
"""Accept and register a new WebSocket connection.
Args:
websocket: The WebSocket connection
conversation_id: ID of the conversation to join
"""
await websocket.accept()
if conversation_id not in self.active_connections:
self.active_connections[conversation_id] = set()
self.active_connections[conversation_id].add(websocket)
logger.info(f"WebSocket connected to conversation {conversation_id}")
# Send connection confirmation
await websocket.send_json({
"type": "connection_status",
"status": "connected",
"conversation_id": conversation_id,
"timestamp": datetime.now().isoformat(),
"message": "Connected to conversation"
})
def disconnect(self, websocket: WebSocket, conversation_id: str):
"""Remove a WebSocket connection.
Args:
websocket: The WebSocket connection to remove
conversation_id: ID of the conversation
"""
if conversation_id in self.active_connections:
self.active_connections[conversation_id].discard(websocket)
# Clean up empty conversation rooms
if not self.active_connections[conversation_id]:
del self.active_connections[conversation_id]
logger.info(f"WebSocket disconnected from conversation {conversation_id}")
async def send_to_conversation(self, conversation_id: str, message: dict):
"""Send message to all connections in a conversation.
Args:
conversation_id: Target conversation ID
message: Message dict to send
"""
if conversation_id in self.active_connections:
connections_copy = list(self.active_connections[conversation_id])
logger.info(f"Sending message to {len(connections_copy)} connections for conversation {conversation_id}")
disconnected = []
for websocket in connections_copy:
try:
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.send_json(message)
else:
disconnected.append(websocket)
except Exception as e:
logger.error(f"Error sending message to WebSocket: {e}")
disconnected.append(websocket)
# Clean up disconnected sockets
for websocket in disconnected:
self.active_connections[conversation_id].discard(websocket)
logger.info(f"Message sent successfully to {len(connections_copy) - len(disconnected)} connections, {len(disconnected)} disconnected")
else:
logger.warning(f"No active connections found for conversation {conversation_id}")
async def broadcast_to_all(self, message: dict):
"""Broadcast message to all active connections.
Args:
message: Message dict to broadcast
"""
for conversation_id in self.active_connections:
await self.send_to_conversation(conversation_id, message)
# Global connection manager instance
manager = ConnectionManager()
async def websocket_endpoint(websocket: WebSocket, conversation_id: str):
"""WebSocket endpoint for conversation streaming.
Args:
websocket: The WebSocket connection
conversation_id: Unique identifier for the conversation
"""
await manager.connect(websocket, conversation_id)
try:
while True:
# Receive message from client
data = await websocket.receive_json()
# Validate message format
if not validate_message(data):
await websocket.send_json({
"type": "error",
"error": "Invalid message format",
"timestamp": datetime.now().isoformat()
})
continue
# Add server-side metadata
data["timestamp"] = datetime.now().isoformat()
data["message_id"] = str(uuid4())
# Log the message
logger.info(f"Received message in conversation {conversation_id}: {data['type']}")
# Handle different message types
await handle_message(data, conversation_id)
except WebSocketDisconnect:
manager.disconnect(websocket, conversation_id)
except Exception as e:
logger.error(f"WebSocket error in conversation {conversation_id}: {e}")
manager.disconnect(websocket, conversation_id)
def validate_message(data: dict) -> bool:
"""Validate incoming WebSocket message format.
Args:
data: Message data to validate
Returns:
True if message is valid
"""
required_fields = ["type", "content"]
if not isinstance(data, dict):
return False
for field in required_fields:
if field not in data:
logger.warning(f"Missing required field: {field}")
return False
# Validate message types
valid_types = [
"conversation_message",
"typing_indicator",
"conversation_control",
"heartbeat",
"start_conversation" # New message type
]
if data["type"] not in valid_types:
logger.warning(f"Invalid message type: {data['type']}")
return False
return True
async def handle_message(data: dict, conversation_id: str):
"""Handle different types of WebSocket messages.
Args:
data: Validated message data
conversation_id: Target conversation ID
"""
message_type = data["type"]
if message_type == "conversation_message":
# Forward conversation messages to all participants
await manager.send_to_conversation(conversation_id, data)
elif message_type == "typing_indicator":
# Forward typing indicators
await manager.send_to_conversation(conversation_id, data)
elif message_type == "conversation_control":
# Handle conversation control (start, pause, stop)
await handle_conversation_control(data, conversation_id)
elif message_type == "start_conversation":
# Handle starting a new conversation
await handle_start_conversation(data, conversation_id)
elif message_type == "heartbeat":
# Respond to heartbeat
await manager.send_to_conversation(conversation_id, {
"type": "heartbeat_response",
"timestamp": datetime.now().isoformat()
})
async def handle_conversation_control(data: dict, conversation_id: str):
"""Handle conversation control messages.
Args:
data: Control message data
conversation_id: Target conversation ID
"""
control_action = data.get("action")
try:
# Import here to avoid circular imports
from .conversation_service import get_conversation_service
service = get_conversation_service()
if control_action == "stop":
success = await service.stop_conversation(conversation_id)
if success:
await manager.send_to_conversation(conversation_id, {
"type": "conversation_control",
"action": "stop",
"conversation_id": conversation_id,
"timestamp": datetime.now().isoformat(),
"message": "Conversation stopped"
})
else:
await manager.send_to_conversation(conversation_id, {
"type": "error",
"error": "Failed to stop conversation",
"timestamp": datetime.now().isoformat()
})
elif control_action in ["pause", "resume"]:
# For now, just broadcast the action (pause/resume not fully implemented)
await manager.send_to_conversation(conversation_id, {
"type": "conversation_control",
"action": control_action,
"conversation_id": conversation_id,
"timestamp": datetime.now().isoformat(),
"message": f"Conversation {control_action}d"
})
logger.info(f"Conversation {conversation_id} {control_action}d")
else:
logger.warning(f"Unknown control action: {control_action}")
await manager.send_to_conversation(conversation_id, {
"type": "error",
"error": f"Unknown control action: {control_action}",
"timestamp": datetime.now().isoformat()
})
except Exception as e:
logger.error(f"Error handling conversation control: {e}")
await manager.send_to_conversation(conversation_id, {
"type": "error",
"error": f"Control error: {str(e)}",
"timestamp": datetime.now().isoformat()
})
async def handle_start_conversation(data: dict, conversation_id: str):
"""Handle starting a new conversation via WebSocket.
Args:
data: Start conversation message data
conversation_id: Target conversation ID
"""
try:
# Import here to avoid circular imports
from .conversation_service import get_conversation_service
service = get_conversation_service()
# Extract required fields
surveyor_persona_id = data.get("surveyor_persona_id")
patient_persona_id = data.get("patient_persona_id")
host = data.get("host")
model = data.get("model")
surveyor_prompt_addition = data.get("surveyor_prompt_addition")
patient_prompt_addition = data.get("patient_prompt_addition")
if not surveyor_persona_id or not patient_persona_id:
await manager.send_to_conversation(conversation_id, {
"type": "error",
"error": "Missing required persona IDs",
"timestamp": datetime.now().isoformat()
})
return
# Start the conversation
success = await service.start_conversation(
conversation_id=conversation_id,
surveyor_persona_id=surveyor_persona_id,
patient_persona_id=patient_persona_id,
host=host,
model=model,
surveyor_prompt_addition=surveyor_prompt_addition,
patient_prompt_addition=patient_prompt_addition,
)
if success:
logger.info(f"Started conversation {conversation_id} via WebSocket")
else:
await manager.send_to_conversation(conversation_id, {
"type": "error",
"error": "Failed to start conversation",
"timestamp": datetime.now().isoformat()
})
except Exception as e:
logger.error(f"Error starting conversation via WebSocket: {e}")
await manager.send_to_conversation(conversation_id, {
"type": "error",
"error": f"Start error: {str(e)}",
"timestamp": datetime.now().isoformat()
})
# Export the manager for use in other modules
__all__ = ["websocket_endpoint", "manager"]