|
|
import time
|
|
|
import json
|
|
|
import torch
|
|
|
import multiprocessing
|
|
|
from tqdm import tqdm
|
|
|
import concurrent.futures
|
|
|
from model import PromptCondAudioDiffusion
|
|
|
from diffusers import DDIMScheduler, DDPMScheduler
|
|
|
import torchaudio
|
|
|
import librosa
|
|
|
import os
|
|
|
import math
|
|
|
import numpy as np
|
|
|
from tools.get_melvaehifigan48k import build_pretrained_models
|
|
|
import tools.torch_tools as torch_tools
|
|
|
from safetensors.torch import load_file
|
|
|
import subprocess
|
|
|
|
|
|
def get_free_gpu() -> int:
|
|
|
"""Return the GPU ID with the least memory usage"""
|
|
|
cmd = "nvidia-smi --query-gpu=index,memory.free --format=csv,noheader,nounits"
|
|
|
result = subprocess.check_output(cmd.split()).decode().strip().split("\n")
|
|
|
|
|
|
free_list = []
|
|
|
for line in result:
|
|
|
idx, free_mem = line.split(",")
|
|
|
free_list.append((int(idx), int(free_mem)))
|
|
|
|
|
|
|
|
|
free_list.sort(key=lambda x: x[1], reverse=True)
|
|
|
return free_list[0][0]
|
|
|
|
|
|
SAVE_DIR = "xxx/suno_mucodec_en"
|
|
|
DATA_PATH = "xxx/paths_en.jsonl"
|
|
|
DEVICE = f"cuda:{get_free_gpu()}"
|
|
|
print(f"Use {DEVICE}")
|
|
|
|
|
|
class MuCodec:
|
|
|
def __init__(self, \
|
|
|
model_path, \
|
|
|
layer_num, \
|
|
|
load_main_model=True):
|
|
|
|
|
|
self.layer_num = layer_num - 1
|
|
|
self.sample_rate = 48000
|
|
|
self.device = DEVICE
|
|
|
|
|
|
self.MAX_DURATION = 360
|
|
|
if load_main_model:
|
|
|
audio_ldm_path = os.path.dirname(os.path.abspath(__file__)) + "/tools/audioldm_48k.pth"
|
|
|
self.vae, self.stft = build_pretrained_models(audio_ldm_path)
|
|
|
self.vae, self.stft = self.vae.eval().to(DEVICE), self.stft.eval().to(DEVICE)
|
|
|
main_config = {
|
|
|
"num_channels":32,
|
|
|
"unet_model_name":None,
|
|
|
"unet_model_config_path":os.path.dirname(os.path.abspath(__file__)) + "/configs/models/transformer2D.json",
|
|
|
"snr_gamma":None,
|
|
|
}
|
|
|
self.model = PromptCondAudioDiffusion(**main_config)
|
|
|
if model_path.endswith('.safetensors'):
|
|
|
main_weights = load_file(model_path)
|
|
|
else:
|
|
|
main_weights = torch.load(model_path, map_location='cpu')
|
|
|
self.model.load_state_dict(main_weights, strict=False)
|
|
|
self.model = self.model.to(DEVICE)
|
|
|
print ("Successfully loaded checkpoint from:", model_path)
|
|
|
else:
|
|
|
main_config = {
|
|
|
"num_channels":32,
|
|
|
"unet_model_name":None,
|
|
|
"unet_model_config_path":None,
|
|
|
"snr_gamma":None,
|
|
|
}
|
|
|
self.model = PromptCondAudioDiffusion(**main_config).to(DEVICE)
|
|
|
main_weights = torch.load(model_path, map_location='cpu')
|
|
|
self.model.load_state_dict(main_weights, strict=False)
|
|
|
self.model = self.model.to(DEVICE)
|
|
|
print ("Successfully loaded checkpoint from:", model_path)
|
|
|
|
|
|
self.model.eval()
|
|
|
self.model.init_device_dtype(torch.device(DEVICE), torch.float32)
|
|
|
print("scaling factor: ", self.model.normfeat.std)
|
|
|
|
|
|
def file2code(self, fname):
|
|
|
orig_samples, fs = torchaudio.load(fname)
|
|
|
if(fs!=self.sample_rate):
|
|
|
orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
|
|
|
fs = self.sample_rate
|
|
|
if orig_samples.shape[0] == 1:
|
|
|
orig_samples = torch.cat([orig_samples, orig_samples], 0)
|
|
|
return self.sound2code(orig_samples)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
|
|
def sound2code(self, orig_samples, batch_size=3):
|
|
|
if(orig_samples.ndim == 2):
|
|
|
audios = orig_samples.unsqueeze(0).to(self.device)
|
|
|
elif(orig_samples.ndim == 3):
|
|
|
audios = orig_samples.to(self.device)
|
|
|
else:
|
|
|
assert orig_samples.ndim in (2,3), orig_samples.shape
|
|
|
audios = self.preprocess_audio(audios)
|
|
|
audios = audios.squeeze(0)
|
|
|
orig_length = audios.shape[-1]
|
|
|
min_samples = int(40.96 * self.sample_rate)
|
|
|
output_len = int(orig_length / float(self.sample_rate) * 25) + 1
|
|
|
print("output_len: ", output_len)
|
|
|
|
|
|
while(audios.shape[-1] < min_samples + 480):
|
|
|
audios = torch.cat([audios, audios], -1)
|
|
|
int_max_len=audios.shape[-1]//min_samples+1
|
|
|
|
|
|
audios = torch.cat([audios, audios], -1)
|
|
|
|
|
|
audios=audios[:,:int(int_max_len*(min_samples+480))]
|
|
|
codes_list=[]
|
|
|
|
|
|
audio_input = audios.reshape(2, -1, min_samples+480).permute(1, 0, 2).reshape(-1, 2, min_samples+480)
|
|
|
|
|
|
for audio_inx in range(0, audio_input.shape[0], batch_size):
|
|
|
|
|
|
codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num)
|
|
|
codes_list.append(torch.cat(codes, 1))
|
|
|
|
|
|
|
|
|
codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(1, -1)[None]
|
|
|
codes=codes[:,:,:output_len]
|
|
|
|
|
|
return codes
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def code2sound(self, codes, prompt=None, duration=40.96, guidance_scale=1.5, num_steps=20, disable_progress=False):
|
|
|
codes = codes.to(self.device)
|
|
|
first_latent = torch.randn(codes.shape[0], 32, 512, 32).to(self.device)
|
|
|
first_latent_length = 0
|
|
|
first_latent_codes_length = 0
|
|
|
if(isinstance(prompt, torch.Tensor)):
|
|
|
prompt = prompt.to(self.device)
|
|
|
if(prompt.ndim == 3):
|
|
|
assert prompt.shape[0] == 1, prompt.shape
|
|
|
prompt = prompt[0]
|
|
|
elif(prompt.ndim == 1):
|
|
|
prompt = prompt.unsqueeze(0).repeat(2,1)
|
|
|
elif(prompt.ndim == 2):
|
|
|
if(prompt.shape[0] == 1):
|
|
|
prompt = prompt.repeat(2,1)
|
|
|
|
|
|
if(prompt.shape[-1] < int(30.76 * self.sample_rate)):
|
|
|
prompt = prompt[:,:int(10.24*self.sample_rate)]
|
|
|
else:
|
|
|
prompt = prompt[:,int(20.48*self.sample_rate):int(30.72*self.sample_rate)]
|
|
|
|
|
|
true_mel , _, _ = torch_tools.wav_to_fbank2(prompt, -1, fn_STFT=self.stft)
|
|
|
true_mel = true_mel.unsqueeze(1).to(self.device)
|
|
|
true_latent = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(true_mel[[m]])) for m in range(true_mel.shape[0])],0)
|
|
|
true_latent = true_latent.reshape(true_latent.shape[0]//2, -1, true_latent.shape[2], true_latent.shape[3]).detach()
|
|
|
|
|
|
first_latent[:,:,0:true_latent.shape[2],:] = true_latent
|
|
|
first_latent_length = true_latent.shape[2]
|
|
|
first_latent_codes = self.sound2code(prompt)[:,:,0:first_latent_length*2]
|
|
|
first_latent_codes_length = first_latent_codes.shape[-1]
|
|
|
codes = torch.cat([first_latent_codes, codes], -1)
|
|
|
|
|
|
min_samples = 1024
|
|
|
hop_samples = min_samples // 4 * 3
|
|
|
ovlp_samples = min_samples - hop_samples
|
|
|
hop_frames = hop_samples // 2
|
|
|
ovlp_frames = ovlp_samples // 2
|
|
|
|
|
|
codes_len= codes.shape[-1]
|
|
|
target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
|
|
|
|
|
|
if(codes_len < min_samples):
|
|
|
while(codes.shape[-1] < min_samples):
|
|
|
codes = torch.cat([codes, codes], -1)
|
|
|
codes = codes[:,:,0:min_samples]
|
|
|
codes_len = codes.shape[-1]
|
|
|
if((codes_len - ovlp_frames) % hop_samples > 0):
|
|
|
len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
|
|
|
while(codes.shape[-1] < len_codes):
|
|
|
codes = torch.cat([codes, codes], -1)
|
|
|
codes = codes[:,:,0:len_codes]
|
|
|
latent_length = 512
|
|
|
latent_list = []
|
|
|
spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device)
|
|
|
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
|
|
for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples):
|
|
|
codes_input=[]
|
|
|
codes_input.append(codes[:,:,sinx:sinx+min_samples])
|
|
|
if(sinx == 0):
|
|
|
incontext_length = first_latent_length
|
|
|
latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
|
|
latent_list.append(latents)
|
|
|
else:
|
|
|
true_latent = latent_list[-1][:,:,-ovlp_frames:,:]
|
|
|
len_add_to_512 = 512 - true_latent.shape[-2]
|
|
|
incontext_length = true_latent.shape[-2]
|
|
|
true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], true_latent.shape[1], len_add_to_512, true_latent.shape[-1]).to(self.device)], -2)
|
|
|
latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
|
|
latent_list.append(latents)
|
|
|
|
|
|
latent_list = [l.float() for l in latent_list]
|
|
|
latent_list[0] = latent_list[0][:,:,first_latent_length:,:]
|
|
|
min_samples = int(duration * self.sample_rate)
|
|
|
hop_samples = min_samples // 4 * 3
|
|
|
ovlp_samples = min_samples - hop_samples
|
|
|
with torch.no_grad():
|
|
|
output = None
|
|
|
for i in range(len(latent_list)):
|
|
|
latent = latent_list[i]
|
|
|
bsz , ch, t, f = latent.shape
|
|
|
latent = latent.reshape(bsz*2, ch//2, t, f)
|
|
|
mel = self.vae.decode_first_stage(latent)
|
|
|
cur_output = self.vae.decode_to_waveform(mel)
|
|
|
cur_output = torch.from_numpy(cur_output)[:, 0:min_samples]
|
|
|
|
|
|
if output is None:
|
|
|
output = cur_output
|
|
|
else:
|
|
|
ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
|
|
|
ov_win = torch.cat([ov_win, 1 - ov_win], -1)
|
|
|
output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
|
|
|
output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
|
|
|
output = output[:, 0:target_len]
|
|
|
return output
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def preprocess_audio(self, input_audios, threshold=0.8):
|
|
|
assert len(input_audios.shape) == 3, input_audios.shape
|
|
|
nchan = input_audios.shape[1]
|
|
|
input_audios = input_audios.reshape(input_audios.shape[0], -1)
|
|
|
norm_value = torch.ones_like(input_audios[:,0])
|
|
|
max_volume = input_audios.abs().max(dim=-1)[0]
|
|
|
norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
|
|
return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def sound2sound(self, sound, prompt=None, min_duration=40.96, steps=50, disable_progress=False):
|
|
|
start_time = time.time()
|
|
|
codes = self.sound2code(sound)
|
|
|
mid_time = time.time()
|
|
|
elapsed_1 = mid_time - start_time
|
|
|
print(f"sound2code: {elapsed_1:.3f}s")
|
|
|
wave = self.code2sound(codes, prompt, duration=min_duration, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
|
|
|
end_time = time.time()
|
|
|
elapsed_2 = end_time - mid_time
|
|
|
print(f"code2sound: step-{steps}, {elapsed_2:.3f}s")
|
|
|
return wave
|
|
|
|
|
|
def load_jsonl(path:str) -> list[dict]:
|
|
|
dataset = []
|
|
|
with open(path, 'r') as file:
|
|
|
for line in tqdm(file, desc=f"Loading {path}"):
|
|
|
dataset.append(json.loads(line))
|
|
|
return dataset
|
|
|
|
|
|
ckpt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ckpt/mucodec.pt")
|
|
|
mucodec = MuCodec(model_path=ckpt_path,layer_num=7,load_main_model=True)
|
|
|
|
|
|
def pure_name(path:str):
|
|
|
"""Get the original name of a file path (without extension)"""
|
|
|
basename = os.path.basename(path)
|
|
|
dot_pos = basename.rfind('.')
|
|
|
if dot_pos == -1:
|
|
|
return basename
|
|
|
return basename[:dot_pos]
|
|
|
|
|
|
def music_encode(path:str):
|
|
|
|
|
|
name = pure_name(path)
|
|
|
|
|
|
|
|
|
sound, fs = torchaudio.load(path)
|
|
|
if(fs!=48000):
|
|
|
sound = torchaudio.functional.resample(sound, fs, 48000)
|
|
|
if(sound.shape[0]==1):
|
|
|
sound = torch.cat([sound, sound],0)
|
|
|
code = mucodec.sound2code(sound)
|
|
|
|
|
|
|
|
|
save_path = os.path.join(SAVE_DIR, f"{name}.pt")
|
|
|
torch.save(code, save_path)
|
|
|
return save_path
|
|
|
|
|
|
|
|
|
def remove_finish(dataset:list[dict]) -> list[dict]:
|
|
|
"""Remove already encoded items from dataset"""
|
|
|
names = []
|
|
|
for path in os.listdir(SAVE_DIR):
|
|
|
names.append(pure_name(path))
|
|
|
new_dataset = []
|
|
|
dup_num = 0
|
|
|
for line in tqdm(dataset):
|
|
|
name = pure_name(line['path'])
|
|
|
if name not in names:
|
|
|
new_dataset.append(line)
|
|
|
else:
|
|
|
dup_num += 1
|
|
|
print(f"Remove dup num: {dup_num}")
|
|
|
return new_dataset
|
|
|
|
|
|
if __name__=="__main__":
|
|
|
multiprocessing.set_start_method('spawn', force=True)
|
|
|
dataset = load_jsonl(DATA_PATH)
|
|
|
dataset = remove_finish(dataset)
|
|
|
|
|
|
with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor:
|
|
|
|
|
|
futures = [executor.submit(music_encode, line['path']) for line in dataset]
|
|
|
|
|
|
for future in tqdm(concurrent.futures.as_completed(futures),
|
|
|
total=len(dataset),
|
|
|
desc="Encoding audio count",
|
|
|
unit="files",
|
|
|
ncols=100,
|
|
|
bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]"):
|
|
|
try:
|
|
|
result = future.result()
|
|
|
except Exception as e:
|
|
|
print(f"Task failed: {e}") |