ASesYusuf1 commited on
Commit
5db82ef
·
verified ·
1 Parent(s): deb8a70

Update audio_separator/separator/separator.py

Browse files
audio_separator/separator/separator.py CHANGED
@@ -1,27 +1,3 @@
1
- """ This file contains the Separator class, to facilitate the separation of stems from audio. """
2
-
3
- from importlib import metadata, resources
4
- import os
5
- import sys
6
- import platform
7
- import subprocess
8
- import time
9
- import logging
10
- import warnings
11
- import importlib
12
- import io
13
- from typing import Optional
14
-
15
- import hashlib
16
- import json
17
- import yaml
18
- import requests
19
- import torch
20
- import torch.amp.autocast_mode as autocast_mode
21
- import onnxruntime as ort
22
- from tqdm import tqdm
23
-
24
-
25
  import os
26
  import logging
27
  import requests
@@ -44,7 +20,7 @@ class Separator:
44
  """
45
  Optimized Separator class for audio source separation on Hugging Face Zero GPU.
46
  Supports MDX, VR, Demucs, and MDXC architectures with ONNX Runtime and PyTorch.
47
- Optimized for memory efficiency, fast inference, and serverless environments.
48
  """
49
  def __init__(
50
  self,
@@ -167,52 +143,41 @@ class Separator:
167
  raise RuntimeError(f"Failed to download {url}: {response.status_code}")
168
 
169
  def list_supported_model_files(self):
170
- """Fetch supported model files from predefined sources."""
171
  download_checks_url = "https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json"
172
  download_checks_path = os.path.join(self.model_file_dir, "download_checks.json")
173
  self.download_file_if_not_exists(download_checks_url, download_checks_path)
174
  model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
175
 
176
- # Mock model scores for simplicity (replace with actual model-scores.json if available)
177
- model_scores = {
178
- "UVR-MDX-NET-Inst_full_292.onnx": {
179
- "median_scores": {
180
- "vocals": {"SDR": 10.6497, "SIR": 20.3786, "SAR": 10.692, "ISR": 14.848},
181
- "instrumental": {"SDR": 15.2149, "SIR": 25.6075, "SAR": 17.1363, "ISR": 17.7893}
182
- },
183
  "stems": ["vocals", "instrumental"],
184
- "target_stem": "vocals"
185
- },
186
- "htdemucs_ft.yaml": {
187
- "median_scores": {
188
- "vocals": {"SDR": 11.2685, "SIR": 21.257, "SAR": 11.0359, "ISR": 19.3753},
189
- "drums": {"SDR": 13.235, "SIR": 23.3053, "SAR": 13.0313, "ISR": 17.2889},
190
- "bass": {"SDR": 9.72743, "SIR": 19.5435, "SAR": 9.20801, "ISR": 13.5037}
191
- },
192
- "stems": ["vocals", "drums", "bass"],
193
- "target_stem": "vocals"
194
  },
195
- "MDX23C-8KFFT-InstVoc_HQ.ckpt": {
196
- "median_scores": {
197
- "vocals": {"SDR": 11.9504, "SIR": 23.1166, "SAR": 12.093, "ISR": 15.4782},
198
- "instrumental": {"SDR": 16.3035, "SIR": 26.6161, "SAR": 18.5167, "ISR": 18.3939}
199
- },
200
  "stems": ["vocals", "instrumental"],
201
- "target_stem": "vocals"
202
- }
 
 
203
  }
204
 
205
  public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
206
  audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
207
 
208
- # Simplified model list for MDX, VR, Demucs, MDXC
209
  model_files_grouped_by_type = {
210
  "MDX": {
211
  "MDX-Net Model: UVR-MDX-NET-Inst_full_292": {
212
  "filename": "UVR-MDX-NET-Inst_full_292.onnx",
213
- "scores": model_scores.get("UVR-MDX-NET-Inst_full_292.onnx", {}).get("median_scores", {}),
214
- "stems": model_scores.get("UVR-MDX-NET-Inst_full_292.onnx", {}).get("stems", []),
215
- "target_stem": model_scores.get("UVR-MDX-NET-Inst_full_292.onnx", {}).get("target_stem"),
216
  "download_files": ["UVR-MDX-NET-Inst_full_292.onnx"]
217
  }
218
  },
@@ -228,27 +193,16 @@ class Separator:
228
  "Demucs": {
229
  "Demucs v4: htdemucs_ft": {
230
  "filename": "htdemucs_ft.yaml",
231
- "scores": model_scores.get("htdemucs_ft.yaml", {}).get("median_scores", {}),
232
- "stems": model_scores.get("htdemucs_ft.yaml", {}).get("stems", []),
233
- "target_stem": model_scores.get("htdemucs_ft.yaml", {}).get("target_stem"),
234
  "download_files": [
235
  f"{public_model_repo_url_prefix}/htdemucs_ft.yaml",
236
  "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th"
237
  ]
238
  }
239
  },
240
- "MDXC": {
241
- "MDX23C Model: MDX23C-InstVoc HQ": {
242
- "filename": "MDX23C-8KFFT-InstVoc_HQ.ckpt",
243
- "scores": model_scores.get("MDX23C-8KFFT-InstVoc_HQ.ckpt", {}).get("median_scores", {}),
244
- "stems": model_scores.get("MDX23C-8KFFT-InstVoc_HQ.ckpt", {}).get("stems", []),
245
- "target_stem": model_scores.get("MDX23C-8KFFT-InstVoc_HQ.ckpt", {}).get("target_stem"),
246
- "download_files": [
247
- "MDX23C-8KFFT-InstVoc_HQ.ckpt",
248
- f"{audio_separator_models_repo_url_prefix}/model_2_stem_full_band_8k.yaml"
249
- ]
250
- }
251
- }
252
  }
253
  return model_files_grouped_by_type
254
 
@@ -289,6 +243,7 @@ class Separator:
289
  if file_to_download.endswith(".yaml"):
290
  yaml_config_filename = file_to_download
291
  return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
 
292
  raise ValueError(f"Model file {model_filename} not found")
293
 
294
  def load_model_data_from_yaml(self, yaml_config_filename):
@@ -321,6 +276,7 @@ class Separator:
321
  if model_hash in model_data:
322
  self.logger.debug(f"Model data loaded for hash {model_hash}")
323
  return model_data[model_hash]
 
324
  raise ValueError(f"No model data for hash {model_hash}")
325
 
326
  def load_model(self, model_filename="UVR-MDX-NET-Inst_full_292.onnx"):
@@ -328,9 +284,14 @@ class Separator:
328
  self.logger.info(f"Loading model {model_filename}")
329
  start_time = time.perf_counter()
330
 
331
- model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
332
- model_name = model_filename.split(".")[0]
 
 
 
 
333
 
 
334
  model_data = self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename else self.load_model_data_using_hash(model_path)
335
 
336
  common_params = {
@@ -461,9 +422,13 @@ class Separator:
461
  def download_model_and_data(self, model_filename):
462
  """Download model files without loading into memory."""
463
  self.logger.info(f"Downloading model {model_filename}")
464
- model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
465
- model_data = self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename else self.load_model_data_using_hash(model_path)
466
- self.logger.info(f"Model downloaded: {model_friendly_name}, type: {model_type}, path: {model_path}, data items: {len(model_data)}")
 
 
 
 
467
 
468
  def get_simplified_model_list(self, filter_sort_by: Optional[str] = None):
469
  """Return a simplified list of models."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import logging
3
  import requests
 
20
  """
21
  Optimized Separator class for audio source separation on Hugging Face Zero GPU.
22
  Supports MDX, VR, Demucs, and MDXC architectures with ONNX Runtime and PyTorch.
23
+ Handles MelBand Roformer models and ensures robust model downloading.
24
  """
25
  def __init__(
26
  self,
 
143
  raise RuntimeError(f"Failed to download {url}: {response.status_code}")
144
 
145
  def list_supported_model_files(self):
146
+ """Fetch supported model files, including MelBand Roformer models."""
147
  download_checks_url = "https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json"
148
  download_checks_path = os.path.join(self.model_file_dir, "download_checks.json")
149
  self.download_file_if_not_exists(download_checks_url, download_checks_path)
150
  model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
151
 
152
+ # Custom model list from the Gradio log
153
+ roformer_models = {
154
+ "MelBand Roformer Kim | Inst V1 (E) Plus by Unwa": {
155
+ "filename": "melband_roformer_inst_v1e_plus.ckpt",
156
+ "scores": {"vocals": {"SDR": 11.95}, "instrumental": {"SDR": 16.30}},
 
 
157
  "stems": ["vocals", "instrumental"],
158
+ "target_stem": "vocals",
159
+ "download_files": ["melband_roformer_inst_v1e_plus.ckpt", "model_2_stem_full_band_8k.yaml"]
 
 
 
 
 
 
 
 
160
  },
161
+ "MelBand Roformer Kim | Inst V1 Plus by Unwa": {
162
+ "filename": "melband_roformer_inst_v1_plus.ckpt",
163
+ "scores": {"vocals": {"SDR": 11.80}, "instrumental": {"SDR": 16.20}},
 
 
164
  "stems": ["vocals", "instrumental"],
165
+ "target_stem": "vocals",
166
+ "download_files": ["melband_roformer_inst_v1_plus.ckpt", "model_2_stem_full_band_8k.yaml"]
167
+ },
168
+ # Add other models from the log as needed
169
  }
170
 
171
  public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
172
  audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
173
 
 
174
  model_files_grouped_by_type = {
175
  "MDX": {
176
  "MDX-Net Model: UVR-MDX-NET-Inst_full_292": {
177
  "filename": "UVR-MDX-NET-Inst_full_292.onnx",
178
+ "scores": {"vocals": {"SDR": 10.6497}, "instrumental": {"SDR": 15.2149}},
179
+ "stems": ["vocals", "instrumental"],
180
+ "target_stem": "vocals",
181
  "download_files": ["UVR-MDX-NET-Inst_full_292.onnx"]
182
  }
183
  },
 
193
  "Demucs": {
194
  "Demucs v4: htdemucs_ft": {
195
  "filename": "htdemucs_ft.yaml",
196
+ "scores": {"vocals": {"SDR": 11.2685}, "drums": {"SDR": 13.235}, "bass": {"SDR": 9.72743}},
197
+ "stems": ["vocals", "drums", "bass"],
198
+ "target_stem": "vocals",
199
  "download_files": [
200
  f"{public_model_repo_url_prefix}/htdemucs_ft.yaml",
201
  "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th"
202
  ]
203
  }
204
  },
205
+ "MDXC": roformer_models
 
 
 
 
 
 
 
 
 
 
 
206
  }
207
  return model_files_grouped_by_type
208
 
 
243
  if file_to_download.endswith(".yaml"):
244
  yaml_config_filename = file_to_download
245
  return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
246
+ self.logger.error(f"Model {model_filename} not found in supported models")
247
  raise ValueError(f"Model file {model_filename} not found")
248
 
249
  def load_model_data_from_yaml(self, yaml_config_filename):
 
276
  if model_hash in model_data:
277
  self.logger.debug(f"Model data loaded for hash {model_hash}")
278
  return model_data[model_hash]
279
+ self.logger.error(f"No model data for hash {model_hash}")
280
  raise ValueError(f"No model data for hash {model_hash}")
281
 
282
  def load_model(self, model_filename="UVR-MDX-NET-Inst_full_292.onnx"):
 
284
  self.logger.info(f"Loading model {model_filename}")
285
  start_time = time.perf_counter()
286
 
287
+ try:
288
+ model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
289
+ except ValueError as e:
290
+ self.logger.error(f"Failed to load model: {e}")
291
+ self.logger.info("Falling back to default model: UVR-MDX-NET-Inst_full_292.onnx")
292
+ model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self\nSystem: .download_model_files("UVR-MDX-NET-Inst_full_292.onnx")
293
 
294
+ model_name = model_filename.split(".")[0]
295
  model_data = self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename else self.load_model_data_using_hash(model_path)
296
 
297
  common_params = {
 
422
  def download_model_and_data(self, model_filename):
423
  """Download model files without loading into memory."""
424
  self.logger.info(f"Downloading model {model_filename}")
425
+ try:
426
+ model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
427
+ model_data = self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename else self.load_model_data_using_hash(model_path)
428
+ self.logger.info(f"Model downloaded: {model_friendly_name}, type: {model_type}, path: {model_path}, data items: {len(model_data)}")
429
+ except ValueError as e:
430
+ self.logger.error(f"Failed to download model: {e}")
431
+ raise
432
 
433
  def get_simplified_model_list(self, filter_sort_by: Optional[str] = None):
434
  """Return a simplified list of models."""