""" WebRTC WebSocket Handler for Real-time Audio Streaming Integrates with FastAPI for unmute.sh-style voice interaction """ import asyncio import json import logging from typing import Dict, Optional import websockets from fastapi import WebSocket, WebSocketDisconnect import numpy as np import soundfile as sf import tempfile import os from datetime import datetime logger = logging.getLogger(__name__) class WebRTCHandler: """Handles WebRTC WebSocket connections for real-time audio streaming""" def __init__(self): self.active_connections: Dict[str, WebSocket] = {} self.audio_buffers: Dict[str, list] = {} self.stt_service_url = "https://pgits-stt-gpu-service.hf.space" self.stt_websocket_url = "wss://pgits-stt-gpu-service.hf.space/ws/stt" self.stt_connections: Dict[str, websockets.WebSocketClientProtocol] = {} self.tts_service_url = "https://pgits-tts-gpu-service.hf.space" self.tts_websocket_url = "wss://pgits-tts-gpu-service.hf.space/ws/tts" self.tts_connections: Dict[str, websockets.WebSocketClientProtocol] = {} async def connect(self, websocket: WebSocket, client_id: str): """Accept WebSocket connection and initialize audio buffer""" await websocket.accept() self.active_connections[client_id] = websocket self.audio_buffers[client_id] = [] logger.info(f"๐Ÿ”Œ WebRTC client {client_id} connected") # Send connection confirmation await self.send_message(client_id, { "type": "connection_confirmed", "client_id": client_id, "timestamp": datetime.now().isoformat(), "services": { "stt": self.stt_service_url, "status": "ready" } }) async def disconnect(self, client_id: str): """Clean up connection and buffers""" if client_id in self.active_connections: del self.active_connections[client_id] if client_id in self.audio_buffers: del self.audio_buffers[client_id] # Clean up STT connection if exists await self.disconnect_from_stt_service(client_id) # Clean up TTS connection if exists await self.disconnect_from_tts_service(client_id) logger.info(f"๐Ÿ”Œ WebRTC client {client_id} disconnected") async def send_message(self, client_id: str, message: dict): """Send JSON message to client""" if client_id in self.active_connections: websocket = self.active_connections[client_id] try: await websocket.send_text(json.dumps(message)) except Exception as e: logger.error(f"Failed to send message to {client_id}: {e}") await self.disconnect(client_id) async def handle_audio_chunk(self, client_id: str, audio_data: bytes, sample_rate: int = 16000): """Process incoming audio chunk using unmute.sh streaming methodology""" try: logger.info(f"๐ŸŽค Received {len(audio_data)} bytes from {client_id}") # UNMUTE.SH METHODOLOGY: Buffer chunks for streaming STT processing if client_id not in self.audio_buffers: self.audio_buffers[client_id] = [] # Add chunk to buffer self.audio_buffers[client_id].append(audio_data) # Send partial transcription acknowledgment (unmute.sh style) await self.send_message(client_id, { "type": "chunk_buffered", "chunk_size": len(audio_data), "buffer_chunks": len(self.audio_buffers[client_id]), "timestamp": datetime.now().isoformat() }) logger.info(f"๐Ÿ“ฆ Buffered chunk for {client_id} ({len(self.audio_buffers[client_id])} total chunks)") except Exception as e: logger.error(f"Error buffering audio chunk for {client_id}: {e}") await self.send_message(client_id, { "type": "error", "message": f"Audio buffering error: {str(e)}", "timestamp": datetime.now().isoformat() }) async def process_buffered_audio_with_flush(self, client_id: str): """Process all buffered audio chunks with unmute.sh flush trick""" try: if client_id not in self.audio_buffers or not self.audio_buffers[client_id]: logger.info(f"No audio chunks to process for {client_id}") return # Combine all audio chunks into one complete audio file all_audio_data = b''.join(self.audio_buffers[client_id]) total_chunks = len(self.audio_buffers[client_id]) logger.info(f"๐Ÿ”„ Processing {total_chunks} buffered chunks ({len(all_audio_data)} bytes total) with flush trick") # Create temporary file for complete audio with tempfile.NamedTemporaryFile(suffix='.webm', delete=False) as tmp_file: tmp_file.write(all_audio_data) tmp_file_path = tmp_file.name try: # Process complete audio with unmute.sh methodology (is_final=True for flush trick) transcription = await self.process_audio_file_webrtc_with_flush(tmp_file_path) if transcription and transcription.strip() and not transcription.startswith("ERROR"): # Send final transcription back to client await self.send_message(client_id, { "type": "transcription", "text": transcription.strip(), "timestamp": datetime.now().isoformat(), "audio_size": len(all_audio_data), "format": "webm/audio", "is_final": True, # unmute.sh flush trick marker "chunks_processed": total_chunks }) logger.info(f"๐Ÿ“ Final transcription sent to {client_id}: {transcription[:50]}...") else: # Send error message await self.send_message(client_id, { "type": "transcription_error", "message": f"Audio processing failed: {transcription if transcription else 'No result'}", "timestamp": datetime.now().isoformat() }) finally: # Clean up if os.path.exists(tmp_file_path): os.unlink(tmp_file_path) # Clear the buffer self.audio_buffers[client_id] = [] logger.info(f"๐Ÿงน Cleared audio buffer for {client_id}") except Exception as e: logger.error(f"Error processing buffered audio for {client_id}: {e}") await self.send_message(client_id, { "type": "transcription_error", "message": f"Buffered audio processing error: {str(e)}", "timestamp": datetime.now().isoformat() }) async def process_audio_file_webrtc_with_flush(self, audio_file_path: str) -> Optional[str]: """Process audio file using unmute.sh flush trick methodology""" try: # Import the MCP audio handler for processing from core.mcp_audio_handler import mcp_audio_handler # Use the real STT service with flush trick (is_final=True) result = await mcp_audio_handler.speech_to_text(audio_file_path) logger.info(f"๐Ÿš€ FLUSH TRICK: STT service returned: {result[:100] if result else 'None'}...") return result except Exception as e: logger.error(f"Error in flush trick audio processing: {e}") return f"ERROR: Flush trick processing failed - {str(e)}" async def connect_to_stt_service(self, client_id: str) -> bool: """Connect to the STT WebSocket service""" try: logger.info(f"๐Ÿ”Œ Connecting to STT service for client {client_id}: {self.stt_websocket_url}") # Connect to STT WebSocket service with shorter timeout stt_ws = await asyncio.wait_for( websockets.connect(self.stt_websocket_url), timeout=5.0 ) self.stt_connections[client_id] = stt_ws # Wait for connection confirmation with timeout confirmation = await asyncio.wait_for(stt_ws.recv(), timeout=10.0) confirmation_data = json.loads(confirmation) if confirmation_data.get("type") == "stt_connection_confirmed": logger.info(f"โœ… STT service connected for client {client_id}") return True else: logger.warning(f"โš ๏ธ Unexpected STT confirmation: {confirmation_data}") return False except asyncio.TimeoutError: logger.error(f"โŒ STT service connection timeout for {client_id} - service may be cold starting or WebSocket endpoints not available") return False except websockets.exceptions.WebSocketException as e: if "503" in str(e): logger.error(f"โŒ STT service unavailable (HTTP 503) for {client_id} - service may be cold starting") logger.info(f"๐Ÿ”„ Try again in a few moments - Hugging Face services need time to start") else: logger.error(f"โŒ STT WebSocket error for {client_id}: {e}") logger.info(f"๐Ÿ” Debug: Attempted connection to {self.stt_websocket_url}") return False except Exception as e: logger.error(f"โŒ Failed to connect to STT service for {client_id}: {e}") logger.info(f"๐Ÿ” Debug: STT service URL: {self.stt_websocket_url}") return False async def disconnect_from_stt_service(self, client_id: str): """Disconnect from STT WebSocket service""" if client_id in self.stt_connections: try: stt_ws = self.stt_connections[client_id] await stt_ws.close() del self.stt_connections[client_id] logger.info(f"๐Ÿ”Œ Disconnected from STT service for client {client_id}") except Exception as e: logger.error(f"Error disconnecting from STT service: {e}") async def send_audio_to_stt_service(self, client_id: str, audio_data: bytes) -> Optional[str]: """Send audio data to STT service and get transcription""" if client_id not in self.stt_connections: # Try to connect if not already connected success = await self.connect_to_stt_service(client_id) if not success: return None try: stt_ws = self.stt_connections[client_id] # Convert audio bytes to base64 for WebSocket transmission import base64 audio_b64 = base64.b64encode(audio_data).decode('utf-8') # Send STT audio chunk message message = { "type": "stt_audio_chunk", "audio_data": audio_b64, "language": "auto", "model_size": "base" } await stt_ws.send(json.dumps(message)) logger.info(f"๐Ÿ“ค Sent {len(audio_data)} bytes to STT service") # Wait for transcription response response = await stt_ws.recv() response_data = json.loads(response) if response_data.get("type") == "stt_transcription": transcription = response_data.get("text", "") logger.info(f"๐Ÿ“ STT transcription received: {transcription[:50]}...") return transcription elif response_data.get("type") == "stt_error": error_msg = response_data.get("message", "Unknown STT error") logger.error(f"โŒ STT service error: {error_msg}") return None else: logger.warning(f"โš ๏ธ Unexpected STT response: {response_data}") return None except Exception as e: logger.error(f"โŒ Error communicating with STT service: {e}") # Cleanup connection on error await self.disconnect_from_stt_service(client_id) return None # TTS WebSocket Methods async def connect_to_tts_service(self, client_id: str) -> bool: """Connect to the TTS WebSocket service""" try: logger.info(f"๐Ÿ”Œ Connecting to TTS service for client {client_id}: {self.tts_websocket_url}") # Connect to TTS WebSocket service with timeout tts_ws = await asyncio.wait_for( websockets.connect(self.tts_websocket_url), timeout=10.0 ) self.tts_connections[client_id] = tts_ws # Wait for connection confirmation confirmation = await asyncio.wait_for(tts_ws.recv(), timeout=15.0) confirmation_data = json.loads(confirmation) if confirmation_data.get("type") == "tts_connection_confirmed": logger.info(f"โœ… TTS service connected for client {client_id}") return True else: logger.warning(f"โš ๏ธ Unexpected TTS confirmation: {confirmation_data}") return False except asyncio.TimeoutError: logger.error(f"โŒ TTS service connection timeout - service may not be in WebSocket mode") logger.info(f"๐Ÿ’ก TTS service needs TTS_SERVICE_MODE=websocket environment variable") return False except websockets.exceptions.InvalidStatusCode as e: logger.error(f"โŒ TTS WebSocket endpoint not available: {e}") logger.info(f"๐Ÿ’ก TTS service may be running in Gradio-only mode instead of WebSocket mode") return False except Exception as e: logger.error(f"โŒ Failed to connect to TTS service for {client_id}: {e}") logger.info(f"๐Ÿ’ก Check if TTS service is running and configured with TTS_SERVICE_MODE=websocket") return False async def disconnect_from_tts_service(self, client_id: str): """Disconnect from TTS WebSocket service""" if client_id in self.tts_connections: try: tts_ws = self.tts_connections[client_id] await tts_ws.close() del self.tts_connections[client_id] logger.info(f"๐Ÿ”Œ Disconnected from TTS service for client {client_id}") except Exception as e: logger.error(f"Error disconnecting from TTS service: {e}") async def send_text_to_tts_service(self, client_id: str, text: str, voice_preset: str = "v2/en_speaker_6") -> Optional[bytes]: """Send text to TTS service and get audio response""" if client_id not in self.tts_connections: # Try to connect if not already connected success = await self.connect_to_tts_service(client_id) if not success: return None try: tts_ws = self.tts_connections[client_id] # Send TTS synthesis message message = { "type": "tts_synthesize", "text": text, "voice_preset": voice_preset } await tts_ws.send(json.dumps(message)) logger.info(f"๐Ÿ“ค Sent text to TTS service: {text[:50]}...") # Wait for audio response response = await tts_ws.recv() response_data = json.loads(response) if response_data.get("type") == "tts_audio_response": # Decode base64 audio data import base64 audio_b64 = response_data.get("audio_data", "") audio_bytes = base64.b64decode(audio_b64) logger.info(f"๐Ÿ”Š TTS audio received: {len(audio_bytes)} bytes") return audio_bytes elif response_data.get("type") == "tts_error": error_msg = response_data.get("message", "Unknown TTS error") logger.error(f"โŒ TTS service error: {error_msg}") return None else: logger.warning(f"โš ๏ธ Unexpected TTS response: {response_data}") return None except Exception as e: logger.error(f"โŒ Error communicating with TTS service: {e}") # Cleanup connection on error await self.disconnect_from_tts_service(client_id) return None async def play_tts_response(self, client_id: str, text: str, voice_preset: str = "v2/en_speaker_6"): """Generate TTS audio and send to client for playback""" try: logger.info(f"๐Ÿ”Š Generating TTS response for client {client_id}: {text[:50]}...") # Try WebSocket FIRST - this is the primary method we want to use logger.info("๐ŸŒ Attempting WebSocket TTS (PRIMARY)") audio_data = await self.send_text_to_tts_service(client_id, text, voice_preset) if not audio_data: logger.info("๐Ÿ”„ WebSocket failed, trying HTTP API fallback") audio_data = await self.try_http_tts_fallback(text, voice_preset) if audio_data: # Convert audio to base64 for WebSocket transmission import base64 audio_b64 = base64.b64encode(audio_data).decode('utf-8') # Send audio playback message to client await self.send_message(client_id, { "type": "tts_playback", "audio_data": audio_b64, "audio_format": "wav", "text": text, "voice_preset": voice_preset, "timestamp": datetime.now().isoformat(), "audio_size": len(audio_data) }) logger.info(f"๐Ÿ”Š TTS playback sent to {client_id} ({len(audio_data)} bytes)") else: logger.warning(f"โš ๏ธ TTS service failed to generate audio for: {text[:50]}...") # Send error message await self.send_message(client_id, { "type": "tts_error", "message": "TTS audio generation failed", "text": text, "timestamp": datetime.now().isoformat() }) except Exception as e: logger.error(f"โŒ TTS playback error for {client_id}: {e}") await self.send_message(client_id, { "type": "tts_error", "message": f"TTS playback error: {str(e)}", "timestamp": datetime.now().isoformat() }) async def process_audio_file_webrtc(self, audio_file_path: str, sample_rate: int) -> Optional[str]: """Process audio file with real STT service via WebSocket""" try: logger.info(f"๐ŸŽค WebRTC: Processing audio file {audio_file_path} with real STT") # Read audio file data with open(audio_file_path, 'rb') as f: audio_data = f.read() file_size = len(audio_data) logger.info(f"๐ŸŽค Audio file size: {file_size} bytes") # Use a temporary client ID for this STT call temp_client_id = f"temp_{datetime.now().isoformat()}" try: # Try WebSocket FIRST - this is the primary method we want to use logger.info("๐ŸŒ Attempting WebSocket STT (PRIMARY)") transcription = await self.send_audio_to_stt_service(temp_client_id, audio_data) if transcription: logger.info(f"โœ… WebSocket STT transcription: {transcription}") return transcription # Fallback to HTTP API only if WebSocket fails logger.info("๐Ÿ”„ WebSocket failed, trying HTTP API fallback") http_transcription = await self.try_http_stt_fallback(audio_file_path) if http_transcription: logger.info(f"โœ… HTTP STT transcription (fallback): {http_transcription}") return f"[HTTP] {http_transcription}" else: logger.error("โŒ Both WebSocket and HTTP STT failed - using minimal fallback") # Final fallback - but make it more realistic for TTS return "I'm having trouble processing that audio. Could you please try again?" finally: # Cleanup temporary connection await self.disconnect_from_stt_service(temp_client_id) except Exception as e: logger.error(f"WebRTC audio file processing failed: {e}") return None async def try_http_stt_fallback(self, audio_file_path: str) -> Optional[str]: """Fallback to HTTP API if WebSocket fails""" try: import requests import aiohttp import asyncio # Convert to async HTTP request def make_request(): api_url = f"{self.stt_service_url}/api/predict" with open(audio_file_path, 'rb') as audio_file: files = {'data': audio_file} data = {'data': '["auto", "base", true]'} # [language, model_size, timestamps] response = requests.post(api_url, files=files, data=data, timeout=30) return response # Run in thread to avoid blocking loop = asyncio.get_event_loop() response = await loop.run_in_executor(None, make_request) if response.status_code == 200: result = response.json() logger.info(f"๐Ÿ“ HTTP STT result: {result}") # Extract transcription from Gradio API format if result and 'data' in result and len(result['data']) > 1: transcription = result['data'][1] # [status, transcription, timestamps] if transcription and transcription.strip(): logger.info(f"โœ… HTTP STT transcription: {transcription}") return transcription except Exception as e: logger.error(f"โŒ HTTP STT fallback failed: {e}") return None async def try_http_tts_fallback(self, text: str, voice_preset: str = "v2/en_speaker_6") -> Optional[bytes]: """Fallback to HTTP API if TTS WebSocket fails""" try: import requests import asyncio # Convert to async HTTP request def make_request(): api_url = f"{self.tts_service_url}/api/predict" data = {'data': f'["{text}", "{voice_preset}"]'} # [text, voice_preset] response = requests.post(api_url, data=data, timeout=60) # TTS takes longer return response # Run in thread to avoid blocking loop = asyncio.get_event_loop() response = await loop.run_in_executor(None, make_request) if response.status_code == 200: result = response.json() logger.info(f"๐Ÿ”Š HTTP TTS result received") # Extract audio file path from Gradio API format if result and 'data' in result and len(result['data']) > 0: audio_file_path = result['data'][0] # Should be a file path if audio_file_path and isinstance(audio_file_path, str): # Download the audio file if audio_file_path.startswith('http'): audio_response = requests.get(audio_file_path, timeout=30) if audio_response.status_code == 200: logger.info(f"โœ… HTTP TTS audio downloaded: {len(audio_response.content)} bytes") return audio_response.content except Exception as e: logger.error(f"โŒ HTTP TTS fallback failed: {e}") return None async def process_audio_chunk_real_time(self, audio_array: np.ndarray, sample_rate: int) -> Optional[str]: """Legacy method - kept for compatibility""" try: logger.info(f"๐ŸŽค WebRTC: Processing {len(audio_array)} samples at {sample_rate}Hz") duration = len(audio_array) / sample_rate transcription = f"WebRTC test: Audio array ({duration:.1f}s, {sample_rate}Hz)" return transcription except Exception as e: logger.error(f"WebRTC audio processing failed: {e}") return None async def handle_message(self, client_id: str, message_data: dict): """Handle different types of WebSocket messages""" message_type = message_data.get("type") if message_type == "audio_chunk": # Real-time audio data audio_data = message_data.get("audio_data") # Base64 encoded sample_rate = message_data.get("sample_rate", 16000) if audio_data: # Decode base64 audio data import base64 audio_bytes = base64.b64decode(audio_data) await self.handle_audio_chunk(client_id, audio_bytes, sample_rate) elif message_type == "start_recording": # Client started recording await self.send_message(client_id, { "type": "recording_started", "timestamp": datetime.now().isoformat() }) logger.info(f"๐ŸŽค Recording started for {client_id}") elif message_type == "stop_recording": # Client stopped recording - UNMUTE.SH FLUSH TRICK logger.info(f"๐ŸŽค Recording stopped for {client_id} - applying unmute.sh flush trick") # Process all buffered audio with flush trick await self.process_buffered_audio_with_flush(client_id) await self.send_message(client_id, { "type": "recording_stopped", "timestamp": datetime.now().isoformat() }) elif message_type == "tts_request": # Client requesting TTS playback text = message_data.get("text", "") voice_preset = message_data.get("voice_preset", "v2/en_speaker_6") if text.strip(): await self.play_tts_response(client_id, text, voice_preset) else: await self.send_message(client_id, { "type": "tts_error", "message": "Empty text provided for TTS", "timestamp": datetime.now().isoformat() }) elif message_type == "get_tts_voices": # Client requesting available TTS voices await self.send_message(client_id, { "type": "tts_voices_list", "voices": ["v2/en_speaker_6", "v2/en_speaker_9", "v2/en_speaker_3", "v2/en_speaker_1"], "timestamp": datetime.now().isoformat() }) else: logger.warning(f"Unknown message type from {client_id}: {message_type}") # Global WebRTC handler instance webrtc_handler = WebRTCHandler()