File size: 12,342 Bytes
0a865e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca1ff64
0a865e9
 
 
 
 
ca1ff64
0a865e9
 
 
ca1ff64
0a865e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca1ff64
0a865e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca1ff64
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
"""Thread-safe WebSocket Manager for Gradio Frontend.

This module provides a robust WebSocket connection that runs in a background
thread with its own event loop, completely separated from Gradio's synchronous
environment. Uses thread-safe queues for communication.

Architecture:
    Gradio (Sync) ←→ Message Queues ←→ Background Thread (Async WebSocket)

Usage:
    manager = WebSocketManager("ws://localhost:8000/ws/conversation/123")
    manager.start()

    # Send messages (sync)
    manager.send_message({"type": "start_conversation", ...})

    # Get received messages (sync)
    messages = manager.get_messages()
"""

import asyncio
import threading
import time
import json
import queue
import logging
from typing import Dict, List, Optional
from datetime import datetime
from enum import Enum

import websockets
from websockets.exceptions import ConnectionClosed, WebSocketException

# Setup logging
logger = logging.getLogger(__name__)


class ManagerState(Enum):
    """WebSocket manager states."""
    STOPPED = "stopped"
    STARTING = "starting"
    CONNECTED = "connected"
    DISCONNECTED = "disconnected"
    ERROR = "error"


