peterlllmm commited on
Commit
c2c369a
·
verified ·
1 Parent(s): 89f3c44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +167 -42
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import nltk
2
  nltk.download('punkt')
 
3
  nltk.download('punkt_tab')
4
 
5
  import random
@@ -11,35 +12,47 @@ import io
11
  import soundfile as sf
12
  from pydub import AudioSegment
13
  from nltk.tokenize import sent_tokenize
14
- import os
15
 
16
- # Determine the device
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
- print(f"🚀 Running on device: {DEVICE}")
19
 
20
  # --- Global Model Initialization ---
21
  MODEL = None
22
 
23
  def get_or_load_model():
 
 
 
24
  global MODEL
25
  if MODEL is None:
 
26
  try:
 
27
  MODEL = ChatterboxTTS.from_pretrained(DEVICE)
 
28
  if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
29
  MODEL.to(DEVICE)
 
30
  except Exception as e:
31
  print(f"Error loading model: {e}")
 
32
  raise
33
  return MODEL
34
 
 
 
35
  try:
36
  get_or_load_model()
37
  except Exception as e:
38
- print(f"CRITICAL: Model failed to load: {e}")
39
 
40
  def set_seed(seed: int):
 
41
  torch.manual_seed(seed)
42
  if DEVICE == "cuda":
 
43
  torch.cuda.manual_seed_all(seed)
44
  random.seed(seed)
45
  np.random.seed(seed)
@@ -47,78 +60,190 @@ def set_seed(seed: int):
47
  def generate_tts_audio(
48
  text_input: str,
49
  audio_prompt_path_input: str = None,
50
- exaggeration_input: float = 0.8, # Defaulted for 'Elder' gravitas
51
- temperature_input: float = 0.7, # Defaulted for consistency
52
  seed_num_input: int = 0,
53
- cfgw_input: float = 0.3, # Defaulted for 'Slow' pace
54
- pause_duration: float = 1.0 # NEW: Silence between chunks
55
- ) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  current_model = get_or_load_model()
 
 
 
57
  if seed_num_input != 0:
58
  set_seed(int(seed_num_input))
59
 
 
 
 
60
  generate_kwargs = {
61
  "exaggeration": exaggeration_input,
62
  "temperature": temperature_input,
63
  "cfg_weight": cfgw_input,
64
  }
65
 
 
66
  if audio_prompt_path_input:
67
- generate_kwargs["audio_prompt_path"] = audio_prompt_path_input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  all_audio_segments = []
 
70
  sentences = sent_tokenize(text_input)
71
-
72
- # Create the 'Silence' segment (in milliseconds)
73
- silence_gap = AudioSegment.silent(duration=int(pause_duration * 1000))
 
 
 
 
74
 
75
  for sentence in sentences:
