oicui commited on
Commit
6fc220a
·
verified ·
1 Parent(s): 1afd111

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -52
app.py CHANGED
@@ -1,40 +1,44 @@
1
  import random
2
  import numpy as np
3
  import torch
4
- from chatterbox.src.chatterbox.tts import ChatterboxTTS
5
  import gradio as gr
6
  import spaces
 
7
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
  print(f"🚀 Running on device: {DEVICE}")
10
 
11
- # --- Global Model Initialization ---
 
 
12
  MODEL = None
13
 
14
  def get_or_load_model():
15
- """Loads the ChatterboxTTS model if it hasn't been loaded already,
16
- and ensures it's on the correct device."""
17
  global MODEL
18
  if MODEL is None:
19
  print("Model not loaded, initializing...")
20
  try:
21
  MODEL = ChatterboxTTS.from_pretrained(DEVICE)
22
- if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
23
  MODEL.to(DEVICE)
24
- print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
25
  except Exception as e:
26
  print(f"Error loading model: {e}")
27
  raise
28
  return MODEL
29
 
30
- # Attempt to load the model at startup.
31
  try:
32
  get_or_load_model()
33
  except Exception as e:
34
- print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
 
 
 
 
 
35
 
36
  def set_seed(seed: int):
37
- """Sets the random seed for reproducibility across torch, numpy, and random."""
38
  torch.manual_seed(seed)
39
  if DEVICE == "cuda":
40
  torch.cuda.manual_seed(seed)
@@ -42,6 +46,19 @@ def set_seed(seed: int):
42
  random.seed(seed)
43
  np.random.seed(seed)
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  @spaces.GPU
46
  def generate_tts_audio(
47
  text_input: str,
@@ -51,85 +68,120 @@ def generate_tts_audio(
51
  seed_num_input: int = 0,
52
  cfgw_input: float = 0.5,
53
  vad_trim_input: bool = False,
54
- ) -> tuple[int, np.ndarray]:
55
- """
56
- Generate high-quality speech audio from text using ChatterboxTTS model with optional reference audio styling.
57
-
58
- This tool synthesizes natural-sounding speech from input text. When a reference audio file
59
- is provided, it captures the speaker's voice characteristics and speaking style. The generated audio
60
- maintains the prosody, tone, and vocal qualities of the reference speaker, or uses default voice if no reference is provided.
61
-
62
- Args:
63
- text_input (str): The text to synthesize into speech (maximum 300 characters)
64
- audio_prompt_path_input (str, optional): File path or URL to the reference audio file that defines the target voice style. Defaults to None.
65
- exaggeration_input (float, optional): Controls speech expressiveness (0.25-2.0, neutral=0.5, extreme values may be unstable). Defaults to 0.5.
66
- temperature_input (float, optional): Controls randomness in generation (0.05-5.0, higher=more varied). Defaults to 0.8.
67
- seed_num_input (int, optional): Random seed for reproducible results (0 for random generation). Defaults to 0.
68
- cfgw_input (float, optional): CFG/Pace weight controlling generation guidance (0.2-1.0). Defaults to 0.5.
69
-
70
- Returns:
71
- tuple[int, np.ndarray]: A tuple containing the sample rate (int) and the generated audio waveform (numpy.ndarray)
72
- """
73
- current_model = get_or_load_model()
74
 
 
75
  if current_model is None:
76
  raise RuntimeError("TTS model is not loaded.")
77
 
78
- if seed_num_input != 0:
79
- set_seed(int(seed_num_input))
 
 
 
 
 
80
 
81
- print(f"Generating audio for text: '{text_input[:50]}...'")
 
 
 
82
 
83
- # Handle optional audio prompt
84
  generate_kwargs = {
85
  "exaggeration": exaggeration_input,
86
  "temperature": temperature_input,
87
  "cfg_weight": cfgw_input,
88
  "vad_trim": vad_trim_input,
89
  }
90
-
91
  if audio_prompt_path_input:
92
  generate_kwargs["audio_prompt_path"] = audio_prompt_path_input
93
 
94
- wav = current_model.generate(
95
- text_input[:300], # Truncate text to max chars
96
- **generate_kwargs
97
- )
 
 
 
 
 
 
 
 
 
 
 
 
98
  print("Audio generation complete.")
99
- return (current_model.sr, wav.squeeze(0).numpy())
 
 
 
 
 
 
100
 
101
  with gr.Blocks() as demo:
102
  gr.Markdown(
103
  """
104
- # Chatterbox TTS Demo
105
- Generate high-quality speech from text with reference audio styling.
106
  """
107
  )
 
108
  with gr.Row():
109
  with gr.Column():
 
 
110
  text = gr.Textbox(
111
- value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.",
112
- label="Text to synthesize (max chars 300)",
113
- max_lines=5
114
  )
 
 
115
  ref_wav = gr.Audio(
116
  sources=["upload", "microphone"],
117
  type="filepath",
118
  label="Reference Audio File (Optional)",
119
  value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
120
  )
121
- exaggeration = gr.Slider(
122
- 0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
123
- )
124
- cfg_weight = gr.Slider(
125
- 0.2, 1, step=.05, label="CFG/Pace", value=0.5
126
- )
127
 
 
 
 
 
128
  with gr.Accordion("More options", open=False):
129
- seed_num = gr.Number(value=0, label="Random seed (0 for random)")
 
 
 
 
 
 
 
 
 
130
  temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
131
  vad_trim = gr.Checkbox(label="Ref VAD trimming", value=False)
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  run_btn = gr.Button("Generate", variant="primary")
134
 
135
  with gr.Column():
@@ -145,8 +197,13 @@ with gr.Blocks() as demo:
145
  seed_num,
146
  cfg_weight,
147
  vad_trim,
 
 
 
 
 
 
148
  ],
