peterlllmm commited on
Commit
8196470
·
verified ·
1 Parent(s): c2c369a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -207
app.py CHANGED
@@ -1,249 +1,178 @@
1
  import nltk
2
- nltk.download('punkt')
3
- # Explicitly download 'punkt_tab' as it's often required by sent_tokenize
4
- nltk.download('punkt_tab')
5
 
6
  import random
7
  import numpy as np
8
  import torch
9
- from chatterbox.src.chatterbox.tts import ChatterboxTTS
10
- import gradio as gr
11
  import io
 
12
  import soundfile as sf
13
- from pydub import AudioSegment
14
  from nltk.tokenize import sent_tokenize
15
- import os # Added for temporary file handling
 
16
 
17
- # Determine the device to run on (GPU if available, otherwise CPU)
 
 
 
 
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
- print(f"?? Running on device: {DEVICE}")
20
 
21
- # --- Global Model Initialization ---
 
 
22
  MODEL = None
23
 
24
- def get_or_load_model():
25
- """Loads the ChatterboxTTS model if it hasn't been loaded already,
26
- and ensures it's on the correct device. This helps avoid reloading
27
- the model multiple times which can be slow."""
28
  global MODEL
29
  if MODEL is None:
30
- print("Model not loaded, initializing...")
31
- try:
32
- # Load the model and move it to the determined device (CPU or CUDA)
33
- MODEL = ChatterboxTTS.from_pretrained(DEVICE)
34
- # Ensure the model is explicitly on the correct device after loading
35
- if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
36
- MODEL.to(DEVICE)
37
- print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
38
- except Exception as e:
39
- print(f"Error loading model: {e}")
40
- # Re-raise the exception to indicate a critical failure
41
- raise
42
  return MODEL
43
 
44
- # Attempt to load the model at startup of the script.
45
- # This ensures the model is ready when the Gradio interface starts.
46
- try:
47
- get_or_load_model()
48
- except Exception as e:
49
- print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
50
 
51
- def set_seed(seed: int):
52
- """Sets the random seed for reproducibility across torch, numpy, and random."""
 
 
53
  torch.manual_seed(seed)
54
  if DEVICE == "cuda":
55
- torch.cuda.manual_seed(seed)
56
  torch.cuda.manual_seed_all(seed)
57
  random.seed(seed)
58
  np.random.seed(seed)
59
 
60
- def generate_tts_audio(
61
- text_input: str,
62
- audio_prompt_path_input: str = None,
63
- exaggeration_input: float = 0.5,
64
- temperature_input: float = 0.8,
65
- seed_num_input: int = 0,
66
- cfgw_input: float = 0.5
67
- ) -> str: # Return type changed to str (filepath)
68
- """
69
- Generate high-quality speech audio from text using ChatterboxTTS model with optional reference audio styling.
70
- Handles long scripts by chunking text, generating audio for each chunk, and combining them into an MP3.
71
-
72
- Args:
73
- text_input (str): The text to synthesize into speech.
74
- audio_prompt_path_input (str, optional): File path or URL to the reference audio file that defines the target voice style. Defaults to None.
75
- exaggeration_input (float, optional): Controls speech expressiveness (0.25-2.0, neutral=0.5, extreme values may be unstable). Defaults to 0.5.
76
- temperature_input (float, optional): Controls randomness in generation (0.05-5.0, higher=more varied). Defaults to 0.8.
77
- seed_num_input (int, optional): Random seed for reproducible results (0 for random generation). Defaults to 0.
78
- cfgw_input (float, optional): CFG/Pace weight controlling generation guidance (0.2-1.0). Defaults to 0.5.
79
- Returns:
80
- str: Filepath to the generated combined MP3 audio waveform.
81
- """
82
- current_model = get_or_load_model()
83
- if current_model is None:
84
- raise RuntimeError("TTS model is not loaded. Please check the startup logs for errors.")
85
-
86
- if seed_num_input != 0:
87
- set_seed(int(seed_num_input))
88
-
89
- print(f"Generating audio for text: '{text_input[:100]}...' (first 100 chars)")
90
- print(f"Audio prompt path received: {audio_prompt_path_input}") # Debug print for the received path
91
-
92
- generate_kwargs = {
93
- "exaggeration": exaggeration_input,
94
- "temperature": temperature_input,
95
- "cfg_weight": cfgw_input,
96
  }
97
 
98
- processed_audio_prompt_path = None
99
- if audio_prompt_path_input:
 
 
 
100
  try:
