test_ui / src /open_llm_vtuber /proxy_handler.py
britto224's picture
Upload 130 files
5669b22 verified
import asyncio
import json
import uuid
from typing import Dict, Optional
from fastapi import WebSocket
from loguru import logger
import aiohttp
from starlette.websockets import WebSocketDisconnect
from .proxy_message_queue import ProxyMessageQueue
class ProxyHandler:
"""
A proxy handler that allows multiple clients to connect through a single WebSocket connection to the server.
This enables scenarios like having a web client and a live platform both connected to the same VTuber server.
"""
def __init__(self, server_url: str = "ws://localhost:12393/client-ws"):
"""
Initialize the proxy handler.
Args:
server_url: The WebSocket URL of the actual server
"""
self.server_url = server_url
self.server_ws: Optional[aiohttp.ClientWebSocketResponse] = None
self.clients: Dict[str, WebSocket] = {}
self.connected = False
self.server_task: Optional[asyncio.Task] = None
self.lock = asyncio.Lock()
# Initialize message queue manager
self.message_queue = ProxyMessageQueue()
self._heartbeat_task: Optional[asyncio.Task] = None
self._running = True
self._session: Optional[aiohttp.ClientSession] = None
async def connect_to_server(self):
"""Establish a WebSocket connection to the actual server"""
if self.connected:
return
async with self.lock:
if self.connected: # Double-check to prevent race conditions
return
try:
# Create session if not exists
if not self._session:
self._session = aiohttp.ClientSession()
self.server_ws = await self._session.ws_connect(self.server_url)
self.connected = True
logger.info(f"Proxy connected to server at {self.server_url}")
# Initialize message queue with our forward function
self.message_queue.initialize(self.forward_with_broadcast)
# Start heartbeat task
self._heartbeat_task = asyncio.create_task(self._maintain_connection())
# Start task to receive messages from server
self.server_task = asyncio.create_task(self.forward_server_messages())
except Exception as e:
logger.error(f"Failed to connect to server: {e}")
if self._session:
await self._session.close()
self._session = None
raise
async def _maintain_connection(self):
"""Maintain connection with heartbeat and automatic reconnection"""
while self._running:
try:
if self.connected and self.server_ws and not self.server_ws.closed:
# Send heartbeat
await self.server_ws.send_json({"type": "heartbeat"})
await asyncio.sleep(30) # Heartbeat interval
else:
# Try to reconnect
logger.info("Connection lost, attempting to reconnect...")
try:
await self.connect_to_server()
except Exception as e:
logger.error(f"Reconnection failed: {e}")
await asyncio.sleep(5) # Wait before retry
except Exception as e:
logger.error(f"Error in connection maintenance: {e}")
self.connected = False
await asyncio.sleep(5)
async def handle_client_connection(self, websocket: WebSocket):
"""
Handle a new client connection to the proxy.
Args:
websocket: The client's WebSocket connection
"""
await websocket.accept()
# Generate a unique client ID
client_id = str(uuid.uuid4())
self.clients[client_id] = websocket
logger.info(
f"Client {client_id} connected to proxy. Total clients: {len(self.clients)}"
)
# Ensure server connection is established
if not self.connected:
await self.connect_to_server()
if self.connected:
try:
init_request = {"type": "request-init-config", "client_id": client_id}
await self.forward_to_server(init_request, client_id)
except Exception as e:
logger.error(f"Failed to request initialization: {e}")
try:
# Handle messages from this client
while True:
message = await websocket.receive_json()
# Process text-input messages through the queue
if message.get("type") == "text-input":
# Queue the message with the sender's ID
self.message_queue.queue_message(message, client_id)
# Handle interrupt signals
elif message.get("type") == "interrupt-signal":
logger.info(
"Received interrupt signal, marking conversation as inactive"
)
# Mark conversation as inactive to allow processing of next message
self.message_queue.conversation_active = False
# Forward the interrupt signal directly
await self.forward_to_server(message, client_id)
else:
# Forward other message types directly
await self.forward_to_server(message, client_id)
except WebSocketDisconnect:
await self.handle_client_disconnect(client_id)
except Exception as e:
logger.error(f"Error handling client connection: {e}")
await self.handle_client_disconnect(client_id)
async def handle_client_disconnect(self, client_id: str):
"""
Handle a client disconnection.
Args:
client_id: The ID of the disconnected client
"""
self.clients.pop(client_id, None)
logger.info(
f"Client {client_id} removed. Remaining clients: {len(self.clients)}"
)
# If no clients are connected, disconnect from the server
if not self.clients and self.connected:
await self.disconnect()
async def disconnect(self):
"""Disconnect from the server"""
self._running = False
# Cancel heartbeat task
if self._heartbeat_task:
self._heartbeat_task.cancel()
try:
await self._heartbeat_task
except asyncio.CancelledError:
pass
if self.server_ws and not self.server_ws.closed:
await self.server_ws.close()
if self.server_task:
self.server_task.cancel()
# Close session
if self._session:
await self._session.close()
self._session = None
self.connected = False
# Stop and clear the message queue
self.message_queue.stop()
self.message_queue.clear()
logger.info("Proxy disconnected from server")
async def forward_to_server(self, message: dict, sender_id: Optional[str] = None):
"""
Forward a message from a client to the server.
Args:
message: The message to forward
sender_id: ID of the client sending the message, to exclude from broadcast
"""
if not self.connected or not self.server_ws:
await self.connect_to_server()
if self.server_ws and not self.server_ws.closed:
await self.server_ws.send_json(message)
async def forward_server_messages(self):
"""Forward messages from server to all connected clients"""
try:
while self.connected and self.server_ws and not self.server_ws.closed:
try:
msg = await self.server_ws.receive()
if msg.type == aiohttp.WSMsgType.TEXT:
try:
# Parse the message
if not msg.data: # Check if data is empty
continue
data = json.loads(msg.data)
if not data: # Check if parsed data is empty
continue
# Check for conversation end signal
if (
data.get("type") == "control"
and data.get("text") == "conversation-chain-end"
):
logger.info("Received conversation end signal")
self.message_queue.conversation_active = False
# Broadcast the message to all clients
await self.broadcast_to_clients(data)
except json.JSONDecodeError as e:
logger.error(f"Failed to parse message data: {e}")
continue
elif msg.type == aiohttp.WSMsgType.ERROR:
logger.error(f"WebSocket error: {self.server_ws.exception()}")
break
elif msg.type == aiohttp.WSMsgType.CLOSED:
break
except Exception as e:
logger.error(f"Error processing server message: {e}")
await asyncio.sleep(1)
except Exception as e:
logger.error(f"Error forwarding server messages: {e}")
finally:
self.connected = False
self.message_queue.conversation_active = False
logger.info("Server message forwarding ended")
async def broadcast_to_clients(
self, message: dict, exclude_client: Optional[str] = None
):
"""
Broadcast a message to all connected clients.
Args:
message: The message to broadcast
exclude_client: Optional client ID to exclude from broadcast
"""
if not message: # Add null check
return
disconnected_clients = []
# Log message, but handle audio data specially to avoid huge logs
log_msg = (
message.copy()
if "audio" not in message
else {
**{k: v for k, v in message.items() if k != "audio"},
"audio": f"[Audio data, {len(message.get('audio', ''))} bytes truncated]",
}
)
if "volumes" in log_msg and len(log_msg.get("volumes", [])) > 10:
log_msg["volumes"] = f"[{len(message.get('volumes', []))} volume values]"
logger.debug(f"Broadcasting to clients (excluding {exclude_client}): {log_msg}")
for client_id, websocket in self.clients.items():
# Skip the excluded client
if exclude_client and client_id == exclude_client:
continue
try:
await websocket.send_json(message)
except Exception as e:
logger.error(f"Error sending to client {client_id}: {e}")
disconnected_clients.append(client_id)
# Clean up disconnected clients
for client_id in disconnected_clients:
await self.handle_client_disconnect(client_id)
async def forward_with_broadcast(
self, message: dict, sender_id: Optional[str] = None
):
"""
Forward message to server and handle any necessary broadcasting
Args:
message: The message to forward
sender_id: ID of the client sending the message
"""
# Forward to server
await self.forward_to_server(message, sender_id)
# For transcription messages, broadcast to other clients
if message.get("type") == "user-input-transcription":
await self.broadcast_to_clients(message, exclude_client=sender_id)