Approximetal commited on
Commit
16fa5ac
·
verified ·
1 Parent(s): 1914d13

Update inference_gradio.py

Browse files
Files changed (1) hide show
  1. inference_gradio.py +1100 -479
inference_gradio.py CHANGED
@@ -1,576 +1,1197 @@
1
- import gc
2
- import os
3
- import platform
4
- import psutil
5
- import tempfile
6
- from glob import glob
7
- import traceback
8
- import click
9
  import gradio as gr
10
- import torch
11
- import torchaudio
12
- import soundfile as sf
13
- from pathlib import Path
14
-
15
- from cached_path import cached_path
 
 
 
 
 
16
 
17
  from lemas_tts.api import TTS, PRETRAINED_ROOT, CKPTS_ROOT
18
-
19
- # Global variables
20
- tts_api = None
21
- last_checkpoint = ""
22
- last_device = ""
23
- last_ema = None
24
-
25
- # Device detection
26
- device = (
27
- "cuda"
28
- if torch.cuda.is_available()
29
- else "xpu"
30
- if torch.xpu.is_available()
31
- else "mps"
32
- if torch.backends.mps.is_available()
33
- else "cpu"
34
  )
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- REPO_ROOT = Path(__file__).resolve().parent
 
 
37
 
38
- # HF location for large TTS checkpoints (too big for Space storage)
39
- HF_PRETRAINED_ROOT = "hf://LEMAS-Project/LEMAS-TTS/pretrained_models"
40
 
41
- # 指向 `pretrained_models` 里的 espeak-ng-data(本地自带的字典)
42
- # 动态库交给系统安装的 espeak-ng 来提供(通过 apt),不强行指定 PHONEMIZER_ESPEAK_LIBRARY,
43
- # 避免本地复制的 .so 与 Space 基础镜像不兼容。
44
- ESPEAK_DATA_DIR = Path(PRETRAINED_ROOT) / "espeak-ng-data"
45
- os.environ["ESPEAK_DATA_PATH"] = str(ESPEAK_DATA_DIR)
46
- os.environ["ESPEAKNG_DATA_PATH"] = str(ESPEAK_DATA_DIR)
47
 
 
 
 
 
48
 
49
- class UVR5:
50
- """Small wrapper around the bundled uvr5 implementation for denoising."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- def __init__(self, model_dir: Path, code_dir: Path):
53
- self.model = self.load_model(str(model_dir), str(code_dir))
54
 
55
- def load_model(self, model_dir: str, code_dir: str):
56
- import sys
57
- import json
58
 
 
 
 
 
 
 
59
  if code_dir not in sys.path:
60
  sys.path.append(code_dir)
61
-
62
  from multiprocess_cuda_infer import ModelData, Inference
63
-
64
  model_path = os.path.join(model_dir, "Kim_Vocal_1.onnx")
65
  config_path = os.path.join(model_dir, "MDX-Net-Kim-Vocal1.json")
66
  with open(config_path, "r", encoding="utf-8") as f:
67
  configs = json.load(f)
68
  model_data = ModelData(
69
  model_path=model_path,
70
- audio_path=model_dir,
71
- result_path=model_dir,
72
- device="cpu",
73
- process_method="MDX-Net",
74
- base_dir=model_dir, # keep base_dir and model_dir the same (paths under `pretrained_models`)
75
- **configs,
76
  )
77
 
78
- uvr5_model = Inference(model_data, "cpu")
79
  uvr5_model.load_model(model_path, 1)
80
  return uvr5_model
81
-
82
  def denoise(self, audio_info):
83
- print("denoise UVR5: ", audio_info)
84
  input_audio = load_wav(audio_info, sr=44100, channel=2)
85
- output_audio = self.model.demix_base({0: input_audio.squeeze()}, is_match_mix=False)
 
 
86
  return output_audio.squeeze().T.numpy(), 44100
87
 
88
- denoise_model = UVR5(
89
- model_dir=str(Path(PRETRAINED_ROOT) / "uvr5"),
90
- code_dir=str(REPO_ROOT / "uvr5"),
91
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
  def load_wav(audio_info, sr=16000, channel=1):
94
- print("load audio:", audio_info)
95
- audio, raw_sr = torchaudio.load(audio_info)
96
  audio = audio.T if len(audio.shape) > 1 and audio.shape[1] == 2 else audio
97
- audio = audio / torch.max(torch.abs(audio))
98
- audio = audio.squeeze().float()
99
  if channel == 1 and len(audio.shape) == 2: # stereo to mono
100
  audio = audio.mean(dim=0, keepdim=True)
101
  elif channel == 2 and len(audio.shape) == 1:
102
  audio = torch.stack((audio, audio)) # mono to stereo
103
- if raw_sr != sr:
104
  audio = torchaudio.functional.resample(audio.squeeze(), raw_sr, sr)
105
  audio = torch.clip(audio, -0.999, 0.999).squeeze()
106
  return audio
107
 
108
 
109
- def denoise(audio_info):
110
- save_path = "./denoised_audio.wav"
111
- denoised_audio, sr = denoise_model.denoise(audio_info)
112
- sf.write(save_path, denoised_audio, sr, format='wav', subtype='PCM_24')
113
- print("save denoised audio:", save_path)
114
- return save_path
 
115
 
116
- def cancel_denoise(audio_info):
117
- return audio_info
 
 
 
118
 
119
 
120
- def get_checkpoints_project(project_name=None, is_gradio=True):
121
- """Get available checkpoint files"""
122
- checkpoint_dir = [str(CKPTS_ROOT)]
123
- # Remote ckpt locations on HF (used when local ckpts are not present)
124
- remote_ckpts = {
125
- "multilingual_grl": f"{HF_PRETRAINED_ROOT}/ckpts/multilingual_grl/multilingual_grl.safetensors",
126
- "multilingual_prosody": f"{HF_PRETRAINED_ROOT}/ckpts/multilingual_prosody/multilingual_prosody.safetensors",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  }
128
 
129
- if project_name is None:
130
- # Look for checkpoints in local directory
131
- files_checkpoints = []
132
- for path in checkpoint_dir:
133
- if os.path.isdir(path):
134
- files_checkpoints.extend(glob(os.path.join(path, "**/*.pt"), recursive=True))
135
- files_checkpoints.extend(glob(os.path.join(path, "**/*.safetensors"), recursive=True))
136
- break
137
- # Fallback to remote ckpts if none found locally
138
- if not files_checkpoints:
139
- files_checkpoints = list(remote_ckpts.values())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  else:
141
- files_checkpoints = []
142
- if os.path.isdir(checkpoint_dir[0]):
143
- files_checkpoints = glob(os.path.join(checkpoint_dir[0], project_name, "*.pt"))
144
- files_checkpoints.extend(glob(os.path.join(checkpoint_dir[0], project_name, "*.safetensors")))
145
- # If no local ckpts for this project, try remote mapping
146
- if not files_checkpoints:
147
- ckpt = remote_ckpts.get(project_name)
148
- files_checkpoints = [ckpt] if ckpt is not None else []
149
- print("files_checkpoints:", project_name, files_checkpoints)
150
- # Separate pretrained and regular checkpoints
151
- pretrained_checkpoints = [f for f in files_checkpoints if "pretrained_" in os.path.basename(f)]
152
- regular_checkpoints = [
153
- f
154
- for f in files_checkpoints
155
- if "pretrained_" not in os.path.basename(f) and "model_last.pt" not in os.path.basename(f)
156
  ]
157
- last_checkpoint = [f for f in files_checkpoints if "model_last.pt" in os.path.basename(f)]
158
 
159
- # Sort regular checkpoints by number
160
- try:
161
- regular_checkpoints = sorted(
162
- regular_checkpoints, key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0])
163
- )
164
- except (IndexError, ValueError):
165
- regular_checkpoints = sorted(regular_checkpoints)
166
 
167
- # Combine in order: pretrained, regular, last
168
- files_checkpoints = pretrained_checkpoints + regular_checkpoints + last_checkpoint
 
 
 
 
 
169
 
170
- select_checkpoint = None if not files_checkpoints else files_checkpoints[-1]
 
 
 
 
 
 
 
171
 
172
- if is_gradio:
173
- return gr.update(choices=files_checkpoints, value=select_checkpoint)
 
 
174
 
175
- return files_checkpoints, select_checkpoint
176
 
 
 
 
 
 
 
 
177
 
178
- def get_available_projects():
179
- """Get available project names from data directory"""
180
- data_paths = [
181
- str(Path(PRETRAINED_ROOT) / "data"),
182
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
- project_list = []
185
- for data_path in data_paths:
186
- if os.path.isdir(data_path):
187
- for folder in os.listdir(data_path):
188
- path_folder = os.path.join(data_path, folder)
189
- if "test" not in folder:
190
- project_list.append(folder)
 
 
 
 
 
 
 
 
191
  break
192
- # Fallback: if no local data dir, default to known HF projects
193
- if not project_list:
194
- project_list = ["multilingual_grl", "multilingual_prosody"]
195
- project_list.sort()
196
- print("project_list:", project_list)
197
- return project_list
198
-
199
-
200
- def infer(
201
- project, file_checkpoint, exp_name, ref_text, ref_audio, denoise_audio, gen_text, nfe_step, use_ema, separate_langs, frontend, speed, cfg_strength, use_acc_grl, ref_ratio, no_ref_audio, sway_sampling_coef, use_prosody_encoder, seed
202
- ):
203
- global last_checkpoint, last_device, tts_api, last_ema
204
-
205
- # Resolve checkpoint path (local or HF URL)
206
- ckpt_path = file_checkpoint
207
- if isinstance(ckpt_path, str) and ckpt_path.startswith("hf://"):
208
- try:
209
- ckpt_resolved = str(cached_path(ckpt_path))
210
- except Exception as e:
211
- traceback.print_exc()
212
- return None, f"Error downloading checkpoint: {str(e)}", ""
213
- else:
214
- ckpt_resolved = ckpt_path
215
 
216
- if not os.path.isfile(ckpt_resolved):
217
- return None, "Checkpoint not found!", ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
- if denoise_audio:
220
- ref_audio = denoise_audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- device_test = device # Use the global device
 
 
 
223
 
224
- if last_checkpoint != ckpt_resolved or last_device != device_test or last_ema != use_ema or tts_api is None:
225
- if last_checkpoint != ckpt_resolved:
226
- last_checkpoint = ckpt_resolved
227
 
228
- if last_device != device_test:
229
- last_device = device_test
 
230
 
231
- if last_ema != use_ema:
232
- last_ema = use_ema
 
 
 
 
 
 
 
 
 
 
 
233
 
234
- # Automatically enable prosody encoder when using the prosody checkpoint
235
- use_prosody_encoder = True if "prosody" in str(ckpt_resolved) else False
236
 
237
- # Resolve vocab file (local)
238
- local_vocab = Path(PRETRAINED_ROOT) / "data" / project / "vocab.txt"
239
- if not local_vocab.is_file():
240
- return None, "Vocab file not found!", ""
241
- vocab_file = str(local_vocab)
 
 
 
 
242
 
243
- # Resolve prosody encoder config & weights (local)
244
- local_prosody_cfg = Path(CKPTS_ROOT) / "prosody_encoder" / "pretssel_cfg.json"
245
- local_prosody_ckpt = Path(CKPTS_ROOT) / "prosody_encoder" / "prosody_encoder_UnitY2.pt"
246
- if not local_prosody_cfg.is_file() or not local_prosody_ckpt.is_file():
247
- return None, "Prosody encoder files not found!", ""
248
- prosody_cfg_path = str(local_prosody_cfg)
249
- prosody_ckpt_path = str(local_prosody_ckpt)
250
 
251
- try:
252
- tts_api = TTS(
253
- model=exp_name,
254
- ckpt_file=ckpt_resolved,
255
- vocab_file=vocab_file,
256
- device=device_test,
257
- use_ema=use_ema,
258
- frontend=frontend,
259
- use_prosody_encoder=use_prosody_encoder,
260
- prosody_cfg_path=prosody_cfg_path,
261
- prosody_ckpt_path=prosody_ckpt_path,
262
- )
263
- except Exception as e:
264
- traceback.print_exc()
265
- return None, f"Error loading model: {str(e)}", ""
266
-
267
- print("Model loaded >>", device_test, file_checkpoint, use_ema)
268
-
269
- if seed == -1: # -1 used for random
270
- seed = None
271
-
272
- try:
273
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
274
- tts_api.infer(
275
- ref_file=ref_audio,
276
- ref_text=ref_text.strip(),
277
- gen_text=gen_text.strip(),
278
- nfe_step=nfe_step,
279
- separate_langs=separate_langs,
280
- speed=speed,
281
- cfg_strength=cfg_strength,
282
- sway_sampling_coef=sway_sampling_coef,
283
- use_acc_grl=use_acc_grl,
284
- ref_ratio=ref_ratio,
285
- no_ref_audio=no_ref_audio,
286
- use_prosody_encoder=use_prosody_encoder,
287
- file_wave=f.name,
288
- seed=seed,
289
- )
290
- return f.name, f"Device: {tts_api.device}", str(tts_api.seed)
291
- except Exception as e:
292
- traceback.print_exc()
293
- return None, f"Inference error: {str(e)}", ""
294
-
295
-
296
- def get_gpu_stats():
297
- """Get GPU statistics"""
298
- gpu_stats = ""
299
-
300
- if torch.cuda.is_available():
301
- gpu_count = torch.cuda.device_count()
302
- for i in range(gpu_count):
303
- gpu_name = torch.cuda.get_device_name(i)
304
- gpu_properties = torch.cuda.get_device_properties(i)
305
- total_memory = gpu_properties.total_memory / (1024**3) # in GB
306
- allocated_memory = torch.cuda.memory_allocated(i) / (1024**2) # in MB
307
- reserved_memory = torch.cuda.memory_reserved(i) / (1024**2) # in MB
308
-
309
- gpu_stats += (
310
- f"GPU {i} Name: {gpu_name}\n"
311
- f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n"
312
- f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
313
- f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
314
- )
315
- elif torch.xpu.is_available():
316
- gpu_count = torch.xpu.device_count()
317
- for i in range(gpu_count):
318
- gpu_name = torch.xpu.get_device_name(i)
319
- gpu_properties = torch.xpu.get_device_properties(i)
320
- total_memory = gpu_properties.total_memory / (1024**3) # in GB
321
- allocated_memory = torch.xpu.memory_allocated(i) / (1024**2) # in MB
322
- reserved_memory = torch.xpu.memory_reserved(i) / (1024**2) # in MB
323
-
324
- gpu_stats += (
325
- f"GPU {i} Name: {gpu_name}\n"
326
- f"Total GPU memory (GPU {i}): {total_memory:.2f} GB\n"
327
- f"Allocated GPU memory (GPU {i}): {allocated_memory:.2f} MB\n"
328
- f"Reserved GPU memory (GPU {i}): {reserved_memory:.2f} MB\n\n"
329
- )
330
- elif torch.backends.mps.is_available():
331
- gpu_count = 1
332
- gpu_stats += "MPS GPU\n"
333
- total_memory = psutil.virtual_memory().total / (
334
- 1024**3
335
- ) # Total system memory (MPS doesn't have its own memory)
336
- allocated_memory = 0
337
- reserved_memory = 0
338
-
339
- gpu_stats += (
340
- f"Total system memory: {total_memory:.2f} GB\n"
341
- f"Allocated GPU memory (MPS): {allocated_memory:.2f} MB\n"
342
- f"Reserved GPU memory (MPS): {reserved_memory:.2f} MB\n"
343
- )
344
 
345
- else:
346
- gpu_stats = "No GPU available"
347
 
348
- return gpu_stats
 
 
349
 
 
 
 
 
 
 
 
 
350
 
351
- def get_cpu_stats():
352
- """Get CPU statistics"""
353
- cpu_usage = psutil.cpu_percent(interval=1)
354
- memory_info = psutil.virtual_memory()
355
- memory_used = memory_info.used / (1024**2)
356
- memory_total = memory_info.total / (1024**2)
357
- memory_percent = memory_info.percent
358
 
359
- pid = os.getpid()
360
- process = psutil.Process(pid)
361
- nice_value = process.nice()
362
 
363
- cpu_stats = (
364
- f"CPU Usage: {cpu_usage:.2f}%\n"
365
- f"System Memory: {memory_used:.2f} MB used / {memory_total:.2f} MB total ({memory_percent}% used)\n"
366
- f"Process Priority (Nice value): {nice_value}"
367
- )
368
 
369
- return cpu_stats
370
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
- def get_combined_stats():
373
- """Get combined system stats"""
374
- gpu_stats = get_gpu_stats()
375
- cpu_stats = get_cpu_stats()
376
- combined_stats = f"### GPU Stats\n{gpu_stats}\n\n### CPU Stats\n{cpu_stats}"
377
- return combined_stats
378
 
 
379
 
380
- # Create Gradio interface
381
- with gr.Blocks(title="LEMAS-TTS Inference") as app:
382
- gr.Markdown(
383
- """
384
- # Zero-Shot TTS
385
 
