File size: 12,345 Bytes
5669b22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import asyncio
import json
import uuid
from typing import Dict, Optional
from fastapi import WebSocket
from loguru import logger
import aiohttp
from starlette.websockets import WebSocketDisconnect

from .proxy_message_queue import ProxyMessageQueue


class ProxyHandler:
    """

    A proxy handler that allows multiple clients to connect through a single WebSocket connection to the server.

    This enables scenarios like having a web client and a live platform both connected to the same VTuber server.

    """

    def __init__(self, server_url: str = "ws://localhost:12393/client-ws"):
        """

        Initialize the proxy handler.



        Args:

            server_url: The WebSocket URL of the actual server

        """
        self.server_url = server_url
        self.server_ws: Optional[aiohttp.ClientWebSocketResponse] = None
        self.clients: Dict[str, WebSocket] = {}
        self.connected = False
        self.server_task: Optional[asyncio.Task] = None
        self.lock = asyncio.Lock()

        # Initialize message queue manager
        self.message_queue = ProxyMessageQueue()
        self._heartbeat_task: Optional[asyncio.Task] = None
        self._running = True
        self._session: Optional[aiohttp.ClientSession] = None

    async def connect_to_server(self):
        """Establish a WebSocket connection to the actual server"""
        if self.connected:
            return

        async with self.lock:
            if self.connected:  # Double-check to prevent race conditions
                return

            try:
                # Create session if not exists
                if not self._session:
                    self._session = aiohttp.ClientSession()
                self.server_ws = await self._session.ws_connect(self.server_url)
                self.connected = True
                logger.info(f"Proxy connected to server at {self.server_url}")

                # Initialize message queue with our forward function
                self.message_queue.initialize(self.forward_with_broadcast)

                # Start heartbeat task
                self._heartbeat_task = asyncio.create_task(self._maintain_connection())

                # Start task to receive messages from server
                self.server_task = asyncio.create_task(self.forward_server_messages())
            except Exception as e:
                logger.error(f"Failed to connect to server: {e}")
                if self._session:
                    await self._session.close()
                    self._session = None
                raise

    async def _maintain_connection(self):
        """Maintain connection with heartbeat and automatic reconnection"""
        while self._running:
            try:
                if self.connected and self.server_ws and not self.server_ws.closed:
                    # Send heartbeat
                    await self.server_ws.send_json({"type": "heartbeat"})
                    await asyncio.sleep(30)  # Heartbeat interval
                else:
                    # Try to reconnect
                    logger.info("Connection lost, attempting to reconnect...")
                    try:
                        await self.connect_to_server()
                    except Exception as e:
                        logger.error(f"Reconnection failed: {e}")
                        await asyncio.sleep(5)  # Wait before retry
            except Exception as e:
                logger.error(f"Error in connection maintenance: {e}")
                self.connected = False
                await asyncio.sleep(5)

    async def handle_client_connection(self, websocket: WebSocket):
        """

        Handle a new client connection to the proxy.



        Args:

            websocket: The client's WebSocket connection

        """
        await websocket.accept()

        # Generate a unique client ID
        client_id = str(uuid.uuid4())
        self.clients[client_id] = websocket
        logger.info(
            f"Client {client_id} connected to proxy. Total clients: {len(self.clients)}"
        )

        # Ensure server connection is established
        if not self.connected:
            await self.connect_to_server()

        if self.connected:
            try:
                init_request = {"type": "request-init-config", "client_id": client_id}
                await self.forward_to_server(init_request, client_id)
            except Exception as e:
                logger.error(f"Failed to request initialization: {e}")

        try:
            # Handle messages from this client
            while True:
                message = await websocket.receive_json()

                # Process text-input messages through the queue
                if message.get("type") == "text-input":
                    # Queue the message with the sender's ID
                    self.message_queue.queue_message(message, client_id)
                # Handle interrupt signals
                elif message.get("type") == "interrupt-signal":
                    logger.info(
                        "Received interrupt signal, marking conversation as inactive"
                    )
                    # Mark conversation as inactive to allow processing of next message
                    self.message_queue.conversation_active = False
                    # Forward the interrupt signal directly
                    await self.forward_to_server(message, client_id)
                else:
                    # Forward other message types directly
                    await self.forward_to_server(message, client_id)

        except WebSocketDisconnect:
            await self.handle_client_disconnect(client_id)
        except Exception as e:
            logger.error(f"Error handling client connection: {e}")
            await self.handle_client_disconnect(client_id)

    async def handle_client_disconnect(self, client_id: str):
        """

        Handle a client disconnection.



        Args:

            client_id: The ID of the disconnected client

        """
        self.clients.pop(client_id, None)
        logger.info(
            f"Client {client_id} removed. Remaining clients: {len(self.clients)}"
        )

        # If no clients are connected, disconnect from the server
        if not self.clients and self.connected:
            await self.disconnect()

    async def disconnect(self):
        """Disconnect from the server"""
        self._running = False

        # Cancel heartbeat task
        if self._heartbeat_task:
            self._heartbeat_task.cancel()
            try:
                await self._heartbeat_task
            except asyncio.CancelledError:
                pass

        if self.server_ws and not self.server_ws.closed:
            await self.server_ws.close()

        if self.server_task:
            self.server_task.cancel()

        # Close session
        if self._session:
            await self._session.close()
            self._session = None

        self.connected = False
        # Stop and clear the message queue
        self.message_queue.stop()
        self.message_queue.clear()
        logger.info("Proxy disconnected from server")

    async def forward_to_server(self, message: dict, sender_id: Optional[str] = None):
        """

        Forward a message from a client to the server.



        Args:

            message: The message to forward

            sender_id: ID of the client sending the message, to exclude from broadcast

        """
        if not self.connected or not self.server_ws:
            await self.connect_to_server()

        if self.server_ws and not self.server_ws.closed:
            await self.server_ws.send_json(message)

    async def forward_server_messages(self):
        """Forward messages from server to all connected clients"""
        try:
            while self.connected and self.server_ws and not self.server_ws.closed:
                try:
                    msg = await self.server_ws.receive()

                    if msg.type == aiohttp.WSMsgType.TEXT:
                        try:
                            # Parse the message
                            if not msg.data:  # Check if data is empty
                                continue

                            data = json.loads(msg.data)
                            if not data:  # Check if parsed data is empty
                                continue

                            # Check for conversation end signal
                            if (
                                data.get("type") == "control"
                                and data.get("text") == "conversation-chain-end"
                            ):
                                logger.info("Received conversation end signal")
                                self.message_queue.conversation_active = False

                            # Broadcast the message to all clients
                            await self.broadcast_to_clients(data)
                        except json.JSONDecodeError as e:
                            logger.error(f"Failed to parse message data: {e}")
                            continue
                    elif msg.type == aiohttp.WSMsgType.ERROR:
                        logger.error(f"WebSocket error: {self.server_ws.exception()}")
                        break
                    elif msg.type == aiohttp.WSMsgType.CLOSED:
                        break
                except Exception as e:
                    logger.error(f"Error processing server message: {e}")
                    await asyncio.sleep(1)
        except Exception as e:
            logger.error(f"Error forwarding server messages: {e}")
        finally:
            self.connected = False
            self.message_queue.conversation_active = False
            logger.info("Server message forwarding ended")

    async def broadcast_to_clients(

        self, message: dict, exclude_client: Optional[str] = None

    ):
        """

        Broadcast a message to all connected clients.



        Args:

            message: The message to broadcast

            exclude_client: Optional client ID to exclude from broadcast

        """
        if not message:  # Add null check
            return

        disconnected_clients = []

        # Log message, but handle audio data specially to avoid huge logs
        log_msg = (
            message.copy()
            if "audio" not in message
            else {
                **{k: v for k, v in message.items() if k != "audio"},
                "audio": f"[Audio data, {len(message.get('audio', ''))} bytes truncated]",
            }
        )

        if "volumes" in log_msg and len(log_msg.get("volumes", [])) > 10:
            log_msg["volumes"] = f"[{len(message.get('volumes', []))} volume values]"

        logger.debug(f"Broadcasting to clients (excluding {exclude_client}): {log_msg}")

        for client_id, websocket in self.clients.items():
            # Skip the excluded client
            if exclude_client and client_id == exclude_client:
                continue

            try:
                await websocket.send_json(message)
            except Exception as e:
                logger.error(f"Error sending to client {client_id}: {e}")
                disconnected_clients.append(client_id)

        # Clean up disconnected clients
        for client_id in disconnected_clients:
            await self.handle_client_disconnect(client_id)

    async def forward_with_broadcast(

        self, message: dict, sender_id: Optional[str] = None

    ):
        """

        Forward message to server and handle any necessary broadcasting



        Args:

            message: The message to forward

            sender_id: ID of the client sending the message

        """
        # Forward to server
        await self.forward_to_server(message, sender_id)

        # For transcription messages, broadcast to other clients
        if message.get("type") == "user-input-transcription":
            await self.broadcast_to_clients(message, exclude_client=sender_id)