import yaml import random import inspect import numpy as np import typing as tp from abc import ABC import torch import torch.nn as nn import torch.nn.functional as F import torchaudio from einops import repeat from tools.torch_tools import wav_to_fbank from diffusers.utils.torch_utils import randn_tensor from transformers import HubertModel from libs.rvq.descript_quantize3 import ResidualVectorQuantize from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model from models_gpt.models.gpt2_config import GPT2Config from our_MERT_BESTRQ.mert_fairseq.models.musicfm.musicfm_model import MusicFMModel, MusicFMConfig from torch.cuda.amp import autocast class HubertModelWithFinalProj(HubertModel): def __init__(self, config): super().__init__(config) # The final projection layer is only used for backward compatibility. # Following https://github.com/auspicious3000/contentvec/issues/6 # Remove this layer is necessary to achieve the desired outcome. print("hidden_size:",config.hidden_size) print("classifier_proj_size:",config.classifier_proj_size) self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) class SampleProcessor(torch.nn.Module): def project_sample(self, x: torch.Tensor): """Project the original sample to the 'space' where the diffusion will happen.""" """Project back from diffusion space to the actual sample space.""" return z class Feature1DProcessor(SampleProcessor): def __init__(self, dim: int = 100, power_std = 1., \ num_samples: int = 100_000, cal_num_frames: int = 600): super().__init__() self.num_samples = num_samples self.dim = dim self.power_std = power_std self.cal_num_frames = cal_num_frames self.register_buffer('counts', torch.zeros(1)) self.register_buffer('sum_x', torch.zeros(dim)) self.register_buffer('sum_x2', torch.zeros(dim)) self.register_buffer('sum_target_x2', torch.zeros(dim)) self.counts: torch.Tensor self.sum_x: torch.Tensor self.sum_x2: torch.Tensor @property def mean(self): mean = self.sum_x / self.counts if(self.counts < 10): mean = torch.zeros_like(mean) return mean @property def std(self): std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() if(self.counts < 10): std = torch.ones_like(std) return std @property def target_std(self): return 1 def project_sample(self, x: torch.Tensor): assert x.dim() == 3 if self.counts.item() < self.num_samples: self.counts += len(x) self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0) self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0) rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1) return x def return_sample(self, x: torch.Tensor): assert x.dim() == 3 rescale = (self.std / self.target_std) ** self.power_std x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1) return x def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77): if(prior_text_encoder_hidden_states.shape[1] 1.0: def double(z): return torch.cat([z, z], 0) if z is not None else None attention_mask = double(attention_mask) x_next = x.clone() noise = x.clone() for i in range(len(dt)): ti = t[i] x_next[:, :incontext_length] = ( (1 - (1 - self.sigma_min) * ti) * noise[:, :incontext_length] + ti * incontext_x[:, :incontext_length] ) if guidance_scale > 1.0: model_input = torch.cat([ double(latent_mask_input), double(incontext_x), torch.cat([torch.zeros_like(mu), mu], 0), double(x_next), ], dim=2) timestep = ti.expand(2 * B) else: model_input = torch.cat([ latent_mask_input, incontext_x, mu, x_next ], dim=2) timestep = ti.expand(B) v = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask, time_step=timestep).last_hidden_state v = v[..., -x.shape[2]:] if guidance_scale > 1.0: v_uncond, v_cond = v.chunk(2, 0) v = v_uncond + guidance_scale * (v_cond - v_uncond) x_next = x_next + dt[i] * v return x_next def projection_loss(self,hidden_proj, bestrq_emb): bsz = hidden_proj.shape[0] hidden_proj_normalized = F.normalize(hidden_proj, dim=-1) bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1) proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1) proj_loss = 1+proj_loss.mean() return proj_loss def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False): """Computes diffusion loss Args: x1 (torch.Tensor): Target shape: (batch_size, n_channels, mel_timesteps, n_feats) mu (torch.Tensor): output of encoder shape: (batch_size, n_channels, mel_timesteps, n_feats) Returns: loss: conditional flow matching loss y: conditional flow shape: (batch_size, n_channels, mel_timesteps, n_feats) """ b = mu[0].shape[0] len_x = x1.shape[2] # random timestep if(validation_mode): t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5 else: t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) # sample noise p(x_0) z = torch.randn_like(x1) y = (1 - (1 - self.sigma_min) * t) * z + t * x1 u = x1 - (1 - self.sigma_min) * z model_input = torch.cat([*mu,y], 2) t=t.squeeze(-1).squeeze(-1) out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True) hidden_layer_7 = out.hidden_states[7] hidden_proj = self.mlp(hidden_layer_7) out = out.last_hidden_state out=out[:,:,-len_x:] weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01 loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum() loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds) loss = loss_re + loss_cos * 0.5 return loss, loss_re, loss_cos class PromptCondAudioDiffusion(nn.Module): def __init__( self, num_channels, unet_model_name=None, unet_model_config_path=None, snr_gamma=None, uncondition=True, ): super().__init__() assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required" self.unet_model_name = unet_model_name self.unet_model_config_path = unet_model_config_path self.snr_gamma = snr_gamma self.uncondition = uncondition self.num_channels = num_channels # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview self.normfeat = Feature1DProcessor(dim=64) self.sample_rate = 48000 self.num_samples_perseg = self.sample_rate * 20 // 1000 self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000) self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000) # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) self.bestrq = MusicFMModel(MusicFMConfig()) self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000) self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000) self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) self.rvq_bestrq_bgm_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) # self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68") # for v in self.hubert.parameters():v.requires_grad = False self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,)) # self.xvecmodel = XVECModel() config = GPT2Config(n_positions=1000,n_layer=16,n_head=20,n_embd=2200,n_inner=4400) unet = GPT2Model(config) mlp = nn.Sequential( nn.Linear(2200, 1024), nn.SiLU(), nn.Linear(1024, 1024), nn.SiLU(), nn.Linear(1024, 768) ) self.set_from = "random" self.cfm_wrapper = BASECFM(unet, mlp) self.mask_emb = torch.nn.Embedding(3, 24) print("Transformer initialized from pretrain.") torch.cuda.empty_cache() def compute_snr(self, timesteps): """ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 """ alphas_cumprod = self.noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 # Expand the tensors. # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(timesteps.shape) sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) # Compute SNR. snr = (alpha / sigma) ** 2 return snr def preprocess_audio(self, input_audios, threshold=0.9): assert len(input_audios.shape) == 2, input_audios.shape norm_value = torch.ones_like(input_audios[:,0]) max_volume = input_audios.abs().max(dim=-1)[0] norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold return input_audios/norm_value.unsqueeze(-1) def extract_wav2vec_embeds(self, input_audios,output_len): wav2vec_stride = 2 wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024 wav2vec_embeds_last=wav2vec_embeds[-1] wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1) return wav2vec_embeds_last def extract_mert_embeds(self, input_audios): prompt_stride = 3 inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt") input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype) prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024 mert_emb= prompt_embeds[-1] mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=375, mode='linear', align_corners=False).permute(0, 2, 1) return mert_emb def extract_bestrq_embeds(self, input_audio_vocal_0,input_audio_vocal_1,layer): input_wav_mean = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0 input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True) layer_results = input_wav_mean['layer_results'] bestrq_emb = layer_results[layer] bestrq_emb = bestrq_emb.permute(0,2,1).contiguous() return bestrq_emb def extract_spk_embeds(self, input_audios): spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios)) spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32) return spk_embeds def extract_lyric_feats(self, lyric): with torch.no_grad(): try: text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False) except: text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False) text_encoder_hidden_states = text_encoder_hidden_states.to(self.device) text_mask = text_mask.to(self.device) text_encoder_hidden_states, text_mask, text_prompt_embeds = \ pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds) text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous() return text_encoder_hidden_states, text_mask def extract_energy_bar(self, input_audios): if(input_audios.shape[-1] % self.num_samples_perseg > 0): energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg) else: energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg) energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int() energy_embedding = self.energy_embedding(energy_bar) energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t return energy_embedding def forward(self, input_audios_vocal,input_audios_bgm, lyric, latents, latent_masks, validation_mode=False, \ additional_feats = ['spk', 'lyric'], \ train_rvq=True, train_ssl=False,layer_vocal=7,layer_bgm=7): if not hasattr(self,"device"): self.device = input_audios_vocal.device if not hasattr(self,"dtype"): self.dtype = input_audios_vocal.dtype device = self.device input_audio_vocal_0 = input_audios_vocal[:,0,:] input_audio_vocal_1 = input_audios_vocal[:,1,:] input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0) input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1) input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0 input_audio_bgm_0 = input_audios_bgm[:,0,:] input_audio_bgm_1 = input_audios_bgm[:,1,:] input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0) input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1) input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0 if(train_ssl): self.wav2vec.train() wav2vec_embeds = self.extract_wav2vec_embeds(input_audios) self.clap_embd_extractor.train() prompt_embeds = self.extract_mert_embeds(input_audios) if('spk' in additional_feats): self.xvecmodel.train() spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1) else: with torch.no_grad(): with autocast(enabled=False): bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal) bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm) # mert_emb = self.extract_mert_embeds(input_audios_mert) output_len = bestrq_emb.shape[2] wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_vocal_wav2vec+input_audios_bgm_wav2vec,output_len) bestrq_emb = bestrq_emb.detach() bestrq_emb_bgm = bestrq_emb_bgm.detach() if('lyric' in additional_feats): text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric) else: text_encoder_hidden_states, text_mask = None, None if(train_rvq): quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t quantized_bestrq_emb_bgm, _, _, commitment_loss_bestrq_emb_bgm, codebook_loss_bestrq_emb_bgm,_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t else: bestrq_emb = bestrq_emb.float() self.rvq_bestrq_emb.eval() # with autocast(enabled=False): quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach() codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach() quantized_bestrq_emb = quantized_bestrq_emb.detach() commitment_loss = commitment_loss_bestrq_emb+commitment_loss_bestrq_emb_bgm codebook_loss = codebook_loss_bestrq_emb+codebook_loss_bestrq_emb_bgm alpha=1 quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha) quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm * alpha + bestrq_emb_bgm * (1-alpha) scenario = np.random.choice(['start_seg', 'other_seg']) if(scenario == 'other_seg'): for binx in range(input_audios_vocal.shape[0]): # latent_masks[binx,0:64] = 1 latent_masks[binx,0:random.randint(64,128)] = 1 quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm.permute(0,2,1).contiguous() quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) quantized_bestrq_emb_bgm = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb_bgm \ + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) if self.uncondition: mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1] if len(mask_indices) > 0: quantized_bestrq_emb[mask_indices] = 0 quantized_bestrq_emb_bgm[mask_indices] = 0 latents = latents.permute(0,2,1).contiguous() latents = self.normfeat.project_sample(latents) latents = latents.permute(0,2,1).contiguous() incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() attention_mask=(latent_masks > 0.5) B, L = attention_mask.size() attention_mask = attention_mask.view(B, 1, L) attention_mask = attention_mask * attention_mask.transpose(-1, -2) attention_mask = attention_mask.unsqueeze(1) latent_mask_input = self.mask_emb(latent_masks) loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb,quantized_bestrq_emb_bgm], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode) return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean() def init_device_dtype(self, device, dtype): self.device = device self.dtype = dtype @torch.no_grad() def fetch_codes(self, input_audios_vocal,input_audios_bgm, additional_feats,layer_vocal=7,layer_bgm=7): input_audio_vocal_0 = input_audios_vocal[[0],:] input_audio_vocal_1 = input_audios_vocal[[1],:] input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0) input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1) input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0 input_audio_bgm_0 = input_audios_bgm[[0],:] input_audio_bgm_1 = input_audios_bgm[[1],:] input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0) input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1) input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0 self.bestrq.eval() # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) # bestrq_middle = bestrq_middle.detach() # bestrq_last = bestrq_last.detach() bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal) bestrq_emb = bestrq_emb.detach() bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm) bestrq_emb_bgm = bestrq_emb_bgm.detach() self.rvq_bestrq_emb.eval() quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t self.rvq_bestrq_bgm_emb.eval() quantized_bestrq_emb_bgm, codes_bestrq_emb_bgm, *_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t if('spk' in additional_feats): self.xvecmodel.eval() spk_embeds = self.extract_spk_embeds(input_audios) else: spk_embeds = None # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds return [codes_bestrq_emb,codes_bestrq_emb_bgm], [bestrq_emb,bestrq_emb_bgm], spk_embeds # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds @torch.no_grad() def fetch_codes_batch(self, input_audios_vocal, input_audios_bgm, additional_feats,layer_vocal=7,layer_bgm=7): input_audio_vocal_0 = input_audios_vocal[:,0,:] input_audio_vocal_1 = input_audios_vocal[:,1,:] input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0) input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1) input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0 input_audio_bgm_0 = input_audios_bgm[:,0,:] input_audio_bgm_1 = input_audios_bgm[:,1,:] input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0) input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1) input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0 self.bestrq.eval() # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) # bestrq_middle = bestrq_middle.detach() # bestrq_last = bestrq_last.detach() bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal) bestrq_emb = bestrq_emb.detach() bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm) bestrq_emb_bgm = bestrq_emb_bgm.detach() self.rvq_bestrq_emb.eval() quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t self.rvq_bestrq_bgm_emb.eval() quantized_bestrq_emb_bgm, codes_bestrq_emb_bgm, *_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t if('spk' in additional_feats): self.xvecmodel.eval() spk_embeds = self.extract_spk_embeds(input_audios) else: spk_embeds = None # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds return [codes_bestrq_emb,codes_bestrq_emb_bgm], [bestrq_emb,bestrq_emb_bgm], spk_embeds # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds @torch.no_grad() def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats,incontext_length=127, guidance_scale=2, num_steps=20, disable_progress=True, scenario='start_seg'): classifier_free_guidance = guidance_scale > 1.0 device = self.device dtype = self.dtype # codes_bestrq_middle, codes_bestrq_last = codes codes_bestrq_emb,codes_bestrq_emb_bgm = codes batch_size = codes_bestrq_emb.shape[0] quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb) quantized_bestrq_emb_bgm,_,_=self.rvq_bestrq_bgm_emb.from_codes(codes_bestrq_emb_bgm) quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm.permute(0,2,1).contiguous() if('spk' in additional_feats): spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach() num_frames = quantized_bestrq_emb.shape[1] num_channels_latents = self.num_channels shape = (batch_size, num_frames, 64) latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device) latent_masks[:,0:latent_length] = 2 if(scenario=='other_seg'): latent_masks[:,0:incontext_length] = 1 quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) quantized_bestrq_emb_bgm = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb_bgm \ + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) true_latents = true_latents.permute(0,2,1).contiguous() true_latents = self.normfeat.project_sample(true_latents) true_latents = true_latents.permute(0,2,1).contiguous() incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0] attention_mask=(latent_masks > 0.5) B, L = attention_mask.size() attention_mask = attention_mask.view(B, 1, L) attention_mask = attention_mask * attention_mask.transpose(-1, -2) attention_mask = attention_mask.unsqueeze(1) latent_mask_input = self.mask_emb(latent_masks) if('spk' in additional_feats): # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1) additional_model_input = torch.cat([quantized_bestrq_emb,quantized_bestrq_emb_bgm, spk_embeds],2) else: # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1) additional_model_input = torch.cat([quantized_bestrq_emb,quantized_bestrq_emb_bgm],2) temperature = 1.0 t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device) latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale) latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:] latents = latents.permute(0,2,1).contiguous() latents = self.normfeat.return_sample(latents) # latents = latents.permute(0,2,1).contiguous() return latents @torch.no_grad() def inference(self, input_audios_vocal,input_audios_bgm, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, disable_progress=True,layer_vocal=7,layer_bgm=3,scenario='start_seg'): codes, embeds, spk_embeds = self.fetch_codes(input_audios_vocal,input_audios_bgm, additional_feats,layer_vocal,layer_bgm) latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ guidance_scale=guidance_scale, num_steps=num_steps, \ disable_progress=disable_progress,scenario=scenario) return latents def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device): divisor = 4 shape = (batch_size, num_channels_latents, num_frames, 32) if(num_frames%divisor>0): num_frames = round(num_frames/float(divisor))*divisor shape = (batch_size, num_channels_latents, num_frames, 32) latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) return latents