Approximetal commited on
Commit
694583e
·
verified ·
1 Parent(s): bd908ee

Update gradio_mix.py

Browse files
Files changed (1) hide show
  1. gradio_mix.py +29 -19
gradio_mix.py CHANGED
@@ -51,17 +51,24 @@ MODELS_PATH = os.getenv("MODELS_PATH", "./pretrained_models")
51
  # Mirrors LEMAS-TTS `inference_gradio.py`.
52
  HF_PRETRAINED_ROOT = "hf://LEMAS-Project/LEMAS-TTS/pretrained_models"
53
 
 
 
 
54
  # Pick device for the TTS editing model.
55
- # - Default: "cuda" if available, else "cpu"
56
- # - You can override via LEMAS_DEVICE env (e.g. "cpu" or "cuda").
 
 
57
  def _pick_device():
 
 
58
  forced = os.getenv("LEMAS_DEVICE")
59
  if forced:
60
  return forced
61
  return "cuda" if torch.cuda.is_available() else "cpu"
62
 
63
  device = _pick_device()
64
- ASR_DEVICE = "cuda" # force whisperx/pyannote to CPU to avoid cuDNN issues
65
  whisper_model, align_model = None, None
66
  tts_edit_model = None
67
 
@@ -94,32 +101,39 @@ class UVR5:
94
  self.model = self.load_model(model_dir, code_dir)
95
 
96
  def load_model(self, model_dir, code_dir):
97
- import sys, json, os
98
  if code_dir not in sys.path:
99
  sys.path.append(code_dir)
100
  from multiprocess_cuda_infer import ModelData, Inference
101
  # In the minimal LEMAS-TTS layout, UVR5 weights live under:
102
- # <pretrained_models>/uvr5/models/MDX_Net_Models/model_data/
103
- # Here `model_dir` points to that `model_data` directory.
104
  model_path = os.path.join(model_dir, "Kim_Vocal_1.onnx")
105
  config_path = os.path.join(model_dir, "MDX-Net-Kim-Vocal1.json")
106
  with open(config_path, "r", encoding="utf-8") as f:
107
  configs = json.load(f)
108
  model_data = ModelData(
109
  model_path=model_path,
110
- audio_path = model_dir,
111
- result_path = model_dir,
112
- device = _pick_device(),
113
- process_method = "MDX-Net",
114
  # Keep base_dir and model_dir the same so all UVR5 metadata
115
  # (model_data.json, model_name_mapper.json, etc.) are resolved
116
  # under `pretrained_models/uvr5`, matching LEMAS-TTS inference.
117
  base_dir=model_dir,
118
- **configs
119
  )
120
 
121
- uvr5_model = Inference(model_data, _pick_device())
122
- uvr5_model.load_model(model_path, 1)
 
 
 
 
 
 
 
 
 
123
  return uvr5_model
124
 
125
  def denoise(self, audio_info):
@@ -322,11 +336,9 @@ class TextNorm():
322
  def chunk_text(text, max_chars=135):
323
  """
324
  Splits the input text into chunks, each with a maximum number of characters.
325
-
326
  Args:
327
  text (str): The text to be split.
328
  max_chars (int): The maximum number of characters per chunk.
329
-
330
  Returns:
331
  List[str]: A list of text chunks.
332
  """
@@ -688,9 +700,7 @@ def align(transcript, audio_info, state):
688
  state
689
  ]
690
 
691
- @spaces.GPU
692
- @torch.no_grad()
693
- @torch.inference_mode()
694
  def denoise(audio_info):
695
  denoised_audio, sr = denoise_model.denoise(audio_info)
696
  denoised_audio = denoised_audio # .squeeze().numpy()
@@ -1249,4 +1259,4 @@ if __name__ == "__main__":
1249
  MODELS_PATH = args.models_path
1250
 
1251
  app = get_app()
1252
- app.queue().launch(share=args.share, server_name=args.server_name, server_port=args.port)
 
