| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | Functions for Noise Schedule, defines diffusion process, reverse process and data processor. |
| | """ |
| |
|
| | from collections import namedtuple |
| | import random |
| | import typing as tp |
| | import julius |
| | import torch |
| |
|
| | TrainingItem = namedtuple("TrainingItem", "noisy noise step") |
| |
|
| |
|
| | def betas_from_alpha_bar(alpha_bar): |
| | alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]]) |
| | return 1 - alphas |
| |
|
| |
|
| | class SampleProcessor(torch.nn.Module): |
| | def project_sample(self, x: torch.Tensor): |
| | """Project the original sample to the 'space' where the diffusion will happen.""" |
| | return x |
| |
|
| | def return_sample(self, z: torch.Tensor): |
| | """Project back from diffusion space to the actual sample space.""" |
| | return z |
| |
|
| |
|
| | class MultiBandProcessor(SampleProcessor): |
| | """ |
| | MultiBand sample processor. The input audio is splitted across |
| | frequency bands evenly distributed in mel-scale. |
| | |
| | Each band will be rescaled to match the power distribution |
| | of Gaussian noise in that band, using online metrics |
| | computed on the first few samples. |
| | |
| | Args: |
| | n_bands (int): Number of mel-bands to split the signal over. |
| | sample_rate (int): Sample rate of the audio. |
| | num_samples (int): Number of samples to use to fit the rescaling |
| | for each band. The processor won't be stable |
| | until it has seen that many samples. |
| | power_std (float or list/tensor): The rescaling factor computed to match the |
| | power of Gaussian noise in each band is taken to |
| | that power, i.e. `1.` means full correction of the energy |
| | in each band, and values less than `1` means only partial |
| | correction. Can be used to balance the relative importance |
| | of low vs. high freq in typical audio signals. |
| | """ |
| | def __init__(self, n_bands: int = 8, sample_rate: float = 24_000, |
| | num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.): |
| | super().__init__() |
| | self.n_bands = n_bands |
| | self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands) |
| | self.num_samples = num_samples |
| | self.power_std = power_std |
| | if isinstance(power_std, list): |
| | assert len(power_std) == n_bands |
| | power_std = torch.tensor(power_std) |
| | self.register_buffer('counts', torch.zeros(1)) |
| | self.register_buffer('sum_x', torch.zeros(n_bands)) |
| | self.register_buffer('sum_x2', torch.zeros(n_bands)) |
| | self.register_buffer('sum_target_x2', torch.zeros(n_bands)) |
| | self.counts: torch.Tensor |
| | self.sum_x: torch.Tensor |
| | self.sum_x2: torch.Tensor |
| | self.sum_target_x2: torch.Tensor |
| |
|
| | @property |
| | def mean(self): |
| | mean = self.sum_x / self.counts |
| | return mean |
| |
|
| | @property |
| | def std(self): |
| | std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() |
| | return std |
| |
|
| | @property |
| | def target_std(self): |
| | target_std = self.sum_target_x2 / self.counts |
| | return target_std |
| |
|
| | def project_sample(self, x: torch.Tensor): |
| | assert x.dim() == 3 |
| | bands = self.split_bands(x) |
| | if self.counts.item() < self.num_samples: |
| | ref_bands = self.split_bands(torch.randn_like(x)) |
| | self.counts += len(x) |
| | self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1) |
| | self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1) |
| | self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1) |
| | rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std |
| | bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1) |
| | return bands.sum(dim=0) |
| |
|
| | def return_sample(self, x: torch.Tensor): |
| | assert x.dim() == 3 |
| | bands = self.split_bands(x) |
| | rescale = (self.std / self.target_std) ** self.power_std |
| | bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1) |
| | return bands.sum(dim=0) |
| |
|
| |
|
| | class NoiseSchedule: |
| | """Noise schedule for diffusion. |
| | |
| | Args: |
| | beta_t0 (float): Variance of the first diffusion step. |
| | beta_t1 (float): Variance of the last diffusion step. |
| | beta_exp (float): Power schedule exponent |
| | num_steps (int): Number of diffusion step. |
| | variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde" |
| | clip (float): clipping value for the denoising steps |
| | rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1) |
| | repartition (str): shape of the schedule only power schedule is supported |
| | sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution |
| | noise_scale (float): Scaling factor for the noise |
| | """ |
| | def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta', |
| | clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1, |
| | repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None, |
| | sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs): |
| |
|
| | self.beta_t0 = beta_t0 |
| | self.beta_t1 = beta_t1 |
| | self.variance = variance |
| | self.num_steps = num_steps |
| | self.clip = clip |
| | self.sample_processor = sample_processor |
| | self.rescale = rescale |
| | self.n_bands = n_bands |
| | self.noise_scale = noise_scale |
| | assert n_bands is None |
| | if repartition == "power": |
| | self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps, |
| | device=device, dtype=torch.float) ** beta_exp |
| | else: |
| | raise RuntimeError('Not implemented') |
| | self.rng = random.Random(1234) |
| |
|
| | def get_beta(self, step: tp.Union[int, torch.Tensor]): |
| | if self.n_bands is None: |
| | return self.betas[step] |
| | else: |
| | return self.betas[:, step] |
| |
|
| | def get_initial_noise(self, x: torch.Tensor): |
| | if self.n_bands is None: |
| | return torch.randn_like(x) |
| | return torch.randn((x.size(0), self.n_bands, x.size(2))) |
| |
|
| | def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor: |
| | """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step.""" |
| | if step is None: |
| | return (1 - self.betas).cumprod(dim=-1) |
| | if type(step) is int: |
| | return (1 - self.betas[:step + 1]).prod() |
| | else: |
| | return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1) |
| |
|
| | def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem: |
| | """Create a noisy data item for diffusion model training: |
| | |
| | Args: |
| | x (torch.Tensor): clean audio data torch.tensor(bs, 1, T) |
| | tensor_step (bool): If tensor_step = false, only one step t is sample, |
| | the whole batch is diffused to the same step and t is int. |
| | If tensor_step = true, t is a tensor of size (x.size(0),) |
| | every element of the batch is diffused to a independently sampled. |
| | """ |
| | step: tp.Union[int, torch.Tensor] |
| | if tensor_step: |
| | bs = x.size(0) |
| | step = torch.randint(0, self.num_steps, size=(bs,), device=x.device) |
| | else: |
| | step = self.rng.randrange(self.num_steps) |
| | alpha_bar = self.get_alpha_bar(step) |
| |
|
| | x = self.sample_processor.project_sample(x) |
| | noise = torch.randn_like(x) |
| | noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale |
| | return TrainingItem(noisy, noise, step) |
| |
|
| | def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None, |
| | condition: tp.Optional[torch.Tensor] = None, return_list: bool = False): |
| | """Full ddpm reverse process. |
| | |
| | Args: |
| | model (nn.Module): Diffusion model. |
| | initial (tensor): Initial Noise. |
| | condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation). |
| | return_list (bool): Whether to return the whole process or only the sampled point. |
| | """ |
| | alpha_bar = self.get_alpha_bar(step=self.num_steps - 1) |
| | current = initial |
| | iterates = [initial] |
| | for step in range(self.num_steps)[::-1]: |
| | with torch.no_grad(): |
| | estimate = model(current, step, condition=condition).sample |
| | alpha = 1 - self.betas[step] |
| | previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt() |
| | previous_alpha_bar = self.get_alpha_bar(step=step - 1) |
| | if step == 0: |
| | sigma2 = 0 |
| | elif self.variance == 'beta': |
| | sigma2 = 1 - alpha |
| | elif self.variance == 'beta_tilde': |
| | sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha) |
| | elif self.variance == 'none': |
| | sigma2 = 0 |
| | else: |
| | raise ValueError(f'Invalid variance type {self.variance}') |
| |
|
| | if sigma2 > 0: |
| | previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale |
| | if self.clip: |
| | previous = previous.clamp(-self.clip, self.clip) |
| | current = previous |
| | alpha_bar = previous_alpha_bar |
| | if step == 0: |
| | previous *= self.rescale |
| | if return_list: |
| | iterates.append(previous.cpu()) |
| |
|
| | if return_list: |
| | return iterates |
| | else: |
| | return self.sample_processor.return_sample(previous) |
| |
|
| | def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None, |
| | condition: tp.Optional[torch.Tensor] = None, return_list: bool = False): |
| | """Reverse process that only goes through Markov chain states in step_list.""" |
| | if step_list is None: |
| | step_list = list(range(1000))[::-50] + [0] |
| | alpha_bar = self.get_alpha_bar(step=self.num_steps - 1) |
| | alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu() |
| | betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled) |
| | current = initial * self.noise_scale |
| | iterates = [current] |
| | for idx, step in enumerate(step_list[:-1]): |
| | with torch.no_grad(): |
| | estimate = model(current, step, condition=condition).sample * self.noise_scale |
| | alpha = 1 - betas_subsampled[-1 - idx] |
| | previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt() |
| | previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1]) |
| | if step == step_list[-2]: |
| | sigma2 = 0 |
| | previous_alpha_bar = torch.tensor(1.0) |
| | else: |
| | sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha) |
| | if sigma2 > 0: |
| | previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale |
| | if self.clip: |
| | previous = previous.clamp(-self.clip, self.clip) |
| | current = previous |
| | alpha_bar = previous_alpha_bar |
| | if step == 0: |
| | previous *= self.rescale |
| | if return_list: |
| | iterates.append(previous.cpu()) |
| | if return_list: |
| | return iterates |
| | else: |
| | return self.sample_processor.return_sample(previous) |
| |
|