spagestic commited on
Commit
de90621
·
1 Parent(s): 3924db1

feat: Improve TTS audio generation with enhanced progress tracking and error handling

Browse files
Files changed (1) hide show
  1. ui/chatterbox/generate_tts_audio.py +56 -32
ui/chatterbox/generate_tts_audio.py CHANGED
@@ -11,79 +11,103 @@ def generate_tts_audio(text_input: str, audio_prompt_input, progress=None):
11
 
12
  if not text_input or len(text_input.strip()) == 0:
13
  raise gr.Error("Please enter some text to synthesize.")
14
-
15
  if len(text_input) > 1000:
16
  raise gr.Error("Text is too long. Maximum 1000 characters allowed.")
17
 
18
- if progress:
19
- progress(0.1, desc="Preparing request...")
20
 
21
  try:
22
  if audio_prompt_input is None:
23
- if progress:
24
- progress(0.3, desc="Sending request to API...")
25
  payload = {"text": text_input}
26
  response = requests.post(
27
  GENERATE_AUDIO_ENDPOINT,
28
  json=payload,
29
  headers={"Content-Type": "application/json"},
30
- timeout=90 # Increased timeout for better reliability
 
31
  )
32
  if response.status_code != 200:
33
  raise gr.Error(f"API Error: {response.status_code} - {response.text}")
34
 
35
- if progress:
36
- progress(0.8, desc="Processing audio response...")
 
 
 
 
37
 
 
38
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
39
- temp_file.write(response.content)
 
 
 
 
 
 
 
 
 
 
 
 
40
  temp_path = temp_file.name
41
-
 
42
  audio_data, sample_rate = sf.read(temp_path)
43
  os.unlink(temp_path)
44
-
45
- if progress:
46
- progress(1.0, desc="Complete!")
47
-
48
  return (sample_rate, audio_data)
49
- else:
50
- if progress:
51
- progress(0.3, desc="Preparing voice prompt...")
52
 
 
 
53
  files = {'text': (None, text_input)}
54
  with open(audio_prompt_input, 'rb') as f:
55
  audio_content = f.read()
56
  files['voice_prompt'] = ('voice_prompt.wav', audio_content, 'audio/wav')
57
 
58
- if progress:
59
- progress(0.5, desc="Sending request with voice cloning...")
60
-
61
  response = requests.post(
62
  GENERATE_WITH_FILE_ENDPOINT,
63
  files=files,
64
- timeout=150 # Longer timeout for voice cloning
 
65
  )
66
  if response.status_code != 200:
67
  raise gr.Error(f"API Error: {response.status_code} - {response.text}")
68
 
69
- if progress:
70
- progress(0.8, desc="Processing cloned voice response...")
 
 
 
 
71
 
 
72
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
73
- temp_file.write(response.content)
 
 
 
 
 
 
 
 
 
 
 
74
  temp_path = temp_file.name
75
-
76
  audio_data, sample_rate = sf.read(temp_path)
77
  os.unlink(temp_path)
78
-
79
- if progress:
80
- progress(1.0, desc="Voice cloning complete!")
81
-
82
  return (sample_rate, audio_data)
83
 
84
  except requests.exceptions.Timeout:
85
- raise gr.Error("⏱️ Request timed out. The TTS API is taking too long to respond. This usually happens when the service is under heavy load. Please try again in a few moments.")
86
  except requests.exceptions.ConnectionError:
87
- raise gr.Error("🔌 Unable to connect to the TTS API. Please check your internet connection and try again.")
88
  except Exception as e:
89
- raise gr.Error(f"Error generating audio: {str(e)}")
 
11
 
12
  if not text_input or len(text_input.strip()) == 0:
13
  raise gr.Error("Please enter some text to synthesize.")
 
14
  if len(text_input) > 1000:
15
  raise gr.Error("Text is too long. Maximum 1000 characters allowed.")
16
 
17
+ if progress: progress(0.1, desc="Preparing request...")
 
18
 
19
  try:
20
  if audio_prompt_input is None:
21
+ if progress: progress(0.3, desc="Sending request to API...")
 
