Spaces:
Running
on
Zero
Running
on
Zero
| # https://github.com/ZFTurbo/Music-Source-Separation-Training | |
| # https://huggingface.co/becruily/mel-band-roformer-karaoke/blob/main/mel_band_roformer_karaoke_becruily.ckpt | |
| # https://huggingface.co/anvuew/dereverb_mel_band_roformer/blob/main/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, Optional, Tuple | |
| import librosa | |
| import sys | |
| import os | |
| import time | |
| import torch | |
| import numpy as np | |
| from .utils.audio_utils import normalize_audio, denormalize_audio | |
| from .utils.settings import get_model_from_config, parse_args_inference | |
| from .utils.model_utils import demix | |
| from .utils.model_utils import prefer_target_instrument, apply_tta, load_start_checkpoint | |
| def process(mix, model, args, config, device): | |
| instruments = prefer_target_instrument(config)[:] | |
| # If mono audio we must adjust it depending on model | |
| if len(mix.shape) == 1: | |
| mix = np.expand_dims(mix, axis=0) | |
| if 'num_channels' in config.audio: | |
| if config.audio['num_channels'] == 2: | |
| # print(f'Convert mono track to stereo...') | |
| mix = np.concatenate([mix, mix], axis=0) | |
| if 'normalize' in config.inference: | |
| if config.inference['normalize'] is True: | |
| mix, norm_params = normalize_audio(mix) | |
| waveforms_orig = demix(config, model, mix, device, model_type=args.model_type, pbar=not args.disable_detailed_pbar) | |
| instr = 'vocals' if 'vocals' in instruments else instruments[0] | |
| estimates = waveforms_orig[instr] | |
| if 'normalize' in config.inference: | |
| if config.inference['normalize'] is True: | |
| estimates = denormalize_audio(estimates, norm_params) | |
| return estimates | |
| def build_model(args): | |
| model, config = get_model_from_config(args.model_type, args.config_path) | |
| load_start_checkpoint(args, model, None, type_='inference') | |
| return model, config | |
| def build_models(dict_args): | |
| args = parse_args_inference(dict_args) | |
| ########## load model ########## | |
| torch.backends.cudnn.benchmark = True | |
| args.config_path = args.sep_config_path | |
| args.start_check_point = args.sep_start_check_point | |
| sep_model, sep_config = build_model(args) | |
| args.config_path = args.der_config_path | |
| args.start_check_point = args.der_start_check_point | |
| dereverb_model, dereverb_config = build_model(args) | |
| sep_model = sep_model | |
| dereverb_model = dereverb_model | |
| return sep_model, sep_config, dereverb_model, dereverb_config, args | |
| def main(args, sep_model=None, sep_config=None, dereverb_model=None, dereverb_config=None, device=None): | |
| ######## process data ########## | |
| sample_rate = getattr(sep_config.audio, 'sample_rate', 44100) | |
| path = args.input_path | |
| mix, _ = librosa.load(path, sr=sample_rate, mono=False) | |
| vocals = process(mix, sep_model, args, sep_config, device) | |
| dereverbed_vocals = process(vocals.mean(0), dereverb_model, args, dereverb_config, device) | |
| accompaniment = mix - dereverbed_vocals | |
| return mix, vocals, dereverbed_vocals, accompaniment, sample_rate | |
| class VocalSeparationOutputs: | |
| """Vocal extraction output container.""" | |
| mix: np.ndarray | |
| vocals: np.ndarray | |
| vocals_dereverbed: np.ndarray | |
| accompaniment: np.ndarray | |
| sample_rate: int | |
| class VocalSeparator: | |
| """Vocal separation and dereverb wrapper. | |
| Wraps the karaoke separation and dereverb models from the | |
| ZFTurbo Music Source Separation project and exposes a simple | |
| :py:meth:`process` API that returns mix/vocals/dereverbed/accompaniment. | |
| """ | |
| def __init__( | |
| self, | |
| sep_model_path: str, | |
| sep_config_path: str, | |
| der_model_path: str, | |
| der_config_path: str, | |
| *, | |
| model_type: str = "mel_band_roformer", | |
| disable_detailed_pbar: bool = True, | |
| device: str = "cuda", | |
| verbose: bool = True, | |
| ): | |
| """Initialize the vocal separator. | |
| Args: | |
| device: Torch device string, e.g. ``"cuda:0"``. | |
| model_type: Separation model type key. | |
| sep_config_path: Config path for separation model. | |
| sep_start_check_point: Checkpoint path for separation model. | |
| der_config_path: Config path for dereverb model. | |
| der_start_check_point: Checkpoint path for dereverb model. | |
| disable_detailed_pbar: Disable detailed progress bars in underlying utils. | |
| verbose: Whether to print verbose logs. | |
| """ | |
| # Match original script args schema | |
| args_dict: Dict[str, Any] = { | |
| "model_type": model_type, | |
| "disable_detailed_pbar": disable_detailed_pbar, | |
| "sep_config_path": sep_config_path, | |
| "sep_start_check_point": sep_model_path, | |
| "der_config_path": der_config_path, | |
| "der_start_check_point": der_model_path, | |
| } | |
| if verbose: | |
| print("[vocal extraction] init: start") | |
| sep_model, sep_config, dereverb_model, dereverb_config, args = build_models(args_dict) | |
| sep_model = sep_model.to(device) | |
| dereverb_model = dereverb_model.to(device) | |
| self.sep_model = sep_model | |
| self.sep_config = sep_config | |
| self.dereverb_model = dereverb_model | |
| self.dereverb_config = dereverb_config | |
| self.device = device | |
| self.args = args | |
| self.verbose = verbose | |
| if verbose: | |
| print( | |
| "[vocal extraction] init success: sep=loaded, dereverb=loaded, device=", | |
| device, | |
| ) | |
| def process(self, input_path: str, *, verbose: Optional[bool] = None) -> VocalSeparationOutputs: | |
| """Separate a single audio file into sources. | |
| Args: | |
| input_path: Path to the mixture wav. | |
| verbose: Override instance-level verbose flag for this call. | |
| Returns: | |
| :class:`VocalSeparationOutputs` containing mix, vocals, | |
| dereverbed vocals, accompaniment and sample rate. | |
| """ | |
| verbose = self.verbose if verbose is None else verbose | |
| if verbose: | |
| print(f"[vocal extraction] process_file: start: {input_path}") | |
| t0 = time.time() | |
| self.args.input_path = input_path | |
| mix, vocals, dereverbed, accompaniment, sample_rate = main( | |
| self.args, | |
| self.sep_model, | |
| self.sep_config, | |
| self.dereverb_model, | |
| self.dereverb_config, | |
| torch.device(self.device) if not isinstance(self.device, torch.device) else self.device, | |
| ) | |
| if verbose: | |
| dt = time.time() - t0 | |
| print( | |
| "[vocal extraction] process_file: done:", | |
| f"sr={sample_rate}", | |
| f"mix={getattr(mix, 'shape', None)}", | |
| f"vocals={getattr(vocals, 'shape', None)}", | |
| f"dereverbed={getattr(dereverbed, 'shape', None)}", | |
| f"acc={getattr(accompaniment, 'shape', None)}", | |
| f"time={dt:.3f}s", | |
| ) | |
| return VocalSeparationOutputs( | |
| mix=mix, | |
| vocals=vocals, | |
| vocals_dereverbed=dereverbed, | |
| accompaniment=accompaniment, | |
| sample_rate=sample_rate, | |
| ) | |
| if __name__ == "__main__": | |
| m = VocalSeparator( | |
| sep_model_path="pretrained_models/mel-band-roformer-karaoke/mel_band_roformer_karaoke_becruily.ckpt", | |
| sep_config_path="pretrained_models/mel-band-roformer-karaoke/config_karaoke_becruily.yaml", | |
| der_model_path="pretrained_models/dereverb_mel_band_roformer/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt", | |
| der_config_path="pretrained_models/dereverb_mel_band_roformer/dereverb_mel_band_roformer_anvuew.yaml", | |
| device="cuda" | |
| ) | |
| out = m.process("example/test/separation_test.mp3") | |
| print(out.vocals_dereverbed.shape) | |