51
  # Mirrors LEMAS-TTS `inference_gradio.py`.
52
  HF_PRETRAINED_ROOT = "hf://LEMAS-Project/LEMAS-TTS/pretrained_models"
53
 
54
+ # Detect whether we are running inside a HF Space with stateless GPU.
55
+ IS_SPACES = os.getenv("SYSTEM") == "spaces"
56
+
57
  # Pick device for the TTS editing model.
58
+ # - On Spaces (SYSTEM=spaces): always use CPU in the main process to respect
59
+ # stateless GPU constraints.
60
+ # - Elsewhere: "cuda" if available, else "cpu", unless overridden via
61
+ # LEMAS_DEVICE env (e.g. "cpu" or "cuda").
62
  def _pick_device():
63
+ if IS_SPACES:
64
+ return "cpu"
65
  forced = os.getenv("LEMAS_DEVICE")
66
  if forced:
67
  return forced
68
  return "cuda" if torch.cuda.is_available() else "cpu"
69
 
70
  device = _pick_device()
71
+ ASR_DEVICE = "cpu" # force whisperx/pyannote to CPU to avoid cuDNN issues
72
  whisper_model, align_model = None, None
73
  tts_edit_model = None
74
 
 
101
  self.model = self.load_model(model_dir, code_dir)
102
 
103
  def load_model(self, model_dir, code_dir):
104
+ import sys, json, os, torch
105
  if code_dir not in sys.path:
106
  sys.path.append(code_dir)
107
  from multiprocess_cuda_infer import ModelData, Inference
108
  # In the minimal LEMAS-TTS layout, UVR5 weights live under:
 
 
109
  model_path = os.path.join(model_dir, "Kim_Vocal_1.onnx")
110
  config_path = os.path.join(model_dir, "MDX-Net-Kim-Vocal1.json")
111
  with open(config_path, "r", encoding="utf-8") as f:
112
  configs = json.load(f)
113
  model_data = ModelData(
114
  model_path=model_path,
115
+ audio_path=model_dir,
116
+ result_path=model_dir,
117
+ device="cpu",
118
+ process_method="MDX-Net",
119
  # Keep base_dir and model_dir the same so all UVR5 metadata
120
  # (model_data.json, model_name_mapper.json, etc.) are resolved
121
  # under `pretrained_models/uvr5`, matching LEMAS-TTS inference.
122
  base_dir=model_dir,
123
+ **configs,
124
  )
125
 
126
+ uvr5_model = Inference(model_data, "cpu")
127
+ # On HF Spaces with stateless GPU, we must not initialize CUDA in the
128
+ # main process. UVR5's internal `load_model` checks `torch.cuda.is_available()`
129
+ # and may touch CUDA APIs. Temporarily spoof this to force CPU-only
130
+ # providers during UVR5 init.
131
+ orig_is_available = torch.cuda.is_available
132
+ torch.cuda.is_available = lambda: False
133
+ try:
134
+ uvr5_model.load_model(model_path, 1)
135
+ finally:
136
+ torch.cuda.is_available = orig_is_available
137
  return uvr5_model
138
 
139
  def denoise(self, audio_info):
 
336
  def chunk_text(text, max_chars=135):
337
  """
338
  Splits the input text into chunks, each with a maximum number of characters.
 
339
  Args:
340
  text (str): The text to be split.
341
  max_chars (int): The maximum number of characters per chunk.
 
342
  Returns:
343
  List[str]: A list of text chunks.
344
  """
 
700
  state
701
  ]
702
 
703
+
 
 
704
  def denoise(audio_info):
705
  denoised_audio, sr = denoise_model.denoise(audio_info)
706
  denoised_audio = denoised_audio # .squeeze().numpy()
 
1259
  MODELS_PATH = args.models_path
1260
 
1261
  app = get_app()
1262
+ app.queue().launch(share=args.share, server_name=args.server_name, server_port=args.port)