Spaces:
Running
Running
| """ This file contains the Separator class, to facilitate the separation of stems from audio. """ | |
| from importlib import metadata, resources | |
| import os | |
| import sys | |
| import platform | |
| import subprocess | |
| import time | |
| import logging | |
| import warnings | |
| import importlib | |
| import io | |
| from typing import Optional | |
| import hashlib | |
| import json | |
| import yaml | |
| import requests | |
| import torch | |
| import torch.amp.autocast_mode as autocast_mode | |
| import onnxruntime as ort | |
| from tqdm import tqdm | |
| class Separator: | |
| """ | |
| The Separator class is designed to facilitate the separation of audio sources from a given audio file. | |
| It supports various separation architectures and models, including MDX, VR, and Demucs. The class provides | |
| functionalities to configure separation parameters, load models, and perform audio source separation. | |
| It also handles logging, normalization, and output formatting of the separated audio stems. | |
| The actual separation task is handled by one of the architecture-specific classes in the `architectures` module; | |
| this class is responsible for initialising logging, configuring hardware acceleration, loading the model, | |
| initiating the separation process and passing outputs back to the caller. | |
| Common Attributes: | |
| log_level (int): The logging level. | |
| log_formatter (logging.Formatter): The logging formatter. | |
| model_file_dir (str): The directory where model files are stored. | |
| output_dir (str): The directory where output files will be saved. | |
| output_format (str): The format of the output audio file. | |
| output_bitrate (str): The bitrate of the output audio file. | |
| amplification_threshold (float): The threshold for audio amplification. | |
| normalization_threshold (float): The threshold for audio normalization. | |
| output_single_stem (str): Option to output a single stem. | |
| invert_using_spec (bool): Flag to invert using spectrogram. | |
| sample_rate (int): The sample rate of the audio. | |
| use_soundfile (bool): Use soundfile for audio writing, can solve OOM issues. | |
| use_autocast (bool): Flag to use PyTorch autocast for faster inference. | |
| MDX Architecture Specific Attributes: | |
| hop_length (int): The hop length for STFT. | |
| segment_size (int): The segment size for processing. | |
| overlap (float): The overlap between segments. | |
| batch_size (int): The batch size for processing. | |
| enable_denoise (bool): Flag to enable or disable denoising. | |
| VR Architecture Specific Attributes & Defaults: | |
| batch_size: 16 | |
| window_size: 512 | |
| aggression: 5 | |
| enable_tta: False | |
| enable_post_process: False | |
| post_process_threshold: 0.2 | |
| high_end_process: False | |
| Demucs Architecture Specific Attributes & Defaults: | |
| segment_size: "Default" | |
| shifts: 2 | |
| overlap: 0.25 | |
| segments_enabled: True | |
| MDXC Architecture Specific Attributes & Defaults: | |
| segment_size: 256 | |
| override_model_segment_size: False | |
| batch_size: 1 | |
| overlap: 8 | |
| pitch_shift: 0 | |
| """ | |
| def __init__( | |
| self, | |
| log_level=logging.INFO, | |
| log_formatter=None, | |
| model_file_dir="/tmp/audio-separator-models/", | |
| output_dir=None, | |
| output_format="WAV", | |
| output_bitrate=None, | |
| normalization_threshold=0.9, | |
| amplification_threshold=0.0, | |
| output_single_stem=None, | |
| invert_using_spec=False, | |
| sample_rate=44100, | |
| use_soundfile=False, | |
| use_autocast=False, | |
| use_directml=False, | |
| mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False}, | |
| vr_params={"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False}, | |
| demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}, | |
| mdxc_params={"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0}, | |
| info_only=False, | |
| ): | |
| """Initialize the separator.""" | |
| self.logger = logging.getLogger(__name__) | |
| self.logger.setLevel(log_level) | |
| self.log_level = log_level | |
| self.log_formatter = log_formatter | |
| self.log_handler = logging.StreamHandler() | |
| if self.log_formatter is None: | |
| self.log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(module)s - %(message)s") | |
| self.log_handler.setFormatter(self.log_formatter) | |
| if not self.logger.hasHandlers(): | |
| self.logger.addHandler(self.log_handler) | |
| # Filter out noisy warnings from PyTorch for users who don't care about them | |
| if log_level > logging.DEBUG: | |
| warnings.filterwarnings("ignore") | |
| # Skip initialization logs if info_only is True | |
| if not info_only: | |
| package_version = self.get_package_distribution("audio-separator").version | |
| self.logger.info(f"Separator version {package_version} instantiating with output_dir: {output_dir}, output_format: {output_format}") | |
| if output_dir is None: | |
| output_dir = os.getcwd() | |
| if not info_only: | |
| self.logger.info("Output directory not specified. Using current working directory.") | |
| self.output_dir = output_dir | |
| # Check for environment variable to override model_file_dir | |
| env_model_dir = os.environ.get("AUDIO_SEPARATOR_MODEL_DIR") | |
| if env_model_dir: | |
| self.model_file_dir = env_model_dir | |
| self.logger.info(f"Using model directory from AUDIO_SEPARATOR_MODEL_DIR env var: {self.model_file_dir}") | |
| if not os.path.exists(self.model_file_dir): | |
| raise FileNotFoundError(f"The specified model directory does not exist: {self.model_file_dir}") | |
| else: | |
| self.logger.info(f"Using model directory from model_file_dir parameter: {model_file_dir}") | |
| self.model_file_dir = model_file_dir | |
| # Create the model directory if it does not exist | |
| os.makedirs(self.model_file_dir, exist_ok=True) | |
| os.makedirs(self.output_dir, exist_ok=True) | |
| self.output_format = output_format | |
| self.output_bitrate = output_bitrate | |
| if self.output_format is None: | |
| self.output_format = "WAV" | |
| self.normalization_threshold = normalization_threshold | |
| if normalization_threshold <= 0 or normalization_threshold > 1: | |
| raise ValueError("The normalization_threshold must be greater than 0 and less than or equal to 1.") | |
| self.amplification_threshold = amplification_threshold | |
| if amplification_threshold < 0 or amplification_threshold > 1: | |
| raise ValueError("The amplification_threshold must be greater than or equal to 0 and less than or equal to 1.") | |
| self.output_single_stem = output_single_stem | |
| if output_single_stem is not None: | |
| self.logger.debug(f"Single stem output requested, so only one output file ({output_single_stem}) will be written") | |
| self.invert_using_spec = invert_using_spec | |
| if self.invert_using_spec: | |
| self.logger.debug(f"Secondary step will be inverted using spectogram rather than waveform. This may improve quality but is slightly slower.") | |
| try: | |
| self.sample_rate = int(sample_rate) | |
| if self.sample_rate <= 0: | |
| raise ValueError(f"The sample rate setting is {self.sample_rate} but it must be a non-zero whole number.") | |
| if self.sample_rate > 12800000: | |
| raise ValueError(f"The sample rate setting is {self.sample_rate}. Enter something less ambitious.") | |
| except ValueError: | |
| raise ValueError("The sample rate must be a non-zero whole number. Please provide a valid integer.") | |
| self.use_soundfile = use_soundfile | |
| self.use_autocast = use_autocast | |
| self.use_directml = use_directml | |
| # These are parameters which users may want to configure so we expose them to the top-level Separator class, | |
| # even though they are specific to a single model architecture | |
| self.arch_specific_params = {"MDX": mdx_params, "VR": vr_params, "Demucs": demucs_params, "MDXC": mdxc_params} | |
| self.torch_device = None | |
| self.torch_device_cpu = None | |
| self.torch_device_mps = None | |
| self.onnx_execution_provider = None | |
| self.model_instance = None | |
| self.model_is_uvr_vip = False | |
| self.model_friendly_name = None | |
| if not info_only: | |
| self.setup_accelerated_inferencing_device() | |
| def setup_accelerated_inferencing_device(self): | |
| """ | |
| This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available. | |
| """ | |
| system_info = self.get_system_info() | |
| self.check_ffmpeg_installed() | |
| self.log_onnxruntime_packages() | |
| self.setup_torch_device(system_info) | |
| def get_system_info(self): | |
| """ | |
| This method logs the system information, including the operating system, CPU archutecture and Python version | |
| """ | |
| os_name = platform.system() | |
| os_version = platform.version() | |
| self.logger.info(f"Operating System: {os_name} {os_version}") | |
| system_info = platform.uname() | |
| self.logger.info(f"System: {system_info.system} Node: {system_info.node} Release: {system_info.release} Machine: {system_info.machine} Proc: {system_info.processor}") | |
| python_version = platform.python_version() | |
| self.logger.info(f"Python Version: {python_version}") | |
| pytorch_version = torch.__version__ | |
| self.logger.info(f"PyTorch Version: {pytorch_version}") | |
| return system_info | |
| def check_ffmpeg_installed(self): | |
| """ | |
| This method checks if ffmpeg is installed and logs its version. | |
| """ | |
| try: | |
| ffmpeg_version_output = subprocess.check_output(["ffmpeg", "-version"], text=True) | |
| first_line = ffmpeg_version_output.splitlines()[0] | |
| self.logger.info(f"FFmpeg installed: {first_line}") | |
| except FileNotFoundError: | |
| self.logger.error("FFmpeg is not installed. Please install FFmpeg to use this package.") | |
| # Raise an exception if this is being run by a user, as ffmpeg is required for pydub to write audio | |
| # but if we're just running unit tests in CI, no reason to throw | |
| if "PYTEST_CURRENT_TEST" not in os.environ: | |
| raise | |
| def log_onnxruntime_packages(self): | |
| """ | |
| This method logs the ONNX Runtime package versions, including the GPU and Silicon packages if available. | |
| """ | |
| onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu") | |
| onnxruntime_silicon_package = self.get_package_distribution("onnxruntime-silicon") | |
| onnxruntime_cpu_package = self.get_package_distribution("onnxruntime") | |
| onnxruntime_dml_package = self.get_package_distribution("onnxruntime-directml") | |
| if onnxruntime_gpu_package is not None: | |
| self.logger.info(f"ONNX Runtime GPU package installed with version: {onnxruntime_gpu_package.version}") | |
| if onnxruntime_silicon_package is not None: | |
| self.logger.info(f"ONNX Runtime Silicon package installed with version: {onnxruntime_silicon_package.version}") | |
| if onnxruntime_cpu_package is not None: | |
| self.logger.info(f"ONNX Runtime CPU package installed with version: {onnxruntime_cpu_package.version}") | |
| if onnxruntime_dml_package is not None: | |
| self.logger.info(f"ONNX Runtime DirectML package installed with version: {onnxruntime_dml_package.version}") | |
| def setup_torch_device(self, system_info): | |
| """ | |
| This method sets up the PyTorch and/or ONNX Runtime inferencing device, using GPU hardware acceleration if available. | |
| """ | |
| hardware_acceleration_enabled = False | |
| ort_providers = ort.get_available_providers() | |
| has_torch_dml_installed = self.get_package_distribution("torch_directml") | |
| self.torch_device_cpu = torch.device("cpu") | |
| if torch.cuda.is_available(): | |
| self.configure_cuda(ort_providers) | |
| hardware_acceleration_enabled = True | |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and system_info.processor == "arm": | |
| self.configure_mps(ort_providers) | |
| hardware_acceleration_enabled = True | |
| elif self.use_directml and has_torch_dml_installed: | |
| import torch_directml | |
| if torch_directml.is_available(): | |
| self.configure_dml(ort_providers) | |
| hardware_acceleration_enabled = True | |
| if not hardware_acceleration_enabled: | |
| self.logger.info("No hardware acceleration could be configured, running in CPU mode") | |
| self.torch_device = self.torch_device_cpu | |
| self.onnx_execution_provider = ["CPUExecutionProvider"] | |
| def configure_cuda(self, ort_providers): | |
| """ | |
| This method configures the CUDA device for PyTorch and ONNX Runtime, if available. | |
| """ | |
| self.logger.info("CUDA is available in Torch, setting Torch device to CUDA") | |
| self.torch_device = torch.device("cuda") | |
| if "CUDAExecutionProvider" in ort_providers: | |
| self.logger.info("ONNXruntime has CUDAExecutionProvider available, enabling acceleration") | |
| self.onnx_execution_provider = ["CUDAExecutionProvider"] | |
| else: | |
| self.logger.warning("CUDAExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled") | |
| def configure_mps(self, ort_providers): | |
| """ | |
| This method configures the Apple Silicon MPS/CoreML device for PyTorch and ONNX Runtime, if available. | |
| """ | |
| self.logger.info("Apple Silicon MPS/CoreML is available in Torch and processor is ARM, setting Torch device to MPS") | |
| self.torch_device_mps = torch.device("mps") | |
| self.torch_device = self.torch_device_mps | |
| if "CoreMLExecutionProvider" in ort_providers: | |
| self.logger.info("ONNXruntime has CoreMLExecutionProvider available, enabling acceleration") | |
| self.onnx_execution_provider = ["CoreMLExecutionProvider"] | |
| else: | |
| self.logger.warning("CoreMLExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled") | |
| def configure_dml(self, ort_providers): | |
| """ | |
| This method configures the DirectML device for PyTorch and ONNX Runtime, if available. | |
| """ | |
| import torch_directml | |
| self.logger.info("DirectML is available in Torch, setting Torch device to DirectML") | |
| self.torch_device_dml = torch_directml.device() | |
| self.torch_device = self.torch_device_dml | |
| if "DmlExecutionProvider" in ort_providers: | |
| self.logger.info("ONNXruntime has DmlExecutionProvider available, enabling acceleration") | |
| self.onnx_execution_provider = ["DmlExecutionProvider"] | |
| else: | |
| self.logger.warning("DmlExecutionProvider not available in ONNXruntime, so acceleration will NOT be enabled") | |
| def get_package_distribution(self, package_name): | |
| """ | |
| This method returns the package distribution for a given package name if installed, or None otherwise. | |
| """ | |
| try: | |
| return metadata.distribution(package_name) | |
| except metadata.PackageNotFoundError: | |
| self.logger.debug(f"Python package: {package_name} not installed") | |
| return None | |
| def get_model_hash(self, model_path): | |
| """ | |
| This method returns the MD5 hash of a given model file. | |
| """ | |
| self.logger.debug(f"Calculating hash of model file {model_path}") | |
| # Use the specific byte count from the original logic | |
| BYTES_TO_HASH = 10000 * 1024 # 10,240,000 bytes | |
| try: | |
| file_size = os.path.getsize(model_path) | |
| with open(model_path, "rb") as f: | |
| if file_size < BYTES_TO_HASH: | |
| # Hash the entire file if smaller than the target byte count | |
| self.logger.debug(f"File size {file_size} < {BYTES_TO_HASH}, hashing entire file.") | |
| hash_value = hashlib.md5(f.read()).hexdigest() | |
| else: | |
| # Seek to the specific position before the end (from the beginning) and hash | |
| seek_pos = file_size - BYTES_TO_HASH | |
| self.logger.debug(f"File size {file_size} >= {BYTES_TO_HASH}, seeking to {seek_pos} and hashing remaining bytes.") | |
| f.seek(seek_pos, io.SEEK_SET) | |
| hash_value = hashlib.md5(f.read()).hexdigest() | |
| # Log the calculated hash | |
| self.logger.info(f"Hash of model file {model_path} is {hash_value}") | |
| return hash_value | |
| except FileNotFoundError: | |
| self.logger.error(f"Model file not found at {model_path}") | |
| raise # Re-raise the specific error | |
| except Exception as e: | |
| # Catch other potential errors (e.g., permissions, other IOErrors) | |
| self.logger.error(f"Error calculating hash for {model_path}: {e}") | |
| raise # Re-raise other errors | |
| def download_file_if_not_exists(self, url, output_path): | |
| """ | |
| This method downloads a file from a given URL to a given output path, if the file does not already exist. | |
| """ | |
| if os.path.isfile(output_path): | |
| self.logger.debug(f"File already exists at {output_path}, skipping download") | |
| return | |
| self.logger.debug(f"Downloading file from {url} to {output_path} with timeout 300s") | |
| response = requests.get(url, stream=True, timeout=300) | |
| if response.status_code == 200: | |
| total_size_in_bytes = int(response.headers.get("content-length", 0)) | |
| progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) | |
| with open(output_path, "wb") as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| progress_bar.update(len(chunk)) | |
| f.write(chunk) | |
| progress_bar.close() | |
| else: | |
| raise RuntimeError(f"Failed to download file from {url}, response code: {response.status_code}") | |
| def list_supported_model_files(self): | |
| """ | |
| This method lists the supported model files for audio-separator, by fetching the same file UVR uses to list these. | |
| Also includes model performance scores where available. | |
| Example response object: | |
| { | |
| "MDX": { | |
| "MDX-Net Model VIP: UVR-MDX-NET-Inst_full_292": { | |
| "filename": "UVR-MDX-NET-Inst_full_292.onnx", | |
| "scores": { | |
| "vocals": { | |
| "SDR": 10.6497, | |
| "SIR": 20.3786, | |
| "SAR": 10.692, | |
| "ISR": 14.848 | |
| }, | |
| "instrumental": { | |
| "SDR": 15.2149, | |
| "SIR": 25.6075, | |
| "SAR": 17.1363, | |
| "ISR": 17.7893 | |
| } | |
| }, | |
| "download_files": [ | |
| "UVR-MDX-NET-Inst_full_292.onnx" | |
| ] | |
| } | |
| }, | |
| "Demucs": { | |
| "Demucs v4: htdemucs_ft": { | |
| "filename": "htdemucs_ft.yaml", | |
| "scores": { | |
| "vocals": { | |
| "SDR": 11.2685, | |
| "SIR": 21.257, | |
| "SAR": 11.0359, | |
| "ISR": 19.3753 | |
| }, | |
| "drums": { | |
| "SDR": 13.235, | |
| "SIR": 23.3053, | |
| "SAR": 13.0313, | |
| "ISR": 17.2889 | |
| }, | |
| "bass": { | |
| "SDR": 9.72743, | |
| "SIR": 19.5435, | |
| "SAR": 9.20801, | |
| "ISR": 13.5037 | |
| } | |
| }, | |
| "download_files": [ | |
| "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/f7e0c4bc-ba3fe64a.th", | |
| "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/d12395a8-e57c48e6.th", | |
| "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/92cfc3b6-ef3bcb9c.th", | |
| "https://dl.fbaipublicfiles.com/demucs/hybrid_transformer/04573f0d-f3cf25b2.th", | |
| "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models/htdemucs_ft.yaml" | |
| ] | |
| } | |
| }, | |
| "MDXC": { | |
| "MDX23C Model: MDX23C-InstVoc HQ": { | |
| "filename": "MDX23C-8KFFT-InstVoc_HQ.ckpt", | |
| "scores": { | |
| "vocals": { | |
| "SDR": 11.9504, | |
| "SIR": 23.1166, | |
| "SAR": 12.093, | |
| "ISR": 15.4782 | |
| }, | |
| "instrumental": { | |
| "SDR": 16.3035, | |
| "SIR": 26.6161, | |
| "SAR": 18.5167, | |
| "ISR": 18.3939 | |
| } | |
| }, | |
| "download_files": [ | |
| "MDX23C-8KFFT-InstVoc_HQ.ckpt", | |
| "model_2_stem_full_band_8k.yaml" | |
| ] | |
| } | |
| } | |
| } | |
| """ | |
| download_checks_path = os.path.join(self.model_file_dir, "download_checks.json") | |
| self.download_file_if_not_exists("https://raw.githubusercontent.com/TRvlvr/application_data/main/filelists/download_checks.json", download_checks_path) | |
| model_downloads_list = json.load(open(download_checks_path, encoding="utf-8")) | |
| self.logger.debug(f"UVR model download list loaded") | |
| # Load the model scores with error handling | |
| model_scores = {} | |
| try: | |
| with resources.open_text("audio_separator", "models-scores.json") as f: | |
| model_scores = json.load(f) | |
| self.logger.debug(f"Model scores loaded") | |
| except json.JSONDecodeError as e: | |
| self.logger.warning(f"Failed to load model scores: {str(e)}") | |
| self.logger.warning("Continuing without model scores") | |
| # Only show Demucs v4 models as we've only implemented support for v4 | |
| filtered_demucs_v4 = {key: value for key, value in model_downloads_list["demucs_download_list"].items() if key.startswith("Demucs v4")} | |
| # Modified Demucs handling to use YAML files as identifiers and include download files | |
| demucs_models = {} | |
| for name, files in filtered_demucs_v4.items(): | |
| # Find the YAML file in the model files | |
| yaml_file = next((filename for filename in files.keys() if filename.endswith(".yaml")), None) | |
| if yaml_file: | |
| model_score_data = model_scores.get(yaml_file, {}) | |
| demucs_models[name] = { | |
| "filename": yaml_file, | |
| "scores": model_score_data.get("median_scores", {}), | |
| "stems": model_score_data.get("stems", []), | |
| "target_stem": model_score_data.get("target_stem"), | |
| "download_files": list(files.values()), # List of all download URLs/filenames | |
| } | |
| # Load the JSON file using importlib.resources | |
| with resources.open_text("audio_separator", "models.json") as f: | |
| audio_separator_models_list = json.load(f) | |
| self.logger.debug(f"Audio-Separator model list loaded") | |
| # Return object with list of model names | |
| model_files_grouped_by_type = { | |
| "VR": { | |
| name: { | |
| "filename": filename, | |
| "scores": model_scores.get(filename, {}).get("median_scores", {}), | |
| "stems": model_scores.get(filename, {}).get("stems", []), | |
| "target_stem": model_scores.get(filename, {}).get("target_stem"), | |
| "download_files": [filename], | |
| } # Just the filename for VR models | |
| for name, filename in {**model_downloads_list["vr_download_list"], **audio_separator_models_list["vr_download_list"]}.items() | |
| }, | |
| "MDX": { | |
| name: { | |
| "filename": filename, | |
| "scores": model_scores.get(filename, {}).get("median_scores", {}), | |
| "stems": model_scores.get(filename, {}).get("stems", []), | |
| "target_stem": model_scores.get(filename, {}).get("target_stem"), | |
| "download_files": [filename], | |
| } # Just the filename for MDX models | |
| for name, filename in {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"], **audio_separator_models_list["mdx_download_list"]}.items() | |
| }, | |
| "Demucs": demucs_models, | |
| "MDXC": { | |
| name: { | |
| "filename": next(iter(files.keys())), | |
| "scores": model_scores.get(next(iter(files.keys())), {}).get("median_scores", {}), | |
| "stems": model_scores.get(next(iter(files.keys())), {}).get("stems", []), | |
| "target_stem": model_scores.get(next(iter(files.keys())), {}).get("target_stem"), | |
| "download_files": list(files.keys()) + list(files.values()), # List of both model filenames and config filenames | |
| } | |
| for name, files in { | |
| **model_downloads_list["mdx23c_download_list"], | |
| **model_downloads_list["mdx23c_download_vip_list"], | |
| **model_downloads_list["roformer_download_list"], | |
| **audio_separator_models_list["mdx23c_download_list"], | |
| **audio_separator_models_list["roformer_download_list"], | |
| }.items() | |
| }, | |
| } | |
| return model_files_grouped_by_type | |
| def print_uvr_vip_message(self): | |
| """ | |
| This method prints a message to the user if they have downloaded a VIP model, reminding them to support Anjok07 on Patreon. | |
| """ | |
| if self.model_is_uvr_vip: | |
| self.logger.warning(f"The model: '{self.model_friendly_name}' is a VIP model, intended by Anjok07 for access by paying subscribers only.") | |
| self.logger.warning("If you are not already subscribed, please consider supporting the developer of UVR, Anjok07 by subscribing here: https://patreon.com/uvr") | |
| def download_model_files(self, model_filename): | |
| """ | |
| This method downloads the model files for a given model filename, if they are not already present. | |
| Returns tuple of (model_filename, model_type, model_friendly_name, model_path, yaml_config_filename) | |
| """ | |
| model_path = os.path.join(self.model_file_dir, f"{model_filename}") | |
| # BYPASS: If the model file already exists locally (pre-downloaded), skip the registry check entirely. | |
| # This allows custom/non-registered models to be loaded without being in UVR's supported model list. | |
| if os.path.isfile(model_path): | |
| self.logger.info(f"Model file already exists at {model_path}, skipping registry lookup (bypass active)") | |
| yaml_config_filename = None | |
| model_basename = model_filename.rsplit('.', 1)[0] | |
| # Try to find a matching YAML config file in the model directory | |
| try: | |
| for f in os.listdir(self.model_file_dir): | |
| if f.endswith('.yaml') and not f.startswith('model_data') and f != 'download_checks.json': | |
| # Prefer YAML files that share a name prefix with the model | |
| if model_basename in f or f.replace('.yaml', '') in model_basename: | |
| yaml_config_filename = f | |
| break | |
| # If no exact match, try any YAML that was recently modified (likely downloaded together) | |
| if yaml_config_filename is None: | |
| yaml_files = [f for f in os.listdir(self.model_file_dir) if f.endswith('.yaml') and not f.startswith('model_data')] | |
| if yaml_files: | |
| # Sort by modification time, newest first | |
| yaml_files.sort(key=lambda x: os.path.getmtime(os.path.join(self.model_file_dir, x)), reverse=True) | |
| yaml_config_filename = yaml_files[0] | |
| except Exception as e: | |
| self.logger.warning(f"Error scanning for YAML config: {e}") | |
| self.model_friendly_name = model_filename | |
| model_type = "MDXC" # MDXC type uses YAML-based config loading, bypassing hash check | |
| self.logger.info(f"Bypass: Using model_type={model_type}, yaml_config={yaml_config_filename}") | |
| return model_filename, model_type, model_filename, model_path, yaml_config_filename | |
| supported_model_files_grouped = self.list_supported_model_files() | |
| public_model_repo_url_prefix = "https://github.com/TRvlvr/model_repo/releases/download/all_public_uvr_models" | |
| vip_model_repo_url_prefix = "https://github.com/Anjok0109/ai_magic/releases/download/v5" | |
| audio_separator_models_repo_url_prefix = "https://github.com/nomadkaraoke/python-audio-separator/releases/download/model-configs" | |
| yaml_config_filename = None | |
| self.logger.debug(f"Searching for model_filename {model_filename} in supported_model_files_grouped") | |
| # Iterate through model types (MDX, Demucs, MDXC) | |
| for model_type, models in supported_model_files_grouped.items(): | |
| # Iterate through each model in this type | |
| for model_friendly_name, model_info in models.items(): | |
| self.model_is_uvr_vip = "VIP" in model_friendly_name | |
| model_repo_url_prefix = vip_model_repo_url_prefix if self.model_is_uvr_vip else public_model_repo_url_prefix | |
| # Check if this model matches our target filename | |
| if model_info["filename"] == model_filename or model_filename in model_info["download_files"]: | |
| self.logger.debug(f"Found matching model: {model_friendly_name}") | |
| self.model_friendly_name = model_friendly_name | |
| self.print_uvr_vip_message() | |
| # Download each required file for this model | |
| for file_to_download in model_info["download_files"]: | |
| # For URLs, extract just the filename portion | |
| if file_to_download.startswith("http"): | |
| filename = file_to_download.split("/")[-1] | |
| download_path = os.path.join(self.model_file_dir, filename) | |
| self.download_file_if_not_exists(file_to_download, download_path) | |
| continue | |
| download_path = os.path.join(self.model_file_dir, file_to_download) | |
| # For MDXC models, handle YAML config files specially | |
| if model_type == "MDXC" and file_to_download.endswith(".yaml"): | |
| yaml_config_filename = file_to_download | |
| try: | |
| yaml_url = f"{model_repo_url_prefix}/mdx_model_data/mdx_c_configs/{file_to_download}" | |
| self.download_file_if_not_exists(yaml_url, download_path) | |
| except RuntimeError: | |
| self.logger.debug("YAML config not found in UVR repo, trying audio-separator models repo...") | |
| yaml_url = f"{audio_separator_models_repo_url_prefix}/{file_to_download}" | |
| self.download_file_if_not_exists(yaml_url, download_path) | |
| continue | |
| # For regular model files, try UVR repo first, then audio-separator repo | |
| try: | |
| download_url = f"{model_repo_url_prefix}/{file_to_download}" | |
| self.download_file_if_not_exists(download_url, download_path) | |
| except RuntimeError: | |
| self.logger.debug("Model not found in UVR repo, trying audio-separator models repo...") | |
| download_url = f"{audio_separator_models_repo_url_prefix}/{file_to_download}" | |
| self.download_file_if_not_exists(download_url, download_path) | |
| return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename | |
| raise ValueError(f"Model file {model_filename} not found in supported model files") | |
| def load_model_data_from_yaml(self, yaml_config_filename): | |
| """ | |
| This method loads model-specific parameters from the YAML file for that model. | |
| The parameters in the YAML are critical to inferencing, as they need to match whatever was used during training. | |
| """ | |
| # Verify if the YAML filename includes a full path or just the filename | |
| if not os.path.exists(yaml_config_filename): | |
| model_data_yaml_filepath = os.path.join(self.model_file_dir, yaml_config_filename) | |
| else: | |
| model_data_yaml_filepath = yaml_config_filename | |
| self.logger.debug(f"Loading model data from YAML at path {model_data_yaml_filepath}") | |
| model_data = yaml.load(open(model_data_yaml_filepath, encoding="utf-8"), Loader=yaml.FullLoader) | |
| self.logger.debug(f"Model data loaded from YAML file: {model_data}") | |
| if "roformer" in model_data_yaml_filepath: | |
| model_data["is_roformer"] = True | |
| return model_data | |
| def load_model_data_using_hash(self, model_path): | |
| """ | |
| This method loads model-specific parameters from UVR model data files. | |
| These parameters are critical to inferencing using a given model, as they need to match whatever was used during training. | |
| The correct parameters are identified by calculating the hash of the model file and looking up the hash in the UVR data files. | |
| """ | |
| # Model data and configuration sources from UVR | |
| model_data_url_prefix = "https://raw.githubusercontent.com/TRvlvr/application_data/main" | |
| vr_model_data_url = f"{model_data_url_prefix}/vr_model_data/model_data_new.json" | |
| mdx_model_data_url = f"{model_data_url_prefix}/mdx_model_data/model_data_new.json" | |
| # Calculate hash for the downloaded model | |
| self.logger.debug("Calculating MD5 hash for model file to identify model parameters from UVR data...") | |
| model_hash = self.get_model_hash(model_path) | |
| self.logger.debug(f"Model {model_path} has hash {model_hash}") | |
| # Setting up the path for model data and checking its existence | |
| vr_model_data_path = os.path.join(self.model_file_dir, "vr_model_data.json") | |
| self.logger.debug(f"VR model data path set to {vr_model_data_path}") | |
| self.download_file_if_not_exists(vr_model_data_url, vr_model_data_path) | |
| mdx_model_data_path = os.path.join(self.model_file_dir, "mdx_model_data.json") | |
| self.logger.debug(f"MDX model data path set to {mdx_model_data_path}") | |
| self.download_file_if_not_exists(mdx_model_data_url, mdx_model_data_path) | |
| # Loading model data from UVR | |
| self.logger.debug("Loading MDX and VR model parameters from UVR model data files...") | |
| vr_model_data_object = json.load(open(vr_model_data_path, encoding="utf-8")) | |
| mdx_model_data_object = json.load(open(mdx_model_data_path, encoding="utf-8")) | |
| # Load additional model data from audio-separator | |
| self.logger.debug("Loading additional model parameters from audio-separator model data file...") | |
| with resources.open_text("audio_separator", "model-data.json") as f: | |
| audio_separator_model_data = json.load(f) | |
| # Merge the model data objects, with audio-separator data taking precedence | |
| vr_model_data_object = {**vr_model_data_object, **audio_separator_model_data.get("vr_model_data", {})} | |
| mdx_model_data_object = {**mdx_model_data_object, **audio_separator_model_data.get("mdx_model_data", {})} | |
| if model_hash in mdx_model_data_object: | |
| model_data = mdx_model_data_object[model_hash] | |
| elif model_hash in vr_model_data_object: | |
| model_data = vr_model_data_object[model_hash] | |
| else: | |
| raise ValueError(f"Unsupported Model File: parameters for MD5 hash {model_hash} could not be found in UVR model data file for MDX or VR arch.") | |
| self.logger.debug(f"Model data loaded using hash {model_hash}: {model_data}") | |
| return model_data | |
| def load_model(self, model_filename="model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt"): | |
| """ | |
| This method instantiates the architecture-specific separation class, | |
| loading the separation model into memory, downloading it first if necessary. | |
| """ | |
| self.logger.info(f"Loading model {model_filename}...") | |
| load_model_start_time = time.perf_counter() | |
| # Setting up the model path | |
| model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename) | |
| model_name = model_filename.split(".")[0] | |
| self.logger.debug(f"Model downloaded, friendly name: {model_friendly_name}, model_path: {model_path}") | |
| if model_path.lower().endswith(".yaml"): | |
| yaml_config_filename = model_path | |
| if yaml_config_filename is not None: | |
| model_data = self.load_model_data_from_yaml(yaml_config_filename) | |
| else: | |
| model_data = self.load_model_data_using_hash(model_path) | |
| common_params = { | |
| "logger": self.logger, | |
| "log_level": self.log_level, | |
| "torch_device": self.torch_device, | |
| "torch_device_cpu": self.torch_device_cpu, | |
| "torch_device_mps": self.torch_device_mps, | |
| "onnx_execution_provider": self.onnx_execution_provider, | |
| "model_name": model_name, | |
| "model_path": model_path, | |
| "model_data": model_data, | |
| "output_format": self.output_format, | |
| "output_bitrate": self.output_bitrate, | |
| "output_dir": self.output_dir, | |
| "normalization_threshold": self.normalization_threshold, | |
| "amplification_threshold": self.amplification_threshold, | |
| "output_single_stem": self.output_single_stem, | |
| "invert_using_spec": self.invert_using_spec, | |
| "sample_rate": self.sample_rate, | |
| "use_soundfile": self.use_soundfile, | |
| } | |
| # Instantiate the appropriate separator class depending on the model type | |
| separator_classes = {"MDX": "mdx_separator.MDXSeparator", "VR": "vr_separator.VRSeparator", "Demucs": "demucs_separator.DemucsSeparator", "MDXC": "mdxc_separator.MDXCSeparator"} | |
| if model_type not in self.arch_specific_params or model_type not in separator_classes: | |
| raise ValueError(f"Model type not supported (yet): {model_type}") | |
| if model_type == "Demucs" and sys.version_info < (3, 10): | |
| raise Exception("Demucs models require Python version 3.10 or newer.") | |
| self.logger.debug(f"Importing module for model type {model_type}: {separator_classes[model_type]}") | |
| module_name, class_name = separator_classes[model_type].split(".") | |
| module = importlib.import_module(f"audio_separator.separator.architectures.{module_name}") | |
| separator_class = getattr(module, class_name) | |
| self.logger.debug(f"Instantiating separator class for model type {model_type}: {separator_class}") | |
| self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type]) | |
| # Log the completion of the model load process | |
| self.logger.debug("Loading model completed.") | |
| self.logger.info(f'Load model duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - load_model_start_time)))}') | |
| def separate(self, audio_file_path, custom_output_names=None): | |
| """ | |
| Separates the audio file(s) into different stems (e.g., vocals, instruments) using the loaded model. | |
| This method takes the path to an audio file or a directory containing audio files, processes them through | |
| the loaded separation model, and returns the paths to the output files containing the separated audio stems. | |
| It handles the entire flow from loading the audio, running the separation, clearing up resources, and logging the process. | |
| Parameters: | |
| - audio_file_path (str or list): The path to the audio file or directory, or a list of paths. | |
| - custom_output_names (dict, optional): Custom names for the output files. Defaults to None. | |
| Returns: | |
| - output_files (list of str): A list containing the paths to the separated audio stem files. | |
| """ | |
| # Check if the model and device are properly initialized | |
| if not (self.torch_device and self.model_instance): | |
| raise ValueError("Initialization failed or model not loaded. Please load a model before attempting to separate.") | |
| # If audio_file_path is a string, convert it to a list for uniform processing | |
| if isinstance(audio_file_path, str): | |
| audio_file_path = [audio_file_path] | |
| # Initialize a list to store paths of all output files | |
| output_files = [] | |
| # Process each path in the list | |
| for path in audio_file_path: | |
| if os.path.isdir(path): | |
| # If the path is a directory, recursively search for all audio files | |
| for root, dirs, files in os.walk(path): | |
| for file in files: | |
| # Check the file extension to ensure it's an audio file | |
| if file.endswith((".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aiff", ".ac3")): # Add other formats if needed | |
| full_path = os.path.join(root, file) | |
| self.logger.info(f"Processing file: {full_path}") | |
| try: | |
| # Perform separation for each file | |
| files_output = self._separate_file(full_path, custom_output_names) | |
| output_files.extend(files_output) | |
| except Exception as e: | |
| self.logger.error(f"Failed to process file {full_path}: {e}") | |
| else: | |
| # If the path is a file, process it directly | |
| self.logger.info(f"Processing file: {path}") | |
| try: | |
| files_output = self._separate_file(path, custom_output_names) | |
| output_files.extend(files_output) | |
| except Exception as e: | |
| self.logger.error(f"Failed to process file {path}: {e}") | |
| return output_files | |
| def _separate_file(self, audio_file_path, custom_output_names=None): | |
| """ | |
| Internal method to handle separation for a single audio file. | |
| This method performs the actual separation process for a single audio file. It logs the start and end of the process, | |
| handles autocast if enabled, and ensures GPU cache is cleared after processing. | |
| Parameters: | |
| - audio_file_path (str): The path to the audio file. | |
| - custom_output_names (dict, optional): Custom names for the output files. Defaults to None. | |
| Returns: | |
| - output_files (list of str): A list containing the paths to the separated audio stem files. | |
| """ | |
| # Log the start of the separation process | |
| self.logger.info(f"Starting separation process for audio_file_path: {audio_file_path}") | |
| separate_start_time = time.perf_counter() | |
| # Log normalization and amplification thresholds | |
| self.logger.debug(f"Normalization threshold set to {self.normalization_threshold}, waveform will be lowered to this max amplitude to avoid clipping.") | |
| self.logger.debug(f"Amplification threshold set to {self.amplification_threshold}, waveform will be scaled up to this max amplitude if below it.") | |
| # Run separation method for the loaded model with autocast enabled if supported by the device | |
| output_files = None | |
| if self.use_autocast and autocast_mode.is_autocast_available(self.torch_device.type): | |
| self.logger.debug("Autocast available.") | |
| with autocast_mode.autocast(self.torch_device.type): | |
| output_files = self.model_instance.separate(audio_file_path, custom_output_names) | |
| else: | |
| self.logger.debug("Autocast unavailable.") | |
| output_files = self.model_instance.separate(audio_file_path, custom_output_names) | |
| # Clear GPU cache to free up memory | |
| self.model_instance.clear_gpu_cache() | |
| # Unset separation parameters to prevent accidentally re-using the wrong source files or output paths | |
| self.model_instance.clear_file_specific_paths() | |
| # Remind the user one more time if they used a VIP model, so the message doesn't get lost in the logs | |
| self.print_uvr_vip_message() | |
| # Log the completion of the separation process | |
| self.logger.debug("Separation process completed.") | |
| self.logger.info(f'Separation duration: {time.strftime("%H:%M:%S", time.gmtime(int(time.perf_counter() - separate_start_time)))}') | |
| return output_files | |
| def download_model_and_data(self, model_filename): | |
| """ | |
| Downloads the model file without loading it into memory. | |
| """ | |
| self.logger.info(f"Downloading model {model_filename}...") | |
| model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename) | |
| if model_path.lower().endswith(".yaml"): | |
| yaml_config_filename = model_path | |
| if yaml_config_filename is not None: | |
| model_data = self.load_model_data_from_yaml(yaml_config_filename) | |
| else: | |
| model_data = self.load_model_data_using_hash(model_path) | |
| model_data_dict_size = len(model_data) | |
| self.logger.info(f"Model downloaded, type: {model_type}, friendly name: {model_friendly_name}, model_path: {model_path}, model_data: {model_data_dict_size} items") | |
| def get_simplified_model_list(self, filter_sort_by: Optional[str] = None): | |
| """ | |
| Returns a simplified, user-friendly list of models with their key metrics. | |
| Optionally sorts the list based on the specified criteria. | |
| :param sort_by: Criteria to sort by. Can be "name", "filename", or any stem name | |
| """ | |
| model_files = self.list_supported_model_files() | |
| simplified_list = {} | |
| for model_type, models in model_files.items(): | |
| for name, data in models.items(): | |
| filename = data["filename"] | |
| scores = data.get("scores") or {} | |
| stems = data.get("stems") or [] | |
| target_stem = data.get("target_stem") | |
| # Format stems with their SDR scores where available | |
| stems_with_scores = [] | |
| stem_sdr_dict = {} | |
| # Process each stem from the model's stem list | |
| for stem in stems: | |
| stem_scores = scores.get(stem, {}) | |
| # Add asterisk if this is the target stem | |
| stem_display = f"{stem}*" if stem == target_stem else stem | |
| if isinstance(stem_scores, dict) and "SDR" in stem_scores: | |
| sdr = round(stem_scores["SDR"], 1) | |
| stems_with_scores.append(f"{stem_display} ({sdr})") | |
| stem_sdr_dict[stem.lower()] = sdr | |
| else: | |
| # Include stem without SDR score | |
| stems_with_scores.append(stem_display) | |
| stem_sdr_dict[stem.lower()] = None | |
| # If no stems listed, mark as Unknown | |
| if not stems_with_scores: | |
| stems_with_scores = ["Unknown"] | |
| stem_sdr_dict["unknown"] = None | |
| simplified_list[filename] = {"Name": name, "Type": model_type, "Stems": stems_with_scores, "SDR": stem_sdr_dict} | |
| # Sort and filter the list if a sort_by parameter is provided | |
| if filter_sort_by: | |
| if filter_sort_by == "name": | |
| return dict(sorted(simplified_list.items(), key=lambda x: x[1]["Name"])) | |
| elif filter_sort_by == "filename": | |
| return dict(sorted(simplified_list.items())) | |
| else: | |
| # Convert sort_by to lowercase for case-insensitive comparison | |
| sort_by_lower = filter_sort_by.lower() | |
| # Filter out models that don't have the specified stem | |
| filtered_list = {k: v for k, v in simplified_list.items() if sort_by_lower in v["SDR"]} | |
| # Sort by SDR score if available, putting None values last | |
| def sort_key(item): | |
| sdr = item[1]["SDR"][sort_by_lower] | |
| return (0 if sdr is None else 1, sdr if sdr is not None else float("-inf")) | |
| return dict(sorted(filtered_list.items(), key=sort_key, reverse=True)) | |
| return simplified_list | |