|
|
|
|
|
__author__ = 'Roman Solovyev (ZFTurbo): https://github.com/ZFTurbo/'
|
|
|
|
|
|
import time
|
|
|
import librosa
|
|
|
import sys
|
|
|
import os
|
|
|
import glob
|
|
|
import torch
|
|
|
import soundfile as sf
|
|
|
import numpy as np
|
|
|
from tqdm.auto import tqdm
|
|
|
import torch.nn as nn
|
|
|
|
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
sys.path.append(current_dir)
|
|
|
|
|
|
from utils.audio_utils import normalize_audio, denormalize_audio, draw_spectrogram
|
|
|
from utils.settings import get_model_from_config, parse_args_inference
|
|
|
from utils.model_utils import demix
|
|
|
from utils.model_utils import prefer_target_instrument, apply_tta, load_start_checkpoint
|
|
|
|
|
|
import warnings
|
|
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
|
|
|
def run_folder(model, args, config, device, verbose: bool = False):
|
|
|
"""
|
|
|
Process a folder of audio files for source separation.
|
|
|
|
|
|
Parameters:
|
|
|
----------
|
|
|
model : torch.nn.Module
|
|
|
Pre-trained model for source separation.
|
|
|
args : Namespace
|
|
|
Arguments containing input folder, output folder, and processing options.
|
|
|
config : Dict
|
|
|
Configuration object with audio and inference settings.
|
|
|
device : torch.device
|
|
|
Device for model inference (CPU or CUDA).
|
|
|
verbose : bool, optional
|
|
|
If True, prints detailed information during processing. Default is False.
|
|
|
"""
|
|
|
|
|
|
start_time = time.time()
|
|
|
model.eval()
|
|
|
|
|
|
mixture_paths = sorted(glob.glob(os.path.join(args.input_folder, '*.*')))
|
|
|
sample_rate = getattr(config.audio, 'sample_rate', 44100)
|
|
|
|
|
|
print(f"Total files found: {len(mixture_paths)}. Using sample rate: {sample_rate}")
|
|
|
|
|
|
instruments = prefer_target_instrument(config)[:]
|
|
|
os.makedirs(args.store_dir, exist_ok=True)
|
|
|
|
|
|
if not verbose:
|
|
|
mixture_paths = tqdm(mixture_paths, desc="Total progress")
|
|
|
|
|
|
if args.disable_detailed_pbar:
|
|
|
detailed_pbar = False
|
|
|
else:
|
|
|
detailed_pbar = True
|
|
|
|
|
|
for path in mixture_paths:
|
|
|
print(f"Processing track: {path}")
|
|
|
try:
|
|
|
mix, sr = librosa.load(path, sr=sample_rate, mono=False)
|
|
|
except Exception as e:
|
|
|
print(f'Cannot read track: {format(path)}')
|
|
|
print(f'Error message: {str(e)}')
|
|
|
continue
|
|
|
|
|
|
|
|
|
if len(mix.shape) == 1:
|
|
|
mix = np.expand_dims(mix, axis=0)
|
|
|
if 'num_channels' in config.audio:
|
|
|
if config.audio['num_channels'] == 2:
|
|
|
print(f'Convert mono track to stereo...')
|
|
|
mix = np.concatenate([mix, mix], axis=0)
|
|
|
|
|
|
mix_orig = mix.copy()
|
|
|
if 'normalize' in config.inference:
|
|
|
if config.inference['normalize'] is True:
|
|
|
mix, norm_params = normalize_audio(mix)
|
|
|
|
|
|
waveforms_orig = demix(config, model, mix, device, model_type=args.model_type, pbar=detailed_pbar)
|
|
|
|
|
|
if args.use_tta:
|
|
|
waveforms_orig = apply_tta(config, model, mix, waveforms_orig, device, args.model_type)
|
|
|
|
|
|
if args.extract_instrumental:
|
|
|
instr = 'vocals' if 'vocals' in instruments else instruments[0]
|
|
|
waveforms_orig['instrumental'] = mix_orig - waveforms_orig[instr]
|
|
|
if 'instrumental' not in instruments:
|
|
|
instruments.append('instrumental')
|
|
|
|
|
|
file_name = os.path.splitext(os.path.basename(path))[0]
|
|
|
|
|
|
for instr in instruments:
|
|
|
estimates = waveforms_orig[instr]
|
|
|
if 'normalize' in config.inference:
|
|
|
if config.inference['normalize'] is True:
|
|
|
estimates = denormalize_audio(estimates, norm_params)
|
|
|
|
|
|
codec = 'flac' if getattr(args, 'flac_file', False) else 'wav'
|
|
|
subtype = args.pcm_type
|
|
|
|
|
|
dirnames, fname = format_filename(
|
|
|
args.filename_template,
|
|
|
instr=instr,
|
|
|
start_time=int(start_time),
|
|
|
file_name=file_name,
|
|
|
dir_name=os.path.dirname(path),
|
|
|
model_type=args.model_type,
|
|
|
model=os.path.splitext(os.path.basename(args.start_check_point))[0]
|
|
|
)
|
|
|
|
|
|
output_dir = os.path.join(args.store_dir, *dirnames)
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
stem_fname = f"{file_name}_{instr}_stem"
|
|
|
output_path = os.path.join(output_dir, f"{stem_fname}.{codec}")
|
|
|
sf.write(output_path, estimates.T, sr, subtype=subtype)
|
|
|
print("Wrote file:", output_path)
|
|
|
if args.draw_spectro > 0:
|
|
|
output_img_path = os.path.join(output_dir, f"{stem_fname}.jpg")
|
|
|
draw_spectrogram(estimates.T, sr, args.draw_spectro, output_img_path)
|
|
|
print("Wrote file:", output_img_path)
|
|
|
|
|
|
print(f"Elapsed time: {time.time() - start_time:.2f} seconds.")
|
|
|
|
|
|
def format_filename(template, **kwargs):
|
|
|
'''
|
|
|
Formats a filename from a template. e.g "{file_name}/{instr}"
|
|
|
Using slashes ('/') in template will result in directories being created
|
|
|
Returns [dirnames, fname], i.e. an array of dir names and a single file name
|
|
|
'''
|
|
|
result = template
|
|
|
for k, v in kwargs.items():
|
|
|
result = result.replace(f"{{{k}}}", str(v))
|
|
|
*dirnames, fname = result.split("/")
|
|
|
return dirnames, fname
|
|
|
|
|
|
def proc_folder(dict_args):
|
|
|
args = parse_args_inference(dict_args)
|
|
|
device = "cpu"
|
|
|
if args.force_cpu:
|
|
|
device = "cpu"
|
|
|
elif torch.cuda.is_available():
|
|
|
print('CUDA is available, use --force_cpu to disable it.')
|
|
|
device = f'cuda:{args.device_ids[0]}' if isinstance(args.device_ids, list) else f'cuda:{args.device_ids}'
|
|
|
elif torch.backends.mps.is_available():
|
|
|
device = "mps"
|
|
|
|
|
|
print("Using device: ", device)
|
|
|
|
|
|
model_load_start_time = time.time()
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
model, config = get_model_from_config(args.model_type, args.config_path)
|
|
|
if 'model_type' in config.training:
|
|
|
args.model_type = config.training.model_type
|
|
|
if args.start_check_point:
|
|
|
checkpoint = torch.load(args.start_check_point, weights_only=False, map_location='cpu')
|
|
|
load_start_checkpoint(args, model, checkpoint, type_='inference')
|
|
|
|
|
|
print("Instruments: {}".format(config.training.instruments))
|
|
|
|
|
|
|
|
|
if isinstance(args.device_ids, list) and len(args.device_ids) > 1 and not args.force_cpu:
|
|
|
model = nn.DataParallel(model, device_ids=args.device_ids)
|
|
|
|
|
|
model = model.to(device)
|
|
|
|
|
|
print("Model load time: {:.2f} sec".format(time.time() - model_load_start_time))
|
|
|
|
|
|
run_folder(model, args, config, device, verbose=True)
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
proc_folder(None)
|
|
|
|