ASesYusuf1 commited on
Commit
6c32ddc
·
verified ·
1 Parent(s): 5db82ef

Update audio_separator/separator/separator.py

Browse files
Files changed (1) hide show
  1. audio_separator/separator/separator.py +758 -277
audio_separator/separator/separator.py CHANGED
@@ -1,32 +1,90 @@
 
 
 
1
  import os
 
 
 
 
2
  import logging
 
 
 
 
 
 
 
 
3
  import requests
4
  import torch
5
  import torch.amp.autocast_mode as autocast_mode
6
  import onnxruntime as ort
7
- import numpy as np
8
- import soundfile as sf
9
- import json
10
- import yaml
11
- import importlib
12
- import hashlib
13
- import time
14
- from typing import Optional
15
- from io import BytesIO
16
  from tqdm import tqdm
17
  import spaces
18
 
 
19
  class Separator:
20
  """
21
- Optimized Separator class for audio source separation on Hugging Face Zero GPU.
22
- Supports MDX, VR, Demucs, and MDXC architectures with ONNX Runtime and PyTorch.
23
- Handles MelBand Roformer models and ensures robust model downloading.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  """
 
25
  def __init__(
26
  self,
27
  log_level=logging.INFO,
 
28
  model_file_dir="/tmp/audio-separator-models/",
29
- output_dir="/tmp/audio_output/",
30
  output_format="WAV",
31
  output_bitrate=None,
32
  normalization_threshold=0.9,
@@ -34,8 +92,8 @@ class Separator:
34
  output_single_stem=None,
35
  invert_using_spec=False,
36
  sample_rate=44100,
37
- use_soundfile=True,
38
- use_autocast=True,
39
  use_directml=False,
40
  mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
41
  vr_params={"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
@@ -43,256 +101,639 @@ class Separator:
43
  mdxc_params={"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0},
44
  info_only=False,
45
  ):
46
- """Initialize the separator for Zero GPU."""
47
  self.logger = logging.getLogger(__name__)
48
  self.logger.setLevel(log_level)
49
- handler = logging.StreamHandler()
50
- handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
 
 
 
 
 
 
 
 
51
  if not self.logger.hasHandlers():
52
- self.logger.addHandler(handler)
53
 
54
- # Configuration
55
- self.model_file_dir = os.environ.get("AUDIO_SEPARATOR_MODEL_DIR", model_file_dir)
56
- self.output_dir = output_dir or os.getcwd()
57
- self.output_format = output_format.upper()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  self.output_bitrate = output_bitrate
 
 
 
 
59
  self.normalization_threshold = normalization_threshold
 
 
 
60
  self.amplification_threshold = amplification_threshold
 
 
 
61
  self.output_single_stem = output_single_stem
 
 
 
62
  self.invert_using_spec = invert_using_spec
63
- self.sample_rate = int(sample_rate)
 
 
 
 
 
 
 
 
 
 
 
64
  self.use_soundfile = use_soundfile
65
  self.use_autocast = use_autocast
66
  self.use_directml = use_directml
67
- self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params}
68
 
69
- # Validation
70
- if not (0 < normalization_threshold <= 1):
71
- raise ValueError("normalization_threshold must be in (0, 1]")
72
- if not (0 <= amplification_threshold <= 1):
73
- raise ValueError("amplification_threshold must be in [0, 1]")
74
- if self.sample_rate <= 0 or self.sample_rate > 12800000:
75
- raise ValueError("sample_rate must be a positive integer <= 12800000")
76
 
77
- # Create directories
78
- os.makedirs(self.model_file_dir, exist_ok=True)
79
- os.makedirs(self.output_dir, exist_ok=True)
80
 
81
- # Setup device
82
- self.torch_device_cpu = torch.device("cpu")
83
- self.torch_device_mps = torch.device("mps") if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else None
84
- self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
- self.onnx_execution_provider = ["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"]
86
-
87
- if self.use_directml:
88
- try:
89
- import torch_directml
90
- if torch_directml.is_available():
91
- self.torch_device = torch_directml.device()
92
- self.onnx_execution_provider = ["DmlExecutionProvider"] if "DmlExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"]
93
- except ImportError:
94
- self.logger.warning("torch_directml not installed, falling back to CPU")
95
- self.torch_device = self.torch_device_cpu
96
-
97
- self.logger.info(f"Using device: {self.torch_device}, ONNX provider: {self.onnx_execution_provider}")
98
  self.model_instance = None
 
99
  self.model_is_uvr_vip = False
100
  self.model_friendly_name = None
101
 
102
  if not info_only:
103
- self.logger.info(f"Initialized Separator with model_dir: {self.model_file_dir}, output_dir: {self.output_dir}")
104
 
105
  def setup_accelerated_inferencing_device(self):
106
- """Configure hardware acceleration."""
107
- if self.torch_device.type == "cuda":
108
- self.logger.info("CUDA available, using GPU acceleration")
109
- elif self.torch_device_mps and "arm" in platform.machine().lower():
110
- self.torch_device = self.torch_device_mps
111
- self.logger.info("MPS available, using Apple Silicon acceleration")
112
- else:
113
- self.logger.info("No GPU acceleration available, using CPU")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  self.torch_device = self.torch_device_cpu
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  def get_model_hash(self, model_path):
117
- """Calculate MD5 hash of a model file."""
118
- BYTES_TO_HASH = 10000 * 1024
 
 
 
 
 
119
  try:
 
 
120
  with open(model_path, "rb") as f:
121
- file_size = os.path.getsize(model_path)
122
  if file_size < BYTES_TO_HASH:
123
- return hashlib.md5(f.read()).hexdigest()
124
- f.seek(file_size - BYTES_TO_HASH)
125
- return hashlib.md5(f.read()).hexdigest()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  except Exception as e:
 
127
  self.logger.error(f"Error calculating hash for {model_path}: {e}")
128
- raise
129
 
130
  def download_file_if_not_exists(self, url, output_path):
131
- """Download file from URL if it doesn't exist."""
132
- if os.path.exists(output_path):
133
- self.logger.debug(f"File exists: {output_path}")
 
 
 
134
  return
135
- self.logger.info(f"Downloading {url} to {output_path}")
136
- response = requests.get(url, stream=True, timeout=60)
 
 
137
  if response.status_code == 200:
138
- with open(output_path, "wb") as f, tqdm(total=int(response.headers.get("content-length", 0)), unit="B", unit_scale=True) as pbar:
 
 
 
139
  for chunk in response.iter_content(chunk_size=8192):
 
140
  f.write(chunk)
141
- pbar.update(len(chunk))
142
  else:
143
- raise RuntimeError(f"Failed to download {url}: {response.status_code}")
144
 
145
  def list_supported_model_files(self):
146
- """Fetch supported model files, including MelBand Roformer models."""
147
- download_checks_url = "https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json"
148
- download_checks_path = os.path.join(self.model_file_dir, "download_checks.json")
149
- self.download_file_if_not_exists(download_checks_url, download_checks_path)
150
- model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
151
 
