Approximetal commited on
Commit
0c63007
·
verified ·
1 Parent(s): be1395e

Update inference_gradio.py

Browse files
Files changed (1) hide show
  1. inference_gradio.py +43 -18
inference_gradio.py CHANGED
@@ -75,14 +75,18 @@ class UVR5:
75
  """Small wrapper around the bundled uvr5 implementation for denoising."""
76
 
77
  def __init__(self, model_dir):
 
78
  code_dir = os.path.join(os.path.dirname(__file__), "uvr5")
79
  self.model = self.load_model(model_dir, code_dir)
80
 
81
  def load_model(self, model_dir, code_dir):
82
- import sys, json
83
  if code_dir not in sys.path:
84
  sys.path.append(code_dir)
85
  from multiprocess_cuda_infer import ModelData, Inference
 
 
 
86
  model_path = os.path.join(model_dir, "Kim_Vocal_1.onnx")
87
  config_path = os.path.join(model_dir, "MDX-Net-Kim-Vocal1.json")
88
  with open(config_path, "r", encoding="utf-8") as f:
@@ -93,7 +97,7 @@ class UVR5:
93
  result_path = model_dir,
94
  device = 'cpu',
95
  process_method = "MDX-Net",
96
- base_dir=model_dir,
97
  **configs
98
  )
99
 
@@ -390,11 +394,12 @@ class MMSAlignModel:
390
 
391
  class WhisperxModel:
392
  def __init__(self, model_name):
393
- from whisperx import load_model
394
  from pathlib import Path
 
 
395
  prompt = None # "This might be a blend of Simplified Chinese and English speech, do not translate, only transcription be allowed."
396
 
397
- # Prefer a local VAD model (to avoid network download / 301 issues)
398
  vad_fp = Path(MODELS_PATH) / "whisperx-vad-segmentation.bin"
399
  if not vad_fp.is_file():
400
  logging.warning(
@@ -402,6 +407,30 @@ class WhisperxModel:
402
  vad_fp,
403
  )
404
  vad_fp = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
 