149
- outputs=[audio_output],
150
  )
151
 
152
  demo.launch(mcp_server=True)
 
1
  import random
2
  import numpy as np
3
  import torch
 
4
  import gradio as gr
5
  import spaces
6
+ from chatterbox.src.chatterbox.tts import ChatterboxTTS
7
 
8
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
9
  print(f"🚀 Running on device: {DEVICE}")
10
 
11
+ # ---------------------------------------
12
+ # GLOBAL MODEL LOAD
13
+ # ---------------------------------------
14
  MODEL = None
15
 
16
  def get_or_load_model():
 
 
17
  global MODEL
18
  if MODEL is None:
19
  print("Model not loaded, initializing...")
20
  try:
21
  MODEL = ChatterboxTTS.from_pretrained(DEVICE)
22
+ if hasattr(MODEL, "to") and str(MODEL.device) != DEVICE:
23
  MODEL.to(DEVICE)
24
+ print("Model loaded successfully.")
25
  except Exception as e:
26
  print(f"Error loading model: {e}")
27
  raise
28
  return MODEL
29
 
30
+
31
  try:
32
  get_or_load_model()
33
  except Exception as e:
34
+ print(f"CRITICAL startup load failed: {e}")
35
+
36
+
37
+ # ---------------------------------------
38
+ # UTILITIES
39
+ # ---------------------------------------
40
 
41
  def set_seed(seed: int):
 
42
  torch.manual_seed(seed)
43
  if DEVICE == "cuda":
44
  torch.cuda.manual_seed(seed)
 
46
  random.seed(seed)
47
  np.random.seed(seed)
48
 
49
+ def chunk_text(text: str, chunk_size: int):
50
+ return [text[i : i + chunk_size] for i in range(0, len(text), chunk_size)]
51
+
52
+ def concat_audio(chunks):
53
+ if not chunks:
54
+ return None
55
+ return np.concatenate(chunks, axis=-1)
56
+
57
+
58
+ # ---------------------------------------
59
+ # MAIN TTS FUNCTION
60
+ # ---------------------------------------
61
+
62
  @spaces.GPU
