Spaces:
Sleeping
Sleeping
Peter Michael Gits
feat: Add Streamlit-native WebRTC speech-to-text using unmute.sh patterns
21fac9b | """ | |
| WebRTC WebSocket Handler for Real-time Audio Streaming | |
| Integrates with FastAPI for unmute.sh-style voice interaction | |
| """ | |
| import asyncio | |
| import json | |
| import logging | |
| from typing import Dict, Optional | |
| import websockets | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| import numpy as np | |
| import soundfile as sf | |
| import tempfile | |
| import os | |
| from datetime import datetime | |
| logger = logging.getLogger(__name__) | |
| class WebRTCHandler: | |
| """Handles WebRTC WebSocket connections for real-time audio streaming""" | |
| def __init__(self): | |
| self.active_connections: Dict[str, WebSocket] = {} | |
| self.audio_buffers: Dict[str, list] = {} | |
| self.stt_service_url = "https://pgits-stt-gpu-service.hf.space" | |
| self.stt_websocket_url = "wss://pgits-stt-gpu-service.hf.space/ws/stt" | |
| self.stt_connections: Dict[str, websockets.WebSocketClientProtocol] = {} | |
| self.tts_service_url = "https://pgits-tts-gpu-service.hf.space" | |
| self.tts_websocket_url = "wss://pgits-tts-gpu-service.hf.space/ws/tts" | |
| self.tts_connections: Dict[str, websockets.WebSocketClientProtocol] = {} | |
| async def connect(self, websocket: WebSocket, client_id: str): | |
| """Accept WebSocket connection and initialize audio buffer""" | |
| await websocket.accept() | |
| self.active_connections[client_id] = websocket | |
| self.audio_buffers[client_id] = [] | |
| logger.info(f"π WebRTC client {client_id} connected") | |
| # Send connection confirmation | |
| await self.send_message(client_id, { | |
| "type": "connection_confirmed", | |
| "client_id": client_id, | |
| "timestamp": datetime.now().isoformat(), | |
| "services": { | |
| "stt": self.stt_service_url, | |
| "status": "ready" | |
| } | |
| }) | |
| async def disconnect(self, client_id: str): | |
| """Clean up connection and buffers""" | |
| if client_id in self.active_connections: | |
| del self.active_connections[client_id] | |
| if client_id in self.audio_buffers: | |
| del self.audio_buffers[client_id] | |
| # Clean up STT connection if exists | |
| await self.disconnect_from_stt_service(client_id) | |
| # Clean up TTS connection if exists | |
| await self.disconnect_from_tts_service(client_id) | |
| logger.info(f"π WebRTC client {client_id} disconnected") | |
| async def send_message(self, client_id: str, message: dict): | |
| """Send JSON message to client""" | |
| if client_id in self.active_connections: | |
| websocket = self.active_connections[client_id] | |
| try: | |
| await websocket.send_text(json.dumps(message)) | |
| except Exception as e: | |
| logger.error(f"Failed to send message to {client_id}: {e}") | |
| await self.disconnect(client_id) | |
| async def handle_audio_chunk(self, client_id: str, audio_data: bytes, sample_rate: int = 16000): | |
| """Process incoming audio chunk using unmute.sh streaming methodology""" | |
| try: | |
| logger.info(f"π€ Received {len(audio_data)} bytes from {client_id}") | |
| # UNMUTE.SH METHODOLOGY: Buffer chunks for streaming STT processing | |
| if client_id not in self.audio_buffers: | |
| self.audio_buffers[client_id] = [] | |
| # Add chunk to buffer | |
| self.audio_buffers[client_id].append(audio_data) | |
| # Send partial transcription acknowledgment (unmute.sh style) | |
| await self.send_message(client_id, { | |
| "type": "chunk_buffered", | |
| "chunk_size": len(audio_data), | |
| "buffer_chunks": len(self.audio_buffers[client_id]), | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| logger.info(f"π¦ Buffered chunk for {client_id} ({len(self.audio_buffers[client_id])} total chunks)") | |
| except Exception as e: | |
| logger.error(f"Error buffering audio chunk for {client_id}: {e}") | |
| await self.send_message(client_id, { | |
| "type": "error", | |
| "message": f"Audio buffering error: {str(e)}", | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| async def process_buffered_audio_with_flush(self, client_id: str): | |
| """Process all buffered audio chunks with unmute.sh flush trick""" | |
| try: | |
| if client_id not in self.audio_buffers or not self.audio_buffers[client_id]: | |
| logger.info(f"No audio chunks to process for {client_id}") | |
| return | |
| # Combine all audio chunks into one complete audio file | |
| all_audio_data = b''.join(self.audio_buffers[client_id]) | |
| total_chunks = len(self.audio_buffers[client_id]) | |
| logger.info(f"π Processing {total_chunks} buffered chunks ({len(all_audio_data)} bytes total) with flush trick") | |
| # Create temporary file for complete audio | |
| with tempfile.NamedTemporaryFile(suffix='.webm', delete=False) as tmp_file: | |
| tmp_file.write(all_audio_data) | |
| tmp_file_path = tmp_file.name | |
| try: | |
| # Process complete audio with unmute.sh methodology (is_final=True for flush trick) | |
| transcription = await self.process_audio_file_webrtc_with_flush(tmp_file_path) | |
| if transcription and transcription.strip() and not transcription.startswith("ERROR"): | |
| # Send final transcription back to client | |
| await self.send_message(client_id, { | |
| "type": "transcription", | |
| "text": transcription.strip(), | |
| "timestamp": datetime.now().isoformat(), | |
| "audio_size": len(all_audio_data), | |
| "format": "webm/audio", | |
| "is_final": True, # unmute.sh flush trick marker | |
| "chunks_processed": total_chunks | |
| }) | |
| logger.info(f"π Final transcription sent to {client_id}: {transcription[:50]}...") | |
| else: | |
| # Send error message | |
| await self.send_message(client_id, { | |
| "type": "transcription_error", | |
| "message": f"Audio processing failed: {transcription if transcription else 'No result'}", | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| finally: | |
| # Clean up | |
| if os.path.exists(tmp_file_path): | |
| os.unlink(tmp_file_path) | |
| # Clear the buffer | |
| self.audio_buffers[client_id] = [] | |
| logger.info(f"π§Ή Cleared audio buffer for {client_id}") | |
| except Exception as e: | |
| logger.error(f"Error processing buffered audio for {client_id}: {e}") | |
| await self.send_message(client_id, { | |
| "type": "transcription_error", | |
| "message": f"Buffered audio processing error: {str(e)}", | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| async def process_audio_file_webrtc_with_flush(self, audio_file_path: str) -> Optional[str]: | |
| """Process audio file using unmute.sh flush trick methodology""" | |
| try: | |
| # Import the MCP audio handler for processing | |
| from core.mcp_audio_handler import mcp_audio_handler | |
| # Use the real STT service with flush trick (is_final=True) | |
| result = await mcp_audio_handler.speech_to_text(audio_file_path) | |
| logger.info(f"π FLUSH TRICK: STT service returned: {result[:100] if result else 'None'}...") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error in flush trick audio processing: {e}") | |
| return f"ERROR: Flush trick processing failed - {str(e)}" | |
| async def connect_to_stt_service(self, client_id: str) -> bool: | |
| """Connect to the STT WebSocket service""" | |
| try: | |
| logger.info(f"π Connecting to STT service for client {client_id}: {self.stt_websocket_url}") | |
| # Connect to STT WebSocket service with shorter timeout | |
| stt_ws = await asyncio.wait_for( | |
| websockets.connect(self.stt_websocket_url), | |
| timeout=5.0 | |
| ) | |
| self.stt_connections[client_id] = stt_ws | |
| # Wait for connection confirmation with timeout | |
| confirmation = await asyncio.wait_for(stt_ws.recv(), timeout=10.0) | |
| confirmation_data = json.loads(confirmation) | |
| if confirmation_data.get("type") == "stt_connection_confirmed": | |
| logger.info(f"β STT service connected for client {client_id}") | |
| return True | |
| else: | |
| logger.warning(f"β οΈ Unexpected STT confirmation: {confirmation_data}") | |
| return False | |
| except asyncio.TimeoutError: | |
| logger.error(f"β STT service connection timeout for {client_id} - service may be cold starting or WebSocket endpoints not available") | |
| return False | |
| except websockets.exceptions.WebSocketException as e: | |
| if "503" in str(e): | |
| logger.error(f"β STT service unavailable (HTTP 503) for {client_id} - service may be cold starting") | |
| logger.info(f"π Try again in a few moments - Hugging Face services need time to start") | |
| else: | |
| logger.error(f"β STT WebSocket error for {client_id}: {e}") | |
| logger.info(f"π Debug: Attempted connection to {self.stt_websocket_url}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"β Failed to connect to STT service for {client_id}: {e}") | |
| logger.info(f"π Debug: STT service URL: {self.stt_websocket_url}") | |
| return False | |
| async def disconnect_from_stt_service(self, client_id: str): | |
| """Disconnect from STT WebSocket service""" | |
| if client_id in self.stt_connections: | |
| try: | |
| stt_ws = self.stt_connections[client_id] | |
| await stt_ws.close() | |
| del self.stt_connections[client_id] | |
| logger.info(f"π Disconnected from STT service for client {client_id}") | |
| except Exception as e: | |
| logger.error(f"Error disconnecting from STT service: {e}") | |
| async def send_audio_to_stt_service(self, client_id: str, audio_data: bytes) -> Optional[str]: | |
| """Send audio data to STT service and get transcription""" | |
| if client_id not in self.stt_connections: | |
| # Try to connect if not already connected | |
| success = await self.connect_to_stt_service(client_id) | |
| if not success: | |
| return None | |
| try: | |
| stt_ws = self.stt_connections[client_id] | |
| # Convert audio bytes to base64 for WebSocket transmission | |
| import base64 | |
| audio_b64 = base64.b64encode(audio_data).decode('utf-8') | |
| # Send STT audio chunk message | |
| message = { | |
| "type": "stt_audio_chunk", | |
| "audio_data": audio_b64, | |
| "language": "auto", | |
| "model_size": "base" | |
| } | |
| await stt_ws.send(json.dumps(message)) | |
| logger.info(f"π€ Sent {len(audio_data)} bytes to STT service") | |
| # Wait for transcription response | |
| response = await stt_ws.recv() | |
| response_data = json.loads(response) | |
| if response_data.get("type") == "stt_transcription": | |
| transcription = response_data.get("text", "") | |
| logger.info(f"π STT transcription received: {transcription[:50]}...") | |
| return transcription | |
| elif response_data.get("type") == "stt_error": | |
| error_msg = response_data.get("message", "Unknown STT error") | |
| logger.error(f"β STT service error: {error_msg}") | |
| return None | |
| else: | |
| logger.warning(f"β οΈ Unexpected STT response: {response_data}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"β Error communicating with STT service: {e}") | |
| # Cleanup connection on error | |
| await self.disconnect_from_stt_service(client_id) | |
| return None | |
| # TTS WebSocket Methods | |
| async def connect_to_tts_service(self, client_id: str) -> bool: | |
| """Connect to the TTS WebSocket service""" | |
| try: | |
| logger.info(f"π Connecting to TTS service for client {client_id}: {self.tts_websocket_url}") | |
| # Connect to TTS WebSocket service with timeout | |
| tts_ws = await asyncio.wait_for( | |
| websockets.connect(self.tts_websocket_url), | |
| timeout=10.0 | |
| ) | |
| self.tts_connections[client_id] = tts_ws | |
| # Wait for connection confirmation | |
| confirmation = await asyncio.wait_for(tts_ws.recv(), timeout=15.0) | |
| confirmation_data = json.loads(confirmation) | |
| if confirmation_data.get("type") == "tts_connection_confirmed": | |
| logger.info(f"β TTS service connected for client {client_id}") | |
| return True | |
| else: | |
| logger.warning(f"β οΈ Unexpected TTS confirmation: {confirmation_data}") | |
| return False | |
| except asyncio.TimeoutError: | |
| logger.error(f"β TTS service connection timeout - service may not be in WebSocket mode") | |
| logger.info(f"π‘ TTS service needs TTS_SERVICE_MODE=websocket environment variable") | |
| return False | |
| except websockets.exceptions.InvalidStatusCode as e: | |
| logger.error(f"β TTS WebSocket endpoint not available: {e}") | |
| logger.info(f"π‘ TTS service may be running in Gradio-only mode instead of WebSocket mode") | |
| return False | |
| except Exception as e: | |
| logger.error(f"β Failed to connect to TTS service for {client_id}: {e}") | |
| logger.info(f"π‘ Check if TTS service is running and configured with TTS_SERVICE_MODE=websocket") | |
| return False | |
| async def disconnect_from_tts_service(self, client_id: str): | |
| """Disconnect from TTS WebSocket service""" | |
| if client_id in self.tts_connections: | |
| try: | |
| tts_ws = self.tts_connections[client_id] | |
| await tts_ws.close() | |
| del self.tts_connections[client_id] | |
| logger.info(f"π Disconnected from TTS service for client {client_id}") | |
| except Exception as e: | |
| logger.error(f"Error disconnecting from TTS service: {e}") | |
| async def send_text_to_tts_service(self, client_id: str, text: str, voice_preset: str = "v2/en_speaker_6") -> Optional[bytes]: | |
| """Send text to TTS service and get audio response""" | |
| if client_id not in self.tts_connections: | |
| # Try to connect if not already connected | |
| success = await self.connect_to_tts_service(client_id) | |
| if not success: | |
| return None | |
| try: | |
| tts_ws = self.tts_connections[client_id] | |
| # Send TTS synthesis message | |
| message = { | |
| "type": "tts_synthesize", | |
| "text": text, | |
| "voice_preset": voice_preset | |
| } | |
| await tts_ws.send(json.dumps(message)) | |
| logger.info(f"π€ Sent text to TTS service: {text[:50]}...") | |
| # Wait for audio response | |
| response = await tts_ws.recv() | |
| response_data = json.loads(response) | |
| if response_data.get("type") == "tts_audio_response": | |
| # Decode base64 audio data | |
| import base64 | |
| audio_b64 = response_data.get("audio_data", "") | |
| audio_bytes = base64.b64decode(audio_b64) | |
| logger.info(f"π TTS audio received: {len(audio_bytes)} bytes") | |
| return audio_bytes | |
| elif response_data.get("type") == "tts_error": | |
| error_msg = response_data.get("message", "Unknown TTS error") | |
| logger.error(f"β TTS service error: {error_msg}") | |
| return None | |
| else: | |
| logger.warning(f"β οΈ Unexpected TTS response: {response_data}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"β Error communicating with TTS service: {e}") | |
| # Cleanup connection on error | |
| await self.disconnect_from_tts_service(client_id) | |
| return None | |
| async def play_tts_response(self, client_id: str, text: str, voice_preset: str = "v2/en_speaker_6"): | |
| """Generate TTS audio and send to client for playback""" | |
| try: | |
| logger.info(f"π Generating TTS response for client {client_id}: {text[:50]}...") | |
| # Try WebSocket FIRST - this is the primary method we want to use | |
| logger.info("π Attempting WebSocket TTS (PRIMARY)") | |
| audio_data = await self.send_text_to_tts_service(client_id, text, voice_preset) | |
| if not audio_data: | |
| logger.info("π WebSocket failed, trying HTTP API fallback") | |
| audio_data = await self.try_http_tts_fallback(text, voice_preset) | |
| if audio_data: | |
| # Convert audio to base64 for WebSocket transmission | |
| import base64 | |
| audio_b64 = base64.b64encode(audio_data).decode('utf-8') | |
| # Send audio playback message to client | |
| await self.send_message(client_id, { | |
| "type": "tts_playback", | |
| "audio_data": audio_b64, | |
| "audio_format": "wav", | |
| "text": text, | |
| "voice_preset": voice_preset, | |
| "timestamp": datetime.now().isoformat(), | |
| "audio_size": len(audio_data) | |
| }) | |
| logger.info(f"π TTS playback sent to {client_id} ({len(audio_data)} bytes)") | |
| else: | |
| logger.warning(f"β οΈ TTS service failed to generate audio for: {text[:50]}...") | |
| # Send error message | |
| await self.send_message(client_id, { | |
| "type": "tts_error", | |
| "message": "TTS audio generation failed", | |
| "text": text, | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| except Exception as e: | |
| logger.error(f"β TTS playback error for {client_id}: {e}") | |
| await self.send_message(client_id, { | |
| "type": "tts_error", | |
| "message": f"TTS playback error: {str(e)}", | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| async def process_audio_file_webrtc(self, audio_file_path: str, sample_rate: int) -> Optional[str]: | |
| """Process audio file with real STT service via WebSocket""" | |
| try: | |
| logger.info(f"π€ WebRTC: Processing audio file {audio_file_path} with real STT") | |
| # Read audio file data | |
| with open(audio_file_path, 'rb') as f: | |
| audio_data = f.read() | |
| file_size = len(audio_data) | |
| logger.info(f"π€ Audio file size: {file_size} bytes") | |
| # Use a temporary client ID for this STT call | |
| temp_client_id = f"temp_{datetime.now().isoformat()}" | |
| try: | |
| # Try WebSocket FIRST - this is the primary method we want to use | |
| logger.info("π Attempting WebSocket STT (PRIMARY)") | |
| transcription = await self.send_audio_to_stt_service(temp_client_id, audio_data) | |
| if transcription: | |
| logger.info(f"β WebSocket STT transcription: {transcription}") | |
| return transcription | |
| # Fallback to HTTP API only if WebSocket fails | |
| logger.info("π WebSocket failed, trying HTTP API fallback") | |
| http_transcription = await self.try_http_stt_fallback(audio_file_path) | |
| if http_transcription: | |
| logger.info(f"β HTTP STT transcription (fallback): {http_transcription}") | |
| return f"[HTTP] {http_transcription}" | |
| else: | |
| logger.error("β Both WebSocket and HTTP STT failed - using minimal fallback") | |
| # Final fallback - but make it more realistic for TTS | |
| return "I'm having trouble processing that audio. Could you please try again?" | |
| finally: | |
| # Cleanup temporary connection | |
| await self.disconnect_from_stt_service(temp_client_id) | |
| except Exception as e: | |
| logger.error(f"WebRTC audio file processing failed: {e}") | |
| return None | |
| async def try_http_stt_fallback(self, audio_file_path: str) -> Optional[str]: | |
| """Fallback to HTTP API if WebSocket fails""" | |
| try: | |
| import requests | |
| import aiohttp | |
| import asyncio | |
| # Convert to async HTTP request | |
| def make_request(): | |
| api_url = f"{self.stt_service_url}/api/predict" | |
| with open(audio_file_path, 'rb') as audio_file: | |
| files = {'data': audio_file} | |
| data = {'data': '["auto", "base", true]'} # [language, model_size, timestamps] | |
| response = requests.post(api_url, files=files, data=data, timeout=30) | |
| return response | |
| # Run in thread to avoid blocking | |
| loop = asyncio.get_event_loop() | |
| response = await loop.run_in_executor(None, make_request) | |
| if response.status_code == 200: | |
| result = response.json() | |
| logger.info(f"π HTTP STT result: {result}") | |
| # Extract transcription from Gradio API format | |
| if result and 'data' in result and len(result['data']) > 1: | |
| transcription = result['data'][1] # [status, transcription, timestamps] | |
| if transcription and transcription.strip(): | |
| logger.info(f"β HTTP STT transcription: {transcription}") | |
| return transcription | |
| except Exception as e: | |
| logger.error(f"β HTTP STT fallback failed: {e}") | |
| return None | |
| async def try_http_tts_fallback(self, text: str, voice_preset: str = "v2/en_speaker_6") -> Optional[bytes]: | |
| """Fallback to HTTP API if TTS WebSocket fails""" | |
| try: | |
| import requests | |
| import asyncio | |
| # Convert to async HTTP request | |
| def make_request(): | |
| api_url = f"{self.tts_service_url}/api/predict" | |
| data = {'data': f'["{text}", "{voice_preset}"]'} # [text, voice_preset] | |
| response = requests.post(api_url, data=data, timeout=60) # TTS takes longer | |
| return response | |
| # Run in thread to avoid blocking | |
| loop = asyncio.get_event_loop() | |
| response = await loop.run_in_executor(None, make_request) | |
| if response.status_code == 200: | |
| result = response.json() | |
| logger.info(f"π HTTP TTS result received") | |
| # Extract audio file path from Gradio API format | |
| if result and 'data' in result and len(result['data']) > 0: | |
| audio_file_path = result['data'][0] # Should be a file path | |
| if audio_file_path and isinstance(audio_file_path, str): | |
| # Download the audio file | |
| if audio_file_path.startswith('http'): | |
| audio_response = requests.get(audio_file_path, timeout=30) | |
| if audio_response.status_code == 200: | |
| logger.info(f"β HTTP TTS audio downloaded: {len(audio_response.content)} bytes") | |
| return audio_response.content | |
| except Exception as e: | |
| logger.error(f"β HTTP TTS fallback failed: {e}") | |
| return None | |
| async def process_audio_chunk_real_time(self, audio_array: np.ndarray, sample_rate: int) -> Optional[str]: | |
| """Legacy method - kept for compatibility""" | |
| try: | |
| logger.info(f"π€ WebRTC: Processing {len(audio_array)} samples at {sample_rate}Hz") | |
| duration = len(audio_array) / sample_rate | |
| transcription = f"WebRTC test: Audio array ({duration:.1f}s, {sample_rate}Hz)" | |
| return transcription | |
| except Exception as e: | |
| logger.error(f"WebRTC audio processing failed: {e}") | |
| return None | |
| async def handle_message(self, client_id: str, message_data: dict): | |
| """Handle different types of WebSocket messages""" | |
| message_type = message_data.get("type") | |
| if message_type == "audio_chunk": | |
| # Real-time audio data | |
| audio_data = message_data.get("audio_data") # Base64 encoded | |
| sample_rate = message_data.get("sample_rate", 16000) | |
| if audio_data: | |
| # Decode base64 audio data | |
| import base64 | |
| audio_bytes = base64.b64decode(audio_data) | |
| await self.handle_audio_chunk(client_id, audio_bytes, sample_rate) | |
| elif message_type == "start_recording": | |
| # Client started recording | |
| await self.send_message(client_id, { | |
| "type": "recording_started", | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| logger.info(f"π€ Recording started for {client_id}") | |
| elif message_type == "stop_recording": | |
| # Client stopped recording - UNMUTE.SH FLUSH TRICK | |
| logger.info(f"π€ Recording stopped for {client_id} - applying unmute.sh flush trick") | |
| # Process all buffered audio with flush trick | |
| await self.process_buffered_audio_with_flush(client_id) | |
| await self.send_message(client_id, { | |
| "type": "recording_stopped", | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| elif message_type == "tts_request": | |
| # Client requesting TTS playback | |
| text = message_data.get("text", "") | |
| voice_preset = message_data.get("voice_preset", "v2/en_speaker_6") | |
| if text.strip(): | |
| await self.play_tts_response(client_id, text, voice_preset) | |
| else: | |
| await self.send_message(client_id, { | |
| "type": "tts_error", | |
| "message": "Empty text provided for TTS", | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| elif message_type == "get_tts_voices": | |
| # Client requesting available TTS voices | |
| await self.send_message(client_id, { | |
| "type": "tts_voices_list", | |
| "voices": ["v2/en_speaker_6", "v2/en_speaker_9", "v2/en_speaker_3", "v2/en_speaker_1"], | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| else: | |
| logger.warning(f"Unknown message type from {client_id}: {message_type}") | |
| # Global WebRTC handler instance | |
| webrtc_handler = WebRTCHandler() |