406
  self.model = load_model(
407
  model_name,
@@ -515,21 +544,17 @@ def get_audio_slice(audio, words_info, start_time, end_time, max_len=10, sr=1600
515
  def load_models(lemas_model_name, whisper_model_name, alignment_model_name, denoise_model_name): # , audiosr_name):
516
 
517
  global transcribe_model, align_model, denoise_model, text_norm, tts_edit_model
518
- # if voicecraft_model:
519
- # del denoise_model
520
- # del transcribe_model
521
- # del align_model
522
- # del voicecraft_model
523
- # del audiosr
524
  torch.cuda.empty_cache()
525
  gc.collect()
526
 
527
  if denoise_model_name == "UVR5":
528
- # Prefer the generic MODELS_PATH root for denoiser assets so that
529
- # HF Spaces (where pretrained models are often mounted separately)
530
- # and local runs share the same layout.
531
- denoise_root = MODELS_PATH # e.g. "./pretrained_models" or env override
532
- denoise_model = UVR5(os.path.join(denoise_root, "uvr5"))
 
 
533
  elif denoise_model_name == "DeepFilterNet":
534
  denoise_model = DeepFilterNet("./pretrained_models/denoiser_model.onnx")
535
 
@@ -1177,10 +1202,10 @@ def get_app():
1177
  if __name__ == "__main__":
1178
  import argparse
1179
 
1180
- parser = argparse.ArgumentParser(description="VoiceCraft gradio app.")
1181
 
1182
- parser.add_argument("--demo-path", default="./demo", help="Path to demo directory")
1183
- parser.add_argument("--tmp-path", default="./pretrained_models/demo/tmp", help="Path to tmp directory")
1184
  parser.add_argument("--port", default=41020, type=int, help="App port")
1185
  parser.add_argument("--share", action="store_true", help="Launch with public url")
1186
  parser.add_argument("--server_name", default="0.0.0.0", type=str, help="Server name for launching the app. 127.0.0.1 for localhost; 0.0.0.0 to allow access from other machines in the local network. Might also give access to external users depends on the firewall settings.")
 
75
  """Small wrapper around the bundled uvr5 implementation for denoising."""
76
 
77
  def __init__(self, model_dir):
78
+ # Code directory is always the local `uvr5` folder in this repo
79
  code_dir = os.path.join(os.path.dirname(__file__), "uvr5")
80
  self.model = self.load_model(model_dir, code_dir)
81
 
82
  def load_model(self, model_dir, code_dir):
83
+ import sys, json, os
84
  if code_dir not in sys.path:
85
  sys.path.append(code_dir)
86
  from multiprocess_cuda_infer import ModelData, Inference
87
+ # In the minimal LEMAS-TTS layout, UVR5 weights live under:
88
+ # <pretrained_models>/uvr5/models/MDX_Net_Models/model_data/
89
+ # Here `model_dir` points to that `model_data` directory.
90
  model_path = os.path.join(model_dir, "Kim_Vocal_1.onnx")
91
  config_path = os.path.join(model_dir, "MDX-Net-Kim-Vocal1.json")
92
  with open(config_path, "r", encoding="utf-8") as f:
 
97
  result_path = model_dir,
98
  device = 'cpu',
99
  process_method = "MDX-Net",
100
+ base_dir=code_dir,
101
  **configs
102
  )
103
 
 
394
 
395
  class WhisperxModel:
396
  def __init__(self, model_name):
 
397
  from pathlib import Path
398
+ import whisperx.vad as wx_vad
399
+ from whisperx import load_model
400
  prompt = None # "This might be a blend of Simplified Chinese and English speech, do not translate, only transcription be allowed."
401
 
402
+ # Prefer a local VAD model (to avoid network download / checksum issues)
403
  vad_fp = Path(MODELS_PATH) / "whisperx-vad-segmentation.bin"
404
  if not vad_fp.is_file():
405
  logging.warning(
 
407
  vad_fp,
408
  )
409
  vad_fp = None
410
+ else:
411
+ # Monkey-patch whisperx.vad.load_vad_model so it loads our local
412
+ # segmentation model without enforcing the baked-in SHA256 check.
413
+ def _patched_load_vad_model(device, vad_onset=0.500, vad_offset=0.363, use_auth_token=None, model_fp=None):
414
+ import torch
415
+ from pyannote.audio import Model
416
+ from pyannote.audio.pipelines import VoiceActivitySegmentation
417
+
418
+ model_path = str(model_fp) if model_fp is not None else str(vad_fp)
419
+ model = Model.from_pretrained(model_path, use_auth_token=use_auth_token)
420
+ hyperparameters = {
421
+ "onset": vad_onset,
422
+ "offset": vad_offset,
423
+ "min_duration_on": 0.1,
424
+ "min_duration_off": 0.1,
425
+ }
426
+ vad_pipeline = VoiceActivitySegmentation(
427
+ segmentation=model,
428
+ device=torch.device(device),
429
+ )
430
+ vad_pipeline.instantiate(hyperparameters)
431
+ return vad_pipeline
432
+
433
+ wx_vad.load_vad_model = _patched_load_vad_model
434
 
435
  self.model = load_model(
436
  model_name,
 
544
  def load_models(lemas_model_name, whisper_model_name, alignment_model_name, denoise_model_name): # , audiosr_name):
545
 
546
  global transcribe_model, align_model, denoise_model, text_norm, tts_edit_model
 
 
 
 
 
 
547
  torch.cuda.empty_cache()
548
  gc.collect()
549
 
550
  if denoise_model_name == "UVR5":
551
+ # Follow LEMAS-TTS layout but resolve from MODELS_PATH (./pretrained_models by default),
552
+ # so that only the main TTS checkpoints can live in hf:// mounts while all
553
+ # auxiliary models (UVR5, vocoder, prosody encoder, etc.) are loaded from
554
+ # the local `pretrained_models` folder.
555
+ from pathlib import Path
556
+ uv_root = Path(MODELS_PATH) / "uvr5" / "models" / "MDX_Net_Models" / "model_data"
557
+ denoise_model = UVR5(str(uv_root))
558
  elif denoise_model_name == "DeepFilterNet":
559
  denoise_model = DeepFilterNet("./pretrained_models/denoiser_model.onnx")
560
 
 
1202
  if __name__ == "__main__":
1203
  import argparse
1204
 
1205
+ parser = argparse.ArgumentParser(description="LEMAS-Edit gradio app.")
1206
 
1207
+ parser.add_argument("--demo-path", default="./pretrained_models/demo", help="Path to demo directory")
1208
+ parser.add_argument("--tmp-path", default="./pretrained_models/tmp", help="Path to tmp directory")
1209
  parser.add_argument("--port", default=41020, type=int, help="App port")
1210
  parser.add_argument("--share", action="store_true", help="Launch with public url")
1211
  parser.add_argument("--server_name", default="0.0.0.0", type=str, help="Server name for launching the app. 127.0.0.1 for localhost; 0.0.0.0 to allow access from other machines in the local network. Might also give access to external users depends on the firewall settings.")