| import logging |
| import multiprocessing |
| from concurrent.futures import ProcessPoolExecutor |
| from random import shuffle |
| from tqdm import tqdm |
| import sys, wave |
| import torch, torchaudio |
| import hashlib |
| import time, os, psutil |
|
|
| |
| THIS_FILE = os.path.abspath(__file__) |
| UVR5_ROOT = os.path.dirname(THIS_FILE) |
| if UVR5_ROOT not in sys.path: |
| sys.path.append(UVR5_ROOT) |
|
|
| from gui_data.constants import * |
| from lib_v5.vr_network.model_param_init import ModelParameters |
| import argparse, json |
| import onnx |
| import onnxruntime as ort |
| import traceback |
| from datetime import datetime |
|
|
| logging.getLogger("numba").setLevel(logging.WARNING) |
| logging.getLogger("matplotlib").setLevel(logging.WARNING) |
|
|
| class ModelData(): |
| def __init__(self, |
| model_path, |
| audio_path, |
| result_path, |
| process_method, |
| device, |
| save_background=True, |
| is_pre_proc_model=False, |
| base_dir=UVR5_ROOT, |
| **kwargs): |
| self.__dict__.update(kwargs) |
|
|
| BASE_PATH = result_path |
| VR_MODELS_DIR = os.path.join(base_dir, 'models', 'VR_Models') |
| VR_HASH_JSON = os.path.join(VR_MODELS_DIR, 'model_data', 'model_data.json') |
| VR_PARAM_DIR = os.path.join(base_dir, 'lib_v5', 'vr_network', 'modelparams') |
| SAMPLE_CLIP_PATH = os.path.join(BASE_PATH, 'temp_sample_clips') |
|
|
| MDX_MIXER_PATH = os.path.join(base_dir, 'lib_v5', 'mixer.ckpt') |
| |
| |
| MDX_HASH_JSON = os.path.join(base_dir, 'model_data.json') |
| MDX_MODEL_NAME_SELECT = os.path.join(base_dir, 'model_name_mapper.json') |
|
|
| self.model_name = self.model_name |
| self.aggression_setting = float(int(self.aggression_setting)/100) |
| self.window_size = int(self.window_size) |
| self.batch_size = int(self.batch_size) if self.batch_size.isdigit() else 1 |
| self.mdx_batch_size = 1 if self.mdx_batch_size == DEF_OPT else int(self.mdx_batch_size) |
| self.is_mdx_ckpt = False |
| self.crop_size = int(self.crop_size) |
| self.is_high_end_process = 'mirroring' if self.is_high_end_process else 'None' |
| self.post_process_threshold = float(self.post_process_threshold) |
| self.model_capacity = 32, 128 |
| self.model_path = model_path |
| self.result_path = result_path |
| self.model_basename = os.path.splitext(os.path.basename(self.model_path))[0] |
| self.mixer_path = MDX_MIXER_PATH |
| self.process_method = process_method |
| self.is_pre_proc_model = is_pre_proc_model |
| self.vr_is_secondary_model = self.vr_is_secondary_model_activate |
| self.mdx_is_secondary_model = self.mdx_is_secondary_model_activate |
| self.is_ensemble_mode = False |
| self.secondary_model = None |
| self.primary_model_primary_stem = None |
| self.primary_stem = None |
| self.secondary_stem = None |
| self.secondary_model_scale = None |
| self.is_demucs_pre_proc_model_inst_mix = False |
| self.device = device |
| self.save_background = save_background |
|
|
| if type(audio_path)==str and os.path.isdir(audio_path): |
| self.inputPaths = os.listdir(audio_path) |
| self.inputPaths = [os.path.join(audio_path, x) for x in self.inputPaths if x[-4:]=='.wav'] |
| elif type(audio_path)==str and audio_path[-4:] == '.wav': |
| self.inputPaths = [audio_path] |
| elif type(audio_path) == list and audio_path[0][-4:] == '.wav': |
| self.inputPaths = audio_path |
| else: |
| print(f"Invalid audio_path {audio_path}") |
|
|
| self.get_model_hash() |
| |
| if self.process_method == VR_ARCH_TYPE: |
| self.model_data = json.loads(open(VR_HASH_JSON, 'r', encoding='utf-8').read())[self.model_hash] |
| if self.model_data: |
| vr_model_param = os.path.join(VR_PARAM_DIR, "{}.json".format(self.model_data["vr_model_param"])) |
| self.primary_stem = self.model_data["primary_stem"] |
| self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem] |
| self.vr_model_param = ModelParameters(vr_model_param) |
| self.model_samplerate = self.vr_model_param.param['sr'] |
| if "nout" in self.model_data.keys() and "nout_lstm" in self.model_data.keys(): |
| self.model_capacity = self.model_data["nout"], self.model_data["nout_lstm"] |
| self.is_vr_51_model = True |
| else: |
| self.model_status = False |
| |
| |
| if self.process_method == MDX_ARCH_TYPE: |
| self.is_vr_51_model = False |
| self.margin = int(self.margin) |
| self.model_samplerate = self.margin |
| self.chunks = self.determine_auto_chunks(self.chunks) if self.is_chunk_mdxnet else 0 |
| self.model_data = json.loads(open(MDX_HASH_JSON, 'r', encoding='utf-8').read())[self.model_hash] |
| if self.model_data: |
| self.is_secondary_model = self.mdx_is_secondary_model |
| self.compensate = self.model_data["compensate"] |
| self.mdx_dim_f_set = self.model_data["mdx_dim_f_set"] |
| self.mdx_dim_t_set = self.model_data["mdx_dim_t_set"] |
| self.mdx_n_fft_scale_set = self.model_data["mdx_n_fft_scale_set"] |
| self.primary_stem = self.model_data["primary_stem"] |
| self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem] |
| else: |
| self.model_status = False |
|
|
|
|
| def determine_auto_chunks(self, chunks): |
| """Determines appropriate chunk size based on user computer specs""" |
| gpu = 0 if torch.cuda.device_count() > 0 else -1 |
| if OPERATING_SYSTEM == 'Darwin': |
| gpu = -1 |
|
|
| if chunks == BATCH_MODE: |
| chunks = 0 |
| |
|
|
| if chunks == 'Full': |
| chunk_set = 0 |
| elif chunks == 'Auto': |
| if gpu == 0: |
| gpu_mem = round(torch.cuda.get_device_properties(0).total_memory/1.074e+9) |
| if gpu_mem <= int(6): |
| chunk_set = int(5) |
| if gpu_mem in [7, 8, 9, 10, 11, 12, 13, 14, 15]: |
| chunk_set = int(10) |
| if gpu_mem >= int(16): |
| chunk_set = int(40) |
| if gpu == -1: |
| sys_mem = psutil.virtual_memory().total >> 30 |
| if sys_mem <= int(4): |
| chunk_set = int(1) |
| if sys_mem in [5, 6, 7, 8]: |
| chunk_set = int(10) |
| if sys_mem in [9, 10, 11, 12, 13, 14, 15, 16]: |
| chunk_set = int(25) |
| if sys_mem >= int(17): |
| chunk_set = int(60) |
| elif chunks == '0': |
| chunk_set = 0 |
| else: |
| chunk_set = int(chunks) |
| print("chunks: ", gpu_mem, chunk_set) |
| return chunk_set |
|
|
|
|
| def get_model_hash(self): |
| self.model_hash = None |
| |
| if not os.path.isfile(self.model_path): |
| self.model_status = False |
| self.model_hash is None |
| else: |
| if not self.model_hash: |
| try: |
| with open(self.model_path, 'rb') as f: |
| f.seek(- 10000 * 1024, 2) |
| self.model_hash = hashlib.md5(f.read()).hexdigest() |
| except: |
| self.model_hash = hashlib.md5(open(self.model_path,'rb').read()).hexdigest() |
|
|
|
|
| class Inference(): |
| def __init__(self, model_data: ModelData, device): |
| self.device = device |
| self.n_fft = model_data.mdx_n_fft_scale_set |
| self.is_normalization = model_data.is_normalization |
| self.compensate = model_data.compensate |
| self.dim_f, self.dim_t = model_data.mdx_dim_f_set, 2**model_data.mdx_dim_t_set |
| self.mdx_batch_size = model_data.mdx_batch_size |
| self.is_denoise = model_data.is_denoise |
| self.hop = 1024 |
| self.dim_c = 4 |
| self.chunks = model_data.chunks |
| self.margin = model_data.margin |
| self.adjust = 1 |
| self.progress_value = 0 |
|
|
| self.n_bins = self.n_fft//2+1 |
| self.trim = self.n_fft//2 |
| self.chunk_size = self.hop * (self.dim_t-1) |
| self.window = torch.hann_window(window_length=self.n_fft, periodic=False).to(self.device) |
| self.freq_pad = torch.zeros([1, self.dim_c, self.n_bins-self.dim_f, self.dim_t]).to(self.device) |
| self.gen_size = self.chunk_size-2*self.trim |
| self.save_background = model_data.save_background |
|
|
|
|
| def stft(self, x): |
| x = x.reshape([-1, self.chunk_size]) |
| x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True,return_complex=True) |
| x=torch.view_as_real(x) |
| x = x.permute([0,3,1,2]) |
| x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,self.dim_c,self.n_bins,self.dim_t]) |
| return x[:,:,:self.dim_f] |
|
|
| def istft(self, x, freq_pad=None): |
| freq_pad = self.freq_pad.repeat([x.shape[0],1,1,1]) if freq_pad is None else freq_pad |
| x = torch.cat([x, freq_pad], -2) |
| x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t]) |
| x = x.permute([0,2,3,1]) |
| x=x.contiguous() |
| x=torch.view_as_complex(x) |
| x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) |
| return x.reshape([-1,2,self.chunk_size]) |
|
|
|
|
| def load_model(self, model_path, threads, device='cpu'): |
| model = onnx.load_model(model_path) |
| if torch.cuda.is_available() and device != 'cpu': |
| providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(), |
| "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})] |
| else: |
| providers = ["CPUExecutionProvider"] |
|
|
| sess_options = ort.SessionOptions() |
| sess_options.intra_op_num_threads = threads |
| |
| self.ort_ = ort.InferenceSession(model.SerializeToString(), sess_options=sess_options, providers=providers) |
|
|
| self.model_run = lambda spek:self.ort_.run(None, {'input': spek.cpu().numpy()})[0] |
|
|
|
|
| def initialize_mix(self, mix): |
| mix_waves = [] |
| n_sample = mix.shape[1] |
| pad = self.gen_size - n_sample%self.gen_size |
| zero_pad = torch.zeros((2,self.trim), device=mix.device) |
| |
| mix_p = torch.cat((zero_pad, mix, torch.zeros((2,pad), device=mix.device), zero_pad), 1) |
| i = 0 |
| while i < n_sample + pad: |
| waves = mix_p[:, i:i+self.chunk_size] |
| mix_waves.append(waves.unsqueeze(0)) |
| i += self.gen_size |
| |
| mix_waves = torch.cat(mix_waves, 0).to(self.device) |
| |
| return mix_waves, pad |
|
|
|
|
| def run_model(self, mix, is_match_mix=False): |
| |
| spek = self.stft(mix.to(self.device))*self.adjust |
| spek[:, :, :3, :] *= 0 |
| |
| if is_match_mix: |
| spec_pred = spek.to(self.device) |
| else: |
| spec_pred = -self.model_run(-spek)*0.5+self.model_run(spek)*0.5 if self.is_denoise else self.model_run(spek) |
| spec_pred = torch.from_numpy(spec_pred).to(self.device) |
|
|
| |
| return self.istft(spec_pred).to(self.device)[:,:,self.trim:-self.trim].transpose(0,1).reshape(2, -1) |
|
|
|
|
| def demix_base(self, mix, is_match_mix=False, device='cpu'): |
| chunked_sources = [] |
| |
| for slice in mix: |
| |
| sources = [] |
| tar_waves_ = [] |
| mix_p = mix[slice] |
| |
| mix_waves, pad = self.initialize_mix(mix_p.to(device)) |
| mix_waves = mix_waves.split(self.mdx_batch_size) |
| with torch.no_grad(): |
| for mix_wave in mix_waves: |
| |
| |
| tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix) |
| tar_waves_.append(tar_waves) |
|
|
| tar_waves = torch.cat(tar_waves_, axis=-1)[:, :-pad] |
| start = 0 if slice == 0 else self.margin |
| end = None if slice == list(mix.keys())[::-1][0] or self.margin == 0 else -self.margin |
| sources.append(tar_waves[:,start:end]*(1/self.adjust)) |
| chunked_sources = torch.cat(sources, axis=-1) |
| |
| |
| sources = chunked_sources |
| |
| return sources |
|
|
| def onnx_inference(self, wav_path, save_dir, device): |
| start_time = time.time() |
| input_audio, sr = torchaudio.load(wav_path, channels_first=True) |
| input_audio = input_audio.to(device) |
| |
| if input_audio.shape[0] == 1: |
| input_audio = torch.cat((input_audio, input_audio), 0) |
| if sr != 44100: |
| input_audio = torchaudio.functional.resample(input_audio.squeeze(), sr, 44100) |
|
|
| output_audio = self.demix_base({0:input_audio.squeeze()}, is_match_mix=False, device=device) |
| torchaudio.save( |
| os.path.join(save_dir, os.path.basename(wav_path).replace(".wav", "_vocal.wav")), |
| output_audio.cpu(), |
| 44100, |
| ) |
|
|
| if self.save_background: |
| raw_mix = self.demix_base({0:input_audio.squeeze()}, is_match_mix=True) |
| secondary_source, raw_mix = normalize_two_stem(output_audio*self.compensate, raw_mix, self.is_normalization) |
| secondary_source = (-secondary_source+raw_mix) |
| torchaudio.save( |
| os.path.join(save_dir, os.path.basename(wav_path)).replace(".wav", "_background.wav"), |
| secondary_source.cpu(), |
| 44100, |
| ) |
| process_time = time.time() - start_time |
| print(f"{datetime.now()} {wav_path} denoised time: {process_time:.3f}s audio len: {output_audio.shape[-1]/44100:.3f}s RTF: {output_audio.shape[-1]/44100/process_time:.3f}") |
| |
| vocal_path = os.path.join(save_dir, os.path.basename(wav_path).replace(".wav", "_vocal.wav")) |
| bg_path = os.path.join(save_dir, os.path.basename(wav_path).replace(".wav", "_background.wav")) if self.save_background else "" |
| |
| return vocal_path, bg_path |
|
|
| def normalize_two_stem(wave, mix, is_normalize=False): |
| """Save output music files""" |
| |
| maxv = torch.abs(wave).max() |
| max_mix = torch.abs(mix).max() |
| |
| if maxv > 1.0: |
| |
| |
| if is_normalize: |
| wave /= maxv |
| mix /= maxv |
| |
| return wave, mix |
|
|
|
|
| def get_wav_duration(file_path): |
| with wave.open(file_path, 'rb') as wav_file: |
| |
| n_frames = wav_file.getnframes() |
| |
| framerate = wav_file.getframerate() |
| |
| duration = n_frames / float(framerate) |
| return duration |
|
|
|
|
| def walkFile(data_dir, save_dir): |
| res_wavs = [] |
| res_txts = [] |
| for root, dirs, files in tqdm(os.walk(data_dir)): |
| |
| for f in files: |
| if f[-4:] == '.wav': |
| wav_path = os.path.join(root, f) |
| if not os.path.exists(os.path.join(save_dir, f'{f[:-4]}_Vocals.wav')): |
| res_wavs.append(wav_path) |
| |
| |
| |
| return res_wavs |
|
|
|
|
| def process_batch(files, args, device='cpu'): |
|
|
| configs = json.loads(open(args.config_path, 'r', encoding='utf-8').read()) |
| model_data = ModelData( |
| model_path=args.model_path, |
| audio_path = files, |
| result_path = args.result_path, |
| process_method = args.process_method, |
| device = device, |
| save_background = args.save_background, |
| **configs |
| ) |
| |
| |
|
|
| uvr5_model = Inference(model_data, device) |
| uvr5_model.load_model(args.model_path, args.num_processes) |
| print(f"Loaded UVR5 model in {device}.") |
|
|
| for file in files: |
| vocal_path, bg_path = uvr5_model.onnx_inference(file, os.path.join(args.result_path, os.path.basename(file)), device) |
|
|
|
|
|
|
| def parallel_process(filenames, args): |
| total_gpu = torch.cuda.device_count() |
| print(f'Total GPUs: {total_gpu}') |
| with ProcessPoolExecutor(max_workers=args.num_processes*total_gpu) as executor: |
| tasks = [] |
| for i in range(args.num_processes): |
| start = int(i * len(filenames) / args.num_processes) |
| end = int((i + 1) * len(filenames) / args.num_processes) |
| file_chunk = filenames[start:end] |
| for n in range(total_gpu): |
| chunk = file_chunk[int(n*len(file_chunk)/total_gpu): int((n+1)*len(file_chunk)/total_gpu)] |
| device = f"cuda:{n}" if torch.cuda.is_available() else "cpu" |
| print("load model in devices: ", args.num_processes, total_gpu, i, n, device) |
| tasks.append(executor.submit(process_batch, chunk, args, device)) |
|
|
| for task in tqdm(tasks): |
| task.result() |
|
|
|
|
| def parallel_process_cpu(filenames, args): |
| with ProcessPoolExecutor(max_workers=args.num_processes) as executor: |
| tasks = [] |
| for i in range(args.num_processes): |
| start = int(i * len(filenames) / args.num_processes) |
| end = int((i + 1) * len(filenames) / args.num_processes) |
| chunk = filenames[start:end] |
| print("load model in devices: ", args.num_processes, i, "cpu") |
| tasks.append(executor.submit(process_batch, chunk, args, "cpu")) |
| for task in tqdm(tasks): |
| task.result() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument('-m', '--model_path', type=str, default="models/MDX_Net_Models/model_data/Kim_Vocal_1.onnx", help='模型路径') |
| parser.add_argument('-c', '--config_path', type=str, default="models/MDX_Net_Models/model_data/MDX-Net-Kim-Vocal1.json", help='配置文件路径') |
| parser.add_argument('-a', '--audio_path', type=str, default="", help='wav文件名列表,放在raw文件夹下') |
| parser.add_argument('-r', '--result_path', type=str, default="", help='结果存储路径') |
| parser.add_argument('-p', '--process_method', type=str, default="MDX-Net", help='可选方法:["VR Arc", "MDX-Net"]') |
| parser.add_argument('-b', '--save_background', type=bool, default=True, help='True:保存人声和背景音,False:只保存人声') |
| parser.add_argument('-w', '--num_processes', type=int, default=4, help='You are advised to set the number of processes to the same as the number of CPU cores') |
|
|
| args = parser.parse_args() |
|
|
| if not os.path.exists(args.result_path): |
| os.makedirs(args.result_path, exist_ok=True) |
| if args.save_background: |
| os.makedirs(os.path.join(os.path.dirname(args.result_path), "bg_music"), exist_ok=True) |
|
|
| if os.path.isdir(args.audio_path): |
| filenames = walkFile(args.audio_path, args.result_path) |
| elif args.audio_path.endswith(".wav"): |
| filenames = [args.audio_path] |
|
|
| |
| print(len(filenames)) |
|
|
| |
|
|
| multiprocessing.set_start_method("spawn", force=True) |
|
|
| num_processes = args.num_processes |
| if num_processes == 0: |
| num_processes = os.cpu_count() |
|
|
| if torch.cuda.is_available(): |
| parallel_process(filenames, args) |
| else: |
| parallel_process_cpu(filenames, args) |
|
|