|
|
"""Thread-safe WebSocket Manager for Gradio Frontend. |
|
|
|
|
|
This module provides a robust WebSocket connection that runs in a background |
|
|
thread with its own event loop, completely separated from Gradio's synchronous |
|
|
environment. Uses thread-safe queues for communication. |
|
|
|
|
|
Architecture: |
|
|
Gradio (Sync) ←→ Message Queues ←→ Background Thread (Async WebSocket) |
|
|
|
|
|
Usage: |
|
|
manager = WebSocketManager("ws://localhost:8000/ws/conversation/123") |
|
|
manager.start() |
|
|
|
|
|
# Send messages (sync) |
|
|
manager.send_message({"type": "start_conversation", ...}) |
|
|
|
|
|
# Get received messages (sync) |
|
|
messages = manager.get_messages() |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import threading |
|
|
import time |
|
|
import json |
|
|
import queue |
|
|
import logging |
|
|
from typing import Dict, List, Optional |
|
|
from datetime import datetime |
|
|
from enum import Enum |
|
|
|
|
|
import websockets |
|
|
from websockets.exceptions import ConnectionClosed, WebSocketException |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ManagerState(Enum): |
|
|
"""WebSocket manager states.""" |
|
|
STOPPED = "stopped" |
|
|
STARTING = "starting" |
|
|
CONNECTED = "connected" |
|
|
DISCONNECTED = "disconnected" |
|
|
ERROR = "error" |
|
|
|
|
|
|
|
|
class WebSocketManager: |
|
|
"""Thread-safe WebSocket manager for Gradio frontend.""" |
|
|
|
|
|
def __init__(self, url: str, conversation_id: str, extra_headers: Optional[Dict[str, str]] = None): |
|
|
"""Initialize WebSocket manager. |
|
|
|
|
|
Args: |
|
|
url: WebSocket server URL |
|
|
conversation_id: Unique conversation identifier |
|
|
extra_headers: Optional headers to send during the WebSocket handshake |
|
|
""" |
|
|
self.url = url |
|
|
self.conversation_id = conversation_id |
|
|
self.extra_headers = extra_headers |
|
|
|
|
|
|
|
|
self.state = ManagerState.STOPPED |
|
|
self.last_error = None |
|
|
|
|
|
|
|
|
self.thread = None |
|
|
self.loop = None |
|
|
self.websocket = None |
|
|
self._stop_event = threading.Event() |
|
|
|
|
|
|
|
|
self.outbound_queue = queue.Queue() |
|
|
self.inbound_queue = queue.Queue() |
|
|
self.max_messages = 100 |
|
|
|
|
|
|
|
|
self.messages_sent = 0 |
|
|
self.messages_received = 0 |
|
|
self.connection_time = None |
|
|
|
|
|
def start(self) -> bool: |
|
|
"""Start the WebSocket manager in background thread. |
|
|
|
|
|
Returns: |
|
|
True if started successfully |
|
|
""" |
|
|
if self.thread and self.thread.is_alive(): |
|
|
logger.warning("WebSocket manager already running") |
|
|
return True |
|
|
|
|
|
try: |
|
|
self.state = ManagerState.STARTING |
|
|
self._stop_event.clear() |
|
|
|
|
|
|
|
|
self.thread = threading.Thread(target=self._run_websocket, daemon=True) |
|
|
self.thread.start() |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
while time.time() - start_time < 10: |
|
|
if self.state == ManagerState.CONNECTED: |
|
|
logger.info(f"WebSocket manager started successfully") |
|
|
return True |
|
|
elif self.state == ManagerState.ERROR: |
|
|
logger.error(f"WebSocket manager failed to start: {self.last_error}") |
|
|
return False |
|
|
time.sleep(0.1) |
|
|
|
|
|
logger.error("WebSocket manager startup timed out") |
|
|
self.state = ManagerState.ERROR |
|
|
self.last_error = "Startup timeout" |
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
self.state = ManagerState.ERROR |
|
|
self.last_error = str(e) |
|
|
logger.error(f"Error starting WebSocket manager: {e}") |
|
|
return False |
|
|
|
|
|
def stop(self): |
|
|
"""Stop the WebSocket manager.""" |
|
|
logger.info("Stopping WebSocket manager...") |
|
|
self._stop_event.set() |
|
|
|
|
|
if self.thread and self.thread.is_alive(): |
|
|
self.thread.join(timeout=5) |
|
|
|
|
|
self.state = ManagerState.STOPPED |
|
|
logger.info("WebSocket manager stopped") |
|
|
|
|
|
def send_message(self, message: Dict) -> bool: |
|
|
"""Send message via WebSocket (thread-safe). |
|
|
|
|
|
Args: |
|
|
message: Message dictionary to send |
|
|
|
|
|
Returns: |
|
|
True if queued successfully |
|
|
""" |
|
|
if self.state != ManagerState.CONNECTED: |
|
|
logger.warning(f"Cannot send message - manager state: {self.state.value}") |
|
|
return False |
|
|
|
|
|
try: |
|
|
|
|
|
message.update({ |
|
|
"conversation_id": self.conversation_id, |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"client_id": f"gradio_{id(self)}" |
|
|
}) |
|
|
|
|
|
|
|
|
self.outbound_queue.put_nowait(message) |
|
|
logger.debug(f"Queued message: {message.get('type', 'unknown')}") |
|
|
return True |
|
|
|
|
|
except queue.Full: |
|
|
logger.error("Outbound message queue is full") |
|
|
return False |
|
|
except Exception as e: |
|
|
logger.error(f"Error queuing message: {e}") |
|
|
return False |
|
|
|
|
|
def get_messages(self) -> List[Dict]: |
|
|
"""Get all received messages (thread-safe). |
|
|
|
|
|
Returns: |
|
|
List of received message dictionaries |
|
|
""" |
|
|
messages = [] |
|
|
|
|
|
try: |
|
|
while True: |
|
|
message = self.inbound_queue.get_nowait() |
|
|
messages.append(message) |
|
|
except queue.Empty: |
|
|
pass |
|
|
except Exception as e: |
|
|
logger.error(f"Error getting messages: {e}") |
|
|
|
|
|
return messages |
|
|
|
|
|
def get_conversation_messages(self) -> List[Dict]: |
|
|
"""Get only conversation messages from received messages. |
|
|
|
|
|
Returns: |
|
|
List of conversation message dictionaries |
|
|
""" |
|
|
all_messages = self.get_messages() |
|
|
return [ |
|
|
msg for msg in all_messages |
|
|
if msg.get("type") == "conversation_message" |
|
|
] |
|
|
|
|
|
def get_status(self) -> Dict: |
|
|
"""Get current manager status. |
|
|
|
|
|
Returns: |
|
|
Status dictionary |
|
|
""" |
|
|
return { |
|
|
"state": self.state.value, |
|
|
"url": self.url, |
|
|
"conversation_id": self.conversation_id, |
|
|
"messages_sent": self.messages_sent, |
|
|
"messages_received": self.messages_received, |
|
|
"last_error": self.last_error, |
|
|
"connection_time": self.connection_time.isoformat() if self.connection_time else None, |
|
|
"thread_alive": self.thread.is_alive() if self.thread else False |
|
|
} |
|
|
|
|
|
def _run_websocket(self): |
|
|
"""Run WebSocket in background thread with dedicated event loop.""" |
|
|
logger.info("Starting WebSocket background thread") |
|
|
|
|
|
try: |
|
|
|
|
|
self.loop = asyncio.new_event_loop() |
|
|
asyncio.set_event_loop(self.loop) |
|
|
|
|
|
|
|
|
self.loop.run_until_complete(self._websocket_main()) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in WebSocket background thread: {e}") |
|
|
self.state = ManagerState.ERROR |
|
|
self.last_error = str(e) |
|
|
finally: |
|
|
if self.loop: |
|
|
self.loop.close() |
|
|
|
|
|
async def _websocket_main(self): |
|
|
"""Main WebSocket connection and message handling loop.""" |
|
|
retry_count = 0 |
|
|
max_retries = 5 |
|
|
|
|
|
while not self._stop_event.is_set() and retry_count < max_retries: |
|
|
try: |
|
|
logger.info(f"Connecting to WebSocket: {self.url}") |
|
|
|
|
|
async with websockets.connect( |
|
|
self.url, |
|
|
extra_headers=self.extra_headers, |
|
|
ping_interval=20, |
|
|
ping_timeout=10 |
|
|
) as websocket: |
|
|
self.websocket = websocket |
|
|
self.state = ManagerState.CONNECTED |
|
|
self.connection_time = datetime.now() |
|
|
retry_count = 0 |
|
|
|
|
|
logger.info("WebSocket connected successfully") |
|
|
|
|
|
|
|
|
send_task = asyncio.create_task(self._send_loop()) |
|
|
receive_task = asyncio.create_task(self._receive_loop()) |
|
|
|
|
|
|
|
|
done, pending = await asyncio.wait( |
|
|
[send_task, receive_task], |
|
|
return_when=asyncio.FIRST_COMPLETED |
|
|
) |
|
|
|
|
|
|
|
|
for task in pending: |
|
|
task.cancel() |
|
|
try: |
|
|
await task |
|
|
except asyncio.CancelledError: |
|
|
pass |
|
|
|
|
|
except (ConnectionClosed, WebSocketException) as e: |
|
|
logger.warning(f"WebSocket connection lost: {e}") |
|
|
self.state = ManagerState.DISCONNECTED |
|
|
|
|
|
if not self._stop_event.is_set(): |
|
|
retry_count += 1 |
|
|
retry_delay = min(2 ** retry_count, 30) |
|
|
logger.info(f"Reconnecting in {retry_delay}s (attempt {retry_count}/{max_retries})") |
|
|
await asyncio.sleep(retry_delay) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected WebSocket error: {e}") |
|
|
self.state = ManagerState.ERROR |
|
|
self.last_error = str(e) |
|
|
break |
|
|
|
|
|
if retry_count >= max_retries: |
|
|
self.state = ManagerState.ERROR |
|
|
self.last_error = "Max reconnection attempts reached" |
|
|
|
|
|
self.websocket = None |
|
|
logger.info("WebSocket connection ended") |
|
|
|
|
|
async def _send_loop(self): |
|
|
"""Send messages from outbound queue.""" |
|
|
while not self._stop_event.is_set(): |
|
|
try: |
|
|
|
|
|
try: |
|
|
message = self.outbound_queue.get_nowait() |
|
|
await self.websocket.send(json.dumps(message)) |
|
|
self.messages_sent += 1 |
|
|
logger.debug(f"Sent message: {message.get('type', 'unknown')}") |
|
|
except queue.Empty: |
|
|
|
|
|
await asyncio.sleep(0.1) |
|
|
except json.JSONEncodeError as e: |
|
|
logger.error(f"Error encoding message: {e}") |
|
|
|
|
|
except (ConnectionClosed, WebSocketException): |
|
|
logger.warning("WebSocket closed during send") |
|
|
break |
|
|
except Exception as e: |
|
|
logger.error(f"Error in send loop: {e}") |
|
|
break |
|
|
|
|
|
async def _receive_loop(self): |
|
|
"""Receive messages and put in inbound queue.""" |
|
|
while not self._stop_event.is_set(): |
|
|
try: |
|
|
message_str = await self.websocket.recv() |
|
|
message = json.loads(message_str) |
|
|
|
|
|
|
|
|
try: |
|
|
self.inbound_queue.put_nowait(message) |
|
|
self.messages_received += 1 |
|
|
logger.debug(f"Received message: {message.get('type', 'unknown')}") |
|
|
|
|
|
|
|
|
while self.inbound_queue.qsize() > self.max_messages: |
|
|
try: |
|
|
self.inbound_queue.get_nowait() |
|
|
except queue.Empty: |
|
|
break |
|
|
|
|
|
except queue.Full: |
|
|
logger.warning("Inbound message queue is full, dropping message") |
|
|
|
|
|
except (ConnectionClosed, WebSocketException): |
|
|
logger.warning("WebSocket closed during receive") |
|
|
break |
|
|
except json.JSONDecodeError as e: |
|
|
logger.error(f"Error decoding received message: {e}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error in receive loop: {e}") |
|
|
break |
|
|
|
|
|
def __del__(self): |
|
|
"""Cleanup on destruction.""" |
|
|
try: |
|
|
self.stop() |
|
|
except: |
|
|
pass |
|
|
|