101
- # Load the input audio using pydub
102
- audio = AudioSegment.from_file(audio_prompt_path_input)
103
- # Create a temporary WAV file to ensure compatibility with ChatterboxTTS
104
- temp_wav_path = "temp_prompt.wav"
105
- audio.export(temp_wav_path, format="wav")
106
- processed_audio_prompt_path = temp_wav_path
107
- print(f"Converted audio prompt to temporary WAV: {processed_audio_prompt_path}")
108
- except Exception as e:
109
- print(f"Warning: Could not process audio prompt file '{audio_prompt_path_input}'. Error: {e}")
110
- print("Proceeding without audio prompt (will use default voice).")
111
- # If conversion fails, ensure the audio prompt path is not used
112
- processed_audio_prompt_path = None
113
-
114
- if processed_audio_prompt_path:
115
- generate_kwargs["audio_prompt_path"] = processed_audio_prompt_path
116
-
117
- all_audio_segments = []
118
- # Split text into sentences for more natural chunking
119
- sentences = sent_tokenize(text_input)
120
-
121
- # Chatterbox model has an implicit input limit, typically around 300 characters.
122
- # We'll chunk sentences to stay within this limit.
123
- MAX_CHARS_PER_MODEL_INPUT = 300
124
-
125
- current_chunk_sentences = []
126
- current_chunk_char_count = 0
127
-
128
- for sentence in sentences:
129
- # If adding the current sentence exceeds the max chars, process the current chunk
130
- # and ensure current_chunk_sentences is not empty to avoid creating empty chunks
131
- if current_chunk_char_count + len(sentence) + 1 > MAX_CHARS_PER_MODEL_INPUT and current_chunk_sentences: # +1 for space
132
- chunk_text = " ".join(current_chunk_sentences)
133
- print(f"Processing chunk (chars: {len(chunk_text)}): '{chunk_text[:50]}...'")
134
- wav_tensor = current_model.generate(chunk_text, **generate_kwargs)
135
- wav_numpy = wav_tensor.squeeze(0).cpu().numpy()
136
-
137
- # Convert numpy array to AudioSegment via an in-memory WAV buffer
138
- buffer = io.BytesIO()
139
- sf.write(buffer, wav_numpy, current_model.sr, format='WAV')
140
- buffer.seek(0) # Rewind the buffer to the beginning
141
- audio_segment = AudioSegment.from_wav(buffer)
142
- all_audio_segments.append(audio_segment)
143
-
144
- # Start a new chunk with the current sentence
145
- current_chunk_sentences = [sentence]
146
- current_chunk_char_count = len(sentence)
147
  else:
148
- current_chunk_sentences.append(sentence)
149
- # Add 1 for space between sentences, but only if it's not the very first sentence in a chunk
150
- current_chunk_char_count += len(sentence) + (1 if current_chunk_sentences else 0)
 
 
 
 
151
 
152
- # Process the last remaining chunk
153
- if current_chunk_sentences:
154
- chunk_text = " ".join(current_chunk_sentences)
155
- print(f"Processing final chunk (chars: {len(chunk_text)}): '{chunk_text[:50]}...'")
156
- wav_tensor = current_model.generate(chunk_text, **generate_kwargs)
157
- wav_numpy = wav_tensor.squeeze(0).cpu().numpy()
 
 
 
 
 
158
 
159
  buffer = io.BytesIO()
160
- sf.write(buffer, wav_numpy, current_model.sr, format='WAV')
161
  buffer.seek(0)
162
- audio_segment = AudioSegment.from_wav(buffer)
163
- all_audio_segments.append(audio_segment)
164
-
165
- if not all_audio_segments:
166
- raise ValueError("No audio segments were generated. Please ensure the input text is not empty or too short.")
167
 
168
- # Concatenate all audio segments
169
- combined_audio = all_audio_segments[0]
170
- for i in range(1, len(all_audio_segments)):
171
- combined_audio += all_audio_segments[i]
172
 
173
- # Export to MP3 format
174
- output_filename = "combined_chatterbox_output.mp3"
175
- combined_audio.export(output_filename, format="mp3")
176
 
177
- print(f"Combined audio generated and saved as {output_filename}")
 
 
 
 
178
 
179
- # Clean up the temporary WAV file if it was created
180
- if processed_audio_prompt_path and os.path.exists(processed_audio_prompt_path):
181
- os.remove(processed_audio_prompt_path)
182
- print(f"Cleaned up temporary prompt file: {processed_audio_prompt_path}")
183
 