152
- # Custom model list from the Gradio log
153
- roformer_models = {
154
- "MelBand Roformer Kim | Inst V1 (E) Plus by Unwa": {
155
- "filename": "melband_roformer_inst_v1e_plus.ckpt",
156
- "scores": {"vocals": {"SDR": 11.95}, "instrumental": {"SDR": 16.30}},
157
- "stems": ["vocals", "instrumental"],
158
- "target_stem": "vocals",
159
- "download_files": ["melband_roformer_inst_v1e_plus.ckpt", "model_2_stem_full_band_8k.yaml"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  },
161
- "MelBand Roformer Kim | Inst V1 Plus by Unwa": {
162
- "filename": "melband_roformer_inst_v1_plus.ckpt",
163
- "scores": {"vocals": {"SDR": 11.80}, "instrumental": {"SDR": 16.20}},
164
- "stems": ["vocals", "instrumental"],
165
- "target_stem": "vocals",
166
- "download_files": ["melband_roformer_inst_v1_plus.ckpt", "model_2_stem_full_band_8k.yaml"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  },
168
- # Add other models from the log as needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  }
 
 
170
 
171
- public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
172
- audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
173
 
174
- model_files_grouped_by_type = {
175
- "MDX": {
176
- "MDX-Net Model: UVR-MDX-NET-Inst_full_292": {
177
- "filename": "UVR-MDX-NET-Inst_full_292.onnx",
178
- "scores": {"vocals": {"SDR": 10.6497}, "instrumental": {"SDR": 15.2149}},
179
- "stems": ["vocals", "instrumental"],
180
- "target_stem": "vocals",
181
- "download_files": ["UVR-MDX-NET-Inst_full_292.onnx"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  }
183
- },
 
 
 
 
 
 
 
184
  "VR": {
185
- "VR Model: UVR-VR-Model": {
186
- "filename": "UVR-VR-Model.onnx",
187
- "scores": {},
188
- "stems": ["vocals", "instrumental"],
189
- "target_stem": "vocals",
190
- "download_files": ["UVR-VR-Model.onnx"]
191
- }
 
192
  },
193
- "Demucs": {
194
- "Demucs v4: htdemucs_ft": {
195
- "filename": "htdemucs_ft.yaml",
196
- "scores": {"vocals": {"SDR": 11.2685}, "drums": {"SDR": 13.235}, "bass": {"SDR": 9.72743}},
197
- "stems": ["vocals", "drums", "bass"],
198
- "target_stem": "vocals",
199
- "download_files": [
200
- f"{public_model_repo_url_prefix}/htdemucs_ft.yaml",
201
- "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th"
202
- ]
 
 
 
 
 
 
 
 
203
  }
 
 
 
 
 
 
 
204
  },
205
- "MDXC": roformer_models
206
  }
 
207
  return model_files_grouped_by_type
208
 
209
  def print_uvr_vip_message(self):
210
- """Print message for VIP models."""
 
 
211
  if self.model_is_uvr_vip:
212
- self.logger.warning(f"Model '{self.model_friendly_name}' is a VIP model. Consider supporting UVR at https://patreon.com/uvr")
 
213
 
214
  def download_model_files(self, model_filename):
215
- """Download model files and return metadata."""
216
- model_path = os.path.join(self.model_file_dir, model_filename)
217
- supported_models = self.list_supported_model_files()
 
 
 
 
218
  public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
219
  vip_model_repo_url_prefix = "https://github.com/Anjok0109/ai_magic/releases/download/v5"
220
  audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
221
 
222
  yaml_config_filename = None
223
- for model_type, models in supported_models.items():
 
 
 
 
 
224
  for model_friendly_name, model_info in models.items():
225
  self.model_is_uvr_vip = "VIP" in model_friendly_name
226
  model_repo_url_prefix = vip_model_repo_url_prefix if self.model_is_uvr_vip else public_model_repo_url_prefix
 
 
227
  if model_info["filename"] == model_filename or model_filename in model_info["download_files"]:
 
228
  self.model_friendly_name = model_friendly_name
229
  self.print_uvr_vip_message()
 
 
230
  for file_to_download in model_info["download_files"]:
 
231
  if file_to_download.startswith("http"):
232
  filename = file_to_download.split("/")[-1]
233
  download_path = os.path.join(self.model_file_dir, filename)
234
  self.download_file_if_not_exists(file_to_download, download_path)
235
- if file_to_download.endswith(".yaml"):
236
- yaml_config_filename = filename
237
  continue
 
238
  download_path = os.path.join(self.model_file_dir, file_to_download)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  try:
240
- self.download_file_if_not_exists(f"{model_repo_url_prefix}/{file_to_download}", download_path)
 
241
  except RuntimeError:
242
- self.download_file_if_not_exists(f"{audio_separator_models_repo_url_prefix}/{file_to_download}", download_path)
243
- if file_to_download.endswith(".yaml"):
244
- yaml_config_filename = file_to_download
 
245
  return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
246
- self.logger.error(f"Model {model_filename} not found in supported models")
247
- raise ValueError(f"Model file {model_filename} not found")
248
 
249
  def load_model_data_from_yaml(self, yaml_config_filename):
250
- """Load model data from YAML file."""
251
- yaml_path = os.path.join(self.model_file_dir, yaml_config_filename)
252
- try:
253
- with open(yaml_path, encoding="utf-8") as f:
254
- model_data = yaml.load(f, Loader=yaml.FullLoader)
255
- self.logger.debug(f"Model data loaded from YAML: {model_data}")
256
- if "roformer" in yaml_config_filename.lower():
257
- model_data["is_roformer"] = True
258
- return model_data
259
- except Exception as e:
260
- self.logger.error(f"Failed to load YAML {yaml_config_filename}: {e}")
261
- raise
 
 
 
 
 
 
 
262
 
263
  def load_model_data_using_hash(self, model_path):
264
- """Load model data using file hash."""
265
- model_data_urls = [
266
- "https://raw.githubusercontent.com/TRvlvr/application_data/main/vr_model_data/model_data_new.json",
267
- "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/model_data_new.json"
268
- ]
269
- model_data = {}
 
 
 
 
 
 
 
270
  model_hash = self.get_model_hash(model_path)
271
- for url in model_data_urls:
272
- model_data_path = os.path.join(self.model_file_dir, os.path.basename(url))
273
- self.download_file_if_not_exists(url, model_data_path)
274
- with open(model_data_path, encoding="utf-8") as f:
275
- model_data.update(json.load(f))
276
- if model_hash in model_data:
277
- self.logger.debug(f"Model data loaded for hash {model_hash}")
278
- return model_data[model_hash]
279
- self.logger.error(f"No model data for hash {model_hash}")
280
- raise ValueError(f"No model data for hash {model_hash}")
281
-
282
- def load_model(self, model_filename="UVR-MDX-NET-Inst_full_292.onnx"):
283
- """Load model based on architecture."""
284
- self.logger.info(f"Loading model {model_filename}")
285
- start_time = time.perf_counter()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
- try:
288
- model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
289
- except ValueError as e:
290
- self.logger.error(f"Failed to load model: {e}")
291
- self.logger.info("Falling back to default model: UVR-MDX-NET-Inst_full_292.onnx")
292
- model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self\nSystem: .download_model_files("UVR-MDX-NET-Inst_full_292.onnx")
 
 
 
 
293
 
 
 
 
 
294
  model_name = model_filename.split(".")[0]
295
- model_data = self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename else self.load_model_data_using_hash(model_path)
 
 
 
 
 
 
 
 
296
 
297
  common_params = {
298
  "logger": self.logger,
@@ -315,167 +756,207 @@ class Separator:
315
  "use_soundfile": self.use_soundfile,
316
  }
317
 
318
- separator_classes = {
319
- "MDX": "mdx_separator.MDXSeparator",
320
- "VR": "vr_separator.VRSeparator",
321
- "Demucs": "demucs_separator.DemucsSeparator",
322
- "MDXC": "mdxc_separator.MDXCSeparator"
323
- }
324
 
325
- if model_type not in separator_classes:
326
- raise ValueError(f"Unsupported model type: {model_type}")
327
  if model_type == "Demucs" and sys.version_info < (3, 10):
328
- raise Exception("Demucs requires Python 3.10 or newer")
 
 
329
 
330
  module_name, class_name = separator_classes[model_type].split(".")
331
- try:
332
- module = importlib.import_module(f"audio_separator.separator.architectures.{module_name}")
333
- separator_class = getattr(module, class_name)
334
- self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])
335
- except ImportError as e:
336
- self.logger.error(f"Failed to load module {module_name}: {e}")
337
- raise
338
-
339
- self.logger.info(f"Model loaded in {time.perf_counter() - start_time:.2f} seconds")
340
-
341
- def preprocess_audio(self, audio_data, sample_rate):
342
- """Preprocess audio: resample, normalize, and convert to tensor."""
343
- if sample_rate != self.sample_rate:
344
- self.logger.debug(f"Resampling from {sample_rate} to {self.sample_rate} Hz")
345
- audio_data = np.interp(
346
- np.linspace(0, len(audio_data), int(len(audio_data) * self.sample_rate / sample_rate)),
347
- np.arange(len(audio_data)),
348
- audio_data
349
- )
350
- max_amplitude = np.max(np.abs(audio_data))
351
- if max_amplitude > 0:
352
- audio_data = audio_data * (self.normalization_threshold / max_amplitude)
353
- if max_amplitude < self.amplification_threshold:
354
- audio_data = audio_data * (self.amplification_threshold / max_amplitude)
355
- return torch.tensor(audio_data, dtype=torch.float32, device=self.torch_device)
356
 
357
  @spaces.GPU
358
  def separate(self, audio_file_path, custom_output_names=None):
359
- """Separate audio file into stems."""
360
- if not self.model_instance:
361
- raise ValueError("Model not loaded")
 
 
 
 
 
 
 
362
 
363
- self.logger.info(f"Separating audio: {audio_file_path}")
364
- start_time = time.perf_counter()
 
 
 
 
365
 
 
366
  if isinstance(audio_file_path, str):
367
  audio_file_path = [audio_file_path]
368
 
 
369
  output_files = []
 
 
370
  for path in audio_file_path:
371
  if os.path.isdir(path):
372
- for root, _, files in os.walk(path):
 
373
  for file in files:
374
- if file.endswith((".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aiff")):
 
375
  full_path = os.path.join(root, file)
376
- output_files.extend(self._separate_file(full_path, custom_output_names))
 
 
 
 
 
 
377
  else:
378
- output_files.extend(self._separate_file(path, custom_output_names))
 
 
 
 
 
 
379
 
380
- self.print_uvr_vip_message()
381
- self.logger.info(f"Separation completed in {time.perf_counter() - start_time:.2f} seconds")
382
  return output_files
383
 
384
  @spaces.GPU
385
  def _separate_file(self, audio_file_path, custom_output_names=None):
386
- """Internal method to separate a single audio file."""
387
- self.logger.debug(f"Processing file: {audio_file_path}")
388
- try:
389
- audio_data, input_sample_rate = sf.read(audio_file_path)
390
- if len(audio_data.shape) > 1:
391
- audio_data = np.mean(audio_data, axis=1)
392
- except Exception as e:
393
- self.logger.error(f"Failed to read audio: {e}")
394
- raise
395
-
396
- audio_tensor = self.preprocess_audio(audio_data, input_sample_rate)
397
-
398
- output_files = []
399
- try:
400
- if self.use_autocast and autocast_mode.is_autocast_available(self.torch_device.type):
401
- with autocast_mode.autocast(self.torch_device.type):
402
- output_files = self.model_instance.separate(audio_tensor, custom_output_names)
403
- else:
404
- output_files = self.model_instance.separate(audio_tensor, custom_output_names)
405
- except Exception as e:
406
- self.logger.error(f"Separation failed: {e}")
407
- raise
408
-
409
- # Mock output for architectures not implemented (replace with actual logic in separator classes)
410
- if not output_files:
411
- stem_names = ["vocals", "instrumental"] # Adjust based on model
412
- output_files = []
413
- for stem in stem_names:
414
- output_path = os.path.join(self.output_dir, f"{os.path.splitext(os.path.basename(audio_file_path))[0]}_{stem}.{self.output_format.lower()}")
415
- sf.write(output_path, audio_data, self.sample_rate)
416
- output_files.append(output_path)
417
 
 
418
  self.model_instance.clear_gpu_cache()
 
 
419
  self.model_instance.clear_file_specific_paths()
 
 
 
 
 
 
 
 
420
  return output_files
421
 
422
  def download_model_and_data(self, model_filename):
423
- """Download model files without loading into memory."""
424
- self.logger.info(f"Downloading model {model_filename}")
425
- try:
426
- model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
427
- model_data = self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename else self.load_model_data_using_hash(model_path)
428
- self.logger.info(f"Model downloaded: {model_friendly_name}, type: {model_type}, path: {model_path}, data items: {len(model_data)}")
429
- except ValueError as e:
430
- self.logger.error(f"Failed to download model: {e}")
431
- raise
 
 
 
 
 
 
 
 
 
432
 
433
  def get_simplified_model_list(self, filter_sort_by: Optional[str] = None):
434
- """Return a simplified list of models."""
 
 
 
 
 
435
  model_files = self.list_supported_model_files()
436
  simplified_list = {}
437
 
438
  for model_type, models in model_files.items():
439
  for name, data in models.items():
440
  filename = data["filename"]
441
- scores = data.get("scores", {})
442
- stems = data.get("stems", [])
443
  target_stem = data.get("target_stem")
 
 
444
  stems_with_scores = []
445
  stem_sdr_dict = {}
 
 
446
  for stem in stems:
 
 
447
  stem_display = f"{stem}*" if stem == target_stem else stem
448
- sdr = scores.get(stem, {}).get("SDR")
449
- if sdr is not None:
450
- sdr = round(sdr, 1)
451
  stems_with_scores.append(f"{stem_display} ({sdr})")
452
  stem_sdr_dict[stem.lower()] = sdr
453
  else:
 
454
  stems_with_scores.append(stem_display)
455
  stem_sdr_dict[stem.lower()] = None
456
 
 
457
  if not stems_with_scores:
458
  stems_with_scores = ["Unknown"]
459
  stem_sdr_dict["unknown"] = None
460
 
461
- simplified_list[filename] = {
462
- "Name": name,
463
- "Type": model_type,
464
- "Stems": stems_with_scores,
465
- "SDR": stem_sdr_dict
466
- }
467
 
 
468
  if filter_sort_by:
469
  if filter_sort_by == "name":
470
  return dict(sorted(simplified_list.items(), key=lambda x: x[1]["Name"]))
471
  elif filter_sort_by == "filename":
472
  return dict(sorted(simplified_list.items()))
473
  else:
 
474
  sort_by_lower = filter_sort_by.lower()
 
475
  filtered_list = {k: v for k, v in simplified_list.items() if sort_by_lower in v["SDR"]}
 
 
476
  def sort_key(item):
477
- sdr = item[1]["SDR"].get(sort_by_lower)
478
  return (0 if sdr is None else 1, sdr if sdr is not None else float("-inf"))
 
479
  return dict(sorted(filtered_list.items(), key=sort_key, reverse=True))
480
 
481
  return simplified_list
 
1
+ """ This file contains the Separator class, to facilitate the separation of stems from audio. """
2
+
3
+ from importlib import metadata, resources
4
  import os
5
+ import sys
6
+ import platform
7
+ import subprocess
8
+ import time
9
  import logging
10
+ import warnings
11
+ import importlib
12
+ import io
13
+ from typing import Optional
14
+
15
+ import hashlib
16
+ import json
17
+ import yaml
18
  import requests
19
  import torch
20
  import torch.amp.autocast_mode as autocast_mode
21
  import onnxruntime as ort
 
 
 
 
 
 
 
 
 
22
  from tqdm import tqdm
23
  import spaces
24
 
25
+
26
  class Separator:
27
  """
28
+ The Separator class is designed to facilitate the separation of audio sources from a given audio file.
29
+ It supports various separation architectures and models, including MDX, VR, and Demucs. The class provides
30
+ functionalities to configure separation parameters, load models, and perform audio source separation.
31
+ It also handles logging, normalization, and output formatting of the separated audio stems.
32
+
33
+ The actual separation task is handled by one of the architecture-specific classes in the `architectures` module;
34
+ this class is responsible for initialising logging, configuring hardware acceleration, loading the model,
35
+ initiating the separation process and passing outputs back to the caller.
36
+
37
+ Common Attributes:
38
+ log_level (int): The logging level.
39
+ log_formatter (logging.Formatter): The logging formatter.
40
+ model_file_dir (str): The directory where model files are stored.
41
+ output_dir (str): The directory where output files will be saved.
42
+ output_format (str): The format of the output audio file.
43
+ output_bitrate (str): The bitrate of the output audio file.
44
+ amplification_threshold (float): The threshold for audio amplification.
45
+ normalization_threshold (float): The threshold for audio normalization.
46
+ output_single_stem (str): Option to output a single stem.
47
+ invert_using_spec (bool): Flag to invert using spectrogram.
48
+ sample_rate (int): The sample rate of the audio.
49
+ use_soundfile (bool): Use soundfile for audio writing, can solve OOM issues.
50
+ use_autocast (bool): Flag to use PyTorch autocast for faster inference.
51
+
52
+ MDX Architecture Specific Attributes:
53
+ hop_length (int): The hop length for STFT.
54
+ segment_size (int): The segment size for processing.
55
+ overlap (float): The overlap between segments.
56
+ batch_size (int): The batch size for processing.
57
+ enable_denoise (bool): Flag to enable or disable denoising.
58
+
59
+ VR Architecture Specific Attributes & Defaults:
60
+ batch_size: 16
61
+ window_size: 512
62
+ aggression: 5
63
+ enable_tta: False
64
+ enable_post_process: False
65
+ post_process_threshold: 0.2
66
+ high_end_process: False
67
+
68
+ Demucs Architecture Specific Attributes & Defaults:
69
+ segment_size: "Default"
70
+ shifts: 2
71
+ overlap: 0.25
72
+ segments_enabled: True
73
+
74
+ MDXC Architecture Specific Attributes & Defaults:
75
+ segment_size: 256
76
+ override_model_segment_size: False
77
+ batch_size: 1
78
+ overlap: 8
79
+ pitch_shift: 0
80
  """
81
+
82
  def __init__(
83
  self,
84
  log_level=logging.INFO,
85
+ log_formatter=None,
86
  model_file_dir="/tmp/audio-separator-models/",
87
+ output_dir=None,
88
  output_format="WAV",
89
  output_bitrate=None,
90
  normalization_threshold=0.9,
 
92
  output_single_stem=None,
93
  invert_using_spec=False,
94
  sample_rate=44100,
95
+ use_soundfile=False,
96
+ use_autocast=False,
97
  use_directml=False,
98
  mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
99
  vr_params={"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
 
101
  mdxc_params={"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0},
102
  info_only=False,
103
  ):
104
+ """Initialize the separator."""
105
  self.logger = logging.getLogger(__name__)
106
  self.logger.setLevel(log_level)
107
+ self.log_level = log_level
108
+ self.log_formatter = log_formatter
109
+
110
+ self.log_handler = logging.StreamHandler()
111
+
112
+ if self.log_formatter is None:
113
+ self.log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(module)s - %(message)s")
114
+
115
+ self.log_handler.setFormatter(self.log_formatter)
116
+
117
  if not self.logger.hasHandlers():
118
+ self.logger.addHandler(self.log_handler)
119
 
120
+ # Filter out noisy warnings from PyTorch for users who don't care about them
121
+ if log_level > logging.DEBUG:
122
+ warnings.filterwarnings("ignore")
123
+
124
+ # Skip initialization logs if info_only is True
125
+ if not info_only:
126
+ package_version = self.get_package_distribution("audio-separator").version
127
+ self.logger.info(f"Separator version {package_version} instantiating with output_dir: {output_dir}, output_format: {output_format}")
128
+
129
+ if output_dir is None:
130
+ output_dir = os.getcwd()
131
+ if not info_only:
132
+ self.logger.info("Output directory not specified. Using current working directory.")
133
+
134
+ self.output_dir = output_dir
135
+
136
+ # Check for environment variable to override model_file_dir
137
+ env_model_dir = os.environ.get("AUDIO_SEPARATOR_MODEL_DIR")
138
+ if env_model_dir:
139
+ self.model_file_dir = env_model_dir
140
+ self.logger.info(f"Using model directory from AUDIO_SEPARATOR_MODEL_DIR env var: {self.model_file_dir}")
141
+ if not os.path.exists(self.model_file_dir):
142
+ raise FileNotFoundError(f"The specified model directory does not exist: {self.model_file_dir}")
143
+ else:
144
+ self.logger.info(f"Using model directory from model_file_dir parameter: {model_file_dir}")
145
+ self.model_file_dir = model_file_dir
146
+
147
+ # Create the model directory if it does not exist
148
+ os.makedirs(self.model_file_dir, exist_ok=True)
149
+ os.makedirs(self.output_dir, exist_ok=True)
150
+
151
+ self.output_format = output_format
152
  self.output_bitrate = output_bitrate
153
+
154
+ if self.output_format is None:
155
+ self.output_format = "WAV"
156
+
157
  self.normalization_threshold = normalization_threshold
158
+ if normalization_threshold <= 0 or normalization_threshold > 1:
159
+ raise ValueError("The normalization_threshold must be greater than 0 and less than or equal to 1.")
160
+
161
  self.amplification_threshold = amplification_threshold
162
+ if amplification_threshold < 0 or amplification_threshold > 1:
163
+ raise ValueError("The amplification_threshold must be greater than or equal to 0 and less than or equal to 1.")
164
+
165
  self.output_single_stem = output_single_stem
166
+ if output_single_stem is not None:
167
+ self.logger.debug(f"Single stem output requested, so only one output file ({output_single_stem}) will be written")
168
+
169
  self.invert_using_spec = invert_using_spec
170
+ if self.invert_using_spec:
171
+ self.logger.debug(f"Secondary step will be inverted using spectogram rather than waveform. This may improve quality but is slightly slower.")
172
+
173
+ try:
174
+ self.sample_rate = int(sample_rate)
175
+ if self.sample_rate <= 0:
176
+ raise ValueError(f"The sample rate setting is {self.sample_rate} but it must be a non-zero whole number.")
177
+ if self.sample_rate > 12800000:
178
+ raise ValueError(f"The sample rate setting is {self.sample_rate}. Enter something less ambitious.")
179
+ except ValueError:
180
+ raise ValueError("The sample rate must be a non-zero whole number. Please provide a valid integer.")
181
+
182
  self.use_soundfile = use_soundfile
183
  self.use_autocast = use_autocast
184
  self.use_directml = use_directml
 
185
 
186
+ # These are parameters which users may want to configure so we expose them to the top-level Separator class,
187
+ # even though they are specific to a single model architecture
188
+ self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params}
 
 
 
 
189
 
190
+ self.torch_device = None
191
+ self.torch_device_cpu = None
192
+ self.torch_device_mps = None
193
 
194
+ self.onnx_execution_provider = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  self.model_instance = None
196
+
197
  self.model_is_uvr_vip = False
198
  self.model_friendly_name = None
199
 
200
  if not info_only:
201
+ self.setup_accelerated_inferencing_device()
202
 
203
  def setup_accelerated_inferencing_device(self):
204
+ """
205
+ This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
206
+ """
207
+ system_info = self.get_system_info()
208
+ self.check_ffmpeg_installed()
209
+ self.log_onnxruntime_packages()
210
+ self.setup_torch_device(system_info)
211
+
212
+ def get_system_info(self):
213
+ """
214
+ This method logs the system information, including the operating system, CPU archutecture and Python version
215
+ """
216
+ os_name = platform.system()
217
+ os_version = platform.version()
218
+ self.logger.info(f"Operating System: {os_name} {os_version}")
219
+
220
+ system_info = platform.uname()
221
+ self.logger.info(f"System: {system_info.system} Node: {system_info.node} Release: {system_info.release} Machine: {system_info.machine} Proc: {system_info.processor}")
222
+
223
+ python_version = platform.python_version()
224
+ self.logger.info(f"Python Version: {python_version}")
225
+
226
+ pytorch_version = torch.__version__
227
+ self.logger.info(f"PyTorch Version: {pytorch_version}")
228
+ return system_info
229
+
230
+ def check_ffmpeg_installed(self):
231
+ """
232
+ This method checks if ffmpeg is installed and logs its version.
233
+ """
234
+ try:
235
+ ffmpeg_version_output = subprocess.check_output(["ffmpeg", "-version"], text=True)
236
+ first_line = ffmpeg_version_output.splitlines()[0]
237
+ self.logger.info(f"FFmpeg installed: {first_line}")
238
+ except FileNotFoundError:
239
+ self.logger.error("FFmpeg is not installed. Please install FFmpeg to use this package.")
240
+ # Raise an exception if this is being run by a user, as ffmpeg is required for pydub to write audio
241
+ # but if we're just running unit tests in CI, no reason to throw
242
+ if "PYTEST_CURRENT_TEST" not in os.environ:
243
+ raise
244
+
245
+ def log_onnxruntime_packages(self):
246
+ """
247
+ This method logs the ONNX Runtime package versions, including the GPU and Silicon packages if available.
248
+ """
249
+ onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu")
250
+ onnxruntime_silicon_package = self.get_package_distribution("onnxruntime-silicon")
251
+ onnxruntime_cpu_package = self.get_package_distribution("onnxruntime")
252
+ onnxruntime_dml_package = self.get_package_distribution("onnxruntime-directml")
253
+
254
+ if onnxruntime_gpu_package is not None:
255
+ self.logger.info(f"ONNX Runtime GPU package installed with version: {onnxruntime_gpu_package.version}")
256
+ if onnxruntime_silicon_package is not None:
257
+ self.logger.info(f"ONNX Runtime Silicon package installed with version: {onnxruntime_silicon_package.version}")
258
+ if onnxruntime_cpu_package is not None:
259
+ self.logger.info(f"ONNX Runtime CPU package installed with version: {onnxruntime_cpu_package.version}")
260
+ if onnxruntime_dml_package is not None:
261
+ self.logger.info(f"ONNX Runtime DirectML package installed with version: {onnxruntime_dml_package.version}")
262
+
263
+ def setup_torch_device(self, system_info):
264
+ """
265
+ This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
266
+ """
267
+ hardware_acceleration_enabled = False
268
+ ort_providers = ort.get_available_providers()
269
+ has_torch_dml_installed = self.get_package_distribution("torch_directml")
270
+
271
+ self.torch_device_cpu = torch.device("cpu")
272
+
273
+ if torch.cuda.is_available():
274
+ self.configure_cuda(ort_providers)
275
+ hardware_acceleration_enabled = True
276
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and system_info.processor == "arm":
277
+ self.configure_mps(ort_providers)
278
+ hardware_acceleration_enabled = True
279
+ elif self.use_directml and has_torch_dml_installed:
280
+ import torch_directml
281
+ if torch_directml.is_available():
282
+ self.configure_dml(ort_providers)
283
+ hardware_acceleration_enabled = True
284
+
285
+ if not hardware_acceleration_enabled:
286
+ self.logger.info("No hardware acceleration could be configured, running in CPU mode")
287
  self.torch_device = self.torch_device_cpu
288
+ self.onnx_execution_provider = ["CPUExecutionProvider"]
289
+
290
+ def configure_cuda(self, ort_providers):
291
+ """
292
+ This method configures the CUDA device for PyTorch and ONNX Runtime, if available.
293
+ """
294
+ self.logger.info("CUDA is available in Torch, setting Torch device to CUDA")
295
+ self.torch_device = torch.device("cuda")
296
+ if "CUDAExecutionProvider" in ort_providers:
297
+ self.logger.info("ONNXruntime has CUDAExecutionProvider available, enabling acceleration")
298
+ self.onnx_execution_provider = ["CUDAExecutionProvider"]
299
+ else:
300
+ self.logger.warning("CUDAExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
301
+
302
+ def configure_mps(self, ort_providers):
303
+ """
304
+ This method configures the Apple Silicon MPS/CoreML device for PyTorch and ONNX Runtime, if available.
305
+ """
306
+ self.logger.info("Apple Silicon MPS/CoreML is available in Torch and processor is ARM, setting Torch device to MPS")
307
+ self.torch_device_mps = torch.device("mps")
308
+
309
+ self.torch_device = self.torch_device_mps
310
+
311
+ if "CoreMLExecutionProvider" in ort_providers:
312
+ self.logger.info("ONNXruntime has CoreMLExecutionProvider available, enabling acceleration")
313
+ self.onnx_execution_provider = ["CoreMLExecutionProvider"]
314
+ else:
315
+ self.logger.warning("CoreMLExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
316
+
317
+ def configure_dml(self, ort_providers):
318
+ """
319
+ This method configures the DirectML device for PyTorch and ONNX Runtime, if available.
320
+ """
321
+ import torch_directml
322
+ self.logger.info("DirectML is available in Torch, setting Torch device to DirectML")
323
+ self.torch_device_dml = torch_directml.device()
324
+ self.torch_device = self.torch_device_dml
325
+
326
+ if "DmlExecutionProvider" in ort_providers:
327
+ self.logger.info("ONNXruntime has DmlExecutionProvider available, enabling acceleration")
328
+ self.onnx_execution_provider = ["DmlExecutionProvider"]
329
+ else:
330
+ self.logger.warning("DmlExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
331
+
332
+ def get_package_distribution(self, package_name):
333
+ """
334
+ This method returns the package distribution for a given package name if installed, or None otherwise.
335
+ """
336
+ try:
337
+ return metadata.distribution(package_name)
338
+ except metadata.PackageNotFoundError:
339
+ self.logger.debug(f"Python package: {package_name} not installed")
340
+ return None
341
 
342
  def get_model_hash(self, model_path):
343
+ """
344
+ This method returns the MD5 hash of a given model file.
345
+ """
346
+ self.logger.debug(f"Calculating hash of model file {model_path}")
347
+ # Use the specific byte count from the original logic
348
+ BYTES_TO_HASH = 10000 * 1024 # 10,240,000 bytes
349
+
350
  try:
351
+ file_size = os.path.getsize(model_path)
352
+
353
  with open(model_path, "rb") as f:
 
354
  if file_size < BYTES_TO_HASH:
355
+ # Hash the entire file if smaller than the target byte count
356
+ self.logger.debug(f"File size {file_size} < {BYTES_TO_HASH}, hashing entire file.")
357
+ hash_value = hashlib.md5(f.read()).hexdigest()
358
+ else:
359
+ # Seek to the specific position before the end (from the beginning) and hash
360
+ seek_pos = file_size - BYTES_TO_HASH
361
+ self.logger.debug(f"File size {file_size} >= {BYTES_TO_HASH}, seeking to {seek_pos} and hashing remaining bytes.")
362
+ f.seek(seek_pos, io.SEEK_SET)
363
+ hash_value = hashlib.md5(f.read()).hexdigest()
364
+
365
+ # Log the calculated hash
366
+ self.logger.info(f"Hash of model file {model_path} is {hash_value}")
367
+ return hash_value
368
+
369
+ except FileNotFoundError:
370
+ self.logger.error(f"Model file not found at {model_path}")
371
+ raise # Re-raise the specific error
372
  except Exception as e:
373
+ # Catch other potential errors (e.g., permissions, other IOErrors)
374
  self.logger.error(f"Error calculating hash for {model_path}: {e}")
375
+ raise # Re-raise other errors
376
 
377
  def download_file_if_not_exists(self, url, output_path):
378
+ """
379
+ This method downloads a file from a given URL to a given output path, if the file does not already exist.
380
+ """
381
+
382
+ if os.path.isfile(output_path):
383
+ self.logger.debug(f"File already exists at {output_path}, skipping download")
384
  return
385
+
386
+ self.logger.debug(f"Downloading file from {url} to {output_path} with timeout 300s")
387
+ response = requests.get(url, stream=True, timeout=300)
388
+
389
  if response.status_code == 200:
390
+ total_size_in_bytes = int(response.headers.get("content-length", 0))
391
+ progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
392
+
393
+ with open(output_path, "wb") as f:
394
  for chunk in response.iter_content(chunk_size=8192):
395
+ progress_bar.update(len(chunk))
396
  f.write(chunk)
397
+ progress_bar.close()
398
  else:
399
+ raise RuntimeError(f"Failed to download file from {url}, response code: {response.status_code}")
400
 
401
  def list_supported_model_files(self):
402
+ """
403
+ This method lists the supported model files for audio-separator, by fetching the same file UVR uses to list these.
404
+ Also includes model performance scores where available.
405
+
406
+ Example response object:
407
 
408
+ {
409
+ "MDX": {
410
+ "MDX-Net Model VIP: UVR-MDX-NET-Inst_full_292": {
411
+ "filename": "UVR-MDX-NET-Inst_full_292.onnx",
412
+ "scores": {
413
+ "vocals": {
414
+ "SDR": 10.6497,
415
+ "SIR": 20.3786,
416
+ "SAR": 10.692,
417
+ "ISR": 14.848
418
+ },
419
+ "instrumental": {
420
+ "SDR": 15.2149,
421
+ "SIR": 25.6075,
422
+ "SAR": 17.1363,
423
+ "ISR": 17.7893
424
+ }
425
+ },
426
+ "download_files": [
427
+ "UVR-MDX-NET-Inst_full_292.onnx"
428
+ ]
429
+ }
430
  },
431
+ "Demucs": {
432
+ "Demucs v4: htdemucs_ft": {
433
+ "filename": "htdemucs_ft.yaml",
434
+ "scores": {
435
+ "vocals": {
436
+ "SDR": 11.2685,
437
+ "SIR": 21.257,
438
+ "SAR": 11.0359,
439
+ "ISR": 19.3753
440
+ },
441
+ "drums": {
442
+ "SDR": 13.235,
443
+ "SIR": 23.3053,
444
+ "SAR": 13.0313,
445
+ "ISR": 17.2889
446
+ },
447
+ "bass": {
448
+ "SDR": 9.72743,
449
+ "SIR": 19.5435,
450
+ "SAR": 9.20801,
451
+ "ISR": 13.5037
452
+ }
453
+ },
454
+ "download_files": [
455
+ "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th",
456
+ "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/d12395a8-e57c48e6.th",
457
+ "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/92cfc3b6-ef3bcb9c.th",
458
+ "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th",
459
+ "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/htdemucs_ft.yaml"
460
+ ]
461
+ }
462
  },
463
+ "MDXC": {
464
+ "MDX23C Model: MDX23C-InstVoc HQ": {
465
+ "filename": "MDX23C-8KFFT-InstVoc_HQ.ckpt",
466
+ "scores": {
467
+ "vocals": {
468
+ "SDR": 11.9504,
469
+ "SIR": 23.1166,
470
+ "SAR": 12.093,
471
+ "ISR": 15.4782
472
+ },
473
+ "instrumental": {
474
+ "SDR": 16.3035,
475
+ "SIR": 26.6161,
476
+ "SAR": 18.5167,
477
+ "ISR": 18.3939
478
+ }
479
+ },
480
+ "download_files": [
481
+ "MDX23C-8KFFT-InstVoc_HQ.ckpt",
482
+ "model_2_stem_full_band_8k.yaml"
483
+ ]
484
+ }
485
+ }
486
  }
487
+ """
488
+ download_checks_path = os.path.join(self.model_file_dir, "download_checks.json")
489
 
490
+ self.download_file_if_not_exists("https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json", download_checks_path)
 
491
 
492
+ model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
493
+ self.logger.debug(f"UVR model download list loaded")
494
+
495
+ # Load the model scores with error handling
496
+ model_scores = {}
497
+ try:
498
+ with resources.open_text("audio_separator", "models-scores.json") as f:
499
+ model_scores = json.load(f)
500
+ self.logger.debug(f"Model scores loaded")
501
+ except json.JSONDecodeError as e:
502
+ self.logger.warning(f"Failed to load model scores: {str(e)}")
503
+ self.logger.warning("Continuing without model scores")
504
+
505
+ # Only show Demucs v4 models as we've only implemented support for v4
506
+ filtered_demucs_v4 = {key: value for key, value in model_downloads_list["demucs_download_list"].items() if key.startswith("Demucs v4")}
507
+
508
+ # Modified Demucs handling to use YAML files as identifiers and include download files
509
+ demucs_models = {}
510
+ for name, files in filtered_demucs_v4.items():
511
+ # Find the YAML file in the model files
512
+ yaml_file = next((filename for filename in files.keys() if filename.endswith(".yaml")), None)
513
+ if yaml_file:
514
+ model_score_data = model_scores.get(yaml_file, {})
515
+ demucs_models[name] = {
516
+ "filename": yaml_file,
517
+ "scores": model_score_data.get("median_scores", {}),
518
+ "stems": model_score_data.get("stems", []),
519
+ "target_stem": model_score_data.get("target_stem"),
520
+ "download_files": list(files.values()), # List of all download URLs/filenames
521
  }
522
+
523
+ # Load the JSON file using importlib.resources
524
+ with resources.open_text("audio_separator", "models.json") as f:
525
+ audio_separator_models_list = json.load(f)
526
+ self.logger.debug(f"Audio-Separator model list loaded")
527
+
528
+ # Return object with list of model names
529
+ model_files_grouped_by_type = {
530
  "VR": {
531
+ name: {
532
+ "filename": filename,
533
+ "scores": model_scores.get(filename, {}).get("median_scores", {}),
534
+ "stems": model_scores.get(filename, {}).get("stems", []),
535
+ "target_stem": model_scores.get(filename, {}).get("target_stem"),
536
+ "download_files": [filename],
537
+ } # Just the filename for VR models
538
+ for name, filename in {**model_downloads_list["vr_download_list"], **audio_separator_models_list["vr_download_list"]}.items()
539
  },
540
+ "MDX": {
541
+ name: {
542
+ "filename": filename,
543
+ "scores": model_scores.get(filename, {}).get("median_scores", {}),
544
+ "stems": model_scores.get(filename, {}).get("stems", []),
545
+ "target_stem": model_scores.get(filename, {}).get("target_stem"),
546
+ "download_files": [filename],
547
+ } # Just the filename for MDX models
548
+ for name, filename in {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"], **audio_separator_models_list["mdx_download_list"]}.items()
549
+ },
550
+ "Demucs": demucs_models,
551
+ "MDXC": {
552
+ name: {
553
+ "filename": next(iter(files.keys())),
554
+ "scores": model_scores.get(next(iter(files.keys())), {}).get("median_scores", {}),
555
+ "stems": model_scores.get(next(iter(files.keys())), {}).get("stems", []),
556
+ "target_stem": model_scores.get(next(iter(files.keys())), {}).get("target_stem"),
557
+ "download_files": list(files.keys()) + list(files.values()), # List of both model filenames and config filenames
558
  }
559
+ for name, files in {
560
+ **model_downloads_list["mdx23c_download_list"],
561
+ **model_downloads_list["mdx23c_download_vip_list"],
562
+ **model_downloads_list["roformer_download_list"],
563
+ **audio_separator_models_list["mdx23c_download_list"],
564
+ **audio_separator_models_list["roformer_download_list"],
565
+ }.items()
566
  },
 
567
  }
568
+
569
  return model_files_grouped_by_type
570
 
571
  def print_uvr_vip_message(self):
572
+ """
573
+ This method prints a message to the user if they have downloaded a VIP model, reminding them to support Anjok07 on Patreon.
574
+ """
575
  if self.model_is_uvr_vip:
576
+ self.logger.warning(f"The model: '{self.model_friendly_name}' is a VIP model, intended by Anjok07 for access by paying subscribers only.")
577
+ self.logger.warning("If you are not already subscribed, please consider supporting the developer of UVR, Anjok07 by subscribing here: https://patreon.com/uvr")
578
 
579
  def download_model_files(self, model_filename):
580
+ """
581
+ This method downloads the model files for a given model filename, if they are not already present.
582
+ Returns tuple of (model_filename, model_type, model_friendly_name, model_path, yaml_config_filename)
583
+ """
584
+ model_path = os.path.join(self.model_file_dir, f"{model_filename}")
585
+
586
+ supported_model_files_grouped = self.list_supported_model_files()
587
  public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
588
  vip_model_repo_url_prefix = "https://github.com/Anjok0109/ai_magic/releases/download/v5"
589
  audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
590
 
591
  yaml_config_filename = None
592
+
593
+ self.logger.debug(f"Searching for model_filename {model_filename} in supported_model_files_grouped")
594
+
595
+ # Iterate through model types (MDX, Demucs, MDXC)
596
+ for model_type, models in supported_model_files_grouped.items():
597
+ # Iterate through each model in this type
598
  for model_friendly_name, model_info in models.items():
599
  self.model_is_uvr_vip = "VIP" in model_friendly_name
600
  model_repo_url_prefix = vip_model_repo_url_prefix if self.model_is_uvr_vip else public_model_repo_url_prefix
601
+
602
+ # Check if this model matches our target filename
603
  if model_info["filename"] == model_filename or model_filename in model_info["download_files"]:
604
+ self.logger.debug(f"Found matching model: {model_friendly_name}")
605
  self.model_friendly_name = model_friendly_name
606
  self.print_uvr_vip_message()
607
+
608
+ # Download each required file for this model
609
  for file_to_download in model_info["download_files"]:
610
+ # For URLs, extract just the filename portion
611
  if file_to_download.startswith("http"):
612
  filename = file_to_download.split("/")[-1]
613
  download_path = os.path.join(self.model_file_dir, filename)
614
  self.download_file_if_not_exists(file_to_download, download_path)
 
 
615
  continue
616
+
617
  download_path = os.path.join(self.model_file_dir, file_to_download)
618
+
619
+ # For MDXC models, handle YAML config files specially
620
+ if model_type == "MDXC" and file_to_download.endswith(".yaml"):
621
+ yaml_config_filename = file_to_download
622
+ try:
623
+ yaml_url = f"{model_repo_url_prefix}/mdx_model_data/mdx_c_configs/{file_to_download}"
624
+ self.download_file_if_not_exists(yaml_url, download_path)
625
+ except RuntimeError:
626
+ self.logger.debug("YAML config not found in UVR repo, trying audio-separator models repo...")
627
+ yaml_url = f"{audio_separator_models_repo_url_prefix}/{file_to_download}"
628
+ self.download_file_if_not_exists(yaml_url, download_path)
629
+ continue
630
+
631
+ # For regular model files, try UVR repo first, then audio-separator repo
632
  try:
633
+ download_url = f"{model_repo_url_prefix}/{file_to_download}"
634
+ self.download_file_if_not_exists(download_url, download_path)
635
  except RuntimeError:
636
+ self.logger.debug("Model not found in UVR repo, trying audio-separator models repo...")
637
+ download_url = f"{audio_separator_models_repo_url_prefix}/{file_to_download}"
638
+ self.download_file_if_not_exists(download_url, download_path)
639
+
640
  return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
641
+
642
+ raise ValueError(f"Model file {model_filename} not found in supported model files")
643
 
644
  def load_model_data_from_yaml(self, yaml_config_filename):
645
+ """
646
+ This method loads model-specific parameters from the YAML file for that model.
647
+ The parameters in the YAML are critical to inferencing, as they need to match whatever was used during training.
648
+ """
649
+ # Verify if the YAML filename includes a full path or just the filename
650
+ if not os.path.exists(yaml_config_filename):
651
+ model_data_yaml_filepath = os.path.join(self.model_file_dir, yaml_config_filename)
652
+ else:
653
+ model_data_yaml_filepath = yaml_config_filename
654
+
655
+ self.logger.debug(f"Loading model data from YAML at path {model_data_yaml_filepath}")
656
+
657
+ model_data = yaml.load(open(model_data_yaml_filepath, encoding="utf-8"), Loader=yaml.FullLoader)
658
+ self.logger.debug(f"Model data loaded from YAML file: {model_data}")
659
+
660
+ if "roformer" in model_data_yaml_filepath:
661
+ model_data["is_roformer"] = True
662
+
663
+ return model_data
664
 
665
  def load_model_data_using_hash(self, model_path):
666
+ """
667
+ This method loads model-specific parameters from UVR model data files.
668
+ These parameters are critical to inferencing using a given model, as they need to match whatever was used during training.
669
+ The correct parameters are identified by calculating the hash of the model file and looking up the hash in the UVR data files.
670
+ """
671
+ # Model data and configuration sources from UVR
672
+ model_data_url_prefix = "https://raw.githubusercontent.com/TRvlvr/application_data/main"
673
+
674
+ vr_model_data_url = f"{model_data_url_prefix}/vr_model_data/model_data_new.json"
675
+ mdx_model_data_url = f"{model_data_url_prefix}/mdx_model_data/model_data_new.json"
676
+
677
+ # Calculate hash for the downloaded model
678
+ self.logger.debug("Calculating MD5 hash for model file to identify model parameters from UVR data...")
679
  model_hash = self.get_model_hash(model_path)
680
+ self.logger.debug(f"Model {model_path} has hash {model_hash}")
681
+
682
+ # Setting up the path for model data and checking its existence
683
+ vr_model_data_path = os.path.join(self.model_file_dir, "vr_model_data.json")
684
+ self.logger.debug(f"VR model data path set to {vr_model_data_path}")
685
+ self.download_file_if_not_exists(vr_model_data_url, vr_model_data_path)
686
+
687
+ mdx_model_data_path = os.path.join(self.model_file_dir, "mdx_model_data.json")
688
+ self.logger.debug(f"MDX model data path set to {mdx_model_data_path}")
689
+ self.download_file_if_not_exists(mdx_model_data_url, mdx_model_data_path)
690
+
691
+ # Loading model data from UVR
692
+ self.logger.debug("Loading MDX and VR model parameters from UVR model data files...")
693
+ vr_model_data_object = json.load(open(vr_model_data_path, encoding="utf-8"))
694
+ mdx_model_data_object = json.load(open(mdx_model_data_path, encoding="utf-8"))
695
+
696
+ # Load additional model data from audio-separator
697
+ self.logger.debug("Loading additional model parameters from audio-separator model data file...")
698
+ with resources.open_text("audio_separator", "model-data.json") as f:
699
+ audio_separator_model_data = json.load(f)
700
+
701
+ # Merge the model data objects, with audio-separator data taking precedence
702
+ vr_model_data_object = {**vr_model_data_object, **audio_separator_model_data.get("vr_model_data", {})}
703
+ mdx_model_data_object = {**mdx_model_data_object, **audio_separator_model_data.get("mdx_model_data", {})}
704
+
705
+ if model_hash in mdx_model_data_object:
706
+ model_data = mdx_model_data_object[model_hash]
707
+ elif model_hash in vr_model_data_object:
708
+ model_data = vr_model_data_object[model_hash]
709
+ else:
710
+ raise ValueError(f"Unsupported Model File: parameters for MD5 hash {model_hash} could not be found in UVR model data file for MDX or VR arch.")
711
 
712
+ self.logger.debug(f"Model data loaded using hash {model_hash}: {model_data}")
713
+
714
+ return model_data
715
+
716
+ def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt"):
717
+ """
718
+ This method instantiates the architecture-specific separation class,
719
+ loading the separation model into memory, downloading it first if necessary.
720
+ """
721
+ self.logger.info(f"Loading model {model_filename}...")
722
 
723
+ load_model_start_time = time.perf_counter()
724
+
725
+ # Setting up the model path
726
+ model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
727
  model_name = model_filename.split(".")[0]
728
+ self.logger.debug(f"Model downloaded, friendly name: {model_friendly_name}, model_path: {model_path}")
729
+
730
+ if model_path.lower().endswith(".yaml"):
731
+ yaml_config_filename = model_path
732
+
733
+ if yaml_config_filename is not None:
734
+ model_data = self.load_model_data_from_yaml(yaml_config_filename)
735
+ else:
736
+ model_data = self.load_model_data_using_hash(model_path)
737
 
738
  common_params = {
739
  "logger": self.logger,
 
756
  "use_soundfile": self.use_soundfile,
757
  }
758
 
759
+ # Instantiate the appropriate separator class depending on the model type
760
+ separator_classes = {"MDX": "mdx_separator.MDXSeparator", "VR": "vr_separator.VRSeparator", "Demucs": "demucs_separator.DemucsSeparator", "MDXC": "mdxc_separator.MDXCSeparator"}
761
+
762
+ if model_type not in self.arch_specific_params or model_type not in separator_classes:
763
+ raise ValueError(f"Model type not supported (yet): {model_type}")
 
764
 
 
 
765
  if model_type == "Demucs" and sys.version_info < (3, 10):
766
+ raise Exception("Demucs models require Python version 3.10 or newer.")
767
+
768
+ self.logger.debug(f"Importing module for model type {model_type}: {separator_classes[model_type]}")
769
 
770
  module_name, class_name = separator_classes[model_type].split(".")
771
+ module = importlib.import_module(f"audio_separator.separator.architectures.{module_name}")
772
+ separator_class = getattr(module, class_name)
773
+
774
+ self.logger.debug(f"Instantiating separator class for model type {model_type}: {separator_class}")
775
+ self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])
776
+
777
+ # Log the completion of the model load process
778
+ self.logger.debug("Loading model completed.")
779
+ self.logger.info(f'Load model duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - load_model_start_time)))}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
780
 
781
  @spaces.GPU
782
  def separate(self, audio_file_path, custom_output_names=None):
783
+ """
784
+ Separates the audio file(s) into different stems (e.g., vocals, instruments) using the loaded model.
785
+
786
+ This method takes the path to an audio file or a directory containing audio files, processes them through
787
+ the loaded separation model, and returns the paths to the output files containing the separated audio stems.
788
+ It handles the entire flow from loading the audio, running the separation, clearing up resources, and logging the process.
789
+
790
+ Parameters:
791
+ - audio_file_path (str or list): The path to the audio file or directory, or a list of paths.
792
+ - custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
793
 
794
+ Returns:
795
+ - output_files (list of str): A list containing the paths to the separated audio stem files.
796
+ """
797
+ # Check if the model and device are properly initialized
798
+ if not (self.torch_device and self.model_instance):
799
+ raise ValueError("Initialization failed or model not loaded. Please load a model before attempting to separate.")
800
 
801
+ # If audio_file_path is a string, convert it to a list for uniform processing
802
  if isinstance(audio_file_path, str):
803
  audio_file_path = [audio_file_path]
804
 
805
+ # Initialize a list to store paths of all output files
806
  output_files = []
807
+
808
+ # Process each path in the list
809
  for path in audio_file_path:
810
  if os.path.isdir(path):
811
+ # If the path is a directory, recursively search for all audio files
812
+ for root, dirs, files in os.walk(path):
813
  for file in files:
814
+ # Check the file extension to ensure it's an audio file
815
+ if file.endswith((".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aiff", ".ac3")): # Add other formats if needed
816
  full_path = os.path.join(root, file)
817
+ self.logger.info(f"Processing file: {full_path}")
818
+ try:
819
+ # Perform separation for each file
820
+ files_output = self._separate_file(full_path, custom_output_names)
821
+ output_files.extend(files_output)
822
+ except Exception as e:
823
+ self.logger.error(f"Failed to process file {full_path}: {e}")
824
  else:
825
+ # If the path is a file, process it directly
826
+ self.logger.info(f"Processing file: {path}")
827
+ try:
828
+ files_output = self._separate_file(path, custom_output_names)
829
+ output_files.extend(files_output)
830
+ except Exception as e:
831
+ self.logger.error(f"Failed to process file {path}: {e}")
832
 
 
 
833
  return output_files
834
 
835
  @spaces.GPU
836
  def _separate_file(self, audio_file_path, custom_output_names=None):
837
+ """
838
+ Internal method to handle separation for a single audio file.
839
+ This method performs the actual separation process for a single audio file. It logs the start and end of the process,
840
+ handles autocast if enabled, and ensures GPU cache is cleared after processing.
841
+ Parameters:
842
+ - audio_file_path (str): The path to the audio file.
843
+ - custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
844
+ Returns:
845
+ - output_files (list of str): A list containing the paths to the separated audio stem files.
846
+ """
847
+ # Log the start of the separation process
848
+ self.logger.info(f"Starting separation process for audio_file_path: {audio_file_path}")
849
+ separate_start_time = time.perf_counter()
850
+
851
+ # Log normalization and amplification thresholds
852
+ self.logger.debug(f"Normalization threshold set to {self.normalization_threshold}, waveform will be lowered to this max amplitude to avoid clipping.")
853
+ self.logger.debug(f"Amplification threshold set to {self.amplification_threshold}, waveform will be scaled up to this max amplitude if below it.")
854
+
855
+ # Run separation method for the loaded model with autocast enabled if supported by the device
856
+ output_files = None
857
+ if self.use_autocast and autocast_mode.is_autocast_available(self.torch_device.type):
858
+ self.logger.debug("Autocast available.")
859
+ with autocast_mode.autocast(self.torch_device.type):
860
+ output_files = self.model_instance.separate(audio_file_path, custom_output_names)
861
+ else:
862
+ self.logger.debug("Autocast unavailable.")
863
+ output_files = self.model_instance.separate(audio_file_path, custom_output_names)
 
 
 
 
864
 
865
+ # Clear GPU cache to free up memory
866
  self.model_instance.clear_gpu_cache()
867
+
868
+ # Unset separation parameters to prevent accidentally re-using the wrong source files or output paths
869
  self.model_instance.clear_file_specific_paths()
870
+
871
+ # Remind the user one more time if they used a VIP model, so the message doesn't get lost in the logs
872
+ self.print_uvr_vip_message()
873
+
874
+ # Log the completion of the separation process
875
+ self.logger.debug("Separation process completed.")
876
+ self.logger.info(f'Separation duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - separate_start_time)))}')
877
+
878
  return output_files
879
 
880
  def download_model_and_data(self, model_filename):
881
+ """
882
+ Downloads the model file without loading it into memory.
883
+ """
884
+ self.logger.info(f"Downloading model {model_filename}...")
885
+
886
+ model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
887
+
888
+ if model_path.lower().endswith(".yaml"):
889
+ yaml_config_filename = model_path
890
+
891
+ if yaml_config_filename is not None:
892
+ model_data = self.load_model_data_from_yaml(yaml_config_filename)
893
+ else:
894
+ model_data = self.load_model_data_using_hash(model_path)
895
+
896
+ model_data_dict_size = len(model_data)
897
+
898
+ self.logger.info(f"Model downloaded, type: {model_type}, friendly name: {model_friendly_name}, model_path: {model_path}, model_data: {model_data_dict_size} items")
899
 
900
  def get_simplified_model_list(self, filter_sort_by: Optional[str] = None):
901
+ """
902
+ Returns a simplified, user-friendly list of models with their key metrics.
903
+ Optionally sorts the list based on the specified criteria.
904
+
905
+ :param sort_by: Criteria to sort by. Can be "name", "filename", or any stem name
906
+ """
907
  model_files = self.list_supported_model_files()
908
  simplified_list = {}
909
 
910
  for model_type, models in model_files.items():
911
  for name, data in models.items():
912
  filename = data["filename"]
913
+ scores = data.get("scores") or {}
914
+ stems = data.get("stems") or []
915
  target_stem = data.get("target_stem")
916
+
917
+ # Format stems with their SDR scores where available
918
  stems_with_scores = []
919
  stem_sdr_dict = {}
920
+
921
+ # Process each stem from the model's stem list
922
  for stem in stems:
923
+ stem_scores = scores.get(stem, {})
924
+ # Add asterisk if this is the target stem
925
  stem_display = f"{stem}*" if stem == target_stem else stem
926
+
927
+ if isinstance(stem_scores, dict) and "SDR" in stem_scores:
928
+ sdr = round(stem_scores["SDR"], 1)
929
  stems_with_scores.append(f"{stem_display} ({sdr})")
930
  stem_sdr_dict[stem.lower()] = sdr
931
  else:
932
+ # Include stem without SDR score
933
  stems_with_scores.append(stem_display)
934
  stem_sdr_dict[stem.lower()] = None
935
 
936
+ # If no stems listed, mark as Unknown
937
  if not stems_with_scores:
938
  stems_with_scores = ["Unknown"]
939
  stem_sdr_dict["unknown"] = None
940
 
941
+ simplified_list[filename] = {"Name": name, "Type": model_type, "Stems": stems_with_scores, "SDR": stem_sdr_dict}
 
 
 
 
 
942
 
943
+ # Sort and filter the list if a sort_by parameter is provided
944
  if filter_sort_by:
945
  if filter_sort_by == "name":
946
  return dict(sorted(simplified_list.items(), key=lambda x: x[1]["Name"]))
947
  elif filter_sort_by == "filename":
948
  return dict(sorted(simplified_list.items()))
949
  else:
950
+ # Convert sort_by to lowercase for case-insensitive comparison
951
  sort_by_lower = filter_sort_by.lower()
952
+ # Filter out models that don't have the specified stem
953
  filtered_list = {k: v for k, v in simplified_list.items() if sort_by_lower in v["SDR"]}
954
+
955
+ # Sort by SDR score if available, putting None values last
956
  def sort_key(item):
957
+ sdr = item[1]["SDR"][sort_by_lower]
958
  return (0 if sdr is None else 1, sdr if sdr is not None else float("-inf"))
959
+
960
  return dict(sorted(filtered_list.items(), key=sort_key, reverse=True))
961
 
962
  return simplified_list