| import asyncio
|
| import json
|
| import uuid
|
| from typing import Dict, Optional
|
| from fastapi import WebSocket
|
| from loguru import logger
|
| import aiohttp
|
| from starlette.websockets import WebSocketDisconnect
|
|
|
| from .proxy_message_queue import ProxyMessageQueue
|
|
|
|
|
| class ProxyHandler:
|
| """
|
| A proxy handler that allows multiple clients to connect through a single WebSocket connection to the server.
|
| This enables scenarios like having a web client and a live platform both connected to the same VTuber server.
|
| """
|
|
|
| def __init__(self, server_url: str = "ws://localhost:12393/client-ws"):
|
| """
|
| Initialize the proxy handler.
|
|
|
| Args:
|
| server_url: The WebSocket URL of the actual server
|
| """
|
| self.server_url = server_url
|
| self.server_ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
| self.clients: Dict[str, WebSocket] = {}
|
| self.connected = False
|
| self.server_task: Optional[asyncio.Task] = None
|
| self.lock = asyncio.Lock()
|
|
|
|
|
| self.message_queue = ProxyMessageQueue()
|
| self._heartbeat_task: Optional[asyncio.Task] = None
|
| self._running = True
|
| self._session: Optional[aiohttp.ClientSession] = None
|
|
|
| async def connect_to_server(self):
|
| """Establish a WebSocket connection to the actual server"""
|
| if self.connected:
|
| return
|
|
|
| async with self.lock:
|
| if self.connected:
|
| return
|
|
|
| try:
|
|
|
| if not self._session:
|
| self._session = aiohttp.ClientSession()
|
| self.server_ws = await self._session.ws_connect(self.server_url)
|
| self.connected = True
|
| logger.info(f"Proxy connected to server at {self.server_url}")
|
|
|
|
|
| self.message_queue.initialize(self.forward_with_broadcast)
|
|
|
|
|
| self._heartbeat_task = asyncio.create_task(self._maintain_connection())
|
|
|
|
|
| self.server_task = asyncio.create_task(self.forward_server_messages())
|
| except Exception as e:
|
| logger.error(f"Failed to connect to server: {e}")
|
| if self._session:
|
| await self._session.close()
|
| self._session = None
|
| raise
|
|
|
| async def _maintain_connection(self):
|
| """Maintain connection with heartbeat and automatic reconnection"""
|
| while self._running:
|
| try:
|
| if self.connected and self.server_ws and not self.server_ws.closed:
|
|
|
| await self.server_ws.send_json({"type": "heartbeat"})
|
| await asyncio.sleep(30)
|
| else:
|
|
|
| logger.info("Connection lost, attempting to reconnect...")
|
| try:
|
| await self.connect_to_server()
|
| except Exception as e:
|
| logger.error(f"Reconnection failed: {e}")
|
| await asyncio.sleep(5)
|
| except Exception as e:
|
| logger.error(f"Error in connection maintenance: {e}")
|
| self.connected = False
|
| await asyncio.sleep(5)
|
|
|
| async def handle_client_connection(self, websocket: WebSocket):
|
| """
|
| Handle a new client connection to the proxy.
|
|
|
| Args:
|
| websocket: The client's WebSocket connection
|
| """
|
| await websocket.accept()
|
|
|
|
|
| client_id = str(uuid.uuid4())
|
| self.clients[client_id] = websocket
|
| logger.info(
|
| f"Client {client_id} connected to proxy. Total clients: {len(self.clients)}"
|
| )
|
|
|
|
|
| if not self.connected:
|
| await self.connect_to_server()
|
|
|
| if self.connected:
|
| try:
|
| init_request = {"type": "request-init-config", "client_id": client_id}
|
| await self.forward_to_server(init_request, client_id)
|
| except Exception as e:
|
| logger.error(f"Failed to request initialization: {e}")
|
|
|
| try:
|
|
|
| while True:
|
| message = await websocket.receive_json()
|
|
|
|
|
| if message.get("type") == "text-input":
|
|
|
| self.message_queue.queue_message(message, client_id)
|
|
|
| elif message.get("type") == "interrupt-signal":
|
| logger.info(
|
| "Received interrupt signal, marking conversation as inactive"
|
| )
|
|
|
| self.message_queue.conversation_active = False
|
|
|
| await self.forward_to_server(message, client_id)
|
| else:
|
|
|
| await self.forward_to_server(message, client_id)
|
|
|
| except WebSocketDisconnect:
|
| await self.handle_client_disconnect(client_id)
|
| except Exception as e:
|
| logger.error(f"Error handling client connection: {e}")
|
| await self.handle_client_disconnect(client_id)
|
|
|
| async def handle_client_disconnect(self, client_id: str):
|
| """
|
| Handle a client disconnection.
|
|
|
| Args:
|
| client_id: The ID of the disconnected client
|
| """
|
| self.clients.pop(client_id, None)
|
| logger.info(
|
| f"Client {client_id} removed. Remaining clients: {len(self.clients)}"
|
| )
|
|
|
|
|
| if not self.clients and self.connected:
|
| await self.disconnect()
|
|
|
| async def disconnect(self):
|
| """Disconnect from the server"""
|
| self._running = False
|
|
|
|
|
| if self._heartbeat_task:
|
| self._heartbeat_task.cancel()
|
| try:
|
| await self._heartbeat_task
|
| except asyncio.CancelledError:
|
| pass
|
|
|
| if self.server_ws and not self.server_ws.closed:
|
| await self.server_ws.close()
|
|
|
| if self.server_task:
|
| self.server_task.cancel()
|
|
|
|
|
| if self._session:
|
| await self._session.close()
|
| self._session = None
|
|
|
| self.connected = False
|
|
|
| self.message_queue.stop()
|
| self.message_queue.clear()
|
| logger.info("Proxy disconnected from server")
|
|
|
| async def forward_to_server(self, message: dict, sender_id: Optional[str] = None):
|
| """
|
| Forward a message from a client to the server.
|
|
|
| Args:
|
| message: The message to forward
|
| sender_id: ID of the client sending the message, to exclude from broadcast
|
| """
|
| if not self.connected or not self.server_ws:
|
| await self.connect_to_server()
|
|
|
| if self.server_ws and not self.server_ws.closed:
|
| await self.server_ws.send_json(message)
|
|
|
| async def forward_server_messages(self):
|
| """Forward messages from server to all connected clients"""
|
| try:
|
| while self.connected and self.server_ws and not self.server_ws.closed:
|
| try:
|
| msg = await self.server_ws.receive()
|
|
|
| if msg.type == aiohttp.WSMsgType.TEXT:
|
| try:
|
|
|
| if not msg.data:
|
| continue
|
|
|
| data = json.loads(msg.data)
|
| if not data:
|
| continue
|
|
|
|
|
| if (
|
| data.get("type") == "control"
|
| and data.get("text") == "conversation-chain-end"
|
| ):
|
| logger.info("Received conversation end signal")
|
| self.message_queue.conversation_active = False
|
|
|
|
|
| await self.broadcast_to_clients(data)
|
| except json.JSONDecodeError as e:
|
| logger.error(f"Failed to parse message data: {e}")
|
| continue
|
| elif msg.type == aiohttp.WSMsgType.ERROR:
|
| logger.error(f"WebSocket error: {self.server_ws.exception()}")
|
| break
|
| elif msg.type == aiohttp.WSMsgType.CLOSED:
|
| break
|
| except Exception as e:
|
| logger.error(f"Error processing server message: {e}")
|
| await asyncio.sleep(1)
|
| except Exception as e:
|
| logger.error(f"Error forwarding server messages: {e}")
|
| finally:
|
| self.connected = False
|
| self.message_queue.conversation_active = False
|
| logger.info("Server message forwarding ended")
|
|
|
| async def broadcast_to_clients(
|
| self, message: dict, exclude_client: Optional[str] = None
|
| ):
|
| """
|
| Broadcast a message to all connected clients.
|
|
|
| Args:
|
| message: The message to broadcast
|
| exclude_client: Optional client ID to exclude from broadcast
|
| """
|
| if not message:
|
| return
|
|
|
| disconnected_clients = []
|
|
|
|
|
| log_msg = (
|
| message.copy()
|
| if "audio" not in message
|
| else {
|
| **{k: v for k, v in message.items() if k != "audio"},
|
| "audio": f"[Audio data, {len(message.get('audio', ''))} bytes truncated]",
|
| }
|
| )
|
|
|
| if "volumes" in log_msg and len(log_msg.get("volumes", [])) > 10:
|
| log_msg["volumes"] = f"[{len(message.get('volumes', []))} volume values]"
|
|
|
| logger.debug(f"Broadcasting to clients (excluding {exclude_client}): {log_msg}")
|
|
|
| for client_id, websocket in self.clients.items():
|
|
|
| if exclude_client and client_id == exclude_client:
|
| continue
|
|
|
| try:
|
| await websocket.send_json(message)
|
| except Exception as e:
|
| logger.error(f"Error sending to client {client_id}: {e}")
|
| disconnected_clients.append(client_id)
|
|
|
|
|
| for client_id in disconnected_clients:
|
| await self.handle_client_disconnect(client_id)
|
|
|
| async def forward_with_broadcast(
|
| self, message: dict, sender_id: Optional[str] = None
|
| ):
|
| """
|
| Forward message to server and handle any necessary broadcasting
|
|
|
| Args:
|
| message: The message to forward
|
| sender_id: ID of the client sending the message
|
| """
|
|
|
| await self.forward_to_server(message, sender_id)
|
|
|
|
|
| if message.get("type") == "user-input-transcription":
|
| await self.broadcast_to_clients(message, exclude_client=sender_id)
|
|
|