tts-gpu-service / websocket_tts_server.py
Peter Michael Gits
feat: Add standalone WebSocket-only TTS service v1.0.0
390e1c5
#!/usr/bin/env python3
"""
Standalone WebSocket-only TTS Service
Simplified service without Gradio, MCP, or web interfaces
Following unmute.sh WebRTC pattern for HuggingFace Spaces
"""
import asyncio
import json
import uuid
import base64
import tempfile
import os
import logging
import time
from datetime import datetime
from typing import Optional, Dict, Any
import torch
from transformers import AutoProcessor, BarkModel
import soundfile as sf
import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
import spaces
import uvicorn
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Version info
__version__ = "1.0.0"
__service__ = "TTS WebSocket Service"
class TTSWebSocketService:
"""Standalone TTS service with WebSocket-only interface"""
def __init__(self):
self.model = None
self.processor = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.active_connections: Dict[str, WebSocket] = {}
self.available_voices = [
"v2/en_speaker_0", "v2/en_speaker_1", "v2/en_speaker_2", "v2/en_speaker_3",
"v2/en_speaker_4", "v2/en_speaker_5", "v2/en_speaker_6", "v2/en_speaker_7",
"v2/en_speaker_8", "v2/en_speaker_9"
]
logger.info(f"πŸ”Š {__service__} v{__version__} initializing...")
logger.info(f"Device: {self.device}")
logger.info(f"Available voices: {len(self.available_voices)}")
async def load_model(self):
"""Load Bark TTS model with ZeroGPU compatibility"""
if self.model is None:
logger.info("Loading Bark TTS model...")
self.processor = AutoProcessor.from_pretrained("suno/bark")
self.model = BarkModel.from_pretrained("suno/bark")
if self.device == "cuda":
self.model = self.model.to(self.device)
logger.info(f"βœ… Bark model loaded on {self.device}")
@spaces.GPU(duration=30)
async def synthesize_speech(
self,
text: str,
voice_preset: str = "v2/en_speaker_6",
sample_rate: int = 24000
) -> tuple[Optional[str], str, Dict[str, Any]]:
"""Synthesize speech from text using Bark with ZeroGPU"""
try:
if not text.strip():
return None, "error", {"error": "Empty text provided"}
start_time = time.time()
# Ensure model is loaded
if self.model is None:
await self.load_model()
logger.info(f"Synthesizing: '{text[:50]}...' with {voice_preset}")
# Process text with voice preset
inputs = self.processor(
text,
voice_preset=voice_preset,
return_tensors="pt"
)
if self.device == "cuda":
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate audio
with torch.no_grad():
audio_array = self.model.generate(**inputs)
audio_array = audio_array.cpu().numpy().squeeze()
# Save to temporary WAV file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
sf.write(tmp_file.name, audio_array, sample_rate)
temp_path = tmp_file.name
# Calculate timing
processing_time = time.time() - start_time
timing_info = {
"processing_time": processing_time,
"start_time": datetime.fromtimestamp(start_time).isoformat(),
"end_time": datetime.now().isoformat(),
"voice_preset": voice_preset,
"device": self.device,
"text_length": len(text),
"sample_rate": sample_rate
}
logger.info(f"Speech synthesis completed in {processing_time:.2f}s")
return temp_path, "success", timing_info
except Exception as e:
logger.error(f"TTS synthesis error: {str(e)}")
return None, "error", {"error": str(e)}
async def connect_websocket(self, websocket: WebSocket) -> str:
"""Accept WebSocket connection and return client ID"""
client_id = str(uuid.uuid4())
await websocket.accept()
self.active_connections[client_id] = websocket
# Send connection confirmation
await websocket.send_text(json.dumps({
"type": "tts_connection_confirmed",
"client_id": client_id,
"service": __service__,
"version": __version__,
"available_voices": self.available_voices,
"device": self.device,
"message": "TTS WebSocket connected and ready"
}))
logger.info(f"Client {client_id} connected")
return client_id
async def disconnect_websocket(self, client_id: str):
"""Clean up WebSocket connection"""
if client_id in self.active_connections:
del self.active_connections[client_id]
logger.info(f"Client {client_id} disconnected")
async def process_tts_message(self, client_id: str, message: Dict[str, Any]):
"""Process incoming TTS request from WebSocket"""
try:
websocket = self.active_connections[client_id]
# Extract text and voice preset
text = message.get("text", "").strip()
voice_preset = message.get("voice_preset", "v2/en_speaker_6")
if not text:
await websocket.send_text(json.dumps({
"type": "tts_synthesis_error",
"client_id": client_id,
"error": "No text provided for synthesis"
}))
return
# Validate voice preset
if voice_preset not in self.available_voices:
voice_preset = "v2/en_speaker_6" # Default fallback
# Synthesize speech
audio_path, status, timing = await self.synthesize_speech(
text,
voice_preset
)
if status == "success" and audio_path:
try:
# Read generated audio file
with open(audio_path, 'rb') as audio_file:
audio_data = audio_file.read()
# Encode as base64 for WebSocket transmission
audio_b64 = base64.b64encode(audio_data).decode('utf-8')
# Send result back
await websocket.send_text(json.dumps({
"type": "tts_synthesis_complete",
"client_id": client_id,
"audio_data": audio_b64,
"audio_format": "wav",
"text": text,
"voice_preset": voice_preset,
"audio_size": len(audio_data),
"timing": timing,
"status": "success"
}))
logger.info(f"TTS synthesis sent to {client_id} ({len(audio_data)} bytes)")
finally:
# Clean up temp file
if os.path.exists(audio_path):
os.unlink(audio_path)
else:
await websocket.send_text(json.dumps({
"type": "tts_synthesis_error",
"client_id": client_id,
"error": "Speech synthesis failed",
"timing": timing
}))
except Exception as e:
logger.error(f"Error processing TTS for {client_id}: {str(e)}")
if client_id in self.active_connections:
websocket = self.active_connections[client_id]
await websocket.send_text(json.dumps({
"type": "tts_synthesis_error",
"client_id": client_id,
"error": f"Processing error: {str(e)}"
}))
async def process_streaming_tts_message(self, client_id: str, message: Dict[str, Any]):
"""Process streaming TTS request (unmute.sh methodology)"""
try:
websocket = self.active_connections[client_id]
# Extract streaming data
text_chunks = message.get("text_chunks", [])
voice_preset = message.get("voice_preset", "v2/en_speaker_6")
is_final = message.get("is_final", True)
if is_final and text_chunks:
# UNMUTE.SH FLUSH TRICK: Process all accumulated text at once
complete_text = " ".join(text_chunks).strip()
logger.info(f"πŸ”Š TTS STREAMING: Final synthesis for {client_id}: '{complete_text[:50]}...'")
# Synthesize complete text
audio_path, status, timing = await self.synthesize_speech(
complete_text,
voice_preset
)
if status == "success" and audio_path:
try:
# Read generated audio
with open(audio_path, 'rb') as audio_file:
audio_data = audio_file.read()
# Encode as base64
audio_b64 = base64.b64encode(audio_data).decode('utf-8')
# Send streaming response
await websocket.send_text(json.dumps({
"type": "tts_streaming_response",
"client_id": client_id,
"audio_data": audio_b64,
"audio_format": "wav",
"text": complete_text,
"text_chunks": text_chunks,
"voice_preset": voice_preset,
"audio_size": len(audio_data),
"timing": timing,
"is_final": is_final,
"streaming_method": "unmute.sh_flush_trick",
"status": "success"
}))
logger.info(f"πŸ”Š TTS STREAMING: Final audio sent to {client_id} ({len(audio_data)} bytes)")
finally:
# Clean up
if os.path.exists(audio_path):
os.unlink(audio_path)
else:
await websocket.send_text(json.dumps({
"type": "tts_streaming_error",
"client_id": client_id,
"message": f"TTS streaming synthesis failed: {status}",
"text": complete_text,
"is_final": is_final
}))
else:
# Send partial progress update (no audio yet)
await websocket.send_text(json.dumps({
"type": "tts_streaming_progress",
"client_id": client_id,
"text_chunks": text_chunks,
"is_final": is_final,
"message": f"Accumulating text chunks: {len(text_chunks)}"
}))
except Exception as e:
logger.error(f"Error processing streaming TTS for {client_id}: {str(e)}")
if client_id in self.active_connections:
websocket = self.active_connections[client_id]
await websocket.send_text(json.dumps({
"type": "tts_streaming_error",
"client_id": client_id,
"error": f"Streaming processing error: {str(e)}"
}))
# Initialize service
tts_service = TTSWebSocketService()
# Create FastAPI app
app = FastAPI(
title="TTS WebSocket Service",
description="Standalone WebSocket-only Text-to-Speech service",
version=__version__
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
async def startup_event():
"""Initialize service on startup"""
logger.info(f"πŸš€ {__service__} v{__version__} starting...")
logger.info("Pre-loading Bark TTS model for optimal performance...")
await tts_service.load_model()
logger.info("βœ… Service ready for WebSocket connections")
@app.get("/")
async def root():
"""Health check endpoint"""
return {
"service": __service__,
"version": __version__,
"status": "ready",
"endpoints": {
"websocket": "/ws/tts",
"health": "/health"
},
"available_voices": tts_service.available_voices,
"device": tts_service.device
}
@app.get("/health")
async def health_check():
"""Detailed health check"""
return {
"service": __service__,
"version": __version__,
"status": "healthy",
"model_loaded": tts_service.model is not None,
"active_connections": len(tts_service.active_connections),
"available_voices": len(tts_service.available_voices),
"device": tts_service.device,
"timestamp": datetime.now().isoformat()
}
@app.websocket("/ws/tts")
async def websocket_tts_endpoint(websocket: WebSocket):
"""Main TTS WebSocket endpoint"""
client_id = None
try:
# Accept connection
client_id = await tts_service.connect_websocket(websocket)
# Handle messages
while True:
try:
# Receive message
data = await websocket.receive_text()
message = json.loads(data)
# Process based on message type
message_type = message.get("type", "unknown")
if message_type == "tts_synthesize":
await tts_service.process_tts_message(client_id, message)
elif message_type == "tts_streaming_text":
await tts_service.process_streaming_tts_message(client_id, message)
elif message_type == "ping":
# Respond to ping
await websocket.send_text(json.dumps({
"type": "pong",
"client_id": client_id,
"timestamp": datetime.now().isoformat()
}))
else:
logger.warning(f"Unknown message type from {client_id}: {message_type}")
except WebSocketDisconnect:
break
except json.JSONDecodeError:
await websocket.send_text(json.dumps({
"type": "tts_synthesis_error",
"client_id": client_id,
"error": "Invalid JSON message format"
}))
except Exception as e:
logger.error(f"Error handling message from {client_id}: {str(e)}")
break
except WebSocketDisconnect:
logger.info(f"Client {client_id} disconnected normally")
except Exception as e:
logger.error(f"WebSocket error for {client_id}: {str(e)}")
finally:
if client_id:
await tts_service.disconnect_websocket(client_id)
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860)) # HuggingFace Spaces standard port
logger.info(f"πŸ”Š Starting {__service__} v{__version__} on port {port}")
uvicorn.run(
app,
host="0.0.0.0",
port=port,
log_level="info"
)