LEMAS-TTS / uvr5 /multiprocess_cuda_infer.py
Approximetal's picture
Update uvr5/multiprocess_cuda_infer.py
b742966 verified
import logging
import multiprocessing
from concurrent.futures import ProcessPoolExecutor
from random import shuffle
from tqdm import tqdm
import sys, wave
import torch, torchaudio
import hashlib
import time, os, psutil
# Make sure we resolve imports relative to this bundled copy of uvr5
THIS_FILE = os.path.abspath(__file__)
UVR5_ROOT = os.path.dirname(THIS_FILE)
if UVR5_ROOT not in sys.path:
sys.path.append(UVR5_ROOT)
from gui_data.constants import *
from lib_v5.vr_network.model_param_init import ModelParameters
import argparse, json
import onnx
import onnxruntime as ort
import traceback
from datetime import datetime
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
class ModelData():
def __init__(self,
model_path,
audio_path,
result_path,
process_method,
device,
save_background=True,
is_pre_proc_model=False,
base_dir=UVR5_ROOT,
**kwargs):
self.__dict__.update(kwargs)
BASE_PATH = result_path
VR_MODELS_DIR = os.path.join(base_dir, 'models', 'VR_Models')
VR_HASH_JSON = os.path.join(VR_MODELS_DIR, 'model_data', 'model_data.json')
VR_PARAM_DIR = os.path.join(base_dir, 'lib_v5', 'vr_network', 'modelparams')
SAMPLE_CLIP_PATH = os.path.join(BASE_PATH, 'temp_sample_clips')
MDX_MIXER_PATH = os.path.join(base_dir, 'lib_v5', 'mixer.ckpt')
# MDX_MODELS_DIR = os.path.join(base_dir, 'models', 'MDX_Net_Models')
# MDX_HASH_DIR = (base_dir, 'models', 'MDX_Net_Models', 'model_data')
MDX_HASH_JSON = os.path.join(base_dir, 'model_data.json')
MDX_MODEL_NAME_SELECT = os.path.join(base_dir, 'model_name_mapper.json')
self.model_name = self.model_name
self.aggression_setting = float(int(self.aggression_setting)/100) # 1 - 20
self.window_size = int(self.window_size)
self.batch_size = int(self.batch_size) if self.batch_size.isdigit() else 1
self.mdx_batch_size = 1 if self.mdx_batch_size == DEF_OPT else int(self.mdx_batch_size)
self.is_mdx_ckpt = False
self.crop_size = int(self.crop_size)
self.is_high_end_process = 'mirroring' if self.is_high_end_process else 'None'
self.post_process_threshold = float(self.post_process_threshold)
self.model_capacity = 32, 128
self.model_path = model_path
self.result_path = result_path
self.model_basename = os.path.splitext(os.path.basename(self.model_path))[0]
self.mixer_path = MDX_MIXER_PATH
self.process_method = process_method
self.is_pre_proc_model = is_pre_proc_model
self.vr_is_secondary_model = self.vr_is_secondary_model_activate
self.mdx_is_secondary_model = self.mdx_is_secondary_model_activate
self.is_ensemble_mode = False
self.secondary_model = None
self.primary_model_primary_stem = None
self.primary_stem = None
self.secondary_stem = None
self.secondary_model_scale = None
self.is_demucs_pre_proc_model_inst_mix = False
self.device = device
self.save_background = save_background
if type(audio_path)==str and os.path.isdir(audio_path):
self.inputPaths = os.listdir(audio_path)
self.inputPaths = [os.path.join(audio_path, x) for x in self.inputPaths if x[-4:]=='.wav']
elif type(audio_path)==str and audio_path[-4:] == '.wav':
self.inputPaths = [audio_path]
elif type(audio_path) == list and audio_path[0][-4:] == '.wav':
self.inputPaths = audio_path
else:
print(f"Invalid audio_path {audio_path}")
self.get_model_hash()
if self.process_method == VR_ARCH_TYPE:
self.model_data = json.loads(open(VR_HASH_JSON, 'r', encoding='utf-8').read())[self.model_hash]
if self.model_data:
vr_model_param = os.path.join(VR_PARAM_DIR, "{}.json".format(self.model_data["vr_model_param"]))
self.primary_stem = self.model_data["primary_stem"]
self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem]
self.vr_model_param = ModelParameters(vr_model_param)
self.model_samplerate = self.vr_model_param.param['sr']
if "nout" in self.model_data.keys() and "nout_lstm" in self.model_data.keys():
self.model_capacity = self.model_data["nout"], self.model_data["nout_lstm"]
self.is_vr_51_model = True
else:
self.model_status = False
if self.process_method == MDX_ARCH_TYPE:
self.is_vr_51_model = False
self.margin = int(self.margin)
self.model_samplerate = self.margin
self.chunks = self.determine_auto_chunks(self.chunks) if self.is_chunk_mdxnet else 0
self.model_data = json.loads(open(MDX_HASH_JSON, 'r', encoding='utf-8').read())[self.model_hash]
if self.model_data:
self.is_secondary_model = self.mdx_is_secondary_model
self.compensate = self.model_data["compensate"]
self.mdx_dim_f_set = self.model_data["mdx_dim_f_set"]
self.mdx_dim_t_set = self.model_data["mdx_dim_t_set"]
self.mdx_n_fft_scale_set = self.model_data["mdx_n_fft_scale_set"]
self.primary_stem = self.model_data["primary_stem"]
self.secondary_stem = STEM_PAIR_MAPPER[self.primary_stem]
else:
self.model_status = False
def determine_auto_chunks(self, chunks):
"""Determines appropriate chunk size based on user computer specs"""
gpu = 0 if torch.cuda.device_count() > 0 else -1
if OPERATING_SYSTEM == 'Darwin':
gpu = -1
if chunks == BATCH_MODE:
chunks = 0
#self.chunks_var.set(AUTO_SELECT)
if chunks == 'Full':
chunk_set = 0
elif chunks == 'Auto':
if gpu == 0:
gpu_mem = round(torch.cuda.get_device_properties(0).total_memory/1.074e+9)
if gpu_mem <= int(6):
chunk_set = int(5)
if gpu_mem in [7, 8, 9, 10, 11, 12, 13, 14, 15]:
chunk_set = int(10)
if gpu_mem >= int(16):
chunk_set = int(40)
if gpu == -1:
sys_mem = psutil.virtual_memory().total >> 30
if sys_mem <= int(4):
chunk_set = int(1)
if sys_mem in [5, 6, 7, 8]:
chunk_set = int(10)
if sys_mem in [9, 10, 11, 12, 13, 14, 15, 16]:
chunk_set = int(25)
if sys_mem >= int(17):
chunk_set = int(60)
elif chunks == '0':
chunk_set = 0
else:
chunk_set = int(chunks)
print("chunks: ", gpu_mem, chunk_set)
return chunk_set
def get_model_hash(self):
self.model_hash = None
if not os.path.isfile(self.model_path):
self.model_status = False
self.model_hash is None
else:
if not self.model_hash:
try:
with open(self.model_path, 'rb') as f:
f.seek(- 10000 * 1024, 2)
self.model_hash = hashlib.md5(f.read()).hexdigest()
except:
self.model_hash = hashlib.md5(open(self.model_path,'rb').read()).hexdigest()
class Inference():
def __init__(self, model_data: ModelData, device):
self.device = device
self.n_fft = model_data.mdx_n_fft_scale_set
self.is_normalization = model_data.is_normalization
self.compensate = model_data.compensate
self.dim_f, self.dim_t = model_data.mdx_dim_f_set, 2**model_data.mdx_dim_t_set
self.mdx_batch_size = model_data.mdx_batch_size
self.is_denoise = model_data.is_denoise
self.hop = 1024
self.dim_c = 4
self.chunks = model_data.chunks
self.margin = model_data.margin
self.adjust = 1
self.progress_value = 0
self.n_bins = self.n_fft//2+1
self.trim = self.n_fft//2
self.chunk_size = self.hop * (self.dim_t-1)
self.window = torch.hann_window(window_length=self.n_fft, periodic=False).to(self.device)
self.freq_pad = torch.zeros([1, self.dim_c, self.n_bins-self.dim_f, self.dim_t]).to(self.device)
self.gen_size = self.chunk_size-2*self.trim
self.save_background = model_data.save_background
def stft(self, x):
x = x.reshape([-1, self.chunk_size])
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True,return_complex=True)
x=torch.view_as_real(x)
x = x.permute([0,3,1,2])
x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,self.dim_c,self.n_bins,self.dim_t])
return x[:,:,:self.dim_f]
def istft(self, x, freq_pad=None):
freq_pad = self.freq_pad.repeat([x.shape[0],1,1,1]) if freq_pad is None else freq_pad
x = torch.cat([x, freq_pad], -2)
x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t])
x = x.permute([0,2,3,1])
x=x.contiguous()
x=torch.view_as_complex(x)
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
return x.reshape([-1,2,self.chunk_size])
def load_model(self, model_path, threads, device='cpu'):
model = onnx.load_model(model_path)
if torch.cuda.is_available() and device != 'cpu':
providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
"user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
else:
providers = ["CPUExecutionProvider"]
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = threads
# sess_options.enable_profiling = True # debug 时开启
self.ort_ = ort.InferenceSession(model.SerializeToString(), sess_options=sess_options, providers=providers)
self.model_run = lambda spek:self.ort_.run(None, {'input': spek.cpu().numpy()})[0]
def initialize_mix(self, mix):
mix_waves = []
n_sample = mix.shape[1]
pad = self.gen_size - n_sample%self.gen_size
zero_pad = torch.zeros((2,self.trim), device=mix.device)
# print("mix:", mix.shape, mix.device, "zero_pad:", zero_pad.shape, zero_pad.device)
mix_p = torch.cat((zero_pad, mix, torch.zeros((2,pad), device=mix.device), zero_pad), 1)
i = 0
while i < n_sample + pad:
waves = mix_p[:, i:i+self.chunk_size]
mix_waves.append(waves.unsqueeze(0))
i += self.gen_size
# print("debug 7:", i, waves, waves.shape, self.gen_size)
mix_waves = torch.cat(mix_waves, 0).to(self.device)
# print("debug 8:", mix_waves, mix_waves.shape, self.device, pad)
return mix_waves, pad
def run_model(self, mix, is_match_mix=False):
spek = self.stft(mix.to(self.device))*self.adjust
spek[:, :, :3, :] *= 0
# print("spek input:", spek.device, spek.shape)
if is_match_mix:
spec_pred = spek.to(self.device)
else:
spec_pred = -self.model_run(-spek)*0.5+self.model_run(spek)*0.5 if self.is_denoise else self.model_run(spek)
spec_pred = torch.from_numpy(spec_pred).to(self.device)
# print("is_denoise:", self.is_denoise, "spec_pred:", spec_pred.dtype, type(spec_pred))
return self.istft(spec_pred).to(self.device)[:,:,self.trim:-self.trim].transpose(0,1).reshape(2, -1)
def demix_base(self, mix, is_match_mix=False, device='cpu'):
chunked_sources = []
for slice in mix:
# print("debug 6:", mix, slice, is_match_mix)
sources = []
tar_waves_ = []
mix_p = mix[slice]
# print("demix_base: ", mix_p.shape, mix_p.device)
mix_waves, pad = self.initialize_mix(mix_p.to(device))
mix_waves = mix_waves.split(self.mdx_batch_size)
with torch.no_grad():
for mix_wave in mix_waves:
# self.running_inference_progress_bar(len(mix)*len(mix_waves), is_match_mix=is_match_mix)
# print("debug10:", mix_wave, mix_wave.shape, is_match_mix)
tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix)
tar_waves_.append(tar_waves)
tar_waves = torch.cat(tar_waves_, axis=-1)[:, :-pad]
start = 0 if slice == 0 else self.margin
end = None if slice == list(mix.keys())[::-1][0] or self.margin == 0 else -self.margin
sources.append(tar_waves[:,start:end]*(1/self.adjust))
chunked_sources = torch.cat(sources, axis=-1)
# print("debug 11:",chunked_sources, len(chunked_sources), chunked_sources.shape)
# sources = torch.cat(chunked_sources, axis=-1)
sources = chunked_sources
# print("debug 4:", sources, sources.shape)
return sources
def onnx_inference(self, wav_path, save_dir, device):
start_time = time.time()
input_audio, sr = torchaudio.load(wav_path, channels_first=True)
input_audio = input_audio.to(device)
# input_audio = input_audio.mean(dim=0).unsqueeze(0) # stereo to mono
if input_audio.shape[0] == 1:
input_audio = torch.cat((input_audio, input_audio), 0) # mono to stereo
if sr != 44100:
input_audio = torchaudio.functional.resample(input_audio.squeeze(), sr, 44100)
output_audio = self.demix_base({0:input_audio.squeeze()}, is_match_mix=False, device=device)
torchaudio.save(
os.path.join(save_dir, os.path.basename(wav_path).replace(".wav", "_vocal.wav")),
output_audio.cpu(),
44100,
)
if self.save_background:
raw_mix = self.demix_base({0:input_audio.squeeze()}, is_match_mix=True)
secondary_source, raw_mix = normalize_two_stem(output_audio*self.compensate, raw_mix, self.is_normalization)
secondary_source = (-secondary_source+raw_mix)
torchaudio.save(
os.path.join(save_dir, os.path.basename(wav_path)).replace(".wav", "_background.wav"),
secondary_source.cpu(),
44100,
)
process_time = time.time() - start_time
print(f"{datetime.now()} {wav_path} denoised time: {process_time:.3f}s audio len: {output_audio.shape[-1]/44100:.3f}s RTF: {output_audio.shape[-1]/44100/process_time:.3f}")
vocal_path = os.path.join(save_dir, os.path.basename(wav_path).replace(".wav", "_vocal.wav"))
bg_path = os.path.join(save_dir, os.path.basename(wav_path).replace(".wav", "_background.wav")) if self.save_background else ""
return vocal_path, bg_path
def normalize_two_stem(wave, mix, is_normalize=False):
"""Save output music files"""
maxv = torch.abs(wave).max()
max_mix = torch.abs(mix).max()
if maxv > 1.0:
# print(f"\nNormalization Set {is_normalize}: Primary source above threshold for clipping. Max:{maxv}")
# print(f"\nNormalization Set {is_normalize}: Mixture above threshold for clipping. Max:{max_mix}")
if is_normalize:
wave /= maxv
mix /= maxv
return wave, mix
def get_wav_duration(file_path):
with wave.open(file_path, 'rb') as wav_file:
# 获取音频帧数
n_frames = wav_file.getnframes()
# 获取采样率
framerate = wav_file.getframerate()
# 计算时长(秒)
duration = n_frames / float(framerate)
return duration
def walkFile(data_dir, save_dir):
res_wavs = []
res_txts = []
for root, dirs, files in tqdm(os.walk(data_dir)):
# 遍历文件
for f in files:
if f[-4:] == '.wav':
wav_path = os.path.join(root, f)
if not os.path.exists(os.path.join(save_dir, f'{f[:-4]}_Vocals.wav')):
res_wavs.append(wav_path)
# elif f[-4:] == '.csv':
# res_txts.append(os.path.join(root, f))
return res_wavs # , res_txts
def process_batch(files, args, device='cpu'):
configs = json.loads(open(args.config_path, 'r', encoding='utf-8').read())
model_data = ModelData(
model_path=args.model_path,
audio_path = files,
result_path = args.result_path,
process_method = args.process_method,
device = device,
save_background = args.save_background,
**configs
)
# uvr5_model = Inference_raw(model_data, device)
# uvr5_model.process_start()
uvr5_model = Inference(model_data, device)
uvr5_model.load_model(args.model_path, args.num_processes)
print(f"Loaded UVR5 model in {device}.")
for file in files:
vocal_path, bg_path = uvr5_model.onnx_inference(file, os.path.join(args.result_path, os.path.basename(file)), device)
def parallel_process(filenames, args):
total_gpu = torch.cuda.device_count()
print(f'Total GPUs: {total_gpu}')
with ProcessPoolExecutor(max_workers=args.num_processes*total_gpu) as executor:
tasks = []
for i in range(args.num_processes):
start = int(i * len(filenames) / args.num_processes)
end = int((i + 1) * len(filenames) / args.num_processes)
file_chunk = filenames[start:end]
for n in range(total_gpu):
chunk = file_chunk[int(n*len(file_chunk)/total_gpu): int((n+1)*len(file_chunk)/total_gpu)]
device = f"cuda:{n}" if torch.cuda.is_available() else "cpu"
print("load model in devices: ", args.num_processes, total_gpu, i, n, device)
tasks.append(executor.submit(process_batch, chunk, args, device))
for task in tqdm(tasks):
task.result()
def parallel_process_cpu(filenames, args):
with ProcessPoolExecutor(max_workers=args.num_processes) as executor:
tasks = []
for i in range(args.num_processes):
start = int(i * len(filenames) / args.num_processes)
end = int((i + 1) * len(filenames) / args.num_processes)
chunk = filenames[start:end]
print("load model in devices: ", args.num_processes, i, "cpu")
tasks.append(executor.submit(process_batch, chunk, args, "cpu"))
for task in tqdm(tasks):
task.result()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_path', type=str, default="models/MDX_Net_Models/model_data/Kim_Vocal_1.onnx", help='模型路径')
parser.add_argument('-c', '--config_path', type=str, default="models/MDX_Net_Models/model_data/MDX-Net-Kim-Vocal1.json", help='配置文件路径')
parser.add_argument('-a', '--audio_path', type=str, default="", help='wav文件名列表,放在raw文件夹下')
parser.add_argument('-r', '--result_path', type=str, default="", help='结果存储路径')
parser.add_argument('-p', '--process_method', type=str, default="MDX-Net", help='可选方法:["VR Arc", "MDX-Net"]')
parser.add_argument('-b', '--save_background', type=bool, default=True, help='True:保存人声和背景音,False:只保存人声')
parser.add_argument('-w', '--num_processes', type=int, default=4, help='You are advised to set the number of processes to the same as the number of CPU cores')
args = parser.parse_args()
if not os.path.exists(args.result_path):
os.makedirs(args.result_path, exist_ok=True)
if args.save_background:
os.makedirs(os.path.join(os.path.dirname(args.result_path), "bg_music"), exist_ok=True)
if os.path.isdir(args.audio_path):
filenames = walkFile(args.audio_path, args.result_path)
elif args.audio_path.endswith(".wav"):
filenames = [args.audio_path]
# shuffle(filenames)
print(len(filenames))
# process_batch(filenames, args, "cpu")
multiprocessing.set_start_method("spawn", force=True)
num_processes = args.num_processes
if num_processes == 0:
num_processes = os.cpu_count()
if torch.cuda.is_available():
parallel_process(filenames, args)
else:
parallel_process_cpu(filenames, args)