scratch_chat / chat_agent /websocket /chat_websocket.py
WebashalarForML's picture
Upload 178 files
330b6e4 verified
"""
WebSocket handlers for real-time chat communication.
This module provides Flask-SocketIO event handlers for the multi-language chat agent,
including message processing, language switching, typing indicators, and connection management.
"""
import logging
import json
from datetime import datetime
from typing import Dict, Any, Optional
from uuid import uuid4
from flask import request, session
from flask_socketio import emit, join_room, leave_room, disconnect
from sqlalchemy.exc import SQLAlchemyError
from ..services.chat_agent import ChatAgent, ChatAgentError
from ..services.session_manager import SessionManager, SessionNotFoundError, SessionExpiredError
from .message_validator import MessageValidator
from .connection_manager import ConnectionManager
logger = logging.getLogger(__name__)
class ChatWebSocketHandler:
"""Handles WebSocket events for real-time chat communication."""
def __init__(self, chat_agent: ChatAgent, session_manager: SessionManager,
connection_manager: ConnectionManager):
"""
Initialize the WebSocket handler.
Args:
chat_agent: Chat agent service for message processing
session_manager: Session manager for session handling
connection_manager: Connection manager for WebSocket connections
"""
self.chat_agent = chat_agent
self.session_manager = session_manager
self.connection_manager = connection_manager
self.message_validator = MessageValidator()
def handle_connect(self, auth: Optional[Dict[str, Any]] = None) -> bool:
"""
Handle WebSocket connection establishment.
Args:
auth: Authentication data containing session_id and user_id
Returns:
bool: True if connection is accepted, False otherwise
"""
try:
# Extract connection info
client_id = request.sid
user_agent = request.headers.get('User-Agent', 'Unknown')
logger.info(f"WebSocket connection attempt from {client_id}")
# Validate authentication data
if not auth or 'session_id' not in auth or 'user_id' not in auth:
logger.warning(f"Connection rejected: missing auth data for {client_id}")
return False
session_id = auth['session_id']
user_id = auth['user_id']
# Validate session exists and is active
try:
chat_session = self.session_manager.get_session(session_id)
# Verify user owns the session
if chat_session.user_id != user_id:
logger.warning(f"Connection rejected: user {user_id} doesn't own session {session_id}")
return False
except (SessionNotFoundError, SessionExpiredError) as e:
logger.warning(f"Connection rejected: {e}")
return False
# Register connection
connection_info = {
'client_id': client_id,
'session_id': session_id,
'user_id': user_id,
'connected_at': datetime.utcnow().isoformat(),
'user_agent': user_agent,
'language': chat_session.language
}
self.connection_manager.add_connection(client_id, connection_info)
# Join session room for targeted messaging
join_room(f"session_{session_id}")
# Update session activity
self.session_manager.update_session_activity(session_id)
# Send connection confirmation
emit('connection_status', {
'status': 'connected',
'session_id': session_id,
'language': chat_session.language,
'message_count': chat_session.message_count,
'timestamp': datetime.utcnow().isoformat()
})
logger.info(f"WebSocket connection established for session {session_id}")
return True
except Exception as e:
logger.error(f"Error handling WebSocket connection: {e}")
return False
def handle_disconnect(self) -> None:
"""Handle WebSocket disconnection."""
try:
client_id = request.sid
# Get connection info
connection_info = self.connection_manager.get_connection(client_id)
if not connection_info:
logger.warning(f"Disconnect from unknown client {client_id}")
return
session_id = connection_info['session_id']
# Leave session room
leave_room(f"session_{session_id}")
# Remove connection
self.connection_manager.remove_connection(client_id)
# Update session activity (final update)
try:
self.session_manager.update_session_activity(session_id)
except (SessionNotFoundError, SessionExpiredError):
# Session may have expired, that's okay
pass
logger.info(f"WebSocket disconnected for session {session_id}")
except Exception as e:
logger.error(f"Error handling WebSocket disconnect: {e}")
def handle_message(self, data: Dict[str, Any]) -> None:
"""
Handle incoming chat messages.
Args:
data: Message data containing content, session_id, and optional language
"""
try:
client_id = request.sid
# Get connection info
connection_info = self.connection_manager.get_connection(client_id)
if not connection_info:
emit('error', {'message': 'Connection not found', 'code': 'CONNECTION_NOT_FOUND'})
return
session_id = connection_info['session_id']
# Validate message data
validation_result = self.message_validator.validate_message(data)
if not validation_result['valid']:
emit('error', {
'message': 'Invalid message format',
'details': validation_result['errors'],
'code': 'INVALID_MESSAGE'
})
return
message_content = validation_result['sanitized_content']
language = data.get('language') # Optional language override
# Send acknowledgment
emit('message_received', {
'message_id': str(uuid4()),
'timestamp': datetime.utcnow().isoformat()
})
# Indicate processing started
emit('processing_status', {
'status': 'processing',
'session_id': session_id,
'timestamp': datetime.utcnow().isoformat()
})
# Process message with streaming response
try:
for response_chunk in self.chat_agent.stream_response(
session_id, message_content, language
):
if response_chunk['type'] == 'start':
emit('response_start', {
'session_id': session_id,
'language': response_chunk['language'],
'timestamp': response_chunk['timestamp']
})
elif response_chunk['type'] == 'chunk':
emit('response_chunk', {
'content': response_chunk['content'],
'timestamp': response_chunk['timestamp']
})
elif response_chunk['type'] == 'complete':
emit('response_complete', {
'message_id': response_chunk['message_id'],
'total_chunks': response_chunk['total_chunks'],
'processing_time': response_chunk['processing_time'],
'timestamp': response_chunk['timestamp']
})
elif response_chunk['type'] == 'error':
emit('error', {
'message': 'Processing error',
'details': response_chunk['error'],
'code': 'PROCESSING_ERROR'
})
break
except ChatAgentError as e:
logger.error(f"Chat agent error processing message: {e}")
emit('error', {
'message': 'Failed to process message',
'details': str(e),
'code': 'CHAT_AGENT_ERROR'
})
# Update connection activity
self.connection_manager.update_connection_activity(client_id)
except Exception as e:
logger.error(f"Error handling message: {e}")
emit('error', {
'message': 'Internal server error',
'code': 'INTERNAL_ERROR'
})
def handle_language_switch(self, data: Dict[str, Any]) -> None:
"""
Handle programming language context switching.
Args:
data: Language switch data containing new language
"""
try:
client_id = request.sid
# Get connection info
connection_info = self.connection_manager.get_connection(client_id)
if not connection_info:
emit('error', {'message': 'Connection not found', 'code': 'CONNECTION_NOT_FOUND'})
return
session_id = connection_info['session_id']
# Validate language switch data
validation_result = self.message_validator.validate_language_switch(data)
if not validation_result['valid']:
emit('error', {
'message': 'Invalid language switch request',
'details': validation_result['errors'],
'code': 'INVALID_LANGUAGE_SWITCH'
})
return
new_language = validation_result['language']
# Process language switch
try:
switch_result = self.chat_agent.switch_language(session_id, new_language)
# Update connection info
connection_info['language'] = new_language
self.connection_manager.update_connection(client_id, connection_info)
# Send confirmation
emit('language_switched', {
'previous_language': switch_result['previous_language'],
'new_language': switch_result['new_language'],
'message': switch_result['message'],
'timestamp': switch_result['timestamp']
})
logger.info(f"Language switched to {new_language} for session {session_id}")
except ChatAgentError as e:
logger.error(f"Error switching language: {e}")
emit('error', {
'message': 'Failed to switch language',
'details': str(e),
'code': 'LANGUAGE_SWITCH_ERROR'
})
# Update connection activity
self.connection_manager.update_connection_activity(client_id)
except Exception as e:
logger.error(f"Error handling language switch: {e}")
emit('error', {
'message': 'Internal server error',
'code': 'INTERNAL_ERROR'
})
def handle_typing_start(self, data: Dict[str, Any]) -> None:
"""
Handle typing indicator start.
Args:
data: Typing data (currently unused but reserved for future use)
"""
try:
client_id = request.sid
# Get connection info
connection_info = self.connection_manager.get_connection(client_id)
if not connection_info:
return
session_id = connection_info['session_id']
# Broadcast typing indicator to session room (excluding sender)
emit('user_typing', {
'session_id': session_id,
'timestamp': datetime.utcnow().isoformat()
}, room=f"session_{session_id}", include_self=False)
except Exception as e:
logger.error(f"Error handling typing start: {e}")
def handle_typing_stop(self, data: Dict[str, Any]) -> None:
"""
Handle typing indicator stop.
Args:
data: Typing data (currently unused but reserved for future use)
"""
try:
client_id = request.sid
# Get connection info
connection_info = self.connection_manager.get_connection(client_id)
if not connection_info:
return
session_id = connection_info['session_id']
# Broadcast typing stop to session room (excluding sender)
emit('user_typing_stop', {
'session_id': session_id,
'timestamp': datetime.utcnow().isoformat()
}, room=f"session_{session_id}", include_self=False)
except Exception as e:
logger.error(f"Error handling typing stop: {e}")
def handle_ping(self, data: Dict[str, Any]) -> None:
"""
Handle ping requests for connection health checks.
Args:
data: Ping data containing timestamp
"""
try:
client_id = request.sid
# Update connection activity
self.connection_manager.update_connection_activity(client_id)
# Send pong response
emit('pong', {
'timestamp': datetime.utcnow().isoformat(),
'client_timestamp': data.get('timestamp')
})
except Exception as e:
logger.error(f"Error handling ping: {e}")
def handle_get_session_info(self, data: Dict[str, Any]) -> None:
"""
Handle session information requests.
Args:
data: Request data (currently unused)
"""
try:
client_id = request.sid
# Get connection info
connection_info = self.connection_manager.get_connection(client_id)
if not connection_info:
emit('error', {'message': 'Connection not found', 'code': 'CONNECTION_NOT_FOUND'})
return
session_id = connection_info['session_id']
# Get session info from chat agent
try:
session_info = self.chat_agent.get_session_info(session_id)
emit('session_info', {
'session': session_info['session'],
'language_context': session_info['language_context'],
'statistics': session_info['statistics'],
'supported_languages': session_info['supported_languages'],
'timestamp': datetime.utcnow().isoformat()
})
except ChatAgentError as e:
logger.error(f"Error getting session info: {e}")
emit('error', {
'message': 'Failed to get session info',
'details': str(e),
'code': 'SESSION_INFO_ERROR'
})
except Exception as e:
logger.error(f"Error handling session info request: {e}")
emit('error', {
'message': 'Internal server error',
'code': 'INTERNAL_ERROR'
})
def create_chat_websocket_handler(chat_agent: ChatAgent, session_manager: SessionManager,
connection_manager: ConnectionManager) -> ChatWebSocketHandler:
"""
Factory function to create a ChatWebSocketHandler instance.
Args:
chat_agent: Chat agent service
session_manager: Session manager service
connection_manager: Connection manager service
Returns:
ChatWebSocketHandler: Configured WebSocket handler
"""
return ChatWebSocketHandler(chat_agent, session_manager, connection_manager)