22
  payload = {"text": text_input}
23
  response = requests.post(
24
  GENERATE_AUDIO_ENDPOINT,
25
  json=payload,
26
  headers={"Content-Type": "application/json"},
27
+ timeout=120,
28
+ stream=True
29
  )
30
  if response.status_code != 200:
31
  raise gr.Error(f"API Error: {response.status_code} - {response.text}")
32
 
33
+ if progress: progress(0.6, desc="Streaming audio response...")
34
+
35
+ # Get content length if available for progress tracking
36
+ content_length = response.headers.get('content-length')
37
+ if content_length:
38
+ content_length = int(content_length)
39
 
40
+ bytes_downloaded = 0
41
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
42
+ for chunk in response.iter_content(chunk_size=8192):
43
+ if chunk:
44
+ temp_file.write(chunk)
45
+ bytes_downloaded += len(chunk)
46
+
47
+ # Update progress based on bytes downloaded
48
+ if content_length and progress:
49
+ download_progress = min(0.3, (bytes_downloaded / content_length) * 0.3)
50
+ progress(0.6 + download_progress, desc=f"Downloading audio... ({bytes_downloaded // 1024}KB)")
51
+ elif progress:
52
+ # If no content length, just show bytes downloaded
53
+ progress(0.6, desc=f"Downloading audio... ({bytes_downloaded // 1024}KB)")
54
+
55
  temp_path = temp_file.name
56
+
57
+ if progress: progress(0.9, desc="Processing audio...")
58
  audio_data, sample_rate = sf.read(temp_path)
59
  os.unlink(temp_path)
60
+ if progress: progress(1.0, desc="Complete!")
 
 
 
61
  return (sample_rate, audio_data)
 
 
 
62
 
63
+ else:
64
+ if progress: progress(0.3, desc="Preparing voice prompt...")
65
  files = {'text': (None, text_input)}
66
  with open(audio_prompt_input, 'rb') as f:
67
  audio_content = f.read()
68
  files['voice_prompt'] = ('voice_prompt.wav', audio_content, 'audio/wav')
69
 
70
+ if progress: progress(0.5, desc="Sending request with voice cloning...")
 
 
71
  response = requests.post(
72
  GENERATE_WITH_FILE_ENDPOINT,
73
  files=files,
74
+ timeout=180,
75
+ stream=True
76
  )
77
  if response.status_code != 200:
78
  raise gr.Error(f"API Error: {response.status_code} - {response.text}")
79
 
80
+ if progress: progress(0.8, desc="Streaming cloned voice response...")
81
+
82
+ # Get content length if available for progress tracking
83
+ content_length = response.headers.get('content-length')
84
+ if content_length:
85
+ content_length = int(content_length)
86
 
87
+ bytes_downloaded = 0
88
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
89
+ for chunk in response.iter_content(chunk_size=8192):
90
+ if chunk:
91
+ temp_file.write(chunk)
92
+ bytes_downloaded += len(chunk)
93
+
94
+ # Update progress based on bytes downloaded for voice cloning
95
+ if content_length and progress:
96
+ download_progress = min(0.15, (bytes_downloaded / content_length) * 0.15)
97
+ progress(0.8 + download_progress, desc=f"Downloading cloned audio... ({bytes_downloaded // 1024}KB)")
98
+ elif progress:
99
+ progress(0.8, desc=f"Downloading cloned audio... ({bytes_downloaded // 1024}KB)")
100
+
101
  temp_path = temp_file.name
102
+
103
  audio_data, sample_rate = sf.read(temp_path)
104
  os.unlink(temp_path)
105
+ if progress: progress(1.0, desc="Voice cloning complete!")
 
 
 
106
  return (sample_rate, audio_data)
107
 
108
  except requests.exceptions.Timeout:
109
+ raise gr.Error("Request timed out. The API might be under heavy load. Please try again.")
110
  except requests.exceptions.ConnectionError:
111
+ raise gr.Error("Unable to connect to the API. Please check if the endpoint URL is correct.")
112
  except Exception as e:
113
+ raise gr.Error(f"Error generating audio: {str(e)}")