voiceCalendar / webrtc /server /websocket_handler.py
Peter Michael Gits
feat: Add Streamlit-native WebRTC speech-to-text using unmute.sh patterns
21fac9b
"""
WebRTC WebSocket Handler for Real-time Audio Streaming
Integrates with FastAPI for unmute.sh-style voice interaction
"""
import asyncio
import json
import logging
from typing import Dict, Optional
import websockets
from fastapi import WebSocket, WebSocketDisconnect
import numpy as np
import soundfile as sf
import tempfile
import os
from datetime import datetime
logger = logging.getLogger(__name__)
class WebRTCHandler:
"""Handles WebRTC WebSocket connections for real-time audio streaming"""
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.audio_buffers: Dict[str, list] = {}
self.stt_service_url = "https://pgits-stt-gpu-service.hf.space"
self.stt_websocket_url = "wss://pgits-stt-gpu-service.hf.space/ws/stt"
self.stt_connections: Dict[str, websockets.WebSocketClientProtocol] = {}
self.tts_service_url = "https://pgits-tts-gpu-service.hf.space"
self.tts_websocket_url = "wss://pgits-tts-gpu-service.hf.space/ws/tts"
self.tts_connections: Dict[str, websockets.WebSocketClientProtocol] = {}
async def connect(self, websocket: WebSocket, client_id: str):
"""Accept WebSocket connection and initialize audio buffer"""
await websocket.accept()
self.active_connections[client_id] = websocket
self.audio_buffers[client_id] = []
logger.info(f"πŸ”Œ WebRTC client {client_id} connected")
# Send connection confirmation
await self.send_message(client_id, {
"type": "connection_confirmed",
"client_id": client_id,
"timestamp": datetime.now().isoformat(),
"services": {
"stt": self.stt_service_url,
"status": "ready"
}
})
async def disconnect(self, client_id: str):
"""Clean up connection and buffers"""
if client_id in self.active_connections:
del self.active_connections[client_id]
if client_id in self.audio_buffers:
del self.audio_buffers[client_id]
# Clean up STT connection if exists
await self.disconnect_from_stt_service(client_id)
# Clean up TTS connection if exists
await self.disconnect_from_tts_service(client_id)
logger.info(f"πŸ”Œ WebRTC client {client_id} disconnected")
async def send_message(self, client_id: str, message: dict):
"""Send JSON message to client"""
if client_id in self.active_connections:
websocket = self.active_connections[client_id]
try:
await websocket.send_text(json.dumps(message))
except Exception as e:
logger.error(f"Failed to send message to {client_id}: {e}")
await self.disconnect(client_id)
async def handle_audio_chunk(self, client_id: str, audio_data: bytes, sample_rate: int = 16000):
"""Process incoming audio chunk using unmute.sh streaming methodology"""
try:
logger.info(f"🎀 Received {len(audio_data)} bytes from {client_id}")
# UNMUTE.SH METHODOLOGY: Buffer chunks for streaming STT processing
if client_id not in self.audio_buffers:
self.audio_buffers[client_id] = []
# Add chunk to buffer
self.audio_buffers[client_id].append(audio_data)
# Send partial transcription acknowledgment (unmute.sh style)
await self.send_message(client_id, {
"type": "chunk_buffered",
"chunk_size": len(audio_data),
"buffer_chunks": len(self.audio_buffers[client_id]),
"timestamp": datetime.now().isoformat()
})
logger.info(f"πŸ“¦ Buffered chunk for {client_id} ({len(self.audio_buffers[client_id])} total chunks)")
except Exception as e:
logger.error(f"Error buffering audio chunk for {client_id}: {e}")
await self.send_message(client_id, {
"type": "error",
"message": f"Audio buffering error: {str(e)}",
"timestamp": datetime.now().isoformat()
})
async def process_buffered_audio_with_flush(self, client_id: str):
"""Process all buffered audio chunks with unmute.sh flush trick"""
try:
if client_id not in self.audio_buffers or not self.audio_buffers[client_id]:
logger.info(f"No audio chunks to process for {client_id}")
return
# Combine all audio chunks into one complete audio file
all_audio_data = b''.join(self.audio_buffers[client_id])
total_chunks = len(self.audio_buffers[client_id])
logger.info(f"πŸ”„ Processing {total_chunks} buffered chunks ({len(all_audio_data)} bytes total) with flush trick")
# Create temporary file for complete audio
with tempfile.NamedTemporaryFile(suffix='.webm', delete=False) as tmp_file:
tmp_file.write(all_audio_data)
tmp_file_path = tmp_file.name
try:
# Process complete audio with unmute.sh methodology (is_final=True for flush trick)
transcription = await self.process_audio_file_webrtc_with_flush(tmp_file_path)
if transcription and transcription.strip() and not transcription.startswith("ERROR"):
# Send final transcription back to client
await self.send_message(client_id, {
"type": "transcription",
"text": transcription.strip(),
"timestamp": datetime.now().isoformat(),
"audio_size": len(all_audio_data),
"format": "webm/audio",
"is_final": True, # unmute.sh flush trick marker
"chunks_processed": total_chunks
})
logger.info(f"πŸ“ Final transcription sent to {client_id}: {transcription[:50]}...")
else:
# Send error message
await self.send_message(client_id, {
"type": "transcription_error",
"message": f"Audio processing failed: {transcription if transcription else 'No result'}",
"timestamp": datetime.now().isoformat()
})
finally:
# Clean up
if os.path.exists(tmp_file_path):
os.unlink(tmp_file_path)
# Clear the buffer
self.audio_buffers[client_id] = []
logger.info(f"🧹 Cleared audio buffer for {client_id}")
except Exception as e:
logger.error(f"Error processing buffered audio for {client_id}: {e}")
await self.send_message(client_id, {
"type": "transcription_error",
"message": f"Buffered audio processing error: {str(e)}",
"timestamp": datetime.now().isoformat()
})
async def process_audio_file_webrtc_with_flush(self, audio_file_path: str) -> Optional[str]:
"""Process audio file using unmute.sh flush trick methodology"""
try:
# Import the MCP audio handler for processing
from core.mcp_audio_handler import mcp_audio_handler
# Use the real STT service with flush trick (is_final=True)
result = await mcp_audio_handler.speech_to_text(audio_file_path)
logger.info(f"πŸš€ FLUSH TRICK: STT service returned: {result[:100] if result else 'None'}...")
return result
except Exception as e:
logger.error(f"Error in flush trick audio processing: {e}")
return f"ERROR: Flush trick processing failed - {str(e)}"
async def connect_to_stt_service(self, client_id: str) -> bool:
"""Connect to the STT WebSocket service"""
try:
logger.info(f"πŸ”Œ Connecting to STT service for client {client_id}: {self.stt_websocket_url}")
# Connect to STT WebSocket service with shorter timeout
stt_ws = await asyncio.wait_for(
websockets.connect(self.stt_websocket_url),
timeout=5.0
)
self.stt_connections[client_id] = stt_ws
# Wait for connection confirmation with timeout
confirmation = await asyncio.wait_for(stt_ws.recv(), timeout=10.0)
confirmation_data = json.loads(confirmation)
if confirmation_data.get("type") == "stt_connection_confirmed":
logger.info(f"βœ… STT service connected for client {client_id}")
return True
else:
logger.warning(f"⚠️ Unexpected STT confirmation: {confirmation_data}")
return False
except asyncio.TimeoutError:
logger.error(f"❌ STT service connection timeout for {client_id} - service may be cold starting or WebSocket endpoints not available")
return False
except websockets.exceptions.WebSocketException as e:
if "503" in str(e):
logger.error(f"❌ STT service unavailable (HTTP 503) for {client_id} - service may be cold starting")
logger.info(f"πŸ”„ Try again in a few moments - Hugging Face services need time to start")
else:
logger.error(f"❌ STT WebSocket error for {client_id}: {e}")
logger.info(f"πŸ” Debug: Attempted connection to {self.stt_websocket_url}")
return False
except Exception as e:
logger.error(f"❌ Failed to connect to STT service for {client_id}: {e}")
logger.info(f"πŸ” Debug: STT service URL: {self.stt_websocket_url}")
return False
async def disconnect_from_stt_service(self, client_id: str):
"""Disconnect from STT WebSocket service"""
if client_id in self.stt_connections:
try:
stt_ws = self.stt_connections[client_id]
await stt_ws.close()
del self.stt_connections[client_id]
logger.info(f"πŸ”Œ Disconnected from STT service for client {client_id}")
except Exception as e:
logger.error(f"Error disconnecting from STT service: {e}")
async def send_audio_to_stt_service(self, client_id: str, audio_data: bytes) -> Optional[str]:
"""Send audio data to STT service and get transcription"""
if client_id not in self.stt_connections:
# Try to connect if not already connected
success = await self.connect_to_stt_service(client_id)
if not success:
return None
try:
stt_ws = self.stt_connections[client_id]
# Convert audio bytes to base64 for WebSocket transmission
import base64
audio_b64 = base64.b64encode(audio_data).decode('utf-8')
# Send STT audio chunk message
message = {
"type": "stt_audio_chunk",
"audio_data": audio_b64,
"language": "auto",
"model_size": "base"
}
await stt_ws.send(json.dumps(message))
logger.info(f"πŸ“€ Sent {len(audio_data)} bytes to STT service")
# Wait for transcription response
response = await stt_ws.recv()
response_data = json.loads(response)
if response_data.get("type") == "stt_transcription":
transcription = response_data.get("text", "")
logger.info(f"πŸ“ STT transcription received: {transcription[:50]}...")
return transcription
elif response_data.get("type") == "stt_error":
error_msg = response_data.get("message", "Unknown STT error")
logger.error(f"❌ STT service error: {error_msg}")
return None
else:
logger.warning(f"⚠️ Unexpected STT response: {response_data}")
return None
except Exception as e:
logger.error(f"❌ Error communicating with STT service: {e}")
# Cleanup connection on error
await self.disconnect_from_stt_service(client_id)
return None
# TTS WebSocket Methods
async def connect_to_tts_service(self, client_id: str) -> bool:
"""Connect to the TTS WebSocket service"""
try:
logger.info(f"πŸ”Œ Connecting to TTS service for client {client_id}: {self.tts_websocket_url}")
# Connect to TTS WebSocket service with timeout
tts_ws = await asyncio.wait_for(
websockets.connect(self.tts_websocket_url),
timeout=10.0
)
self.tts_connections[client_id] = tts_ws
# Wait for connection confirmation
confirmation = await asyncio.wait_for(tts_ws.recv(), timeout=15.0)
confirmation_data = json.loads(confirmation)
if confirmation_data.get("type") == "tts_connection_confirmed":
logger.info(f"βœ… TTS service connected for client {client_id}")
return True
else:
logger.warning(f"⚠️ Unexpected TTS confirmation: {confirmation_data}")
return False
except asyncio.TimeoutError:
logger.error(f"❌ TTS service connection timeout - service may not be in WebSocket mode")
logger.info(f"πŸ’‘ TTS service needs TTS_SERVICE_MODE=websocket environment variable")
return False
except websockets.exceptions.InvalidStatusCode as e:
logger.error(f"❌ TTS WebSocket endpoint not available: {e}")
logger.info(f"πŸ’‘ TTS service may be running in Gradio-only mode instead of WebSocket mode")
return False
except Exception as e:
logger.error(f"❌ Failed to connect to TTS service for {client_id}: {e}")
logger.info(f"πŸ’‘ Check if TTS service is running and configured with TTS_SERVICE_MODE=websocket")
return False
async def disconnect_from_tts_service(self, client_id: str):
"""Disconnect from TTS WebSocket service"""
if client_id in self.tts_connections:
try:
tts_ws = self.tts_connections[client_id]
await tts_ws.close()
del self.tts_connections[client_id]
logger.info(f"πŸ”Œ Disconnected from TTS service for client {client_id}")
except Exception as e:
logger.error(f"Error disconnecting from TTS service: {e}")
async def send_text_to_tts_service(self, client_id: str, text: str, voice_preset: str = "v2/en_speaker_6") -> Optional[bytes]:
"""Send text to TTS service and get audio response"""
if client_id not in self.tts_connections:
# Try to connect if not already connected
success = await self.connect_to_tts_service(client_id)
if not success:
return None
try:
tts_ws = self.tts_connections[client_id]
# Send TTS synthesis message
message = {
"type": "tts_synthesize",
"text": text,
"voice_preset": voice_preset
}
await tts_ws.send(json.dumps(message))
logger.info(f"πŸ“€ Sent text to TTS service: {text[:50]}...")
# Wait for audio response
response = await tts_ws.recv()
response_data = json.loads(response)
if response_data.get("type") == "tts_audio_response":
# Decode base64 audio data
import base64
audio_b64 = response_data.get("audio_data", "")
audio_bytes = base64.b64decode(audio_b64)
logger.info(f"πŸ”Š TTS audio received: {len(audio_bytes)} bytes")
return audio_bytes
elif response_data.get("type") == "tts_error":
error_msg = response_data.get("message", "Unknown TTS error")
logger.error(f"❌ TTS service error: {error_msg}")
return None
else:
logger.warning(f"⚠️ Unexpected TTS response: {response_data}")
return None
except Exception as e:
logger.error(f"❌ Error communicating with TTS service: {e}")
# Cleanup connection on error
await self.disconnect_from_tts_service(client_id)
return None
async def play_tts_response(self, client_id: str, text: str, voice_preset: str = "v2/en_speaker_6"):
"""Generate TTS audio and send to client for playback"""
try:
logger.info(f"πŸ”Š Generating TTS response for client {client_id}: {text[:50]}...")
# Try WebSocket FIRST - this is the primary method we want to use
logger.info("🌐 Attempting WebSocket TTS (PRIMARY)")
audio_data = await self.send_text_to_tts_service(client_id, text, voice_preset)
if not audio_data:
logger.info("πŸ”„ WebSocket failed, trying HTTP API fallback")
audio_data = await self.try_http_tts_fallback(text, voice_preset)
if audio_data:
# Convert audio to base64 for WebSocket transmission
import base64
audio_b64 = base64.b64encode(audio_data).decode('utf-8')
# Send audio playback message to client
await self.send_message(client_id, {
"type": "tts_playback",
"audio_data": audio_b64,
"audio_format": "wav",
"text": text,
"voice_preset": voice_preset,
"timestamp": datetime.now().isoformat(),
"audio_size": len(audio_data)
})
logger.info(f"πŸ”Š TTS playback sent to {client_id} ({len(audio_data)} bytes)")
else:
logger.warning(f"⚠️ TTS service failed to generate audio for: {text[:50]}...")
# Send error message
await self.send_message(client_id, {
"type": "tts_error",
"message": "TTS audio generation failed",
"text": text,
"timestamp": datetime.now().isoformat()
})
except Exception as e:
logger.error(f"❌ TTS playback error for {client_id}: {e}")
await self.send_message(client_id, {
"type": "tts_error",
"message": f"TTS playback error: {str(e)}",
"timestamp": datetime.now().isoformat()
})
async def process_audio_file_webrtc(self, audio_file_path: str, sample_rate: int) -> Optional[str]:
"""Process audio file with real STT service via WebSocket"""
try:
logger.info(f"🎀 WebRTC: Processing audio file {audio_file_path} with real STT")
# Read audio file data
with open(audio_file_path, 'rb') as f:
audio_data = f.read()
file_size = len(audio_data)
logger.info(f"🎀 Audio file size: {file_size} bytes")
# Use a temporary client ID for this STT call
temp_client_id = f"temp_{datetime.now().isoformat()}"
try:
# Try WebSocket FIRST - this is the primary method we want to use
logger.info("🌐 Attempting WebSocket STT (PRIMARY)")
transcription = await self.send_audio_to_stt_service(temp_client_id, audio_data)
if transcription:
logger.info(f"βœ… WebSocket STT transcription: {transcription}")
return transcription
# Fallback to HTTP API only if WebSocket fails
logger.info("πŸ”„ WebSocket failed, trying HTTP API fallback")
http_transcription = await self.try_http_stt_fallback(audio_file_path)
if http_transcription:
logger.info(f"βœ… HTTP STT transcription (fallback): {http_transcription}")
return f"[HTTP] {http_transcription}"
else:
logger.error("❌ Both WebSocket and HTTP STT failed - using minimal fallback")
# Final fallback - but make it more realistic for TTS
return "I'm having trouble processing that audio. Could you please try again?"
finally:
# Cleanup temporary connection
await self.disconnect_from_stt_service(temp_client_id)
except Exception as e:
logger.error(f"WebRTC audio file processing failed: {e}")
return None
async def try_http_stt_fallback(self, audio_file_path: str) -> Optional[str]:
"""Fallback to HTTP API if WebSocket fails"""
try:
import requests
import aiohttp
import asyncio
# Convert to async HTTP request
def make_request():
api_url = f"{self.stt_service_url}/api/predict"
with open(audio_file_path, 'rb') as audio_file:
files = {'data': audio_file}
data = {'data': '["auto", "base", true]'} # [language, model_size, timestamps]
response = requests.post(api_url, files=files, data=data, timeout=30)
return response
# Run in thread to avoid blocking
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(None, make_request)
if response.status_code == 200:
result = response.json()
logger.info(f"πŸ“ HTTP STT result: {result}")
# Extract transcription from Gradio API format
if result and 'data' in result and len(result['data']) > 1:
transcription = result['data'][1] # [status, transcription, timestamps]
if transcription and transcription.strip():
logger.info(f"βœ… HTTP STT transcription: {transcription}")
return transcription
except Exception as e:
logger.error(f"❌ HTTP STT fallback failed: {e}")
return None
async def try_http_tts_fallback(self, text: str, voice_preset: str = "v2/en_speaker_6") -> Optional[bytes]:
"""Fallback to HTTP API if TTS WebSocket fails"""
try:
import requests
import asyncio
# Convert to async HTTP request
def make_request():
api_url = f"{self.tts_service_url}/api/predict"
data = {'data': f'["{text}", "{voice_preset}"]'} # [text, voice_preset]
response = requests.post(api_url, data=data, timeout=60) # TTS takes longer
return response
# Run in thread to avoid blocking
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(None, make_request)
if response.status_code == 200:
result = response.json()
logger.info(f"πŸ”Š HTTP TTS result received")
# Extract audio file path from Gradio API format
if result and 'data' in result and len(result['data']) > 0:
audio_file_path = result['data'][0] # Should be a file path
if audio_file_path and isinstance(audio_file_path, str):
# Download the audio file
if audio_file_path.startswith('http'):
audio_response = requests.get(audio_file_path, timeout=30)
if audio_response.status_code == 200:
logger.info(f"βœ… HTTP TTS audio downloaded: {len(audio_response.content)} bytes")
return audio_response.content
except Exception as e:
logger.error(f"❌ HTTP TTS fallback failed: {e}")
return None
async def process_audio_chunk_real_time(self, audio_array: np.ndarray, sample_rate: int) -> Optional[str]:
"""Legacy method - kept for compatibility"""
try:
logger.info(f"🎀 WebRTC: Processing {len(audio_array)} samples at {sample_rate}Hz")
duration = len(audio_array) / sample_rate
transcription = f"WebRTC test: Audio array ({duration:.1f}s, {sample_rate}Hz)"
return transcription
except Exception as e:
logger.error(f"WebRTC audio processing failed: {e}")
return None
async def handle_message(self, client_id: str, message_data: dict):
"""Handle different types of WebSocket messages"""
message_type = message_data.get("type")
if message_type == "audio_chunk":
# Real-time audio data
audio_data = message_data.get("audio_data") # Base64 encoded
sample_rate = message_data.get("sample_rate", 16000)
if audio_data:
# Decode base64 audio data
import base64
audio_bytes = base64.b64decode(audio_data)
await self.handle_audio_chunk(client_id, audio_bytes, sample_rate)
elif message_type == "start_recording":
# Client started recording
await self.send_message(client_id, {
"type": "recording_started",
"timestamp": datetime.now().isoformat()
})
logger.info(f"🎀 Recording started for {client_id}")
elif message_type == "stop_recording":
# Client stopped recording - UNMUTE.SH FLUSH TRICK
logger.info(f"🎀 Recording stopped for {client_id} - applying unmute.sh flush trick")
# Process all buffered audio with flush trick
await self.process_buffered_audio_with_flush(client_id)
await self.send_message(client_id, {
"type": "recording_stopped",
"timestamp": datetime.now().isoformat()
})
elif message_type == "tts_request":
# Client requesting TTS playback
text = message_data.get("text", "")
voice_preset = message_data.get("voice_preset", "v2/en_speaker_6")
if text.strip():
await self.play_tts_response(client_id, text, voice_preset)
else:
await self.send_message(client_id, {
"type": "tts_error",
"message": "Empty text provided for TTS",
"timestamp": datetime.now().isoformat()
})
elif message_type == "get_tts_voices":
# Client requesting available TTS voices
await self.send_message(client_id, {
"type": "tts_voices_list",
"voices": ["v2/en_speaker_6", "v2/en_speaker_9", "v2/en_speaker_3", "v2/en_speaker_1"],
"timestamp": datetime.now().isoformat()
})
else:
logger.warning(f"Unknown message type from {client_id}: {message_type}")
# Global WebRTC handler instance
webrtc_handler = WebRTCHandler()