File size: 18,374 Bytes
06a06a0
 
 
 
 
 
 
 
 
 
 
 
 
 
d1c4aa1
06a06a0
d1c4aa1
 
06a06a0
 
 
d1c4aa1
06a06a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1c4aa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c06331
d1c4aa1
4c06331
d1c4aa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
"""
Main TTS service class with all API endpoints.
"""

import io
import base64
import warnings
from typing import Optional

import modal
from fastapi.responses import StreamingResponse, Response
from fastapi import HTTPException, File, UploadFile, Form

from .config import app, image
from .models import TTSRequest, TTSResponse, HealthResponse, FullTextTTSRequest, FullTextTTSResponse
from .audio_utils import AudioUtils
from .text_processing import TextChunker
from .audio_concatenator import AudioConcatenator

with image.imports():
    from chatterbox.tts import ChatterboxTTS
    import torch # Add torch import here
    # Suppress specific transformers deprecation warnings
    warnings.filterwarnings("ignore", message=".*past_key_values.*", category=FutureWarning)


@app.cls(
    gpu="a10g", 
    scaledown_window=60 * 5, 
    enable_memory_snapshot=True
    )
@modal.concurrent(
    max_inputs=10
    )
class ChatterboxTTSService:
    """
    Advanced text-to-speech service using Chatterbox TTS model.
    
    Provides multiple endpoints for different use cases including
    voice cloning, file uploads, and JSON responses.
    """
    
    @modal.enter()
    def load(self):
        """Load the Chatterbox TTS model on container startup."""
        print("Loading Chatterbox TTS model...")
        
        # Suppress transformers deprecation warnings
        warnings.filterwarnings("ignore", message=".*past_key_values.*", category=FutureWarning)
        warnings.filterwarnings("ignore", message=".*tuple of tuples.*", category=FutureWarning)
        
        self.model = ChatterboxTTS.from_pretrained(device="cuda")
        print(f"Model loaded successfully! Sample rate: {self.model.sr}")

    def _validate_text_input(self, text: str) -> None:
        """Validate text input parameters."""
        if not text or len(text.strip()) == 0:
            raise HTTPException(status_code=400, detail="Text cannot be empty")

    def _process_voice_prompt(self, voice_prompt_base64: Optional[str]) -> Optional[str]:
        """Process base64 encoded voice prompt and return temp file path."""
        if not voice_prompt_base64:
            return None
            
        try:
            audio_data = base64.b64decode(voice_prompt_base64)
            return AudioUtils.save_temp_audio_file(audio_data)
        except Exception as e:
            raise HTTPException(status_code=400, detail=f"Invalid voice prompt audio: {str(e)}")

    def _generate_audio(self, text: str, audio_prompt_path: Optional[str] = None):
        """Generate audio with optional voice prompt."""
        print(f"Generating audio for text: {text[:50]}...")
        
        try:
            if audio_prompt_path:
                wav = self.model.generate(text, audio_prompt_path=audio_prompt_path)
                AudioUtils.cleanup_temp_file(audio_prompt_path)
            else:
                wav = self.model.generate(text)
            return wav
        except Exception as e:
            if audio_prompt_path:
                AudioUtils.cleanup_temp_file(audio_prompt_path)
            raise e

    @modal.fastapi_endpoint(docs=True, method="GET")
    def health(self) -> HealthResponse:
        """Health check endpoint to verify model status."""
        return HealthResponse(
            status="healthy", 
            model_loaded=hasattr(self, 'model') and self.model is not None
        )

    @modal.fastapi_endpoint(docs=True, method="POST")
    def generate_audio(self, request: TTSRequest) -> StreamingResponse:
        """
        Generate speech audio from text with optional voice prompt.
        
        Args:
            request: TTSRequest containing text and optional voice prompt
            
        Returns:
            StreamingResponse with generated audio as WAV file
        """
        try:
            self._validate_text_input(request.text)
            audio_prompt_path = self._process_voice_prompt(request.voice_prompt_base64)
            
            # Generate audio
            wav = self._generate_audio(request.text, audio_prompt_path)

            # Create audio buffer
            buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr)
            
            return StreamingResponse(
                io.BytesIO(buffer.read()),
                media_type="audio/wav",
                headers={
                    "Content-Disposition": "attachment; filename=generated_speech.wav",
                    "X-Audio-Duration": str(len(wav[0]) / self.model.sr)
                }
            )

        except HTTPException:
            raise
        except Exception as e:
            print(f"Error generating audio: {str(e)}")
            raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")

    @modal.fastapi_endpoint(docs=True, method="POST")
    def generate_with_file(
        self, 
        text: str = Form(..., description="Text to convert to speech"),
        voice_prompt: Optional[UploadFile] = File(None, description="Optional voice prompt audio file")
    ) -> StreamingResponse:
        """
        Generate speech audio from text with optional voice prompt file upload.
        
        Args:
            text: Text to convert to speech
            voice_prompt: Optional audio file for voice cloning
            
        Returns:
            StreamingResponse with generated audio as WAV file
        """
        try:
            self._validate_text_input(text)
            
            # Handle voice prompt file if provided
            audio_prompt_path = None
            if voice_prompt:
                if voice_prompt.content_type not in ["audio/wav", "audio/mpeg", "audio/mp3"]:
                    raise HTTPException(
                        status_code=400, 
                        detail="Voice prompt must be WAV, MP3, or MPEG audio file"
                    )
                
                # Read and save the uploaded file
                audio_data = voice_prompt.file.read()
                audio_prompt_path = AudioUtils.save_temp_audio_file(audio_data)

            # Generate audio
            wav = self._generate_audio(text, audio_prompt_path)

            # Create audio buffer
            buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr)
            
            return StreamingResponse(
                io.BytesIO(buffer.read()),
                media_type="audio/wav",
                headers={
                    "Content-Disposition": "attachment; filename=generated_speech.wav",
                    "X-Audio-Duration": str(len(wav[0]) / self.model.sr)
                }
            )

        except HTTPException:
            raise
        except Exception as e:
            print(f"Error generating audio: {str(e)}")
            raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")

    @modal.fastapi_endpoint(docs=True, method="POST")
    def generate_json(self, request: TTSRequest) -> TTSResponse:
        """
        Generate speech audio and return as JSON with base64 encoded audio.
        
        Args:
            request: TTSRequest containing text and optional voice prompt
            
        Returns:
            TTSResponse with base64 encoded audio data
        """
        try:
            self._validate_text_input(request.text)
            audio_prompt_path = self._process_voice_prompt(request.voice_prompt_base64)
            
            # Generate audio
            wav = self._generate_audio(request.text, audio_prompt_path)

            # Convert to base64
            buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr)
            audio_base64 = base64.b64encode(buffer.read()).decode('utf-8')
            duration = len(wav[0]) / self.model.sr

            return TTSResponse(
                success=True,
                message="Audio generated successfully",
                audio_base64=audio_base64,
                duration_seconds=duration
            )

        except HTTPException as http_exc:
            return TTSResponse(success=False, message=str(http_exc.detail))
        except Exception as e:
            print(f"Error generating audio: {str(e)}")
            return TTSResponse(success=False, message=f"Audio generation failed: {str(e)}")

    @modal.fastapi_endpoint(docs=True, method="POST")  
    def generate(self, prompt: str):
        """
        Legacy endpoint for backward compatibility.
        Generate audio waveform from the input text.
        """
        try:
            # Generate audio waveform from the input text
            wav = self.model.generate(prompt)

            # Create audio buffer
            buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr)

            # Return the audio as a streaming response with appropriate MIME type.
            return StreamingResponse(
                io.BytesIO(buffer.read()),
                media_type="audio/wav",
            )
        except Exception as e:
            print(f"Error in legacy endpoint: {str(e)}")
            raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
            
    @modal.fastapi_endpoint(docs=True, method="POST")
    def generate_audio_file(self, request: TTSRequest) -> Response:
        """
        Generate speech audio from text with optional voice prompt and return as a complete file.
        
        Unlike the streaming endpoint, this returns the entire file at once.
        
        Args:
            request: TTSRequest containing text and optional voice prompt
            
        Returns:
            Response with complete audio file data
        """
        try:
            self._validate_text_input(request.text)
            audio_prompt_path = self._process_voice_prompt(request.voice_prompt_base64)
            
            # Generate audio
            wav = self._generate_audio(request.text, audio_prompt_path)

            # Create audio buffer
            buffer = AudioUtils.save_audio_to_buffer(wav, self.model.sr)
            audio_data = buffer.read()
            duration = len(wav[0]) / self.model.sr
            
            # Return the complete audio file
            return Response(
                content=audio_data,
                media_type="audio/wav",
                headers={
                    "Content-Disposition": "attachment; filename=generated_speech.wav",
                    "X-Audio-Duration": str(duration)
                }
            )

        except HTTPException:
            raise
        except Exception as e:
            print(f"Error generating audio: {str(e)}")
            raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")

    @modal.fastapi_endpoint(docs=True, method="POST")
    def generate_full_text_audio(self, request: FullTextTTSRequest) -> StreamingResponse:
        """
        Generate speech audio from full text with server-side chunking and parallel processing.
        
        This endpoint handles texts of any length by:
        1. Chunking the text intelligently (respecting sentence/paragraph boundaries)
        2. Processing chunks in parallel using GPU resources
        3. Concatenating audio chunks with proper transitions
        4. Returning the final audio file
        
        Args:
            request: FullTextTTSRequest containing text and processing parameters
            
        Returns:
            StreamingResponse with final concatenated audio as WAV file
        """
        try:
            self._validate_text_input(request.text)
            audio_prompt_path = self._process_voice_prompt(request.voice_prompt_base64)
            
            print(f"Processing full text ({len(request.text)} chars) with server-side chunking...")
            
            # Initialize text chunker with request parameters
            chunker = TextChunker(
                max_chunk_size=request.max_chunk_size,
                overlap_sentences=request.overlap_sentences
            )
            
            # Chunk the text
            text_chunks = chunker.chunk_text(request.text)
            chunk_info = chunker.get_chunk_info(text_chunks)
            print(f"Split text into {len(text_chunks)} chunks for processing")
            
            # Initialize audio_chunks variable for processing info
            audio_chunks = []
              # If only one chunk, process directly
            if len(text_chunks) == 1:
                wav = self._generate_audio(text_chunks[0], audio_prompt_path)
                # For single chunk, pass the full wav object to maintain consistency
                final_audio = wav
                audio_chunks = [wav]  # For consistent processing info
            else:
                # Process chunks in parallel
                import concurrent.futures
                import numpy as np
                
                def process_chunk(chunk_text: str):
                    """Process a single chunk."""
                    wav_result = self._generate_audio(chunk_text, audio_prompt_path)
                    # Return the full wav result, not just wav[0]
                    return wav_result
                
                # Use ThreadPoolExecutor for parallel processing
                with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
                    # Submit all chunks for processing
                    future_to_chunk = {
                        executor.submit(process_chunk, chunk): i 
                        for i, chunk in enumerate(text_chunks)
                    }
                    
                    # Collect results in order
                    results = [None] * len(text_chunks)
                    for future in concurrent.futures.as_completed(future_to_chunk):
                        chunk_index = future_to_chunk[future]
                        try:
                            audio_result = future.result()
                            results[chunk_index] = audio_result
                        except Exception as exc:
                            print(f'Chunk {chunk_index} generated an exception: {exc}')
                            raise HTTPException(status_code=500, detail=f"Failed to process chunk {chunk_index}: {str(exc)}")
                
                # Filter out None results
                audio_chunks = [result for result in results if result is not None]
                
                if len(audio_chunks) != len(text_chunks):
                    raise HTTPException(status_code=500, detail=f"Only {len(audio_chunks)} out of {len(text_chunks)} chunks processed successfully")
                
                # Concatenate audio chunks
                print("Concatenating audio chunks...")
                concatenator = AudioConcatenator(
                    silence_duration=request.silence_duration,
                    fade_duration=request.fade_duration
                )
                
                final_audio = concatenator.concatenate_audio_chunks(audio_chunks, self.model.sr)
            
            # --- Start of new audio processing logic ---
            import torch
            import numpy as np

            processed_tensor = final_audio
            # Unwrap if it's a single-element tuple repeatedly
            while isinstance(processed_tensor, tuple) and len(processed_tensor) == 1:
                processed_tensor = processed_tensor[0]

            # Convert to PyTorch tensor if it's a NumPy array
            if isinstance(processed_tensor, np.ndarray):
                processed_tensor = torch.from_numpy(processed_tensor.astype(np.float32))

            if not isinstance(processed_tensor, torch.Tensor): # Check if it's a tensor now
                raise TypeError(f"Audio data after concatenation is not a tensor. Got type: {type(processed_tensor)}")

            # Ensure correct shape (C, L) for torchaudio.save
            if processed_tensor.ndim == 1:  # Shape (L,)
                audio_to_save = processed_tensor.unsqueeze(0)  # Convert to (1, L)
            elif processed_tensor.ndim == 2: # Shape (C, L)
                if processed_tensor.shape[0] == 0:
                    raise ValueError(f"Audio tensor has 0 channels: {processed_tensor.shape}")
                if processed_tensor.shape[0] > 1: # If C > 1 (stereo/multi-channel)
                    print(f"Multi-channel audio (shape {processed_tensor.shape}) detected. Taking the first channel.")
                    audio_to_save = processed_tensor[0, :].unsqueeze(0) # Result is (1, L)
                else: # Already (1, L)
                    audio_to_save = processed_tensor
            else:
                raise ValueError(f"Unexpected audio tensor dimensions: {processed_tensor.ndim}, shape: {processed_tensor.shape}")
            buffer = AudioUtils.save_audio_to_buffer(audio_to_save, self.model.sr)
            duration = audio_to_save.shape[1] / self.model.sr # Use shape[1] for length
            
            # Reset buffer position for reading
            buffer.seek(0)
            # --- End of new audio processing logic ---            # Prepare processing info
            processing_info = {
                "total_chunks": len(text_chunks),
                "processed_chunks": len(audio_chunks),
                "failed_chunks": len(text_chunks) - len(audio_chunks),
                "sample_rate": self.model.sr,
                "duration": duration
            }

            print(f"Full text processing complete! Final audio duration: {duration:.2f} seconds")            
            return StreamingResponse(
                io.BytesIO(buffer.read()),
                media_type="audio/wav",
                headers={
                    "Content-Disposition": "attachment; filename=generated_full_text_speech.wav",
                    "X-Audio-Duration": str(duration),
                    "X-Chunks-Processed": str(len(audio_chunks)),
                    "X-Total-Characters": str(len(request.text))
                }
            )

        except HTTPException as http_exc:
            print(f"HTTP exception in full text generation: {http_exc.detail}")
            raise http_exc
        except Exception as e:
            error_msg = f"Full text audio generation failed: {str(e)}"
            print(f"Exception in full text generation: {error_msg}")
            raise HTTPException(status_code=500, detail=error_msg)