184
- return output_filename # Return the filepath for Gradio
185
 
186
- # --- Gradio Interface Definition ---
 
 
187
  with gr.Blocks() as demo:
188
- gr.Markdown(
189
- """
190
- # Chatterbox TTS Demo
191
- Generate high-quality speech from text with reference audio styling.
192
- Now supports longer scripts and MP3 output!
193
- """
 
 
 
 
 
 
194
  )
195
- with gr.Row():
196
- with gr.Column():
197
- text = gr.Textbox(
198
- value="""
199
- The quick brown fox jumps over the lazy dog. This is a common pangram used to display all letters of the alphabet.
200
- Now, let's try a slightly longer passage to test the new chunking functionality.
201
- This should demonstrate how the system handles multiple sentences and combines them seamlessly.
202
- We are aiming for a natural flow, even with extended input.
203
- The sun dipped below the horizon, painting the sky in hues of orange and purple.
204
- A gentle breeze rustled through the leaves, carrying the scent of night-blooming jasmine.
205
- Soon, the stars would emerge, tiny pinpricks of light in the vast, dark canvas above.
206
- """,
207
- label="Text to synthesize (can be long)",
208
- lines=10 # Increased lines for longer text input
209
- )
210
- # Gradio's Audio component handles file uploads directly.
211
- # The 'value' here is a placeholder for the demo.
212
- ref_wav = gr.Audio(
213
- sources=["upload", "microphone"],
214
- type="filepath",
215
- label="Reference Audio File (Optional)",
216
- value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac" # Default example audio
217
- )
218
- exaggeration = gr.Slider(
219
- 0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
220
- )
221
- cfg_weight = gr.Slider(
222
- 0.2, 1, step=.05, label="CFG/Pace", value=0.5
223
- )
224
- with gr.Accordion("More options", open=False):
225
- seed_num = gr.Number(value=0, label="Random seed (0 for random)")
226
- temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
227
- run_btn = gr.Button("Generate", variant="primary")
228
- with gr.Column():
229
- # Output type is now implicitly a filepath to the MP3
230
- audio_output = gr.Audio(label="Output Audio (MP3)")
231
-
232
- # Define the action when the "Generate" button is clicked
233
- run_btn.click(
234
- fn=generate_tts_audio,
235
- inputs=[
236
- text,
237
- ref_wav,
238
- exaggeration,
239
- temp,
240
- seed_num,
241
- cfg_weight,
242
- ],
243
- outputs=[audio_output],
244
  )
245
 
246
- # Launch the Gradio interface.
247
- # Use share=True to get a public URL for the app, essential for Colab.
248
- # debug=True can be useful for seeing more detailed error messages in the Colab output.
249
- demo.launch(share=True, debug=True)
 
1
  import nltk
2
+ nltk.download("punkt")
 
 
3
 
4
  import random
5
  import numpy as np
6
  import torch
 
 
7
  import io
8
+ import os
9
  import soundfile as sf
 
10
  from nltk.tokenize import sent_tokenize
11
+ from pydub import AudioSegment
12
+ import gradio as gr
13
 
14
+ from chatterbox.src.chatterbox.tts import ChatterboxTTS
15
+
16
+ # ===============================
17
+ # DEVICE
18
+ # ===============================
19
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
+ print(f"Running on: {DEVICE}")
21
 
22
+ # ===============================
23
+ # LOAD MODEL ONCE
24
+ # ===============================
25
  MODEL = None
26
 
27
+ def get_model():
 
 
 
28
  global MODEL
29
  if MODEL is None:
30
+ print("Loading Chatterbox model...")
31
+ MODEL = ChatterboxTTS.from_pretrained(DEVICE)
32
+ if hasattr(MODEL, "to"):
33
+ MODEL.to(DEVICE)
34
+ print("Model ready.")
 
 
 
 
 
 
 
35
  return MODEL
36
 
37
+ get_model()
 
 
 
 
 
38
 
39
+ # ===============================
40
+ # SEED
41
+ # ===============================
42
+ def set_seed(seed):
43
  torch.manual_seed(seed)
44
  if DEVICE == "cuda":
 
45
  torch.cuda.manual_seed_all(seed)
46
  random.seed(seed)
47
  np.random.seed(seed)
48
 
