File size: 12,345 Bytes
5669b22 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 | import asyncio
import json
import uuid
from typing import Dict, Optional
from fastapi import WebSocket
from loguru import logger
import aiohttp
from starlette.websockets import WebSocketDisconnect
from .proxy_message_queue import ProxyMessageQueue
class ProxyHandler:
"""
A proxy handler that allows multiple clients to connect through a single WebSocket connection to the server.
This enables scenarios like having a web client and a live platform both connected to the same VTuber server.
"""
def __init__(self, server_url: str = "ws://localhost:12393/client-ws"):
"""
Initialize the proxy handler.
Args:
server_url: The WebSocket URL of the actual server
"""
self.server_url = server_url
self.server_ws: Optional[aiohttp.ClientWebSocketResponse] = None
self.clients: Dict[str, WebSocket] = {}
self.connected = False
self.server_task: Optional[asyncio.Task] = None
self.lock = asyncio.Lock()
# Initialize message queue manager
self.message_queue = ProxyMessageQueue()
self._heartbeat_task: Optional[asyncio.Task] = None
self._running = True
self._session: Optional[aiohttp.ClientSession] = None
async def connect_to_server(self):
"""Establish a WebSocket connection to the actual server"""
if self.connected:
return
async with self.lock:
if self.connected: # Double-check to prevent race conditions
return
try:
# Create session if not exists
if not self._session:
self._session = aiohttp.ClientSession()
self.server_ws = await self._session.ws_connect(self.server_url)
self.connected = True
logger.info(f"Proxy connected to server at {self.server_url}")
# Initialize message queue with our forward function
self.message_queue.initialize(self.forward_with_broadcast)
# Start heartbeat task
self._heartbeat_task = asyncio.create_task(self._maintain_connection())
# Start task to receive messages from server
self.server_task = asyncio.create_task(self.forward_server_messages())
except Exception as e:
logger.error(f"Failed to connect to server: {e}")
if self._session:
await self._session.close()
self._session = None
raise
async def _maintain_connection(self):
"""Maintain connection with heartbeat and automatic reconnection"""
while self._running:
try:
if self.connected and self.server_ws and not self.server_ws.closed:
# Send heartbeat
await self.server_ws.send_json({"type": "heartbeat"})
await asyncio.sleep(30) # Heartbeat interval
else:
# Try to reconnect
logger.info("Connection lost, attempting to reconnect...")
try:
await self.connect_to_server()
except Exception as e:
logger.error(f"Reconnection failed: {e}")
await asyncio.sleep(5) # Wait before retry
except Exception as e:
logger.error(f"Error in connection maintenance: {e}")
self.connected = False
await asyncio.sleep(5)
async def handle_client_connection(self, websocket: WebSocket):
"""
Handle a new client connection to the proxy.
Args:
websocket: The client's WebSocket connection
"""
await websocket.accept()
# Generate a unique client ID
client_id = str(uuid.uuid4())
self.clients[client_id] = websocket
logger.info(
f"Client {client_id} connected to proxy. Total clients: {len(self.clients)}"
)
# Ensure server connection is established
if not self.connected:
await self.connect_to_server()
if self.connected:
try:
init_request = {"type": "request-init-config", "client_id": client_id}
await self.forward_to_server(init_request, client_id)
except Exception as e:
logger.error(f"Failed to request initialization: {e}")
try:
# Handle messages from this client
while True:
message = await websocket.receive_json()
# Process text-input messages through the queue
if message.get("type") == "text-input":
# Queue the message with the sender's ID
self.message_queue.queue_message(message, client_id)
# Handle interrupt signals
elif message.get("type") == "interrupt-signal":
logger.info(
"Received interrupt signal, marking conversation as inactive"
)
# Mark conversation as inactive to allow processing of next message
self.message_queue.conversation_active = False
# Forward the interrupt signal directly
await self.forward_to_server(message, client_id)
else:
# Forward other message types directly
await self.forward_to_server(message, client_id)
except WebSocketDisconnect:
await self.handle_client_disconnect(client_id)
except Exception as e:
logger.error(f"Error handling client connection: {e}")
await self.handle_client_disconnect(client_id)
async def handle_client_disconnect(self, client_id: str):
"""
Handle a client disconnection.
Args:
client_id: The ID of the disconnected client
"""
self.clients.pop(client_id, None)
logger.info(
f"Client {client_id} removed. Remaining clients: {len(self.clients)}"
)
# If no clients are connected, disconnect from the server
if not self.clients and self.connected:
await self.disconnect()
async def disconnect(self):
"""Disconnect from the server"""
self._running = False
# Cancel heartbeat task
if self._heartbeat_task:
self._heartbeat_task.cancel()
try:
await self._heartbeat_task
except asyncio.CancelledError:
pass
if self.server_ws and not self.server_ws.closed:
await self.server_ws.close()
if self.server_task:
self.server_task.cancel()
# Close session
if self._session:
await self._session.close()
self._session = None
self.connected = False
# Stop and clear the message queue
self.message_queue.stop()
self.message_queue.clear()
logger.info("Proxy disconnected from server")
async def forward_to_server(self, message: dict, sender_id: Optional[str] = None):
"""
Forward a message from a client to the server.
Args:
message: The message to forward
sender_id: ID of the client sending the message, to exclude from broadcast
"""
if not self.connected or not self.server_ws:
await self.connect_to_server()
if self.server_ws and not self.server_ws.closed:
await self.server_ws.send_json(message)
async def forward_server_messages(self):
"""Forward messages from server to all connected clients"""
try:
while self.connected and self.server_ws and not self.server_ws.closed:
try:
msg = await self.server_ws.receive()
if msg.type == aiohttp.WSMsgType.TEXT:
try:
# Parse the message
if not msg.data: # Check if data is empty
continue
data = json.loads(msg.data)
if not data: # Check if parsed data is empty
continue
# Check for conversation end signal
if (
data.get("type") == "control"
and data.get("text") == "conversation-chain-end"
):
logger.info("Received conversation end signal")
self.message_queue.conversation_active = False
# Broadcast the message to all clients
await self.broadcast_to_clients(data)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse message data: {e}")
continue
elif msg.type == aiohttp.WSMsgType.ERROR:
logger.error(f"WebSocket error: {self.server_ws.exception()}")
break
elif msg.type == aiohttp.WSMsgType.CLOSED:
break
except Exception as e:
logger.error(f"Error processing server message: {e}")
await asyncio.sleep(1)
except Exception as e:
logger.error(f"Error forwarding server messages: {e}")
finally:
self.connected = False
self.message_queue.conversation_active = False
logger.info("Server message forwarding ended")
async def broadcast_to_clients(
self, message: dict, exclude_client: Optional[str] = None
):
"""
Broadcast a message to all connected clients.
Args:
message: The message to broadcast
exclude_client: Optional client ID to exclude from broadcast
"""
if not message: # Add null check
return
disconnected_clients = []
# Log message, but handle audio data specially to avoid huge logs
log_msg = (
message.copy()
if "audio" not in message
else {
**{k: v for k, v in message.items() if k != "audio"},
"audio": f"[Audio data, {len(message.get('audio', ''))} bytes truncated]",
}
)
if "volumes" in log_msg and len(log_msg.get("volumes", [])) > 10:
log_msg["volumes"] = f"[{len(message.get('volumes', []))} volume values]"
logger.debug(f"Broadcasting to clients (excluding {exclude_client}): {log_msg}")
for client_id, websocket in self.clients.items():
# Skip the excluded client
if exclude_client and client_id == exclude_client:
continue
try:
await websocket.send_json(message)
except Exception as e:
logger.error(f"Error sending to client {client_id}: {e}")
disconnected_clients.append(client_id)
# Clean up disconnected clients
for client_id in disconnected_clients:
await self.handle_client_disconnect(client_id)
async def forward_with_broadcast(
self, message: dict, sender_id: Optional[str] = None
):
"""
Forward message to server and handle any necessary broadcasting
Args:
message: The message to forward
sender_id: ID of the client sending the message
"""
# Forward to server
await self.forward_to_server(message, sender_id)
# For transcription messages, broadcast to other clients
if message.get("type") == "user-input-transcription":
await self.broadcast_to_clients(message, exclude_client=sender_id)
|