Xinsheng-Wang's picture
Upload folder using huggingface_hub
c7f3ffb verified
# https://github.com/ZFTurbo/Music-Source-Separation-Training
# https://huggingface.co/becruily/mel-band-roformer-karaoke/blob/main/mel_band_roformer_karaoke_becruily.ckpt
# https://huggingface.co/anvuew/dereverb_mel_band_roformer/blob/main/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import librosa
import sys
import os
import time
import torch
import numpy as np
from .utils.audio_utils import normalize_audio, denormalize_audio
from .utils.settings import get_model_from_config, parse_args_inference
from .utils.model_utils import demix
from .utils.model_utils import prefer_target_instrument, apply_tta, load_start_checkpoint
def process(mix, model, args, config, device):
instruments = prefer_target_instrument(config)[:]
# If mono audio we must adjust it depending on model
if len(mix.shape) == 1:
mix = np.expand_dims(mix, axis=0)
if 'num_channels' in config.audio:
if config.audio['num_channels'] == 2:
# print(f'Convert mono track to stereo...')
mix = np.concatenate([mix, mix], axis=0)
if 'normalize' in config.inference:
if config.inference['normalize'] is True:
mix, norm_params = normalize_audio(mix)
waveforms_orig = demix(config, model, mix, device, model_type=args.model_type, pbar=not args.disable_detailed_pbar)
instr = 'vocals' if 'vocals' in instruments else instruments[0]
estimates = waveforms_orig[instr]
if 'normalize' in config.inference:
if config.inference['normalize'] is True:
estimates = denormalize_audio(estimates, norm_params)
return estimates
def build_model(args):
model, config = get_model_from_config(args.model_type, args.config_path)
load_start_checkpoint(args, model, None, type_='inference')
return model, config
def build_models(dict_args):
args = parse_args_inference(dict_args)
########## load model ##########
torch.backends.cudnn.benchmark = True
args.config_path = args.sep_config_path
args.start_check_point = args.sep_start_check_point
sep_model, sep_config = build_model(args)
args.config_path = args.der_config_path
args.start_check_point = args.der_start_check_point
dereverb_model, dereverb_config = build_model(args)
sep_model = sep_model
dereverb_model = dereverb_model
return sep_model, sep_config, dereverb_model, dereverb_config, args
def main(args, sep_model=None, sep_config=None, dereverb_model=None, dereverb_config=None, device=None):
######## process data ##########
sample_rate = getattr(sep_config.audio, 'sample_rate', 44100)
path = args.input_path
mix, _ = librosa.load(path, sr=sample_rate, mono=False)
vocals = process(mix, sep_model, args, sep_config, device)
dereverbed_vocals = process(vocals.mean(0), dereverb_model, args, dereverb_config, device)
accompaniment = mix - dereverbed_vocals
return mix, vocals, dereverbed_vocals, accompaniment, sample_rate
@dataclass
class VocalSeparationOutputs:
"""Vocal extraction output container."""
mix: np.ndarray
vocals: np.ndarray
vocals_dereverbed: np.ndarray
accompaniment: np.ndarray
sample_rate: int
class VocalSeparator:
"""Vocal separation and dereverb wrapper.
Wraps the karaoke separation and dereverb models from the
ZFTurbo Music Source Separation project and exposes a simple
:py:meth:`process` API that returns mix/vocals/dereverbed/accompaniment.
"""
def __init__(
self,
sep_model_path: str,
sep_config_path: str,
der_model_path: str,
der_config_path: str,
*,
model_type: str = "mel_band_roformer",
disable_detailed_pbar: bool = True,
device: str = "cuda",
verbose: bool = True,
):
"""Initialize the vocal separator.
Args:
device: Torch device string, e.g. ``"cuda:0"``.
model_type: Separation model type key.
sep_config_path: Config path for separation model.
sep_start_check_point: Checkpoint path for separation model.
der_config_path: Config path for dereverb model.
der_start_check_point: Checkpoint path for dereverb model.
disable_detailed_pbar: Disable detailed progress bars in underlying utils.
verbose: Whether to print verbose logs.
"""
# Match original script args schema
args_dict: Dict[str, Any] = {
"model_type": model_type,
"disable_detailed_pbar": disable_detailed_pbar,
"sep_config_path": sep_config_path,
"sep_start_check_point": sep_model_path,
"der_config_path": der_config_path,
"der_start_check_point": der_model_path,
}
if verbose:
print("[vocal extraction] init: start")
sep_model, sep_config, dereverb_model, dereverb_config, args = build_models(args_dict)
sep_model = sep_model.to(device)
dereverb_model = dereverb_model.to(device)
self.sep_model = sep_model
self.sep_config = sep_config
self.dereverb_model = dereverb_model
self.dereverb_config = dereverb_config
self.device = device
self.args = args
self.verbose = verbose
if verbose:
print(
"[vocal extraction] init success: sep=loaded, dereverb=loaded, device=",
device,
)
def process(self, input_path: str, *, verbose: Optional[bool] = None) -> VocalSeparationOutputs:
"""Separate a single audio file into sources.
Args:
input_path: Path to the mixture wav.
verbose: Override instance-level verbose flag for this call.
Returns:
:class:`VocalSeparationOutputs` containing mix, vocals,
dereverbed vocals, accompaniment and sample rate.
"""
verbose = self.verbose if verbose is None else verbose
if verbose:
print(f"[vocal extraction] process_file: start: {input_path}")
t0 = time.time()
self.args.input_path = input_path
mix, vocals, dereverbed, accompaniment, sample_rate = main(
self.args,
self.sep_model,
self.sep_config,
self.dereverb_model,
self.dereverb_config,
torch.device(self.device) if not isinstance(self.device, torch.device) else self.device,
)
if verbose:
dt = time.time() - t0
print(
"[vocal extraction] process_file: done:",
f"sr={sample_rate}",
f"mix={getattr(mix, 'shape', None)}",
f"vocals={getattr(vocals, 'shape', None)}",
f"dereverbed={getattr(dereverbed, 'shape', None)}",
f"acc={getattr(accompaniment, 'shape', None)}",
f"time={dt:.3f}s",
)
return VocalSeparationOutputs(
mix=mix,
vocals=vocals,
vocals_dereverbed=dereverbed,
accompaniment=accompaniment,
sample_rate=sample_rate,
)
if __name__ == "__main__":
m = VocalSeparator(
sep_model_path="pretrained_models/mel-band-roformer-karaoke/mel_band_roformer_karaoke_becruily.ckpt",
sep_config_path="pretrained_models/mel-band-roformer-karaoke/config_karaoke_becruily.yaml",
der_model_path="pretrained_models/dereverb_mel_band_roformer/dereverb_mel_band_roformer_anvuew_sdr_19.1729.ckpt",
der_config_path="pretrained_models/dereverb_mel_band_roformer/dereverb_mel_band_roformer_anvuew.yaml",
device="cuda"
)
out = m.process("example/test/separation_test.mp3")
print(out.vocals_dereverbed.shape)