Approximetal commited on
Commit
c9c7e92
·
verified ·
1 Parent(s): 1f053ff

Upload gradio_mix.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. gradio_mix.py +66 -82
gradio_mix.py CHANGED
@@ -36,16 +36,10 @@ _JIEBA_DICT = os.path.join(
36
  if os.path.isfile(_JIEBA_DICT):
37
  jieba.set_dictionary(_JIEBA_DICT)
38
 
39
- # import sys
40
- # sys.path.append("/cto_labs/vistring/zhaozhiyuan/code/SpeechAugment/versatile_audio_super_resolution")
41
- # from inference import Predictor
42
-
43
  # from inference_tts_scale import inference_one_sample as inference_tts
44
  import langid
45
- langid.set_languages(['es','pt','zh','en','de','fr','it', 'ar', 'ru', 'ja', 'ko', 'hi', 'th', 'id', 'vi'])
46
 
47
- # import nltk
48
- # nltk.download('punkt')
49
 
50
  os.environ['CURL_CA_BUNDLE'] = ''
51
  DEMO_PATH = os.getenv("DEMO_PATH", "./demo")
@@ -77,57 +71,6 @@ def seed_everything(seed):
77
  torch.backends.cudnn.deterministic = True
78
 
79
 
80
- # class AudioSR:
81
- # def __init__(self, model_name):
82
- # code_dir = "/cto_labs/vistring/zhaozhiyuan/code/SpeechAugment/versatile_audio_super_resolution"
83
- # self.model = self.load_model(model_name, code_dir)
84
- # self.sr = 48000
85
- # self.chunk_size=10.24
86
- # self.overlap=0.16
87
- # self.guidance_scale=1
88
- # self.ddim_steps=20
89
- # self.multiband_ensemble=False
90
-
91
- # def load_model(self, model_name, code_dir):
92
- # import sys, json
93
- # sys.path.append(code_dir)
94
- # from inference import Predictor
95
- # sr_model = Predictor()
96
- # sr_model.setup(model_name)
97
- # return sr_model
98
-
99
- # def audiosr(self, in_wav, src_sr, tar_sr, chunk_size=10.24, overlap=0.16, seed=0, guidance_scale=1, ddim_steps=20, multiband_ensemble=False):
100
- # if seed == 0:
101
- # seed = random.randint(0, 2**32 - 1)
102
- # print(f"Setting seed to: {seed}")
103
- # print(f"overlap = {overlap}")
104
- # print(f"guidance_scale = {guidance_scale}")
105
- # print(f"ddim_steps = {ddim_steps}")
106
- # print(f"chunk_size = {chunk_size}")
107
- # print(f"multiband_ensemble = {multiband_ensemble}")
108
- # print(f"in_wav.shape = {in_wav.shape}")
109
-
110
- # in_wav = torchaudio.functional.resample(in_wav.squeeze(), src_sr, 24000)
111
- # in_wav = in_wav.squeeze().numpy()
112
-
113
- # out_wav = self.model.process_audio(
114
- # in_wav, 24000,
115
- # chunk_size=chunk_size,
116
- # overlap=overlap,
117
- # seed=seed,
118
- # guidance_scale=guidance_scale,
119
- # ddim_steps=ddim_steps,
120
- # multiband_ensemble=multiband_ensemble,
121
- # )
122
- # out_wav = out_wav[:int(self.sr*in_wav.shape[0]/24000)].T
123
- # if tar_sr != self.sr:
124
- # out_wav = torchaudio.functional.resample(torch.from_numpy(out_wav).squeeze(), self.sr, tar_sr)
125
- # else:
126
- # out_wav = torch.from_numpy(out_wav)
127
- # print(f"out.shape = {out_wav.shape} tar_sr={tar_sr}")
128
- # return out_wav.squeeze()
129
-
130
-
131
  class UVR5:
132
  """Small wrapper around the bundled uvr5 implementation for denoising."""
133
 
@@ -465,7 +408,7 @@ class WhisperxModel:
465
  ASR_DEVICE,
466
  compute_type="float32",
467
  asr_options={
468
- "suppress_numerals": True,
469
  "max_new_tokens": None,
470
  "clip_timestamps": None,
471
  "initial_prompt": prompt,
@@ -481,10 +424,7 @@ class WhisperxModel:
481
  audio = load_wav(audio_info).numpy()
482
  if lang is None:
483
  lang = self.model.detect_language(audio)
484
- if lang == 'zh':
485
- self.model.options._replace(initial_prompt="简体中文:")
486
- else:
487
- self.model.options._replace(initial_prompt=None)
488
  segments = self.model.transcribe(audio, batch_size=8, language=lang)["segments"]
489
  transcript = " ".join([segment["text"] for segment in segments])
490
 
@@ -587,7 +527,7 @@ def load_models(lemas_model_name, whisper_model_name, alignment_model_name, deno
587
  if denoise_model_name == "UVR5":
588
  denoise_model = UVR5(os.path.join(str(PRETRAINED_ROOT), "uvr5"))
589
  elif denoise_model_name == "DeepFilterNet":
590
- denoise_model = DeepFilterNet("./audio_preprocess/denoiser_model.onnx")
591
 
592
  if alignment_model_name == "MMS":
593
  align_model = MMSAlignModel()
@@ -615,18 +555,29 @@ def load_models(lemas_model_name, whisper_model_name, alignment_model_name, deno
615
 
616
  prosody_cfg = Path(CKPTS_ROOT) / "prosody_encoder" / "pretssel_cfg.json"
617
  prosody_ckpt = Path(CKPTS_ROOT) / "prosody_encoder" / "prosody_encoder_UnitY2.pt"
618
- use_prosody = prosody_cfg.is_file() and prosody_ckpt.is_file()
 
 
 
 
 
 
 
 
 
 
619
 
620
  tts_edit_model = TTS(
621
  model=lemas_model_name,
622
  ckpt_file=ckpt_file,
623
  vocab_file=str(vocab_file),
624
  device=device,
625
- use_ema=True,
626
- frontend="phone",
627
  use_prosody_encoder=use_prosody,
628
  prosody_cfg_path=str(prosody_cfg) if use_prosody else "",
629
  prosody_ckpt_path=str(prosody_ckpt) if use_prosody else "",
 
 
 
630
  )
631
  logging.info(f"Loaded LEMAS-TTS edit model from {ckpt_file}")
632
 
@@ -819,19 +770,23 @@ def run(seed, nfe_step, speed, cfg_strength, sway_sampling_coef, ref_ratio,
819
 
820
  seed_val = None if seed == -1 else int(seed)
821
 
 
 
 
822
  wav_out, _ = gen_wav_multilingual(
823
  tts_edit_model,
824
  segment_audio,
825
  tts_edit_model.target_sample_rate,
826
  target_text,
827
  parts_to_edit,
 
828
  nfe_step=int(nfe_step),
829
  cfg_strength=float(cfg_strength),
830
  sway_sampling_coef=float(sway_sampling_coef),
831
  ref_ratio=float(ref_ratio),
832
  no_ref_audio=False,
833
  use_acc_grl=False,
834
- use_prosody_encoder_flag=True,
835
  seed=seed_val,
836
  )
837
 
@@ -969,22 +924,46 @@ def get_app():
969
  with gr.Accordion("Select models", open=False) as models_selector:
970
  # For LEMAS-TTS editing, we expose a simple model selector
971
  # between the two multilingual variants.
972
- lemas_model_choice = gr.Radio(
973
- label="LEMAS-TTS Model",
974
- choices=["multilingual_grl", "multilingual_prosody"],
975
- value="multilingual_grl",
976
- interactive=True,
977
- )
978
  with gr.Row():
 
 
 
 
 
 
 
979
  denoise_model_choice = gr.Radio(label="Denoise Model", scale=2, value="UVR5", choices=["UVR5", "DeepFilterNet"]) # "830M", "330M_TTSEnhanced", "830M_TTSEnhanced"])
980
  # whisper_backend_choice = gr.Radio(label="Whisper backend", value="", choices=["whisperX", "whisper"])
981
  whisper_model_choice = gr.Radio(label="Whisper model", scale=3, value="medium", choices=["base", "small", "medium", "large"])
982
  align_model_choice = gr.Radio(label="Forced alignment model", scale=2, value="MMS", choices=["whisperX", "MMS"], visible=False)
983
- # audiosr_choice = gr.Radio(label="AudioSR model", scale=2, value="None", choices=["basic", "speech", "None"])
984
 
985
  with gr.Row():
986
  with gr.Column(scale=2):
987
- input_audio = gr.Audio(value=f"{DEMO_PATH}/V-00013_en-US.wav", label="Input Audio", interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
988
 
989
  with gr.Row():
990
  transcribe_btn = gr.Button(value="Transcribe")
@@ -1000,7 +979,7 @@ def get_app():
1000
  with gr.Row():
1001
  denoise_btn = gr.Button(value="Denoise")
1002
  cancel_btn = gr.Button(value="Cancel Denoise")
1003
- denoise_audio = gr.Audio(label="Denoised Audio", value=None, interactive=False)
1004
 
1005
  with gr.Column(scale=3):
1006
  with gr.Group():
@@ -1035,20 +1014,25 @@ def get_app():
1035
  with gr.Row():
1036
  edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=7.614, step=0.001, value=4.022)
1037
  edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=7.614, step=0.001, value=5.768)
 
 
 
1038
  with gr.Row():
1039
- check_btn = gr.Button(value="Check edit words",scale=1)
1040
- edit_audio = gr.Audio(label="Edit word(s)", scale=3)
 
 
1041
 
1042
  run_btn = gr.Button(value="Run", variant="primary")
1043
 
1044
  with gr.Column(scale=2):
1045
- output_audio = gr.Audio(label="Output Audio")
1046
  with gr.Accordion("Inference transcript", open=True):
1047
  inference_transcript = gr.Textbox(label="Inference transcript", lines=5, interactive=False, info="Inference was performed on this transcript.")
1048
  with gr.Group(visible=False) as long_tts_sentence_editor:
1049
  sentence_selector = gr.Dropdown(label="Sentence", value=None,
1050
  info="Select sentence you want to regenerate")
1051
- sentence_audio = gr.Audio(label="Sentence Audio", scale=2)
1052
  rerun_btn = gr.Button(value="Rerun")
1053
 
1054
  with gr.Row():
@@ -1064,7 +1048,7 @@ def get_app():
1064
  label="Speed",
1065
  minimum=0.5,
1066
  maximum=1.5,
1067
- step=0.05,
1068
  value=1.0,
1069
  info="Placeholder for future use; currently not applied.",
1070
  )
 
36
  if os.path.isfile(_JIEBA_DICT):
37
  jieba.set_dictionary(_JIEBA_DICT)
38
 
 
 
 
 
39
  # from inference_tts_scale import inference_one_sample as inference_tts
40
  import langid
41
+ langid.set_languages(['es','pt','zh','en','de','fr','it', 'ru', 'id', 'vi'])
42
 
 
 
43
 
44
  os.environ['CURL_CA_BUNDLE'] = ''
45
  DEMO_PATH = os.getenv("DEMO_PATH", "./demo")
 
71
  torch.backends.cudnn.deterministic = True
72
 
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  class UVR5:
75
  """Small wrapper around the bundled uvr5 implementation for denoising."""
76
 
 
408
  ASR_DEVICE,
409
  compute_type="float32",
410
  asr_options={
411
+ "suppress_numerals": False,
412
  "max_new_tokens": None,
413
  "clip_timestamps": None,
414
  "initial_prompt": prompt,
 
424
  audio = load_wav(audio_info).numpy()
425
  if lang is None:
426
  lang = self.model.detect_language(audio)
427
+
 
 
 
428
  segments = self.model.transcribe(audio, batch_size=8, language=lang)["segments"]
429
  transcript = " ".join([segment["text"] for segment in segments])
430
 
 
527
  if denoise_model_name == "UVR5":
528
  denoise_model = UVR5(os.path.join(str(PRETRAINED_ROOT), "uvr5"))
529
  elif denoise_model_name == "DeepFilterNet":
530
+ denoise_model = DeepFilterNet("./pretrained_models/denoiser_model.onnx")
531
 
532
  if alignment_model_name == "MMS":
533
  align_model = MMSAlignModel()
 
555
 
556
  prosody_cfg = Path(CKPTS_ROOT) / "prosody_encoder" / "pretssel_cfg.json"
557
  prosody_ckpt = Path(CKPTS_ROOT) / "prosody_encoder" / "prosody_encoder_UnitY2.pt"
558
+
559
+ # Decide whether to enable the prosody encoder:
560
+ # - multilingual_prosody: True (if assets exist)
561
+ # - multilingual_grl: False (GRL-only variant)
562
+ # - others: fall back to presence of assets.
563
+ if lemas_model_name.endswith("prosody"):
564
+ use_prosody = prosody_cfg.is_file() and prosody_ckpt.is_file()
565
+ elif lemas_model_name.endswith("grl"):
566
+ use_prosody = False
567
+ else:
568
+ use_prosody = prosody_cfg.is_file() and prosody_ckpt.is_file()
569
 
570
  tts_edit_model = TTS(
571
  model=lemas_model_name,
572
  ckpt_file=ckpt_file,
573
  vocab_file=str(vocab_file),
574
  device=device,
 
 
575
  use_prosody_encoder=use_prosody,
576
  prosody_cfg_path=str(prosody_cfg) if use_prosody else "",
577
  prosody_ckpt_path=str(prosody_ckpt) if use_prosody else "",
578
+ ode_method="euler",
579
+ use_ema=True,
580
+ frontend="phone",
581
  )
582
  logging.info(f"Loaded LEMAS-TTS edit model from {ckpt_file}")
583
 
 
770
 
771
  seed_val = None if seed == -1 else int(seed)
772
 
773
+ # Decide whether to use prosody encoder at inference based on how TTS was built
774
+ use_prosody_flag = bool(getattr(tts_edit_model, "use_prosody_encoder", False))
775
+
776
  wav_out, _ = gen_wav_multilingual(
777
  tts_edit_model,
778
  segment_audio,
779
  tts_edit_model.target_sample_rate,
780
  target_text,
781
  parts_to_edit,
782
+ speed=float(speed),
783
  nfe_step=int(nfe_step),
784
  cfg_strength=float(cfg_strength),
785
  sway_sampling_coef=float(sway_sampling_coef),
786
  ref_ratio=float(ref_ratio),
787
  no_ref_audio=False,
788
  use_acc_grl=False,
789
+ use_prosody_encoder_flag=use_prosody_flag,
790
  seed=seed_val,
791
  )
792
 
 
924
  with gr.Accordion("Select models", open=False) as models_selector:
925
  # For LEMAS-TTS editing, we expose a simple model selector
926
  # between the two multilingual variants.
 
 
 
 
 
 
927
  with gr.Row():
928
+ lemas_model_choice = gr.Radio(
929
+ label="Edit Model",
930
+ choices=["multilingual_grl", "multilingual_prosody"],
931
+ value="multilingual_grl",
932
+ interactive=True,
933
+ scale=3,
934
+ )
935
  denoise_model_choice = gr.Radio(label="Denoise Model", scale=2, value="UVR5", choices=["UVR5", "DeepFilterNet"]) # "830M", "330M_TTSEnhanced", "830M_TTSEnhanced"])
936
  # whisper_backend_choice = gr.Radio(label="Whisper backend", value="", choices=["whisperX", "whisper"])
937
  whisper_model_choice = gr.Radio(label="Whisper model", scale=3, value="medium", choices=["base", "small", "medium", "large"])
938
  align_model_choice = gr.Radio(label="Forced alignment model", scale=2, value="MMS", choices=["whisperX", "MMS"], visible=False)
 
939
 
940
  with gr.Row():
941
  with gr.Column(scale=2):
942
+ # Use a numpy waveform as default value to avoid Gradio's
943
+ # InvalidPathError with local filesystem paths.
944
+ _demo_value = None
945
+ demo_candidates = [
946
+ os.path.join(DEMO_PATH, "V-00013_en-US.wav"),
947
+ os.path.join(os.path.dirname(__file__), "..", "VoiceCraft", "demo", "V-00013_en-US.wav"),
948
+ ]
949
+ for demo_path in demo_candidates:
950
+ try:
951
+ if not os.path.isfile(demo_path):
952
+ continue
953
+ _demo_wav, _demo_sr = torchaudio.load(demo_path)
954
+ if _demo_wav.dim() > 1 and _demo_wav.shape[0] > 1:
955
+ _demo_wav = _demo_wav.mean(dim=0, keepdim=True)
956
+ _demo_value = (_demo_sr, _demo_wav.squeeze(0).numpy())
957
+ break
958
+ except Exception:
959
+ continue
960
+
961
+ input_audio = gr.Audio(
962
+ value=_demo_value,
963
+ label="Input Audio",
964
+ interactive=True,
965
+ type="numpy",
966
+ )
967
 
968
  with gr.Row():
969
  transcribe_btn = gr.Button(value="Transcribe")
 
979
  with gr.Row():
980
  denoise_btn = gr.Button(value="Denoise")
981
  cancel_btn = gr.Button(value="Cancel Denoise")
982
+ denoise_audio = gr.Audio(label="Denoised Audio", value=None, interactive=False, type="numpy")
983
 
984
  with gr.Column(scale=3):
985
  with gr.Group():
 
1014
  with gr.Row():
1015
  edit_start_time = gr.Slider(label="Edit from time", minimum=0, maximum=7.614, step=0.001, value=4.022)
1016
  edit_end_time = gr.Slider(label="Edit to time", minimum=0, maximum=7.614, step=0.001, value=5.768)
1017
+ # Put the button and audio in separate columns so that
1018
+ # the tall audio widget does not overlap the clickable
1019
+ # area of the button.
1020
  with gr.Row():
1021
+ with gr.Column(scale=1):
1022
+ check_btn = gr.Button(value="Check edit words")
1023
+ with gr.Column(scale=3):
1024
+ edit_audio = gr.Audio(label="Edit word(s)", scale=3, type="numpy")
1025
 
1026
  run_btn = gr.Button(value="Run", variant="primary")
1027
 
1028
  with gr.Column(scale=2):
1029
+ output_audio = gr.Audio(label="Output Audio", type="numpy")
1030
  with gr.Accordion("Inference transcript", open=True):
1031
  inference_transcript = gr.Textbox(label="Inference transcript", lines=5, interactive=False, info="Inference was performed on this transcript.")
1032
  with gr.Group(visible=False) as long_tts_sentence_editor:
1033
  sentence_selector = gr.Dropdown(label="Sentence", value=None,
1034
  info="Select sentence you want to regenerate")
1035
+ sentence_audio = gr.Audio(label="Sentence Audio", scale=2, type="numpy")
1036
  rerun_btn = gr.Button(value="Rerun")
1037
 
1038
  with gr.Row():
 
1048
  label="Speed",
1049
  minimum=0.5,
1050
  maximum=1.5,
1051
+ step=0.1,
1052
  value=1.0,
1053
  info="Placeholder for future use; currently not applied.",
1054
  )