class WebSocketManager:
    """Thread-safe WebSocket manager for Gradio frontend."""

    def __init__(self, url: str, conversation_id: str, extra_headers: Optional[Dict[str, str]] = None):
        """Initialize WebSocket manager.

        Args:
            url: WebSocket server URL
            conversation_id: Unique conversation identifier
            extra_headers: Optional headers to send during the WebSocket handshake
        """
        self.url = url
        self.conversation_id = conversation_id
        self.extra_headers = extra_headers

        # State management
        self.state = ManagerState.STOPPED
        self.last_error = None

        # Background thread and event loop
        self.thread = None
        self.loop = None
        self.websocket = None
        self._stop_event = threading.Event()

        # Thread-safe message queues
        self.outbound_queue = queue.Queue()  # Messages to send
        self.inbound_queue = queue.Queue()   # Received messages
        self.max_messages = 100

        # Statistics
        self.messages_sent = 0
        self.messages_received = 0
        self.connection_time = None

    def start(self) -> bool:
        """Start the WebSocket manager in background thread.

        Returns:
            True if started successfully
        """
        if self.thread and self.thread.is_alive():
            logger.warning("WebSocket manager already running")
            return True

        try:
            self.state = ManagerState.STARTING
            self._stop_event.clear()

            # Start background thread
            self.thread = threading.Thread(target=self._run_websocket, daemon=True)
            self.thread.start()

            # Wait for connection (up to 10 seconds)
            start_time = time.time()
            while time.time() - start_time < 10:
                if self.state == ManagerState.CONNECTED:
                    logger.info(f"WebSocket manager started successfully")
                    return True
                elif self.state == ManagerState.ERROR:
                    logger.error(f"WebSocket manager failed to start: {self.last_error}")
                    return False
                time.sleep(0.1)

            logger.error("WebSocket manager startup timed out")
            self.state = ManagerState.ERROR
            self.last_error = "Startup timeout"
            return False

        except Exception as e:
            self.state = ManagerState.ERROR
            self.last_error = str(e)
            logger.error(f"Error starting WebSocket manager: {e}")
            return False

    def stop(self):
        """Stop the WebSocket manager."""
        logger.info("Stopping WebSocket manager...")
        self._stop_event.set()

        if self.thread and self.thread.is_alive():
            self.thread.join(timeout=5)

        self.state = ManagerState.STOPPED
        logger.info("WebSocket manager stopped")

    def send_message(self, message: Dict) -> bool:
        """Send message via WebSocket (thread-safe).

        Args:
            message: Message dictionary to send

        Returns:
            True if queued successfully
        """
        if self.state != ManagerState.CONNECTED:
            logger.warning(f"Cannot send message - manager state: {self.state.value}")
            return False

        try:
            # Add metadata
            message.update({
                "conversation_id": self.conversation_id,
                "timestamp": datetime.now().isoformat(),
                "client_id": f"gradio_{id(self)}"
            })

            # Queue for background thread to send
            self.outbound_queue.put_nowait(message)
            logger.debug(f"Queued message: {message.get('type', 'unknown')}")
            return True

        except queue.Full:
            logger.error("Outbound message queue is full")
            return False
        except Exception as e:
            logger.error(f"Error queuing message: {e}")
            return False

    def get_messages(self) -> List[Dict]:
        """Get all received messages (thread-safe).

        Returns:
            List of received message dictionaries
        """
        messages = []

        try:
            while True:
                message = self.inbound_queue.get_nowait()
                messages.append(message)
        except queue.Empty:
            pass
        except Exception as e:
            logger.error(f"Error getting messages: {e}")

        return messages

    def get_conversation_messages(self) -> List[Dict]:
        """Get only conversation messages from received messages.

        Returns:
            List of conversation message dictionaries
        """
        all_messages = self.get_messages()
        return [
            msg for msg in all_messages
            if msg.get("type") == "conversation_message"
        ]

    def get_status(self) -> Dict:
        """Get current manager status.

        Returns:
            Status dictionary
        """
        return {
            "state": self.state.value,
            "url": self.url,
            "conversation_id": self.conversation_id,
            "messages_sent": self.messages_sent,
            "messages_received": self.messages_received,
            "last_error": self.last_error,
            "connection_time": self.connection_time.isoformat() if self.connection_time else None,
            "thread_alive": self.thread.is_alive() if self.thread else False
        }

    def _run_websocket(self):
        """Run WebSocket in background thread with dedicated event loop."""
        logger.info("Starting WebSocket background thread")

        try:
            # Create new event loop for this thread
            self.loop = asyncio.new_event_loop()
            asyncio.set_event_loop(self.loop)

            # Run the WebSocket connection
            self.loop.run_until_complete(self._websocket_main())

        except Exception as e:
            logger.error(f"Error in WebSocket background thread: {e}")
            self.state = ManagerState.ERROR
            self.last_error = str(e)
        finally:
            if self.loop:
                self.loop.close()

    async def _websocket_main(self):
        """Main WebSocket connection and message handling loop."""
        retry_count = 0
        max_retries = 5

        while not self._stop_event.is_set() and retry_count < max_retries:
            try:
                logger.info(f"Connecting to WebSocket: {self.url}")

                async with websockets.connect(
                    self.url,
                    extra_headers=self.extra_headers,
                    ping_interval=20,
                    ping_timeout=10
                ) as websocket:
                    self.websocket = websocket
                    self.state = ManagerState.CONNECTED
                    self.connection_time = datetime.now()
                    retry_count = 0  # Reset on successful connection

                    logger.info("WebSocket connected successfully")

                    # Start message handling tasks
                    send_task = asyncio.create_task(self._send_loop())
                    receive_task = asyncio.create_task(self._receive_loop())

                    # Wait until connection closes or stop requested
                    done, pending = await asyncio.wait(
                        [send_task, receive_task],
                        return_when=asyncio.FIRST_COMPLETED
                    )

                    # Cancel remaining tasks
                    for task in pending:
                        task.cancel()
                        try:
                            await task
                        except asyncio.CancelledError:
                            pass

            except (ConnectionClosed, WebSocketException) as e:
                logger.warning(f"WebSocket connection lost: {e}")
                self.state = ManagerState.DISCONNECTED

                if not self._stop_event.is_set():
                    retry_count += 1
                    retry_delay = min(2 ** retry_count, 30)  # Exponential backoff
                    logger.info(f"Reconnecting in {retry_delay}s (attempt {retry_count}/{max_retries})")
                    await asyncio.sleep(retry_delay)

            except Exception as e:
                logger.error(f"Unexpected WebSocket error: {e}")
                self.state = ManagerState.ERROR
                self.last_error = str(e)
                break

        if retry_count >= max_retries:
            self.state = ManagerState.ERROR
            self.last_error = "Max reconnection attempts reached"

        self.websocket = None
        logger.info("WebSocket connection ended")

    async def _send_loop(self):
        """Send messages from outbound queue."""
        while not self._stop_event.is_set():
            try:
                # Check for messages to send (non-blocking)
                try:
                    message = self.outbound_queue.get_nowait()
                    await self.websocket.send(json.dumps(message))
                    self.messages_sent += 1
                    logger.debug(f"Sent message: {message.get('type', 'unknown')}")
                except queue.Empty:
                    # No messages to send, sleep briefly
                    await asyncio.sleep(0.1)
                except json.JSONEncodeError as e:
                    logger.error(f"Error encoding message: {e}")

            except (ConnectionClosed, WebSocketException):
                logger.warning("WebSocket closed during send")
                break
            except Exception as e:
                logger.error(f"Error in send loop: {e}")
                break

    async def _receive_loop(self):
        """Receive messages and put in inbound queue."""
        while not self._stop_event.is_set():
            try:
                message_str = await self.websocket.recv()
                message = json.loads(message_str)

                # Add to inbound queue (with size limit)
                try:
                    self.inbound_queue.put_nowait(message)
                    self.messages_received += 1
                    logger.debug(f"Received message: {message.get('type', 'unknown')}")

                    # Keep queue size manageable
                    while self.inbound_queue.qsize() > self.max_messages:
                        try:
                            self.inbound_queue.get_nowait()
                        except queue.Empty:
                            break

                except queue.Full:
                    logger.warning("Inbound message queue is full, dropping message")

            except (ConnectionClosed, WebSocketException):
                logger.warning("WebSocket closed during receive")
                break
            except json.JSONDecodeError as e:
                logger.error(f"Error decoding received message: {e}")
            except Exception as e:
                logger.error(f"Error in receive loop: {e}")
                break

    def __del__(self):
        """Cleanup on destruction."""
        try:
            self.stop()
        except:
            pass