""" WebRTC WebSocket Handler for Real-time Audio Streaming Integrates with FastAPI for unmute.sh-style voice interaction """ import asyncio import json import logging from typing import Dict, Optional import websockets from fastapi import WebSocket, WebSocketDisconnect import numpy as np import soundfile as sf import tempfile import os from datetime import datetime logger = logging.getLogger(__name__) class WebRTCHandler: """Handles WebRTC WebSocket connections for real-time audio streaming""" def __init__(self): self.active_connections: Dict[str, WebSocket] = {} self.audio_buffers: Dict[str, list] = {} self.stt_service_url = "https://pgits-stt-gpu-service.hf.space" self.stt_websocket_url = "wss://pgits-stt-gpu-service.hf.space/ws/stt" self.stt_connections: Dict[str, websockets.WebSocketClientProtocol] = {} self.tts_service_url = "https://pgits-tts-gpu-service.hf.space" self.tts_websocket_url = "wss://pgits-tts-gpu-service.hf.space/ws/tts" self.tts_connections: Dict[str, websockets.WebSocketClientProtocol] = {} async def connect(self, websocket: WebSocket, client_id: str): """Accept WebSocket connection and initialize audio buffer""" await websocket.accept() self.active_connections[client_id] = websocket self.audio_buffers[client_id] = [] logger.info(f"🔌 WebRTC client {client_id} connected") # Send connection confirmation await self.send_message(client_id, { "type": "connection_confirmed", "client_id": client_id, "timestamp": datetime.now().isoformat(), "services": { "stt": self.stt_service_url, "status": "ready" } }) async def disconnect(self, client_id: str): """Clean up connection and buffers""" if client_id in self.active_connections: del self.active_connections[client_id] if client_id in self.audio_buffers: del self.audio_buffers[client_id] # Clean up STT connection if exists await self.disconnect_from_stt_service(client_id) # Clean up TTS connection if exists await self.disconnect_from_tts_service(client_id) logger.info(f"🔌 WebRTC client {client_id} disconnected") async def send_message(self, client_id: str, message: dict): """Send JSON message to client""" if client_id in self.active_connections: websocket = self.active_connections[client_id] try: await websocket.send_text(json.dumps(message)) except Exception as e: logger.error(f"Failed to send message to {client_id}: {e}") await self.disconnect(client_id) async def handle_audio_chunk(self, client_id: str, audio_data: bytes, sample_rate: int = 16000): """Process incoming audio chunk for real-time STT""" try: logger.info(f"🎤 Received {len(audio_data)} bytes from {client_id}") # MediaRecorder typically produces WebM/OGG/WAV format, not raw PCM # For WebRTC demo, we'll save the audio data temporarily and process it with tempfile.NamedTemporaryFile(suffix='.webm', delete=False) as tmp_file: tmp_file.write(audio_data) tmp_file_path = tmp_file.name try: # Process the audio file directly (WebRTC demo mode) transcription = await self.process_audio_file_webrtc(tmp_file_path, sample_rate) if transcription: # Send transcription back to client await self.send_message(client_id, { "type": "transcription", "text": transcription, "timestamp": datetime.now().isoformat(), "audio_size": len(audio_data), "format": "webm/audio" }) logger.info(f"📝 Transcription sent to {client_id}: {transcription[:50]}...") else: # Send error message await self.send_message(client_id, { "type": "error", "message": "Audio processing failed", "timestamp": datetime.now().isoformat() }) finally: # Clean up temporary file if os.path.exists(tmp_file_path): os.unlink(tmp_file_path) except Exception as e: logger.error(f"Error processing audio chunk for {client_id}: {e}") await self.send_message(client_id, { "type": "error", "message": f"Audio processing error: {str(e)}", "timestamp": datetime.now().isoformat() }) async def connect_to_stt_service(self, client_id: str) -> bool: """Connect to the STT WebSocket service""" try: logger.info(f"🔌 Connecting to STT service for client {client_id}: {self.stt_websocket_url}") # Connect to STT WebSocket service with shorter timeout stt_ws = await asyncio.wait_for( websockets.connect(self.stt_websocket_url), timeout=5.0 ) self.stt_connections[client_id] = stt_ws # Wait for connection confirmation with timeout confirmation = await asyncio.wait_for(stt_ws.recv(), timeout=10.0) confirmation_data = json.loads(confirmation) if confirmation_data.get("type") == "stt_connection_confirmed": logger.info(f"✅ STT service connected for client {client_id}") return True else: logger.warning(f"⚠️ Unexpected STT confirmation: {confirmation_data}") return False except asyncio.TimeoutError: logger.error(f"❌ STT service connection timeout for {client_id} - service may be cold starting or WebSocket endpoints not available") return False except websockets.exceptions.WebSocketException as e: logger.error(f"❌ STT WebSocket error for {client_id}: {e}") logger.info(f"🔍 Debug: Attempted connection to {self.stt_websocket_url}") return False except Exception as e: logger.error(f"❌ Failed to connect to STT service for {client_id}: {e}") logger.info(f"🔍 Debug: STT service URL: {self.stt_websocket_url}") return False async def disconnect_from_stt_service(self, client_id: str): """Disconnect from STT WebSocket service""" if client_id in self.stt_connections: try: stt_ws = self.stt_connections[client_id] await stt_ws.close() del self.stt_connections[client_id] logger.info(f"🔌 Disconnected from STT service for client {client_id}") except Exception as e: logger.error(f"Error disconnecting from STT service: {e}") async def send_audio_to_stt_service(self, client_id: str, audio_data: bytes) -> Optional[str]: """Send audio data to STT service and get transcription""" if client_id not in self.stt_connections: # Try to connect if not already connected success = await self.connect_to_stt_service(client_id) if not success: return None try: stt_ws = self.stt_connections[client_id] # Convert audio bytes to base64 for WebSocket transmission import base64 audio_b64 = base64.b64encode(audio_data).decode('utf-8') # Send STT audio chunk message message = { "type": "stt_audio_chunk", "audio_data": audio_b64, "language": "auto", "model_size": "base" } await stt_ws.send(json.dumps(message)) logger.info(f"📤 Sent {len(audio_data)} bytes to STT service") # Wait for transcription response response = await stt_ws.recv() response_data = json.loads(response) if response_data.get("type") == "stt_transcription": transcription = response_data.get("text", "") logger.info(f"📝 STT transcription received: {transcription[:50]}...") return transcription elif response_data.get("type") == "stt_error": error_msg = response_data.get("message", "Unknown STT error") logger.error(f"❌ STT service error: {error_msg}") return None else: logger.warning(f"⚠️ Unexpected STT response: {response_data}") return None except Exception as e: logger.error(f"❌ Error communicating with STT service: {e}") # Cleanup connection on error await self.disconnect_from_stt_service(client_id) return None # TTS WebSocket Methods async def connect_to_tts_service(self, client_id: str) -> bool: """Connect to the TTS WebSocket service""" try: logger.info(f"🔌 Connecting to TTS service for client {client_id}: {self.tts_websocket_url}") # Connect to TTS WebSocket service tts_ws = await websockets.connect(self.tts_websocket_url) self.tts_connections[client_id] = tts_ws # Wait for connection confirmation confirmation = await tts_ws.recv() confirmation_data = json.loads(confirmation) if confirmation_data.get("type") == "tts_connection_confirmed": logger.info(f"✅ TTS service connected for client {client_id}") return True else: logger.warning(f"⚠️ Unexpected TTS confirmation: {confirmation_data}") return False except Exception as e: logger.error(f"❌ Failed to connect to TTS service for {client_id}: {e}") return False async def disconnect_from_tts_service(self, client_id: str): """Disconnect from TTS WebSocket service""" if client_id in self.tts_connections: try: tts_ws = self.tts_connections[client_id] await tts_ws.close() del self.tts_connections[client_id] logger.info(f"🔌 Disconnected from TTS service for client {client_id}") except Exception as e: logger.error(f"Error disconnecting from TTS service: {e}") async def send_text_to_tts_service(self, client_id: str, text: str, voice_preset: str = "v2/en_speaker_6") -> Optional[bytes]: """Send text to TTS service and get audio response""" if client_id not in self.tts_connections: # Try to connect if not already connected success = await self.connect_to_tts_service(client_id) if not success: return None try: tts_ws = self.tts_connections[client_id] # Send TTS synthesis message message = { "type": "tts_synthesize", "text": text, "voice_preset": voice_preset } await tts_ws.send(json.dumps(message)) logger.info(f"📤 Sent text to TTS service: {text[:50]}...") # Wait for audio response response = await tts_ws.recv() response_data = json.loads(response) if response_data.get("type") == "tts_audio_response": # Decode base64 audio data audio_b64 = response_data.get("audio_data", "") audio_bytes = base64.b64decode(audio_b64) logger.info(f"🔊 TTS audio received: {len(audio_bytes)} bytes") return audio_bytes elif response_data.get("type") == "tts_error": error_msg = response_data.get("message", "Unknown TTS error") logger.error(f"❌ TTS service error: {error_msg}") return None else: logger.warning(f"⚠️ Unexpected TTS response: {response_data}") return None except Exception as e: logger.error(f"❌ Error communicating with TTS service: {e}") # Cleanup connection on error await self.disconnect_from_tts_service(client_id) return None async def play_tts_response(self, client_id: str, text: str, voice_preset: str = "v2/en_speaker_6"): """Generate TTS audio and send to client for playback""" try: logger.info(f"🔊 Generating TTS response for client {client_id}: {text[:50]}...") # Try WebSocket FIRST - this is the primary method we want to use logger.info("🌐 Attempting WebSocket TTS (PRIMARY)") audio_data = await self.send_text_to_tts_service(client_id, text, voice_preset) if not audio_data: logger.info("🔄 WebSocket failed, trying HTTP API fallback") audio_data = await self.try_http_tts_fallback(text, voice_preset) if audio_data: # Convert audio to base64 for WebSocket transmission audio_b64 = base64.b64encode(audio_data).decode('utf-8') # Send audio playback message to client await self.send_message(client_id, { "type": "tts_playback", "audio_data": audio_b64, "audio_format": "wav", "text": text, "voice_preset": voice_preset, "timestamp": datetime.now().isoformat(), "audio_size": len(audio_data) }) logger.info(f"🔊 TTS playback sent to {client_id} ({len(audio_data)} bytes)") else: logger.warning(f"⚠️ TTS service failed to generate audio for: {text[:50]}...") # Send error message await self.send_message(client_id, { "type": "tts_error", "message": "TTS audio generation failed", "text": text, "timestamp": datetime.now().isoformat() }) except Exception as e: logger.error(f"❌ TTS playback error for {client_id}: {e}") await self.send_message(client_id, { "type": "tts_error", "message": f"TTS playback error: {str(e)}", "timestamp": datetime.now().isoformat() }) async def process_audio_file_webrtc(self, audio_file_path: str, sample_rate: int) -> Optional[str]: """Process audio file with real STT service via WebSocket""" try: logger.info(f"🎤 WebRTC: Processing audio file {audio_file_path} with real STT") # Read audio file data with open(audio_file_path, 'rb') as f: audio_data = f.read() file_size = len(audio_data) logger.info(f"🎤 Audio file size: {file_size} bytes") # Use a temporary client ID for this STT call temp_client_id = f"temp_{datetime.now().isoformat()}" try: # Try WebSocket FIRST - this is the primary method we want to use logger.info("🌐 Attempting WebSocket STT (PRIMARY)") transcription = await self.send_audio_to_stt_service(temp_client_id, audio_data) if transcription: logger.info(f"✅ WebSocket STT transcription: {transcription}") return transcription # Fallback to HTTP API only if WebSocket fails logger.info("🔄 WebSocket failed, trying HTTP API fallback") http_transcription = await self.try_http_stt_fallback(audio_file_path) if http_transcription: logger.info(f"✅ HTTP STT transcription (fallback): {http_transcription}") return f"[HTTP] {http_transcription}" else: logger.error("❌ Both WebSocket and HTTP STT failed - using minimal fallback") # Final fallback - but make it more realistic for TTS return "I'm having trouble processing that audio. Could you please try again?" finally: # Cleanup temporary connection await self.disconnect_from_stt_service(temp_client_id) except Exception as e: logger.error(f"WebRTC audio file processing failed: {e}") return None async def try_http_stt_fallback(self, audio_file_path: str) -> Optional[str]: """Fallback to HTTP API if WebSocket fails""" try: import requests import aiohttp import asyncio # Convert to async HTTP request def make_request(): api_url = f"{self.stt_service_url}/api/predict" with open(audio_file_path, 'rb') as audio_file: files = {'data': audio_file} data = {'data': '["auto", "base", true]'} # [language, model_size, timestamps] response = requests.post(api_url, files=files, data=data, timeout=30) return response # Run in thread to avoid blocking loop = asyncio.get_event_loop() response = await loop.run_in_executor(None, make_request) if response.status_code == 200: result = response.json() logger.info(f"📝 HTTP STT result: {result}") # Extract transcription from Gradio API format if result and 'data' in result and len(result['data']) > 1: transcription = result['data'][1] # [status, transcription, timestamps] if transcription and transcription.strip(): logger.info(f"✅ HTTP STT transcription: {transcription}") return transcription except Exception as e: logger.error(f"❌ HTTP STT fallback failed: {e}") return None async def try_http_tts_fallback(self, text: str, voice_preset: str = "v2/en_speaker_6") -> Optional[bytes]: """Fallback to HTTP API if TTS WebSocket fails""" try: import requests import asyncio # Convert to async HTTP request def make_request(): api_url = f"{self.tts_service_url}/api/predict" data = {'data': f'["{text}", "{voice_preset}"]'} # [text, voice_preset] response = requests.post(api_url, data=data, timeout=60) # TTS takes longer return response # Run in thread to avoid blocking loop = asyncio.get_event_loop() response = await loop.run_in_executor(None, make_request) if response.status_code == 200: result = response.json() logger.info(f"🔊 HTTP TTS result received") # Extract audio file path from Gradio API format if result and 'data' in result and len(result['data']) > 0: audio_file_path = result['data'][0] # Should be a file path if audio_file_path and isinstance(audio_file_path, str): # Download the audio file if audio_file_path.startswith('http'): audio_response = requests.get(audio_file_path, timeout=30) if audio_response.status_code == 200: logger.info(f"✅ HTTP TTS audio downloaded: {len(audio_response.content)} bytes") return audio_response.content except Exception as e: logger.error(f"❌ HTTP TTS fallback failed: {e}") return None async def process_audio_chunk_real_time(self, audio_array: np.ndarray, sample_rate: int) -> Optional[str]: """Legacy method - kept for compatibility""" try: logger.info(f"🎤 WebRTC: Processing {len(audio_array)} samples at {sample_rate}Hz") duration = len(audio_array) / sample_rate transcription = f"WebRTC test: Audio array ({duration:.1f}s, {sample_rate}Hz)" return transcription except Exception as e: logger.error(f"WebRTC audio processing failed: {e}") return None async def handle_message(self, client_id: str, message_data: dict): """Handle different types of WebSocket messages""" message_type = message_data.get("type") if message_type == "audio_chunk": # Real-time audio data audio_data = message_data.get("audio_data") # Base64 encoded sample_rate = message_data.get("sample_rate", 16000) if audio_data: # Decode base64 audio data import base64 audio_bytes = base64.b64decode(audio_data) await self.handle_audio_chunk(client_id, audio_bytes, sample_rate) elif message_type == "start_recording": # Client started recording await self.send_message(client_id, { "type": "recording_started", "timestamp": datetime.now().isoformat() }) logger.info(f"🎤 Recording started for {client_id}") elif message_type == "stop_recording": # Client stopped recording await self.send_message(client_id, { "type": "recording_stopped", "timestamp": datetime.now().isoformat() }) logger.info(f"🎤 Recording stopped for {client_id}") elif message_type == "tts_request": # Client requesting TTS playback text = message_data.get("text", "") voice_preset = message_data.get("voice_preset", "v2/en_speaker_6") if text.strip(): await self.play_tts_response(client_id, text, voice_preset) else: await self.send_message(client_id, { "type": "tts_error", "message": "Empty text provided for TTS", "timestamp": datetime.now().isoformat() }) elif message_type == "get_tts_voices": # Client requesting available TTS voices await self.send_message(client_id, { "type": "tts_voices_list", "voices": ["v2/en_speaker_6", "v2/en_speaker_9", "v2/en_speaker_3", "v2/en_speaker_1"], "timestamp": datetime.now().isoformat() }) else: logger.warning(f"Unknown message type from {client_id}: {message_type}") # Global WebRTC handler instance webrtc_handler = WebRTCHandler()