File size: 6,937 Bytes
8f3c067
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""Simplified TTS audio generation that uses the enhanced API endpoints."""
import os
import requests
import tempfile
import soundfile as sf
import numpy as np
import gradio as gr


def generate_simple_tts_audio(text_input: str, audio_prompt_input=None, progress=None):
    """
    Generate TTS audio using the enhanced API that handles chunking and concatenation server-side.
    
    Args:
        text_input: The text to convert to speech (any length)
        audio_prompt_input: Optional audio prompt for voice cloning
        progress: Optional progress callback
        
    Returns:
        Tuple of (sample_rate, audio_data)
    """
    # Use the new full-text endpoint that handles everything server-side
    FULL_TEXT_ENDPOINT = os.getenv("FULL_TEXT_TTS_ENDPOINT", "YOUR-MODAL-ENDPOINT-URL/generate_full_text_audio")
    GENERATE_WITH_FILE_ENDPOINT = os.getenv("GENERATE_WITH_FILE_ENDPOINT", "YOUR-MODAL-ENDPOINT-URL/generate_with_file")
    
    if not text_input or len(text_input.strip()) == 0:
        raise gr.Error("Please enter some text to synthesize.")
    
    if progress: 
        progress(0.1, desc="Preparing request for full-text processing...")
    
    try:
        if audio_prompt_input is None:
            # Use the new full-text endpoint for enhanced processing
            if progress: 
                progress(0.3, desc="Sending full text to enhanced TTS API...")
            
            payload = {
                "text": text_input,
                "max_chunk_size": 800,
                "silence_duration": 0.5,
                "fade_duration": 0.1,
                "overlap_sentences": 0
            }
            
            response = requests.post(
                FULL_TEXT_ENDPOINT,
                json=payload,
                headers={"Content-Type": "application/json"},
                timeout=300,  # Longer timeout for full-text processing
                stream=True
            )
            
            if response.status_code != 200:
                raise gr.Error(f"API Error: {response.status_code} - {response.text}")
            
            if progress: 
                progress(0.6, desc="Server processing text chunks in parallel...")
            
            # Get content length if available for progress tracking
            content_length = response.headers.get('content-length')
            chunks_processed = response.headers.get('X-Chunks-Processed', 'unknown')
            total_chars = response.headers.get('X-Total-Characters', len(text_input))
            
            if progress:
                progress(0.7, desc=f"Processing {chunks_processed} chunks ({total_chars} characters)...")
            
            bytes_downloaded = 0
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
                for chunk in response.iter_content(chunk_size=8192):
                    if chunk:
                        temp_file.write(chunk)
                        bytes_downloaded += len(chunk)
                        
                        # Update progress based on bytes downloaded
                        if progress:
                            progress(0.7, desc=f"Downloading processed audio... ({bytes_downloaded // 1024}KB)")
                            
                temp_path = temp_file.name
                
            if progress: 
                progress(0.9, desc="Loading final audio...")
            
            audio_data, sample_rate = sf.read(temp_path)
            os.unlink(temp_path)
            
            if progress: 
                progress(1.0, desc=f"Complete! Processed {chunks_processed} chunks into final audio.")
            
            return (sample_rate, audio_data)
            
        else:
            # For voice cloning, still use the original endpoint
            if progress: 
                progress(0.3, desc="Preparing voice cloning request...")
            
            files = {'text': (None, text_input)}
            with open(audio_prompt_input, 'rb') as f:
                audio_content = f.read()
            files['voice_prompt'] = ('voice_prompt.wav', audio_content, 'audio/wav')
            
            if progress: 
                progress(0.5, desc="Sending request with voice cloning...")
            
            response = requests.post(
                GENERATE_WITH_FILE_ENDPOINT,
                files=files,
                timeout=180,
                stream=True
            )
            
            if response.status_code != 200:
                raise gr.Error(f"API Error: {response.status_code} - {response.text}")
            
            if progress: 
                progress(0.8, desc="Processing cloned voice...")
            
            bytes_downloaded = 0
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
                for chunk in response.iter_content(chunk_size=8192):
                    if chunk:
                        temp_file.write(chunk)
                        bytes_downloaded += len(chunk)
                        
                        if progress:
                            progress(0.8, desc=f"Downloading cloned audio... ({bytes_downloaded // 1024}KB)")
                            
                temp_path = temp_file.name
                
            audio_data, sample_rate = sf.read(temp_path)
            os.unlink(temp_path)
            
            if progress: 
                progress(1.0, desc="Voice cloning complete!")
            
            return (sample_rate, audio_data)
            
    except requests.exceptions.Timeout:
        raise gr.Error("Request timed out. The API might be processing a large text. Please try again.")
    except requests.exceptions.ConnectionError:
        raise gr.Error("Unable to connect to the API. Please check if the endpoint URL is correct.")
    except Exception as e:
        raise gr.Error(f"Error generating audio: {str(e)}")


def get_api_processing_info(text: str) -> dict:
    """
    Get processing information from the API without generating audio.
    
    Args:
        text: The text to analyze
        
    Returns:
        Dictionary with processing information
    """
    try:
        # This could be enhanced to call an API info endpoint
        text_length = len(text.strip()) if text else 0
        estimated_chunks = max(1, text_length // 800)
        
        return {
            "text_length": text_length,
            "estimated_chunks": estimated_chunks,
            "processing_mode": "server_side_parallel_gpu",
            "benefits": [
                "Server-side GPU acceleration",
                "Parallel chunk processing", 
                "Automatic audio concatenation",
                "Optimized for large texts",
                "No client-side resource usage"
            ]
        }
    except Exception as e:
        return {"error": f"Failed to analyze text: {str(e)}"}