386
- Set seed to -1 for random generation.
387
- """
388
- )
389
- with gr.Accordion("Model configuration", open=False):
390
- # Model configuration
391
- with gr.Row():
392
- exp_name = gr.Radio(
393
- label="Model",
394
- choices=["multilingual_grl", "multilingual_prosody"],
395
- value="multilingual_grl",
396
- visible=False,
397
- )
398
- # Project selection
399
- available_projects = get_available_projects()
400
 
401
- # Get initial checkpoints
402
- list_checkpoints, checkpoint_select = get_checkpoints_project(available_projects[0] if available_projects else None, False)
 
403
 
404
- with gr.Row():
405
- with gr.Column(scale=1):
406
- # load_models_btn = gr.Button(value="Load models")
407
- cm_project = gr.Dropdown(
408
- choices=available_projects,
409
- value=available_projects[0] if available_projects else None,
410
- label="Project",
411
- allow_custom_value=True,
412
- scale=4
413
- )
414
-
415
- with gr.Column(scale=5):
416
- cm_checkpoint = gr.Dropdown(
417
- choices=list_checkpoints, value=checkpoint_select, label="Checkpoints", allow_custom_value=True # scale=4,
418
- )
419
- bt_checkpoint_refresh = gr.Button("Refresh", scale=1)
420
 
 
 
421
  with gr.Row():
422
- ch_use_ema = gr.Checkbox(label="Use EMA", visible=False, value=True, scale=2, info="Turn off at early stage might offer better results")
423
- frontend = gr.Radio(label="Frontend", visible=False, choices=["phone", "char", "bpe"], value="phone", scale=3)
424
- separate_langs = gr.Checkbox(label="Separate Languages", visible=False, value=True, scale=2, info="separate language tokens")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
- # Inference parameters
427
  with gr.Row():
428
- nfe_step = gr.Number(label="NFE Step", scale=1, value=64)
429
- speed = gr.Slider(label="Speed", scale=3, value=1.0, minimum=0.5, maximum=1.5, step=0.1)
430
- cfg_strength = gr.Slider(label="CFG Strength", scale=2, value=5.0, minimum=0.0, maximum=10.0, step=1)
431
- sway_sampling_coef = gr.Slider(label="Sway Sampling Coef", scale=2, value=3, minimum=2, maximum=5, step=0.1)
432
- ref_ratio = gr.Slider(label="Ref Ratio", scale=2, value=1.0, minimum=0.0, maximum=1.0, step=0.1)
433
- no_ref_audio = gr.Checkbox(label="No Reference Audio", visible=False, value=False, scale=1, info="No mel condition")
434
- use_acc_grl = gr.Checkbox(label="Use accent grl condition", visible=False, value=True, scale=1, info="Use accent grl condition")
435
- use_prosody_encoder = gr.Checkbox(label="Use prosody encoder", visible=False, value=False, scale=1, info="Use prosody encoder")
436
- seed = gr.Number(label="Random Seed", scale=1, value=-1, minimum=-1)
437
-
438
-
439
- # Input fields
440
- ref_text = gr.Textbox(label="Reference Text", placeholder="Enter the text for the reference audio...")
441
- ref_audio = gr.Audio(label="Reference Audio", type="filepath", interactive=True, show_download_button=True, editable=True)
 
 
 
 
 
 
 
 
 
 
 
 
442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
- with gr.Accordion("Denoise audio (Optional / Recommend)", open=True):
445
  with gr.Row():
446
- denoise_btn = gr.Button(value="Denoise")
447
- cancel_btn = gr.Button(value="Cancel Denoise")
448
- denoise_audio = gr.Audio(label="Denoised Audio", value=None, type="filepath", interactive=True, show_download_button=True, editable=True)
449
-
450
- gen_text = gr.Textbox(label="Text to Generate", placeholder="Enter the text you want to generate...")
451
-
452
- # Inference button and outputs
453
- with gr.Row():
454
- txt_info_gpu = gr.Textbox("", label="Device Info")
455
- seed_info = gr.Textbox(label="Used Random Seed")
456
- check_button_infer = gr.Button("Generate Audio", variant="primary")
457
-
458
- gen_audio = gr.Audio(label="Generated Audio", type="filepath", interactive=True, show_download_button=True, editable=True)
459
-
460
- # Examples
461
- def _resolve_example(name: str) -> str:
462
- local = Path(PRETRAINED_ROOT) / "data" / "test_examples" / name
463
- return str(local) if local.is_file() else ""
464
-
465
- examples = gr.Examples(
466
- examples=[
467
- ["em, #1 I have a list of YouTubers, and I'm gonna be going to their houses and raiding them by.",
468
- _resolve_example("en.wav"),
469
- "我有一份 YouTuber 名单,我打算去他们家,对他们进行突袭。",
470
- ],
471
- ["Te voy a dar un tip #1 que le copia a John Rockefeller, uno de los empresarios más picudos de la historia.",
472
- _resolve_example("es.wav"),
473
- "我要给你一个从历史上最精明的商人之一约翰·洛克菲勒那里抄来的秘诀。",
474
- ],
475
- ["Nova, #1 dia 25 desse mês vai rolar operação the last Frontier.",
476
- _resolve_example("pt.wav"),
477
- "新消息,本月二十五日,'最后的边疆行动'将启动。",
478
- ],
479
- ],
480
- inputs=[
481
- ref_text,
482
- ref_audio,
483
- gen_text,
484
- ],
485
- outputs=[gen_audio, txt_info_gpu, seed_info],
486
- fn=infer,
487
- cache_examples=False
488
- )
489
-
490
- # System Info section at the bottom
491
- gr.Markdown("---")
492
- gr.Markdown("## System Information")
493
- with gr.Accordion("Update System Stats", open=False):
494
- update_button = gr.Button("Update System Stats", scale=1)
495
- output_box = gr.Textbox(label="GPU and CPU Information", lines=5, scale=5)
 
 
 
 
 
 
 
 
 
 
 
 
496
 
497
- def update_stats():
498
- return get_combined_stats()
499
-
500
-
501
- denoise_btn.click(fn=denoise,
502
- inputs=[ref_audio],
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  outputs=[denoise_audio])
504
 
505
- cancel_btn.click(fn=cancel_denoise,
506
- inputs=[ref_audio],
507
  outputs=[denoise_audio])
508
 
509
- # Event handlers
510
- check_button_infer.click(
511
- fn=infer,
512
- inputs=[
513
- cm_project,
514
- cm_checkpoint,
515
- exp_name,
516
- ref_text,
517
- ref_audio,
518
- denoise_audio,
519
- gen_text,
520
- nfe_step,
521
- ch_use_ema,
522
- separate_langs,
523
- frontend,
524
- speed,
525
- cfg_strength,
526
- use_acc_grl,
527
- ref_ratio,
528
- no_ref_audio,
529
- sway_sampling_coef,
530
- use_prosody_encoder,
531
- seed,
532
- ],
533
- outputs=[gen_audio, txt_info_gpu, seed_info],
534
- )
535
-
536
- bt_checkpoint_refresh.click(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
537
- cm_project.change(fn=get_checkpoints_project, inputs=[cm_project], outputs=[cm_checkpoint])
538
-
539
- ref_audio.change(
540
- fn=lambda x: None,
541
- inputs=[ref_audio],
542
- outputs=[denoise_audio]
543
  )
544
-
545
- update_button.click(fn=update_stats, outputs=output_box)
546
-
547
- # Auto-load system stats on startup
548
- app.load(fn=update_stats, outputs=output_box)
549
-
550
-
551
- @click.command()
552
- @click.option("--port", "-p", default=7860, type=int, help="Port to run the app on")
553
- @click.option("--host", "-H", default="0.0.0.0", help="Host to run the app on")
554
- @click.option(
555
- "--share",
556
- "-s",
557
- default=False,
558
- is_flag=True,
559
- help="Share the app via Gradio share link",
560
- )
561
- @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
562
- def main(port, host, share, api):
563
- global app
564
- print("Starting LEMAS-TTS Inference Interface...")
565
- print(f"Device: {device}")
566
- app.queue(api_open=api).launch(
567
- server_name=host,
568
- server_port=port,
569
- share=share,
570
- show_api=api,
571
- allowed_paths=[str(Path(PRETRAINED_ROOT) / "data")],
572
- )
573
 
574
 
575
  if __name__ == "__main__":
576
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, gc
2
+ import re, time
3
+ import logging
4
+ from num2words import num2words
 
 
 
 
5
  import gradio as gr
6
+ import torch, torchaudio
7
+ import numpy as np
8
+ import random
9
+ from scipy.io import wavfile
10
+ import onnx
11
+ import onnxruntime as ort
12
+ import copy
13
+ import uroman as ur
14
+ import jieba, zhconv
15
+ from pypinyin.core import Pinyin
16
+ from pypinyin import Style
17
 
18
  from lemas_tts.api import TTS, PRETRAINED_ROOT, CKPTS_ROOT
19
+ from lemas_tts.infer.edit_multilingual import gen_wav_multilingual
20
+ from lemas_tts.infer.text_norm.txt2pinyin import (
21
+ MyConverter,
22
+ _PAUSE_SYMBOL,
23
+ change_tone_in_bu_or_yi,
24
+ get_phoneme_from_char_and_pinyin,
 
 
 
 
 
 
 
 
 
 
25
  )
26
+ from lemas_tts.infer.text_norm.cn_tn import NSWNormalizer
27
+ # import io
28
+ # import uuid
29
+ _JIEBA_DICT = os.path.join(
30
+ os.path.dirname(__file__),
31
+ "lemas_tts",
32
+ "infer",
33
+ "text_norm",
34
+ "jieba_dict.txt",
35
+ )
36
+ if os.path.isfile(_JIEBA_DICT):
37
+ jieba.set_dictionary(_JIEBA_DICT)
38
 
39
+ # from inference_tts_scale import inference_one_sample as inference_tts
40
+ import langid
41
+ langid.set_languages(['es','pt','zh','en','de','fr','it', 'ru', 'id', 'vi'])
42
 
 
 
43
 
44
+ os.environ['CURL_CA_BUNDLE'] = ''
45
+ DEMO_PATH = os.getenv("DEMO_PATH", "./demo")
46
+ TMP_PATH = os.getenv("TMP_PATH", "./demo/temp")
47
+ MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models")
 
 
48
 
49
+ device = "cuda" if torch.cuda.is_available() else "cpu"
50
+ ASR_DEVICE = "cpu" # force whisperx/pyannote to CPU to avoid cuDNN issues
51
+ whisper_model, align_model = None, None
52
+ tts_edit_model = None
53
 
54
+ _whitespace_re = re.compile(r"\s+")
55
+ alpha_pattern = re.compile(r"[a-zA-Z]")
56
+
57
+ formatter = ("%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s")
58
+ logging.basicConfig(format=formatter, level=logging.INFO)
59
+
60
+ # def get_random_string():
61
+ # return "".join(str(uuid.uuid4()).split("-"))
62
+
63
+ def seed_everything(seed):
64
+ if seed != -1:
65
+ os.environ['PYTHONHASHSEED'] = str(seed)
66
+ random.seed(seed)
67
+ np.random.seed(seed)
68
+ torch.manual_seed(seed)
69
+ torch.cuda.manual_seed(seed)
70
+ torch.backends.cudnn.benchmark = False
71
+ torch.backends.cudnn.deterministic = True
72
 
 
 
73
 
74
+ class UVR5:
75
+ """Small wrapper around the bundled uvr5 implementation for denoising."""
 
76
 
77
+ def __init__(self, model_dir):
78
+ code_dir = os.path.join(os.path.dirname(__file__), "uvr5")
79
+ self.model = self.load_model(model_dir, code_dir)
80
+
81
+ def load_model(self, model_dir, code_dir):
82
+ import sys, json
83
  if code_dir not in sys.path:
84
  sys.path.append(code_dir)
 
85
  from multiprocess_cuda_infer import ModelData, Inference
 
86
  model_path = os.path.join(model_dir, "Kim_Vocal_1.onnx")
87
  config_path = os.path.join(model_dir, "MDX-Net-Kim-Vocal1.json")
88
  with open(config_path, "r", encoding="utf-8") as f:
89
  configs = json.load(f)
90
  model_data = ModelData(
91
  model_path=model_path,
92
+ audio_path = model_dir,
93
+ result_path = model_dir,
94
+ device = 'cpu',
95
+ process_method = "MDX-Net",
96
+ base_dir=model_dir,
97
+ **configs
98
  )
99
 
100
+ uvr5_model = Inference(model_data, 'cpu')
101
  uvr5_model.load_model(model_path, 1)
102
  return uvr5_model
103
+
104
  def denoise(self, audio_info):
 
105
  input_audio = load_wav(audio_info, sr=44100, channel=2)
106
+ output_audio = self.model.demix_base({0:input_audio.squeeze()}, is_match_mix=False)
107
+ # transform = torchaudio.transforms.Resample(44100, 16000)
108
+ # output_audio = transform(output_audio)
109
  return output_audio.squeeze().T.numpy(), 44100
110
 
111
+
112
+ class DeepFilterNet:
113
+ def __init__(self, model_path):
114
+ self.hop_size = 480
115
+ self.fft_size = 960
116
+ self.model = self.load_model(model_path)
117
+
118
+
119
+ def load_model(self, model_path, threads=1):
120
+ sess_options = ort.SessionOptions()
121
+ sess_options.intra_op_num_threads = threads
122
+ sess_options.graph_optimization_level = (ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED)
123
+ sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
124
+
125
+ model = onnx.load_model(model_path)
126
+ ort_session = ort.InferenceSession(
127
+ model.SerializeToString(),
128
+ sess_options,
129
+ providers=["CPUExecutionProvider"], # ["CUDAExecutionProvider"], #
130
+ )
131
+
132
+ input_names = ["input_frame", "states", "atten_lim_db"]
133
+ output_names = ["enhanced_audio_frame", "new_states", "lsnr"]
134
+ return ort_session
135
+
136
+
137
+ def denoise(self, audio_info):
138
+ wav = load_wav(audio_info, 48000)
139
+ orig_len = wav.shape[-1]
140
+ hop_size_divisible_padding_size = (self.hop_size - orig_len % self.hop_size) % self.hop_size
141
+ orig_len += hop_size_divisible_padding_size
142
+ wav = torch.nn.functional.pad(
143
+ wav, (0, self.fft_size + hop_size_divisible_padding_size)
144
+ )
145
+ chunked_audio = torch.split(wav, self.hop_size)
146
+ # chunked_audio = torch.split(wav, int(wav.shape[-1]/2))
147
+
148
+ state = np.zeros(45304,dtype=np.float32)
149
+ atten_lim_db = np.zeros(1,dtype=np.float32)
150
+ enhanced = []
151
+ for frame in chunked_audio:
152
+ out = self.model.run(None,input_feed={"input_frame":frame.numpy(),"states":state,"atten_lim_db":atten_lim_db})
153
+ enhanced.append(torch.tensor(out[0]))
154
+ state = out[1]
155
+
156
+ enhanced_audio = torch.cat(enhanced).unsqueeze(0) # [t] -> [1, t] typical mono format
157
+
158
+ d = self.fft_size - self.hop_size
159
+ enhanced_audio = enhanced_audio[:, d: orig_len + d]
160
+
161
+ return enhanced_audio.squeeze().numpy(), 48000
162
+
163
+
164
+ class TextNorm():
165
+ def __init__(self):
166
+ my_pinyin = Pinyin(MyConverter())
167
+ self.pinyin_parser = my_pinyin.pinyin
168
+
169
+ def sil_type(self, time_s):
170
+ if round(time_s) < 0.4:
171
+ return ""
172
+ elif round(time_s) >= 0.4 and round(time_s) < 0.8:
173
+ return "#1"
174
+ elif round(time_s) >= 0.8 and round(time_s) < 1.5:
175
+ return "#2"
176
+ elif round(time_s) >= 1.5 and round(time_s) < 3.0:
177
+ return "#3"
178
+ elif round(time_s) >= 3.0:
179
+ return "#4"
180
+
181
+
182
+ def add_sil_raw(self, sub_list, start_time, end_time, target_transcript):
183
+ txt = []
184
+ txt_list = [x["word"] for x in sub_list]
185
+ sil = self.sil_type(sub_list[0]["start"])
186
+ if len(sil) > 0:
187
+ txt.append(sil)
188
+ txt.append(txt_list[0])
189
+ for i in range(1, len(sub_list)):
190
+ if sub_list[i]["start"] >= start_time and sub_list[i]["end"] <= end_time:
191
+ txt.append(target_transcript)
192
+ target_transcript = ""
193
+ else:
194
+ sil = self.sil_type(sub_list[i]["start"] - sub_list[i-1]["end"])
195
+ if len(sil) > 0:
196
+ txt.append(sil)
197
+ txt.append(txt_list[i])
198
+ return ' '.join(txt)
199
+
200
+ def add_sil(self, sub_list, start_time, end_time, target_transcript, src_lang, tar_lang):
201
+ txts = []
202
+ txt_list = [x["word"] for x in sub_list]
203
+ sil = self.sil_type(sub_list[0]["start"])
204
+ if len(sil) > 0:
205
+ txts.append([src_lang, sil])
206
+
207
+ if sub_list[0]["start"] < start_time:
208
+ txts.append([src_lang, txt_list[0]])
209
+ for i in range(1, len(sub_list)):
210
+ if sub_list[i]["start"] >= start_time and sub_list[i]["end"] <= end_time:
211
+ txts.append([tar_lang, target_transcript])
212
+ target_transcript = ""
213
+ else:
214
+ sil = self.sil_type(sub_list[i]["start"] - sub_list[i-1]["end"])
215
+ if len(sil) > 0:
216
+ txts.append([src_lang, sil])
217
+ txts.append([src_lang, txt_list[i]])
218
+
219
+ target_txt = [txts[0]]
220
+ for txt in txts[1:]:
221
+ if txt[1] == "":
222
+ continue
223
+ if txt[0] != target_txt[-1][0]:
224
+ target_txt.append([txt[0], ""])
225
+ target_txt[-1][-1] += " " + txt[1]
226
+
227
+ return target_txt
228
+
229
+
230
+ def get_prompt(self, sub_list, start_time, end_time, src_lang):
231
+ txts = []
232
+ txt_list = [x["word"] for x in sub_list]
233
+
234
+ if start_time <= sub_list[0]["start"]:
235
+ sil = self.sil_type(sub_list[0]["start"])
236
+ if len(sil) > 0:
237
+ txts.append([src_lang, sil])
238
+ txts.append([src_lang, txt_list[0]])
239
+
240
+ for i in range(1, len(sub_list)):
241
+ # if sub_list[i]["start"] <= start_time and sub_list[i]["end"] <= end_time:
242
+ # txts.append([tar_lang, target_transcript])
243
+ # target_transcript = ""
244
+ if sub_list[i]["start"] >= start_time and sub_list[i]["end"] <= end_time:
245
+ sil = self.sil_type(sub_list[i]["start"] - sub_list[i-1]["end"])
246
+ if len(sil) > 0:
247
+ txts.append([src_lang, sil])
248
+ txts.append([src_lang, txt_list[i]])
249
+
250
+ target_txt = [txts[0]]
251
+ for txt in txts[1:]:
252
+ if txt[1] == "":
253
+ continue
254
+ if txt[0] != target_txt[-1][0]:
255
+ target_txt.append([txt[0], ""])
256
+ target_txt[-1][-1] += " " + txt[1]
257
+ return target_txt
258
+
259
+
260
+ def txt2pinyin(self, text):
261
+ txts, phonemes = [], []
262
+ texts = re.split(r"(#\d)", text)
263
+ print("before norm: ", texts)
264
+ for text in texts:
265
+ if text in {'#1', '#2', '#3', '#4'}:
266
+ txts.append(text)
267
+ phonemes.append(text)
268
+ continue
269
+ text = NSWNormalizer(text.strip()).normalize()
270
+
271
+ text_list = list(jieba.cut(text))
272
+ print("jieba cut: ", text, text_list)
273
+ for words in text_list:
274
+ if words in _PAUSE_SYMBOL:
275
+ # phonemes.append('#2')
276
+ phonemes[-1] += _PAUSE_SYMBOL[words]
277
+ txts[-1] += words
278
+ elif re.search("[\u4e00-\u9fa5]+", words):
279
+ pinyin = self.pinyin_parser(words, style=Style.TONE3, errors="ignore")
280
+ new_pinyin = []
281
+ for x in pinyin:
282
+ x = "".join(x)
283
+ if "#" not in x:
284
+ new_pinyin.append(x)
285
+ else:
286
+ phonemes.append(words)
287
+ continue
288
+ new_pinyin = change_tone_in_bu_or_yi(words, new_pinyin) if len(words)>1 and words[-1] not in {"一","不"} else new_pinyin
289
+ phoneme = get_phoneme_from_char_and_pinyin(words, new_pinyin)
290
+ phonemes += phoneme
291
+ txts += list(words)
292
+ elif re.search(r"[a-zA-Z]", words) or re.search(r"#[1-4]", words):
293
+ phonemes.append(words)
294
+ txts.append(words)
295
+ # phonemes.append("#1")
296
+ # phones = " ".join(phonemes)
297
+ return txts, phonemes
298
+
299
+
300
+
301
+ def chunk_text(text, max_chars=135):
302
+ """
303
+ Splits the input text into chunks, each with a maximum number of characters.
304
+
305
+ Args:
306
+ text (str): The text to be split.
307
+ max_chars (int): The maximum number of characters per chunk.
308
+
309
+ Returns:
310
+ List[str]: A list of text chunks.
311
+ """
312
+ chunks = []
313
+ current_chunk = ""
314
+ # Split the text into sentences based on punctuation followed by whitespace
315
+ sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)
316
+
317
+ for sentence in sentences:
318
+ if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
319
+ current_chunk += sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
320
+ else:
321
+ if current_chunk:
322
+ chunks.append(current_chunk.strip())
323
+ current_chunk = sentence + " " if sentence and len(sentence[-1].encode("utf-8")) == 1 else sentence
324
+
325
+ if current_chunk:
326
+ chunks.append(current_chunk.strip())
327
+
328
+ return chunks
329
+
330
+
331
+ class MMSAlignModel:
332
+ def __init__(self):
333
+ from torchaudio.pipelines import MMS_FA as bundle
334
+ self.mms_model = bundle.get_model()
335
+ self.mms_model.to(device)
336
+ self.mms_tokenizer = bundle.get_tokenizer()
337
+ self.mms_aligner = bundle.get_aligner()
338
+ self.text_normalizer = ur.Uroman()
339
+
340
+
341
+ def text_normalization(self, text_list):
342
+ text_normalized = []
343
+ for word in text_list:
344
+ text_char = ''
345
+ for c in word:
346
+ if c.isalpha() or c=="'":
347
+ text_char += c.lower()
348
+ elif c == "-":
349
+ text_char += '*'
350
+ text_char = text_char if len(text_char) > 0 else "*"
351
+ text_normalized.append(text_char)
352
+ assert len(text_normalized) == len(text_list), f"normalized text len != raw text len: {len(text_normalized)} != {text_list}"
353
+ return text_normalized
354
+
355
+ def compute_alignments(self, waveform: torch.Tensor, tokens):
356
+ with torch.inference_mode():
357
+ emission, _ = self.mms_model(waveform.to(device))
358
+ token_spans = self.mms_aligner(emission[0], tokens)
359
+ return emission, token_spans
360
+
361
+
362
+ def align(self, data, wav):
363
+ waveform = load_wav(wav, 16000).unsqueeze(0)
364
+ raw_text = data['text'][0]
365
+ text = " ".join(data['text'][1]).replace("-", " ")
366
+ text = re.sub("\s+", " ", text)
367
+ text_normed = self.text_normalizer.romanize_string(text, lcode=data["lang"])
368
+ # text_normed = re.sub("[\d_.,!$£%?#−/]", '', text_normed)
369
+ fliter = re.compile("[^a-z^*^'^ ]")
370
+ text_normed = fliter.sub('', text_normed.lower())
371
+ text_normed = re.sub("\s+", " ", text_normed)
372
+ text_normed = text_normed.split()
373
+ assert len(text_normed) == len(raw_text), f"normalized text len != raw text len: {len(text_normed)} != {len(raw_text)}"
374
+ tokens = self.mms_tokenizer(text_normed)
375
+ with torch.inference_mode():
376
+ emission, _ = self.mms_model(waveform.to(device))
377
+ token_spans = self.mms_aligner(emission[0], tokens)
378
+ num_frames = emission.size(1)
379
+ ratio = waveform.size(1) / num_frames
380
+ res = []
381
+ for i in range(len(token_spans)):
382
+ score = round(sum([x.score for x in token_spans[i]]) / len(token_spans[i]), ndigits=3)
383
+ start = round(waveform.size(-1) * token_spans[i][0].start / num_frames / 16000, ndigits=3)
384
+ end = round(waveform.size(-1) * token_spans[i][-1].end / num_frames / 16000, ndigits=3)
385
+ res.append({"word": raw_text[i], "start": start, "end": end, "score": score})
386
+
387
+ res = {"lang":data["lang"], "start": 0, "end": round(waveform.shape[-1]/16000, ndigits=3), "text_raw":data["text_raw"], "text": text, "words": res}
388
+ return res
389
+
390
+
391
+ class WhisperxModel:
392
+ def __init__(self, model_name):
393
+ from whisperx import load_model
394
+ from pathlib import Path
395
+ prompt = None # "This might be a blend of Simplified Chinese and English speech, do not translate, only transcription be allowed."
396
+
397
+ # Prefer a local VAD model (to avoid network download / 301 issues)
398
+ vad_fp = Path(MODELS_PATH) / "whisperx-vad-segmentation.bin"
399
+ if not vad_fp.is_file():
400
+ logging.warning(
401
+ "Local whisperx VAD not found at %s, falling back to default download path.",
402
+ vad_fp,
403
+ )
404
+ vad_fp = None
405
+
406
+ self.model = load_model(
407
+ model_name,
408
+ ASR_DEVICE,
409
+ compute_type="float32",
410
+ asr_options={
411
+ "suppress_numerals": False,
412
+ "max_new_tokens": None,
413
+ "clip_timestamps": None,
414
+ "initial_prompt": prompt,
415
+ "append_punctuations": ".。,,!!??::、",
416
+ "hallucination_silence_threshold": None,
417
+ "multilingual": True,
418
+ "hotwords": None
419
+ },
420
+ vad_model_fp=str(vad_fp) if vad_fp is not None else None,
421
+ )
422
+
423
+ def transcribe(self, audio_info, lang=None):
424
+ audio = load_wav(audio_info).numpy()
425
+ if lang is None:
426
+ lang = self.model.detect_language(audio)
427
+
428
+ segments = self.model.transcribe(audio, batch_size=8, language=lang)["segments"]
429
+ transcript = " ".join([segment["text"] for segment in segments])
430
+
431
+ if lang not in {'es','pt','zh','en','de','fr','it', 'ar', 'ru', 'ja', 'ko', 'hi', 'th', 'id', 'vi'}:
432
+ lang = langid.classify(transcript)[0]
433
+ segments = self.model.transcribe(audio, batch_size=8, language=lang)["segments"]
434
+ transcript = " ".join([segment["text"] for segment in segments])
435
+ logging.debug(f"whisperx: {segments}")
436
+
437
+ transcript = zhconv.convert(transcript, 'zh-hans')
438
+ transcript = transcript.replace("-", " ")
439
+ transcript = re.sub(_whitespace_re, " ", transcript)
440
+ transcript = transcript[1:] if transcript[0] == " " else transcript
441
+ segments = {'lang':lang, 'text_raw':transcript}
442
+ if lang == "zh":
443
+ segments["text"] = text_norm.txt2pinyin(transcript)
444
+ else:
445
+ transcript = replace_numbers_with_words(transcript, lang=lang).split(' ')
446
+ segments["text"] = (transcript, transcript)
447
+
448
+ return align_model.align(segments, audio_info)
449
+
450
 
451
  def load_wav(audio_info, sr=16000, channel=1):
452
+ raw_sr, audio = audio_info
 
453
  audio = audio.T if len(audio.shape) > 1 and audio.shape[1] == 2 else audio
454
+ audio = audio / np.max(np.abs(audio))
455
+ audio = torch.from_numpy(audio).squeeze().float()
456
  if channel == 1 and len(audio.shape) == 2: # stereo to mono
457
  audio = audio.mean(dim=0, keepdim=True)
458
  elif channel == 2 and len(audio.shape) == 1:
459
  audio = torch.stack((audio, audio)) # mono to stereo
460
+ if raw_sr != sr:
461
  audio = torchaudio.functional.resample(audio.squeeze(), raw_sr, sr)
462
  audio = torch.clip(audio, -0.999, 0.999).squeeze()
463
  return audio
464
 
465
 
466
+ def update_word_time(lst, cut_time, edit_start, edit_end):
467
+ for i in range(len(lst)):
468
+ lst[i]["start"] = round(lst[i]["start"] - cut_time, ndigits=3)
469
+ lst[i]["end"] = round(lst[i]["end"] - cut_time, ndigits=3)
470
+ edit_start = max(round(edit_start - cut_time, ndigits=3), 0)
471
+ edit_end = round(edit_end - cut_time, ndigits=3)
472
+ return lst, edit_start, edit_end
473
 
474
+
475
+ # def update_word_time2(lst, cut_time, edit_start, edit_end):
476
+ # for i in range(len(lst)):
477
+ # lst[i]["start"] = round(lst[i]["start"] + cut_time, ndigits=3)
478
+ # return lst, edit_start, edit_end
479
 
480
 
481
+ def get_audio_slice(audio, words_info, start_time, end_time, max_len=10, sr=16000, code_sr=50):
482
+ audio_dur = audio.shape[-1] / sr
483
+ sub_list = []
484
+ # 如果尾部小于5s则保留后面全部,并截取前半段音频
485
+ if audio_dur - end_time <= max_len/2:
486
+ for word in reversed(words_info):
487
+ if word['start'] > start_time or audio_dur - word['start'] < max_len:
488
+ sub_list = [word] + sub_list
489
+
490
+ # 如果头部小于5s则保留前面全部,并截取后半段音频
491
+ elif start_time <=max_len/2:
492
+ for word in words_info:
493
+ if word['end'] < max(end_time, max_len):
494
+ sub_list += [word]
495
+
496
+ # 如果前后都大于5s,则前后各留5s
497
+ else:
498
+ for word in words_info:
499
+ if word['start'] > start_time - max_len/2 and word['end'] < end_time + max_len/2:
500
+ sub_list += [word]
501
+ audio = audio.squeeze()
502
+
503
+ start = int(sub_list[0]['start']*sr)
504
+ end = int(sub_list[-1]['end']*sr)
505
+ # print("wav cuts:", start, end, (end-start) % int(sr/code_sr))
506
+ end -= (end-start) % int(sr/code_sr) # chunk取整
507
+
508
+ sub_list, start_time, end_time = update_word_time(sub_list, sub_list[0]['start'], start_time, end_time)
509
+ audio = audio.squeeze()
510
+ # print("after update_word_time:", sub_list, start_time, end_time, (end-start)/sr)
511
+
512
+ return (audio[:start], audio[start:end], audio[end:]), (sub_list, start_time, end_time)
513
+
514
+
515
+ def load_models(lemas_model_name, whisper_model_name, alignment_model_name, denoise_model_name): # , audiosr_name):
516
+
517
+ global transcribe_model, align_model, denoise_model, text_norm, tts_edit_model
518
+ # if voicecraft_model:
519
+ # del denoise_model
520
+ # del transcribe_model
521
+ # del align_model
522
+ # del voicecraft_model
523
+ # del audiosr
524
+ torch.cuda.empty_cache()
525
+ gc.collect()
526
+
527
+ if denoise_model_name == "UVR5":
528
+ # Prefer the generic MODELS_PATH root for denoiser assets so that
529
+ # HF Spaces (where pretrained models are often mounted separately)
530
+ # and local runs share the same layout.
531
+ denoise_root = MODELS_PATH # e.g. "./pretrained_models" or env override
532
+ denoise_model = UVR5(os.path.join(denoise_root, "uvr5"))
533
+ elif denoise_model_name == "DeepFilterNet":
534
+ denoise_model = DeepFilterNet("./pretrained_models/denoiser_model.onnx")
535
+
536
+ if alignment_model_name == "MMS":
537
+ align_model = MMSAlignModel()
538
+ else:
539
+ align_model = WhisperxAlignModel()
540
+
541
+ text_norm = TextNorm()
542
+
543
+ transcribe_model = WhisperxModel(whisper_model_name)
544
+
545
+ # Load LEMAS-TTS editing model (selected multilingual variant)
546
+ from pathlib import Path
547
+
548
+ ckpt_dir = Path(CKPTS_ROOT) / lemas_model_name
549
+ ckpt_candidates = sorted(
550
+ list(ckpt_dir.glob("*.safetensors")) + list(ckpt_dir.glob("*.pt"))
551
+ )
552
+ if not ckpt_candidates:
553
+ raise gr.Error(f"No LEMAS-TTS ckpt found under {ckpt_dir}")
554
+ ckpt_file = str(ckpt_candidates[-1])
555
+
556
+ vocab_file = Path(PRETRAINED_ROOT) / "data" / lemas_model_name / "vocab.txt"
557
+ if not vocab_file.is_file():
558
+ raise gr.Error(f"Vocab file not found: {vocab_file}")
559
+
560
+ prosody_cfg = Path(CKPTS_ROOT) / "prosody_encoder" / "pretssel_cfg.json"
561
+ prosody_ckpt = Path(CKPTS_ROOT) / "prosody_encoder" / "prosody_encoder_UnitY2.pt"
562
+
563
+ # Decide whether to enable the prosody encoder:
564
+ # - multilingual_prosody: True (if assets exist)
565
+ # - multilingual_grl: False (GRL-only variant)
566
+ # - others: fall back to presence of assets.
567
+ if lemas_model_name.endswith("prosody"):
568
+ use_prosody = prosody_cfg.is_file() and prosody_ckpt.is_file()
569
+ elif lemas_model_name.endswith("grl"):
570
+ use_prosody = False
571
+ else:
572
+ use_prosody = prosody_cfg.is_file() and prosody_ckpt.is_file()
573
+
574
+ tts_edit_model = TTS(
575
+ model=lemas_model_name,
576
+ ckpt_file=ckpt_file,
577
+ vocab_file=str(vocab_file),
578
+ device=device,
579
+ use_prosody_encoder=use_prosody,
580
+ prosody_cfg_path=str(prosody_cfg) if use_prosody else "",
581
+ prosody_ckpt_path=str(prosody_ckpt) if use_prosody else "",
582
+ ode_method="euler",
583
+ use_ema=True,
584
+ frontend="phone",
585
+ )
586
+ logging.info(f"Loaded LEMAS-TTS edit model from {ckpt_file}")
587
+
588
+ return gr.Accordion()
589
+
590
+
591
+ def get_transcribe_state(segments):
592
+ logging.info("===========After Align===========")
593
+ logging.info(segments)
594
+ return {
595
+ "segments": segments,
596
+ "transcript": segments["text_raw"],
597
+ "words_info": segments["words"],
598
+ "transcript_with_start_time": " ".join([f"{word['start']} {word['word']}" for word in segments["words"]]),
599
+ "transcript_with_end_time": " ".join([f"{word['word']} {word['end']}" for word in segments["words"]]),
600
+ "word_bounds": [f"{word['start']} {word['word']} {word['end']}" for word in segments["words"]]
601
  }
602
 
603
+
604
+ def transcribe(seed, audio_info):
605
+ if transcribe_model is None:
606
+ raise gr.Error("Transcription model not loaded")
607
+ seed_everything(seed)
608
+
609
+ segments = transcribe_model.transcribe(audio_info)
610
+ state = get_transcribe_state(segments)
611
+
612
+ return [
613
+ state["transcript"], state["transcript_with_start_time"], state["transcript_with_end_time"],
614
+ # gr.Dropdown(value=state["word_bounds"][-1], choices=state["word_bounds"], interactive=True), # prompt_to_word
615
+ gr.Dropdown(value=state["word_bounds"][0], choices=state["word_bounds"], interactive=True), # edit_from_word
616
+ gr.Dropdown(value=state["word_bounds"][-1], choices=state["word_bounds"], interactive=True), # edit_to_word
617
+ state
618
+ ]
619
+
620
+ def align(transcript, audio_info, state):
621
+ lang = state["segments"]["lang"]
622
+ # print("realign: ", transcript, state)
623
+ transcript = re.sub(_whitespace_re, " ", transcript)
624
+ transcript = transcript[1:] if transcript[0] == " " else transcript
625
+ segments = {'lang':lang, 'text':transcript, 'text_raw':transcript}
626
+ if lang == "zh":
627
+ segments["text"] = text_norm.txt2pinyin(transcript)
628
  else:
629
+ transcript = replace_numbers_with_words(transcript)
630
+ segments["text"] = (transcript.split(' '), transcript.split(' '))
631
+ # print("text:", segments["text"])
632
+ segments = align_model.align(segments, audio_info)
633
+
634
+ state = get_transcribe_state(segments)
635
+
636
+ return [
637
+ state["transcript"], state["transcript_with_start_time"], state["transcript_with_end_time"],
638
+ # gr.Dropdown(value=state["word_bounds"][-1], choices=state["word_bounds"], interactive=True), # prompt_to_word
639
+ gr.Dropdown(value=state["word_bounds"][0], choices=state["word_bounds"], interactive=True), # edit_from_word
640
+ gr.Dropdown(value=state["word_bounds"][-1], choices=state["word_bounds"], interactive=True), # edit_to_word
641
+ state
 
 
642
  ]
 
643
 
 
 
 
 
 
 
 
644
 
645
+ def denoise(audio_info):
646
+ denoised_audio, sr = denoise_model.denoise(audio_info)
647
+ denoised_audio = denoised_audio # .squeeze().numpy()
648
+ return (sr, denoised_audio)
649
+
650
+ def cancel_denoise(audio_info):
651
+ return audio_info
652
 
653
+ def get_output_audio(audio_tensors, sr):
654
+ result = torch.cat(audio_tensors, -1)
655
+ result = result.squeeze().cpu().numpy()
656
+ result = (result * np.iinfo(np.int16).max).astype(np.int16)
657
+ print("save result:", result.shape)
658
+ # wavfile.write(os.path.join(TMP_PATH, "output.wav"), sr, result)
659
+ return (int(sr), result)
660
+
661
 
662
+ def get_edit_audio_part(audio_info, edit_start, edit_end):
663
+ sr, raw_wav = audio_info
664
+ raw_wav = raw_wav[int(edit_start*sr):int(edit_end*sr)]
665
+ return (sr, raw_wav)
666
 
 
667
 
668
+ def crossfade_concat(chunk1, chunk2, overlap):
669
+ # 计算淡入和淡出系数
670
+ fade_out = torch.cos(torch.linspace(0, torch.pi / 2, overlap)) ** 2
671
+ fade_in = torch.cos(torch.linspace(torch.pi / 2, 0, overlap)) ** 2
672
+ chunk2[:overlap] = chunk1[-overlap:] * fade_out + chunk2[:overlap] * fade_in
673
+ chunk = torch.cat((chunk1[:-overlap], chunk2), dim=0)
674
+ return chunk
675
 
676
+ def replace_numbers_with_words(sentence, lang="en"):
677
+ sentence = re.sub(r'(\d+)', r' \1 ', sentence) # add spaces around numbers
678
+ def replace_with_words(match):
679
+ num = match.group(0)
680
+ try:
681
+ return num2words(num, lang=lang) # Convert numbers to words
682
+ except:
683
+ return num # In case num2words fails (unlikely with digits but just to be safe)
684
+ return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers
685
+
686
+
687
+ def run(seed, nfe_step, speed, cfg_strength, sway_sampling_coef, ref_ratio,
688
+ audio_info, denoised_audio, transcribe_state, transcript, smart_transcript,
689
+ mode, start_time, end_time,
690
+ split_text, selected_sentence, audio_tensors):
691
+ if tts_edit_model is None:
692
+ raise gr.Error("LEMAS-TTS edit model not loaded")
693
+ if smart_transcript and (transcribe_state is None):
694
+ raise gr.Error("Can't use smart transcript: whisper transcript not found")
695
+
696
+ # if mode == "Rerun":
697
+ # colon_position = selected_sentence.find(':')
698
+ # selected_sentence_idx = int(selected_sentence[:colon_position])
699
+ # sentences = [selected_sentence[colon_position + 1:]]
700
+
701
+ # Choose base audio (denoised if duration matches)
702
+ audio_base = audio_info
703
+ audio_dur = round(audio_info[1].shape[0] / audio_info[0], ndigits=3)
704
+ if denoised_audio is not None:
705
+ denoised_dur = round(denoised_audio[1].shape[0] / denoised_audio[0], ndigits=3)
706
+ if audio_dur == denoised_dur or (
707
+ denoised_audio[0] != audio_info[0] and abs(audio_dur - denoised_dur) < 0.1
708
+ ):
709
+ audio_base = denoised_audio
710
+ logging.info("use denoised audio")
711
+
712
+ raw_sr, raw_wav = audio_base
713
+ print("audio_dur: ", audio_dur, raw_sr, raw_wav.shape, start_time, end_time)
714
 
715
+ # Build target text by replacing the selected span with `transcript`
716
+ words = transcribe_state["words_info"]
717
+ if not words:
718
+ raise gr.Error("No word-level alignment found; please run Transcribe first.")
719
+
720
+ start_time = float(start_time)
721
+ end_time = float(end_time)
722
+ if end_time <= start_time:
723
+ raise gr.Error("Edit end time must be greater than start time.")
724
+
725
+ # Find word indices covering the selected region
726
+ start_idx = 0
727
+ for i, w in enumerate(words):
728
+ if w["end"] > start_time:
729
+ start_idx = i
730
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731
 
732
+ end_idx = len(words)
733
+ for i in range(len(words) - 1, -1, -1):
734
+ if words[i]["start"] < end_time:
735
+ end_idx = i + 1
736
+ break
737
+ if end_idx <= start_idx:
738
+ end_idx = min(start_idx + 1, len(words))
739
+
740
+ word_start_sec = float(words[start_idx]["start"])
741
+ word_end_sec = float(words[end_idx - 1]["end"])
742
+
743
+ # Edit span in seconds (relative to full utterance)
744
+ edit_start = max(0.0, word_start_sec - 0.1)
745
+ edit_end = min(word_end_sec + 0.1, audio_dur)
746
+ parts_to_edit = [(edit_start, edit_end)]
747
+
748
+ display_text = transcribe_state["segments"]["text_raw"].strip()
749
+ txt_list = display_text.split(" ") if display_text else [w["word"] for w in words]
750
+
751
+ prefix = " ".join(txt_list[:start_idx]).strip()
752
+ suffix = " ".join(txt_list[end_idx:]).strip()
753
+ new_phrase = transcript.strip()
754
+
755
+ pieces = []
756
+ if prefix:
757
+ pieces.append(prefix)
758
+ if new_phrase:
759
+ pieces.append(new_phrase)
760
+ if suffix:
761
+ pieces.append(suffix)
762
+ target_text = " ".join(pieces)
763
+
764
+ logging.info(
765
+ "target_text: %s (start_idx=%d, end_idx=%d, parts_to_edit=%s)",
766
+ target_text,
767
+ start_idx,
768
+ end_idx,
769
+ parts_to_edit,
770
+ )
771
 
772
+ # Prepare audio for LEMAS-TTS editing (mono, target SR)
773
+ segment_audio = load_wav(audio_base, sr=tts_edit_model.target_sample_rate)
774
+
775
+ seed_val = None if seed == -1 else int(seed)
776
+
777
+ # Decide whether to use prosody encoder at inference based on how TTS was built
778
+ use_prosody_flag = bool(getattr(tts_edit_model, "use_prosody_encoder", False))
779
+
780
+ wav_out, _ = gen_wav_multilingual(
781
+ tts_edit_model,
782
+ segment_audio,
783
+ tts_edit_model.target_sample_rate,
784
+ target_text,
785
+ parts_to_edit,
786
+ speed=float(speed),
787
+ nfe_step=int(nfe_step),
788
+ cfg_strength=float(cfg_strength),
789
+ sway_sampling_coef=float(sway_sampling_coef),
790
+ ref_ratio=float(ref_ratio),
791
+ no_ref_audio=False,
792
+ use_acc_grl=False,
793
+ use_prosody_encoder_flag=use_prosody_flag,
794
+ seed=seed_val,
795
+ )
796
 
797
+ wav_np = wav_out.cpu().numpy()
798
+ wav_np = np.clip(wav_np, -0.999, 0.999)
799
+ wav_int16 = (wav_np * np.iinfo(np.int16).max).astype(np.int16)
800
+ out_sr = int(tts_edit_model.target_sample_rate)
801
 
802
+ output_audio = (out_sr, wav_int16)
803
+ sentences = [f"0: {target_text}"]
804
+ audio_tensors = [torch.from_numpy(wav_np)]
805
 
806
+ component = gr.Dropdown(choices=sentences, value=sentences[0])
807
+ return output_audio, target_text, component, audio_tensors
808
+
809
 
810
+ def update_input_audio(audio_info):
811
+ if audio_info is None:
812
+ return 0, 0, 0
813
+ elif type(audio_info) is str:
814
+ info = torchaudio.info(audio_path)
815
+ max_time = round(info.num_frames / info.sample_rate, 2)
816
+ elif type(audio_info) is tuple:
817
+ max_time = round(audio_info[1].shape[0] / audio_info[0], 2)
818
+ return [
819
+ # gr.Slider(maximum=max_time, value=max_time),
820
+ gr.Slider(maximum=max_time, value=0),
821
+ gr.Slider(maximum=max_time, value=max_time),
822
+ ]
823
 
 
 
824
 
825
+ def change_mode(mode):
826
+ # tts_mode_controls, edit_mode_controls, edit_word_mode, split_text, long_tts_sentence_editor
827
+ return [
828
+ gr.Group(visible=mode != "Edit"),
829
+ gr.Group(visible=mode == "Edit"),
830
+ gr.Radio(visible=mode == "Edit"),
831
+ gr.Radio(visible=mode == "Long TTS"),
832
+ gr.Group(visible=mode == "Long TTS"),
833
+ ]
834
 
 
 
 
 
 
 
 
835
 
836
+ def load_sentence(selected_sentence, audio_tensors):
837
+ if selected_sentence is None:
838
+ return None
839
+ colon_position = selected_sentence.find(':')
840
+ selected_sentence_idx = int(selected_sentence[:colon_position])
841
+ # Use LEMAS-TTS target sample rate if available, otherwise default to 16000
842
+ sr = getattr(tts_edit_model, "target_sample_rate", 16000)
843
+ return get_output_audio([audio_tensors[selected_sentence_idx]], sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
844
 
 
 
845
 
846
+ def update_bound_word(is_first_word, selected_word, edit_word_mode):
847
+ if selected_word is None:
848
+ return None
849
 
850
+ word_start_time = float(selected_word.split(' ')[0])
851
+ word_end_time = float(selected_word.split(' ')[-1])
852
+ if edit_word_mode == "Replace half":
853
+ bound_time = (word_start_time + word_end_time) / 2
854
+ elif is_first_word:
855
+ bound_time = word_start_time
856
+ else:
857
+ bound_time = word_end_time
858
 
859
+ return bound_time
 
 
 
 
 
 
860
 
 
 
 
861
 
862
+ def update_bound_words(from_selected_word, to_selected_word, edit_word_mode):
863
+ return [
864
+ update_bound_word(True, from_selected_word, edit_word_mode),
865
+ update_bound_word(False, to_selected_word, edit_word_mode),
866
+ ]
867
 
 
868
 
869
+ smart_transcript_info = """
870
+ If enabled, the target transcript will be constructed for you:</br>
871
+ - In TTS and Long TTS mode just write the text you want to synthesize.</br>
872
+ - In Edit mode just write the text to replace selected editing segment.</br>
873
+ If disabled, you should write the target transcript yourself:</br>
874
+ - In TTS mode write prompt transcript followed by generation transcript.</br>
875
+ - In Long TTS select split by newline (<b>SENTENCE SPLIT WON'T WORK</b>) and start each line with a prompt transcript.</br>
876
+ - In Edit mode write full prompt</br>
877
+ """
878
+
879
+ demo_original_transcript = ""
880
+
881
+ demo_text = {
882
+ "TTS": {
883
+ "smart": "take over the stage for half an hour,",
884
+ "regular": "Gwynplaine had, besides, for his work and for his feats of strength, take over the stage for half an hour."
885
+ },
886
+ "Edit": {
887
+ "smart": "Just write it line-by-line.",
888
+ "regular": "照片、医疗记录、神经重塑的易损性,这是某种数据库啊!还有PRELESS的脑部扫描、生物管型、神经重塑."
889
+ },
890
+ "Long TTS": {
891
+ "smart": "You can run the model on a big text!\n"
892
+ "Just write it line-by-line. Or sentence-by-sentence.\n"
893
+ "If some sentences sound odd, just rerun the model on them, no need to generate the whole text again!",
894
+ "regular": "Gwynplaine had, besides, for his work and for his feats of strength, You can run the model on a big text!\n"
895
+ "Gwynplaine had, besides, for his work and for his feats of strength, Just write it line-by-line. Or sentence-by-sentence.\n"
896
+ "Gwynplaine had, besides, for his work and for his feats of strength, If some sentences sound odd, just rerun the model on them, no need to generate the whole text again!"
897
+ }
898
+ }
899
 
900
+ all_demo_texts = {vv for k, v in demo_text.items() for kk, vv in v.items()}
 
 
 
 
 
901
 
902
+ demo_words = ['0.069 Gwynplain 0.611', '0.671 had, 0.912', '0.952 besides, 1.414', '1.494 for 1.634', '1.695 his 1.835', '1.915 work 2.136', '2.196 and 2.297', '2.337 for 2.517', '2.557 his 2.678', '2.758 feats 3.019', '3.079 of 3.139', '3.2 strength, 3.561', '4.022 round 4.263', '4.303 his 4.444', '4.524 neck 4.705', '4.745 and 4.825', '4.905 over 5.086', '5.146 his 5.266', '5.307 shoulders, 5.768', '6.23 an 6.33', '6.531 esclavine 7.133', '7.213 of 7.293', '7.353 leather. 7.614']
903
 
904
+ demo_words_info = [{'word': 'Gwynplain', 'start': 0.069, 'end': 0.611, 'score': 0.833}, {'word': 'had,', 'start': 0.671, 'end': 0.912, 'score': 0.879}, {'word': 'besides,', 'start': 0.952, 'end': 1.414, 'score': 0.863}, {'word': 'for', 'start': 1.494, 'end': 1.634, 'score': 0.89}, {'word': 'his', 'start': 1.695, 'end': 1.835, 'score': 0.669}, {'word': 'work', 'start': 1.915, 'end': 2.136, 'score': 0.916}, {'word': 'and', 'start': 2.196, 'end': 2.297, 'score': 0.766}, {'word': 'for', 'start': 2.337, 'end': 2.517, 'score': 0.808}, {'word': 'his', 'start': 2.557, 'end': 2.678, 'score': 0.786}, {'word': 'feats', 'start': 2.758, 'end': 3.019, 'score': 0.97}, {'word': 'of', 'start': 3.079, 'end': 3.139, 'score': 0.752}, {'word': 'strength,', 'start': 3.2, 'end': 3.561, 'score': 0.742}, {'word': 'round', 'start': 4.022, 'end': 4.263, 'score': 0.916}, {'word': 'his', 'start': 4.303, 'end': 4.444, 'score': 0.666}, {'word': 'neck', 'start': 4.524, 'end': 4.705, 'score': 0.908}, {'word': 'and', 'start': 4.745, 'end': 4.825, 'score': 0.882}, {'word': 'over', 'start': 4.905, 'end': 5.086, 'score': 0.847}, {'word': 'his', 'start': 5.146, 'end': 5.266, 'score': 0.791}, {'word': 'shoulders,', 'start': 5.307, 'end': 5.768, 'score': 0.729}, {'word': 'an', 'start': 6.23, 'end': 6.33, 'score': 0.854}, {'word': 'esclavine', 'start': 6.531, 'end': 7.133, 'score': 0.803}, {'word': 'of', 'start': 7.213, 'end': 7.293, 'score': 0.772}, {'word': 'leather.', 'start': 7.353, 'end': 7.614, 'score': 0.896}]
 
 
 
 
905
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
906
 
907
+ def update_demo(mode, smart_transcript, edit_word_mode, transcript, edit_from_word, edit_to_word):
908
+ if transcript not in all_demo_texts:
909
+ return transcript, edit_from_word, edit_to_word
910
 
911
+ replace_half = edit_word_mode == "Replace half"
912
+ change_edit_from_word = edit_from_word == demo_words[2] or edit_from_word == demo_words[3]
913
+ change_edit_to_word = edit_to_word == demo_words[11] or edit_to_word == demo_words[12]
914
+ demo_edit_from_word_value = demo_words[2] if replace_half else demo_words[3]
915
+ demo_edit_to_word_value = demo_words[12] if replace_half else demo_words[11]
916
+ return [
917
+ demo_text[mode]["smart" if smart_transcript else "regular"],
918
+ demo_edit_from_word_value if change_edit_from_word else edit_from_word,
919
+ demo_edit_to_word_value if change_edit_to_word else edit_to_word,
920
+ ]
 
 
 
 
 
 
921
 
922
+ def get_app():
923
+ with gr.Blocks() as app:
924
  with gr.Row():
925
+ with gr.Column(scale=2):
926
+ load_models_btn = gr.Button(value="Load models")
927
+ with gr.Column(scale=5):
928
+ with gr.Accordion("Select models", open=False) as models_selector:
929
+ # For LEMAS-TTS editing, we expose a simple model selector
930
+ # between the two multilingual variants.
931
+ with gr.Row():
932
+ lemas_model_choice = gr.Radio(
933
+ label="Edit Model",
934
+ choices=["multilingual_grl", "multilingual_prosody"],
935
+ value="multilingual_grl",
936
+ interactive=True,
937
+ scale=3,
938
+ )
939
+ denoise_model_choice = gr.Radio(label="Denoise Model", scale=2, value="UVR5", choices=["UVR5", "DeepFilterNet"]) # "830M", "330M_TTSEnhanced", "830M_TTSEnhanced"])
940
+ # whisper_backend_choice = gr.Radio(label="Whisper backend", value="", choices=["whisperX", "whisper"])
941
+ whisper_model_choice = gr.Radio(label="Whisper model", scale=3, value="medium", choices=["base", "small", "medium", "large"])
942
+ align_model_choice = gr.Radio(label="Forced alignment model", scale=2, value="MMS", choices=["whisperX", "MMS"], visible=False)
943
 
 
944
  with gr.Row():
945
+ with gr.Column(scale=2):
946
+ # Use a numpy waveform as default value to avoid Gradio's
947
+ # InvalidPathError with local filesystem paths.
948
+ _demo_value = None
949
+ demo_candidates = [
950
+ os.path.join(DEMO_PATH, "V-00013_en-US.wav"),
951
+ os.path.join(os.path.dirname(__file__), "..", "VoiceCraft", "demo", "V-00013_en-US.wav"),
952
+ ]
953
+ for demo_path in demo_candidates:
954
+ try:
955
+ if not os.path.isfile(demo_path):
956
+ continue
957
+ _demo_wav, _demo_sr = torchaudio.load(demo_path)
958
+ if _demo_wav.dim() > 1 and _demo_wav.shape[0] > 1:
959
+ _demo_wav = _demo_wav.mean(dim=0, keepdim=True)
960
+ _demo_value = (_demo_sr, _demo_wav.squeeze(0).numpy())
961
+ break
962
+ except Exception:
963
+ continue
964
+
965
+ input_audio = gr.Audio(
966
+ value=_demo_value,
967
+ label="Input Audio",
968
+ interactive=True,
969
+ type="numpy",
970
+ )
971
 
972
+ with gr.Row():
973
+ transcribe_btn = gr.Button(value="Transcribe")
974
+ align_btn = gr.Button(value="ReAlign")
975
+ with gr.Group():
976
+ original_transcript = gr.Textbox(label="Original transcript", lines=5, interactive=True, value=demo_original_transcript,
977
+ info="Use whisperx model to get the transcript. Fix and align it if necessary.")
978
+ with gr.Accordion("Word start time", open=False, visible=False):
979
+ transcript_with_start_time = gr.Textbox(label="Start time", lines=5, interactive=False, info="Start time before each word")
980
+ with gr.Accordion("Word end time", open=False, visible=False):
981
+ transcript_with_end_time = gr.Textbox(label="End time", lines=5, interactive=False, info="End time after each word")
982
+
983
+ with gr.Row():
984
+ denoise_btn = gr.Button(value="Denoise")
985
+ cancel_btn = gr.Button(value="Cancel Denoise")
986
+ denoise_audio = gr.Audio(label="Denoised Audio", value=None, interactive=False, type="numpy")
987
+
988
+ with gr.Column(scale=3):
989
+ with gr.Group():
990
+ transcript_inbox = gr.Textbox(label="Text", lines=5, value=demo_text["Edit"]["smart"])
991
+ with gr.Row(visible=False):
992
+ smart_transcript = gr.Checkbox(label="Smart transcript", value=True)
993
+ with gr.Accordion(label="?", open=False):
994
+ info = gr.Markdown(value=smart_transcript_info)
995
+
996
+ mode = gr.Radio(label="Mode", choices=["Edit"], value="Edit", visible=False)
997
+ with gr.Row(visible=False):
998
+ split_text = gr.Radio(label="Split text", choices=["Newline", "Sentence"], value="Newline",
999
+ info="Split text into parts and run TTS for each part.", visible=True)
1000
+ edit_word_mode = gr.Radio(label="Edit word mode", choices=["Replace half", "Replace all"], value="Replace all",
1001
+ info="What to do with first and last word", visible=False)
1002
+
1003
+ # with gr.Group(visible=False) as tts_mode_controls:
1004
+ # with gr.Row():
1005
+ # edit_from_word = gr.Dropdown(label="First word in prompt", choices=demo_words, value=demo_words[12], interactive=True)
1006
+ # edit_to_word = gr.Dropdown(label="Last word in prompt", choices=demo_words, value=demo_words[18], interactive=True)
1007
+ # with gr.Row():
1008
+ # edit_start_time = gr.Slider(label="Prompt start time", minimum=0, maximum=7.614, step=0.001, value=4.022)
1009
+ # edit_end_time = gr.Slider(label="Prompt end time", minimum=0, maximum=7.614, step=0.001, value=5.768)
1010
+ # with gr.Row():
1011
+ # check_btn = gr.Button(value="Check prompt",scale=1)
1012
+ # edit_audio = gr.Audio(label="Prompt Audio", scale=3)
1013
+
1014
+ # with gr.Group() as edit_mode_controls:
1015
+ with gr.Row():
1016
+ edit_from_word = gr.Dropdown(label="First word to edit", choices=demo_words, value=demo_words[12], interactive=True)
1017
+ edit_to_word = gr.Dropdown(label="Last word to edit", choices=demo_words, value=demo_words[18], interactive=True)
1018
+ with gr.Row():
1019
+ edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=7.614, step=0.001, value=4.022)
1020
+ edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=7.614, step=0.001, value=5.768)
1021
+ # Put the button and audio in separate columns so that
1022
+ # the tall audio widget does not overlap the clickable
1023
+ # area of the button.
1024
+ with gr.Row():
1025
+ with gr.Column(scale=1):
1026
+ check_btn = gr.Button(value="Check edit words")
1027
+ with gr.Column(scale=3):
1028
+ edit_audio = gr.Audio(label="Edit word(s)", scale=3, type="numpy")
1029
+
1030
+ run_btn = gr.Button(value="Run", variant="primary")
1031
+
1032
+ with gr.Column(scale=2):
1033
+ output_audio = gr.Audio(label="Output Audio", type="numpy")
1034
+ with gr.Accordion("Inference transcript", open=True):
1035
+ inference_transcript = gr.Textbox(label="Inference transcript", lines=5, interactive=False, info="Inference was performed on this transcript.")
1036
+ with gr.Group(visible=False) as long_tts_sentence_editor:
1037
+ sentence_selector = gr.Dropdown(label="Sentence", value=None,
1038
+ info="Select sentence you want to regenerate")
1039
+ sentence_audio = gr.Audio(label="Sentence Audio", scale=2, type="numpy")
1040
+ rerun_btn = gr.Button(value="Rerun")
1041
 
 
1042
  with gr.Row():
1043
+ with gr.Accordion("Generation Parameters - change these if you are unhappy with the generation", open=False):
1044
+ with gr.Row():
1045
+ nfe_step = gr.Number(
1046
+ label="NFE Step",
1047
+ value=64,
1048
+ precision=0,
1049
+ info="Number of function evaluations (sampling steps).",
1050
+ )
1051
+ speed = gr.Slider(
1052
+ label="Speed",
1053
+ minimum=0.5,
1054
+ maximum=1.5,
1055
+ step=0.1,
1056
+ value=1.0,
1057
+ info="Placeholder for future use; currently not applied.",
1058
+ )
1059
+ cfg_strength = gr.Slider(
1060
+ label="CFG Strength",
1061
+ minimum=0.0,
1062
+ maximum=10.0,
1063
+ step=0.5,
1064
+ value=5.0,
1065
+ info="Classifier-free guidance strength.",
1066
+ )
1067
+
1068
+ with gr.Row():
1069
+ sway_sampling_coef = gr.Slider(
1070
+ label="Sway",
1071
+ minimum=2.0,
1072
+ maximum=5.0,
1073
+ step=0.1,
1074
+ value=3.0,
1075
+ info="Sampling sway coefficient.",
1076
+ )
1077
+ ref_ratio = gr.Slider(
1078
+ label="Ref Ratio",
1079
+ minimum=0.0,
1080
+ maximum=1.0,
1081
+ step=0.05,
1082
+ value=1.0,
1083
+ info="How much to rely on reference audio (if used).",
1084
+ )
1085
+ seed = gr.Number(
1086
+ label="Seed",
1087
+ value=-1,
1088
+ precision=0,
1089
+ info="-1 for random, otherwise fixed seed.",
1090
+ )
1091
+
1092
+
1093
+ audio_tensors = gr.State()
1094
+ transcribe_state = gr.State(value={"words_info": demo_words_info, "lang":"zh"})
1095
+
1096
+
1097
+ edit_word_mode.change(fn=update_demo,
1098
+ inputs=[mode, smart_transcript, edit_word_mode, transcript_inbox, edit_from_word, edit_to_word],
1099
+ outputs=[transcript_inbox, edit_from_word, edit_to_word])
1100
+ smart_transcript.change(
1101
+ fn=update_demo,
1102
+ inputs=[mode, smart_transcript, edit_word_mode, transcript_inbox, edit_from_word, edit_to_word],
1103
+ outputs=[transcript_inbox, edit_from_word, edit_to_word],
1104
+ )
1105
 
1106
+ load_models_btn.click(fn=load_models,
1107
+ inputs=[lemas_model_choice, whisper_model_choice, align_model_choice, denoise_model_choice], # audiosr_choice],
1108
+ outputs=[models_selector])
1109
+
1110
+ input_audio.upload(fn=update_input_audio,
1111
+ inputs=[input_audio],
1112
+ outputs=[edit_start_time, edit_end_time]) # prompt_end_time
1113
+
1114
+ transcribe_btn.click(fn=transcribe,
1115
+ inputs=[seed, input_audio],
1116
+ outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time,
1117
+ edit_from_word, edit_to_word, transcribe_state]) # prompt_to_word
1118
+ align_btn.click(fn=align,
1119
+ inputs=[original_transcript, input_audio, transcribe_state],
1120
+ outputs=[original_transcript, transcript_with_start_time, transcript_with_end_time,
1121
+ edit_from_word, edit_to_word, transcribe_state]) # prompt_to_word
1122
+
1123
+ denoise_btn.click(fn=denoise,
1124
+ inputs=[input_audio],
1125
  outputs=[denoise_audio])
1126
 
1127
+ cancel_btn.click(fn=cancel_denoise,
1128
+ inputs=[input_audio],
1129
  outputs=[denoise_audio])
1130
 
1131
+ # mode.change(fn=change_mode,
1132
+ # inputs=[mode],
1133
+ # outputs=[tts_mode_controls, edit_mode_controls, edit_word_mode, split_text, long_tts_sentence_editor])
1134
+
1135
+ check_btn.click(fn=get_edit_audio_part,
1136
+ inputs=[input_audio, edit_start_time, edit_end_time],
1137
+ outputs=[edit_audio])
1138
+
1139
+ run_btn.click(fn=run,
1140
+ inputs=[
1141
+ seed, nfe_step, speed, cfg_strength, sway_sampling_coef, ref_ratio,
1142
+ input_audio, denoise_audio, transcribe_state, transcript_inbox, smart_transcript,
1143
+ mode, edit_start_time, edit_end_time,
1144
+ split_text, sentence_selector, audio_tensors
1145
+ ],
1146
+ outputs=[output_audio, inference_transcript, sentence_selector, audio_tensors])
1147
+
1148
+ sentence_selector.change(
1149
+ fn=load_sentence,
1150
+ inputs=[sentence_selector, audio_tensors],
1151
+ outputs=[sentence_audio],
 
 
 
 
 
 
 
 
 
 
 
 
 
1152
  )
1153
+ rerun_btn.click(fn=run,
1154
+ inputs=[
1155
+ seed, nfe_step, speed, cfg_strength, sway_sampling_coef, ref_ratio,
1156
+ input_audio, denoise_audio, transcribe_state, transcript_inbox, smart_transcript,
1157
+ gr.State(value="Rerun"), edit_start_time, edit_end_time,
1158
+ split_text, sentence_selector, audio_tensors
1159
+ ],
1160
+ outputs=[output_audio, inference_transcript, sentence_audio, audio_tensors])
1161
+
1162
+ # prompt_to_word.change(fn=update_bound_word,
1163
+ # inputs=[gr.State(False), prompt_to_word, gr.State("Replace all")],
1164
+ # outputs=[prompt_end_time])
1165
+ edit_from_word.change(fn=update_bound_word,
1166
+ inputs=[gr.State(True), edit_from_word, edit_word_mode],
1167
+ outputs=[edit_start_time])
1168
+ edit_to_word.change(fn=update_bound_word,
1169
+ inputs=[gr.State(False), edit_to_word, edit_word_mode],
1170
+ outputs=[edit_end_time])
1171
+ edit_word_mode.change(fn=update_bound_words,
1172
+ inputs=[edit_from_word, edit_to_word, edit_word_mode],
1173
+ outputs=[edit_start_time, edit_end_time])
1174
+
1175
+ return app
 
 
 
 
 
 
1176
 
1177
 
1178
  if __name__ == "__main__":
1179
+ import argparse
1180
+
1181
+ parser = argparse.ArgumentParser(description="VoiceCraft gradio app.")
1182
+
1183
+ parser.add_argument("--demo-path", default="./demo", help="Path to demo directory")
1184
+ parser.add_argument("--tmp-path", default="/cto_labs/vistring/zhaozhiyuan/outputs/voicecraft/tmp", help="Path to tmp directory")
1185
+ parser.add_argument("--models-path", default="/cto_labs/vistring/zhaozhiyuan/outputs/voicecraft/pretrain/VoiceCraft", help="Path to voicecraft models directory")
1186
+ parser.add_argument("--port", default=41020, type=int, help="App port")
1187
+ parser.add_argument("--share", action="store_true", help="Launch with public url")
1188
+ parser.add_argument("--server_name", default="0.0.0.0", type=str, help="Server name for launching the app. 127.0.0.1 for localhost; 0.0.0.0 to allow access from other machines in the local network. Might also give access to external users depends on the firewall settings.")
1189
+
1190
+ os.environ["USER"] = os.getenv("USER", "user")
1191
+ args = parser.parse_args()
1192
+ DEMO_PATH = args.demo_path
1193
+ TMP_PATH = args.tmp_path
1194
+ MODELS_PATH = args.models_path
1195
+
1196
+ app = get_app()
1197
+ app.queue().launch(share=args.share, server_name=args.server_name, server_port=args.port)