ChatCal.ai-1 / webrtc /server /websocket_handler.py
Peter Michael Gits
fix: Improve fallback handling and prevent TTS errors v0.5.4
08815de
"""
WebRTC WebSocket Handler for Real-time Audio Streaming
Integrates with FastAPI for unmute.sh-style voice interaction
"""
import asyncio
import json
import logging
from typing import Dict, Optional
import websockets
from fastapi import WebSocket, WebSocketDisconnect
import numpy as np
import soundfile as sf
import tempfile
import os
from datetime import datetime
logger = logging.getLogger(__name__)
class WebRTCHandler:
"""Handles WebRTC WebSocket connections for real-time audio streaming"""
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.audio_buffers: Dict[str, list] = {}
self.stt_service_url = "https://pgits-stt-gpu-service.hf.space"
self.stt_websocket_url = "wss://pgits-stt-gpu-service.hf.space/ws/stt"
self.stt_connections: Dict[str, websockets.WebSocketClientProtocol] = {}
self.tts_service_url = "https://pgits-tts-gpu-service.hf.space"
self.tts_websocket_url = "wss://pgits-tts-gpu-service.hf.space/ws/tts"
self.tts_connections: Dict[str, websockets.WebSocketClientProtocol] = {}
async def connect(self, websocket: WebSocket, client_id: str):
"""Accept WebSocket connection and initialize audio buffer"""
await websocket.accept()
self.active_connections[client_id] = websocket
self.audio_buffers[client_id] = []
logger.info(f"πŸ”Œ WebRTC client {client_id} connected")
# Send connection confirmation
await self.send_message(client_id, {
"type": "connection_confirmed",
"client_id": client_id,
"timestamp": datetime.now().isoformat(),
"services": {
"stt": self.stt_service_url,
"status": "ready"
}
})
async def disconnect(self, client_id: str):
"""Clean up connection and buffers"""
if client_id in self.active_connections:
del self.active_connections[client_id]
if client_id in self.audio_buffers:
del self.audio_buffers[client_id]
# Clean up STT connection if exists
await self.disconnect_from_stt_service(client_id)
# Clean up TTS connection if exists
await self.disconnect_from_tts_service(client_id)
logger.info(f"πŸ”Œ WebRTC client {client_id} disconnected")
async def send_message(self, client_id: str, message: dict):
"""Send JSON message to client"""
if client_id in self.active_connections:
websocket = self.active_connections[client_id]
try:
await websocket.send_text(json.dumps(message))
except Exception as e:
logger.error(f"Failed to send message to {client_id}: {e}")
await self.disconnect(client_id)
async def handle_audio_chunk(self, client_id: str, audio_data: bytes, sample_rate: int = 16000):
"""Process incoming audio chunk for real-time STT"""
try:
logger.info(f"🎀 Received {len(audio_data)} bytes from {client_id}")
# MediaRecorder typically produces WebM/OGG/WAV format, not raw PCM
# For WebRTC demo, we'll save the audio data temporarily and process it
with tempfile.NamedTemporaryFile(suffix='.webm', delete=False) as tmp_file:
tmp_file.write(audio_data)
tmp_file_path = tmp_file.name
try:
# Process the audio file directly (WebRTC demo mode)
transcription = await self.process_audio_file_webrtc(tmp_file_path, sample_rate)
if transcription:
# Send transcription back to client
await self.send_message(client_id, {
"type": "transcription",
"text": transcription,
"timestamp": datetime.now().isoformat(),
"audio_size": len(audio_data),
"format": "webm/audio"
})
logger.info(f"πŸ“ Transcription sent to {client_id}: {transcription[:50]}...")
else:
# Send error message
await self.send_message(client_id, {
"type": "error",
"message": "Audio processing failed",
"timestamp": datetime.now().isoformat()
})
finally:
# Clean up temporary file
if os.path.exists(tmp_file_path):
os.unlink(tmp_file_path)
except Exception as e:
logger.error(f"Error processing audio chunk for {client_id}: {e}")
await self.send_message(client_id, {
"type": "error",
"message": f"Audio processing error: {str(e)}",
"timestamp": datetime.now().isoformat()
})
async def connect_to_stt_service(self, client_id: str) -> bool:
"""Connect to the STT WebSocket service"""
try:
logger.info(f"πŸ”Œ Connecting to STT service for client {client_id}: {self.stt_websocket_url}")
# Connect to STT WebSocket service with shorter timeout
stt_ws = await asyncio.wait_for(
websockets.connect(self.stt_websocket_url),
timeout=5.0
)
self.stt_connections[client_id] = stt_ws
# Wait for connection confirmation with timeout
confirmation = await asyncio.wait_for(stt_ws.recv(), timeout=10.0)
confirmation_data = json.loads(confirmation)
if confirmation_data.get("type") == "stt_connection_confirmed":
logger.info(f"βœ… STT service connected for client {client_id}")
return True
else:
logger.warning(f"⚠️ Unexpected STT confirmation: {confirmation_data}")
return False
except asyncio.TimeoutError:
logger.error(f"❌ STT service connection timeout for {client_id} - service may be cold starting or WebSocket endpoints not available")
return False
except websockets.exceptions.WebSocketException as e:
logger.error(f"❌ STT WebSocket error for {client_id}: {e}")
logger.info(f"πŸ” Debug: Attempted connection to {self.stt_websocket_url}")
return False
except Exception as e:
logger.error(f"❌ Failed to connect to STT service for {client_id}: {e}")
logger.info(f"πŸ” Debug: STT service URL: {self.stt_websocket_url}")
return False
async def disconnect_from_stt_service(self, client_id: str):
"""Disconnect from STT WebSocket service"""
if client_id in self.stt_connections:
try:
stt_ws = self.stt_connections[client_id]
await stt_ws.close()
del self.stt_connections[client_id]
logger.info(f"πŸ”Œ Disconnected from STT service for client {client_id}")
except Exception as e:
logger.error(f"Error disconnecting from STT service: {e}")
async def send_audio_to_stt_service(self, client_id: str, audio_data: bytes) -> Optional[str]:
"""Send audio data to STT service and get transcription"""
if client_id not in self.stt_connections:
# Try to connect if not already connected
success = await self.connect_to_stt_service(client_id)
if not success:
return None
try:
stt_ws = self.stt_connections[client_id]
# Convert audio bytes to base64 for WebSocket transmission
import base64
audio_b64 = base64.b64encode(audio_data).decode('utf-8')
# Send STT audio chunk message
message = {
"type": "stt_audio_chunk",
"audio_data": audio_b64,
"language": "auto",
"model_size": "base"
}
await stt_ws.send(json.dumps(message))
logger.info(f"πŸ“€ Sent {len(audio_data)} bytes to STT service")
# Wait for transcription response
response = await stt_ws.recv()
response_data = json.loads(response)
if response_data.get("type") == "stt_transcription":
transcription = response_data.get("text", "")
logger.info(f"πŸ“ STT transcription received: {transcription[:50]}...")
return transcription
elif response_data.get("type") == "stt_error":
error_msg = response_data.get("message", "Unknown STT error")
logger.error(f"❌ STT service error: {error_msg}")
return None
else:
logger.warning(f"⚠️ Unexpected STT response: {response_data}")
return None
except Exception as e:
logger.error(f"❌ Error communicating with STT service: {e}")
# Cleanup connection on error
await self.disconnect_from_stt_service(client_id)
return None
# TTS WebSocket Methods
async def connect_to_tts_service(self, client_id: str) -> bool:
"""Connect to the TTS WebSocket service"""
try:
logger.info(f"πŸ”Œ Connecting to TTS service for client {client_id}: {self.tts_websocket_url}")
# Connect to TTS WebSocket service
tts_ws = await websockets.connect(self.tts_websocket_url)
self.tts_connections[client_id] = tts_ws
# Wait for connection confirmation
confirmation = await tts_ws.recv()
confirmation_data = json.loads(confirmation)
if confirmation_data.get("type") == "tts_connection_confirmed":
logger.info(f"βœ… TTS service connected for client {client_id}")
return True
else:
logger.warning(f"⚠️ Unexpected TTS confirmation: {confirmation_data}")
return False
except Exception as e:
logger.error(f"❌ Failed to connect to TTS service for {client_id}: {e}")
return False
async def disconnect_from_tts_service(self, client_id: str):
"""Disconnect from TTS WebSocket service"""
if client_id in self.tts_connections:
try:
tts_ws = self.tts_connections[client_id]
await tts_ws.close()
del self.tts_connections[client_id]
logger.info(f"πŸ”Œ Disconnected from TTS service for client {client_id}")
except Exception as e:
logger.error(f"Error disconnecting from TTS service: {e}")
async def send_text_to_tts_service(self, client_id: str, text: str, voice_preset: str = "v2/en_speaker_6") -> Optional[bytes]:
"""Send text to TTS service and get audio response"""
if client_id not in self.tts_connections:
# Try to connect if not already connected
success = await self.connect_to_tts_service(client_id)
if not success:
return None
try:
tts_ws = self.tts_connections[client_id]
# Send TTS synthesis message
message = {
"type": "tts_synthesize",
"text": text,
"voice_preset": voice_preset
}
await tts_ws.send(json.dumps(message))
logger.info(f"πŸ“€ Sent text to TTS service: {text[:50]}...")
# Wait for audio response
response = await tts_ws.recv()
response_data = json.loads(response)
if response_data.get("type") == "tts_audio_response":
# Decode base64 audio data
audio_b64 = response_data.get("audio_data", "")
audio_bytes = base64.b64decode(audio_b64)
logger.info(f"πŸ”Š TTS audio received: {len(audio_bytes)} bytes")
return audio_bytes
elif response_data.get("type") == "tts_error":
error_msg = response_data.get("message", "Unknown TTS error")
logger.error(f"❌ TTS service error: {error_msg}")
return None
else:
logger.warning(f"⚠️ Unexpected TTS response: {response_data}")
return None
except Exception as e:
logger.error(f"❌ Error communicating with TTS service: {e}")
# Cleanup connection on error
await self.disconnect_from_tts_service(client_id)
return None
async def play_tts_response(self, client_id: str, text: str, voice_preset: str = "v2/en_speaker_6"):
"""Generate TTS audio and send to client for playback"""
try:
logger.info(f"πŸ”Š Generating TTS response for client {client_id}: {text[:50]}...")
# Try WebSocket FIRST - this is the primary method we want to use
logger.info("🌐 Attempting WebSocket TTS (PRIMARY)")
audio_data = await self.send_text_to_tts_service(client_id, text, voice_preset)
if not audio_data:
logger.info("πŸ”„ WebSocket failed, trying HTTP API fallback")
audio_data = await self.try_http_tts_fallback(text, voice_preset)
if audio_data:
# Convert audio to base64 for WebSocket transmission
audio_b64 = base64.b64encode(audio_data).decode('utf-8')
# Send audio playback message to client
await self.send_message(client_id, {
"type": "tts_playback",
"audio_data": audio_b64,
"audio_format": "wav",
"text": text,
"voice_preset": voice_preset,
"timestamp": datetime.now().isoformat(),
"audio_size": len(audio_data)
})
logger.info(f"πŸ”Š TTS playback sent to {client_id} ({len(audio_data)} bytes)")
else:
logger.warning(f"⚠️ TTS service failed to generate audio for: {text[:50]}...")
# Send error message
await self.send_message(client_id, {
"type": "tts_error",
"message": "TTS audio generation failed",
"text": text,
"timestamp": datetime.now().isoformat()
})
except Exception as e:
logger.error(f"❌ TTS playback error for {client_id}: {e}")
await self.send_message(client_id, {
"type": "tts_error",
"message": f"TTS playback error: {str(e)}",
"timestamp": datetime.now().isoformat()
})
async def process_audio_file_webrtc(self, audio_file_path: str, sample_rate: int) -> Optional[str]:
"""Process audio file with real STT service via WebSocket"""
try:
logger.info(f"🎀 WebRTC: Processing audio file {audio_file_path} with real STT")
# Read audio file data
with open(audio_file_path, 'rb') as f:
audio_data = f.read()
file_size = len(audio_data)
logger.info(f"🎀 Audio file size: {file_size} bytes")
# Use a temporary client ID for this STT call
temp_client_id = f"temp_{datetime.now().isoformat()}"
try:
# Try WebSocket FIRST - this is the primary method we want to use
logger.info("🌐 Attempting WebSocket STT (PRIMARY)")
transcription = await self.send_audio_to_stt_service(temp_client_id, audio_data)
if transcription:
logger.info(f"βœ… WebSocket STT transcription: {transcription}")
return transcription
# Fallback to HTTP API only if WebSocket fails
logger.info("πŸ”„ WebSocket failed, trying HTTP API fallback")
http_transcription = await self.try_http_stt_fallback(audio_file_path)
if http_transcription:
logger.info(f"βœ… HTTP STT transcription (fallback): {http_transcription}")
return f"[HTTP] {http_transcription}"
else:
logger.error("❌ Both WebSocket and HTTP STT failed - using minimal fallback")
# Final fallback - but make it more realistic for TTS
return "I'm having trouble processing that audio. Could you please try again?"
finally:
# Cleanup temporary connection
await self.disconnect_from_stt_service(temp_client_id)
except Exception as e:
logger.error(f"WebRTC audio file processing failed: {e}")
return None
async def try_http_stt_fallback(self, audio_file_path: str) -> Optional[str]:
"""Fallback to HTTP API if WebSocket fails"""
try:
import requests
import aiohttp
import asyncio
# Convert to async HTTP request
def make_request():
api_url = f"{self.stt_service_url}/api/predict"
with open(audio_file_path, 'rb') as audio_file:
files = {'data': audio_file}
data = {'data': '["auto", "base", true]'} # [language, model_size, timestamps]
response = requests.post(api_url, files=files, data=data, timeout=30)
return response
# Run in thread to avoid blocking
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(None, make_request)
if response.status_code == 200:
result = response.json()
logger.info(f"πŸ“ HTTP STT result: {result}")
# Extract transcription from Gradio API format
if result and 'data' in result and len(result['data']) > 1:
transcription = result['data'][1] # [status, transcription, timestamps]
if transcription and transcription.strip():
logger.info(f"βœ… HTTP STT transcription: {transcription}")
return transcription
except Exception as e:
logger.error(f"❌ HTTP STT fallback failed: {e}")
return None
async def try_http_tts_fallback(self, text: str, voice_preset: str = "v2/en_speaker_6") -> Optional[bytes]:
"""Fallback to HTTP API if TTS WebSocket fails"""
try:
import requests
import asyncio
# Convert to async HTTP request
def make_request():
api_url = f"{self.tts_service_url}/api/predict"
data = {'data': f'["{text}", "{voice_preset}"]'} # [text, voice_preset]
response = requests.post(api_url, data=data, timeout=60) # TTS takes longer
return response
# Run in thread to avoid blocking
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(None, make_request)
if response.status_code == 200:
result = response.json()
logger.info(f"πŸ”Š HTTP TTS result received")
# Extract audio file path from Gradio API format
if result and 'data' in result and len(result['data']) > 0:
audio_file_path = result['data'][0] # Should be a file path
if audio_file_path and isinstance(audio_file_path, str):
# Download the audio file
if audio_file_path.startswith('http'):
audio_response = requests.get(audio_file_path, timeout=30)
if audio_response.status_code == 200:
logger.info(f"βœ… HTTP TTS audio downloaded: {len(audio_response.content)} bytes")
return audio_response.content
except Exception as e:
logger.error(f"❌ HTTP TTS fallback failed: {e}")
return None
async def process_audio_chunk_real_time(self, audio_array: np.ndarray, sample_rate: int) -> Optional[str]:
"""Legacy method - kept for compatibility"""
try:
logger.info(f"🎀 WebRTC: Processing {len(audio_array)} samples at {sample_rate}Hz")
duration = len(audio_array) / sample_rate
transcription = f"WebRTC test: Audio array ({duration:.1f}s, {sample_rate}Hz)"
return transcription
except Exception as e:
logger.error(f"WebRTC audio processing failed: {e}")
return None
async def handle_message(self, client_id: str, message_data: dict):
"""Handle different types of WebSocket messages"""
message_type = message_data.get("type")
if message_type == "audio_chunk":
# Real-time audio data
audio_data = message_data.get("audio_data") # Base64 encoded
sample_rate = message_data.get("sample_rate", 16000)
if audio_data:
# Decode base64 audio data
import base64
audio_bytes = base64.b64decode(audio_data)
await self.handle_audio_chunk(client_id, audio_bytes, sample_rate)
elif message_type == "start_recording":
# Client started recording
await self.send_message(client_id, {
"type": "recording_started",
"timestamp": datetime.now().isoformat()
})
logger.info(f"🎀 Recording started for {client_id}")
elif message_type == "stop_recording":
# Client stopped recording
await self.send_message(client_id, {
"type": "recording_stopped",
"timestamp": datetime.now().isoformat()
})
logger.info(f"🎀 Recording stopped for {client_id}")
elif message_type == "tts_request":
# Client requesting TTS playback
text = message_data.get("text", "")
voice_preset = message_data.get("voice_preset", "v2/en_speaker_6")
if text.strip():
await self.play_tts_response(client_id, text, voice_preset)
else:
await self.send_message(client_id, {
"type": "tts_error",
"message": "Empty text provided for TTS",
"timestamp": datetime.now().isoformat()
})
elif message_type == "get_tts_voices":
# Client requesting available TTS voices
await self.send_message(client_id, {
"type": "tts_voices_list",
"voices": ["v2/en_speaker_6", "v2/en_speaker_9", "v2/en_speaker_3", "v2/en_speaker_1"],
"timestamp": datetime.now().isoformat()
})
else:
logger.warning(f"Unknown message type from {client_id}: {message_type}")
# Global WebRTC handler instance
webrtc_handler = WebRTCHandler()