Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import multiprocessing | |
| from concurrent.futures import ProcessPoolExecutor | |
| from random import shuffle | |
| from tqdm import tqdm | |
| import sys, wave | |
| import torch, torchaudio | |
| import hashlib | |
| import time, os, psutil | |
| # Make sure we resolve imports relative to this bundled copy of uvr5 | |
| THIS_FILE = os.path.abspath(__file__) | |
| UVR5_ROOT = os.path.dirname(THIS_FILE) | |
| if UVR5_ROOT not in sys.path: | |
| sys.path.append(UVR5_ROOT) | |
| from gui_data.constants import * | |
| from lib_v5.vr_network.model_param_init import ModelParameters | |
| import argparse, json | |
| import onnx | |
| import onnxruntime as ort | |
| import traceback | |
| from datetime import datetime | |
| logging.getLogger("numba").setLevel(logging.WARNING) | |
| logging.getLogger("matplotlib").setLevel(logging.WARNING) | |
| class ModelData(): | |
| def __init__(self, | |
| model_path, | |
| audio_path, | |
| result_path, | |
| process_method, | |
| device, | |
| save_background=True, | |
| is_pre_proc_model=False, | |
| base_dir=UVR5_ROOT, | |
| **kwargs): | |
| self.__dict__.update(kwargs) | |
| BASE_PATH = result_path | |
| VR_MODELS_DIR = os.path.join(base_dir, 'models', 'VR_Models') | |
| VR_HASH_JSON = os.path.join(VR_MODELS_DIR, 'model_data', 'model_data.json') | |
| VR_PARAM_DIR = os.path.join(base_dir, 'lib_v5', 'vr_network', 'modelparams') | |
| SAMPLE_CLIP_PATH = os.path.join(BASE_PATH, 'temp_sample_clips') | |
| MDX_MIXER_PATH = os.path.join(base_dir, 'lib_v5', 'mixer.ckpt') | |
| # MDX_MODELS_DIR = os.path.join(base_dir, 'models', 'MDX_Net_Models') | |
| # MDX_HASH_DIR = (base_dir, 'models', 'MDX_Net_Models', 'model_data') | |
| MDX_HASH_JSON = os.path.join(base_dir, 'model_data.json') | |
| MDX_MODEL_NAME_SELECT = os.path.join(base_dir, 'model_name_mapper.json') | |
| self.model_name = self.model_name | |
| self.aggression_setting = float(int(self.aggression_setting)/100) # 1 - 20 | |
| self.window_size = int(self.window_size) | |
| self.batch_size = int(self.batch_size) if self.batch_size.isdigit() else 1 | |
| self.mdx_batch_size = 1 if self.mdx_batch_size == DEF_OPT else int(self.mdx_batch_size) | |
| self.is_mdx_ckpt = False | |
| self.crop_size = int(self.crop_size) | |
| self.is_high_end_process = 'mirroring' if self.is_high_end_process else 'None' | |
| self.post_process_threshold = float(self.post_process_threshold) | |
| self.model_capacity = 32, 128 | |
| self.model_path = model_path | |
| self.result_path = result_path | |
| self.model_basename = os.path.splitext(os.path.basename(self.model_path))[0] | |
| self.mixer_path = MDX_MIXER_PATH | |
| self.process_method = process_method | |
| self.is_pre_proc_model = is_pre_proc_model | |
| self.vr_is_secondary_model = self.vr_is_secondary_model_activate | |
| self.mdx_is_secondary_model = self.mdx_is_secondary_model_activate | |
| self.is_ensemble_mode = False | |
| self.secondary_model = None | |
| self.primary_model_primary_stem = None | |
| self.primary_stem = None | |
| self.secondary_stem = None | |
| self.secondary_model_scale = None | |
| self.is_demucs_pre_proc_model_inst_mix = False | |
| self.device = device | |
| self.save_background = save_background | |
| if type(audio_path)==str and os.path.isdir(audio_path): | |
| self.inputPaths = os.listdir(audio_path) | |
| self.inputPaths = [os.path.join(audio_path, x) for x in self.inputPaths if x[-4:]=='.wav'] | |
| elif type(audio_path)==str and audio_path[-4:] == '.wav': | |
| self.inputPaths = [audio_path] | |
| elif type(audio_path) == list and audio_path[0][-4:] == '.wav': | |
| self.inputPaths = audio_path | |
| else: | |
| print(f"Invalid audio_path {audio_path}") | |
| self.get_model_hash() | |
| if self.process_method == VR_ARCH_TYPE: | |
| self.model_data = json.loads(open(VR_HASH_JSON, 'r', encoding='utf-8').read())[self.model_hash] | |
| if self.model_data: | |
| vr_model_param = os.path.join(VR_PARAM_DIR, "{}.json".format(self.model_data["vr_model_param"])) | |
| self.primary_stem = self.model_data["primary_stem"] | |
| self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem] | |
| self.vr_model_param = ModelParameters(vr_model_param) | |
| self.model_samplerate = self.vr_model_param.param['sr'] | |
| if "nout" in self.model_data.keys() and "nout_lstm" in self.model_data.keys(): | |
| self.model_capacity = self.model_data["nout"], self.model_data["nout_lstm"] | |
| self.is_vr_51_model = True | |
| else: | |
| self.model_status = False | |
| if self.process_method == MDX_ARCH_TYPE: | |
| self.is_vr_51_model = False | |
| self.margin = int(self.margin) | |
| self.model_samplerate = self.margin | |
| self.chunks = self.determine_auto_chunks(self.chunks) if self.is_chunk_mdxnet else 0 | |
| self.model_data = json.loads(open(MDX_HASH_JSON, 'r', encoding='utf-8').read())[self.model_hash] | |
| if self.model_data: | |
| self.is_secondary_model = self.mdx_is_secondary_model | |
| self.compensate = self.model_data["compensate"] | |
| self.mdx_dim_f_set = self.model_data["mdx_dim_f_set"] | |
| self.mdx_dim_t_set = self.model_data["mdx_dim_t_set"] | |
| self.mdx_n_fft_scale_set = self.model_data["mdx_n_fft_scale_set"] | |
| self.primary_stem = self.model_data["primary_stem"] | |
| self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem] | |
| else: | |
| self.model_status = False | |
| def determine_auto_chunks(self, chunks): | |
| """Determines appropriate chunk size based on user computer specs""" | |
| gpu = 0 if torch.cuda.device_count() > 0 else -1 | |
| if OPERATING_SYSTEM == 'Darwin': | |
| gpu = -1 | |
| if chunks == BATCH_MODE: | |
| chunks = 0 | |
| #self.chunks_var.set(AUTO_SELECT) | |
| if chunks == 'Full': | |
| chunk_set = 0 | |
| elif chunks == 'Auto': | |
| if gpu == 0: | |
| gpu_mem = round(torch.cuda.get_device_properties(0).total_memory/1.074e+9) | |
| if gpu_mem <= int(6): | |
| chunk_set = int(5) | |
| if gpu_mem in [7, 8, 9, 10, 11, 12, 13, 14, 15]: | |
| chunk_set = int(10) | |
| if gpu_mem >= int(16): | |
| chunk_set = int(40) | |
| if gpu == -1: | |
| sys_mem = psutil.virtual_memory().total >> 30 | |
| if sys_mem <= int(4): | |
| chunk_set = int(1) | |
| if sys_mem in [5, 6, 7, 8]: | |
| chunk_set = int(10) | |
| if sys_mem in [9, 10, 11, 12, 13, 14, 15, 16]: | |
| chunk_set = int(25) | |
| if sys_mem >= int(17): | |
| chunk_set = int(60) | |
| elif chunks == '0': | |
| chunk_set = 0 | |
| else: | |
| chunk_set = int(chunks) | |
| print("chunks: ", gpu_mem, chunk_set) | |
| return chunk_set | |
| def get_model_hash(self): | |
| self.model_hash = None | |
| if not os.path.isfile(self.model_path): | |
| self.model_status = False | |
| self.model_hash is None | |
| else: | |
| if not self.model_hash: | |
| try: | |
| with open(self.model_path, 'rb') as f: | |
| f.seek(- 10000 * 1024, 2) | |
| self.model_hash = hashlib.md5(f.read()).hexdigest() | |
| except: | |
| self.model_hash = hashlib.md5(open(self.model_path,'rb').read()).hexdigest() | |
| class Inference(): | |
| def __init__(self, model_data: ModelData, device): | |
| self.device = device | |
| self.n_fft = model_data.mdx_n_fft_scale_set | |
| self.is_normalization = model_data.is_normalization | |
| self.compensate = model_data.compensate | |
| self.dim_f, self.dim_t = model_data.mdx_dim_f_set, 2**model_data.mdx_dim_t_set | |
| self.mdx_batch_size = model_data.mdx_batch_size | |
| self.is_denoise = model_data.is_denoise | |
| self.hop = 1024 | |
| self.dim_c = 4 | |
| self.chunks = model_data.chunks | |
| self.margin = model_data.margin | |
| self.adjust = 1 | |
| self.progress_value = 0 | |
| self.n_bins = self.n_fft//2+1 | |
| self.trim = self.n_fft//2 | |
| self.chunk_size = self.hop * (self.dim_t-1) | |
| self.window = torch.hann_window(window_length=self.n_fft, periodic=False).to(self.device) | |
| self.freq_pad = torch.zeros([1, self.dim_c, self.n_bins-self.dim_f, self.dim_t]).to(self.device) | |
| self.gen_size = self.chunk_size-2*self.trim | |
| self.save_background = model_data.save_background | |
| def stft(self, x): | |
| x = x.reshape([-1, self.chunk_size]) | |
| x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True,return_complex=True) | |
| x=torch.view_as_real(x) | |
| x = x.permute([0,3,1,2]) | |
| x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,self.dim_c,self.n_bins,self.dim_t]) | |
| return x[:,:,:self.dim_f] | |
| def istft(self, x, freq_pad=None): | |
| freq_pad = self.freq_pad.repeat([x.shape[0],1,1,1]) if freq_pad is None else freq_pad | |
| x = torch.cat([x, freq_pad], -2) | |
| x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t]) | |
| x = x.permute([0,2,3,1]) | |
| x=x.contiguous() | |
| x=torch.view_as_complex(x) | |
| x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True) | |
| return x.reshape([-1,2,self.chunk_size]) | |
| def load_model(self, model_path, threads, device='cpu'): | |
| model = onnx.load_model(model_path) | |
| if torch.cuda.is_available() and device != 'cpu': | |
| providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(), | |
| "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})] | |
| else: | |
| providers = ["CPUExecutionProvider"] | |
| sess_options = ort.SessionOptions() | |
| sess_options.intra_op_num_threads = threads | |
| # sess_options.enable_profiling = True # debug 时开启 | |
| self.ort_ = ort.InferenceSession(model.SerializeToString(), sess_options=sess_options, providers=providers) | |
| self.model_run = lambda spek:self.ort_.run(None, {'input': spek.cpu().numpy()})[0] | |
| def initialize_mix(self, mix): | |
| mix_waves = [] | |
| n_sample = mix.shape[1] | |
| pad = self.gen_size - n_sample%self.gen_size | |
| zero_pad = torch.zeros((2,self.trim), device=mix.device) | |
| # print("mix:", mix.shape, mix.device, "zero_pad:", zero_pad.shape, zero_pad.device) | |
| mix_p = torch.cat((zero_pad, mix, torch.zeros((2,pad), device=mix.device), zero_pad), 1) | |
| i = 0 | |
| while i < n_sample + pad: | |
| waves = mix_p[:, i:i+self.chunk_size] | |
| mix_waves.append(waves.unsqueeze(0)) | |
| i += self.gen_size | |
| # print("debug 7:", i, waves, waves.shape, self.gen_size) | |
| mix_waves = torch.cat(mix_waves, 0).to(self.device) | |
| # print("debug 8:", mix_waves, mix_waves.shape, self.device, pad) | |
| return mix_waves, pad | |
| def run_model(self, mix, is_match_mix=False): | |
| spek = self.stft(mix.to(self.device))*self.adjust | |
| spek[:, :, :3, :] *= 0 | |
| # print("spek input:", spek.device, spek.shape) | |
| if is_match_mix: | |
| spec_pred = spek.to(self.device) | |
| else: | |
| spec_pred = -self.model_run(-spek)*0.5+self.model_run(spek)*0.5 if self.is_denoise else self.model_run(spek) | |
| spec_pred = torch.from_numpy(spec_pred).to(self.device) | |
| # print("is_denoise:", self.is_denoise, "spec_pred:", spec_pred.dtype, type(spec_pred)) | |
| return self.istft(spec_pred).to(self.device)[:,:,self.trim:-self.trim].transpose(0,1).reshape(2, -1) | |
| def demix_base(self, mix, is_match_mix=False, device='cpu'): | |
| chunked_sources = [] | |
| for slice in mix: | |
| # print("debug 6:", mix, slice, is_match_mix) | |
| sources = [] | |
| tar_waves_ = [] | |
| mix_p = mix[slice] | |
| # print("demix_base: ", mix_p.shape, mix_p.device) | |
| mix_waves, pad = self.initialize_mix(mix_p.to(device)) | |
| mix_waves = mix_waves.split(self.mdx_batch_size) | |
| with torch.no_grad(): | |
| for mix_wave in mix_waves: | |
| # self.running_inference_progress_bar(len(mix)*len(mix_waves), is_match_mix=is_match_mix) | |
| # print("debug10:", mix_wave, mix_wave.shape, is_match_mix) | |
| tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix) | |
| tar_waves_.append(tar_waves) | |
| tar_waves = torch.cat(tar_waves_, axis=-1)[:, :-pad] | |
| start = 0 if slice == 0 else self.margin | |
| end = None if slice == list(mix.keys())[::-1][0] or self.margin == 0 else -self.margin | |
| sources.append(tar_waves[:,start:end]*(1/self.adjust)) | |
| chunked_sources = torch.cat(sources, axis=-1) | |
| # print("debug 11:",chunked_sources, len(chunked_sources), chunked_sources.shape) | |
| # sources = torch.cat(chunked_sources, axis=-1) | |
| sources = chunked_sources | |
| # print("debug 4:", sources, sources.shape) | |
| return sources | |
| def onnx_inference(self, wav_path, save_dir, device): | |
| start_time = time.time() | |
| input_audio, sr = torchaudio.load(wav_path, channels_first=True) | |
| input_audio = input_audio.to(device) | |
| # input_audio = input_audio.mean(dim=0).unsqueeze(0) # stereo to mono | |
| if input_audio.shape[0] == 1: | |
| input_audio = torch.cat((input_audio, input_audio), 0) # mono to stereo | |
| if sr != 44100: | |
| input_audio = torchaudio.functional.resample(input_audio.squeeze(), sr, 44100) | |
| output_audio = self.demix_base({0:input_audio.squeeze()}, is_match_mix=False, device=device) | |
| torchaudio.save( | |
| os.path.join(save_dir, os.path.basename(wav_path).replace(".wav", "_vocal.wav")), | |
| output_audio.cpu(), | |
| 44100, | |
| ) | |
| if self.save_background: | |
| raw_mix = self.demix_base({0:input_audio.squeeze()}, is_match_mix=True) | |
| secondary_source, raw_mix = normalize_two_stem(output_audio*self.compensate, raw_mix, self.is_normalization) | |
| secondary_source = (-secondary_source+raw_mix) | |
| torchaudio.save( | |
| os.path.join(save_dir, os.path.basename(wav_path)).replace(".wav", "_background.wav"), | |
| secondary_source.cpu(), | |
| 44100, | |
| ) | |
| process_time = time.time() - start_time | |
| print(f"{datetime.now()} {wav_path} denoised time: {process_time:.3f}s audio len: {output_audio.shape[-1]/44100:.3f}s RTF: {output_audio.shape[-1]/44100/process_time:.3f}") | |
| vocal_path = os.path.join(save_dir, os.path.basename(wav_path).replace(".wav", "_vocal.wav")) | |
| bg_path = os.path.join(save_dir, os.path.basename(wav_path).replace(".wav", "_background.wav")) if self.save_background else "" | |
| return vocal_path, bg_path | |
| def normalize_two_stem(wave, mix, is_normalize=False): | |
| """Save output music files""" | |
| maxv = torch.abs(wave).max() | |
| max_mix = torch.abs(mix).max() | |
| if maxv > 1.0: | |
| # print(f"\nNormalization Set {is_normalize}: Primary source above threshold for clipping. Max:{maxv}") | |
| # print(f"\nNormalization Set {is_normalize}: Mixture above threshold for clipping. Max:{max_mix}") | |
| if is_normalize: | |
| wave /= maxv | |
| mix /= maxv | |
| return wave, mix | |
| def get_wav_duration(file_path): | |
| with wave.open(file_path, 'rb') as wav_file: | |
| # 获取音频帧数 | |
| n_frames = wav_file.getnframes() | |
| # 获取采样率 | |
| framerate = wav_file.getframerate() | |
| # 计算时长(秒) | |
| duration = n_frames / float(framerate) | |
| return duration | |
| def walkFile(data_dir, save_dir): | |
| res_wavs = [] | |
| res_txts = [] | |
| for root, dirs, files in tqdm(os.walk(data_dir)): | |
| # 遍历文件 | |
| for f in files: | |
| if f[-4:] == '.wav': | |
| wav_path = os.path.join(root, f) | |
| if not os.path.exists(os.path.join(save_dir, f'{f[:-4]}_Vocals.wav')): | |
| res_wavs.append(wav_path) | |
| # elif f[-4:] == '.csv': | |
| # res_txts.append(os.path.join(root, f)) | |
| return res_wavs # , res_txts | |
| def process_batch(files, args, device='cpu'): | |
| configs = json.loads(open(args.config_path, 'r', encoding='utf-8').read()) | |
| model_data = ModelData( | |
| model_path=args.model_path, | |
| audio_path = files, | |
| result_path = args.result_path, | |
| process_method = args.process_method, | |
| device = device, | |
| save_background = args.save_background, | |
| **configs | |
| ) | |
| # uvr5_model = Inference_raw(model_data, device) | |
| # uvr5_model.process_start() | |
| uvr5_model = Inference(model_data, device) | |
| uvr5_model.load_model(args.model_path, args.num_processes) | |
| print(f"Loaded UVR5 model in {device}.") | |
| for file in files: | |
| vocal_path, bg_path = uvr5_model.onnx_inference(file, os.path.join(args.result_path, os.path.basename(file)), device) | |
| def parallel_process(filenames, args): | |
| total_gpu = torch.cuda.device_count() | |
| print(f'Total GPUs: {total_gpu}') | |
| with ProcessPoolExecutor(max_workers=args.num_processes*total_gpu) as executor: | |
| tasks = [] | |
| for i in range(args.num_processes): | |
| start = int(i * len(filenames) / args.num_processes) | |
| end = int((i + 1) * len(filenames) / args.num_processes) | |
| file_chunk = filenames[start:end] | |
| for n in range(total_gpu): | |
| chunk = file_chunk[int(n*len(file_chunk)/total_gpu): int((n+1)*len(file_chunk)/total_gpu)] | |
| device = f"cuda:{n}" if torch.cuda.is_available() else "cpu" | |
| print("load model in devices: ", args.num_processes, total_gpu, i, n, device) | |
| tasks.append(executor.submit(process_batch, chunk, args, device)) | |
| for task in tqdm(tasks): | |
| task.result() | |
| def parallel_process_cpu(filenames, args): | |
| with ProcessPoolExecutor(max_workers=args.num_processes) as executor: | |
| tasks = [] | |
| for i in range(args.num_processes): | |
| start = int(i * len(filenames) / args.num_processes) | |
| end = int((i + 1) * len(filenames) / args.num_processes) | |
| chunk = filenames[start:end] | |
| print("load model in devices: ", args.num_processes, i, "cpu") | |
| tasks.append(executor.submit(process_batch, chunk, args, "cpu")) | |
| for task in tqdm(tasks): | |
| task.result() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-m', '--model_path', type=str, default="models/MDX_Net_Models/model_data/Kim_Vocal_1.onnx", help='模型路径') | |
| parser.add_argument('-c', '--config_path', type=str, default="models/MDX_Net_Models/model_data/MDX-Net-Kim-Vocal1.json", help='配置文件路径') | |
| parser.add_argument('-a', '--audio_path', type=str, default="", help='wav文件名列表,放在raw文件夹下') | |
| parser.add_argument('-r', '--result_path', type=str, default="", help='结果存储路径') | |
| parser.add_argument('-p', '--process_method', type=str, default="MDX-Net", help='可选方法:["VR Arc", "MDX-Net"]') | |
| parser.add_argument('-b', '--save_background', type=bool, default=True, help='True:保存人声和背景音,False:只保存人声') | |
| parser.add_argument('-w', '--num_processes', type=int, default=4, help='You are advised to set the number of processes to the same as the number of CPU cores') | |
| args = parser.parse_args() | |
| if not os.path.exists(args.result_path): | |
| os.makedirs(args.result_path, exist_ok=True) | |
| if args.save_background: | |
| os.makedirs(os.path.join(os.path.dirname(args.result_path), "bg_music"), exist_ok=True) | |
| if os.path.isdir(args.audio_path): | |
| filenames = walkFile(args.audio_path, args.result_path) | |
| elif args.audio_path.endswith(".wav"): | |
| filenames = [args.audio_path] | |
| # shuffle(filenames) | |
| print(len(filenames)) | |
| # process_batch(filenames, args, "cpu") | |
| multiprocessing.set_start_method("spawn", force=True) | |
| num_processes = args.num_processes | |
| if num_processes == 0: | |
| num_processes = os.cpu_count() | |
| if torch.cuda.is_available(): | |
| parallel_process(filenames, args) | |
| else: | |
| parallel_process_cpu(filenames, args) | |