ASesYusuf1 commited on
Commit
6de14d9
·
verified ·
1 Parent(s): 01f8b5b

Update audio_separator/separator/separator.py

Browse files
Files changed (1) hide show
  1. audio_separator/separator/separator.py +290 -736
audio_separator/separator/separator.py CHANGED
@@ -22,68 +22,34 @@ import onnxruntime as ort
22
  from tqdm import tqdm
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  class Separator:
26
  """
27
- The Separator class is designed to facilitate the separation of audio sources from a given audio file.
28
- It supports various separation architectures and models, including MDX, VR, and Demucs. The class provides
29
- functionalities to configure separation parameters, load models, and perform audio source separation.
30
- It also handles logging, normalization, and output formatting of the separated audio stems.
31
-
32
- The actual separation task is handled by one of the architecture-specific classes in the `architectures` module;
33
- this class is responsible for initialising logging, configuring hardware acceleration, loading the model,
34
- initiating the separation process and passing outputs back to the caller.
35
-
36
- Common Attributes:
37
- log_level (int): The logging level.
38
- log_formatter (logging.Formatter): The logging formatter.
39
- model_file_dir (str): The directory where model files are stored.
40
- output_dir (str): The directory where output files will be saved.
41
- output_format (str): The format of the output audio file.
42
- output_bitrate (str): The bitrate of the output audio file.
43
- amplification_threshold (float): The threshold for audio amplification.
44
- normalization_threshold (float): The threshold for audio normalization.
45
- output_single_stem (str): Option to output a single stem.
46
- invert_using_spec (bool): Flag to invert using spectrogram.
47
- sample_rate (int): The sample rate of the audio.
48
- use_soundfile (bool): Use soundfile for audio writing, can solve OOM issues.
49
- use_autocast (bool): Flag to use PyTorch autocast for faster inference.
50
-
51
- MDX Architecture Specific Attributes:
52
- hop_length (int): The hop length for STFT.
53
- segment_size (int): The segment size for processing.
54
- overlap (float): The overlap between segments.
55
- batch_size (int): The batch size for processing.
56
- enable_denoise (bool): Flag to enable or disable denoising.
57
-
58
- VR Architecture Specific Attributes & Defaults:
59
- batch_size: 16
60
- window_size: 512
61
- aggression: 5
62
- enable_tta: False
63
- enable_post_process: False
64
- post_process_threshold: 0.2
65
- high_end_process: False
66
-
67
- Demucs Architecture Specific Attributes & Defaults:
68
- segment_size: "Default"
69
- shifts: 2
70
- overlap: 0.25
71
- segments_enabled: True
72
-
73
- MDXC Architecture Specific Attributes & Defaults:
74
- segment_size: 256
75
- override_model_segment_size: False
76
- batch_size: 1
77
- overlap: 8
78
- pitch_shift: 0
79
  """
