Approximetal commited on
Commit
66c9260
·
verified ·
1 Parent(s): 694583e

Update gradio_mix.py

Browse files
Files changed (1) hide show
  1. gradio_mix.py +99 -29
gradio_mix.py CHANGED
@@ -68,7 +68,15 @@ def _pick_device():
68
  return "cuda" if torch.cuda.is_available() else "cpu"
69
 
70
  device = _pick_device()
71
- ASR_DEVICE = "cpu" # force whisperx/pyannote to CPU to avoid cuDNN issues
 
 
 
 
 
 
 
 
72
  whisper_model, align_model = None, None
73
  tts_edit_model = None
74
 
@@ -97,51 +105,68 @@ class UVR5:
97
 
98
  def __init__(self, model_dir):
99
  # Code directory is always the local `uvr5` folder in this repo
100
- code_dir = os.path.join(os.path.dirname(__file__), "uvr5")
101
- self.model = self.load_model(model_dir, code_dir)
 
 
102
 
103
- def load_model(self, model_dir, code_dir):
104
  import sys, json, os, torch
105
- if code_dir not in sys.path:
106
- sys.path.append(code_dir)
 
 
 
 
 
107
  from multiprocess_cuda_infer import ModelData, Inference
108
  # In the minimal LEMAS-TTS layout, UVR5 weights live under:
109
- model_path = os.path.join(model_dir, "Kim_Vocal_1.onnx")
110
- config_path = os.path.join(model_dir, "MDX-Net-Kim-Vocal1.json")
111
  with open(config_path, "r", encoding="utf-8") as f:
112
  configs = json.load(f)
113
  model_data = ModelData(
114
  model_path=model_path,
115
- audio_path=model_dir,
116
- result_path=model_dir,
117
- device="cpu",
118
  process_method="MDX-Net",
119
  # Keep base_dir and model_dir the same so all UVR5 metadata
120
  # (model_data.json, model_name_mapper.json, etc.) are resolved
121
  # under `pretrained_models/uvr5`, matching LEMAS-TTS inference.
122
- base_dir=model_dir,
123
  **configs,
124
  )
125
 
126
- uvr5_model = Inference(model_data, "cpu")
127
  # On HF Spaces with stateless GPU, we must not initialize CUDA in the
128
- # main process. UVR5's internal `load_model` checks `torch.cuda.is_available()`
129
- # and may touch CUDA APIs. Temporarily spoof this to force CPU-only
130
- # providers during UVR5 init.
131
- orig_is_available = torch.cuda.is_available
132
- torch.cuda.is_available = lambda: False
133
- try:
 
 
 
 
 
134
  uvr5_model.load_model(model_path, 1)
135
- finally:
136
- torch.cuda.is_available = orig_is_available
137
- return uvr5_model
 
138
 
139
  def denoise(self, audio_info):
 
 
 
 
140
  input_audio = load_wav(audio_info, sr=44100, channel=2)
141
- output_audio = self.model.demix_base({0:input_audio.squeeze()}, is_match_mix=False)
142
  # transform = torchaudio.transforms.Resample(44100, 16000)
143
  # output_audio = transform(output_audio)
144
- return output_audio.squeeze().T.numpy(), 44100
145
 
146
 
147
  class DeepFilterNet:
@@ -424,14 +449,31 @@ class MMSAlignModel:
424
 
425
  class WhisperxModel:
426
  def __init__(self, model_name):
 
 
 
 
 
 
 
 
427
  from whisperx import load_model
 
428
  prompt = None # "This might be a blend of Simplified Chinese and English speech, do not translate, only transcription be allowed."
429
 
 
 
 
 
 
 
 
 
430
  # Use the lighter Silero VAD backend to avoid pyannote checkpoints
431
  # and their PyTorch 2.6 `weights_only` pickling issues.
