Spaces:
Running
Running
Update audio_separator/separator/separator.py
Browse files- audio_separator/separator/separator.py +758 -277
audio_separator/separator/separator.py
CHANGED
|
@@ -1,32 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import requests
|
| 4 |
import torch
|
| 5 |
import torch.amp.autocast_mode as autocast_mode
|
| 6 |
import onnxruntime as ort
|
| 7 |
-
import numpy as np
|
| 8 |
-
import soundfile as sf
|
| 9 |
-
import json
|
| 10 |
-
import yaml
|
| 11 |
-
import importlib
|
| 12 |
-
import hashlib
|
| 13 |
-
import time
|
| 14 |
-
from typing import Optional
|
| 15 |
-
from io import BytesIO
|
| 16 |
from tqdm import tqdm
|
| 17 |
import spaces
|
| 18 |
|
|
|
|
| 19 |
class Separator:
|
| 20 |
"""
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
"""
|
|
|
|
| 25 |
def __init__(
|
| 26 |
self,
|
| 27 |
log_level=logging.INFO,
|
|
|
|
| 28 |
model_file_dir="/tmp/audio-separator-models/",
|
| 29 |
-
output_dir=
|
| 30 |
output_format="WAV",
|
| 31 |
output_bitrate=None,
|
| 32 |
normalization_threshold=0.9,
|
|
@@ -34,8 +92,8 @@ class Separator:
|
|
| 34 |
output_single_stem=None,
|
| 35 |
invert_using_spec=False,
|
| 36 |
sample_rate=44100,
|
| 37 |
-
use_soundfile=
|
| 38 |
-
use_autocast=
|
| 39 |
use_directml=False,
|
| 40 |
mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
|
| 41 |
vr_params={"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
|
|
@@ -43,256 +101,639 @@ class Separator:
|
|
| 43 |
mdxc_params={"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0},
|
| 44 |
info_only=False,
|
| 45 |
):
|
| 46 |
-
"""Initialize the separator
|
| 47 |
self.logger = logging.getLogger(__name__)
|
| 48 |
self.logger.setLevel(log_level)
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
if not self.logger.hasHandlers():
|
| 52 |
-
self.logger.addHandler(
|
| 53 |
|
| 54 |
-
#
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
self.output_bitrate = output_bitrate
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
self.normalization_threshold = normalization_threshold
|
|
|
|
|
|
|
|
|
|
| 60 |
self.amplification_threshold = amplification_threshold
|
|
|
|
|
|
|
|
|
|
| 61 |
self.output_single_stem = output_single_stem
|
|
|
|
|
|
|
|
|
|
| 62 |
self.invert_using_spec = invert_using_spec
|
| 63 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
self.use_soundfile = use_soundfile
|
| 65 |
self.use_autocast = use_autocast
|
| 66 |
self.use_directml = use_directml
|
| 67 |
-
self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params}
|
| 68 |
|
| 69 |
-
#
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
if not (0 <= amplification_threshold <= 1):
|
| 73 |
-
raise ValueError("amplification_threshold must be in [0, 1]")
|
| 74 |
-
if self.sample_rate <= 0 or self.sample_rate > 12800000:
|
| 75 |
-
raise ValueError("sample_rate must be a positive integer <= 12800000")
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
|
| 81 |
-
|
| 82 |
-
self.torch_device_cpu = torch.device("cpu")
|
| 83 |
-
self.torch_device_mps = torch.device("mps") if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else None
|
| 84 |
-
self.torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 85 |
-
self.onnx_execution_provider = ["CUDAExecutionProvider"] if "CUDAExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"]
|
| 86 |
-
|
| 87 |
-
if self.use_directml:
|
| 88 |
-
try:
|
| 89 |
-
import torch_directml
|
| 90 |
-
if torch_directml.is_available():
|
| 91 |
-
self.torch_device = torch_directml.device()
|
| 92 |
-
self.onnx_execution_provider = ["DmlExecutionProvider"] if "DmlExecutionProvider" in ort.get_available_providers() else ["CPUExecutionProvider"]
|
| 93 |
-
except ImportError:
|
| 94 |
-
self.logger.warning("torch_directml not installed, falling back to CPU")
|
| 95 |
-
self.torch_device = self.torch_device_cpu
|
| 96 |
-
|
| 97 |
-
self.logger.info(f"Using device: {self.torch_device}, ONNX provider: {self.onnx_execution_provider}")
|
| 98 |
self.model_instance = None
|
|
|
|
| 99 |
self.model_is_uvr_vip = False
|
| 100 |
self.model_friendly_name = None
|
| 101 |
|
| 102 |
if not info_only:
|
| 103 |
-
self.
|
| 104 |
|
| 105 |
def setup_accelerated_inferencing_device(self):
|
| 106 |
-
"""
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
self.torch_device = self.torch_device_cpu
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
def get_model_hash(self, model_path):
|
| 117 |
-
"""
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
try:
|
|
|
|
|
|
|
| 120 |
with open(model_path, "rb") as f:
|
| 121 |
-
file_size = os.path.getsize(model_path)
|
| 122 |
if file_size < BYTES_TO_HASH:
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
except Exception as e:
|
|
|
|
| 127 |
self.logger.error(f"Error calculating hash for {model_path}: {e}")
|
| 128 |
-
raise
|
| 129 |
|
| 130 |
def download_file_if_not_exists(self, url, output_path):
|
| 131 |
-
"""
|
| 132 |
-
if
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
| 134 |
return
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
| 137 |
if response.status_code == 200:
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
| 139 |
for chunk in response.iter_content(chunk_size=8192):
|
|
|
|
| 140 |
f.write(chunk)
|
| 141 |
-
|
| 142 |
else:
|
| 143 |
-
raise RuntimeError(f"Failed to download {url}: {response.status_code}")
|
| 144 |
|
| 145 |
def list_supported_model_files(self):
|
| 146 |
-
"""
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
"filename": "
|
| 156 |
-
"scores": {
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
},
|
| 161 |
-
"
|
| 162 |
-
"
|
| 163 |
-
"
|
| 164 |
-
"
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
},
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
}
|
|
|
|
|
|
|
| 170 |
|
| 171 |
-
|
| 172 |
-
audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
|
| 173 |
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
}
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
"VR": {
|
| 185 |
-
|
| 186 |
-
"filename":
|
| 187 |
-
"scores": {},
|
| 188 |
-
"stems":
|
| 189 |
-
"target_stem": "
|
| 190 |
-
"download_files": [
|
| 191 |
-
}
|
|
|
|
| 192 |
},
|
| 193 |
-
"
|
| 194 |
-
|
| 195 |
-
"filename":
|
| 196 |
-
"scores":
|
| 197 |
-
"stems":
|
| 198 |
-
"target_stem": "
|
| 199 |
-
"download_files": [
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
},
|
| 205 |
-
"MDXC": roformer_models
|
| 206 |
}
|
|
|
|
| 207 |
return model_files_grouped_by_type
|
| 208 |
|
| 209 |
def print_uvr_vip_message(self):
|
| 210 |
-
"""
|
|
|
|
|
|
|
| 211 |
if self.model_is_uvr_vip:
|
| 212 |
-
self.logger.warning(f"
|
|
|
|
| 213 |
|
| 214 |
def download_model_files(self, model_filename):
|
| 215 |
-
"""
|
| 216 |
-
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
|
| 219 |
vip_model_repo_url_prefix = "https://github.com/Anjok0109/ai_magic/releases/download/v5"
|
| 220 |
audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
|
| 221 |
|
| 222 |
yaml_config_filename = None
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
for model_friendly_name, model_info in models.items():
|
| 225 |
self.model_is_uvr_vip = "VIP" in model_friendly_name
|
| 226 |
model_repo_url_prefix = vip_model_repo_url_prefix if self.model_is_uvr_vip else public_model_repo_url_prefix
|
|
|
|
|
|
|
| 227 |
if model_info["filename"] == model_filename or model_filename in model_info["download_files"]:
|
|
|
|
| 228 |
self.model_friendly_name = model_friendly_name
|
| 229 |
self.print_uvr_vip_message()
|
|
|
|
|
|
|
| 230 |
for file_to_download in model_info["download_files"]:
|
|
|
|
| 231 |
if file_to_download.startswith("http"):
|
| 232 |
filename = file_to_download.split("/")[-1]
|
| 233 |
download_path = os.path.join(self.model_file_dir, filename)
|
| 234 |
self.download_file_if_not_exists(file_to_download, download_path)
|
| 235 |
-
if file_to_download.endswith(".yaml"):
|
| 236 |
-
yaml_config_filename = filename
|
| 237 |
continue
|
|
|
|
| 238 |
download_path = os.path.join(self.model_file_dir, file_to_download)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
try:
|
| 240 |
-
|
|
|
|
| 241 |
except RuntimeError:
|
| 242 |
-
self.
|
| 243 |
-
|
| 244 |
-
|
|
|
|
| 245 |
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
|
| 246 |
-
|
| 247 |
-
raise ValueError(f"Model file {model_filename} not found")
|
| 248 |
|
| 249 |
def load_model_data_from_yaml(self, yaml_config_filename):
|
| 250 |
-
"""
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
def load_model_data_using_hash(self, model_path):
|
| 264 |
-
"""
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
model_hash = self.get_model_hash(model_path)
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
self.logger.
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
"
|
| 284 |
-
|
| 285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
model_name = model_filename.split(".")[0]
|
| 295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
common_params = {
|
| 298 |
"logger": self.logger,
|
|
@@ -315,167 +756,207 @@ class Separator:
|
|
| 315 |
"use_soundfile": self.use_soundfile,
|
| 316 |
}
|
| 317 |
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
"
|
| 323 |
-
}
|
| 324 |
|
| 325 |
-
if model_type not in separator_classes:
|
| 326 |
-
raise ValueError(f"Unsupported model type: {model_type}")
|
| 327 |
if model_type == "Demucs" and sys.version_info < (3, 10):
|
| 328 |
-
raise Exception("Demucs
|
|
|
|
|
|
|
| 329 |
|
| 330 |
module_name, class_name = separator_classes[model_type].split(".")
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
self.logger.info(f
|
| 340 |
-
|
| 341 |
-
def preprocess_audio(self, audio_data, sample_rate):
|
| 342 |
-
"""Preprocess audio: resample, normalize, and convert to tensor."""
|
| 343 |
-
if sample_rate != self.sample_rate:
|
| 344 |
-
self.logger.debug(f"Resampling from {sample_rate} to {self.sample_rate} Hz")
|
| 345 |
-
audio_data = np.interp(
|
| 346 |
-
np.linspace(0, len(audio_data), int(len(audio_data) * self.sample_rate / sample_rate)),
|
| 347 |
-
np.arange(len(audio_data)),
|
| 348 |
-
audio_data
|
| 349 |
-
)
|
| 350 |
-
max_amplitude = np.max(np.abs(audio_data))
|
| 351 |
-
if max_amplitude > 0:
|
| 352 |
-
audio_data = audio_data * (self.normalization_threshold / max_amplitude)
|
| 353 |
-
if max_amplitude < self.amplification_threshold:
|
| 354 |
-
audio_data = audio_data * (self.amplification_threshold / max_amplitude)
|
| 355 |
-
return torch.tensor(audio_data, dtype=torch.float32, device=self.torch_device)
|
| 356 |
|
| 357 |
@spaces.GPU
|
| 358 |
def separate(self, audio_file_path, custom_output_names=None):
|
| 359 |
-
"""
|
| 360 |
-
|
| 361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
|
| 363 |
-
|
| 364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
|
|
|
|
| 366 |
if isinstance(audio_file_path, str):
|
| 367 |
audio_file_path = [audio_file_path]
|
| 368 |
|
|
|
|
| 369 |
output_files = []
|
|
|
|
|
|
|
| 370 |
for path in audio_file_path:
|
| 371 |
if os.path.isdir(path):
|
| 372 |
-
|
|
|
|
| 373 |
for file in files:
|
| 374 |
-
|
|
|
|
| 375 |
full_path = os.path.join(root, file)
|
| 376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
else:
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
|
| 380 |
-
self.print_uvr_vip_message()
|
| 381 |
-
self.logger.info(f"Separation completed in {time.perf_counter() - start_time:.2f} seconds")
|
| 382 |
return output_files
|
| 383 |
|
| 384 |
@spaces.GPU
|
| 385 |
def _separate_file(self, audio_file_path, custom_output_names=None):
|
| 386 |
-
"""
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
output_files =
|
| 413 |
-
for stem in stem_names:
|
| 414 |
-
output_path = os.path.join(self.output_dir, f"{os.path.splitext(os.path.basename(audio_file_path))[0]}_{stem}.{self.output_format.lower()}")
|
| 415 |
-
sf.write(output_path, audio_data, self.sample_rate)
|
| 416 |
-
output_files.append(output_path)
|
| 417 |
|
|
|
|
| 418 |
self.model_instance.clear_gpu_cache()
|
|
|
|
|
|
|
| 419 |
self.model_instance.clear_file_specific_paths()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
return output_files
|
| 421 |
|
| 422 |
def download_model_and_data(self, model_filename):
|
| 423 |
-
"""
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
|
| 433 |
def get_simplified_model_list(self, filter_sort_by: Optional[str] = None):
|
| 434 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
model_files = self.list_supported_model_files()
|
| 436 |
simplified_list = {}
|
| 437 |
|
| 438 |
for model_type, models in model_files.items():
|
| 439 |
for name, data in models.items():
|
| 440 |
filename = data["filename"]
|
| 441 |
-
scores = data.get("scores"
|
| 442 |
-
stems = data.get("stems"
|
| 443 |
target_stem = data.get("target_stem")
|
|
|
|
|
|
|
| 444 |
stems_with_scores = []
|
| 445 |
stem_sdr_dict = {}
|
|
|
|
|
|
|
| 446 |
for stem in stems:
|
|
|
|
|
|
|
| 447 |
stem_display = f"{stem}*" if stem == target_stem else stem
|
| 448 |
-
|
| 449 |
-
if
|
| 450 |
-
sdr = round(
|
| 451 |
stems_with_scores.append(f"{stem_display} ({sdr})")
|
| 452 |
stem_sdr_dict[stem.lower()] = sdr
|
| 453 |
else:
|
|
|
|
| 454 |
stems_with_scores.append(stem_display)
|
| 455 |
stem_sdr_dict[stem.lower()] = None
|
| 456 |
|
|
|
|
| 457 |
if not stems_with_scores:
|
| 458 |
stems_with_scores = ["Unknown"]
|
| 459 |
stem_sdr_dict["unknown"] = None
|
| 460 |
|
| 461 |
-
simplified_list[filename] = {
|
| 462 |
-
"Name": name,
|
| 463 |
-
"Type": model_type,
|
| 464 |
-
"Stems": stems_with_scores,
|
| 465 |
-
"SDR": stem_sdr_dict
|
| 466 |
-
}
|
| 467 |
|
|
|
|
| 468 |
if filter_sort_by:
|
| 469 |
if filter_sort_by == "name":
|
| 470 |
return dict(sorted(simplified_list.items(), key=lambda x: x[1]["Name"]))
|
| 471 |
elif filter_sort_by == "filename":
|
| 472 |
return dict(sorted(simplified_list.items()))
|
| 473 |
else:
|
|
|
|
| 474 |
sort_by_lower = filter_sort_by.lower()
|
|
|
|
| 475 |
filtered_list = {k: v for k, v in simplified_list.items() if sort_by_lower in v["SDR"]}
|
|
|
|
|
|
|
| 476 |
def sort_key(item):
|
| 477 |
-
sdr = item[1]["SDR"]
|
| 478 |
return (0 if sdr is None else 1, sdr if sdr is not None else float("-inf"))
|
|
|
|
| 479 |
return dict(sorted(filtered_list.items(), key=sort_key, reverse=True))
|
| 480 |
|
| 481 |
return simplified_list
|
|
|
|
| 1 |
+
""" This file contains the Separator class, to facilitate the separation of stems from audio. """
|
| 2 |
+
|
| 3 |
+
from importlib import metadata, resources
|
| 4 |
import os
|
| 5 |
+
import sys
|
| 6 |
+
import platform
|
| 7 |
+
import subprocess
|
| 8 |
+
import time
|
| 9 |
import logging
|
| 10 |
+
import warnings
|
| 11 |
+
import importlib
|
| 12 |
+
import io
|
| 13 |
+
from typing import Optional
|
| 14 |
+
|
| 15 |
+
import hashlib
|
| 16 |
+
import json
|
| 17 |
+
import yaml
|
| 18 |
import requests
|
| 19 |
import torch
|
| 20 |
import torch.amp.autocast_mode as autocast_mode
|
| 21 |
import onnxruntime as ort
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
from tqdm import tqdm
|
| 23 |
import spaces
|
| 24 |
|
| 25 |
+
|
| 26 |
class Separator:
|
| 27 |
"""
|
| 28 |
+
The Separator class is designed to facilitate the separation of audio sources from a given audio file.
|
| 29 |
+
It supports various separation architectures and models, including MDX, VR, and Demucs. The class provides
|
| 30 |
+
functionalities to configure separation parameters, load models, and perform audio source separation.
|
| 31 |
+
It also handles logging, normalization, and output formatting of the separated audio stems.
|
| 32 |
+
|
| 33 |
+
The actual separation task is handled by one of the architecture-specific classes in the `architectures` module;
|
| 34 |
+
this class is responsible for initialising logging, configuring hardware acceleration, loading the model,
|
| 35 |
+
initiating the separation process and passing outputs back to the caller.
|
| 36 |
+
|
| 37 |
+
Common Attributes:
|
| 38 |
+
log_level (int): The logging level.
|
| 39 |
+
log_formatter (logging.Formatter): The logging formatter.
|
| 40 |
+
model_file_dir (str): The directory where model files are stored.
|
| 41 |
+
output_dir (str): The directory where output files will be saved.
|
| 42 |
+
output_format (str): The format of the output audio file.
|
| 43 |
+
output_bitrate (str): The bitrate of the output audio file.
|
| 44 |
+
amplification_threshold (float): The threshold for audio amplification.
|
| 45 |
+
normalization_threshold (float): The threshold for audio normalization.
|
| 46 |
+
output_single_stem (str): Option to output a single stem.
|
| 47 |
+
invert_using_spec (bool): Flag to invert using spectrogram.
|
| 48 |
+
sample_rate (int): The sample rate of the audio.
|
| 49 |
+
use_soundfile (bool): Use soundfile for audio writing, can solve OOM issues.
|
| 50 |
+
use_autocast (bool): Flag to use PyTorch autocast for faster inference.
|
| 51 |
+
|
| 52 |
+
MDX Architecture Specific Attributes:
|
| 53 |
+
hop_length (int): The hop length for STFT.
|
| 54 |
+
segment_size (int): The segment size for processing.
|
| 55 |
+
overlap (float): The overlap between segments.
|
| 56 |
+
batch_size (int): The batch size for processing.
|
| 57 |
+
enable_denoise (bool): Flag to enable or disable denoising.
|
| 58 |
+
|
| 59 |
+
VR Architecture Specific Attributes & Defaults:
|
| 60 |
+
batch_size: 16
|
| 61 |
+
window_size: 512
|
| 62 |
+
aggression: 5
|
| 63 |
+
enable_tta: False
|
| 64 |
+
enable_post_process: False
|
| 65 |
+
post_process_threshold: 0.2
|
| 66 |
+
high_end_process: False
|
| 67 |
+
|
| 68 |
+
Demucs Architecture Specific Attributes & Defaults:
|
| 69 |
+
segment_size: "Default"
|
| 70 |
+
shifts: 2
|
| 71 |
+
overlap: 0.25
|
| 72 |
+
segments_enabled: True
|
| 73 |
+
|
| 74 |
+
MDXC Architecture Specific Attributes & Defaults:
|
| 75 |
+
segment_size: 256
|
| 76 |
+
override_model_segment_size: False
|
| 77 |
+
batch_size: 1
|
| 78 |
+
overlap: 8
|
| 79 |
+
pitch_shift: 0
|
| 80 |
"""
|
| 81 |
+
|
| 82 |
def __init__(
|
| 83 |
self,
|
| 84 |
log_level=logging.INFO,
|
| 85 |
+
log_formatter=None,
|
| 86 |
model_file_dir="/tmp/audio-separator-models/",
|
| 87 |
+
output_dir=None,
|
| 88 |
output_format="WAV",
|
| 89 |
output_bitrate=None,
|
| 90 |
normalization_threshold=0.9,
|
|
|
|
| 92 |
output_single_stem=None,
|
| 93 |
invert_using_spec=False,
|
| 94 |
sample_rate=44100,
|
| 95 |
+
use_soundfile=False,
|
| 96 |
+
use_autocast=False,
|
| 97 |
use_directml=False,
|
| 98 |
mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
|
| 99 |
vr_params={"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
|
|
|
|
| 101 |
mdxc_params={"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0},
|
| 102 |
info_only=False,
|
| 103 |
):
|
| 104 |
+
"""Initialize the separator."""
|
| 105 |
self.logger = logging.getLogger(__name__)
|
| 106 |
self.logger.setLevel(log_level)
|
| 107 |
+
self.log_level = log_level
|
| 108 |
+
self.log_formatter = log_formatter
|
| 109 |
+
|
| 110 |
+
self.log_handler = logging.StreamHandler()
|
| 111 |
+
|
| 112 |
+
if self.log_formatter is None:
|
| 113 |
+
self.log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(module)s - %(message)s")
|
| 114 |
+
|
| 115 |
+
self.log_handler.setFormatter(self.log_formatter)
|
| 116 |
+
|
| 117 |
if not self.logger.hasHandlers():
|
| 118 |
+
self.logger.addHandler(self.log_handler)
|
| 119 |
|
| 120 |
+
# Filter out noisy warnings from PyTorch for users who don't care about them
|
| 121 |
+
if log_level > logging.DEBUG:
|
| 122 |
+
warnings.filterwarnings("ignore")
|
| 123 |
+
|
| 124 |
+
# Skip initialization logs if info_only is True
|
| 125 |
+
if not info_only:
|
| 126 |
+
package_version = self.get_package_distribution("audio-separator").version
|
| 127 |
+
self.logger.info(f"Separator version {package_version} instantiating with output_dir: {output_dir}, output_format: {output_format}")
|
| 128 |
+
|
| 129 |
+
if output_dir is None:
|
| 130 |
+
output_dir = os.getcwd()
|
| 131 |
+
if not info_only:
|
| 132 |
+
self.logger.info("Output directory not specified. Using current working directory.")
|
| 133 |
+
|
| 134 |
+
self.output_dir = output_dir
|
| 135 |
+
|
| 136 |
+
# Check for environment variable to override model_file_dir
|
| 137 |
+
env_model_dir = os.environ.get("AUDIO_SEPARATOR_MODEL_DIR")
|
| 138 |
+
if env_model_dir:
|
| 139 |
+
self.model_file_dir = env_model_dir
|
| 140 |
+
self.logger.info(f"Using model directory from AUDIO_SEPARATOR_MODEL_DIR env var: {self.model_file_dir}")
|
| 141 |
+
if not os.path.exists(self.model_file_dir):
|
| 142 |
+
raise FileNotFoundError(f"The specified model directory does not exist: {self.model_file_dir}")
|
| 143 |
+
else:
|
| 144 |
+
self.logger.info(f"Using model directory from model_file_dir parameter: {model_file_dir}")
|
| 145 |
+
self.model_file_dir = model_file_dir
|
| 146 |
+
|
| 147 |
+
# Create the model directory if it does not exist
|
| 148 |
+
os.makedirs(self.model_file_dir, exist_ok=True)
|
| 149 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
| 150 |
+
|
| 151 |
+
self.output_format = output_format
|
| 152 |
self.output_bitrate = output_bitrate
|
| 153 |
+
|
| 154 |
+
if self.output_format is None:
|
| 155 |
+
self.output_format = "WAV"
|
| 156 |
+
|
| 157 |
self.normalization_threshold = normalization_threshold
|
| 158 |
+
if normalization_threshold <= 0 or normalization_threshold > 1:
|
| 159 |
+
raise ValueError("The normalization_threshold must be greater than 0 and less than or equal to 1.")
|
| 160 |
+
|
| 161 |
self.amplification_threshold = amplification_threshold
|
| 162 |
+
if amplification_threshold < 0 or amplification_threshold > 1:
|
| 163 |
+
raise ValueError("The amplification_threshold must be greater than or equal to 0 and less than or equal to 1.")
|
| 164 |
+
|
| 165 |
self.output_single_stem = output_single_stem
|
| 166 |
+
if output_single_stem is not None:
|
| 167 |
+
self.logger.debug(f"Single stem output requested, so only one output file ({output_single_stem}) will be written")
|
| 168 |
+
|
| 169 |
self.invert_using_spec = invert_using_spec
|
| 170 |
+
if self.invert_using_spec:
|
| 171 |
+
self.logger.debug(f"Secondary step will be inverted using spectogram rather than waveform. This may improve quality but is slightly slower.")
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
self.sample_rate = int(sample_rate)
|
| 175 |
+
if self.sample_rate <= 0:
|
| 176 |
+
raise ValueError(f"The sample rate setting is {self.sample_rate} but it must be a non-zero whole number.")
|
| 177 |
+
if self.sample_rate > 12800000:
|
| 178 |
+
raise ValueError(f"The sample rate setting is {self.sample_rate}. Enter something less ambitious.")
|
| 179 |
+
except ValueError:
|
| 180 |
+
raise ValueError("The sample rate must be a non-zero whole number. Please provide a valid integer.")
|
| 181 |
+
|
| 182 |
self.use_soundfile = use_soundfile
|
| 183 |
self.use_autocast = use_autocast
|
| 184 |
self.use_directml = use_directml
|
|
|
|
| 185 |
|
| 186 |
+
# These are parameters which users may want to configure so we expose them to the top-level Separator class,
|
| 187 |
+
# even though they are specific to a single model architecture
|
| 188 |
+
self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
+
self.torch_device = None
|
| 191 |
+
self.torch_device_cpu = None
|
| 192 |
+
self.torch_device_mps = None
|
| 193 |
|
| 194 |
+
self.onnx_execution_provider = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
self.model_instance = None
|
| 196 |
+
|
| 197 |
self.model_is_uvr_vip = False
|
| 198 |
self.model_friendly_name = None
|
| 199 |
|
| 200 |
if not info_only:
|
| 201 |
+
self.setup_accelerated_inferencing_device()
|
| 202 |
|
| 203 |
def setup_accelerated_inferencing_device(self):
|
| 204 |
+
"""
|
| 205 |
+
This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
|
| 206 |
+
"""
|
| 207 |
+
system_info = self.get_system_info()
|
| 208 |
+
self.check_ffmpeg_installed()
|
| 209 |
+
self.log_onnxruntime_packages()
|
| 210 |
+
self.setup_torch_device(system_info)
|
| 211 |
+
|
| 212 |
+
def get_system_info(self):
|
| 213 |
+
"""
|
| 214 |
+
This method logs the system information, including the operating system, CPU archutecture and Python version
|
| 215 |
+
"""
|
| 216 |
+
os_name = platform.system()
|
| 217 |
+
os_version = platform.version()
|
| 218 |
+
self.logger.info(f"Operating System: {os_name} {os_version}")
|
| 219 |
+
|
| 220 |
+
system_info = platform.uname()
|
| 221 |
+
self.logger.info(f"System: {system_info.system} Node: {system_info.node} Release: {system_info.release} Machine: {system_info.machine} Proc: {system_info.processor}")
|
| 222 |
+
|
| 223 |
+
python_version = platform.python_version()
|
| 224 |
+
self.logger.info(f"Python Version: {python_version}")
|
| 225 |
+
|
| 226 |
+
pytorch_version = torch.__version__
|
| 227 |
+
self.logger.info(f"PyTorch Version: {pytorch_version}")
|
| 228 |
+
return system_info
|
| 229 |
+
|
| 230 |
+
def check_ffmpeg_installed(self):
|
| 231 |
+
"""
|
| 232 |
+
This method checks if ffmpeg is installed and logs its version.
|
| 233 |
+
"""
|
| 234 |
+
try:
|
| 235 |
+
ffmpeg_version_output = subprocess.check_output(["ffmpeg", "-version"], text=True)
|
| 236 |
+
first_line = ffmpeg_version_output.splitlines()[0]
|
| 237 |
+
self.logger.info(f"FFmpeg installed: {first_line}")
|
| 238 |
+
except FileNotFoundError:
|
| 239 |
+
self.logger.error("FFmpeg is not installed. Please install FFmpeg to use this package.")
|
| 240 |
+
# Raise an exception if this is being run by a user, as ffmpeg is required for pydub to write audio
|
| 241 |
+
# but if we're just running unit tests in CI, no reason to throw
|
| 242 |
+
if "PYTEST_CURRENT_TEST" not in os.environ:
|
| 243 |
+
raise
|
| 244 |
+
|
| 245 |
+
def log_onnxruntime_packages(self):
|
| 246 |
+
"""
|
| 247 |
+
This method logs the ONNX Runtime package versions, including the GPU and Silicon packages if available.
|
| 248 |
+
"""
|
| 249 |
+
onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu")
|
| 250 |
+
onnxruntime_silicon_package = self.get_package_distribution("onnxruntime-silicon")
|
| 251 |
+
onnxruntime_cpu_package = self.get_package_distribution("onnxruntime")
|
| 252 |
+
onnxruntime_dml_package = self.get_package_distribution("onnxruntime-directml")
|
| 253 |
+
|
| 254 |
+
if onnxruntime_gpu_package is not None:
|
| 255 |
+
self.logger.info(f"ONNX Runtime GPU package installed with version: {onnxruntime_gpu_package.version}")
|
| 256 |
+
if onnxruntime_silicon_package is not None:
|
| 257 |
+
self.logger.info(f"ONNX Runtime Silicon package installed with version: {onnxruntime_silicon_package.version}")
|
| 258 |
+
if onnxruntime_cpu_package is not None:
|
| 259 |
+
self.logger.info(f"ONNX Runtime CPU package installed with version: {onnxruntime_cpu_package.version}")
|
| 260 |
+
if onnxruntime_dml_package is not None:
|
| 261 |
+
self.logger.info(f"ONNX Runtime DirectML package installed with version: {onnxruntime_dml_package.version}")
|
| 262 |
+
|
| 263 |
+
def setup_torch_device(self, system_info):
|
| 264 |
+
"""
|
| 265 |
+
This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available.
|
| 266 |
+
"""
|
| 267 |
+
hardware_acceleration_enabled = False
|
| 268 |
+
ort_providers = ort.get_available_providers()
|
| 269 |
+
has_torch_dml_installed = self.get_package_distribution("torch_directml")
|
| 270 |
+
|
| 271 |
+
self.torch_device_cpu = torch.device("cpu")
|
| 272 |
+
|
| 273 |
+
if torch.cuda.is_available():
|
| 274 |
+
self.configure_cuda(ort_providers)
|
| 275 |
+
hardware_acceleration_enabled = True
|
| 276 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and system_info.processor == "arm":
|
| 277 |
+
self.configure_mps(ort_providers)
|
| 278 |
+
hardware_acceleration_enabled = True
|
| 279 |
+
elif self.use_directml and has_torch_dml_installed:
|
| 280 |
+
import torch_directml
|
| 281 |
+
if torch_directml.is_available():
|
| 282 |
+
self.configure_dml(ort_providers)
|
| 283 |
+
hardware_acceleration_enabled = True
|
| 284 |
+
|
| 285 |
+
if not hardware_acceleration_enabled:
|
| 286 |
+
self.logger.info("No hardware acceleration could be configured, running in CPU mode")
|
| 287 |
self.torch_device = self.torch_device_cpu
|
| 288 |
+
self.onnx_execution_provider = ["CPUExecutionProvider"]
|
| 289 |
+
|
| 290 |
+
def configure_cuda(self, ort_providers):
|
| 291 |
+
"""
|
| 292 |
+
This method configures the CUDA device for PyTorch and ONNX Runtime, if available.
|
| 293 |
+
"""
|
| 294 |
+
self.logger.info("CUDA is available in Torch, setting Torch device to CUDA")
|
| 295 |
+
self.torch_device = torch.device("cuda")
|
| 296 |
+
if "CUDAExecutionProvider" in ort_providers:
|
| 297 |
+
self.logger.info("ONNXruntime has CUDAExecutionProvider available, enabling acceleration")
|
| 298 |
+
self.onnx_execution_provider = ["CUDAExecutionProvider"]
|
| 299 |
+
else:
|
| 300 |
+
self.logger.warning("CUDAExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
|
| 301 |
+
|
| 302 |
+
def configure_mps(self, ort_providers):
|
| 303 |
+
"""
|
| 304 |
+
This method configures the Apple Silicon MPS/CoreML device for PyTorch and ONNX Runtime, if available.
|
| 305 |
+
"""
|
| 306 |
+
self.logger.info("Apple Silicon MPS/CoreML is available in Torch and processor is ARM, setting Torch device to MPS")
|
| 307 |
+
self.torch_device_mps = torch.device("mps")
|
| 308 |
+
|
| 309 |
+
self.torch_device = self.torch_device_mps
|
| 310 |
+
|
| 311 |
+
if "CoreMLExecutionProvider" in ort_providers:
|
| 312 |
+
self.logger.info("ONNXruntime has CoreMLExecutionProvider available, enabling acceleration")
|
| 313 |
+
self.onnx_execution_provider = ["CoreMLExecutionProvider"]
|
| 314 |
+
else:
|
| 315 |
+
self.logger.warning("CoreMLExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
|
| 316 |
+
|
| 317 |
+
def configure_dml(self, ort_providers):
|
| 318 |
+
"""
|
| 319 |
+
This method configures the DirectML device for PyTorch and ONNX Runtime, if available.
|
| 320 |
+
"""
|
| 321 |
+
import torch_directml
|
| 322 |
+
self.logger.info("DirectML is available in Torch, setting Torch device to DirectML")
|
| 323 |
+
self.torch_device_dml = torch_directml.device()
|
| 324 |
+
self.torch_device = self.torch_device_dml
|
| 325 |
+
|
| 326 |
+
if "DmlExecutionProvider" in ort_providers:
|
| 327 |
+
self.logger.info("ONNXruntime has DmlExecutionProvider available, enabling acceleration")
|
| 328 |
+
self.onnx_execution_provider = ["DmlExecutionProvider"]
|
| 329 |
+
else:
|
| 330 |
+
self.logger.warning("DmlExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled")
|
| 331 |
+
|
| 332 |
+
def get_package_distribution(self, package_name):
|
| 333 |
+
"""
|
| 334 |
+
This method returns the package distribution for a given package name if installed, or None otherwise.
|
| 335 |
+
"""
|
| 336 |
+
try:
|
| 337 |
+
return metadata.distribution(package_name)
|
| 338 |
+
except metadata.PackageNotFoundError:
|
| 339 |
+
self.logger.debug(f"Python package: {package_name} not installed")
|
| 340 |
+
return None
|
| 341 |
|
| 342 |
def get_model_hash(self, model_path):
|
| 343 |
+
"""
|
| 344 |
+
This method returns the MD5 hash of a given model file.
|
| 345 |
+
"""
|
| 346 |
+
self.logger.debug(f"Calculating hash of model file {model_path}")
|
| 347 |
+
# Use the specific byte count from the original logic
|
| 348 |
+
BYTES_TO_HASH = 10000 * 1024 # 10,240,000 bytes
|
| 349 |
+
|
| 350 |
try:
|
| 351 |
+
file_size = os.path.getsize(model_path)
|
| 352 |
+
|
| 353 |
with open(model_path, "rb") as f:
|
|
|
|
| 354 |
if file_size < BYTES_TO_HASH:
|
| 355 |
+
# Hash the entire file if smaller than the target byte count
|
| 356 |
+
self.logger.debug(f"File size {file_size} < {BYTES_TO_HASH}, hashing entire file.")
|
| 357 |
+
hash_value = hashlib.md5(f.read()).hexdigest()
|
| 358 |
+
else:
|
| 359 |
+
# Seek to the specific position before the end (from the beginning) and hash
|
| 360 |
+
seek_pos = file_size - BYTES_TO_HASH
|
| 361 |
+
self.logger.debug(f"File size {file_size} >= {BYTES_TO_HASH}, seeking to {seek_pos} and hashing remaining bytes.")
|
| 362 |
+
f.seek(seek_pos, io.SEEK_SET)
|
| 363 |
+
hash_value = hashlib.md5(f.read()).hexdigest()
|
| 364 |
+
|
| 365 |
+
# Log the calculated hash
|
| 366 |
+
self.logger.info(f"Hash of model file {model_path} is {hash_value}")
|
| 367 |
+
return hash_value
|
| 368 |
+
|
| 369 |
+
except FileNotFoundError:
|
| 370 |
+
self.logger.error(f"Model file not found at {model_path}")
|
| 371 |
+
raise # Re-raise the specific error
|
| 372 |
except Exception as e:
|
| 373 |
+
# Catch other potential errors (e.g., permissions, other IOErrors)
|
| 374 |
self.logger.error(f"Error calculating hash for {model_path}: {e}")
|
| 375 |
+
raise # Re-raise other errors
|
| 376 |
|
| 377 |
def download_file_if_not_exists(self, url, output_path):
|
| 378 |
+
"""
|
| 379 |
+
This method downloads a file from a given URL to a given output path, if the file does not already exist.
|
| 380 |
+
"""
|
| 381 |
+
|
| 382 |
+
if os.path.isfile(output_path):
|
| 383 |
+
self.logger.debug(f"File already exists at {output_path}, skipping download")
|
| 384 |
return
|
| 385 |
+
|
| 386 |
+
self.logger.debug(f"Downloading file from {url} to {output_path} with timeout 300s")
|
| 387 |
+
response = requests.get(url, stream=True, timeout=300)
|
| 388 |
+
|
| 389 |
if response.status_code == 200:
|
| 390 |
+
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
| 391 |
+
progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
| 392 |
+
|
| 393 |
+
with open(output_path, "wb") as f:
|
| 394 |
for chunk in response.iter_content(chunk_size=8192):
|
| 395 |
+
progress_bar.update(len(chunk))
|
| 396 |
f.write(chunk)
|
| 397 |
+
progress_bar.close()
|
| 398 |
else:
|
| 399 |
+
raise RuntimeError(f"Failed to download file from {url}, response code: {response.status_code}")
|
| 400 |
|
| 401 |
def list_supported_model_files(self):
|
| 402 |
+
"""
|
| 403 |
+
This method lists the supported model files for audio-separator, by fetching the same file UVR uses to list these.
|
| 404 |
+
Also includes model performance scores where available.
|
| 405 |
+
|
| 406 |
+
Example response object:
|
| 407 |
|
| 408 |
+
{
|
| 409 |
+
"MDX": {
|
| 410 |
+
"MDX-Net Model VIP: UVR-MDX-NET-Inst_full_292": {
|
| 411 |
+
"filename": "UVR-MDX-NET-Inst_full_292.onnx",
|
| 412 |
+
"scores": {
|
| 413 |
+
"vocals": {
|
| 414 |
+
"SDR": 10.6497,
|
| 415 |
+
"SIR": 20.3786,
|
| 416 |
+
"SAR": 10.692,
|
| 417 |
+
"ISR": 14.848
|
| 418 |
+
},
|
| 419 |
+
"instrumental": {
|
| 420 |
+
"SDR": 15.2149,
|
| 421 |
+
"SIR": 25.6075,
|
| 422 |
+
"SAR": 17.1363,
|
| 423 |
+
"ISR": 17.7893
|
| 424 |
+
}
|
| 425 |
+
},
|
| 426 |
+
"download_files": [
|
| 427 |
+
"UVR-MDX-NET-Inst_full_292.onnx"
|
| 428 |
+
]
|
| 429 |
+
}
|
| 430 |
},
|
| 431 |
+
"Demucs": {
|
| 432 |
+
"Demucs v4: htdemucs_ft": {
|
| 433 |
+
"filename": "htdemucs_ft.yaml",
|
| 434 |
+
"scores": {
|
| 435 |
+
"vocals": {
|
| 436 |
+
"SDR": 11.2685,
|
| 437 |
+
"SIR": 21.257,
|
| 438 |
+
"SAR": 11.0359,
|
| 439 |
+
"ISR": 19.3753
|
| 440 |
+
},
|
| 441 |
+
"drums": {
|
| 442 |
+
"SDR": 13.235,
|
| 443 |
+
"SIR": 23.3053,
|
| 444 |
+
"SAR": 13.0313,
|
| 445 |
+
"ISR": 17.2889
|
| 446 |
+
},
|
| 447 |
+
"bass": {
|
| 448 |
+
"SDR": 9.72743,
|
| 449 |
+
"SIR": 19.5435,
|
| 450 |
+
"SAR": 9.20801,
|
| 451 |
+
"ISR": 13.5037
|
| 452 |
+
}
|
| 453 |
+
},
|
| 454 |
+
"download_files": [
|
| 455 |
+
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th",
|
| 456 |
+
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/d12395a8-e57c48e6.th",
|
| 457 |
+
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/92cfc3b6-ef3bcb9c.th",
|
| 458 |
+
"https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th",
|
| 459 |
+
"https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/htdemucs_ft.yaml"
|
| 460 |
+
]
|
| 461 |
+
}
|
| 462 |
},
|
| 463 |
+
"MDXC": {
|
| 464 |
+
"MDX23C Model: MDX23C-InstVoc HQ": {
|
| 465 |
+
"filename": "MDX23C-8KFFT-InstVoc_HQ.ckpt",
|
| 466 |
+
"scores": {
|
| 467 |
+
"vocals": {
|
| 468 |
+
"SDR": 11.9504,
|
| 469 |
+
"SIR": 23.1166,
|
| 470 |
+
"SAR": 12.093,
|
| 471 |
+
"ISR": 15.4782
|
| 472 |
+
},
|
| 473 |
+
"instrumental": {
|
| 474 |
+
"SDR": 16.3035,
|
| 475 |
+
"SIR": 26.6161,
|
| 476 |
+
"SAR": 18.5167,
|
| 477 |
+
"ISR": 18.3939
|
| 478 |
+
}
|
| 479 |
+
},
|
| 480 |
+
"download_files": [
|
| 481 |
+
"MDX23C-8KFFT-InstVoc_HQ.ckpt",
|
| 482 |
+
"model_2_stem_full_band_8k.yaml"
|
| 483 |
+
]
|
| 484 |
+
}
|
| 485 |
+
}
|
| 486 |
}
|
| 487 |
+
"""
|
| 488 |
+
download_checks_path = os.path.join(self.model_file_dir, "download_checks.json")
|
| 489 |
|
| 490 |
+
self.download_file_if_not_exists("https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json", download_checks_path)
|
|
|
|
| 491 |
|
| 492 |
+
model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
|
| 493 |
+
self.logger.debug(f"UVR model download list loaded")
|
| 494 |
+
|
| 495 |
+
# Load the model scores with error handling
|
| 496 |
+
model_scores = {}
|
| 497 |
+
try:
|
| 498 |
+
with resources.open_text("audio_separator", "models-scores.json") as f:
|
| 499 |
+
model_scores = json.load(f)
|
| 500 |
+
self.logger.debug(f"Model scores loaded")
|
| 501 |
+
except json.JSONDecodeError as e:
|
| 502 |
+
self.logger.warning(f"Failed to load model scores: {str(e)}")
|
| 503 |
+
self.logger.warning("Continuing without model scores")
|
| 504 |
+
|
| 505 |
+
# Only show Demucs v4 models as we've only implemented support for v4
|
| 506 |
+
filtered_demucs_v4 = {key: value for key, value in model_downloads_list["demucs_download_list"].items() if key.startswith("Demucs v4")}
|
| 507 |
+
|
| 508 |
+
# Modified Demucs handling to use YAML files as identifiers and include download files
|
| 509 |
+
demucs_models = {}
|
| 510 |
+
for name, files in filtered_demucs_v4.items():
|
| 511 |
+
# Find the YAML file in the model files
|
| 512 |
+
yaml_file = next((filename for filename in files.keys() if filename.endswith(".yaml")), None)
|
| 513 |
+
if yaml_file:
|
| 514 |
+
model_score_data = model_scores.get(yaml_file, {})
|
| 515 |
+
demucs_models[name] = {
|
| 516 |
+
"filename": yaml_file,
|
| 517 |
+
"scores": model_score_data.get("median_scores", {}),
|
| 518 |
+
"stems": model_score_data.get("stems", []),
|
| 519 |
+
"target_stem": model_score_data.get("target_stem"),
|
| 520 |
+
"download_files": list(files.values()), # List of all download URLs/filenames
|
| 521 |
}
|
| 522 |
+
|
| 523 |
+
# Load the JSON file using importlib.resources
|
| 524 |
+
with resources.open_text("audio_separator", "models.json") as f:
|
| 525 |
+
audio_separator_models_list = json.load(f)
|
| 526 |
+
self.logger.debug(f"Audio-Separator model list loaded")
|
| 527 |
+
|
| 528 |
+
# Return object with list of model names
|
| 529 |
+
model_files_grouped_by_type = {
|
| 530 |
"VR": {
|
| 531 |
+
name: {
|
| 532 |
+
"filename": filename,
|
| 533 |
+
"scores": model_scores.get(filename, {}).get("median_scores", {}),
|
| 534 |
+
"stems": model_scores.get(filename, {}).get("stems", []),
|
| 535 |
+
"target_stem": model_scores.get(filename, {}).get("target_stem"),
|
| 536 |
+
"download_files": [filename],
|
| 537 |
+
} # Just the filename for VR models
|
| 538 |
+
for name, filename in {**model_downloads_list["vr_download_list"], **audio_separator_models_list["vr_download_list"]}.items()
|
| 539 |
},
|
| 540 |
+
"MDX": {
|
| 541 |
+
name: {
|
| 542 |
+
"filename": filename,
|
| 543 |
+
"scores": model_scores.get(filename, {}).get("median_scores", {}),
|
| 544 |
+
"stems": model_scores.get(filename, {}).get("stems", []),
|
| 545 |
+
"target_stem": model_scores.get(filename, {}).get("target_stem"),
|
| 546 |
+
"download_files": [filename],
|
| 547 |
+
} # Just the filename for MDX models
|
| 548 |
+
for name, filename in {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"], **audio_separator_models_list["mdx_download_list"]}.items()
|
| 549 |
+
},
|
| 550 |
+
"Demucs": demucs_models,
|
| 551 |
+
"MDXC": {
|
| 552 |
+
name: {
|
| 553 |
+
"filename": next(iter(files.keys())),
|
| 554 |
+
"scores": model_scores.get(next(iter(files.keys())), {}).get("median_scores", {}),
|
| 555 |
+
"stems": model_scores.get(next(iter(files.keys())), {}).get("stems", []),
|
| 556 |
+
"target_stem": model_scores.get(next(iter(files.keys())), {}).get("target_stem"),
|
| 557 |
+
"download_files": list(files.keys()) + list(files.values()), # List of both model filenames and config filenames
|
| 558 |
}
|
| 559 |
+
for name, files in {
|
| 560 |
+
**model_downloads_list["mdx23c_download_list"],
|
| 561 |
+
**model_downloads_list["mdx23c_download_vip_list"],
|
| 562 |
+
**model_downloads_list["roformer_download_list"],
|
| 563 |
+
**audio_separator_models_list["mdx23c_download_list"],
|
| 564 |
+
**audio_separator_models_list["roformer_download_list"],
|
| 565 |
+
}.items()
|
| 566 |
},
|
|
|
|
| 567 |
}
|
| 568 |
+
|
| 569 |
return model_files_grouped_by_type
|
| 570 |
|
| 571 |
def print_uvr_vip_message(self):
|
| 572 |
+
"""
|
| 573 |
+
This method prints a message to the user if they have downloaded a VIP model, reminding them to support Anjok07 on Patreon.
|
| 574 |
+
"""
|
| 575 |
if self.model_is_uvr_vip:
|
| 576 |
+
self.logger.warning(f"The model: '{self.model_friendly_name}' is a VIP model, intended by Anjok07 for access by paying subscribers only.")
|
| 577 |
+
self.logger.warning("If you are not already subscribed, please consider supporting the developer of UVR, Anjok07 by subscribing here: https://patreon.com/uvr")
|
| 578 |
|
| 579 |
def download_model_files(self, model_filename):
|
| 580 |
+
"""
|
| 581 |
+
This method downloads the model files for a given model filename, if they are not already present.
|
| 582 |
+
Returns tuple of (model_filename, model_type, model_friendly_name, model_path, yaml_config_filename)
|
| 583 |
+
"""
|
| 584 |
+
model_path = os.path.join(self.model_file_dir, f"{model_filename}")
|
| 585 |
+
|
| 586 |
+
supported_model_files_grouped = self.list_supported_model_files()
|
| 587 |
public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models"
|
| 588 |
vip_model_repo_url_prefix = "https://github.com/Anjok0109/ai_magic/releases/download/v5"
|
| 589 |
audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs"
|
| 590 |
|
| 591 |
yaml_config_filename = None
|
| 592 |
+
|
| 593 |
+
self.logger.debug(f"Searching for model_filename {model_filename} in supported_model_files_grouped")
|
| 594 |
+
|
| 595 |
+
# Iterate through model types (MDX, Demucs, MDXC)
|
| 596 |
+
for model_type, models in supported_model_files_grouped.items():
|
| 597 |
+
# Iterate through each model in this type
|
| 598 |
for model_friendly_name, model_info in models.items():
|
| 599 |
self.model_is_uvr_vip = "VIP" in model_friendly_name
|
| 600 |
model_repo_url_prefix = vip_model_repo_url_prefix if self.model_is_uvr_vip else public_model_repo_url_prefix
|
| 601 |
+
|
| 602 |
+
# Check if this model matches our target filename
|
| 603 |
if model_info["filename"] == model_filename or model_filename in model_info["download_files"]:
|
| 604 |
+
self.logger.debug(f"Found matching model: {model_friendly_name}")
|
| 605 |
self.model_friendly_name = model_friendly_name
|
| 606 |
self.print_uvr_vip_message()
|
| 607 |
+
|
| 608 |
+
# Download each required file for this model
|
| 609 |
for file_to_download in model_info["download_files"]:
|
| 610 |
+
# For URLs, extract just the filename portion
|
| 611 |
if file_to_download.startswith("http"):
|
| 612 |
filename = file_to_download.split("/")[-1]
|
| 613 |
download_path = os.path.join(self.model_file_dir, filename)
|
| 614 |
self.download_file_if_not_exists(file_to_download, download_path)
|
|
|
|
|
|
|
| 615 |
continue
|
| 616 |
+
|
| 617 |
download_path = os.path.join(self.model_file_dir, file_to_download)
|
| 618 |
+
|
| 619 |
+
# For MDXC models, handle YAML config files specially
|
| 620 |
+
if model_type == "MDXC" and file_to_download.endswith(".yaml"):
|
| 621 |
+
yaml_config_filename = file_to_download
|
| 622 |
+
try:
|
| 623 |
+
yaml_url = f"{model_repo_url_prefix}/mdx_model_data/mdx_c_configs/{file_to_download}"
|
| 624 |
+
self.download_file_if_not_exists(yaml_url, download_path)
|
| 625 |
+
except RuntimeError:
|
| 626 |
+
self.logger.debug("YAML config not found in UVR repo, trying audio-separator models repo...")
|
| 627 |
+
yaml_url = f"{audio_separator_models_repo_url_prefix}/{file_to_download}"
|
| 628 |
+
self.download_file_if_not_exists(yaml_url, download_path)
|
| 629 |
+
continue
|
| 630 |
+
|
| 631 |
+
# For regular model files, try UVR repo first, then audio-separator repo
|
| 632 |
try:
|
| 633 |
+
download_url = f"{model_repo_url_prefix}/{file_to_download}"
|
| 634 |
+
self.download_file_if_not_exists(download_url, download_path)
|
| 635 |
except RuntimeError:
|
| 636 |
+
self.logger.debug("Model not found in UVR repo, trying audio-separator models repo...")
|
| 637 |
+
download_url = f"{audio_separator_models_repo_url_prefix}/{file_to_download}"
|
| 638 |
+
self.download_file_if_not_exists(download_url, download_path)
|
| 639 |
+
|
| 640 |
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
|
| 641 |
+
|
| 642 |
+
raise ValueError(f"Model file {model_filename} not found in supported model files")
|
| 643 |
|
| 644 |
def load_model_data_from_yaml(self, yaml_config_filename):
|
| 645 |
+
"""
|
| 646 |
+
This method loads model-specific parameters from the YAML file for that model.
|
| 647 |
+
The parameters in the YAML are critical to inferencing, as they need to match whatever was used during training.
|
| 648 |
+
"""
|
| 649 |
+
# Verify if the YAML filename includes a full path or just the filename
|
| 650 |
+
if not os.path.exists(yaml_config_filename):
|
| 651 |
+
model_data_yaml_filepath = os.path.join(self.model_file_dir, yaml_config_filename)
|
| 652 |
+
else:
|
| 653 |
+
model_data_yaml_filepath = yaml_config_filename
|
| 654 |
+
|
| 655 |
+
self.logger.debug(f"Loading model data from YAML at path {model_data_yaml_filepath}")
|
| 656 |
+
|
| 657 |
+
model_data = yaml.load(open(model_data_yaml_filepath, encoding="utf-8"), Loader=yaml.FullLoader)
|
| 658 |
+
self.logger.debug(f"Model data loaded from YAML file: {model_data}")
|
| 659 |
+
|
| 660 |
+
if "roformer" in model_data_yaml_filepath:
|
| 661 |
+
model_data["is_roformer"] = True
|
| 662 |
+
|
| 663 |
+
return model_data
|
| 664 |
|
| 665 |
def load_model_data_using_hash(self, model_path):
|
| 666 |
+
"""
|
| 667 |
+
This method loads model-specific parameters from UVR model data files.
|
| 668 |
+
These parameters are critical to inferencing using a given model, as they need to match whatever was used during training.
|
| 669 |
+
The correct parameters are identified by calculating the hash of the model file and looking up the hash in the UVR data files.
|
| 670 |
+
"""
|
| 671 |
+
# Model data and configuration sources from UVR
|
| 672 |
+
model_data_url_prefix = "https://raw.githubusercontent.com/TRvlvr/application_data/main"
|
| 673 |
+
|
| 674 |
+
vr_model_data_url = f"{model_data_url_prefix}/vr_model_data/model_data_new.json"
|
| 675 |
+
mdx_model_data_url = f"{model_data_url_prefix}/mdx_model_data/model_data_new.json"
|
| 676 |
+
|
| 677 |
+
# Calculate hash for the downloaded model
|
| 678 |
+
self.logger.debug("Calculating MD5 hash for model file to identify model parameters from UVR data...")
|
| 679 |
model_hash = self.get_model_hash(model_path)
|
| 680 |
+
self.logger.debug(f"Model {model_path} has hash {model_hash}")
|
| 681 |
+
|
| 682 |
+
# Setting up the path for model data and checking its existence
|
| 683 |
+
vr_model_data_path = os.path.join(self.model_file_dir, "vr_model_data.json")
|
| 684 |
+
self.logger.debug(f"VR model data path set to {vr_model_data_path}")
|
| 685 |
+
self.download_file_if_not_exists(vr_model_data_url, vr_model_data_path)
|
| 686 |
+
|
| 687 |
+
mdx_model_data_path = os.path.join(self.model_file_dir, "mdx_model_data.json")
|
| 688 |
+
self.logger.debug(f"MDX model data path set to {mdx_model_data_path}")
|
| 689 |
+
self.download_file_if_not_exists(mdx_model_data_url, mdx_model_data_path)
|
| 690 |
+
|
| 691 |
+
# Loading model data from UVR
|
| 692 |
+
self.logger.debug("Loading MDX and VR model parameters from UVR model data files...")
|
| 693 |
+
vr_model_data_object = json.load(open(vr_model_data_path, encoding="utf-8"))
|
| 694 |
+
mdx_model_data_object = json.load(open(mdx_model_data_path, encoding="utf-8"))
|
| 695 |
+
|
| 696 |
+
# Load additional model data from audio-separator
|
| 697 |
+
self.logger.debug("Loading additional model parameters from audio-separator model data file...")
|
| 698 |
+
with resources.open_text("audio_separator", "model-data.json") as f:
|
| 699 |
+
audio_separator_model_data = json.load(f)
|
| 700 |
+
|
| 701 |
+
# Merge the model data objects, with audio-separator data taking precedence
|
| 702 |
+
vr_model_data_object = {**vr_model_data_object, **audio_separator_model_data.get("vr_model_data", {})}
|
| 703 |
+
mdx_model_data_object = {**mdx_model_data_object, **audio_separator_model_data.get("mdx_model_data", {})}
|
| 704 |
+
|
| 705 |
+
if model_hash in mdx_model_data_object:
|
| 706 |
+
model_data = mdx_model_data_object[model_hash]
|
| 707 |
+
elif model_hash in vr_model_data_object:
|
| 708 |
+
model_data = vr_model_data_object[model_hash]
|
| 709 |
+
else:
|
| 710 |
+
raise ValueError(f"Unsupported Model File: parameters for MD5 hash {model_hash} could not be found in UVR model data file for MDX or VR arch.")
|
| 711 |
|
| 712 |
+
self.logger.debug(f"Model data loaded using hash {model_hash}: {model_data}")
|
| 713 |
+
|
| 714 |
+
return model_data
|
| 715 |
+
|
| 716 |
+
def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt"):
|
| 717 |
+
"""
|
| 718 |
+
This method instantiates the architecture-specific separation class,
|
| 719 |
+
loading the separation model into memory, downloading it first if necessary.
|
| 720 |
+
"""
|
| 721 |
+
self.logger.info(f"Loading model {model_filename}...")
|
| 722 |
|
| 723 |
+
load_model_start_time = time.perf_counter()
|
| 724 |
+
|
| 725 |
+
# Setting up the model path
|
| 726 |
+
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
| 727 |
model_name = model_filename.split(".")[0]
|
| 728 |
+
self.logger.debug(f"Model downloaded, friendly name: {model_friendly_name}, model_path: {model_path}")
|
| 729 |
+
|
| 730 |
+
if model_path.lower().endswith(".yaml"):
|
| 731 |
+
yaml_config_filename = model_path
|
| 732 |
+
|
| 733 |
+
if yaml_config_filename is not None:
|
| 734 |
+
model_data = self.load_model_data_from_yaml(yaml_config_filename)
|
| 735 |
+
else:
|
| 736 |
+
model_data = self.load_model_data_using_hash(model_path)
|
| 737 |
|
| 738 |
common_params = {
|
| 739 |
"logger": self.logger,
|
|
|
|
| 756 |
"use_soundfile": self.use_soundfile,
|
| 757 |
}
|
| 758 |
|
| 759 |
+
# Instantiate the appropriate separator class depending on the model type
|
| 760 |
+
separator_classes = {"MDX": "mdx_separator.MDXSeparator", "VR": "vr_separator.VRSeparator", "Demucs": "demucs_separator.DemucsSeparator", "MDXC": "mdxc_separator.MDXCSeparator"}
|
| 761 |
+
|
| 762 |
+
if model_type not in self.arch_specific_params or model_type not in separator_classes:
|
| 763 |
+
raise ValueError(f"Model type not supported (yet): {model_type}")
|
|
|
|
| 764 |
|
|
|
|
|
|
|
| 765 |
if model_type == "Demucs" and sys.version_info < (3, 10):
|
| 766 |
+
raise Exception("Demucs models require Python version 3.10 or newer.")
|
| 767 |
+
|
| 768 |
+
self.logger.debug(f"Importing module for model type {model_type}: {separator_classes[model_type]}")
|
| 769 |
|
| 770 |
module_name, class_name = separator_classes[model_type].split(".")
|
| 771 |
+
module = importlib.import_module(f"audio_separator.separator.architectures.{module_name}")
|
| 772 |
+
separator_class = getattr(module, class_name)
|
| 773 |
+
|
| 774 |
+
self.logger.debug(f"Instantiating separator class for model type {model_type}: {separator_class}")
|
| 775 |
+
self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])
|
| 776 |
+
|
| 777 |
+
# Log the completion of the model load process
|
| 778 |
+
self.logger.debug("Loading model completed.")
|
| 779 |
+
self.logger.info(f'Load model duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - load_model_start_time)))}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 780 |
|
| 781 |
@spaces.GPU
|
| 782 |
def separate(self, audio_file_path, custom_output_names=None):
|
| 783 |
+
"""
|
| 784 |
+
Separates the audio file(s) into different stems (e.g., vocals, instruments) using the loaded model.
|
| 785 |
+
|
| 786 |
+
This method takes the path to an audio file or a directory containing audio files, processes them through
|
| 787 |
+
the loaded separation model, and returns the paths to the output files containing the separated audio stems.
|
| 788 |
+
It handles the entire flow from loading the audio, running the separation, clearing up resources, and logging the process.
|
| 789 |
+
|
| 790 |
+
Parameters:
|
| 791 |
+
- audio_file_path (str or list): The path to the audio file or directory, or a list of paths.
|
| 792 |
+
- custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
|
| 793 |
|
| 794 |
+
Returns:
|
| 795 |
+
- output_files (list of str): A list containing the paths to the separated audio stem files.
|
| 796 |
+
"""
|
| 797 |
+
# Check if the model and device are properly initialized
|
| 798 |
+
if not (self.torch_device and self.model_instance):
|
| 799 |
+
raise ValueError("Initialization failed or model not loaded. Please load a model before attempting to separate.")
|
| 800 |
|
| 801 |
+
# If audio_file_path is a string, convert it to a list for uniform processing
|
| 802 |
if isinstance(audio_file_path, str):
|
| 803 |
audio_file_path = [audio_file_path]
|
| 804 |
|
| 805 |
+
# Initialize a list to store paths of all output files
|
| 806 |
output_files = []
|
| 807 |
+
|
| 808 |
+
# Process each path in the list
|
| 809 |
for path in audio_file_path:
|
| 810 |
if os.path.isdir(path):
|
| 811 |
+
# If the path is a directory, recursively search for all audio files
|
| 812 |
+
for root, dirs, files in os.walk(path):
|
| 813 |
for file in files:
|
| 814 |
+
# Check the file extension to ensure it's an audio file
|
| 815 |
+
if file.endswith((".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aiff", ".ac3")): # Add other formats if needed
|
| 816 |
full_path = os.path.join(root, file)
|
| 817 |
+
self.logger.info(f"Processing file: {full_path}")
|
| 818 |
+
try:
|
| 819 |
+
# Perform separation for each file
|
| 820 |
+
files_output = self._separate_file(full_path, custom_output_names)
|
| 821 |
+
output_files.extend(files_output)
|
| 822 |
+
except Exception as e:
|
| 823 |
+
self.logger.error(f"Failed to process file {full_path}: {e}")
|
| 824 |
else:
|
| 825 |
+
# If the path is a file, process it directly
|
| 826 |
+
self.logger.info(f"Processing file: {path}")
|
| 827 |
+
try:
|
| 828 |
+
files_output = self._separate_file(path, custom_output_names)
|
| 829 |
+
output_files.extend(files_output)
|
| 830 |
+
except Exception as e:
|
| 831 |
+
self.logger.error(f"Failed to process file {path}: {e}")
|
| 832 |
|
|
|
|
|
|
|
| 833 |
return output_files
|
| 834 |
|
| 835 |
@spaces.GPU
|
| 836 |
def _separate_file(self, audio_file_path, custom_output_names=None):
|
| 837 |
+
"""
|
| 838 |
+
Internal method to handle separation for a single audio file.
|
| 839 |
+
This method performs the actual separation process for a single audio file. It logs the start and end of the process,
|
| 840 |
+
handles autocast if enabled, and ensures GPU cache is cleared after processing.
|
| 841 |
+
Parameters:
|
| 842 |
+
- audio_file_path (str): The path to the audio file.
|
| 843 |
+
- custom_output_names (dict, optional): Custom names for the output files. Defaults to None.
|
| 844 |
+
Returns:
|
| 845 |
+
- output_files (list of str): A list containing the paths to the separated audio stem files.
|
| 846 |
+
"""
|
| 847 |
+
# Log the start of the separation process
|
| 848 |
+
self.logger.info(f"Starting separation process for audio_file_path: {audio_file_path}")
|
| 849 |
+
separate_start_time = time.perf_counter()
|
| 850 |
+
|
| 851 |
+
# Log normalization and amplification thresholds
|
| 852 |
+
self.logger.debug(f"Normalization threshold set to {self.normalization_threshold}, waveform will be lowered to this max amplitude to avoid clipping.")
|
| 853 |
+
self.logger.debug(f"Amplification threshold set to {self.amplification_threshold}, waveform will be scaled up to this max amplitude if below it.")
|
| 854 |
+
|
| 855 |
+
# Run separation method for the loaded model with autocast enabled if supported by the device
|
| 856 |
+
output_files = None
|
| 857 |
+
if self.use_autocast and autocast_mode.is_autocast_available(self.torch_device.type):
|
| 858 |
+
self.logger.debug("Autocast available.")
|
| 859 |
+
with autocast_mode.autocast(self.torch_device.type):
|
| 860 |
+
output_files = self.model_instance.separate(audio_file_path, custom_output_names)
|
| 861 |
+
else:
|
| 862 |
+
self.logger.debug("Autocast unavailable.")
|
| 863 |
+
output_files = self.model_instance.separate(audio_file_path, custom_output_names)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 864 |
|
| 865 |
+
# Clear GPU cache to free up memory
|
| 866 |
self.model_instance.clear_gpu_cache()
|
| 867 |
+
|
| 868 |
+
# Unset separation parameters to prevent accidentally re-using the wrong source files or output paths
|
| 869 |
self.model_instance.clear_file_specific_paths()
|
| 870 |
+
|
| 871 |
+
# Remind the user one more time if they used a VIP model, so the message doesn't get lost in the logs
|
| 872 |
+
self.print_uvr_vip_message()
|
| 873 |
+
|
| 874 |
+
# Log the completion of the separation process
|
| 875 |
+
self.logger.debug("Separation process completed.")
|
| 876 |
+
self.logger.info(f'Separation duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - separate_start_time)))}')
|
| 877 |
+
|
| 878 |
return output_files
|
| 879 |
|
| 880 |
def download_model_and_data(self, model_filename):
|
| 881 |
+
"""
|
| 882 |
+
Downloads the model file without loading it into memory.
|
| 883 |
+
"""
|
| 884 |
+
self.logger.info(f"Downloading model {model_filename}...")
|
| 885 |
+
|
| 886 |
+
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
| 887 |
+
|
| 888 |
+
if model_path.lower().endswith(".yaml"):
|
| 889 |
+
yaml_config_filename = model_path
|
| 890 |
+
|
| 891 |
+
if yaml_config_filename is not None:
|
| 892 |
+
model_data = self.load_model_data_from_yaml(yaml_config_filename)
|
| 893 |
+
else:
|
| 894 |
+
model_data = self.load_model_data_using_hash(model_path)
|
| 895 |
+
|
| 896 |
+
model_data_dict_size = len(model_data)
|
| 897 |
+
|
| 898 |
+
self.logger.info(f"Model downloaded, type: {model_type}, friendly name: {model_friendly_name}, model_path: {model_path}, model_data: {model_data_dict_size} items")
|
| 899 |
|
| 900 |
def get_simplified_model_list(self, filter_sort_by: Optional[str] = None):
|
| 901 |
+
"""
|
| 902 |
+
Returns a simplified, user-friendly list of models with their key metrics.
|
| 903 |
+
Optionally sorts the list based on the specified criteria.
|
| 904 |
+
|
| 905 |
+
:param sort_by: Criteria to sort by. Can be "name", "filename", or any stem name
|
| 906 |
+
"""
|
| 907 |
model_files = self.list_supported_model_files()
|
| 908 |
simplified_list = {}
|
| 909 |
|
| 910 |
for model_type, models in model_files.items():
|
| 911 |
for name, data in models.items():
|
| 912 |
filename = data["filename"]
|
| 913 |
+
scores = data.get("scores") or {}
|
| 914 |
+
stems = data.get("stems") or []
|
| 915 |
target_stem = data.get("target_stem")
|
| 916 |
+
|
| 917 |
+
# Format stems with their SDR scores where available
|
| 918 |
stems_with_scores = []
|
| 919 |
stem_sdr_dict = {}
|
| 920 |
+
|
| 921 |
+
# Process each stem from the model's stem list
|
| 922 |
for stem in stems:
|
| 923 |
+
stem_scores = scores.get(stem, {})
|
| 924 |
+
# Add asterisk if this is the target stem
|
| 925 |
stem_display = f"{stem}*" if stem == target_stem else stem
|
| 926 |
+
|
| 927 |
+
if isinstance(stem_scores, dict) and "SDR" in stem_scores:
|
| 928 |
+
sdr = round(stem_scores["SDR"], 1)
|
| 929 |
stems_with_scores.append(f"{stem_display} ({sdr})")
|
| 930 |
stem_sdr_dict[stem.lower()] = sdr
|
| 931 |
else:
|
| 932 |
+
# Include stem without SDR score
|
| 933 |
stems_with_scores.append(stem_display)
|
| 934 |
stem_sdr_dict[stem.lower()] = None
|
| 935 |
|
| 936 |
+
# If no stems listed, mark as Unknown
|
| 937 |
if not stems_with_scores:
|
| 938 |
stems_with_scores = ["Unknown"]
|
| 939 |
stem_sdr_dict["unknown"] = None
|
| 940 |
|
| 941 |
+
simplified_list[filename] = {"Name": name, "Type": model_type, "Stems": stems_with_scores, "SDR": stem_sdr_dict}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 942 |
|
| 943 |
+
# Sort and filter the list if a sort_by parameter is provided
|
| 944 |
if filter_sort_by:
|
| 945 |
if filter_sort_by == "name":
|
| 946 |
return dict(sorted(simplified_list.items(), key=lambda x: x[1]["Name"]))
|
| 947 |
elif filter_sort_by == "filename":
|
| 948 |
return dict(sorted(simplified_list.items()))
|
| 949 |
else:
|
| 950 |
+
# Convert sort_by to lowercase for case-insensitive comparison
|
| 951 |
sort_by_lower = filter_sort_by.lower()
|
| 952 |
+
# Filter out models that don't have the specified stem
|
| 953 |
filtered_list = {k: v for k, v in simplified_list.items() if sort_by_lower in v["SDR"]}
|
| 954 |
+
|
| 955 |
+
# Sort by SDR score if available, putting None values last
|
| 956 |
def sort_key(item):
|
| 957 |
+
sdr = item[1]["SDR"][sort_by_lower]
|
| 958 |
return (0 if sdr is None else 1, sdr if sdr is not None else float("-inf"))
|
| 959 |
+
|
| 960 |
return dict(sorted(filtered_list.items(), key=sort_key, reverse=True))
|
| 961 |
|
| 962 |
return simplified_list
|