80
-
81
  def __init__(
82
  self,
83
  log_level=logging.INFO,
84
- log_formatter=None,
85
  model_file_dir="/tmp/audio-separator-models/",
86
- output_dir=None,
87
  output_format="WAV",
88
  output_bitrate=None,
89
  normalization_threshold=0.9,
@@ -91,8 +57,8 @@ class Separator:
91
  output_single_stem=None,
92
  invert_using_spec=False,
93
  sample_rate=44100,
94
- use_soundfile=False,
95
- use_autocast=False,
96
  use_directml=False,
97
  mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
98
  vr_params={"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
@@ -100,639 +66,271 @@ class Separator:
100
  mdxc_params={"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0},
101
  info_only=False,
102
  ):
103
- """Initialize the separator."""
104
  self.logger = logging.getLogger(__name__)
105
  self.logger.setLevel(log_level)
106
- self.log_level = log_level
107
- self.log_formatter = log_formatter
108
-
109
- self.log_handler = logging.StreamHandler()
110
-
111
- if self.log_formatter is None:
112
- self.log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(module)s - %(message)s")
113
-
114
- self.log_handler.setFormatter(self.log_formatter)
115
-
116
  if not self.logger.hasHandlers():
117
- self.logger.addHandler(self.log_handler)
118
-
119
- # Filter out noisy warnings from PyTorch for users who don't care about them
120
- if log_level > logging.DEBUG:
121
- warnings.filterwarnings("ignore")
122
 
123
- # Skip initialization logs if info_only is True
124
- if not info_only:
125
- package_version = self.get_package_distribution("audio-separator").version
126
- self.logger.info(f"Separator version {package_version} instantiating with output_dir: {output_dir}, output_format: {output_format}")
127
-
128
- if output_dir is None:
129
- output_dir = os.getcwd()
130
- if not info_only:
131
- self.logger.info("Output directory not specified. Using current working directory.")
132
-
133
- self.output_dir = output_dir
134
-
135
- # Check for environment variable to override model_file_dir
136
- env_model_dir = os.environ.get("AUDIO_SEPARATOR_MODEL_DIR")
137
- if env_model_dir:
138
- self.model_file_dir = env_model_dir
139
- self.logger.info(f"Using model directory from AUDIO_SEPARATOR_MODEL_DIR env var: {self.model_file_dir}")
140
- if not os.path.exists(self.model_file_dir):
141
- raise FileNotFoundError(f"The specified model directory does not exist: {self.model_file_dir}")
142
- else:
143
- self.logger.info(f"Using model directory from model_file_dir parameter: {model_file_dir}")
144
- self.model_file_dir = model_file_dir
145
-
146
- # Create the model directory if it does not exist
147
- os.makedirs(self.model_file_dir, exist_ok=True)
148
- os.makedirs(self.output_dir, exist_ok=True)
149
-
150
- self.output_format = output_format
151
  self.output_bitrate = output_bitrate
152
-
153
- if self.output_format is None:
154
- self.output_format = "WAV"
155
-
156
  self.normalization_threshold = normalization_threshold
157
- if normalization_threshold <= 0 or normalization_threshold > 1:
158
- raise ValueError("The normalization_threshold must be greater than 0 and less than or equal to 1.")
159
-
160
  self.amplification_threshold = amplification_threshold
161
- if amplification_threshold < 0 or amplification_threshold > 1:
162
- raise ValueError("The amplification_threshold must be greater than or equal to 0 and less than or equal to 1.")
163
-
164
  self.output_single_stem = output_single_stem
165
- if output_single_stem is not None:
166
- self.logger.debug(f"Single stem output requested, so only one output file ({output_single_stem}) will be written")
167
-
168
  self.invert_using_spec = invert_using_spec
169
- if self.invert_using_spec:
170
- self.logger.debug(f"Secondary step will be inverted using spectogram rather than waveform. This may improve quality but is slightly slower.")
171
-
172
- try:
173
- self.sample_rate = int(sample_rate)
174
- if self.sample_rate <= 0:
175
- raise ValueError(f"The sample rate setting is {self.sample_rate} but it must be a non-zero whole number.")
176
- if self.sample_rate > 12800000:
177
- raise ValueError(f"The sample rate setting is {self.sample_rate}. Enter something less ambitious.")
178
- except ValueError:
179
- raise ValueError("The sample rate must be a non-zero whole number. Please provide a valid integer.")
180
-
181
  self.use_soundfile = use_soundfile
182
  self.use_autocast = use_autocast
183
  self.use_directml = use_directml
184
-
185
- # These are parameters which users may want to configure so we expose them to the top-level Separator class,
186
- # even though they are specific to a single model architecture
187
  self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params}
188
 
189
- self.torch_device = None
190
- self.torch_device_cpu = None
191
- self.torch_device_mps = None
 
 
 
 
192
 
193
- self.onnx_execution_provider = None
194
- self.model_instance = None
 
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  self.model_is_uvr_vip = False
197
  self.model_friendly_name = None
198
 
199
  if not info_only:
200
- self.setup_accelerated_inferencing_device()
201
 
202
  def setup_accelerated_inferencing_device(self):
203
- """
204
- This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
205
- """
206
- system_info = self.get_system_info()
207
- self.check_ffmpeg_installed()
208
- self.log_onnxruntime_packages()
209
- self.setup_torch_device(system_info)
210
-
211
- def get_system_info(self):
212
- """
213
- This method logs the system information, including the operating system, CPU archutecture and Python version
214
- """
215
- os_name = platform.system()
216
- os_version = platform.version()
217
- self.logger.info(f"Operating System: {os_name} {os_version}")
218
-
219
- system_info = platform.uname()
220
- self.logger.info(f"System: {system_info.system} Node: {system_info.node} Release: {system_info.release} Machine: {system_info.machine} Proc: {system_info.processor}")
221
-
222
- python_version = platform.python_version()
223
- self.logger.info(f"Python Version: {python_version}")
224
-
225
- pytorch_version = torch.__version__
226
- self.logger.info(f"PyTorch Version: {pytorch_version}")
227
- return system_info
228
-
229
- def check_ffmpeg_installed(self):
230
- """
231
- This method checks if ffmpeg is installed and logs its version.
232
- """
233
- try:
234
- ffmpeg_version_output = subprocess.check_output(["ffmpeg", "-version"], text=True)
235
- first_line = ffmpeg_version_output.splitlines()[0]
236
- self.logger.info(f"FFmpeg installed: {first_line}")
237
- except FileNotFoundError:
238
- self.logger.error("FFmpeg is not installed. Please install FFmpeg to use this package.")
239
- # Raise an exception if this is being run by a user, as ffmpeg is required for pydub to write audio
240
- # but if we're just running unit tests in CI, no reason to throw
241
- if "PYTEST_CURRENT_TEST" not in os.environ:
242
- raise
243
-
244
- def log_onnxruntime_packages(self):
245
- """
246
- This method logs the ONNX Runtime package versions, including the GPU and Silicon packages if available.
247
- """
248
- onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu")
249
- onnxruntime_silicon_package = self.get_package_distribution("onnxruntime-silicon")
250
- onnxruntime_cpu_package = self.get_package_distribution("onnxruntime")
251
- onnxruntime_dml_package = self.get_package_distribution("onnxruntime-directml")
252
-
253
- if onnxruntime_gpu_package is not None:
254
- self.logger.info(f"ONNX Runtime GPU package installed with version: {onnxruntime_gpu_package.version}")
255
- if onnxruntime_silicon_package is not None:
256
- self.logger.info(f"ONNX Runtime Silicon package installed with version: {onnxruntime_silicon_package.version}")
257
- if onnxruntime_cpu_package is not None:
258
- self.logger.info(f"ONNX Runtime CPU package installed with version: {onnxruntime_cpu_package.version}")
259
- if onnxruntime_dml_package is not None:
260
- self.logger.info(f"ONNX Runtime DirectML package installed with version: {onnxruntime_dml_package.version}")
261
-
262
- def setup_torch_device(self, system_info):
263
- """
264
- This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
265
- """
266
- hardware_acceleration_enabled = False
267
- ort_providers = ort.get_available_providers()
268
- has_torch_dml_installed = self.get_package_distribution("torch_directml")
269
-
270
- self.torch_device_cpu = torch.device("cpu")
271
-
272
- if torch.cuda.is_available():
273
- self.configure_cuda(ort_providers)
274
- hardware_acceleration_enabled = True
275
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and system_info.processor == "arm":
276
- self.configure_mps(ort_providers)
277
- hardware_acceleration_enabled = True
278
- elif self.use_directml and has_torch_dml_installed:
279
- import torch_directml
280
- if torch_directml.is_available():
281
- self.configure_dml(ort_providers)
282
- hardware_acceleration_enabled = True
283
-
284
- if not hardware_acceleration_enabled:
285
- self.logger.info("No hardware acceleration could be configured, running in CPU mode")
286
- self.torch_device = self.torch_device_cpu
287
- self.onnx_execution_provider = ["CPUExecutionProvider"]
288
-
289
- def configure_cuda(self, ort_providers):
290
- """
291
- This method configures the CUDA device for PyTorch and ONNX Runtime, if available.
292
- """
293
- self.logger.info("CUDA is available in Torch, setting Torch device to CUDA")
294
- self.torch_device = torch.device("cuda")
295
- if "CUDAExecutionProvider" in ort_providers:
296
- self.logger.info("ONNXruntime has CUDAExecutionProvider available, enabling acceleration")
297
- self.onnx_execution_provider = ["CUDAExecutionProvider"]
298
- else:
299
- self.logger.warning("CUDAExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
300
-
301
- def configure_mps(self, ort_providers):
302
- """
303
- This method configures the Apple Silicon MPS/CoreML device for PyTorch and ONNX Runtime, if available.
304
- """
305
- self.logger.info("Apple Silicon MPS/CoreML is available in Torch and processor is ARM, setting Torch device to MPS")
306
- self.torch_device_mps = torch.device("mps")
307
-
308
- self.torch_device = self.torch_device_mps
309
-
310
- if "CoreMLExecutionProvider" in ort_providers:
311
- self.logger.info("ONNXruntime has CoreMLExecutionProvider available, enabling acceleration")
312
- self.onnx_execution_provider = ["CoreMLExecutionProvider"]
313
  else:
314
- self.logger.warning("CoreMLExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
315
-
316
- def configure_dml(self, ort_providers):
317
- """
318
- This method configures the DirectML device for PyTorch and ONNX Runtime, if available.
319
- """
320
- import torch_directml
321
- self.logger.info("DirectML is available in Torch, setting Torch device to DirectML")
322
- self.torch_device_dml = torch_directml.device()
323
- self.torch_device = self.torch_device_dml
324
-
325
- if "DmlExecutionProvider" in ort_providers:
326
- self.logger.info("ONNXruntime has DmlExecutionProvider available, enabling acceleration")
327
- self.onnx_execution_provider = ["DmlExecutionProvider"]
328
- else:
329
- self.logger.warning("DmlExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
330
-
331
- def get_package_distribution(self, package_name):
332
- """
333
- This method returns the package distribution for a given package name if installed, or None otherwise.
334
- """
335
- try:
336
- return metadata.distribution(package_name)
337
- except metadata.PackageNotFoundError:
338
- self.logger.debug(f"Python package: {package_name} not installed")
339
- return None
340
 
341
  def get_model_hash(self, model_path):
342
- """
343
- This method returns the MD5 hash of a given model file.
344
- """
345
- self.logger.debug(f"Calculating hash of model file {model_path}")
346
- # Use the specific byte count from the original logic
347
- BYTES_TO_HASH = 10000 * 1024 # 10,240,000 bytes
348
-
349
  try:
350
- file_size = os.path.getsize(model_path)
351
-
352
  with open(model_path, "rb") as f:
 
353
  if file_size < BYTES_TO_HASH:
354
- # Hash the entire file if smaller than the target byte count
355
- self.logger.debug(f"File size {file_size} < {BYTES_TO_HASH}, hashing entire file.")
356
- hash_value = hashlib.md5(f.read()).hexdigest()
357
- else:
358
- # Seek to the specific position before the end (from the beginning) and hash
359
- seek_pos = file_size - BYTES_TO_HASH
360
- self.logger.debug(f"File size {file_size} >= {BYTES_TO_HASH}, seeking to {seek_pos} and hashing remaining bytes.")
361
- f.seek(seek_pos, io.SEEK_SET)
362
- hash_value = hashlib.md5(f.read()).hexdigest()
363
-
364
- # Log the calculated hash
365
- self.logger.info(f"Hash of model file {model_path} is {hash_value}")
366
- return hash_value
367
-
368
- except FileNotFoundError:
369
- self.logger.error(f"Model file not found at {model_path}")
370
- raise # Re-raise the specific error
371
  except Exception as e:
372
- # Catch other potential errors (e.g., permissions, other IOErrors)
373
  self.logger.error(f"Error calculating hash for {model_path}: {e}")
374
- raise # Re-raise other errors
375
 
376
  def download_file_if_not_exists(self, url, output_path):
377
- """
378
- This method downloads a file from a given URL to a given output path, if the file does not already exist.
379
- """
380
-
381
- if os.path.isfile(output_path):
382
- self.logger.debug(f"File already exists at {output_path}, skipping download")
383
  return
384
-
385
- self.logger.debug(f"Downloading file from {url} to {output_path} with timeout 300s")
386
- response = requests.get(url, stream=True, timeout=300)
387
-
388
  if response.status_code == 200:
389
- total_size_in_bytes = int(response.headers.get("content-length", 0))
390
- progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
391
-
392
- with open(output_path, "wb") as f:
393
  for chunk in response.iter_content(chunk_size=8192):
394
- progress_bar.update(len(chunk))
395
  f.write(chunk)
396
- progress_bar.close()
397
  else:
398
- raise RuntimeError(f"Failed to download file from {url}, response code: {response.status_code}")
399
 
400
  def list_supported_model_files(self):
401
- """
402
- This method lists the supported model files for audio-separator, by fetching the same file UVR uses to list these.
403
- Also includes model performance scores where available.
404
-
405
- Example response object:
406
 
407
- {
408
- "MDX": {
409
- "MDX-Net Model VIP: UVR-MDX-NET-Inst_full_292": {
410
- "filename": "UVR-MDX-NET-Inst_full_292.onnx",
411
- "scores": {
412
- "vocals": {
413
- "SDR": 10.6497,
414
- "SIR": 20.3786,
415
- "SAR": 10.692,
416
- "ISR": 14.848
417
- },
418
- "instrumental": {
419
- "SDR": 15.2149,
420
- "SIR": 25.6075,
421
- "SAR": 17.1363,
422
- "ISR": 17.7893
423
- }
424
  },
425
- "download_files": [
426
- "UVR-MDX-NET-Inst_full_292.onnx"
427
- ]
428
- }
429
  },
430
- "Demucs": {
431
- "Demucs v4: htdemucs_ft": {
432
- "filename": "htdemucs_ft.yaml",
433
- "scores": {
434
- "vocals": {
435
- "SDR": 11.2685,
436
- "SIR": 21.257,
437
- "SAR": 11.0359,
438
- "ISR": 19.3753
439
- },
440
- "drums": {
441
- "SDR": 13.235,
442
- "SIR": 23.3053,
443
- "SAR": 13.0313,
444
- "ISR": 17.2889
445
- },
446
- "bass": {
447
- "SDR": 9.72743,
448
- "SIR": 19.5435,
449
- "SAR": 9.20801,
450
- "ISR": 13.5037
451
- }
452
  },
453
- "download_files": [
454
- "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th",
455
- "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/d12395a8-e57c48e6.th",
456
- "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/92cfc3b6-ef3bcb9c.th",
457
- "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th",
458
- "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/htdemucs_ft.yaml"
459
- ]
460
- }
461
  },
462
- "MDXC": {
463
- "MDX23C Model: MDX23C-InstVoc HQ": {
464
- "filename": "MDX23C-8KFFT-InstVoc_HQ.ckpt",
465
- "scores": {
466
- "vocals": {
467
- "SDR": 11.9504,
468
- "SIR": 23.1166,
469
- "SAR": 12.093,
470
- "ISR": 15.4782
471
- },
472
- "instrumental": {
473
- "SDR": 16.3035,
474
- "SIR": 26.6161,
475
- "SAR": 18.5167,
476
- "ISR": 18.3939
477
- }
478
  },
479
- "download_files": [
480
- "MDX23C-8KFFT-InstVoc_HQ.ckpt",
481
- "model_2_stem_full_band_8k.yaml"
482
- ]
483
- }
484
  }
485
  }
486
- """
487
- download_checks_path = os.path.join(self.model_file_dir, "download_checks.json")
488
-
489
- self.download_file_if_not_exists("https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json", download_checks_path)
490
-
491
- model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
492
- self.logger.debug(f"UVR model download list loaded")
493
 
494
- # Load the model scores with error handling
495
- model_scores = {}
496
- try:
497
- with resources.open_text("audio_separator", "models-scores.json") as f:
498
- model_scores = json.load(f)
499
- self.logger.debug(f"Model scores loaded")
500
- except json.JSONDecodeError as e:
501
- self.logger.warning(f"Failed to load model scores: {str(e)}")
502
- self.logger.warning("Continuing without model scores")
503
-
504
- # Only show Demucs v4 models as we've only implemented support for v4
505
- filtered_demucs_v4 = {key: value for key, value in model_downloads_list["demucs_download_list"].items() if key.startswith("Demucs v4")}
506
-
507
- # Modified Demucs handling to use YAML files as identifiers and include download files
508
- demucs_models = {}
509
- for name, files in filtered_demucs_v4.items():
510
- # Find the YAML file in the model files
511
- yaml_file = next((filename for filename in files.keys() if filename.endswith(".yaml")), None)
512
- if yaml_file:
513
- model_score_data = model_scores.get(yaml_file, {})
514
- demucs_models[name] = {
515
- "filename": yaml_file,
516
- "scores": model_score_data.get("median_scores", {}),
517
- "stems": model_score_data.get("stems", []),
518
- "target_stem": model_score_data.get("target_stem"),
519
- "download_files": list(files.values()), # List of all download URLs/filenames
520
- }
521
-
522
- # Load the JSON file using importlib.resources
523
- with resources.open_text("audio_separator", "models.json") as f:
524
- audio_separator_models_list = json.load(f)
525
- self.logger.debug(f"Audio-Separator model list loaded")
526
 
527
- # Return object with list of model names
528
  model_files_grouped_by_type = {
 
 
 
 
 
 
 
 
 
529
  "VR": {
530
- name: {
531
- "filename": filename,
532
- "scores": model_scores.get(filename, {}).get("median_scores", {}),
533
- "stems": model_scores.get(filename, {}).get("stems", []),
534
- "target_stem": model_scores.get(filename, {}).get("target_stem"),
535
- "download_files": [filename],
536
- } # Just the filename for VR models
537
- for name, filename in {**model_downloads_list["vr_download_list"], **audio_separator_models_list["vr_download_list"]}.items()
538
  },
539
- "MDX": {
540
- name: {
541
- "filename": filename,
542
- "scores": model_scores.get(filename, {}).get("median_scores", {}),
543
- "stems": model_scores.get(filename, {}).get("stems", []),
544
- "target_stem": model_scores.get(filename, {}).get("target_stem"),
545
- "download_files": [filename],
546
- } # Just the filename for MDX models
547
- for name, filename in {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"], **audio_separator_models_list["mdx_download_list"]}.items()
 
 
548
  },
549
- "Demucs": demucs_models,
550
  "MDXC": {
551
- name: {
552
- "filename": next(iter(files.keys())),
553
- "scores": model_scores.get(next(iter(files.keys())), {}).get("median_scores", {}),
554
- "stems": model_scores.get(next(iter(files.keys())), {}).get("stems", []),
555
- "target_stem": model_scores.get(next(iter(files.keys())), {}).get("target_stem"),
556
- "download_files": list(files.keys()) + list(files.values()), # List of both model filenames and config filenames
 
 
 
557
  }
558
- for name, files in {
559
- **model_downloads_list["mdx23c_download_list"],
560
- **model_downloads_list["mdx23c_download_vip_list"],
561
- **model_downloads_list["roformer_download_list"],
562
- **audio_separator_models_list["mdx23c_download_list"],
563
- **audio_separator_models_list["roformer_download_list"],
564
- }.items()
565
- },
566
  }
567
-
568
  return model_files_grouped_by_type
569
 
570
  def print_uvr_vip_message(self):
571
- """
572
- This method prints a message to the user if they have downloaded a VIP model, reminding them to support Anjok07 on Patreon.
573
- """
574
  if self.model_is_uvr_vip:
575
- self.logger.warning(f"The model: '{self.model_friendly_name}' is a VIP model, intended by Anjok07 for access by paying subscribers only.")
576
- self.logger.warning("If you are not already subscribed, please consider supporting the developer of UVR, Anjok07 by subscribing here: https://patreon.com/uvr")
577
 
578
  def download_model_files(self, model_filename):
579
- """
580
- This method downloads the model files for a given model filename, if they are not already present.
581
- Returns tuple of (model_filename, model_type, model_friendly_name, model_path, yaml_config_filename)
582
- """
583
- model_path = os.path.join(self.model_file_dir, f"{model_filename}")
584
-
585
- supported_model_files_grouped = self.list_supported_model_files()
586
  public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
587
  vip_model_repo_url_prefix = "https://github.com/Anjok0109/ai_magic/releases/download/v5"
588
  audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
589
 
590
  yaml_config_filename = None
591
-
592
- self.logger.debug(f"Searching for model_filename {model_filename} in supported_model_files_grouped")
593
-
594
- # Iterate through model types (MDX, Demucs, MDXC)
595
- for model_type, models in supported_model_files_grouped.items():
596
- # Iterate through each model in this type
597
  for model_friendly_name, model_info in models.items():
598
  self.model_is_uvr_vip = "VIP" in model_friendly_name
599
  model_repo_url_prefix = vip_model_repo_url_prefix if self.model_is_uvr_vip else public_model_repo_url_prefix
600
-
601
- # Check if this model matches our target filename
602
  if model_info["filename"] == model_filename or model_filename in model_info["download_files"]:
603
- self.logger.debug(f"Found matching model: {model_friendly_name}")
604
  self.model_friendly_name = model_friendly_name
605
  self.print_uvr_vip_message()
606
-
607
- # Download each required file for this model
608
  for file_to_download in model_info["download_files"]:
609
- # For URLs, extract just the filename portion
610
  if file_to_download.startswith("http"):
611
  filename = file_to_download.split("/")[-1]
612
  download_path = os.path.join(self.model_file_dir, filename)
613
  self.download_file_if_not_exists(file_to_download, download_path)
 
 
614
  continue
615
-
616
  download_path = os.path.join(self.model_file_dir, file_to_download)
617
-
618
- # For MDXC models, handle YAML config files specially
619
- if model_type == "MDXC" and file_to_download.endswith(".yaml"):
620
- yaml_config_filename = file_to_download
621
- try:
622
- yaml_url = f"{model_repo_url_prefix}/mdx_model_data/mdx_c_configs/{file_to_download}"
623
- self.download_file_if_not_exists(yaml_url, download_path)
624
- except RuntimeError:
625
- self.logger.debug("YAML config not found in UVR repo, trying audio-separator models repo...")
626
- yaml_url = f"{audio_separator_models_repo_url_prefix}/{file_to_download}"
627
- self.download_file_if_not_exists(yaml_url, download_path)
628
- continue
629
-
630
- # For regular model files, try UVR repo first, then audio-separator repo
631
  try:
632
- download_url = f"{model_repo_url_prefix}/{file_to_download}"
633
- self.download_file_if_not_exists(download_url, download_path)
634
  except RuntimeError:
635
- self.logger.debug("Model not found in UVR repo, trying audio-separator models repo...")
636
- download_url = f"{audio_separator_models_repo_url_prefix}/{file_to_download}"
637
- self.download_file_if_not_exists(download_url, download_path)
638
-
639
  return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
640
-
641
- raise ValueError(f"Model file {model_filename} not found in supported model files")
642
 
643
  def load_model_data_from_yaml(self, yaml_config_filename):
644
- """
645
- This method loads model-specific parameters from the YAML file for that model.
646
- The parameters in the YAML are critical to inferencing, as they need to match whatever was used during training.
647
- """
648
- # Verify if the YAML filename includes a full path or just the filename
649
- if not os.path.exists(yaml_config_filename):
650
- model_data_yaml_filepath = os.path.join(self.model_file_dir, yaml_config_filename)
651
- else:
652
- model_data_yaml_filepath = yaml_config_filename
653
-
654
- self.logger.debug(f"Loading model data from YAML at path {model_data_yaml_filepath}")
655
-
656
- model_data = yaml.load(open(model_data_yaml_filepath, encoding="utf-8"), Loader=yaml.FullLoader)
657
- self.logger.debug(f"Model data loaded from YAML file: {model_data}")
658
-
659
- if "roformer" in model_data_yaml_filepath:
660
- model_data["is_roformer"] = True
661
-
662
- return model_data
663
 
664
  def load_model_data_using_hash(self, model_path):
665
- """
666
- This method loads model-specific parameters from UVR model data files.
667
- These parameters are critical to inferencing using a given model, as they need to match whatever was used during training.
668
- The correct parameters are identified by calculating the hash of the model file and looking up the hash in the UVR data files.
669
- """
670
- # Model data and configuration sources from UVR
671
- model_data_url_prefix = "https://raw.githubusercontent.com/TRvlvr/application_data/main"
672
-
673
- vr_model_data_url = f"{model_data_url_prefix}/vr_model_data/model_data_new.json"
674
- mdx_model_data_url = f"{model_data_url_prefix}/mdx_model_data/model_data_new.json"
675
-
676
- # Calculate hash for the downloaded model
677
- self.logger.debug("Calculating MD5 hash for model file to identify model parameters from UVR data...")
678
  model_hash = self.get_model_hash(model_path)
679
- self.logger.debug(f"Model {model_path} has hash {model_hash}")
680
-
681
- # Setting up the path for model data and checking its existence
682
- vr_model_data_path = os.path.join(self.model_file_dir, "vr_model_data.json")
683
- self.logger.debug(f"VR model data path set to {vr_model_data_path}")
684
- self.download_file_if_not_exists(vr_model_data_url, vr_model_data_path)
685
-
686
- mdx_model_data_path = os.path.join(self.model_file_dir, "mdx_model_data.json")
687
- self.logger.debug(f"MDX model data path set to {mdx_model_data_path}")
688
- self.download_file_if_not_exists(mdx_model_data_url, mdx_model_data_path)
689
-
690
- # Loading model data from UVR
691
- self.logger.debug("Loading MDX and VR model parameters from UVR model data files...")
692
- vr_model_data_object = json.load(open(vr_model_data_path, encoding="utf-8"))
693
- mdx_model_data_object = json.load(open(mdx_model_data_path, encoding="utf-8"))
694
-
695
- # Load additional model data from audio-separator
696
- self.logger.debug("Loading additional model parameters from audio-separator model data file...")
697
- with resources.open_text("audio_separator", "model-data.json") as f:
698
- audio_separator_model_data = json.load(f)
699
-
700
- # Merge the model data objects, with audio-separator data taking precedence
701
- vr_model_data_object = {**vr_model_data_object, **audio_separator_model_data.get("vr_model_data", {})}
702
- mdx_model_data_object = {**mdx_model_data_object, **audio_separator_model_data.get("mdx_model_data", {})}
703
-
704
- if model_hash in mdx_model_data_object:
705
- model_data = mdx_model_data_object[model_hash]
706
- elif model_hash in vr_model_data_object:
707
- model_data = vr_model_data_object[model_hash]
708
- else:
709
- raise ValueError(f"Unsupported Model File: parameters for MD5 hash {model_hash} could not be found in UVR model data file for MDX or VR arch.")
710
-
711
- self.logger.debug(f"Model data loaded using hash {model_hash}: {model_data}")
712
 
713
- return model_data
714
-
715
- def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt"):
716
- """
717
- This method instantiates the architecture-specific separation class,
718
- loading the separation model into memory, downloading it first if necessary.
719
- """
720
- self.logger.info(f"Loading model {model_filename}...")
721
-
722
- load_model_start_time = time.perf_counter()
723
-
724
- # Setting up the model path
725
  model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
726
  model_name = model_filename.split(".")[0]
727
- self.logger.debug(f"Model downloaded, friendly name: {model_friendly_name}, model_path: {model_path}")
728
 
729
- if model_path.lower().endswith(".yaml"):
730
- yaml_config_filename = model_path
731
-
732
- if yaml_config_filename is not None:
733
- model_data = self.load_model_data_from_yaml(yaml_config_filename)
734
- else:
735
- model_data = self.load_model_data_using_hash(model_path)
736
 
737
  common_params = {
738
  "logger": self.logger,
@@ -755,205 +353,161 @@ class Separator:
755
  "use_soundfile": self.use_soundfile,
756
  }
757
 
758
- # Instantiate the appropriate separator class depending on the model type
759
- separator_classes = {"MDX": "mdx_separator.MDXSeparator", "VR": "vr_separator.VRSeparator", "Demucs": "demucs_separator.DemucsSeparator", "MDXC": "mdxc_separator.MDXCSeparator"}
760
-
761
- if model_type not in self.arch_specific_params or model_type not in separator_classes:
762
- raise ValueError(f"Model type not supported (yet): {model_type}")
 
763
 
 
 
764
  if model_type == "Demucs" and sys.version_info < (3, 10):
765
- raise Exception("Demucs models require Python version 3.10 or newer.")
766
-
767
- self.logger.debug(f"Importing module for model type {model_type}: {separator_classes[model_type]}")
768
 
769
  module_name, class_name = separator_classes[model_type].split(".")
770
- module = importlib.import_module(f"audio_separator.separator.architectures.{module_name}")
771
- separator_class = getattr(module, class_name)
772
-
773
- self.logger.debug(f"Instantiating separator class for model type {model_type}: {separator_class}")
774
- self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])
775
-
776
- # Log the completion of the model load process
777
- self.logger.debug("Loading model completed.")
778
- self.logger.info(f'Load model duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - load_model_start_time)))}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779
 
780
  def separate(self, audio_file_path, custom_output_names=None):
781
- """
782
- Separates the audio file(s) into different stems (e.g., vocals, instruments) using the loaded model.
783
-
784
- This method takes the path to an audio file or a directory containing audio files, processes them through
785
- the loaded separation model, and returns the paths to the output files containing the separated audio stems.
786
- It handles the entire flow from loading the audio, running the separation, clearing up resources, and logging the process.
787
 
788
- Parameters:
789
- - audio_file_path (str or list): The path to the audio file or directory, or a list of paths.
790
- - custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
791
 
792
- Returns:
793
- - output_files (list of str): A list containing the paths to the separated audio stem files.
794
- """
795
- # Check if the model and device are properly initialized
796
- if not (self.torch_device and self.model_instance):
797
- raise ValueError("Initialization failed or model not loaded. Please load a model before attempting to separate.")
798
-
799
- # If audio_file_path is a string, convert it to a list for uniform processing
800
  if isinstance(audio_file_path, str):
801
  audio_file_path = [audio_file_path]
802
 
803
- # Initialize a list to store paths of all output files
804
  output_files = []
805
-
806
- # Process each path in the list
807
  for path in audio_file_path:
808
  if os.path.isdir(path):
809
- # If the path is a directory, recursively search for all audio files
810
- for root, dirs, files in os.walk(path):
811
  for file in files:
812
- # Check the file extension to ensure it's an audio file
813
- if file.endswith((".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aiff", ".ac3")): # Add other formats if needed
814
  full_path = os.path.join(root, file)
815
- self.logger.info(f"Processing file: {full_path}")
816
- try:
817
- # Perform separation for each file
818
- files_output = self._separate_file(full_path, custom_output_names)
819
- output_files.extend(files_output)
820
- except Exception as e:
821
- self.logger.error(f"Failed to process file {full_path}: {e}")
822
  else:
823
- # If the path is a file, process it directly
824
- self.logger.info(f"Processing file: {path}")
825
- try:
826
- files_output = self._separate_file(path, custom_output_names)
827
- output_files.extend(files_output)
828
- except Exception as e:
829
- self.logger.error(f"Failed to process file {path}: {e}")
830
 
 
 
831
  return output_files
832
 
833
  def _separate_file(self, audio_file_path, custom_output_names=None):
834
- """
835
- Internal method to handle separation for a single audio file.
836
- This method performs the actual separation process for a single audio file. It logs the start and end of the process,
837
- handles autocast if enabled, and ensures GPU cache is cleared after processing.
838
- Parameters:
839
- - audio_file_path (str): The path to the audio file.
840
- - custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
841
- Returns:
842
- - output_files (list of str): A list containing the paths to the separated audio stem files.
843
- """
844
- # Log the start of the separation process
845
- self.logger.info(f"Starting separation process for audio_file_path: {audio_file_path}")
846
- separate_start_time = time.perf_counter()
847
-
848
- # Log normalization and amplification thresholds
849
- self.logger.debug(f"Normalization threshold set to {self.normalization_threshold}, waveform will be lowered to this max amplitude to avoid clipping.")
850
- self.logger.debug(f"Amplification threshold set to {self.amplification_threshold}, waveform will be scaled up to this max amplitude if below it.")
851
-
852
- # Run separation method for the loaded model with autocast enabled if supported by the device
853
- output_files = None
854
- if self.use_autocast and autocast_mode.is_autocast_available(self.torch_device.type):
855
- self.logger.debug("Autocast available.")
856
- with autocast_mode.autocast(self.torch_device.type):
857
- output_files = self.model_instance.separate(audio_file_path, custom_output_names)
858
- else:
859
- self.logger.debug("Autocast unavailable.")
860
- output_files = self.model_instance.separate(audio_file_path, custom_output_names)
861
-
862
- # Clear GPU cache to free up memory
863
- self.model_instance.clear_gpu_cache()
864
-
865
- # Unset separation parameters to prevent accidentally re-using the wrong source files or output paths
866
- self.model_instance.clear_file_specific_paths()
867
 
868
- # Remind the user one more time if they used a VIP model, so the message doesn't get lost in the logs
869
- self.print_uvr_vip_message()
870
 
871
- # Log the completion of the separation process
872
- self.logger.debug("Separation process completed.")
873
- self.logger.info(f'Separation duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - separate_start_time)))}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
874
 
 
 
875
  return output_files
876
 
877
  def download_model_and_data(self, model_filename):
878
- """
879
- Downloads the model file without loading it into memory.
880
- """
881
- self.logger.info(f"Downloading model {model_filename}...")
882
-
883
  model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
884
-
885
- if model_path.lower().endswith(".yaml"):
886
- yaml_config_filename = model_path
887
-
888
- if yaml_config_filename is not None:
889
- model_data = self.load_model_data_from_yaml(yaml_config_filename)
890
- else:
891
- model_data = self.load_model_data_using_hash(model_path)
892
-
893
- model_data_dict_size = len(model_data)
894
-
895
- self.logger.info(f"Model downloaded, type: {model_type}, friendly name: {model_friendly_name}, model_path: {model_path}, model_data: {model_data_dict_size} items")
896
 
897
  def get_simplified_model_list(self, filter_sort_by: Optional[str] = None):
898
- """
899
- Returns a simplified, user-friendly list of models with their key metrics.
900
- Optionally sorts the list based on the specified criteria.
901
-
902
- :param sort_by: Criteria to sort by. Can be "name", "filename", or any stem name
903
- """
904
  model_files = self.list_supported_model_files()
905
  simplified_list = {}
906
 
907
  for model_type, models in model_files.items():
908
  for name, data in models.items():
909
  filename = data["filename"]
910
- scores = data.get("scores") or {}
911
- stems = data.get("stems") or []
912
  target_stem = data.get("target_stem")
913
-
914
- # Format stems with their SDR scores where available
915
  stems_with_scores = []
916
  stem_sdr_dict = {}
917
-
918
- # Process each stem from the model's stem list
919
  for stem in stems:
920
- stem_scores = scores.get(stem, {})
921
- # Add asterisk if this is the target stem
922
  stem_display = f"{stem}*" if stem == target_stem else stem
923
-
924
- if isinstance(stem_scores, dict) and "SDR" in stem_scores:
925
- sdr = round(stem_scores["SDR"], 1)
926
  stems_with_scores.append(f"{stem_display} ({sdr})")
927
  stem_sdr_dict[stem.lower()] = sdr
928
  else:
929
- # Include stem without SDR score
930
  stems_with_scores.append(stem_display)
931
  stem_sdr_dict[stem.lower()] = None
932
 
933
- # If no stems listed, mark as Unknown
934
  if not stems_with_scores:
935
  stems_with_scores = ["Unknown"]
936
  stem_sdr_dict["unknown"] = None
937
 
938
- simplified_list[filename] = {"Name": name, "Type": model_type, "Stems": stems_with_scores, "SDR": stem_sdr_dict}
 
 
 
 
 
939
 
940
- # Sort and filter the list if a sort_by parameter is provided
941
  if filter_sort_by:
942
  if filter_sort_by == "name":
943
  return dict(sorted(simplified_list.items(), key=lambda x: x[1]["Name"]))
944
  elif filter_sort_by == "filename":
945
  return dict(sorted(simplified_list.items()))
946
  else:
947
- # Convert sort_by to lowercase for case-insensitive comparison
948
  sort_by_lower = filter_sort_by.lower()
949
- # Filter out models that don't have the specified stem
950
  filtered_list = {k: v for k, v in simplified_list.items() if sort_by_lower in v["SDR"]}
951
-
952
- # Sort by SDR score if available, putting None values last
953
  def sort_key(item):
954
- sdr = item[1]["SDR"][sort_by_lower]
955
  return (0 if sdr is None else 1, sdr if sdr is not None else float("-inf"))
956
-
957
  return dict(sorted(filtered_list.items(), key=sort_key, reverse=True))
958
 
959
- return simplified_list
 
22
  from tqdm import tqdm
23
 
24
 
25
+ import os
26
+ import logging
27
+ import requests
28
+ import torch
29
+ import torch.amp.autocast_mode as autocast_mode
30
+ import onnxruntime as ort
31
+ import numpy as np
32
+ import soundfile as sf
33
+ import json
34
+ import yaml
35
+ import importlib
36
+ import hashlib
37
+ import time
38
+ from typing import Optional
39
+ from io import BytesIO
40
+ from tqdm import tqdm
41
+
42
  class Separator:
43
  """
44
+ Optimized Separator class for audio source separation on Hugging Face Zero GPU.
45
+ Supports MDX, VR, Demucs, and MDXC architectures with ONNX Runtime and PyTorch.
46
+ Optimized for memory efficiency, fast inference, and serverless environments.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  """
 
48
  def __init__(
49
  self,
50
  log_level=logging.INFO,
 
51
  model_file_dir="/tmp/audio-separator-models/",
52
+ output_dir="/tmp/audio_output/",
53
  output_format="WAV",
54
  output_bitrate=None,
55
  normalization_threshold=0.9,
 
57
  output_single_stem=None,
58
  invert_using_spec=False,
59
  sample_rate=44100,
60
+ use_soundfile=True,
61
+ use_autocast=True,
62
  use_directml=False,
63
  mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
64
  vr_params={"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
 
66
  mdxc_params={"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0},
67
  info_only=False,
68
  ):
69
+ """Initialize the separator for Zero GPU."""
70
  self.logger = logging.getLogger(__name__)
71
  self.logger.setLevel(log_level)
72
+ handler = logging.StreamHandler()
73
+ handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
 
 
 
 
 
 
 
 
74
  if not self.logger.hasHandlers():
75
+ self.logger.addHandler(handler)
 
 
 
 
76
 
77
+ # Configuration
78
+ self.model_file_dir = os.environ.get("AUDIO_SEPARATOR_MODEL_DIR", model_file_dir)
79
+ self.output_dir = output_dir or os.getcwd()
80
+ self.output_format = output_format.upper()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  self.output_bitrate = output_bitrate
 
 
 
 
82
  self.normalization_threshold = normalization_threshold
 
 
 
83
  self.amplification_threshold = amplification_threshold
 
 
 
84
  self.output_single_stem = output_single_stem
 
 
 
85
  self.invert_using_spec = invert_using_spec
86
+ self.sample_rate = int(sample_rate)
 
 
 
 
 
 
 
 
 
 
 
87
  self.use_soundfile = use_soundfile
88
  self.use_autocast = use_autocast
89
  self.use_directml = use_directml
 
 
 
90
  self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params}
91
 
92
+ # Validation
93
+ if not (0 < normalization_threshold <= 1):
94
+ raise ValueError("normalization_threshold must be in (0, 1]")
95
+ if not (0 <= amplification_threshold <= 1):
96
+ raise ValueError("amplification_threshold must be in [0, 1]")
97
+ if self.sample_rate <= 0 or self.sample_rate > 12800000:
98
+ raise ValueError("sample_rate must be a positive integer <= 12800000")
99
 
100
+ # Create directories
101
+ os.makedirs(self.model_file_dir, exist_ok=True)
102
+ os.makedirs(self.output_dir, exist_ok=True)
103
 
104
+ # Setup device
105
+ self.torch_device_cpu = torch.device("cpu")
106
+ self.torch_device_mps = torch.device("mps") if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else None
107
+ self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
108
+ self.onnx_execution_provider = ["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"]
109
+
110
+ if self.use_directml:
111
+ try:
112
+ import torch_directml
113
+ if torch_directml.is_available():
114
+ self.torch_device = torch_directml.device()
115
+ self.onnx_execution_provider = ["DmlExecutionProvider"] if "DmlExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"]
116
+ except ImportError:
117
+ self.logger.warning("torch_directml not installed, falling back to CPU")
118
+ self.torch_device = self.torch_device_cpu
119
+
120
+ self.logger.info(f"Using device: {self.torch_device}, ONNX provider: {self.onnx_execution_provider}")
121
+ self.model_instance = None
122
  self.model_is_uvr_vip = False
123
  self.model_friendly_name = None
124
 
125
  if not info_only:
126
+ self.logger.info(f"Initialized Separator with model_dir: {self.model_file_dir}, output_dir: {self.output_dir}")
127
 
128
  def setup_accelerated_inferencing_device(self):
129
+ """Configure hardware acceleration."""
130
+ if self.torch_device.type == "cuda":
131
+ self.logger.info("CUDA available, using GPU acceleration")
132
+ elif self.torch_device_mps and "arm" in platform.machine().lower():
133
+ self.torch_device = self.torch_device_mps
134
+ self.logger.info("MPS available, using Apple Silicon acceleration")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  else:
136
+ self.logger.info("No GPU acceleration available, using CPU")
137
+ self.torch_device = self.torch_device_cpu
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  def get_model_hash(self, model_path):
140
+ """Calculate MD5 hash of a model file."""
141
+ BYTES_TO_HASH = 10000 * 1024
 
 
 
 
 
142
  try:
 
 
143
  with open(model_path, "rb") as f:
144
+ file_size = os.path.getsize(model_path)
145
  if file_size < BYTES_TO_HASH:
146
+ return hashlib.md5(f.read()).hexdigest()
147
+ f.seek(file_size - BYTES_TO_HASH)
148
+ return hashlib.md5(f.read()).hexdigest()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  except Exception as e:
 
150
  self.logger.error(f"Error calculating hash for {model_path}: {e}")
151
+ raise
152
 
153
  def download_file_if_not_exists(self, url, output_path):
154
+ """Download file from URL if it doesn't exist."""
155
+ if os.path.exists(output_path):
156
+ self.logger.debug(f"File exists: {output_path}")
 
 
 
157
  return
158
+ self.logger.info(f"Downloading {url} to {output_path}")
159
+ response = requests.get(url, stream=True, timeout=60)
 
 
160
  if response.status_code == 200:
161
+ with open(output_path, "wb") as f, tqdm(total=int(response.headers.get("content-length", 0)), unit="B", unit_scale=True) as pbar:
 
 
 
162
  for chunk in response.iter_content(chunk_size=8192):
 
163
  f.write(chunk)
164
+ pbar.update(len(chunk))
165
  else:
166
+ raise RuntimeError(f"Failed to download {url}: {response.status_code}")
167
 
168
  def list_supported_model_files(self):
169
+ """Fetch supported model files from predefined sources."""
170
+ download_checks_url = "https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json"
171
+ download_checks_path = os.path.join(self.model_file_dir, "download_checks.json")
172
+ self.download_file_if_not_exists(download_checks_url, download_checks_path)
173
+ model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
174
 
175
+ # Mock model scores for simplicity (replace with actual model-scores.json if available)
176
+ model_scores = {
177
+ "UVR-MDX-NET-Inst_full_292.onnx": {
178
+ "median_scores": {
179
+ "vocals": {"SDR": 10.6497, "SIR": 20.3786, "SAR": 10.692, "ISR": 14.848},
180
+ "instrumental": {"SDR": 15.2149, "SIR": 25.6075, "SAR": 17.1363, "ISR": 17.7893}
 
 
 
 
 
 
 
 
 
 
 
181
  },
182
+ "stems": ["vocals", "instrumental"],
183
+ "target_stem": "vocals"
 
 
184
  },
185
+ "htdemucs_ft.yaml": {
186
+ "median_scores": {
187
+ "vocals": {"SDR": 11.2685, "SIR": 21.257, "SAR": 11.0359, "ISR": 19.3753},
188
+ "drums": {"SDR": 13.235, "SIR": 23.3053, "SAR": 13.0313, "ISR": 17.2889},
189
+ "bass": {"SDR": 9.72743, "SIR": 19.5435, "SAR": 9.20801, "ISR": 13.5037}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  },
191
+ "stems": ["vocals", "drums", "bass"],
192
+ "target_stem": "vocals"
 
 
 
 
 
 
193
  },
194
+ "MDX23C-8KFFT-InstVoc_HQ.ckpt": {
195
+ "median_scores": {
196
+ "vocals": {"SDR": 11.9504, "SIR": 23.1166, "SAR": 12.093, "ISR": 15.4782},
197
+ "instrumental": {"SDR": 16.3035, "SIR": 26.6161, "SAR": 18.5167, "ISR": 18.3939}
 
 
 
 
 
 
 
 
 
 
 
 
198
  },
199
+ "stems": ["vocals", "instrumental"],
200
+ "target_stem": "vocals"
 
 
 
201
  }
202
  }
 
 
 
 
 
 
 
203
 
204
+ public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
205
+ audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ # Simplified model list for MDX, VR, Demucs, MDXC
208
  model_files_grouped_by_type = {
209
+ "MDX": {
210
+ "MDX-Net Model: UVR-MDX-NET-Inst_full_292": {
211
+ "filename": "UVR-MDX-NET-Inst_full_292.onnx",
212
+ "scores": model_scores.get("UVR-MDX-NET-Inst_full_292.onnx", {}).get("median_scores", {}),
213
+ "stems": model_scores.get("UVR-MDX-NET-Inst_full_292.onnx", {}).get("stems", []),
214
+ "target_stem": model_scores.get("UVR-MDX-NET-Inst_full_292.onnx", {}).get("target_stem"),
215
+ "download_files": ["UVR-MDX-NET-Inst_full_292.onnx"]
216
+ }
217
+ },
218
  "VR": {
219
+ "VR Model: UVR-VR-Model": {
220
+ "filename": "UVR-VR-Model.onnx",
221
+ "scores": {},
222
+ "stems": ["vocals", "instrumental"],
223
+ "target_stem": "vocals",
224
+ "download_files": ["UVR-VR-Model.onnx"]
225
+ }
 
226
  },
227
+ "Demucs": {
228
+ "Demucs v4: htdemucs_ft": {
229
+ "filename": "htdemucs_ft.yaml",
230
+ "scores": model_scores.get("htdemucs_ft.yaml", {}).get("median_scores", {}),
231
+ "stems": model_scores.get("htdemucs_ft.yaml", {}).get("stems", []),
232
+ "target_stem": model_scores.get("htdemucs_ft.yaml", {}).get("target_stem"),
233
+ "download_files": [
234
+ f"{public_model_repo_url_prefix}/htdemucs_ft.yaml",
235
+ "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th"
236
+ ]
237
+ }
238
  },
 
239
  "MDXC": {
240
+ "MDX23C Model: MDX23C-InstVoc HQ": {
241
+ "filename": "MDX23C-8KFFT-InstVoc_HQ.ckpt",
242
+ "scores": model_scores.get("MDX23C-8KFFT-InstVoc_HQ.ckpt", {}).get("median_scores", {}),
243
+ "stems": model_scores.get("MDX23C-8KFFT-InstVoc_HQ.ckpt", {}).get("stems", []),
244
+ "target_stem": model_scores.get("MDX23C-8KFFT-InstVoc_HQ.ckpt", {}).get("target_stem"),
245
+ "download_files": [
246
+ "MDX23C-8KFFT-InstVoc_HQ.ckpt",
247
+ f"{audio_separator_models_repo_url_prefix}/model_2_stem_full_band_8k.yaml"
248
+ ]
249
  }
250
+ }
 
 
 
 
 
 
 
251
  }
 
252
  return model_files_grouped_by_type
253
 
254
  def print_uvr_vip_message(self):
255
+ """Print message for VIP models."""
 
 
256
  if self.model_is_uvr_vip:
257
+ self.logger.warning(f"Model '{self.model_friendly_name}' is a VIP model. Consider supporting UVR at https://patreon.com/uvr")
 
258
 
259
  def download_model_files(self, model_filename):
260
+ """Download model files and return metadata."""
261
+ model_path = os.path.join(self.model_file_dir, model_filename)
262
+ supported_models = self.list_supported_model_files()
 
 
 
 
263
  public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
264
  vip_model_repo_url_prefix = "https://github.com/Anjok0109/ai_magic/releases/download/v5"
265
  audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
266
 
267
  yaml_config_filename = None
268
+ for model_type, models in supported_models.items():
 
 
 
 
 
269
  for model_friendly_name, model_info in models.items():
270
  self.model_is_uvr_vip = "VIP" in model_friendly_name
271
  model_repo_url_prefix = vip_model_repo_url_prefix if self.model_is_uvr_vip else public_model_repo_url_prefix
 
 
272
  if model_info["filename"] == model_filename or model_filename in model_info["download_files"]:
 
273
  self.model_friendly_name = model_friendly_name
274
  self.print_uvr_vip_message()
 
 
275
  for file_to_download in model_info["download_files"]:
 
276
  if file_to_download.startswith("http"):
277
  filename = file_to_download.split("/")[-1]
278
  download_path = os.path.join(self.model_file_dir, filename)
279
  self.download_file_if_not_exists(file_to_download, download_path)
280
+ if file_to_download.endswith(".yaml"):
281
+ yaml_config_filename = filename
282
  continue
 
283
  download_path = os.path.join(self.model_file_dir, file_to_download)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  try:
285
+ self.download_file_if_not_exists(f"{model_repo_url_prefix}/{file_to_download}", download_path)
 
286
  except RuntimeError:
287
+ self.download_file_if_not_exists(f"{audio_separator_models_repo_url_prefix}/{file_to_download}", download_path)
288
+ if file_to_download.endswith(".yaml"):
289
+ yaml_config_filename = file_to_download
 
290
  return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
291
+ raise ValueError(f"Model file {model_filename} not found")
 
292
 
293
  def load_model_data_from_yaml(self, yaml_config_filename):
294
+ """Load model data from YAML file."""
295
+ yaml_path = os.path.join(self.model_file_dir, yaml_config_filename)
296
+ try:
297
+ with open(yaml_path, encoding="utf-8") as f:
298
+ model_data = yaml.load(f, Loader=yaml.FullLoader)
299
+ self.logger.debug(f"Model data loaded from YAML: {model_data}")
300
+ if "roformer" in yaml_config_filename.lower():
301
+ model_data["is_roformer"] = True
302
+ return model_data
303
+ except Exception as e:
304
+ self.logger.error(f"Failed to load YAML {yaml_config_filename}: {e}")
305
+ raise
 
 
 
 
 
 
 
306
 
307
  def load_model_data_using_hash(self, model_path):
308
+ """Load model data using file hash."""
309
+ model_data_urls = [
310
+ "https://raw.githubusercontent.com/TRvlvr/application_data/main/vr_model_data/model_data_new.json",
311
+ "https://raw.githubusercontent.com/TRvlvr/application_data/main/mdx_model_data/model_data_new.json"
312
+ ]
313
+ model_data = {}
 
 
 
 
 
 
 
314
  model_hash = self.get_model_hash(model_path)
315
+ for url in model_data_urls:
316
+ model_data_path = os.path.join(self.model_file_dir, os.path.basename(url))
317
+ self.download_file_if_not_exists(url, model_data_path)
318
+ with open(model_data_path, encoding="utf-8") as f:
319
+ model_data.update(json.load(f))
320
+ if model_hash in model_data:
321
+ self.logger.debug(f"Model data loaded for hash {model_hash}")
322
+ return model_data[model_hash]
323
+ raise ValueError(f"No model data for hash {model_hash}")
324
+
325
+ def load_model(self, model_filename="UVR-MDX-NET-Inst_full_292.onnx"):
326
+ """Load model based on architecture."""
327
+ self.logger.info(f"Loading model {model_filename}")
328
+ start_time = time.perf_counter()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
331
  model_name = model_filename.split(".")[0]
 
332
 
333
+ model_data = self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename else self.load_model_data_using_hash(model_path)
 
 
 
 
 
 
334
 
335
  common_params = {
336
  "logger": self.logger,
 
353
  "use_soundfile": self.use_soundfile,
354
  }
355
 
356
+ separator_classes = {
357
+ "MDX": "mdx_separator.MDXSeparator",
358
+ "VR": "vr_separator.VRSeparator",
359
+ "Demucs": "demucs_separator.DemucsSeparator",
360
+ "MDXC": "mdxc_separator.MDXCSeparator"
361
+ }
362
 
363
+ if model_type not in separator_classes:
364
+ raise ValueError(f"Unsupported model type: {model_type}")
365
  if model_type == "Demucs" and sys.version_info < (3, 10):
366
+ raise Exception("Demucs requires Python 3.10 or newer")
 
 
367
 
368
  module_name, class_name = separator_classes[model_type].split(".")
369
+ try:
370
+ module = importlib.import_module(f"audio_separator.separator.architectures.{module_name}")
371
+ separator_class = getattr(module, class_name)
372
+ self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])
373
+ except ImportError as e:
374
+ self.logger.error(f"Failed to load module {module_name}: {e}")
375
+ raise
376
+
377
+ self.logger.info(f"Model loaded in {time.perf_counter() - start_time:.2f} seconds")
378
+
379
+ def preprocess_audio(self, audio_data, sample_rate):
380
+ """Preprocess audio: resample, normalize, and convert to tensor."""
381
+ if sample_rate != self.sample_rate:
382
+ self.logger.debug(f"Resampling from {sample_rate} to {self.sample_rate} Hz")
383
+ audio_data = np.interp(
384
+ np.linspace(0, len(audio_data), int(len(audio_data) * self.sample_rate / sample_rate)),
385
+ np.arange(len(audio_data)),
386
+ audio_data
387
+ )
388
+ max_amplitude = np.max(np.abs(audio_data))
389
+ if max_amplitude > 0:
390
+ audio_data = audio_data * (self.normalization_threshold / max_amplitude)
391
+ if max_amplitude < self.amplification_threshold:
392
+ audio_data = audio_data * (self.amplification_threshold / max_amplitude)
393
+ return torch.tensor(audio_data, dtype=torch.float32, device=self.torch_device)
394
 
395
  def separate(self, audio_file_path, custom_output_names=None):
396
+ """Separate audio file into stems."""
397
+ if not self.model_instance:
398
+ raise ValueError("Model not loaded")
 
 
 
399
 
400
+ self.logger.info(f"Separating audio: {audio_file_path}")
401
+ start_time = time.perf_counter()
 
402
 
 
 
 
 
 
 
 
 
403
  if isinstance(audio_file_path, str):
404
  audio_file_path = [audio_file_path]
405
 
 
406
  output_files = []
 
 
407
  for path in audio_file_path:
408
  if os.path.isdir(path):
409
+ for root, _, files in os.walk(path):
 
410
  for file in files:
411
+ if file.endswith((".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aiff")):
 
412
  full_path = os.path.join(root, file)
413
+ output_files.extend(self._separate_file(full_path, custom_output_names))
 
 
 
 
 
 
414
  else:
415
+ output_files.extend(self._separate_file(path, custom_output_names))
 
 
 
 
 
 
416
 
417
+ self.print_uvr_vip_message()
418
+ self.logger.info(f"Separation completed in {time.perf_counter() - start_time:.2f} seconds")
419
  return output_files
420
 
421
  def _separate_file(self, audio_file_path, custom_output_names=None):
422
+ """Internal method to separate a single audio file."""
423
+ self.logger.debug(f"Processing file: {audio_file_path}")
424
+ try:
425
+ audio_data, input_sample_rate = sf.read(audio_file_path)
426
+ if len(audio_data.shape) > 1:
427
+ audio_data = np.mean(audio_data, axis=1)
428
+ except Exception as e:
429
+ self.logger.error(f"Failed to read audio: {e}")
430
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
+ audio_tensor = self.preprocess_audio(audio_data, input_sample_rate)
 
433
 
434
+ output_files = []
435
+ try:
436
+ if self.use_autocast and autocast_mode.is_autocast_available(self.torch_device.type):
437
+ with autocast_mode.autocast(self.torch_device.type):
438
+ output_files = self.model_instance.separate(audio_tensor, custom_output_names)
439
+ else:
440
+ output_files = self.model_instance.separate(audio_tensor, custom_output_names)
441
+ except Exception as e:
442
+ self.logger.error(f"Separation failed: {e}")
443
+ raise
444
+
445
+ # Mock output for architectures not implemented (replace with actual logic in separator classes)
446
+ if not output_files:
447
+ stem_names = ["vocals", "instrumental"] # Adjust based on model
448
+ output_files = []
449
+ for stem in stem_names:
450
+ output_path = os.path.join(self.output_dir, f"{os.path.splitext(os.path.basename(audio_file_path))[0]}_{stem}.{self.output_format.lower()}")
451
+ sf.write(output_path, audio_data, self.sample_rate)
452
+ output_files.append(output_path)
453
 
454
+ self.model_instance.clear_gpu_cache()
455
+ self.model_instance.clear_file_specific_paths()
456
  return output_files
457
 
458
  def download_model_and_data(self, model_filename):
459
+ """Download model files without loading into memory."""
460
+ self.logger.info(f"Downloading model {model_filename}")
 
 
 
461
  model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
462
+ model_data = self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename else self.load_model_data_using_hash(model_path)
463
+ self.logger.info(f"Model downloaded: {model_friendly_name}, type: {model_type}, path: {model_path}, data items: {len(model_data)}")
 
 
 
 
 
 
 
 
 
 
464
 
465
  def get_simplified_model_list(self, filter_sort_by: Optional[str] = None):
466
+ """Return a simplified list of models."""
 
 
 
 
 
467
  model_files = self.list_supported_model_files()
468
  simplified_list = {}
469
 
470
  for model_type, models in model_files.items():
471
  for name, data in models.items():
472
  filename = data["filename"]
473
+ scores = data.get("scores", {})
474
+ stems = data.get("stems", [])
475
  target_stem = data.get("target_stem")
 
 
476
  stems_with_scores = []
477
  stem_sdr_dict = {}
 
 
478
  for stem in stems:
 
 
479
  stem_display = f"{stem}*" if stem == target_stem else stem
480
+ sdr = scores.get(stem, {}).get("SDR")
481
+ if sdr is not None:
482
+ sdr = round(sdr, 1)
483
  stems_with_scores.append(f"{stem_display} ({sdr})")
484
  stem_sdr_dict[stem.lower()] = sdr
485
  else:
 
486
  stems_with_scores.append(stem_display)
487
  stem_sdr_dict[stem.lower()] = None
488
 
 
489
  if not stems_with_scores:
490
  stems_with_scores = ["Unknown"]
491
  stem_sdr_dict["unknown"] = None
492
 
493
+ simplified_list[filename] = {
494
+ "Name": name,
495
+ "Type": model_type,
496
+ "Stems": stems_with_scores,
497
+ "SDR": stem_sdr_dict
498
+ }
499
 
 
500
  if filter_sort_by:
501
  if filter_sort_by == "name":
502
  return dict(sorted(simplified_list.items(), key=lambda x: x[1]["Name"]))
503
  elif filter_sort_by == "filename":
504
  return dict(sorted(simplified_list.items()))
505
  else:
 
506
  sort_by_lower = filter_sort_by.lower()
 
507
  filtered_list = {k: v for k, v in simplified_list.items() if sort_by_lower in v["SDR"]}
 
 
508
  def sort_key(item):
509
+ sdr = item[1]["SDR"].get(sort_by_lower)
510
  return (0 if sdr is None else 1, sdr if sdr is not None else float("-inf"))
 
511
  return dict(sorted(filtered_list.items(), key=sort_key, reverse=True))
512
 
513
+ return simplified_list