Spaces:
Sleeping
Sleeping
File size: 18,374 Bytes
06a06a0 d1c4aa1 06a06a0 d1c4aa1 06a06a0 d1c4aa1 06a06a0 d1c4aa1 4c06331 d1c4aa1 4c06331 d1c4aa1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 |
"""
Main TTS service class with all API endpoints.
"""
import io
import base64
import warnings
from typing import Optional
import modal
from fastapi.responses import StreamingResponse, Response
from fastapi import HTTPException, File, UploadFile, Form
from .config import app, image
from .models import TTSRequest, TTSResponse, HealthResponse, FullTextTTSRequest, FullTextTTSResponse
from .audio_utils import AudioUtils
from .text_processing import TextChunker
from .audio_concatenator import AudioConcatenator
with image.imports():
from chatterbox.tts import ChatterboxTTS
import torch # Add torch import here
# Suppress specific transformers deprecation warnings
warnings.filterwarnings("ignore", message=".*past_key_values.*", category=FutureWarning)
@app.cls(
gpu="a10g",
scaledown_window=60 * 5,
enable_memory_snapshot=True
)
@modal.concurrent(
max_inputs=10
)
class ChatterboxTTSService:
"""
Advanced text-to-speech service using Chatterbox TTS model.
Provides multiple endpoints for different use cases including
voice cloning, file uploads, and JSON responses.
"""
@modal.enter()
def load(self):
"""Load the Chatterbox TTS model on container startup."""
print("Loading Chatterbox TTS model...")
# Suppress transformers deprecation warnings
warnings.filterwarnings("ignore", message=".*past_key_values.*", category=FutureWarning)
warnings.filterwarnings("ignore", message=".*tuple of tuples.*", category=FutureWarning)
self.model = ChatterboxTTS.from_pretrained(device="cuda")
print(f"Model loaded successfully! Sample rate: {self.model.sr}")
def _validate_text_input(self, text: str) -> None:
"""Validate text input parameters."""
if not text or len(text.strip()) == 0:
raise HTTPException(status_code=400, detail="Text cannot be empty")
def _process_voice_prompt(self, voice_prompt_base64: Optional[str]) -> Optional[str]:
"""Process base64 encoded voice prompt and return temp file path."""
if not voice_prompt_base64:
return None
try:
audio_data = base64.b64decode(voice_prompt_base64)
return AudioUtils.save_temp_audio_file(audio_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid voice prompt audio: {str(e)}")
def _generate_audio(self, text: str, audio_prompt_path: Optional[str] = None):
"""Generate audio with optional voice prompt."""
print(f"Generating audio for text: {text[:50]}...")
try:
if audio_prompt_path:
wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)
AudioUtils.cleanup_temp_file(audio_prompt_path)
else:
wav = self.model.generate(text)
return wav
except Exception as e:
if audio_prompt_path:
AudioUtils.cleanup_temp_file(audio_prompt_path)
raise e
@modal.fastapi_endpoint(docs=True, method="GET")
def health(self) -> HealthResponse:
"""Health check endpoint to verify model status."""
return HealthResponse(
status="healthy",
model_loaded=hasattr(self, 'model') and self.model is not None
)
@modal.fastapi_endpoint(docs=True, method="POST")
def generate_audio(self, request: TTSRequest) -> StreamingResponse:
"""
Generate speech audio from text with optional voice prompt.
Args:
request: TTSRequest containing text and optional voice prompt
Returns:
StreamingResponse with generated audio as WAV file
"""
try:
self._validate_text_input(request.text)
audio_prompt_path = self._process_voice_prompt(request.voice_prompt_base64)
# Generate audio
wav = self._generate_audio(request.text, audio_prompt_path)
# Create audio buffer
buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr)
return StreamingResponse(
io.BytesIO(buffer.read()),
media_type="audio/wav",
headers={
"Content-Disposition": "attachment; filename=generated_speech.wav",
"X-Audio-Duration": str(len(wav[0]) / self.model.sr)
}
)
except HTTPException:
raise
except Exception as e:
print(f"Error generating audio: {str(e)}")
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
@modal.fastapi_endpoint(docs=True, method="POST")
def generate_with_file(
self,
text: str = Form(..., description="Text to convert to speech"),
voice_prompt: Optional[UploadFile] = File(None, description="Optional voice prompt audio file")
) -> StreamingResponse:
"""
Generate speech audio from text with optional voice prompt file upload.
Args:
text: Text to convert to speech
voice_prompt: Optional audio file for voice cloning
Returns:
StreamingResponse with generated audio as WAV file
"""
try:
self._validate_text_input(text)
# Handle voice prompt file if provided
audio_prompt_path = None
if voice_prompt:
if voice_prompt.content_type not in ["audio/wav", "audio/mpeg", "audio/mp3"]:
raise HTTPException(
status_code=400,
detail="Voice prompt must be WAV, MP3, or MPEG audio file"
)
# Read and save the uploaded file
audio_data = voice_prompt.file.read()
audio_prompt_path = AudioUtils.save_temp_audio_file(audio_data)
# Generate audio
wav = self._generate_audio(text, audio_prompt_path)
# Create audio buffer
buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr)
return StreamingResponse(
io.BytesIO(buffer.read()),
media_type="audio/wav",
headers={
"Content-Disposition": "attachment; filename=generated_speech.wav",
"X-Audio-Duration": str(len(wav[0]) / self.model.sr)
}
)
except HTTPException:
raise
except Exception as e:
print(f"Error generating audio: {str(e)}")
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
@modal.fastapi_endpoint(docs=True, method="POST")
def generate_json(self, request: TTSRequest) -> TTSResponse:
"""
Generate speech audio and return as JSON with base64 encoded audio.
Args:
request: TTSRequest containing text and optional voice prompt
Returns:
TTSResponse with base64 encoded audio data
"""
try:
self._validate_text_input(request.text)
audio_prompt_path = self._process_voice_prompt(request.voice_prompt_base64)
# Generate audio
wav = self._generate_audio(request.text, audio_prompt_path)
# Convert to base64
buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr)
audio_base64 = base64.b64encode(buffer.read()).decode('utf-8')
duration = len(wav[0]) / self.model.sr
return TTSResponse(
success=True,
message="Audio generated successfully",
audio_base64=audio_base64,
duration_seconds=duration
)
except HTTPException as http_exc:
return TTSResponse(success=False, message=str(http_exc.detail))
except Exception as e:
print(f"Error generating audio: {str(e)}")
return TTSResponse(success=False, message=f"Audio generation failed: {str(e)}")
@modal.fastapi_endpoint(docs=True, method="POST")
def generate(self, prompt: str):
"""
Legacy endpoint for backward compatibility.
Generate audio waveform from the input text.
"""
try:
# Generate audio waveform from the input text
wav = self.model.generate(prompt)
# Create audio buffer
buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr)
# Return the audio as a streaming response with appropriate MIME type.
return StreamingResponse(
io.BytesIO(buffer.read()),
media_type="audio/wav",
)
except Exception as e:
print(f"Error in legacy endpoint: {str(e)}")
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
@modal.fastapi_endpoint(docs=True, method="POST")
def generate_audio_file(self, request: TTSRequest) -> Response:
"""
Generate speech audio from text with optional voice prompt and return as a complete file.
Unlike the streaming endpoint, this returns the entire file at once.
Args:
request: TTSRequest containing text and optional voice prompt
Returns:
Response with complete audio file data
"""
try:
self._validate_text_input(request.text)
audio_prompt_path = self._process_voice_prompt(request.voice_prompt_base64)
# Generate audio
wav = self._generate_audio(request.text, audio_prompt_path)
# Create audio buffer
buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr)
audio_data = buffer.read()
duration = len(wav[0]) / self.model.sr
# Return the complete audio file
return Response(
content=audio_data,
media_type="audio/wav",
headers={
"Content-Disposition": "attachment; filename=generated_speech.wav",
"X-Audio-Duration": str(duration)
}
)
except HTTPException:
raise
except Exception as e:
print(f"Error generating audio: {str(e)}")
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
@modal.fastapi_endpoint(docs=True, method="POST")
def generate_full_text_audio(self, request: FullTextTTSRequest) -> StreamingResponse:
"""
Generate speech audio from full text with server-side chunking and parallel processing.
This endpoint handles texts of any length by:
1. Chunking the text intelligently (respecting sentence/paragraph boundaries)
2. Processing chunks in parallel using GPU resources
3. Concatenating audio chunks with proper transitions
4. Returning the final audio file
Args:
request: FullTextTTSRequest containing text and processing parameters
Returns:
StreamingResponse with final concatenated audio as WAV file
"""
try:
self._validate_text_input(request.text)
audio_prompt_path = self._process_voice_prompt(request.voice_prompt_base64)
print(f"Processing full text ({len(request.text)} chars) with server-side chunking...")
# Initialize text chunker with request parameters
chunker = TextChunker(
max_chunk_size=request.max_chunk_size,
overlap_sentences=request.overlap_sentences
)
# Chunk the text
text_chunks = chunker.chunk_text(request.text)
chunk_info = chunker.get_chunk_info(text_chunks)
print(f"Split text into {len(text_chunks)} chunks for processing")
# Initialize audio_chunks variable for processing info
audio_chunks = []
# If only one chunk, process directly
if len(text_chunks) == 1:
wav = self._generate_audio(text_chunks[0], audio_prompt_path)
# For single chunk, pass the full wav object to maintain consistency
final_audio = wav
audio_chunks = [wav] # For consistent processing info
else:
# Process chunks in parallel
import concurrent.futures
import numpy as np
def process_chunk(chunk_text: str):
"""Process a single chunk."""
wav_result = self._generate_audio(chunk_text, audio_prompt_path)
# Return the full wav result, not just wav[0]
return wav_result
# Use ThreadPoolExecutor for parallel processing
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
# Submit all chunks for processing
future_to_chunk = {
executor.submit(process_chunk, chunk): i
for i, chunk in enumerate(text_chunks)
}
# Collect results in order
results = [None] * len(text_chunks)
for future in concurrent.futures.as_completed(future_to_chunk):
chunk_index = future_to_chunk[future]
try:
audio_result = future.result()
results[chunk_index] = audio_result
except Exception as exc:
print(f'Chunk {chunk_index} generated an exception: {exc}')
raise HTTPException(status_code=500, detail=f"Failed to process chunk {chunk_index}: {str(exc)}")
# Filter out None results
audio_chunks = [result for result in results if result is not None]
if len(audio_chunks) != len(text_chunks):
raise HTTPException(status_code=500, detail=f"Only {len(audio_chunks)} out of {len(text_chunks)} chunks processed successfully")
# Concatenate audio chunks
print("Concatenating audio chunks...")
concatenator = AudioConcatenator(
silence_duration=request.silence_duration,
fade_duration=request.fade_duration
)
final_audio = concatenator.concatenate_audio_chunks(audio_chunks, self.model.sr)
# --- Start of new audio processing logic ---
import torch
import numpy as np
processed_tensor = final_audio
# Unwrap if it's a single-element tuple repeatedly
while isinstance(processed_tensor, tuple) and len(processed_tensor) == 1:
processed_tensor = processed_tensor[0]
# Convert to PyTorch tensor if it's a NumPy array
if isinstance(processed_tensor, np.ndarray):
processed_tensor = torch.from_numpy(processed_tensor.astype(np.float32))
if not isinstance(processed_tensor, torch.Tensor): # Check if it's a tensor now
raise TypeError(f"Audio data after concatenation is not a tensor. Got type: {type(processed_tensor)}")
# Ensure correct shape (C, L) for torchaudio.save
if processed_tensor.ndim == 1: # Shape (L,)
audio_to_save = processed_tensor.unsqueeze(0) # Convert to (1, L)
elif processed_tensor.ndim == 2: # Shape (C, L)
if processed_tensor.shape[0] == 0:
raise ValueError(f"Audio tensor has 0 channels: {processed_tensor.shape}")
if processed_tensor.shape[0] > 1: # If C > 1 (stereo/multi-channel)
print(f"Multi-channel audio (shape {processed_tensor.shape}) detected. Taking the first channel.")
audio_to_save = processed_tensor[0, :].unsqueeze(0) # Result is (1, L)
else: # Already (1, L)
audio_to_save = processed_tensor
else:
raise ValueError(f"Unexpected audio tensor dimensions: {processed_tensor.ndim}, shape: {processed_tensor.shape}")
buffer = AudioUtils.save_audio_to_buffer(audio_to_save, self.model.sr)
duration = audio_to_save.shape[1] / self.model.sr # Use shape[1] for length
# Reset buffer position for reading
buffer.seek(0)
# --- End of new audio processing logic --- # Prepare processing info
processing_info = {
"total_chunks": len(text_chunks),
"processed_chunks": len(audio_chunks),
"failed_chunks": len(text_chunks) - len(audio_chunks),
"sample_rate": self.model.sr,
"duration": duration
}
print(f"Full text processing complete! Final audio duration: {duration:.2f} seconds")
return StreamingResponse(
io.BytesIO(buffer.read()),
media_type="audio/wav",
headers={
"Content-Disposition": "attachment; filename=generated_full_text_speech.wav",
"X-Audio-Duration": str(duration),
"X-Chunks-Processed": str(len(audio_chunks)),
"X-Total-Characters": str(len(request.text))
}
)
except HTTPException as http_exc:
print(f"HTTP exception in full text generation: {http_exc.detail}")
raise http_exc
except Exception as e:
error_msg = f"Full text audio generation failed: {str(e)}"
print(f"Exception in full text generation: {error_msg}")
raise HTTPException(status_code=500, detail=error_msg)
|