| | import time
|
| | import json
|
| | import torch
|
| | from tqdm import tqdm
|
| | import sys
|
| | from model import PromptCondAudioDiffusion
|
| | from diffusers import DDIMScheduler, DDPMScheduler
|
| | import torchaudio
|
| | import librosa
|
| | import os
|
| | import math
|
| | import numpy as np
|
| | from tools.get_melvaehifigan48k import build_pretrained_models
|
| | import tools.torch_tools as torch_tools
|
| | from safetensors.torch import load_file
|
| | import subprocess
|
| |
|
| | def get_free_gpu() -> int:
|
| | """Return the GPU ID with the least memory usage"""
|
| | cmd = "nvidia-smi --query-gpu=index,memory.free --format=csv,noheader,nounits"
|
| | result = subprocess.check_output(cmd.split()).decode().strip().split("\n")
|
| |
|
| | free_list = []
|
| | for line in result:
|
| | idx, free_mem = line.split(",")
|
| | free_list.append((int(idx), int(free_mem)))
|
| |
|
| |
|
| | free_list.sort(key=lambda x: x[1], reverse=True)
|
| | return free_list[0][0]
|
| |
|
| | DEVICE = f"cuda:{get_free_gpu()}"
|
| | print(f"Use {DEVICE}")
|
| |
|
| | class MuCodec:
|
| | def __init__(self, \
|
| | model_path, \
|
| | layer_num, \
|
| | load_main_model=True, \
|
| | device=DEVICE):
|
| |
|
| | self.layer_num = layer_num - 1
|
| | self.sample_rate = 48000
|
| | self.device = device
|
| |
|
| | self.MAX_DURATION = 360
|
| | if load_main_model:
|
| | audio_ldm_path = os.path.dirname(os.path.abspath(__file__)) + "/tools/audioldm_48k.pth"
|
| | self.vae, self.stft = build_pretrained_models(audio_ldm_path)
|
| | self.vae, self.stft = self.vae.eval().to(device), self.stft.eval().to(device)
|
| | main_config = {
|
| | "num_channels":32,
|
| | "unet_model_name":None,
|
| | "unet_model_config_path":os.path.dirname(os.path.abspath(__file__)) + "/configs/models/transformer2D.json",
|
| | "snr_gamma":None,
|
| | }
|
| | self.model = PromptCondAudioDiffusion(**main_config)
|
| | if model_path.endswith('.safetensors'):
|
| | main_weights = load_file(model_path)
|
| | else:
|
| | main_weights = torch.load(model_path, map_location='cpu')
|
| | self.model.load_state_dict(main_weights, strict=False)
|
| | self.model = self.model.to(device)
|
| | print ("Successfully loaded checkpoint from:", model_path)
|
| | else:
|
| | main_config = {
|
| | "num_channels":32,
|
| | "unet_model_name":None,
|
| | "unet_model_config_path":None,
|
| | "snr_gamma":None,
|
| | }
|
| | self.model = PromptCondAudioDiffusion(**main_config).to(device)
|
| | main_weights = torch.load(model_path, map_location='cpu')
|
| | self.model.load_state_dict(main_weights, strict=False)
|
| | self.model = self.model.to(device)
|
| | print ("Successfully loaded checkpoint from:", model_path)
|
| |
|
| | self.model.eval()
|
| | self.model.init_device_dtype(torch.device(device), torch.float32)
|
| | print("scaling factor: ", self.model.normfeat.std)
|
| |
|
| | def file2code(self, fname):
|
| | orig_samples, fs = torchaudio.load(fname)
|
| | if(fs!=self.sample_rate):
|
| | orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate)
|
| | fs = self.sample_rate
|
| | if orig_samples.shape[0] == 1:
|
| | orig_samples = torch.cat([orig_samples, orig_samples], 0)
|
| | return self.sound2code(orig_samples)
|
| |
|
| | @torch.no_grad()
|
| | @torch.autocast(device_type="cuda", dtype=torch.float32)
|
| | def sound2code(self, orig_samples, batch_size=3):
|
| | if(orig_samples.ndim == 2):
|
| | audios = orig_samples.unsqueeze(0).to(self.device)
|
| | elif(orig_samples.ndim == 3):
|
| | audios = orig_samples.to(self.device)
|
| | else:
|
| | assert orig_samples.ndim in (2,3), orig_samples.shape
|
| | audios = self.preprocess_audio(audios)
|
| | audios = audios.squeeze(0)
|
| | orig_length = audios.shape[-1]
|
| | min_samples = int(40.96 * self.sample_rate)
|
| | output_len = int(orig_length / float(self.sample_rate) * 25) + 1
|
| | print("output_len: ", output_len)
|
| |
|
| | while(audios.shape[-1] < min_samples + 480):
|
| | audios = torch.cat([audios, audios], -1)
|
| | int_max_len=audios.shape[-1]//min_samples+1
|
| |
|
| | audios = torch.cat([audios, audios], -1)
|
| |
|
| | audios=audios[:,:int(int_max_len*(min_samples+480))]
|
| | codes_list=[]
|
| |
|
| | audio_input = audios.reshape(2, -1, min_samples+480).permute(1, 0, 2).reshape(-1, 2, min_samples+480)
|
| |
|
| | for audio_inx in range(0, audio_input.shape[0], batch_size):
|
| |
|
| | codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num)
|
| | codes_list.append(torch.cat(codes, 1))
|
| |
|
| |
|
| | codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(1, -1)[None]
|
| | codes=codes[:,:,:output_len]
|
| |
|
| | return codes
|
| |
|
| | @torch.no_grad()
|
| | def code2sound(self, codes, prompt=None, duration=40.96, guidance_scale=1.5, num_steps=20, disable_progress=False):
|
| | codes = codes.to(self.device)
|
| | first_latent = torch.randn(codes.shape[0], 32, 512, 32).to(self.device)
|
| | first_latent_length = 0
|
| | first_latent_codes_length = 0
|
| | if(isinstance(prompt, torch.Tensor)):
|
| | prompt = prompt.to(self.device)
|
| | if(prompt.ndim == 3):
|
| | assert prompt.shape[0] == 1, prompt.shape
|
| | prompt = prompt[0]
|
| | elif(prompt.ndim == 1):
|
| | prompt = prompt.unsqueeze(0).repeat(2,1)
|
| | elif(prompt.ndim == 2):
|
| | if(prompt.shape[0] == 1):
|
| | prompt = prompt.repeat(2,1)
|
| |
|
| | if(prompt.shape[-1] < int(30.76 * self.sample_rate)):
|
| | prompt = prompt[:,:int(10.24*self.sample_rate)]
|
| | else:
|
| | prompt = prompt[:,int(20.48*self.sample_rate):int(30.72*self.sample_rate)]
|
| |
|
| | true_mel , _, _ = torch_tools.wav_to_fbank2(prompt, -1, fn_STFT=self.stft)
|
| | true_mel = true_mel.unsqueeze(1).to(self.device)
|
| | true_latent = torch.cat([self.vae.get_first_stage_encoding(self.vae.encode_first_stage(true_mel[[m]])) for m in range(true_mel.shape[0])],0)
|
| | true_latent = true_latent.reshape(true_latent.shape[0]//2, -1, true_latent.shape[2], true_latent.shape[3]).detach()
|
| |
|
| | first_latent[:,:,0:true_latent.shape[2],:] = true_latent
|
| | first_latent_length = true_latent.shape[2]
|
| | first_latent_codes = self.sound2code(prompt)[:,:,0:first_latent_length*2]
|
| | first_latent_codes_length = first_latent_codes.shape[-1]
|
| | codes = torch.cat([first_latent_codes, codes], -1)
|
| |
|
| | min_samples = 1024
|
| | hop_samples = min_samples // 4 * 3
|
| | ovlp_samples = min_samples - hop_samples
|
| | hop_frames = hop_samples // 2
|
| | ovlp_frames = ovlp_samples // 2
|
| |
|
| | codes_len= codes.shape[-1]
|
| | target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate)
|
| |
|
| | if(codes_len < min_samples):
|
| | while(codes.shape[-1] < min_samples):
|
| | codes = torch.cat([codes, codes], -1)
|
| | codes = codes[:,:,0:min_samples]
|
| | codes_len = codes.shape[-1]
|
| | if((codes_len - ovlp_frames) % hop_samples > 0):
|
| | len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples
|
| | while(codes.shape[-1] < len_codes):
|
| | codes = torch.cat([codes, codes], -1)
|
| | codes = codes[:,:,0:len_codes]
|
| | latent_length = 512
|
| | latent_list = []
|
| | spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device)
|
| | with torch.autocast(device_type="cuda", dtype=torch.float16):
|
| | for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples):
|
| | codes_input=[]
|
| | codes_input.append(codes[:,:,sinx:sinx+min_samples])
|
| | if(sinx == 0):
|
| | incontext_length = first_latent_length
|
| | latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
| | latent_list.append(latents)
|
| | else:
|
| | true_latent = latent_list[-1][:,:,-ovlp_frames:,:]
|
| | len_add_to_512 = 512 - true_latent.shape[-2]
|
| | incontext_length = true_latent.shape[-2]
|
| | true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], true_latent.shape[1], len_add_to_512, true_latent.shape[-1]).to(self.device)], -2)
|
| | latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg')
|
| | latent_list.append(latents)
|
| |
|
| | latent_list = [l.float() for l in latent_list]
|
| | latent_list[0] = latent_list[0][:,:,first_latent_length:,:]
|
| | min_samples = int(duration * self.sample_rate)
|
| | hop_samples = min_samples // 4 * 3
|
| | ovlp_samples = min_samples - hop_samples
|
| | with torch.no_grad():
|
| | output = None
|
| | for i in range(len(latent_list)):
|
| | latent = latent_list[i]
|
| | bsz , ch, t, f = latent.shape
|
| | latent = latent.reshape(bsz*2, ch//2, t, f)
|
| | mel = self.vae.decode_first_stage(latent)
|
| | cur_output = self.vae.decode_to_waveform(mel)
|
| | cur_output = torch.from_numpy(cur_output)[:, 0:min_samples]
|
| |
|
| | if output is None:
|
| | output = cur_output
|
| | else:
|
| | ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :])
|
| | ov_win = torch.cat([ov_win, 1 - ov_win], -1)
|
| | output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples]
|
| | output = torch.cat([output, cur_output[:, ovlp_samples:]], -1)
|
| | output = output[:, 0:target_len]
|
| | return output
|
| |
|
| | @torch.no_grad()
|
| | def preprocess_audio(self, input_audios, threshold=0.8):
|
| | assert len(input_audios.shape) == 3, input_audios.shape
|
| | nchan = input_audios.shape[1]
|
| | input_audios = input_audios.reshape(input_audios.shape[0], -1)
|
| | norm_value = torch.ones_like(input_audios[:,0])
|
| | max_volume = input_audios.abs().max(dim=-1)[0]
|
| | norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold
|
| | return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1)
|
| |
|
| | @torch.no_grad()
|
| | def sound2sound(self, sound, prompt=None, min_duration=40.96, steps=50, disable_progress=False):
|
| | start_time = time.time()
|
| | codes = self.sound2code(sound)
|
| | mid_time = time.time()
|
| | elapsed_1 = mid_time - start_time
|
| | print(f"sound2code: {elapsed_1:.3f}s")
|
| | wave = self.code2sound(codes, prompt, duration=min_duration, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
|
| | end_time = time.time()
|
| | elapsed_2 = end_time - mid_time
|
| | print(f"code2sound: step-{steps}, {elapsed_2:.3f}s")
|
| | return wave
|
| |
|
| | def decode_one(mucodec, path):
|
| | """Decode one audio"""
|
| | codes = torch.load(path)
|
| |
|
| |
|
| | prompt=None
|
| | min_duration=40.96
|
| | steps=50
|
| | disable_progress=False
|
| | wave = mucodec.code2sound(codes, prompt, duration=min_duration, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress)
|
| | torchaudio.save("./origin.wav", wave.detach().cpu(), 48000)
|
| |
|
| | if __name__=="__main__":
|
| | ckpt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ckpt/mucodec.pt")
|
| | mucodec = MuCodec(model_path=ckpt_path,layer_num=7,load_main_model=True)
|
| |
|
| | path = "xxx/suno_cn_009326_1.pt"
|
| | decode_one(mucodec, path)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | |