Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Standalone WebSocket-only TTS Service | |
| Simplified service without Gradio, MCP, or web interfaces | |
| Following unmute.sh WebRTC pattern for HuggingFace Spaces | |
| """ | |
| import asyncio | |
| import json | |
| import uuid | |
| import base64 | |
| import tempfile | |
| import os | |
| import logging | |
| import time | |
| from datetime import datetime | |
| from typing import Optional, Dict, Any | |
| import torch | |
| from transformers import AutoProcessor, BarkModel | |
| import soundfile as sf | |
| import numpy as np | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import spaces | |
| import uvicorn | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Version info | |
| __version__ = "1.0.0" | |
| __service__ = "TTS WebSocket Service" | |
| class TTSWebSocketService: | |
| """Standalone TTS service with WebSocket-only interface""" | |
| def __init__(self): | |
| self.model = None | |
| self.processor = None | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.active_connections: Dict[str, WebSocket] = {} | |
| self.available_voices = [ | |
| "v2/en_speaker_0", "v2/en_speaker_1", "v2/en_speaker_2", "v2/en_speaker_3", | |
| "v2/en_speaker_4", "v2/en_speaker_5", "v2/en_speaker_6", "v2/en_speaker_7", | |
| "v2/en_speaker_8", "v2/en_speaker_9" | |
| ] | |
| logger.info(f"π {__service__} v{__version__} initializing...") | |
| logger.info(f"Device: {self.device}") | |
| logger.info(f"Available voices: {len(self.available_voices)}") | |
| async def load_model(self): | |
| """Load Bark TTS model with ZeroGPU compatibility""" | |
| if self.model is None: | |
| logger.info("Loading Bark TTS model...") | |
| self.processor = AutoProcessor.from_pretrained("suno/bark") | |
| self.model = BarkModel.from_pretrained("suno/bark") | |
| if self.device == "cuda": | |
| self.model = self.model.to(self.device) | |
| logger.info(f"β Bark model loaded on {self.device}") | |
| async def synthesize_speech( | |
| self, | |
| text: str, | |
| voice_preset: str = "v2/en_speaker_6", | |
| sample_rate: int = 24000 | |
| ) -> tuple[Optional[str], str, Dict[str, Any]]: | |
| """Synthesize speech from text using Bark with ZeroGPU""" | |
| try: | |
| if not text.strip(): | |
| return None, "error", {"error": "Empty text provided"} | |
| start_time = time.time() | |
| # Ensure model is loaded | |
| if self.model is None: | |
| await self.load_model() | |
| logger.info(f"Synthesizing: '{text[:50]}...' with {voice_preset}") | |
| # Process text with voice preset | |
| inputs = self.processor( | |
| text, | |
| voice_preset=voice_preset, | |
| return_tensors="pt" | |
| ) | |
| if self.device == "cuda": | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| # Generate audio | |
| with torch.no_grad(): | |
| audio_array = self.model.generate(**inputs) | |
| audio_array = audio_array.cpu().numpy().squeeze() | |
| # Save to temporary WAV file | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: | |
| sf.write(tmp_file.name, audio_array, sample_rate) | |
| temp_path = tmp_file.name | |
| # Calculate timing | |
| processing_time = time.time() - start_time | |
| timing_info = { | |
| "processing_time": processing_time, | |
| "start_time": datetime.fromtimestamp(start_time).isoformat(), | |
| "end_time": datetime.now().isoformat(), | |
| "voice_preset": voice_preset, | |
| "device": self.device, | |
| "text_length": len(text), | |
| "sample_rate": sample_rate | |
| } | |
| logger.info(f"Speech synthesis completed in {processing_time:.2f}s") | |
| return temp_path, "success", timing_info | |
| except Exception as e: | |
| logger.error(f"TTS synthesis error: {str(e)}") | |
| return None, "error", {"error": str(e)} | |
| async def connect_websocket(self, websocket: WebSocket) -> str: | |
| """Accept WebSocket connection and return client ID""" | |
| client_id = str(uuid.uuid4()) | |
| await websocket.accept() | |
| self.active_connections[client_id] = websocket | |
| # Send connection confirmation | |
| await websocket.send_text(json.dumps({ | |
| "type": "tts_connection_confirmed", | |
| "client_id": client_id, | |
| "service": __service__, | |
| "version": __version__, | |
| "available_voices": self.available_voices, | |
| "device": self.device, | |
| "message": "TTS WebSocket connected and ready" | |
| })) | |
| logger.info(f"Client {client_id} connected") | |
| return client_id | |
| async def disconnect_websocket(self, client_id: str): | |
| """Clean up WebSocket connection""" | |
| if client_id in self.active_connections: | |
| del self.active_connections[client_id] | |
| logger.info(f"Client {client_id} disconnected") | |
| async def process_tts_message(self, client_id: str, message: Dict[str, Any]): | |
| """Process incoming TTS request from WebSocket""" | |
| try: | |
| websocket = self.active_connections[client_id] | |
| # Extract text and voice preset | |
| text = message.get("text", "").strip() | |
| voice_preset = message.get("voice_preset", "v2/en_speaker_6") | |
| if not text: | |
| await websocket.send_text(json.dumps({ | |
| "type": "tts_synthesis_error", | |
| "client_id": client_id, | |
| "error": "No text provided for synthesis" | |
| })) | |
| return | |
| # Validate voice preset | |
| if voice_preset not in self.available_voices: | |
| voice_preset = "v2/en_speaker_6" # Default fallback | |
| # Synthesize speech | |
| audio_path, status, timing = await self.synthesize_speech( | |
| text, | |
| voice_preset | |
| ) | |
| if status == "success" and audio_path: | |
| try: | |
| # Read generated audio file | |
| with open(audio_path, 'rb') as audio_file: | |
| audio_data = audio_file.read() | |
| # Encode as base64 for WebSocket transmission | |
| audio_b64 = base64.b64encode(audio_data).decode('utf-8') | |
| # Send result back | |
| await websocket.send_text(json.dumps({ | |
| "type": "tts_synthesis_complete", | |
| "client_id": client_id, | |
| "audio_data": audio_b64, | |
| "audio_format": "wav", | |
| "text": text, | |
| "voice_preset": voice_preset, | |
| "audio_size": len(audio_data), | |
| "timing": timing, | |
| "status": "success" | |
| })) | |
| logger.info(f"TTS synthesis sent to {client_id} ({len(audio_data)} bytes)") | |
| finally: | |
| # Clean up temp file | |
| if os.path.exists(audio_path): | |
| os.unlink(audio_path) | |
| else: | |
| await websocket.send_text(json.dumps({ | |
| "type": "tts_synthesis_error", | |
| "client_id": client_id, | |
| "error": "Speech synthesis failed", | |
| "timing": timing | |
| })) | |
| except Exception as e: | |
| logger.error(f"Error processing TTS for {client_id}: {str(e)}") | |
| if client_id in self.active_connections: | |
| websocket = self.active_connections[client_id] | |
| await websocket.send_text(json.dumps({ | |
| "type": "tts_synthesis_error", | |
| "client_id": client_id, | |
| "error": f"Processing error: {str(e)}" | |
| })) | |
| async def process_streaming_tts_message(self, client_id: str, message: Dict[str, Any]): | |
| """Process streaming TTS request (unmute.sh methodology)""" | |
| try: | |
| websocket = self.active_connections[client_id] | |
| # Extract streaming data | |
| text_chunks = message.get("text_chunks", []) | |
| voice_preset = message.get("voice_preset", "v2/en_speaker_6") | |
| is_final = message.get("is_final", True) | |
| if is_final and text_chunks: | |
| # UNMUTE.SH FLUSH TRICK: Process all accumulated text at once | |
| complete_text = " ".join(text_chunks).strip() | |
| logger.info(f"π TTS STREAMING: Final synthesis for {client_id}: '{complete_text[:50]}...'") | |
| # Synthesize complete text | |
| audio_path, status, timing = await self.synthesize_speech( | |
| complete_text, | |
| voice_preset | |
| ) | |
| if status == "success" and audio_path: | |
| try: | |
| # Read generated audio | |
| with open(audio_path, 'rb') as audio_file: | |
| audio_data = audio_file.read() | |
| # Encode as base64 | |
| audio_b64 = base64.b64encode(audio_data).decode('utf-8') | |
| # Send streaming response | |
| await websocket.send_text(json.dumps({ | |
| "type": "tts_streaming_response", | |
| "client_id": client_id, | |
| "audio_data": audio_b64, | |
| "audio_format": "wav", | |
| "text": complete_text, | |
| "text_chunks": text_chunks, | |
| "voice_preset": voice_preset, | |
| "audio_size": len(audio_data), | |
| "timing": timing, | |
| "is_final": is_final, | |
| "streaming_method": "unmute.sh_flush_trick", | |
| "status": "success" | |
| })) | |
| logger.info(f"π TTS STREAMING: Final audio sent to {client_id} ({len(audio_data)} bytes)") | |
| finally: | |
| # Clean up | |
| if os.path.exists(audio_path): | |
| os.unlink(audio_path) | |
| else: | |
| await websocket.send_text(json.dumps({ | |
| "type": "tts_streaming_error", | |
| "client_id": client_id, | |
| "message": f"TTS streaming synthesis failed: {status}", | |
| "text": complete_text, | |
| "is_final": is_final | |
| })) | |
| else: | |
| # Send partial progress update (no audio yet) | |
| await websocket.send_text(json.dumps({ | |
| "type": "tts_streaming_progress", | |
| "client_id": client_id, | |
| "text_chunks": text_chunks, | |
| "is_final": is_final, | |
| "message": f"Accumulating text chunks: {len(text_chunks)}" | |
| })) | |
| except Exception as e: | |
| logger.error(f"Error processing streaming TTS for {client_id}: {str(e)}") | |
| if client_id in self.active_connections: | |
| websocket = self.active_connections[client_id] | |
| await websocket.send_text(json.dumps({ | |
| "type": "tts_streaming_error", | |
| "client_id": client_id, | |
| "error": f"Streaming processing error: {str(e)}" | |
| })) | |
| # Initialize service | |
| tts_service = TTSWebSocketService() | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="TTS WebSocket Service", | |
| description="Standalone WebSocket-only Text-to-Speech service", | |
| version=__version__ | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def startup_event(): | |
| """Initialize service on startup""" | |
| logger.info(f"π {__service__} v{__version__} starting...") | |
| logger.info("Pre-loading Bark TTS model for optimal performance...") | |
| await tts_service.load_model() | |
| logger.info("β Service ready for WebSocket connections") | |
| async def root(): | |
| """Health check endpoint""" | |
| return { | |
| "service": __service__, | |
| "version": __version__, | |
| "status": "ready", | |
| "endpoints": { | |
| "websocket": "/ws/tts", | |
| "health": "/health" | |
| }, | |
| "available_voices": tts_service.available_voices, | |
| "device": tts_service.device | |
| } | |
| async def health_check(): | |
| """Detailed health check""" | |
| return { | |
| "service": __service__, | |
| "version": __version__, | |
| "status": "healthy", | |
| "model_loaded": tts_service.model is not None, | |
| "active_connections": len(tts_service.active_connections), | |
| "available_voices": len(tts_service.available_voices), | |
| "device": tts_service.device, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| async def websocket_tts_endpoint(websocket: WebSocket): | |
| """Main TTS WebSocket endpoint""" | |
| client_id = None | |
| try: | |
| # Accept connection | |
| client_id = await tts_service.connect_websocket(websocket) | |
| # Handle messages | |
| while True: | |
| try: | |
| # Receive message | |
| data = await websocket.receive_text() | |
| message = json.loads(data) | |
| # Process based on message type | |
| message_type = message.get("type", "unknown") | |
| if message_type == "tts_synthesize": | |
| await tts_service.process_tts_message(client_id, message) | |
| elif message_type == "tts_streaming_text": | |
| await tts_service.process_streaming_tts_message(client_id, message) | |
| elif message_type == "ping": | |
| # Respond to ping | |
| await websocket.send_text(json.dumps({ | |
| "type": "pong", | |
| "client_id": client_id, | |
| "timestamp": datetime.now().isoformat() | |
| })) | |
| else: | |
| logger.warning(f"Unknown message type from {client_id}: {message_type}") | |
| except WebSocketDisconnect: | |
| break | |
| except json.JSONDecodeError: | |
| await websocket.send_text(json.dumps({ | |
| "type": "tts_synthesis_error", | |
| "client_id": client_id, | |
| "error": "Invalid JSON message format" | |
| })) | |
| except Exception as e: | |
| logger.error(f"Error handling message from {client_id}: {str(e)}") | |
| break | |
| except WebSocketDisconnect: | |
| logger.info(f"Client {client_id} disconnected normally") | |
| except Exception as e: | |
| logger.error(f"WebSocket error for {client_id}: {str(e)}") | |
| finally: | |
| if client_id: | |
| await tts_service.disconnect_websocket(client_id) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) # HuggingFace Spaces standard port | |
| logger.info(f"π Starting {__service__} v{__version__} on port {port}") | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=port, | |
| log_level="info" | |
| ) |