49
+ # ===============================
50
+ # PODCAST SAFE SETTINGS
51
+ # ===============================
52
+ MAX_CHARS = 220 # stable for chatterbox
53
+ SILENCE_MS = 350 # natural pause
54
+ FADE_IN = 30
55
+ FADE_OUT = 60
56
+
57
+ # ===============================
58
+ # MAIN TTS FUNCTION
59
+ # ===============================
60
+ def generate_tts(
61
+ text,
62
+ ref_audio=None,
63
+ exaggeration=0.4,
64
+ temperature=0.7,
65
+ seed=0,
66
+ cfg_weight=0.6,
67
+ ):
68
+
69
+ model = get_model()
70
+
71
+ if seed != 0:
72
+ set_seed(int(seed))
73
+
74
+ kwargs = {
75
+ "exaggeration": exaggeration,
76
+ "temperature": temperature,
77
+ "cfg_weight": cfg_weight,
 
 
 
 
 
 
 
78
  }
79
 
80
+ # --------------------------------
81
+ # Handle reference voice
82
+ # --------------------------------
83
+ temp_prompt = None
84
+ if ref_audio:
85
  try:
86
+ audio = AudioSegment.from_file(ref_audio)
87
+ temp_prompt = "voice_prompt.wav"
88
+ audio.export(temp_prompt, format="wav")
89
+ kwargs["audio_prompt_path"] = temp_prompt
90
+ except:
91
+ print("Reference audio failed — using default voice.")
92
+
93
+ # --------------------------------
94
+ # Sentence chunking
95
+ # --------------------------------
96
+ sentences = sent_tokenize(text)
97
+
98
+ chunks = []
99
+ current = ""
100
+
101
+ for s in sentences:
102
+ if len(current) + len(s) < MAX_CHARS:
103
+ current += " " + s
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  else:
105
+ chunks.append(current.strip())
106
+ current = s
107
+
108
+ if current.strip():
109
+ chunks.append(current.strip())
110
+
111
+ print(f"Total chunks: {len(chunks)}")
112
 
113
+ # --------------------------------
114
+ # Generate audio per chunk
115
+ # --------------------------------
116
+ final_audio = AudioSegment.empty()
117
+ silence = AudioSegment.silent(duration=SILENCE_MS)
118
+
119
+ for i, chunk in enumerate(chunks):
120
+ print(f"Generating chunk {i+1}/{len(chunks)}")
121
+
122
+ wav = model.generate(chunk, **kwargs)
123
+ wav_np = wav.squeeze(0).cpu().numpy()
124
 
125
  buffer = io.BytesIO()
126
+ sf.write(buffer, wav_np, model.sr, format="WAV")
127
  buffer.seek(0)
 
 
 
 
 
128
 
129
+ segment = AudioSegment.from_wav(buffer)
130
+ segment = segment.fade_in(FADE_IN).fade_out(FADE_OUT)
 
 
131
 
132
+ final_audio += segment + silence
 
 
133
 
134
+ # --------------------------------
135
+ # Export
136
+ # --------------------------------
137
+ output_path = "story_voice.mp3"
138
+ final_audio.export(output_path, format="mp3", bitrate="192k")
139
 
140
+ if temp_prompt and os.path.exists(temp_prompt):
141
+ os.remove(temp_prompt)
 
 
142
 
143
+ return output_path
144
 
145
+ # ===============================
146
+ # GRADIO UI
147
+ # ===============================
148
  with gr.Blocks() as demo:
149
+ gr.Markdown("## 🎙️ Storyteller / Podcast Chatterbox TTS")
150
+
151
+ text = gr.Textbox(
152
+ label="Story Text",
153
+ lines=12,
154
+ placeholder="Paste your full story here..."
155
+ )
156
+
157
+ ref = gr.Audio(
158
+ sources=["upload", "microphone"],
159
+ type="filepath",
160
+ label="Reference Voice (optional)"
161
  )
162
+
163
+ exaggeration = gr.Slider(0.25, 1.0, value=0.4, step=0.05, label="Emotion")
164
+ temperature = gr.Slider(0.3, 1.2, value=0.7, step=0.05, label="Variation")
165
+ cfg = gr.Slider(0.3, 1.0, value=0.6, step=0.05, label="Voice Stability")
166
+
167
+ seed = gr.Number(value=0, label="Seed (0 = random)")
168
+
169
+ btn = gr.Button("Generate Voice")
170
+ out = gr.Audio(label="Final Audio")
171
+
172
+ btn.click(
173
+ fn=generate_tts,
174
+ inputs=[text, ref, exaggeration, temperature, seed, cfg],
175
+ outputs=out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  )
177
 
178
+ demo.launch(share=True)