LEMAS-Edit / uvr5 /multiprocess_cuda_infer.py
Approximetal's picture
Upload folder using huggingface_hub
f36e46d verified
import logging
import multiprocessing
from concurrent.futures import ProcessPoolExecutor
from random import shuffle
from tqdm import tqdm
import sys, wave
import torch, torchaudio
import hashlib
import time, os, psutil
# Make sure we resolve imports relative to this bundled copy of uvr5
THIS_FILE = os.path.abspath(__file__)
UVR5_ROOT = os.path.dirname(THIS_FILE)
if UVR5_ROOT not in sys.path:
sys.path.append(UVR5_ROOT)
from gui_data.constants import *
from lib_v5.vr_network.model_param_init import ModelParameters
import argparse, json
import onnx
import onnxruntime as ort
import traceback
from datetime import datetime
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
class ModelData():
def __init__(self,
model_path,
audio_path,
result_path,
process_method,
device,
save_background=True,
is_pre_proc_model=False,
base_dir=UVR5_ROOT,
**kwargs):
self.__dict__.update(kwargs)
BASE_PATH = result_path
VR_MODELS_DIR = os.path.join(base_dir, 'models', 'VR_Models')
VR_HASH_JSON = os.path.join(VR_MODELS_DIR, 'model_data', 'model_data.json')
VR_PARAM_DIR = os.path.join(base_dir, 'lib_v5', 'vr_network', 'modelparams')
SAMPLE_CLIP_PATH = os.path.join(BASE_PATH, 'temp_sample_clips')
MDX_MIXER_PATH = os.path.join(base_dir, 'lib_v5', 'mixer.ckpt')
# MDX_MODELS_DIR = os.path.join(base_dir, 'models', 'MDX_Net_Models')
# MDX_HASH_DIR = (base_dir, 'models', 'MDX_Net_Models', 'model_data')
MDX_HASH_JSON = os.path.join(base_dir, 'model_data.json')
MDX_MODEL_NAME_SELECT = os.path.join(base_dir, 'model_name_mapper.json')
self.model_name = self.model_name
self.aggression_setting = float(int(self.aggression_setting)/100) # 1 - 20
self.window_size = int(self.window_size)
self.batch_size = int(self.batch_size) if self.batch_size.isdigit() else 1
self.mdx_batch_size = 1 if self.mdx_batch_size == DEF_OPT else int(self.mdx_batch_size)
self.is_mdx_ckpt = False
self.crop_size = int(self.crop_size)
self.is_high_end_process = 'mirroring' if self.is_high_end_process else 'None'
self.post_process_threshold = float(self.post_process_threshold)
self.model_capacity = 32, 128
self.model_path = model_path
self.result_path = result_path
self.model_basename = os.path.splitext(os.path.basename(self.model_path))[0]
self.mixer_path = MDX_MIXER_PATH
self.process_method = process_method
self.is_pre_proc_model = is_pre_proc_model
self.vr_is_secondary_model = self.vr_is_secondary_model_activate
self.mdx_is_secondary_model = self.mdx_is_secondary_model_activate
self.is_ensemble_mode = False
self.secondary_model = None
self.primary_model_primary_stem = None
self.primary_stem = None
self.secondary_stem = None
self.secondary_model_scale = None
self.is_demucs_pre_proc_model_inst_mix = False
self.device = device
self.save_background = save_background
if type(audio_path)==str and os.path.isdir(audio_path):
self.inputPaths = os.listdir(audio_path)
self.inputPaths = [os.path.join(audio_path, x) for x in self.inputPaths if x[-4:]=='.wav']
elif type(audio_path)==str and audio_path[-4:] == '.wav':
self.inputPaths = [audio_path]
elif type(audio_path) == list and audio_path[0][-4:] == '.wav':
self.inputPaths = audio_path
else:
print(f"Invalid audio_path {audio_path}")
self.get_model_hash()
if self.process_method == VR_ARCH_TYPE:
self.model_data = json.loads(open(VR_HASH_JSON, 'r', encoding='utf-8').read())[self.model_hash]
if self.model_data:
vr_model_param = os.path.join(VR_PARAM_DIR, "{}.json".format(self.model_data["vr_model_param"]))
self.primary_stem = self.model_data["primary_stem"]
self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem]
self.vr_model_param = ModelParameters(vr_model_param)
self.model_samplerate = self.vr_model_param.param['sr']
if "nout" in self.model_data.keys() and "nout_lstm" in self.model_data.keys():
self.model_capacity = self.model_data["nout"], self.model_data["nout_lstm"]
self.is_vr_51_model = True
else:
self.model_status = False
if self.process_method == MDX_ARCH_TYPE:
self.is_vr_51_model = False
self.margin = int(self.margin)
self.model_samplerate = self.margin
self.chunks = self.determine_auto_chunks(self.chunks) if self.is_chunk_mdxnet else 0
self.model_data = json.loads(open(MDX_HASH_JSON, 'r', encoding='utf-8').read())[self.model_hash]
if self.model_data:
self.is_secondary_model = self.mdx_is_secondary_model
self.compensate = self.model_data["compensate"]
self.mdx_dim_f_set = self.model_data["mdx_dim_f_set"]
self.mdx_dim_t_set = self.model_data["mdx_dim_t_set"]
self.mdx_n_fft_scale_set = self.model_data["mdx_n_fft_scale_set"]
self.primary_stem = self.model_data["primary_stem"]
self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem]
else:
self.model_status = False
def determine_auto_chunks(self, chunks):
"""Determines appropriate chunk size based on user computer specs"""
gpu = 0 if torch.cuda.device_count() > 0 else -1
if OPERATING_SYSTEM == 'Darwin':
gpu = -1
if chunks == BATCH_MODE:
chunks = 0
#self.chunks_var.set(AUTO_SELECT)
if chunks == 'Full':
chunk_set = 0
elif chunks == 'Auto':
if gpu == 0:
gpu_mem = round(torch.cuda.get_device_properties(0).total_memory/1.074e+9)
if gpu_mem <= int(6):
chunk_set = int(5)
if gpu_mem in [7, 8, 9, 10, 11, 12, 13, 14, 15]:
chunk_set = int(10)
if gpu_mem >= int(16):
chunk_set = int(40)
if gpu == -1:
sys_mem = psutil.virtual_memory().total >> 30
if sys_mem <= int(4):
chunk_set = int(1)
if sys_mem in [5, 6, 7, 8]:
chunk_set = int(10)
if sys_mem in [9, 10, 11, 12, 13, 14, 15, 16]:
chunk_set = int(25)
if sys_mem >= int(17):
chunk_set = int(60)
elif chunks == '0':
chunk_set = 0
else:
chunk_set = int(chunks)
print("chunks: ", gpu_mem, chunk_set)
return chunk_set
def get_model_hash(self):
self.model_hash = None
if not os.path.isfile(self.model_path):
self.model_status = False
self.model_hash is None
else:
if not self.model_hash:
try:
with open(self.model_path, 'rb') as f:
f.seek(- 10000 * 1024, 2)
self.model_hash = hashlib.md5(f.read()).hexdigest()
except:
self.model_hash = hashlib.md5(open(self.model_path,'rb').read()).hexdigest()
class Inference():
def __init__(self, model_data: ModelData, device):
self.device = device
self.n_fft = model_data.mdx_n_fft_scale_set
self.is_normalization = model_data.is_normalization
self.compensate = model_data.compensate
self.dim_f, self.dim_t = model_data.mdx_dim_f_set, 2**model_data.mdx_dim_t_set
self.mdx_batch_size = model_data.mdx_batch_size
self.is_denoise = model_data.is_denoise
self.hop = 1024
self.dim_c = 4
self.chunks = model_data.chunks
self.margin = model_data.margin
self.adjust = 1
self.progress_value = 0
self.n_bins = self.n_fft//2+1
self.trim = self.n_fft//2
self.chunk_size = self.hop * (self.dim_t-1)
self.window = torch.hann_window(window_length=self.n_fft, periodic=False).to(self.device)
self.freq_pad = torch.zeros([1, self.dim_c, self.n_bins-self.dim_f, self.dim_t]).to(self.device)
self.gen_size = self.chunk_size-2*self.trim
self.save_background = model_data.save_background
def stft(self, x):
x = x.reshape([-1, self.chunk_size])
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True,return_complex=True)
x=torch.view_as_real(x)
x = x.permute([0,3,1,2])
x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,self.dim_c,self.n_bins,self.dim_t])
return x[:,:,:self.dim_f]
def istft(self, x, freq_pad=None):
freq_pad = self.freq_pad.repeat([x.shape[0],1,1,1]) if freq_pad is None else freq_pad
x = torch.cat([x, freq_pad], -2)
x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t])
x = x.permute([0,2,3,1])
x=x.contiguous()
x=torch.view_as_complex(x)
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
return x.reshape([-1,2,self.chunk_size])
def load_model(self, model_path, threads):
model = onnx.load_model(model_path)
if torch.cuda.is_available():
providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
"user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
else:
providers = ["CPUExecutionProvider"]
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = threads
# sess_options.enable_profiling = True # debug 时开启
self.ort_ = ort.InferenceSession(model.SerializeToString(), sess_options=sess_options, providers=providers)
self.model_run = lambda spek:self.ort_.run(None, {'input': spek.cpu().numpy()})[0]
def initialize_mix(self, mix):
mix_waves = []
n_sample = mix.shape[1]
pad = self.gen_size - n_sample%self.gen_size
zero_pad = torch.zeros((2,self.trim), device=mix.device)
# print("mix:", mix.shape, mix.device, "zero_pad:", zero_pad.shape, zero_pad.device)
mix_p = torch.cat((zero_pad, mix, torch.zeros((2,pad), device=mix.device), zero_pad), 1)
i = 0
while i < n_sample + pad:
waves = mix_p[:, i:i+self.chunk_size]
mix_waves.append(waves.unsqueeze(0))
i += self.gen_size
# print("debug 7:", i, waves, waves.shape, self.gen_size)
mix_waves = torch.cat(mix_waves, 0).to(self.device)
# print("debug 8:", mix_waves, mix_waves.shape, self.device, pad)
return mix_waves, pad
def run_model(self, mix, is_match_mix=False):
spek = self.stft(mix.to(self.device))*self.adjust
spek[:, :, :3, :] *= 0
# print("spek input:", spek.device, spek.shape)
if is_match_mix:
spec_pred = spek.to(self.device)
else:
spec_pred = -self.model_run(-spek)*0.5+self.model_run(spek)*0.5 if self.is_denoise else self.model_run(spek)
spec_pred = torch.from_numpy(spec_pred).to(self.device)
# print("is_denoise:", self.is_denoise, "spec_pred:", spec_pred.dtype, type(spec_pred))
return self.istft(spec_pred).to(self.device)[:,:,self.trim:-self.trim].transpose(0,1).reshape(2, -1)
def demix_base(self, mix, is_match_mix=False, device='cpu'):
chunked_sources = []
for slice in mix:
# print("debug 6:", mix, slice, is_match_mix)
sources = []
tar_waves_ = []
mix_p = mix[slice]
# print("demix_base: ", mix_p.shape, mix_p.device)
mix_waves, pad = self.initialize_mix(mix_p.to(device))
mix_waves = mix_waves.split(self.mdx_batch_size)
with torch.no_grad():
for mix_wave in mix_waves:
# self.running_inference_progress_bar(len(mix)*len(mix_waves), is_match_mix=is_match_mix)
# print("debug10:", mix_wave, mix_wave.shape, is_match_mix)
tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix)
tar_waves_.append(tar_waves)
tar_waves = torch.cat(tar_waves_, axis=-1)[:, :-pad]
start = 0 if slice == 0 else self.margin
end = None if slice == list(mix.keys())[::-1][0] or self.margin == 0 else -self.margin
sources.append(tar_waves[:,start:end]*(1/self.adjust))
chunked_sources = torch.cat(sources, axis=-1)
# print("debug 11:",chunked_sources, len(chunked_sources), chunked_sources.shape)
# sources = torch.cat(chunked_sources, axis=-1)
sources = chunked_sources
# print("debug 4:", sources, sources.shape)
return sources
def onnx_inference(self, wav_path, save_dir, device):
start_time = time.time()
input_audio, sr = torchaudio.load(wav_path, channels_first=True)
input_audio = input_audio.to(device)
# input_audio = input_audio.mean(dim=0).unsqueeze(0) # stereo to mono
if input_audio.shape[0] == 1:
input_audio = torch.cat((input_audio, input_audio), 0) # mono to stereo
if sr != 44100:
input_audio = torchaudio.functional.resample(input_audio.squeeze(), sr, 44100)
output_audio = self.demix_base({0:input_audio.squeeze()}, is_match_mix=False, device=device)
torchaudio.save(
os.path.join(save_dir, os.path.basename(wav_path).replace(".wav", "_vocal.wav")),
output_audio.cpu(),
44100,
)
if self.save_background:
raw_mix = self.demix_base({0:input_audio.squeeze()}, is_match_mix=True)
secondary_source, raw_mix = normalize_two_stem(output_audio*self.compensate, raw_mix, self.is_normalization)
secondary_source = (-secondary_source+raw_mix)
torchaudio.save(
os.path.join(save_dir, os.path.basename(wav_path)).replace(".wav", "_background.wav"),
secondary_source.cpu(),
44100,
)
process_time = time.time() - start_time
print(f"{datetime.now()} {wav_path} denoised time: {process_time:.3f}s audio len: {output_audio.shape[-1]/44100:.3f}s RTF: {output_audio.shape[-1]/44100/process_time:.3f}")
vocal_path = os.path.join(save_dir, os.path.basename(wav_path).replace(".wav", "_vocal.wav"))
bg_path = os.path.join(save_dir, os.path.basename(wav_path).replace(".wav", "_background.wav")) if self.save_background else ""
return vocal_path, bg_path
def normalize_two_stem(wave, mix, is_normalize=False):
"""Save output music files"""
maxv = torch.abs(wave).max()
max_mix = torch.abs(mix).max()
if maxv > 1.0:
# print(f"\nNormalization Set {is_normalize}: Primary source above threshold for clipping. Max:{maxv}")
# print(f"\nNormalization Set {is_normalize}: Mixture above threshold for clipping. Max:{max_mix}")
if is_normalize:
wave /= maxv
mix /= maxv
return wave, mix
def get_wav_duration(file_path):
with wave.open(file_path, 'rb') as wav_file:
# 获取音频帧数
n_frames = wav_file.getnframes()
# 获取采样率
framerate = wav_file.getframerate()
# 计算时长(秒)
duration = n_frames / float(framerate)
return duration
def walkFile(data_dir, save_dir):
res_wavs = []
res_txts = []
for root, dirs, files in tqdm(os.walk(data_dir)):
# 遍历文件
for f in files:
if f[-4:] == '.wav':
wav_path = os.path.join(root, f)
if not os.path.exists(os.path.join(save_dir, f'{f[:-4]}_Vocals.wav')):
res_wavs.append(wav_path)
# elif f[-4:] == '.csv':
# res_txts.append(os.path.join(root, f))
return res_wavs # , res_txts
def process_batch(files, args, device='cpu'):
configs = json.loads(open(args.config_path, 'r', encoding='utf-8').read())
model_data = ModelData(
model_path=args.model_path,
audio_path = files,
result_path = args.result_path,
process_method = args.process_method,
device = device,
save_background = args.save_background,
**configs
)
# uvr5_model = Inference_raw(model_data, device)
# uvr5_model.process_start()
uvr5_model = Inference(model_data, device)
uvr5_model.load_model(args.model_path, args.num_processes)
print(f"Loaded UVR5 model in {device}.")
for file in files:
vocal_path, bg_path = uvr5_model.onnx_inference(file, os.path.join(args.result_path, os.path.basename(file)), device)
def parallel_process(filenames, args):
total_gpu = torch.cuda.device_count()
print(f'Total GPUs: {total_gpu}')
with ProcessPoolExecutor(max_workers=args.num_processes*total_gpu) as executor:
tasks = []
for i in range(args.num_processes):
start = int(i * len(filenames) / args.num_processes)
end = int((i + 1) * len(filenames) / args.num_processes)
file_chunk = filenames[start:end]
for n in range(total_gpu):
chunk = file_chunk[int(n*len(file_chunk)/total_gpu): int((n+1)*len(file_chunk)/total_gpu)]
device = f"cuda:{n}" if torch.cuda.is_available() else "cpu"
print("load model in devices: ", args.num_processes, total_gpu, i, n, device)
tasks.append(executor.submit(process_batch, chunk, args, device))
for task in tqdm(tasks):
task.result()
def parallel_process_cpu(filenames, args):
with ProcessPoolExecutor(max_workers=args.num_processes) as executor:
tasks = []
for i in range(args.num_processes):
start = int(i * len(filenames) / args.num_processes)
end = int((i + 1) * len(filenames) / args.num_processes)
chunk = filenames[start:end]
print("load model in devices: ", args.num_processes, i, "cpu")
tasks.append(executor.submit(process_batch, chunk, args, "cpu"))
for task in tqdm(tasks):
task.result()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_path', type=str, default="models/MDX_Net_Models/model_data/Kim_Vocal_1.onnx", help='模型路径')
parser.add_argument('-c', '--config_path', type=str, default="models/MDX_Net_Models/model_data/MDX-Net-Kim-Vocal1.json", help='配置文件路径')
parser.add_argument('-a', '--audio_path', type=str, default="", help='wav文件名列表,放在raw文件夹下')
parser.add_argument('-r', '--result_path', type=str, default="", help='结果存储路径')
parser.add_argument('-p', '--process_method', type=str, default="MDX-Net", help='可选方法:["VR Arc", "MDX-Net"]')
parser.add_argument('-b', '--save_background', type=bool, default=True, help='True:保存人声和背景音,False:只保存人声')
parser.add_argument('-w', '--num_processes', type=int, default=4, help='You are advised to set the number of processes to the same as the number of CPU cores')
args = parser.parse_args()
if not os.path.exists(args.result_path):
os.makedirs(args.result_path, exist_ok=True)
if args.save_background:
os.makedirs(os.path.join(os.path.dirname(args.result_path), "bg_music"), exist_ok=True)
if os.path.isdir(args.audio_path):
filenames = walkFile(args.audio_path, args.result_path)
elif args.audio_path.endswith(".wav"):
filenames = [args.audio_path]
# shuffle(filenames)
print(len(filenames))
# process_batch(filenames, args, "cpu")
multiprocessing.set_start_method("spawn", force=True)
num_processes = args.num_processes
if num_processes == 0:
num_processes = os.cpu_count()
if torch.cuda.is_available():
parallel_process(filenames, args)
else:
parallel_process_cpu(filenames, args)