| |
| |
| |
| |
|
|
| """Adapted from https://github.com/zhenye234/CoMoSpeech""" |
|
|
| import torch |
| import torch.nn as nn |
| import copy |
| import numpy as np |
| import math |
| from tqdm.auto import tqdm |
|
|
| from utils.ssim import SSIM |
|
|
| from models.svc.transformer.conformer import Conformer, BaseModule |
| from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper |
| from models.svc.comosvc.utils import slice_segments, rand_ids_segments |
|
|
|
|
| class Consistency(nn.Module): |
| def __init__(self, cfg, distill=False): |
| super().__init__() |
| self.cfg = cfg |
| |
| self.denoise_fn = DiffusionWrapper(self.cfg) |
| self.cfg = cfg.model.comosvc |
| self.teacher = not distill |
| self.P_mean = self.cfg.P_mean |
| self.P_std = self.cfg.P_std |
| self.sigma_data = self.cfg.sigma_data |
| self.sigma_min = self.cfg.sigma_min |
| self.sigma_max = self.cfg.sigma_max |
| self.rho = self.cfg.rho |
| self.N = self.cfg.n_timesteps |
| self.ssim_loss = SSIM() |
|
|
| |
| step_indices = torch.arange(self.N) |
| |
| t_steps = ( |
| self.sigma_min ** (1 / self.rho) |
| + step_indices |
| / (self.N - 1) |
| * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho)) |
| ) ** self.rho |
| self.t_steps = torch.cat( |
| [torch.zeros_like(t_steps[:1]), self.round_sigma(t_steps)] |
| ) |
|
|
| def init_consistency_training(self): |
| self.denoise_fn_ema = copy.deepcopy(self.denoise_fn) |
| self.denoise_fn_pretrained = copy.deepcopy(self.denoise_fn) |
|
|
| def EDMPrecond(self, x, sigma, cond, denoise_fn, mask, spk=None): |
| """ |
| karras diffusion reverse process |
| |
| Args: |
| x: noisy mel-spectrogram [B x n_mel x L] |
| sigma: noise level [B x 1 x 1] |
| cond: output of conformer encoder [B x n_mel x L] |
| denoise_fn: denoiser neural network e.g. DilatedCNN |
| mask: mask of padded frames [B x n_mel x L] |
| |
| Returns: |
| denoised mel-spectrogram [B x n_mel x L] |
| """ |
| sigma = sigma.reshape(-1, 1, 1) |
|
|
| c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) |
| c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() |
| c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() |
| c_noise = sigma.log() / 4 |
|
|
| x_in = c_in * x |
| x_in = x_in.transpose(1, 2) |
| x = x.transpose(1, 2) |
| cond = cond.transpose(1, 2) |
| F_x = denoise_fn(x_in, c_noise.squeeze(), cond) |
| |
| D_x = c_skip * x + c_out * (F_x) |
| D_x = D_x.transpose(1, 2) |
| return D_x |
|
|
| def EDMLoss(self, x_start, cond, mask): |
| """ |
| compute loss for EDM model |
| |
| Args: |
| x_start: ground truth mel-spectrogram [B x n_mel x L] |
| cond: output of conformer encoder [B x n_mel x L] |
| mask: mask of padded frames [B x n_mel x L] |
| """ |
| rnd_normal = torch.randn([x_start.shape[0], 1, 1], device=x_start.device) |
| sigma = (rnd_normal * self.P_std + self.P_mean).exp() |
| weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 |
|
|
| |
| noise = (torch.randn_like(x_start) + cond) * sigma |
| D_yn = self.EDMPrecond(x_start + noise, sigma, cond, self.denoise_fn, mask) |
| loss = weight * ((D_yn - x_start) ** 2) |
| loss = torch.sum(loss * mask) / torch.sum(mask) |
| return loss |
|
|
| def round_sigma(self, sigma): |
| return torch.as_tensor(sigma) |
|
|
| def edm_sampler( |
| self, |
| latents, |
| cond, |
| nonpadding, |
| num_steps=50, |
| sigma_min=0.002, |
| sigma_max=80, |
| rho=7, |
| S_churn=0, |
| S_min=0, |
| S_max=float("inf"), |
| S_noise=1, |
| |
| |
| |
| |
| ): |
| """ |
| karras diffusion sampler |
| |
| Args: |
| latents: noisy mel-spectrogram [B x n_mel x L] |
| cond: output of conformer encoder [B x n_mel x L] |
| nonpadding: mask of padded frames [B x n_mel x L] |
| num_steps: number of steps for diffusion inference |
| |
| Returns: |
| denoised mel-spectrogram [B x n_mel x L] |
| """ |
| |
| step_indices = torch.arange(num_steps, device=latents.device) |
|
|
| num_steps = num_steps + 1 |
| t_steps = ( |
| sigma_max ** (1 / rho) |
| + step_indices |
| / (num_steps - 1) |
| * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) |
| ) ** rho |
| t_steps = torch.cat([self.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) |
|
|
| |
| x_next = latents * t_steps[0] |
| |
| bar = tqdm(enumerate(zip(t_steps[:-1], t_steps[1:]))) |
| for i, (t_cur, t_next) in bar: |
| x_cur = x_next |
| |
| gamma = ( |
| min(S_churn / num_steps, np.sqrt(2) - 1) |
| if S_min <= t_cur <= S_max |
| else 0 |
| ) |
| t_hat = self.round_sigma(t_cur + gamma * t_cur) |
| t = torch.zeros((x_cur.shape[0], 1, 1), device=x_cur.device) |
| t[:, 0, 0] = t_hat |
| t_hat = t |
| x_hat = x_cur + ( |
| t_hat**2 - t_cur**2 |
| ).sqrt() * S_noise * torch.randn_like(x_cur) |
| |
| denoised = self.EDMPrecond(x_hat, t_hat, cond, self.denoise_fn, nonpadding) |
| d_cur = (x_hat - denoised) / t_hat |
| x_next = x_hat + (t_next - t_hat) * d_cur |
|
|
| return x_next |
|
|
| def CTLoss_D(self, y, cond, mask): |
| """ |
| compute loss for consistency distillation |
| |
| Args: |
| y: ground truth mel-spectrogram [B x n_mel x L] |
| cond: output of conformer encoder [B x n_mel x L] |
| mask: mask of padded frames [B x n_mel x L] |
| """ |
| with torch.no_grad(): |
| mu = 0.95 |
| for p, ema_p in zip( |
| self.denoise_fn.parameters(), self.denoise_fn_ema.parameters() |
| ): |
| ema_p.mul_(mu).add_(p, alpha=1 - mu) |
|
|
| n = torch.randint(1, self.N, (y.shape[0],)) |
| z = torch.randn_like(y) + cond |
|
|
| tn_1 = self.t_steps[n + 1].reshape(-1, 1, 1).to(y.device) |
| f_theta = self.EDMPrecond(y + tn_1 * z, tn_1, cond, self.denoise_fn, mask) |
|
|
| with torch.no_grad(): |
| tn = self.t_steps[n].reshape(-1, 1, 1).to(y.device) |
|
|
| |
| x_hat = y + tn_1 * z |
| denoised = self.EDMPrecond( |
| x_hat, tn_1, cond, self.denoise_fn_pretrained, mask |
| ) |
| d_cur = (x_hat - denoised) / tn_1 |
| y_tn = x_hat + (tn - tn_1) * d_cur |
|
|
| f_theta_ema = self.EDMPrecond(y_tn, tn, cond, self.denoise_fn_ema, mask) |
|
|
| |
| |
| loss = self.ssim_loss(f_theta, f_theta_ema.detach()) |
| loss = torch.sum(loss * mask) / torch.sum(mask) |
|
|
| return loss |
|
|
| def get_t_steps(self, N): |
| N = N + 1 |
| step_indices = torch.arange(N) |
| t_steps = ( |
| self.sigma_min ** (1 / self.rho) |
| + step_indices |
| / (N - 1) |
| * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho)) |
| ) ** self.rho |
|
|
| return t_steps.flip(0) |
|
|
| def CT_sampler(self, latents, cond, nonpadding, t_steps=1): |
| """ |
| consistency distillation sampler |
| |
| Args: |
| latents: noisy mel-spectrogram [B x n_mel x L] |
| cond: output of conformer encoder [B x n_mel x L] |
| nonpadding: mask of padded frames [B x n_mel x L] |
| t_steps: number of steps for diffusion inference |
| |
| Returns: |
| denoised mel-spectrogram [B x n_mel x L] |
| """ |
| |
| if t_steps == 1: |
| t_steps = [80] |
| |
| else: |
| t_steps = self.get_t_steps(t_steps) |
|
|
| t_steps = torch.as_tensor(t_steps).to(latents.device) |
| latents = latents * t_steps[0] |
| _t = torch.zeros((latents.shape[0], 1, 1), device=latents.device) |
| _t[:, 0, 0] = t_steps |
| x = self.EDMPrecond(latents, _t, cond, self.denoise_fn_ema, nonpadding) |
|
|
| for t in t_steps[1:-1]: |
| z = torch.randn_like(x) + cond |
| x_tn = x + (t**2 - self.sigma_min**2).sqrt() * z |
| _t = torch.zeros((x.shape[0], 1, 1), device=x.device) |
| _t[:, 0, 0] = t |
| t = _t |
| print(t) |
| x = self.EDMPrecond(x_tn, t, cond, self.denoise_fn_ema, nonpadding) |
| return x |
|
|
| def forward(self, x, nonpadding, cond, t_steps=1, infer=False): |
| """ |
| calculate loss or sample mel-spectrogram |
| |
| Args: |
| x: |
| training: ground truth mel-spectrogram [B x n_mel x L] |
| inference: output of encoder [B x n_mel x L] |
| """ |
| if self.teacher: |
| if not infer: |
| loss = self.EDMLoss(x, cond, nonpadding) |
| return loss |
| else: |
| shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2]) |
| x = torch.randn(shape, device=x.device) + cond |
| x = self.edm_sampler(x, cond, nonpadding, t_steps) |
|
|
| return x |
| else: |
| if not infer: |
| loss = self.CTLoss_D(x, cond, nonpadding) |
| return loss |
|
|
| else: |
| shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2]) |
| x = torch.randn(shape, device=x.device) + cond |
| x = self.CT_sampler(x, cond, nonpadding, t_steps=1) |
|
|
| return x |
|
|
|
|
| class ComoSVC(BaseModule): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| self.cfg.model.comosvc.n_mel = self.cfg.preprocess.n_mel |
| self.distill = self.cfg.model.comosvc.distill |
| self.encoder = Conformer(self.cfg.model.comosvc) |
| self.decoder = Consistency(self.cfg, distill=self.distill) |
| self.ssim_loss = SSIM() |
|
|
| @torch.no_grad() |
| def forward(self, x_mask, x, n_timesteps, temperature=1.0): |
| """ |
| Generates mel-spectrogram from pitch, content vector, energy. Returns: |
| 1. encoder outputs (from conformer) |
| 2. decoder outputs (from diffusion-based decoder) |
| |
| Args: |
| x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel] |
| x : output of encoder framework. [B x L x d_condition] |
| n_timesteps : number of steps to use for reverse diffusion in decoder. |
| temperature : controls variance of terminal distribution. |
| """ |
|
|
| |
| mu_x = self.encoder(x, x_mask) |
| encoder_outputs = mu_x |
|
|
| mu_x = mu_x.transpose(1, 2) |
| x_mask = x_mask.transpose(1, 2) |
|
|
| |
| decoder_outputs = self.decoder( |
| mu_x, x_mask, mu_x, t_steps=n_timesteps, infer=True |
| ) |
| decoder_outputs = decoder_outputs.transpose(1, 2) |
| return encoder_outputs, decoder_outputs |
|
|
| def compute_loss(self, x_mask, x, mel, out_size=None, skip_diff=False): |
| """ |
| Computes 2 losses: |
| 1. prior loss: loss between mel-spectrogram and encoder outputs. |
| 2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder. |
| |
| Args: |
| x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel] |
| x : output of encoder framework. [B x L x d_condition] |
| mel : ground truth mel-spectrogram. [B x L x n_mel] |
| """ |
|
|
| mu_x = self.encoder(x, x_mask) |
| |
| prior_loss = torch.sum( |
| 0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * x_mask |
| ) |
| prior_loss = prior_loss / (torch.sum(x_mask) * self.cfg.model.comosvc.n_mel) |
| |
| ssim_loss = self.ssim_loss(mu_x, mel) |
| ssim_loss = torch.sum(ssim_loss * x_mask) / torch.sum(x_mask) |
|
|
| x_mask = x_mask.transpose(1, 2) |
| mu_x = mu_x.transpose(1, 2) |
| mel = mel.transpose(1, 2) |
| if not self.distill and skip_diff: |
| diff_loss = prior_loss.clone() |
| diff_loss.fill_(0) |
|
|
| |
| else: |
| if self.distill: |
| mu_y = mu_x.detach() |
| else: |
| mu_y = mu_x |
| mask_y = x_mask |
|
|
| diff_loss = self.decoder(mel, mask_y, mu_y, infer=False) |
|
|
| return ssim_loss, prior_loss, diff_loss |
|
|