432
  self.model = load_model(
433
- model_name,
434
- ASR_DEVICE,
435
  compute_type="float32",
436
  asr_options={
437
  "suppress_numerals": False,
@@ -447,6 +489,9 @@ class WhisperxModel:
447
  )
448
 
449
  def transcribe(self, audio_info, lang=None):
 
 
 
450
  audio = load_wav(audio_info).numpy()
451
  if lang is None:
452
  lang = self.model.detect_language(audio)
@@ -541,7 +586,8 @@ def get_audio_slice(audio, words_info, start_time, end_time, max_len=10, sr=1600
541
  def load_models(lemas_model_name, whisper_model_name, alignment_model_name, denoise_model_name): # , audiosr_name):
542
 
543
  global transcribe_model, align_model, denoise_model, text_norm, tts_edit_model
544
- torch.cuda.empty_cache()
 
545
  gc.collect()
546
 
547
  if denoise_model_name == "UVR5":
@@ -701,9 +747,16 @@ def align(transcript, audio_info, state):
701
  ]
702
 
703
 
 
 
 
704
  def denoise(audio_info):
 
 
 
 
705
  denoised_audio, sr = denoise_model.denoise(audio_info)
706
- denoised_audio = denoised_audio # .squeeze().numpy()
707
  return (sr, denoised_audio)
708
 
709
  def cancel_denoise(audio_info):
@@ -742,6 +795,7 @@ def replace_numbers_with_words(sentence, lang="en"):
742
  return num # In case num2words fails (unlikely with digits but just to be safe)
743
  return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers
744
 
 
745
  @spaces.GPU
746
  @torch.no_grad()
747
  @torch.inference_mode()
@@ -754,6 +808,22 @@ def run(seed, nfe_step, speed, cfg_strength, sway_sampling_coef, ref_ratio,
754
  if smart_transcript and (transcribe_state is None):
755
  raise gr.Error("Can't use smart transcript: whisper transcript not found")
756
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
757
  # if mode == "Rerun":
758
  # colon_position = selected_sentence.find(':')
759
  # selected_sentence_idx = int(selected_sentence[:colon_position])
@@ -1259,4 +1329,4 @@ if __name__ == "__main__":
1259
  MODELS_PATH = args.models_path
1260
 
1261
  app = get_app()
1262
- app.queue().launch(share=args.share, server_name=args.server_name, server_port=args.port)
 
68
  return "cuda" if torch.cuda.is_available() else "cpu"
69
 
70
  device = _pick_device()
71
+ # For WhisperX ASR:
72
+ # - On Spaces we always construct the pipeline lazily inside @spaces.GPU
73
+ # functions, so keep the default "cpu" here to avoid touching CUDA in
74
+ # the main process.
75
+ # - Elsewhere prefer CUDA if available.
76
+ if IS_SPACES:
77
+ ASR_DEVICE = "cpu"
78
+ else:
79
+ ASR_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
80
  whisper_model, align_model = None, None
81
  tts_edit_model = None
82
 
 
105
 
106
  def __init__(self, model_dir):
107
  # Code directory is always the local `uvr5` folder in this repo
108
+ self.code_dir = os.path.join(os.path.dirname(__file__), "uvr5")
109
+ self.model_dir = model_dir
110
+ self.model = None
111
+ self.device = "cpu"
112
 
113
+ def load_model(self, device="cpu"):
114
  import sys, json, os, torch
115
+ if self.code_dir not in sys.path:
116
+ sys.path.append(self.code_dir)
117
+
118
+ # Reuse an already-loaded model if it matches the requested device.
119
+ if self.model is not None and self.device == device:
120
+ return self.model
121
+
122
  from multiprocess_cuda_infer import ModelData, Inference
123
  # In the minimal LEMAS-TTS layout, UVR5 weights live under:
124
+ model_path = os.path.join(self.model_dir, "Kim_Vocal_1.onnx")
125
+ config_path = os.path.join(self.model_dir, "MDX-Net-Kim-Vocal1.json")
126
  with open(config_path, "r", encoding="utf-8") as f:
127
  configs = json.load(f)
128
  model_data = ModelData(
129
  model_path=model_path,
130
+ audio_path=self.model_dir,
131
+ result_path=self.model_dir,
132
+ device=device,
133
  process_method="MDX-Net",
134
  # Keep base_dir and model_dir the same so all UVR5 metadata
135
  # (model_data.json, model_name_mapper.json, etc.) are resolved
136
  # under `pretrained_models/uvr5`, matching LEMAS-TTS inference.
137
+ base_dir=self.model_dir,
138
  **configs,
139
  )
140
 
141
+ uvr5_model = Inference(model_data, device)
142
  # On HF Spaces with stateless GPU, we must not initialize CUDA in the
143
+ # main process. The heavy UVR5 loading happens lazily inside
144
+ # @spaces.GPU functions; this guard is kept only for the CPU path to
145
+ # avoid any accidental CUDA init.
146
+ if IS_SPACES and device == "cpu":
147
+ orig_is_available = torch.cuda.is_available
148
+ torch.cuda.is_available = lambda: False
149
+ try:
150
+ uvr5_model.load_model(model_path, 1)
151
+ finally:
152
+ torch.cuda.is_available = orig_is_available
153
+ else:
154
  uvr5_model.load_model(model_path, 1)
155
+
156
+ self.model = uvr5_model
157
+ self.device = device
158
+ return self.model
159
 
160
  def denoise(self, audio_info):
161
+ # Prefer GPU if available; on Spaces this runs inside @spaces.GPU so
162
+ # CUDA can be safely initialized here.
163
+ device = "cuda" if torch.cuda.is_available() else "cpu"
164
+ model = self.load_model(device=device)
165
  input_audio = load_wav(audio_info, sr=44100, channel=2)
166
+ output_audio = model.demix_base({0:input_audio.squeeze()}, is_match_mix=False, device=device)
167
  # transform = torchaudio.transforms.Resample(44100, 16000)
168
  # output_audio = transform(output_audio)
169
+ return output_audio.squeeze().T.cpu().numpy(), 44100
170
 
171
 
172
  class DeepFilterNet:
 
449
 
450
  class WhisperxModel:
451
  def __init__(self, model_name):
452
+ # Lazily construct the WhisperX pipeline so that on Spaces we only
453
+ # touch CUDA inside @spaces.GPU workers.
454
+ self.model_name = model_name
455
+ self.model = None
456
+
457
+ def _ensure_model(self):
458
+ if self.model is not None:
459
+ return
460
  from whisperx import load_model
461
+
462
  prompt = None # "This might be a blend of Simplified Chinese and English speech, do not translate, only transcription be allowed."
463
 
464
+ # On Spaces, this will be called from within @spaces.GPU so we can
465
+ # safely move the ASR to CUDA if available. Locally we respect the
466
+ # ASR_DEVICE hint.
467
+ if IS_SPACES:
468
+ asr_device = "cuda" if torch.cuda.is_available() else "cpu"
469
+ else:
470
+ asr_device = ASR_DEVICE
471
+
472
  # Use the lighter Silero VAD backend to avoid pyannote checkpoints
473
  # and their PyTorch 2.6 `weights_only` pickling issues.
474
  self.model = load_model(
475
+ self.model_name,
476
+ asr_device,
477
  compute_type="float32",
478
  asr_options={
479
  "suppress_numerals": False,
 
489
  )
490
 
491
  def transcribe(self, audio_info, lang=None):
492
+ # Lazily init the underlying WhisperX pipeline.
493
+ self._ensure_model()
494
+
495
  audio = load_wav(audio_info).numpy()
496
  if lang is None:
497
  lang = self.model.detect_language(audio)
 
586
  def load_models(lemas_model_name, whisper_model_name, alignment_model_name, denoise_model_name): # , audiosr_name):
587
 
588
  global transcribe_model, align_model, denoise_model, text_norm, tts_edit_model
589
+ if not IS_SPACES:
590
+ torch.cuda.empty_cache()
591
  gc.collect()
592
 
593
  if denoise_model_name == "UVR5":
 
747
  ]
748
 
749
 
750
+ @spaces.GPU
751
+ @torch.no_grad()
752
+ @torch.inference_mode()
753
  def denoise(audio_info):
754
+ # Denoiser can be relatively heavy (especially UVR5), so schedule it on
755
+ # GPU workers when running on HF Spaces.
756
+ if denoise_model is None:
757
+ return audio_info
758
  denoised_audio, sr = denoise_model.denoise(audio_info)
759
+ denoised_audio = denoised_audio # already numpy
760
  return (sr, denoised_audio)
761
 
762
  def cancel_denoise(audio_info):
 
795
  return num # In case num2words fails (unlikely with digits but just to be safe)
796
  return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers
797
 
798
+
799
  @spaces.GPU
800
  @torch.no_grad()
801
  @torch.inference_mode()
 
808
  if smart_transcript and (transcribe_state is None):
809
  raise gr.Error("Can't use smart transcript: whisper transcript not found")
810
 
811
+ # On HF Spaces, keep CUDA usage inside this GPU worker: move the edit
812
+ # model and vocoder to GPU here (the weights were loaded on CPU).
813
+ if IS_SPACES and torch.cuda.is_available():
814
+ try:
815
+ if getattr(tts_edit_model, "device", "cpu") != "cuda":
816
+ if hasattr(tts_edit_model, "ema_model"):
817
+ tts_edit_model.ema_model.to("cuda")
818
+ if hasattr(tts_edit_model, "vocoder"):
819
+ try:
820
+ tts_edit_model.vocoder.to("cuda")
821
+ except Exception:
822
+ pass
823
+ tts_edit_model.device = "cuda"
824
+ except Exception as e:
825
+ logging.warning("Failed to move LEMAS-TTS model to CUDA: %s", e)
826
+
827
  # if mode == "Rerun":
828
  # colon_position = selected_sentence.find(':')
829
  # selected_sentence_idx = int(selected_sentence[:colon_position])
 
1329
  MODELS_PATH = args.models_path
1330
 
1331
  app = get_app()
1332
+ app.queue().launch(share=args.share, server_name=args.server_name, server_port=args.port)