63
  def generate_tts_audio(
64
  text_input: str,
 
68
  seed_num_input: int = 0,
69
  cfgw_input: float = 0.5,
70
  vad_trim_input: bool = False,
71
+ enable_chunking: bool = False,
72
+ chunk_size_value: int = 250,
73
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ current_model = get_or_load_model()
76
  if current_model is None:
77
  raise RuntimeError("TTS model is not loaded.")
78
 
79
+ # -------------------------
80
+ # SEED HANDLING
81
+ # -------------------------
82
+ if seed_num_input == 0:
83
+ used_seed = random.randint(1, 2**31 - 1)
84
+ else:
85
+ used_seed = int(seed_num_input)
86
 
87
+ print(f"Using seed: {used_seed}")
88
+ set_seed(used_seed)
89
+
90
+ print(f"Generating audio for text (preview): '{text_input[:50]}...'")
91
 
 
92
  generate_kwargs = {
93
  "exaggeration": exaggeration_input,
94
  "temperature": temperature_input,
95
  "cfg_weight": cfgw_input,
96
  "vad_trim": vad_trim_input,
97
  }
 
98
  if audio_prompt_path_input:
99
  generate_kwargs["audio_prompt_path"] = audio_prompt_path_input
100
 
101
+ # -------------------------
102
+ # CHUNK PROCESSING
103
+ # -------------------------
104
+ if enable_chunking:
105
+ print(f"Chunking enabled — chunk size = {chunk_size_value}")
106
+ text_chunks = chunk_text(text_input, int(chunk_size_value))
107
+ else:
108
+ text_chunks = [text_input]
109
+
110
+ audio_segments = []
111
+ for i, chunk in enumerate(text_chunks):
112
+ print(f"Rendering chunk {i+1}/{len(text_chunks)}...")
113
+ wav = current_model.generate(chunk, **generate_kwargs)
114
+ audio_segments.append(wav.squeeze(0).numpy())
115
+
116
+ final_audio = concat_audio(audio_segments)
117
  print("Audio generation complete.")
118
+
119
+ return current_model.sr, final_audio, used_seed
120
+
121
+
122
+ # ---------------------------------------
123
+ # UI
124
+ # ---------------------------------------
125
 
126
  with gr.Blocks() as demo:
127
  gr.Markdown(
128
  """
129
+ # Chatterbox TTS Demo — Enhanced Version
130
+ Supports unlimited text, chunking & random seed viewer.
131
  """
132
  )
133
+
134
  with gr.Row():
135
  with gr.Column():
136
+
137
+ # MAIN TEXT
138
  text = gr.Textbox(
139
+ value="Now let's make my mum's favourite...",
140
+ label="Text to synthesize",
141
+ max_lines=10
142
  )
143
+
144
+ # REFERENCE AUDIO
145
  ref_wav = gr.Audio(
146
  sources=["upload", "microphone"],
147
  type="filepath",
148
  label="Reference Audio File (Optional)",
149
  value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
150
  )
 
 
 
 
 
 
151
 
152
+ exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration", value=.5)
153
+ cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG/Pace", value=0.5)
154
+
155
+ # ADVANCED OPTIONS
156
  with gr.Accordion("More options", open=False):
157
+
158
+ seed_num = gr.Number(value=0, label="Random seed (0 = random)")
159
+
160
+ # NEW — SEED DISPLAY (READ ONLY)
161
+ seed_display = gr.Textbox(
162
+ value="",
163
+ label="Seed Used (auto-filled)",
164
+ interactive=False
165
+ )
166
+
167
  temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
168
  vad_trim = gr.Checkbox(label="Ref VAD trimming", value=False)
169
 
170
+ # NEW — ENABLE CHUNKING
171
+ enable_chunking = gr.Checkbox(
172
+ label="Enable Text Chunking (split long text)",
173
+ value=False
174
+ )
175
+
176
+ # NEW — CHUNK SIZE SLIDER
177
+ chunk_size = gr.Slider(
178
+ minimum=100,
179
+ maximum=300,
180
+ value=250,
181
+ step=10,
182
+ label="Chunk Size (characters) — Text chunking for long conversations"
183
+ )
184
+
185
  run_btn = gr.Button("Generate", variant="primary")
186
 
187
  with gr.Column():
 
197
  seed_num,
198
  cfg_weight,
199
  vad_trim,
200
+ enable_chunking,
201
+ chunk_size,
202
+ ],
203
+ outputs=[
204
+ audio_output,
205
+ seed_display, # NEW: seed returned to UI
206
  ],
 
207
  )
208
 
209
  demo.launch(mcp_server=True)