mrfakename commited on
Commit
cb99a43
·
verified ·
1 Parent(s): 60340e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1060 -56
app.py CHANGED
@@ -1,22 +1,54 @@
1
- import spaces
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
  import numpy as np
4
- import torch
5
- import tempfile
6
  import soundfile as sf
 
7
  import torchaudio
8
- import json
9
- import uuid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- from f5_tts.model import DiT
12
  from f5_tts.infer.utils_infer import (
13
- load_vocoder,
14
  load_model,
 
15
  preprocess_ref_audio_text,
16
- infer_process,
17
  remove_silence_for_generated_wav,
18
  save_spectrogram,
 
19
  )
 
 
 
 
 
20
 
21
  DEFAULT_TTS_MODEL_CFG = [
22
  "hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors",
@@ -25,26 +57,125 @@ DEFAULT_TTS_MODEL_CFG = [
25
  ]
26
 
27
 
 
 
 
28
 
29
 
30
- from cached_path import cached_path
31
- ckpt_path = str(cached_path(DEFAULT_TTS_MODEL_CFG[0]))
32
- F5TTS_model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
33
- F5TTS_ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_path)
34
 
35
 
 
 
 
 
36
 
37
- vocoder = load_vocoder()
38
 
39
- @spaces.GPU
40
- def infer(ref_audio_orig, ref_text, gen_text, remove_silence, seed, cross_fade_duration=0.15, nfe_step=32, speed=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  if not ref_audio_orig:
42
- return None, None, ref_text
 
 
 
 
 
 
43
  torch.manual_seed(seed)
 
 
44
  if not gen_text.strip():
45
- return None, None, ref_text
46
- ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text)
47
- ema_model = F5TTS_ema_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  final_wave, final_sample_rate, combined_spectrogram = infer_process(
49
  ref_audio,
50
  ref_text,
@@ -54,50 +185,139 @@ def infer(ref_audio_orig, ref_text, gen_text, remove_silence, seed, cross_fade_d
54
  cross_fade_duration=cross_fade_duration,
55
  nfe_step=nfe_step,
56
  speed=speed,
57
- show_info=print,
58
  progress=gr.Progress(),
59
  )
 
 
60
  if remove_silence:
61
- temp_wav_path = tempfile.gettempdir() + "/" + str(uuid.uuid4()) + ".wav"
62
- sf.write(temp_wav_path, final_wave, final_sample_rate)
63
- remove_silence_for_generated_wav(temp_wav_path)
64
- final_wave, _ = torchaudio.load(temp_wav_path)
 
 
 
 
65
  final_wave = final_wave.squeeze().cpu().numpy()
66
- return (final_sample_rate, final_wave), ref_text
67
 
68
- def basic_tts(ref_audio, ref_text, gen_text, remove_silence, seed, cross_fade_duration, nfe_step, speed):
69
- if seed is None or seed < 0 or seed > 2**31 - 1:
70
- seed = np.random.randint(0, 2**31 - 1)
71
- audio_out, ref_text_out = infer(
72
- ref_audio,
73
- ref_text,
74
- gen_text,
75
- remove_silence,
76
- seed=seed,
77
- cross_fade_duration=cross_fade_duration,
78
- nfe_step=nfe_step,
79
- speed=speed,
80
- )
81
- return audio_out, ref_text_out, seed
82
 
83
- with gr.Blocks() as app:
84
- gr.Markdown("""
85
- # F5-TTS Simple Demo
86
 
87
- Hi everyone, sorry for the issues lately - looks like there are some temporary issues with ZeroGPU on Hugging Face. For now I've switched to a very simple voice cloning demo without voice chat, etc. - once these issues are resolved I plan to switch back to the full demo!
88
-
89
- Upload a reference audio and enter text to synthesize speech in the reference voice.
90
- """)
91
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
92
- gen_text_input = gr.Textbox(label="Text to Generate", lines=4)
93
- ref_text_input = gr.Textbox(label="Reference Text (optional)", lines=2)
94
- remove_silence = gr.Checkbox(label="Remove Silences", value=False)
95
- seed_input = gr.Number(label="Seed (optional)", value=0, precision=0)
96
- cross_fade_duration = gr.Slider(label="Cross-Fade Duration (s)", minimum=0.0, maximum=1.0, value=0.15, step=0.01)
97
- nfe_slider = gr.Slider(label="NFE Steps", minimum=4, maximum=64, value=32, step=2)
98
- speed_slider = gr.Slider(label="Speed", minimum=0.3, maximum=2.0, value=1.0, step=0.1)
 
99
  generate_btn = gr.Button("Synthesize", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  audio_output = gr.Audio(label="Synthesized Audio")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  generate_btn.click(
102
  basic_tts,
103
  inputs=[
@@ -105,13 +325,797 @@ with gr.Blocks() as app:
105
  ref_text_input,
106
  gen_text_input,
107
  remove_silence,
 
108
  seed_input,
109
- cross_fade_duration,
110
  nfe_slider,
111
  speed_slider,
112
  ],
113
- outputs=[audio_output, ref_text_input, seed_input],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  if __name__ == "__main__":
117
- app.queue().launch()
 
 
 
 
1
+ # ruff: noqa: E402
2
+ # Above allows ruff to ignore E402: module level import not at top of file
3
+
4
+ import gc
5
+ import json
6
+ import os
7
+ import re
8
+ import tempfile
9
+ from collections import OrderedDict
10
+ from functools import lru_cache
11
+ from importlib.resources import files
12
+
13
+ import click
14
  import gradio as gr
15
  import numpy as np
 
 
16
  import soundfile as sf
17
+ import torch
18
  import torchaudio
19
+ from cached_path import cached_path
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer
21
+
22
+
23
+ try:
24
+ import spaces
25
+
26
+ USING_SPACES = True
27
+ except ImportError:
28
+ USING_SPACES = False
29
+
30
+
31
+ def gpu_decorator(func):
32
+ if USING_SPACES:
33
+ return spaces.GPU(func)
34
+ else:
35
+ return func
36
+
37
 
 
38
  from f5_tts.infer.utils_infer import (
39
+ infer_process,
40
  load_model,
41
+ load_vocoder,
42
  preprocess_ref_audio_text,
 
43
  remove_silence_for_generated_wav,
44
  save_spectrogram,
45
+ tempfile_kwargs,
46
  )
47
+ from f5_tts.model import DiT, UNetT
48
+
49
+
50
+ DEFAULT_TTS_MODEL = "F5-TTS_v1"
51
+ tts_model_choice = DEFAULT_TTS_MODEL
52
 
53
  DEFAULT_TTS_MODEL_CFG = [
54
  "hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors",
 
57
  ]
58
 
59
 
60
+ # load models
61
+
62
+ vocoder = load_vocoder()
63
 
64
 
65
+ def load_f5tts():
66
+ ckpt_path = str(cached_path(DEFAULT_TTS_MODEL_CFG[0]))
67
+ F5TTS_model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
68
+ return load_model(DiT, F5TTS_model_cfg, ckpt_path)
69
 
70
 
71
+ def load_e2tts():
72
+ ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
73
+ E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4, text_mask_padding=False, pe_attn_head=1)
74
+ return load_model(UNetT, E2TTS_model_cfg, ckpt_path)
75
 
 
76
 
77
+ def load_custom(ckpt_path: str, vocab_path="", model_cfg=None):
78
+ ckpt_path, vocab_path = ckpt_path.strip(), vocab_path.strip()
79
+ if ckpt_path.startswith("hf://"):
80
+ ckpt_path = str(cached_path(ckpt_path))
81
+ if vocab_path.startswith("hf://"):
82
+ vocab_path = str(cached_path(vocab_path))
83
+ if model_cfg is None:
84
+ model_cfg = json.loads(DEFAULT_TTS_MODEL_CFG[2])
85
+ elif isinstance(model_cfg, str):
86
+ model_cfg = json.loads(model_cfg)
87
+ return load_model(DiT, model_cfg, ckpt_path, vocab_file=vocab_path)
88
+
89
+
90
+ F5TTS_ema_model = load_f5tts()
91
+ E2TTS_ema_model = load_e2tts() if USING_SPACES else None
92
+ custom_ema_model, pre_custom_path = None, ""
93
+
94
+ chat_model_state = None
95
+ chat_tokenizer_state = None
96
+
97
+
98
+ @gpu_decorator
99
+ def chat_model_inference(messages, model, tokenizer):
100
+ """Generate response using Qwen"""
101
+ text = tokenizer.apply_chat_template(
102
+ messages,
103
+ tokenize=False,
104
+ add_generation_prompt=True,
105
+ )
106
+
107
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
108
+ generated_ids = model.generate(
109
+ **model_inputs,
110
+ max_new_tokens=512,
111
+ temperature=0.7,
112
+ top_p=0.95,
113
+ )
114
+
115
+ generated_ids = [
116
+ output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
117
+ ]
118
+ return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
119
+
120
+
121
+ @gpu_decorator
122
+ def load_text_from_file(file):
123
+ if file:
124
+ with open(file, "r", encoding="utf-8") as f:
125
+ text = f.read().strip()
126
+ else:
127
+ text = ""
128
+ return gr.update(value=text)
129
+
130
+
131
+ @lru_cache(maxsize=1000) # NOTE. need to ensure params of infer() hashable
132
+ @gpu_decorator
133
+ def infer(
134
+ ref_audio_orig,
135
+ ref_text,
136
+ gen_text,
137
+ model,
138
+ remove_silence,
139
+ seed,
140
+ cross_fade_duration=0.15,
141
+ nfe_step=32,
142
+ speed=1,
143
+ show_info=gr.Info,
144
+ ):
145
  if not ref_audio_orig:
146
+ gr.Warning("Please provide reference audio.")
147
+ return gr.update(), gr.update(), ref_text
148
+
149
+ # Set inference seed
150
+ if seed < 0 or seed > 2**31 - 1:
151
+ gr.Warning("Seed must in range 0 ~ 2147483647. Using random seed instead.")
152
+ seed = np.random.randint(0, 2**31 - 1)
153
  torch.manual_seed(seed)
154
+ used_seed = seed
155
+
156
  if not gen_text.strip():
157
+ gr.Warning("Please enter text to generate or upload a text file.")
158
+ return gr.update(), gr.update(), ref_text
159
+
160
+ ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
161
+
162
+ if model == DEFAULT_TTS_MODEL:
163
+ ema_model = F5TTS_ema_model
164
+ elif model == "E2-TTS":
165
+ global E2TTS_ema_model
166
+ if E2TTS_ema_model is None:
167
+ show_info("Loading E2-TTS model...")
168
+ E2TTS_ema_model = load_e2tts()
169
+ ema_model = E2TTS_ema_model
170
+ elif isinstance(model, tuple) and model[0] == "Custom":
171
+ assert not USING_SPACES, "Only official checkpoints allowed in Spaces."
172
+ global custom_ema_model, pre_custom_path
173
+ if pre_custom_path != model[1]:
174
+ show_info("Loading Custom TTS model...")
175
+ custom_ema_model = load_custom(model[1], vocab_path=model[2], model_cfg=model[3])
176
+ pre_custom_path = model[1]
177
+ ema_model = custom_ema_model
178
+
179
  final_wave, final_sample_rate, combined_spectrogram = infer_process(
180
  ref_audio,
181
  ref_text,
 
185
  cross_fade_duration=cross_fade_duration,
186
  nfe_step=nfe_step,
187
  speed=speed,
188
+ show_info=show_info,
189
  progress=gr.Progress(),
190
  )
191
+
192
+ # Remove silence
193
  if remove_silence:
194
+ with tempfile.NamedTemporaryFile(suffix=".wav", **tempfile_kwargs) as f:
195
+ temp_path = f.name
196
+ try:
197
+ sf.write(temp_path, final_wave, final_sample_rate)
198
+ remove_silence_for_generated_wav(f.name)
199
+ final_wave, _ = torchaudio.load(f.name)
200
+ finally:
201
+ os.unlink(temp_path)
202
  final_wave = final_wave.squeeze().cpu().numpy()
 
203
 
204
+ # Save the spectrogram
205
+ with tempfile.NamedTemporaryFile(suffix=".png", **tempfile_kwargs) as tmp_spectrogram:
206
+ spectrogram_path = tmp_spectrogram.name
207
+ save_spectrogram(combined_spectrogram, spectrogram_path)
 
 
 
 
 
 
 
 
 
 
208
 
209
+ return (final_sample_rate, final_wave), spectrogram_path, ref_text, used_seed
 
 
210
 
211
+
212
+ with gr.Blocks() as app_tts:
213
+ gr.Markdown("# Batched TTS")
 
214
  ref_audio_input = gr.Audio(label="Reference Audio", type="filepath")
215
+ with gr.Row():
216
+ gen_text_input = gr.Textbox(
217
+ label="Text to Generate",
218
+ lines=10,
219
+ max_lines=40,
220
+ scale=4,
221
+ )
222
+ gen_text_file = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
223
  generate_btn = gr.Button("Synthesize", variant="primary")
224
+ with gr.Accordion("Advanced Settings", open=False):
225
+ with gr.Row():
226
+ ref_text_input = gr.Textbox(
227
+ label="Reference Text",
228
+ info="Leave blank to automatically transcribe the reference audio. If you enter text or upload a file, it will override automatic transcription.",
229
+ lines=2,
230
+ scale=4,
231
+ )
232
+ ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1)
233
+ with gr.Row():
234
+ randomize_seed = gr.Checkbox(
235
+ label="Randomize Seed",
236
+ info="Check to use a random seed for each generation. Uncheck to use the seed specified.",
237
+ value=True,
238
+ scale=3,
239
+ )
240
+ seed_input = gr.Number(show_label=False, value=0, precision=0, scale=1)
241
+ with gr.Column(scale=4):
242
+ remove_silence = gr.Checkbox(
243
+ label="Remove Silences",
244
+ info="If undesired long silence(s) produced, turn on to automatically detect and crop.",
245
+ value=False,
246
+ )
247
+ speed_slider = gr.Slider(
248
+ label="Speed",
249
+ minimum=0.3,
250
+ maximum=2.0,
251
+ value=1.0,
252
+ step=0.1,
253
+ info="Adjust the speed of the audio.",
254
+ )
255
+ nfe_slider = gr.Slider(
256
+ label="NFE Steps",
257
+ minimum=4,
258
+ maximum=64,
259
+ value=32,
260
+ step=2,
261
+ info="Set the number of denoising steps.",
262
+ )
263
+ cross_fade_duration_slider = gr.Slider(
264
+ label="Cross-Fade Duration (s)",
265
+ minimum=0.0,
266
+ maximum=1.0,
267
+ value=0.15,
268
+ step=0.01,
269
+ info="Set the duration of the cross-fade between audio clips.",
270
+ )
271
+
272
  audio_output = gr.Audio(label="Synthesized Audio")
273
+ spectrogram_output = gr.Image(label="Spectrogram")
274
+
275
+ @gpu_decorator
276
+ def basic_tts(
277
+ ref_audio_input,
278
+ ref_text_input,
279
+ gen_text_input,
280
+ remove_silence,
281
+ randomize_seed,
282
+ seed_input,
283
+ cross_fade_duration_slider,
284
+ nfe_slider,
285
+ speed_slider,
286
+ ):
287
+ if randomize_seed:
288
+ seed_input = np.random.randint(0, 2**31 - 1)
289
+
290
+ audio_out, spectrogram_path, ref_text_out, used_seed = infer(
291
+ ref_audio_input,
292
+ ref_text_input,
293
+ gen_text_input,
294
+ tts_model_choice,
295
+ remove_silence,
296
+ seed=seed_input,
297
+ cross_fade_duration=cross_fade_duration_slider,
298
+ nfe_step=nfe_slider,
299
+ speed=speed_slider,
300
+ )
301
+ return audio_out, spectrogram_path, ref_text_out, used_seed
302
+
303
+ gen_text_file.upload(
304
+ load_text_from_file,
305
+ inputs=[gen_text_file],
306
+ outputs=[gen_text_input],
307
+ )
308
+
309
+ ref_text_file.upload(
310
+ load_text_from_file,
311
+ inputs=[ref_text_file],
312
+ outputs=[ref_text_input],
313
+ )
314
+
315
+ ref_audio_input.clear(
316
+ lambda: [None, None],
317
+ None,
318
+ [ref_text_input, ref_text_file],
319
+ )
320
+
321
  generate_btn.click(
322
  basic_tts,
323
  inputs=[
 
325
  ref_text_input,
326
  gen_text_input,
327
  remove_silence,
328
+ randomize_seed,
329
  seed_input,
330
+ cross_fade_duration_slider,
331
  nfe_slider,
332
  speed_slider,
333
  ],
334
+ outputs=[audio_output, spectrogram_output, ref_text_input, seed_input],
335
+ )
336
+
337
+
338
+ def parse_speechtypes_text(gen_text):
339
+ # Pattern to find {str} or {"name": str, "seed": int, "speed": float}
340
+ pattern = r"(\{.*?\})"
341
+
342
+ # Split the text by the pattern
343
+ tokens = re.split(pattern, gen_text)
344
+
345
+ segments = []
346
+
347
+ current_type_dict = {
348
+ "name": "Regular",
349
+ "seed": -1,
350
+ "speed": 1.0,
351
+ }
352
+
353
+ for i in range(len(tokens)):
354
+ if i % 2 == 0:
355
+ # This is text
356
+ text = tokens[i].strip()
357
+ if text:
358
+ current_type_dict["text"] = text
359
+ segments.append(current_type_dict)
360
+ else:
361
+ # This is type
362
+ type_str = tokens[i].strip()
363
+ try: # if type dict
364
+ current_type_dict = json.loads(type_str)
365
+ except json.decoder.JSONDecodeError:
366
+ type_str = type_str[1:-1] # remove brace {}
367
+ current_type_dict = {"name": type_str, "seed": -1, "speed": 1.0}
368
+
369
+ return segments
370
+
371
+
372
+ with gr.Blocks() as app_multistyle:
373
+ # New section for multistyle generation
374
+ gr.Markdown(
375
+ """
376
+ # Multiple Speech-Type Generation
377
+
378
+ This section allows you to generate multiple speech types or multiple people's voices. Enter your text in the format shown below, or upload a .txt file with the same format. The system will generate speech using the appropriate type. If unspecified, the model will use the regular speech type. The current speech type will be used until the next speech type is specified.
379
+ """
380
+ )
381
+
382
+ with gr.Row():
383
+ gr.Markdown(
384
+ """
385
+ **Example Input:** <br>
386
+ {Regular} Hello, I'd like to order a sandwich please. <br>
387
+ {Surprised} What do you mean you're out of bread? <br>
388
+ {Sad} I really wanted a sandwich though... <br>
389
+ {Angry} You know what, darn you and your little shop! <br>
390
+ {Whisper} I'll just go back home and cry now. <br>
391
+ {Shouting} Why me?!
392
+ """
393
+ )
394
+
395
+ gr.Markdown(
396
+ """
397
+ **Example Input 2:** <br>
398
+ {"name": "Speaker1_Happy", "seed": -1, "speed": 1} Hello, I'd like to order a sandwich please. <br>
399
+ {"name": "Speaker2_Regular", "seed": -1, "speed": 1} Sorry, we're out of bread. <br>
400
+ {"name": "Speaker1_Sad", "seed": -1, "speed": 1} I really wanted a sandwich though... <br>
401
+ {"name": "Speaker2_Whisper", "seed": -1, "speed": 1} I'll give you the last one I was hiding.
402
+ """
403
+ )
404
+
405
+ gr.Markdown(
406
+ 'Upload different audio clips for each speech type. The first speech type is mandatory. You can add additional speech types by clicking the "Add Speech Type" button.'
407
+ )
408
+
409
+ # Regular speech type (mandatory)
410
+ with gr.Row(variant="compact") as regular_row:
411
+ with gr.Column(scale=1, min_width=160):
412
+ regular_name = gr.Textbox(value="Regular", label="Speech Type Name")
413
+ regular_insert = gr.Button("Insert Label", variant="secondary")
414
+ with gr.Column(scale=3):
415
+ regular_audio = gr.Audio(label="Regular Reference Audio", type="filepath")
416
+ with gr.Column(scale=3):
417
+ regular_ref_text = gr.Textbox(label="Reference Text (Regular)", lines=4)
418
+ with gr.Row():
419
+ regular_seed_slider = gr.Slider(
420
+ show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed, -1 for random"
421
+ )
422
+ regular_speed_slider = gr.Slider(
423
+ show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
424
+ )
425
+ with gr.Column(scale=1, min_width=160):
426
+ regular_ref_text_file = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
427
+
428
+ # Regular speech type (max 100)
429
+ max_speech_types = 100
430
+ speech_type_rows = [regular_row]
431
+ speech_type_names = [regular_name]
432
+ speech_type_audios = [regular_audio]
433
+ speech_type_ref_texts = [regular_ref_text]
434
+ speech_type_ref_text_files = [regular_ref_text_file]
435
+ speech_type_seeds = [regular_seed_slider]
436
+ speech_type_speeds = [regular_speed_slider]
437
+ speech_type_delete_btns = [None]
438
+ speech_type_insert_btns = [regular_insert]
439
+
440
+ # Additional speech types (99 more)
441
+ for i in range(max_speech_types - 1):
442
+ with gr.Row(variant="compact", visible=False) as row:
443
+ with gr.Column(scale=1, min_width=160):
444
+ name_input = gr.Textbox(label="Speech Type Name")
445
+ insert_btn = gr.Button("Insert Label", variant="secondary")
446
+ delete_btn = gr.Button("Delete Type", variant="stop")
447
+ with gr.Column(scale=3):
448
+ audio_input = gr.Audio(label="Reference Audio", type="filepath")
449
+ with gr.Column(scale=3):
450
+ ref_text_input = gr.Textbox(label="Reference Text", lines=4)
451
+ with gr.Row():
452
+ seed_input = gr.Slider(
453
+ show_label=False, minimum=-1, maximum=999, value=-1, step=1, info="Seed. -1 for random"
454
+ )
455
+ speed_input = gr.Slider(
456
+ show_label=False, minimum=0.3, maximum=2.0, value=1.0, step=0.1, info="Adjust the speed"
457
+ )
458
+ with gr.Column(scale=1, min_width=160):
459
+ ref_text_file_input = gr.File(label="Load Reference Text from File (.txt)", file_types=[".txt"])
460
+ speech_type_rows.append(row)
461
+ speech_type_names.append(name_input)
462
+ speech_type_audios.append(audio_input)
463
+ speech_type_ref_texts.append(ref_text_input)
464
+ speech_type_ref_text_files.append(ref_text_file_input)
465
+ speech_type_seeds.append(seed_input)
466
+ speech_type_speeds.append(speed_input)
467
+ speech_type_delete_btns.append(delete_btn)
468
+ speech_type_insert_btns.append(insert_btn)
469
+
470
+ # Global logic for all speech types
471
+ for i in range(max_speech_types):
472
+ speech_type_audios[i].clear(
473
+ lambda: [None, None],
474
+ None,
475
+ [speech_type_ref_texts[i], speech_type_ref_text_files[i]],
476
+ )
477
+ speech_type_ref_text_files[i].upload(
478
+ load_text_from_file,
479
+ inputs=[speech_type_ref_text_files[i]],
480
+ outputs=[speech_type_ref_texts[i]],
481
+ )
482
+
483
+ # Button to add speech type
484
+ add_speech_type_btn = gr.Button("Add Speech Type")
485
+
486
+ # Keep track of autoincrement of speech types, no roll back
487
+ speech_type_count = 1
488
+
489
+ # Function to add a speech type
490
+ def add_speech_type_fn():
491
+ row_updates = [gr.update() for _ in range(max_speech_types)]
492
+ global speech_type_count
493
+ if speech_type_count < max_speech_types:
494
+ row_updates[speech_type_count] = gr.update(visible=True)
495
+ speech_type_count += 1
496
+ else:
497
+ gr.Warning("Exhausted maximum number of speech types. Consider restart the app.")
498
+ return row_updates
499
+
500
+ add_speech_type_btn.click(add_speech_type_fn, outputs=speech_type_rows)
501
+
502
+ # Function to delete a speech type
503
+ def delete_speech_type_fn():
504
+ return gr.update(visible=False), None, None, None, None
505
+
506
+ # Update delete button clicks and ref text file changes
507
+ for i in range(1, len(speech_type_delete_btns)):
508
+ speech_type_delete_btns[i].click(
509
+ delete_speech_type_fn,
510
+ outputs=[
511
+ speech_type_rows[i],
512
+ speech_type_names[i],
513
+ speech_type_audios[i],
514
+ speech_type_ref_texts[i],
515
+ speech_type_ref_text_files[i],
516
+ ],
517
+ )
518
+
519
+ # Text input for the prompt
520
+ with gr.Row():
521
+ gen_text_input_multistyle = gr.Textbox(
522
+ label="Text to Generate",
523
+ lines=10,
524
+ max_lines=40,
525
+ scale=4,
526
+ placeholder="Enter the script with speaker names (or emotion types) at the start of each block, e.g.:\n\n{Regular} Hello, I'd like to order a sandwich please.\n{Surprised} What do you mean you're out of bread?\n{Sad} I really wanted a sandwich though...\n{Angry} You know what, darn you and your little shop!\n{Whisper} I'll just go back home and cry now.\n{Shouting} Why me?!",
527
+ )
528
+ gen_text_file_multistyle = gr.File(label="Load Text to Generate from File (.txt)", file_types=[".txt"], scale=1)
529
+
530
+ def make_insert_speech_type_fn(index):
531
+ def insert_speech_type_fn(current_text, speech_type_name, speech_type_seed, speech_type_speed):
532
+ current_text = current_text or ""
533
+ if not speech_type_name:
534
+ gr.Warning("Please enter speech type name before insert.")
535
+ return current_text
536
+ speech_type_dict = {
537
+ "name": speech_type_name,
538
+ "seed": speech_type_seed,
539
+ "speed": speech_type_speed,
540
+ }
541
+ updated_text = current_text + json.dumps(speech_type_dict) + " "
542
+ return updated_text
543
+
544
+ return insert_speech_type_fn
545
+
546
+ for i, insert_btn in enumerate(speech_type_insert_btns):
547
+ insert_fn = make_insert_speech_type_fn(i)
548
+ insert_btn.click(
549
+ insert_fn,
550
+ inputs=[gen_text_input_multistyle, speech_type_names[i], speech_type_seeds[i], speech_type_speeds[i]],
551
+ outputs=gen_text_input_multistyle,
552
+ )
553
+
554
+ with gr.Accordion("Advanced Settings", open=True):
555
+ with gr.Row():
556
+ with gr.Column():
557
+ show_cherrypick_multistyle = gr.Checkbox(
558
+ label="Show Cherry-pick Interface",
559
+ info="Turn on to show interface, picking seeds from previous generations.",
560
+ value=False,
561
+ )
562
+ with gr.Column():
563
+ remove_silence_multistyle = gr.Checkbox(
564
+ label="Remove Silences",
565
+ info="Turn on to automatically detect and crop long silences.",
566
+ value=True,
567
+ )
568
+
569
+ # Generate button
570
+ generate_multistyle_btn = gr.Button("Generate Multi-Style Speech", variant="primary")
571
+
572
+ # Output audio
573
+ audio_output_multistyle = gr.Audio(label="Synthesized Audio")
574
+
575
+ # Used seed gallery
576
+ cherrypick_interface_multistyle = gr.Textbox(
577
+ label="Cherry-pick Interface",
578
+ lines=10,
579
+ max_lines=40,
580
+ show_copy_button=True,
581
+ interactive=False,
582
+ visible=False,
583
+ )
584
+
585
+ # Logic control to show/hide the cherrypick interface
586
+ show_cherrypick_multistyle.change(
587
+ lambda is_visible: gr.update(visible=is_visible),
588
+ show_cherrypick_multistyle,
589
+ cherrypick_interface_multistyle,
590
+ )
591
+
592
+ # Function to load text to generate from file
593
+ gen_text_file_multistyle.upload(
594
+ load_text_from_file,
595
+ inputs=[gen_text_file_multistyle],
596
+ outputs=[gen_text_input_multistyle],
597
+ )
598
+
599
+ @gpu_decorator
600
+ def generate_multistyle_speech(
601
+ gen_text,
602
+ *args,
603
+ ):
604
+ speech_type_names_list = args[:max_speech_types]
605
+ speech_type_audios_list = args[max_speech_types : 2 * max_speech_types]
606
+ speech_type_ref_texts_list = args[2 * max_speech_types : 3 * max_speech_types]
607
+ remove_silence = args[3 * max_speech_types]
608
+ # Collect the speech types and their audios into a dict
609
+ speech_types = OrderedDict()
610
+
611
+ ref_text_idx = 0
612
+ for name_input, audio_input, ref_text_input in zip(
613
+ speech_type_names_list, speech_type_audios_list, speech_type_ref_texts_list
614
+ ):
615
+ if name_input and audio_input:
616
+ speech_types[name_input] = {"audio": audio_input, "ref_text": ref_text_input}
617
+ else:
618
+ speech_types[f"@{ref_text_idx}@"] = {"audio": "", "ref_text": ""}
619
+ ref_text_idx += 1
620
+
621
+ # Parse the gen_text into segments
622
+ segments = parse_speechtypes_text(gen_text)
623
+
624
+ # For each segment, generate speech
625
+ generated_audio_segments = []
626
+ current_type_name = "Regular"
627
+ inference_meta_data = ""
628
+
629
+ for segment in segments:
630
+ name = segment["name"]
631
+ seed_input = segment["seed"]
632
+ speed = segment["speed"]
633
+ text = segment["text"]
634
+
635
+ if name in speech_types:
636
+ current_type_name = name
637
+ else:
638
+ gr.Warning(f"Type {name} is not available, will use Regular as default.")
639
+ current_type_name = "Regular"
640
+
641
+ try:
642
+ ref_audio = speech_types[current_type_name]["audio"]
643
+ except KeyError:
644
+ gr.Warning(f"Please provide reference audio for type {current_type_name}.")
645
+ return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
646
+ ref_text = speech_types[current_type_name].get("ref_text", "")
647
+
648
+ if seed_input == -1:
649
+ seed_input = np.random.randint(0, 2**31 - 1)
650
+
651
+ # Generate or retrieve speech for this segment
652
+ audio_out, _, ref_text_out, used_seed = infer(
653
+ ref_audio,
654
+ ref_text,
655
+ text,
656
+ tts_model_choice,
657
+ remove_silence,
658
+ seed=seed_input,
659
+ cross_fade_duration=0,
660
+ speed=speed,
661
+ show_info=print, # no pull to top when generating
662
+ )
663
+ sr, audio_data = audio_out
664
+
665
+ generated_audio_segments.append(audio_data)
666
+ speech_types[current_type_name]["ref_text"] = ref_text_out
667
+ inference_meta_data += json.dumps(dict(name=name, seed=used_seed, speed=speed)) + f" {text}\n"
668
+
669
+ # Concatenate all audio segments
670
+ if generated_audio_segments:
671
+ final_audio_data = np.concatenate(generated_audio_segments)
672
+ return (
673
+ [(sr, final_audio_data)]
674
+ + [speech_types[name]["ref_text"] for name in speech_types]
675
+ + [inference_meta_data]
676
+ )
677
+ else:
678
+ gr.Warning("No audio generated.")
679
+ return [None] + [speech_types[name]["ref_text"] for name in speech_types] + [None]
680
+
681
+ generate_multistyle_btn.click(
682
+ generate_multistyle_speech,
683
+ inputs=[
684
+ gen_text_input_multistyle,
685
+ ]
686
+ + speech_type_names
687
+ + speech_type_audios
688
+ + speech_type_ref_texts
689
+ + [
690
+ remove_silence_multistyle,
691
+ ],
692
+ outputs=[audio_output_multistyle] + speech_type_ref_texts + [cherrypick_interface_multistyle],
693
+ )
694
+
695
+ # Validation function to disable Generate button if speech types are missing
696
+ def validate_speech_types(gen_text, regular_name, *args):
697
+ speech_type_names_list = args
698
+
699
+ # Collect the speech types names
700
+ speech_types_available = set()
701
+ if regular_name:
702
+ speech_types_available.add(regular_name)
703
+ for name_input in speech_type_names_list:
704
+ if name_input:
705
+ speech_types_available.add(name_input)
706
+
707
+ # Parse the gen_text to get the speech types used
708
+ segments = parse_speechtypes_text(gen_text)
709
+ speech_types_in_text = set(segment["name"] for segment in segments)
710
+
711
+ # Check if all speech types in text are available
712
+ missing_speech_types = speech_types_in_text - speech_types_available
713
+
714
+ if missing_speech_types:
715
+ # Disable the generate button
716
+ return gr.update(interactive=False)
717
+ else:
718
+ # Enable the generate button
719
+ return gr.update(interactive=True)
720
+
721
+ gen_text_input_multistyle.change(
722
+ validate_speech_types,
723
+ inputs=[gen_text_input_multistyle, regular_name] + speech_type_names,
724
+ outputs=generate_multistyle_btn,
725
+ )
726
+
727
+
728
+ with gr.Blocks() as app_chat:
729
+ gr.Markdown(
730
+ """
731
+ # Voice Chat
732
+ Have a conversation with an AI using your reference voice!
733
+ 1. Upload a reference audio clip and optionally its transcript (via text or .txt file).
734
+ 2. Load the chat model.
735
+ 3. Record your message through your microphone or type it.
736
+ 4. The AI will respond using the reference voice.
737
+ """
738
+ )
739
+
740
+ chat_model_name_list = [
741
+ "Qwen/Qwen2.5-3B-Instruct",
742
+ "microsoft/Phi-4-mini-instruct",
743
+ ]
744
+
745
+ @gpu_decorator
746
+ def load_chat_model(chat_model_name):
747
+ show_info = gr.Info
748
+ global chat_model_state, chat_tokenizer_state
749
+ if chat_model_state is not None:
750
+ chat_model_state = None
751
+ chat_tokenizer_state = None
752
+ gc.collect()
753
+ torch.cuda.empty_cache()
754
+
755
+ show_info(f"Loading chat model: {chat_model_name}")
756
+ chat_model_state = AutoModelForCausalLM.from_pretrained(chat_model_name, torch_dtype="auto", device_map="auto")
757
+ chat_tokenizer_state = AutoTokenizer.from_pretrained(chat_model_name)
758
+ show_info(f"Chat model {chat_model_name} loaded successfully!")
759
+
760
+ return gr.update(visible=False), gr.update(visible=True)
761
+
762
+ if USING_SPACES:
763
+ load_chat_model(chat_model_name_list[0])
764
+
765
+ chat_model_name_input = gr.Dropdown(
766
+ choices=chat_model_name_list,
767
+ value=chat_model_name_list[0],
768
+ label="Chat Model Name",
769
+ info="Enter the name of a HuggingFace chat model",
770
+ allow_custom_value=not USING_SPACES,
771
+ )
772
+ load_chat_model_btn = gr.Button("Load Chat Model", variant="primary", visible=not USING_SPACES)
773
+ chat_interface_container = gr.Column(visible=USING_SPACES)
774
+
775
+ chat_model_name_input.change(
776
+ lambda: gr.update(visible=True),
777
+ None,
778
+ load_chat_model_btn,
779
+ show_progress="hidden",
780
+ )
781
+ load_chat_model_btn.click(
782
+ load_chat_model, inputs=[chat_model_name_input], outputs=[load_chat_model_btn, chat_interface_container]
783
+ )
784
+
785
+ with chat_interface_container:
786
+ with gr.Row():
787
+ with gr.Column():
788
+ ref_audio_chat = gr.Audio(label="Reference Audio", type="filepath")
789
+ with gr.Column():
790
+ with gr.Accordion("Advanced Settings", open=False):
791
+ with gr.Row():
792
+ ref_text_chat = gr.Textbox(
793
+ label="Reference Text",
794
+ info="Optional: Leave blank to auto-transcribe",
795
+ lines=2,
796
+ scale=3,
797
+ )
798
+ ref_text_file_chat = gr.File(
799
+ label="Load Reference Text from File (.txt)", file_types=[".txt"], scale=1
800
+ )
801
+ with gr.Row():
802
+ randomize_seed_chat = gr.Checkbox(
803
+ label="Randomize Seed",
804
+ value=True,
805
+ info="Uncheck to use the seed specified.",
806
+ scale=3,
807
+ )
808
+ seed_input_chat = gr.Number(show_label=False, value=0, precision=0, scale=1)
809
+ remove_silence_chat = gr.Checkbox(
810
+ label="Remove Silences",
811
+ value=True,
812
+ )
813
+ system_prompt_chat = gr.Textbox(
814
+ label="System Prompt",
815
+ value="You are not an AI assistant, you are whoever the user says you are. You must stay in character. Keep your responses concise since they will be spoken out loud.",
816
+ lines=2,
817
+ )
818
+
819
+ chatbot_interface = gr.Chatbot(label="Conversation", type="messages")
820
+
821
+ with gr.Row():
822
+ with gr.Column():
823
+ audio_input_chat = gr.Microphone(
824
+ label="Speak your message",
825
+ type="filepath",
826
+ )
827
+ audio_output_chat = gr.Audio(autoplay=True)
828
+ with gr.Column():
829
+ text_input_chat = gr.Textbox(
830
+ label="Type your message",
831
+ lines=1,
832
+ )
833
+ send_btn_chat = gr.Button("Send Message")
834
+ clear_btn_chat = gr.Button("Clear Conversation")
835
+
836
+ # Modify process_audio_input to generate user input
837
+ @gpu_decorator
838
+ def process_audio_input(conv_state, audio_path, text):
839
+ """Handle audio or text input from user"""
840
+
841
+ if not audio_path and not text.strip():
842
+ return conv_state
843
+
844
+ if audio_path:
845
+ text = preprocess_ref_audio_text(audio_path, text)[1]
846
+ if not text.strip():
847
+ return conv_state
848
+
849
+ conv_state.append({"role": "user", "content": text})
850
+ return conv_state
851
+
852
+ # Use model and tokenizer from state to get text response
853
+ @gpu_decorator
854
+ def generate_text_response(conv_state, system_prompt):
855
+ """Generate text response from AI"""
856
+
857
+ system_prompt_state = [{"role": "system", "content": system_prompt}]
858
+ response = chat_model_inference(system_prompt_state + conv_state, chat_model_state, chat_tokenizer_state)
859
+
860
+ conv_state.append({"role": "assistant", "content": response})
861
+ return conv_state
862
+
863
+ @gpu_decorator
864
+ def generate_audio_response(conv_state, ref_audio, ref_text, remove_silence, randomize_seed, seed_input):
865
+ """Generate TTS audio for AI response"""
866
+ if not conv_state or not ref_audio:
867
+ return None, ref_text, seed_input
868
+
869
+ last_ai_response = conv_state[-1]["content"]
870
+ if not last_ai_response or conv_state[-1]["role"] != "assistant":
871
+ return None, ref_text, seed_input
872
+
873
+ if randomize_seed:
874
+ seed_input = np.random.randint(0, 2**31 - 1)
875
+
876
+ audio_result, _, ref_text_out, used_seed = infer(
877
+ ref_audio,
878
+ ref_text,
879
+ last_ai_response,
880
+ tts_model_choice,
881
+ remove_silence,
882
+ seed=seed_input,
883
+ cross_fade_duration=0.15,
884
+ speed=1.0,
885
+ show_info=print, # show_info=print no pull to top when generating
886
+ )
887
+ return audio_result, ref_text_out, used_seed
888
+
889
+ def clear_conversation():
890
+ """Reset the conversation"""
891
+ return [], None
892
+
893
+ ref_text_file_chat.upload(
894
+ load_text_from_file,
895
+ inputs=[ref_text_file_chat],
896
+ outputs=[ref_text_chat],
897
+ )
898
+
899
+ for user_operation in [audio_input_chat.stop_recording, text_input_chat.submit, send_btn_chat.click]:
900
+ user_operation(
901
+ process_audio_input,
902
+ inputs=[chatbot_interface, audio_input_chat, text_input_chat],
903
+ outputs=[chatbot_interface],
904
+ ).then(
905
+ generate_text_response,
906
+ inputs=[chatbot_interface, system_prompt_chat],
907
+ outputs=[chatbot_interface],
908
+ ).then(
909
+ generate_audio_response,
910
+ inputs=[
911
+ chatbot_interface,
912
+ ref_audio_chat,
913
+ ref_text_chat,
914
+ remove_silence_chat,
915
+ randomize_seed_chat,
916
+ seed_input_chat,
917
+ ],
918
+ outputs=[audio_output_chat, ref_text_chat, seed_input_chat],
919
+ ).then(
920
+ lambda: [None, None],
921
+ None,
922
+ [audio_input_chat, text_input_chat],
923
+ )
924
+
925
+ # Handle clear button or system prompt change and reset conversation
926
+ for user_operation in [clear_btn_chat.click, system_prompt_chat.change, chatbot_interface.clear]:
927
+ user_operation(
928
+ clear_conversation,
929
+ outputs=[chatbot_interface, audio_output_chat],
930
+ )
931
+
932
+
933
+ with gr.Blocks() as app_credits:
934
+ gr.Markdown("""
935
+ # Credits
936
+
937
+ * [mrfakename](https://github.com/fakerybakery) for the original [online demo](https://huggingface.co/spaces/mrfakename/E2-F5-TTS)
938
+ * [RootingInLoad](https://github.com/RootingInLoad) for initial chunk generation and podcast app exploration
939
+ * [jpgallegoar](https://github.com/jpgallegoar) for multiple speech-type generation & voice chat
940
+ """)
941
+
942
+
943
+ with gr.Blocks() as app:
944
+ gr.Markdown(
945
+ f"""
946
+ # E2/F5 TTS
947
+
948
+ This is {"a local web UI for [F5 TTS](https://github.com/SWivid/F5-TTS)" if not USING_SPACES else "an online demo for [F5-TTS](https://github.com/SWivid/F5-TTS)"} with advanced batch processing support. This app supports the following TTS models:
949
+
950
+ * [F5-TTS](https://arxiv.org/abs/2410.06885) (A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching)
951
+ * [E2 TTS](https://arxiv.org/abs/2406.18009) (Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS)
952
+
953
+ The checkpoints currently support English and Chinese.
954
+
955
+ If you're having issues, try converting your reference audio to WAV or MP3, clipping it to 12s with ✂ in the bottom right corner (otherwise might have non-optimal auto-trimmed result).
956
+
957
+ **NOTE: Reference text will be automatically transcribed with Whisper if not provided. For best results, keep your reference clips short (<12s). Ensure the audio is fully uploaded before generating.**
958
+ """
959
+ )
960
+
961
+ last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info_v1.txt")
962
+
963
+ def load_last_used_custom():
964
+ try:
965
+ custom = []
966
+ with open(last_used_custom, "r", encoding="utf-8") as f:
967
+ for line in f:
968
+ custom.append(line.strip())
969
+ return custom
970
+ except FileNotFoundError:
971
+ last_used_custom.parent.mkdir(parents=True, exist_ok=True)
972
+ return DEFAULT_TTS_MODEL_CFG
973
+
974
+ def switch_tts_model(new_choice):
975
+ global tts_model_choice
976
+ if new_choice == "Custom": # override in case webpage is refreshed
977
+ custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
978
+ tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
979
+ return (
980
+ gr.update(visible=True, value=custom_ckpt_path),
981
+ gr.update(visible=True, value=custom_vocab_path),
982
+ gr.update(visible=True, value=custom_model_cfg),
983
+ )
984
+ else:
985
+ tts_model_choice = new_choice
986
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
987
+
988
+ def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
989
+ global tts_model_choice
990
+ tts_model_choice = ("Custom", custom_ckpt_path, custom_vocab_path, custom_model_cfg)
991
+ with open(last_used_custom, "w", encoding="utf-8") as f:
992
+ f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")
993
+
994
+ with gr.Row():
995
+ if not USING_SPACES:
996
+ choose_tts_model = gr.Radio(
997
+ choices=[DEFAULT_TTS_MODEL, "E2-TTS", "Custom"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
998
+ )
999
+ else:
1000
+ choose_tts_model = gr.Radio(
1001
+ choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
1002
+ )
1003
+ custom_ckpt_path = gr.Dropdown(
1004
+ choices=[DEFAULT_TTS_MODEL_CFG[0]],
1005
+ value=load_last_used_custom()[0],
1006
+ allow_custom_value=True,
1007
+ label="Model: local_path | hf://user_id/repo_id/model_ckpt",
1008
+ visible=False,
1009
+ )
1010
+ custom_vocab_path = gr.Dropdown(
1011
+ choices=[DEFAULT_TTS_MODEL_CFG[1]],
1012
+ value=load_last_used_custom()[1],
1013
+ allow_custom_value=True,
1014
+ label="Vocab: local_path | hf://user_id/repo_id/vocab_file",
1015
+ visible=False,
1016
+ )
1017
+ custom_model_cfg = gr.Dropdown(
1018
+ choices=[
1019
+ DEFAULT_TTS_MODEL_CFG[2],
1020
+ json.dumps(
1021
+ dict(
1022
+ dim=1024,
1023
+ depth=22,
1024
+ heads=16,
1025
+ ff_mult=2,
1026
+ text_dim=512,
1027
+ text_mask_padding=False,
1028
+ conv_layers=4,
1029
+ pe_attn_head=1,
1030
+ )
1031
+ ),
1032
+ json.dumps(
1033
+ dict(
1034
+ dim=768,
1035
+ depth=18,
1036
+ heads=12,
1037
+ ff_mult=2,
1038
+ text_dim=512,
1039
+ text_mask_padding=False,
1040
+ conv_layers=4,
1041
+ pe_attn_head=1,
1042
+ )
1043
+ ),
1044
+ ],
1045
+ value=load_last_used_custom()[2],
1046
+ allow_custom_value=True,
1047
+ label="Config: in a dictionary form",
1048
+ visible=False,
1049
+ )
1050
+
1051
+ choose_tts_model.change(
1052
+ switch_tts_model,
1053
+ inputs=[choose_tts_model],
1054
+ outputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
1055
+ show_progress="hidden",
1056
  )
1057
+ custom_ckpt_path.change(
1058
+ set_custom_model,
1059
+ inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
1060
+ show_progress="hidden",
1061
+ )
1062
+ custom_vocab_path.change(
1063
+ set_custom_model,
1064
+ inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
1065
+ show_progress="hidden",
1066
+ )
1067
+ custom_model_cfg.change(
1068
+ set_custom_model,
1069
+ inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
1070
+ show_progress="hidden",
1071
+ )
1072
+
1073
+ gr.TabbedInterface(
1074
+ [app_tts, app_multistyle, app_chat, app_credits],
1075
+ ["Basic-TTS", "Multi-Speech", "Voice-Chat", "Credits"],
1076
+ )
1077
+
1078
+
1079
+ @click.command()
1080
+ @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
1081
+ @click.option("--host", "-H", default=None, help="Host to run the app on")
1082
+ @click.option(
1083
+ "--share",
1084
+ "-s",
1085
+ default=False,
1086
+ is_flag=True,
1087
+ help="Share the app via Gradio share link",
1088
+ )
1089
+ @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
1090
+ @click.option(
1091
+ "--root_path",
1092
+ "-r",
1093
+ default=None,
1094
+ type=str,
1095
+ help='The root path (or "mount point") of the application, if it\'s not served from the root ("/") of the domain. Often used when the application is behind a reverse proxy that forwards requests to the application, e.g. set "/myapp" or full URL for application served at "https://example.com/myapp".',
1096
+ )
1097
+ @click.option(
1098
+ "--inbrowser",
1099
+ "-i",
1100
+ is_flag=True,
1101
+ default=False,
1102
+ help="Automatically launch the interface in the default web browser",
1103
+ )
1104
+ def main(port, host, share, api, root_path, inbrowser):
1105
+ global app
1106
+ print("Starting app...")
1107
+ app.queue(api_open=api).launch(
1108
+ server_name=host,
1109
+ server_port=port,
1110
+ share=share,
1111
+ show_api=api,
1112
+ root_path=root_path,
1113
+ inbrowser=inbrowser,
1114
+ )
1115
+
1116
 
1117
  if __name__ == "__main__":
1118
+ if not USING_SPACES:
1119
+ main()
1120
+ else:
1121
+ app.queue().launch()