#!/usr/bin/env python3 """ Standalone WebSocket-only TTS Service Simplified service without Gradio, MCP, or web interfaces Following unmute.sh WebRTC pattern for HuggingFace Spaces """ import asyncio import json import uuid import base64 import tempfile import os import logging import time from datetime import datetime from typing import Optional, Dict, Any import torch from transformers import AutoProcessor, BarkModel import soundfile as sf import numpy as np from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware import spaces import uvicorn # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Version info __version__ = "1.0.0" __service__ = "TTS WebSocket Service" class TTSWebSocketService: """Standalone TTS service with WebSocket-only interface""" def __init__(self): self.model = None self.processor = None self.device = "cuda" if torch.cuda.is_available() else "cpu" self.active_connections: Dict[str, WebSocket] = {} self.available_voices = [ "v2/en_speaker_0", "v2/en_speaker_1", "v2/en_speaker_2", "v2/en_speaker_3", "v2/en_speaker_4", "v2/en_speaker_5", "v2/en_speaker_6", "v2/en_speaker_7", "v2/en_speaker_8", "v2/en_speaker_9" ] logger.info(f"🔊 {__service__} v{__version__} initializing...") logger.info(f"Device: {self.device}") logger.info(f"Available voices: {len(self.available_voices)}") async def load_model(self): """Load Bark TTS model with ZeroGPU compatibility""" if self.model is None: logger.info("Loading Bark TTS model...") self.processor = AutoProcessor.from_pretrained("suno/bark") self.model = BarkModel.from_pretrained("suno/bark") if self.device == "cuda": self.model = self.model.to(self.device) logger.info(f"✅ Bark model loaded on {self.device}") @spaces.GPU(duration=30) async def synthesize_speech( self, text: str, voice_preset: str = "v2/en_speaker_6", sample_rate: int = 24000 ) -> tuple[Optional[str], str, Dict[str, Any]]: """Synthesize speech from text using Bark with ZeroGPU""" try: if not text.strip(): return None, "error", {"error": "Empty text provided"} start_time = time.time() # Ensure model is loaded if self.model is None: await self.load_model() logger.info(f"Synthesizing: '{text[:50]}...' with {voice_preset}") # Process text with voice preset inputs = self.processor( text, voice_preset=voice_preset, return_tensors="pt" ) if self.device == "cuda": inputs = {k: v.to(self.device) for k, v in inputs.items()} # Generate audio with torch.no_grad(): audio_array = self.model.generate(**inputs) audio_array = audio_array.cpu().numpy().squeeze() # Save to temporary WAV file with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: sf.write(tmp_file.name, audio_array, sample_rate) temp_path = tmp_file.name # Calculate timing processing_time = time.time() - start_time timing_info = { "processing_time": processing_time, "start_time": datetime.fromtimestamp(start_time).isoformat(), "end_time": datetime.now().isoformat(), "voice_preset": voice_preset, "device": self.device, "text_length": len(text), "sample_rate": sample_rate } logger.info(f"Speech synthesis completed in {processing_time:.2f}s") return temp_path, "success", timing_info except Exception as e: logger.error(f"TTS synthesis error: {str(e)}") return None, "error", {"error": str(e)} async def connect_websocket(self, websocket: WebSocket) -> str: """Accept WebSocket connection and return client ID""" client_id = str(uuid.uuid4()) await websocket.accept() self.active_connections[client_id] = websocket # Send connection confirmation await websocket.send_text(json.dumps({ "type": "tts_connection_confirmed", "client_id": client_id, "service": __service__, "version": __version__, "available_voices": self.available_voices, "device": self.device, "message": "TTS WebSocket connected and ready" })) logger.info(f"Client {client_id} connected") return client_id async def disconnect_websocket(self, client_id: str): """Clean up WebSocket connection""" if client_id in self.active_connections: del self.active_connections[client_id] logger.info(f"Client {client_id} disconnected") async def process_tts_message(self, client_id: str, message: Dict[str, Any]): """Process incoming TTS request from WebSocket""" try: websocket = self.active_connections[client_id] # Extract text and voice preset text = message.get("text", "").strip() voice_preset = message.get("voice_preset", "v2/en_speaker_6") if not text: await websocket.send_text(json.dumps({ "type": "tts_synthesis_error", "client_id": client_id, "error": "No text provided for synthesis" })) return # Validate voice preset if voice_preset not in self.available_voices: voice_preset = "v2/en_speaker_6" # Default fallback # Synthesize speech audio_path, status, timing = await self.synthesize_speech( text, voice_preset ) if status == "success" and audio_path: try: # Read generated audio file with open(audio_path, 'rb') as audio_file: audio_data = audio_file.read() # Encode as base64 for WebSocket transmission audio_b64 = base64.b64encode(audio_data).decode('utf-8') # Send result back await websocket.send_text(json.dumps({ "type": "tts_synthesis_complete", "client_id": client_id, "audio_data": audio_b64, "audio_format": "wav", "text": text, "voice_preset": voice_preset, "audio_size": len(audio_data), "timing": timing, "status": "success" })) logger.info(f"TTS synthesis sent to {client_id} ({len(audio_data)} bytes)") finally: # Clean up temp file if os.path.exists(audio_path): os.unlink(audio_path) else: await websocket.send_text(json.dumps({ "type": "tts_synthesis_error", "client_id": client_id, "error": "Speech synthesis failed", "timing": timing })) except Exception as e: logger.error(f"Error processing TTS for {client_id}: {str(e)}") if client_id in self.active_connections: websocket = self.active_connections[client_id] await websocket.send_text(json.dumps({ "type": "tts_synthesis_error", "client_id": client_id, "error": f"Processing error: {str(e)}" })) async def process_streaming_tts_message(self, client_id: str, message: Dict[str, Any]): """Process streaming TTS request (unmute.sh methodology)""" try: websocket = self.active_connections[client_id] # Extract streaming data text_chunks = message.get("text_chunks", []) voice_preset = message.get("voice_preset", "v2/en_speaker_6") is_final = message.get("is_final", True) if is_final and text_chunks: # UNMUTE.SH FLUSH TRICK: Process all accumulated text at once complete_text = " ".join(text_chunks).strip() logger.info(f"🔊 TTS STREAMING: Final synthesis for {client_id}: '{complete_text[:50]}...'") # Synthesize complete text audio_path, status, timing = await self.synthesize_speech( complete_text, voice_preset ) if status == "success" and audio_path: try: # Read generated audio with open(audio_path, 'rb') as audio_file: audio_data = audio_file.read() # Encode as base64 audio_b64 = base64.b64encode(audio_data).decode('utf-8') # Send streaming response await websocket.send_text(json.dumps({ "type": "tts_streaming_response", "client_id": client_id, "audio_data": audio_b64, "audio_format": "wav", "text": complete_text, "text_chunks": text_chunks, "voice_preset": voice_preset, "audio_size": len(audio_data), "timing": timing, "is_final": is_final, "streaming_method": "unmute.sh_flush_trick", "status": "success" })) logger.info(f"🔊 TTS STREAMING: Final audio sent to {client_id} ({len(audio_data)} bytes)") finally: # Clean up if os.path.exists(audio_path): os.unlink(audio_path) else: await websocket.send_text(json.dumps({ "type": "tts_streaming_error", "client_id": client_id, "message": f"TTS streaming synthesis failed: {status}", "text": complete_text, "is_final": is_final })) else: # Send partial progress update (no audio yet) await websocket.send_text(json.dumps({ "type": "tts_streaming_progress", "client_id": client_id, "text_chunks": text_chunks, "is_final": is_final, "message": f"Accumulating text chunks: {len(text_chunks)}" })) except Exception as e: logger.error(f"Error processing streaming TTS for {client_id}: {str(e)}") if client_id in self.active_connections: websocket = self.active_connections[client_id] await websocket.send_text(json.dumps({ "type": "tts_streaming_error", "client_id": client_id, "error": f"Streaming processing error: {str(e)}" })) # Initialize service tts_service = TTSWebSocketService() # Create FastAPI app app = FastAPI( title="TTS WebSocket Service", description="Standalone WebSocket-only Text-to-Speech service", version=__version__ ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.on_event("startup") async def startup_event(): """Initialize service on startup""" logger.info(f"🚀 {__service__} v{__version__} starting...") logger.info("Pre-loading Bark TTS model for optimal performance...") await tts_service.load_model() logger.info("✅ Service ready for WebSocket connections") @app.get("/") async def root(): """Health check endpoint""" return { "service": __service__, "version": __version__, "status": "ready", "endpoints": { "websocket": "/ws/tts", "health": "/health" }, "available_voices": tts_service.available_voices, "device": tts_service.device } @app.get("/health") async def health_check(): """Detailed health check""" return { "service": __service__, "version": __version__, "status": "healthy", "model_loaded": tts_service.model is not None, "active_connections": len(tts_service.active_connections), "available_voices": len(tts_service.available_voices), "device": tts_service.device, "timestamp": datetime.now().isoformat() } @app.websocket("/ws/tts") async def websocket_tts_endpoint(websocket: WebSocket): """Main TTS WebSocket endpoint""" client_id = None try: # Accept connection client_id = await tts_service.connect_websocket(websocket) # Handle messages while True: try: # Receive message data = await websocket.receive_text() message = json.loads(data) # Process based on message type message_type = message.get("type", "unknown") if message_type == "tts_synthesize": await tts_service.process_tts_message(client_id, message) elif message_type == "tts_streaming_text": await tts_service.process_streaming_tts_message(client_id, message) elif message_type == "ping": # Respond to ping await websocket.send_text(json.dumps({ "type": "pong", "client_id": client_id, "timestamp": datetime.now().isoformat() })) else: logger.warning(f"Unknown message type from {client_id}: {message_type}") except WebSocketDisconnect: break except json.JSONDecodeError: await websocket.send_text(json.dumps({ "type": "tts_synthesis_error", "client_id": client_id, "error": "Invalid JSON message format" })) except Exception as e: logger.error(f"Error handling message from {client_id}: {str(e)}") break except WebSocketDisconnect: logger.info(f"Client {client_id} disconnected normally") except Exception as e: logger.error(f"WebSocket error for {client_id}: {str(e)}") finally: if client_id: await tts_service.disconnect_websocket(client_id) if __name__ == "__main__": port = int(os.environ.get("PORT", 7860)) # HuggingFace Spaces standard port logger.info(f"🔊 Starting {__service__} v{__version__} on port {port}") uvicorn.run( app, host="0.0.0.0", port=port, log_level="info" )