76
- print(f"Processing: '{sentence[:30]}...'")
77
- wav_tensor = current_model.generate(sentence, **generate_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  wav_numpy = wav_tensor.squeeze(0).cpu().numpy()
79
 
80
- # Convert to AudioSegment
81
  buffer = io.BytesIO()
82
  sf.write(buffer, wav_numpy, current_model.sr, format='WAV')
83
  buffer.seek(0)
84
  audio_segment = AudioSegment.from_wav(buffer)
85
-
86
- # Add the clip + the pause
87
  all_audio_segments.append(audio_segment)
88
- all_audio_segments.append(silence_gap)
89
 
90
- # Combine all
91
- combined_audio = AudioSegment.empty()
92
- for seg in all_audio_segments:
93
- combined_audio += seg
 
 
 
94
 
95
- output_filename = "elderly_voice_output.mp3"
 
96
  combined_audio.export(output_filename, format="mp3")
97
- return output_filename
98
 
99
- # --- Gradio Interface ---
100
- with gr.Blocks(title="Elderly TTS Studio") as demo:
101
- gr.Markdown("# 👴 Elderly Wisdom TTS Generator")
102
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  with gr.Row():
104
  with gr.Column():
105
- text = gr.Textbox(label="Script", lines=8, placeholder="Enter the story here...")
106
- ref_wav = gr.Audio(type="filepath", label="Voice Reference (Upload a raspy/old voice)")
107
-
108
- with gr.Accordion("Character Pacing (Elder Settings)", open=True):
109
- cfg_weight = gr.Slider(0.1, 1.0, step=.05, label="Pace (0.3 is slow/elderly)", value=0.3)
110
- pause_len = gr.Slider(0, 3.0, step=.1, label="Breath Pause (Seconds between sentences)", value=1.0)
111
- exaggeration = gr.Slider(0.2, 2.0, step=.05, label="Gravitas/Emotion", value=0.8)
112
-
113
- run_btn = gr.Button("Generate 8-Minute Audio", variant="primary")
114
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  with gr.Column():
116
- audio_output = gr.Audio(label="Final Combined MP3")
 
117
 
 
118
  run_btn.click(
119
  fn=generate_tts_audio,
120
- inputs=[text, ref_wav, exaggeration, gr.State(0.7), gr.State(0), cfg_weight, pause_len],
 
 
 
 
 
 
 
121
  outputs=[audio_output],
122
  )
123
 
124
- demo.launch(share=True)
 
 
 
 
1
  import nltk
2
  nltk.download('punkt')
3
+ # Explicitly download 'punkt_tab' as it's often required by sent_tokenize
4
  nltk.download('punkt_tab')
5
 
6
  import random
 
12
  import soundfile as sf
13
  from pydub import AudioSegment
14
  from nltk.tokenize import sent_tokenize
15
+ import os # Added for temporary file handling
16
 
17
+ # Determine the device to run on (GPU if available, otherwise CPU)
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+ print(f"?? Running on device: {DEVICE}")
20
 
21
  # --- Global Model Initialization ---
22
  MODEL = None
23
 
24
  def get_or_load_model():
25
+ """Loads the ChatterboxTTS model if it hasn't been loaded already,
26
+ and ensures it's on the correct device. This helps avoid reloading
27
+ the model multiple times which can be slow."""
28
  global MODEL
29
  if MODEL is None:
30
+ print("Model not loaded, initializing...")
31
  try:
32
+ # Load the model and move it to the determined device (CPU or CUDA)
33
  MODEL = ChatterboxTTS.from_pretrained(DEVICE)
34
+ # Ensure the model is explicitly on the correct device after loading
35
  if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
36
  MODEL.to(DEVICE)
37
+ print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
38
  except Exception as e:
39
  print(f"Error loading model: {e}")
40
+ # Re-raise the exception to indicate a critical failure
41
  raise
42
  return MODEL
43
 
44
+ # Attempt to load the model at startup of the script.
45
+ # This ensures the model is ready when the Gradio interface starts.
46
  try:
47
  get_or_load_model()
48
  except Exception as e:
49
+ print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
50
 
51
  def set_seed(seed: int):
52
+ """Sets the random seed for reproducibility across torch, numpy, and random."""
53
  torch.manual_seed(seed)
54
  if DEVICE == "cuda":
55
+ torch.cuda.manual_seed(seed)
56
  torch.cuda.manual_seed_all(seed)
57
  random.seed(seed)
58
  np.random.seed(seed)
 
60
  def generate_tts_audio(
61
  text_input: str,
62
  audio_prompt_path_input: str = None,
63
+ exaggeration_input: float = 0.5,
64
+ temperature_input: float = 0.8,
65
  seed_num_input: int = 0,
66
+ cfgw_input: float = 0.5
67
+ ) -> str: # Return type changed to str (filepath)
68
+ """
69
+ Generate high-quality speech audio from text using ChatterboxTTS model with optional reference audio styling.
70
+ Handles long scripts by chunking text, generating audio for each chunk, and combining them into an MP3.
71
+
72
+ Args:
73
+ text_input (str): The text to synthesize into speech.
74
+ audio_prompt_path_input (str, optional): File path or URL to the reference audio file that defines the target voice style. Defaults to None.
75
+ exaggeration_input (float, optional): Controls speech expressiveness (0.25-2.0, neutral=0.5, extreme values may be unstable). Defaults to 0.5.
76
+ temperature_input (float, optional): Controls randomness in generation (0.05-5.0, higher=more varied). Defaults to 0.8.
77
+ seed_num_input (int, optional): Random seed for reproducible results (0 for random generation). Defaults to 0.
78
+ cfgw_input (float, optional): CFG/Pace weight controlling generation guidance (0.2-1.0). Defaults to 0.5.
79
+ Returns:
80
+ str: Filepath to the generated combined MP3 audio waveform.
81
+ """
82
  current_model = get_or_load_model()
83
+ if current_model is None:
84
+ raise RuntimeError("TTS model is not loaded. Please check the startup logs for errors.")
85
+
86
  if seed_num_input != 0:
87
  set_seed(int(seed_num_input))
88
 
89
+ print(f"Generating audio for text: '{text_input[:100]}...' (first 100 chars)")
90
+ print(f"Audio prompt path received: {audio_prompt_path_input}") # Debug print for the received path
91
+
92
  generate_kwargs = {
93
  "exaggeration": exaggeration_input,
94
  "temperature": temperature_input,
95
  "cfg_weight": cfgw_input,
96
  }
97
 
98
+ processed_audio_prompt_path = None
99
  if audio_prompt_path_input:
100
+ try:
101
+ # Load the input audio using pydub
102
+ audio = AudioSegment.from_file(audio_prompt_path_input)
103
+ # Create a temporary WAV file to ensure compatibility with ChatterboxTTS
104
+ temp_wav_path = "temp_prompt.wav"
105
+ audio.export(temp_wav_path, format="wav")
106
+ processed_audio_prompt_path = temp_wav_path
107
+ print(f"Converted audio prompt to temporary WAV: {processed_audio_prompt_path}")
108
+ except Exception as e:
109
+ print(f"Warning: Could not process audio prompt file '{audio_prompt_path_input}'. Error: {e}")
110
+ print("Proceeding without audio prompt (will use default voice).")
111
+ # If conversion fails, ensure the audio prompt path is not used
112
+ processed_audio_prompt_path = None
113
+
114
+ if processed_audio_prompt_path:
115
+ generate_kwargs["audio_prompt_path"] = processed_audio_prompt_path
116
 
117
  all_audio_segments = []
118
+ # Split text into sentences for more natural chunking
119
  sentences = sent_tokenize(text_input)
120
+
121
+ # Chatterbox model has an implicit input limit, typically around 300 characters.
122
+ # We'll chunk sentences to stay within this limit.
123
+ MAX_CHARS_PER_MODEL_INPUT = 300
124
+
125
+ current_chunk_sentences = []
126
+ current_chunk_char_count = 0
127
 
128
  for sentence in sentences:
129
+ # If adding the current sentence exceeds the max chars, process the current chunk
130
+ # and ensure current_chunk_sentences is not empty to avoid creating empty chunks
131
+ if current_chunk_char_count + len(sentence) + 1 > MAX_CHARS_PER_MODEL_INPUT and current_chunk_sentences: # +1 for space
132
+ chunk_text = " ".join(current_chunk_sentences)
133
+ print(f"Processing chunk (chars: {len(chunk_text)}): '{chunk_text[:50]}...'")
134
+ wav_tensor = current_model.generate(chunk_text, **generate_kwargs)
135
+ wav_numpy = wav_tensor.squeeze(0).cpu().numpy()
136
+
137
+ # Convert numpy array to AudioSegment via an in-memory WAV buffer
138
+ buffer = io.BytesIO()
139
+ sf.write(buffer, wav_numpy, current_model.sr, format='WAV')
140
+ buffer.seek(0) # Rewind the buffer to the beginning
141
+ audio_segment = AudioSegment.from_wav(buffer)
142
+ all_audio_segments.append(audio_segment)
143
+
144
+ # Start a new chunk with the current sentence
145
+ current_chunk_sentences = [sentence]
146
+ current_chunk_char_count = len(sentence)
147
+ else:
148
+ current_chunk_sentences.append(sentence)
149
+ # Add 1 for space between sentences, but only if it's not the very first sentence in a chunk
150
+ current_chunk_char_count += len(sentence) + (1 if current_chunk_sentences else 0)
151
+
152
+ # Process the last remaining chunk
153
+ if current_chunk_sentences:
154
+ chunk_text = " ".join(current_chunk_sentences)
155
+ print(f"Processing final chunk (chars: {len(chunk_text)}): '{chunk_text[:50]}...'")
156
+ wav_tensor = current_model.generate(chunk_text, **generate_kwargs)
157
  wav_numpy = wav_tensor.squeeze(0).cpu().numpy()
158
 
 
159
  buffer = io.BytesIO()
160
  sf.write(buffer, wav_numpy, current_model.sr, format='WAV')
161
  buffer.seek(0)
162
  audio_segment = AudioSegment.from_wav(buffer)
 
 
163
  all_audio_segments.append(audio_segment)
 
164
 
165
+ if not all_audio_segments:
166
+ raise ValueError("No audio segments were generated. Please ensure the input text is not empty or too short.")
167
+
168
+ # Concatenate all audio segments
169
+ combined_audio = all_audio_segments[0]
170
+ for i in range(1, len(all_audio_segments)):
171
+ combined_audio += all_audio_segments[i]
172
 
173
+ # Export to MP3 format
174
+ output_filename = "combined_chatterbox_output.mp3"
175
  combined_audio.export(output_filename, format="mp3")
 
176
 
177
+ print(f"Combined audio generated and saved as {output_filename}")
178
+
179
+ # Clean up the temporary WAV file if it was created
180
+ if processed_audio_prompt_path and os.path.exists(processed_audio_prompt_path):
181
+ os.remove(processed_audio_prompt_path)
182
+ print(f"Cleaned up temporary prompt file: {processed_audio_prompt_path}")
183
+
184
+ return output_filename # Return the filepath for Gradio
185
+
186
+ # --- Gradio Interface Definition ---
187
+ with gr.Blocks() as demo:
188
+ gr.Markdown(
189
+ """
190
+ # Chatterbox TTS Demo
191
+ Generate high-quality speech from text with reference audio styling.
192
+ Now supports longer scripts and MP3 output!
193
+ """
194
+ )
195
  with gr.Row():
196
  with gr.Column():
197
+ text = gr.Textbox(
198
+ value="""
199
+ The quick brown fox jumps over the lazy dog. This is a common pangram used to display all letters of the alphabet.
200
+ Now, let's try a slightly longer passage to test the new chunking functionality.
201
+ This should demonstrate how the system handles multiple sentences and combines them seamlessly.
202
+ We are aiming for a natural flow, even with extended input.
203
+ The sun dipped below the horizon, painting the sky in hues of orange and purple.
204
+ A gentle breeze rustled through the leaves, carrying the scent of night-blooming jasmine.
205
+ Soon, the stars would emerge, tiny pinpricks of light in the vast, dark canvas above.
206
+ """,
207
+ label="Text to synthesize (can be long)",
208
+ lines=10 # Increased lines for longer text input
209
+ )
210
+ # Gradio's Audio component handles file uploads directly.
211
+ # The 'value' here is a placeholder for the demo.
212
+ ref_wav = gr.Audio(
213
+ sources=["upload", "microphone"],
214
+ type="filepath",
215
+ label="Reference Audio File (Optional)",
216
+ value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac" # Default example audio
217
+ )
218
+ exaggeration = gr.Slider(
219
+ 0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
220
+ )
221
+ cfg_weight = gr.Slider(
222
+ 0.2, 1, step=.05, label="CFG/Pace", value=0.5
223
+ )
224
+ with gr.Accordion("More options", open=False):
225
+ seed_num = gr.Number(value=0, label="Random seed (0 for random)")
226
+ temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
227
+ run_btn = gr.Button("Generate", variant="primary")
228
  with gr.Column():
229
+ # Output type is now implicitly a filepath to the MP3
230
+ audio_output = gr.Audio(label="Output Audio (MP3)")
231
 
232
+ # Define the action when the "Generate" button is clicked
233
  run_btn.click(
234
  fn=generate_tts_audio,
235
+ inputs=[
236
+ text,
237
+ ref_wav,
238
+ exaggeration,
239
+ temp,
240
+ seed_num,
241
+ cfg_weight,
242
+ ],
243
  outputs=[audio_output],
244
  )
245
 
246
+ # Launch the Gradio interface.
247
+ # Use share=True to get a public URL for the app, essential for Colab.
248
+ # debug=True can be useful for seeing more detailed error messages in the Colab output.
249
+ demo.launch(share=True, debug=True)