| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from dataclasses import dataclass |
| | from typing import Optional, Tuple, Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| |
|
| | from ..configuration_utils import ConfigMixin, register_to_config |
| | from ..utils import BaseOutput |
| | from .scheduling_utils import SchedulerMixin |
| |
|
| |
|
| | @dataclass |
| | class VQDiffusionSchedulerOutput(BaseOutput): |
| | """ |
| | Output class for the scheduler's step function output. |
| | |
| | Args: |
| | prev_sample (`torch.LongTensor` of shape `(batch size, num latent pixels)`): |
| | Computed sample x_{t-1} of previous timestep. `prev_sample` should be used as next model input in the |
| | denoising loop. |
| | """ |
| |
|
| | prev_sample: torch.LongTensor |
| |
|
| |
|
| | def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.Tensor: |
| | """ |
| | Convert batch of vector of class indices into batch of log onehot vectors |
| | |
| | Args: |
| | x (`torch.LongTensor` of shape `(batch size, vector length)`): |
| | Batch of class indices |
| | |
| | num_classes (`int`): |
| | number of classes to be used for the onehot vectors |
| | |
| | Returns: |
| | `torch.Tensor` of shape `(batch size, num classes, vector length)`: |
| | Log onehot vectors |
| | """ |
| | x_onehot = F.one_hot(x, num_classes) |
| | x_onehot = x_onehot.permute(0, 2, 1) |
| | log_x = torch.log(x_onehot.float().clamp(min=1e-30)) |
| | return log_x |
| |
|
| |
|
| | def gumbel_noised(logits: torch.Tensor, generator: Optional[torch.Generator]) -> torch.Tensor: |
| | """ |
| | Apply gumbel noise to `logits` |
| | """ |
| | uniform = torch.rand(logits.shape, device=logits.device, generator=generator) |
| | gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30) |
| | noised = gumbel_noise + logits |
| | return noised |
| |
|
| |
|
| | def alpha_schedules(num_diffusion_timesteps: int, alpha_cum_start=0.99999, alpha_cum_end=0.000009): |
| | """ |
| | Cumulative and non-cumulative alpha schedules. |
| | |
| | See section 4.1. |
| | """ |
| | att = ( |
| | np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (alpha_cum_end - alpha_cum_start) |
| | + alpha_cum_start |
| | ) |
| | att = np.concatenate(([1], att)) |
| | at = att[1:] / att[:-1] |
| | att = np.concatenate((att[1:], [1])) |
| | return at, att |
| |
|
| |
|
| | def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamma_cum_end=0.99999): |
| | """ |
| | Cumulative and non-cumulative gamma schedules. |
| | |
| | See section 4.1. |
| | """ |
| | ctt = ( |
| | np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (gamma_cum_end - gamma_cum_start) |
| | + gamma_cum_start |
| | ) |
| | ctt = np.concatenate(([0], ctt)) |
| | one_minus_ctt = 1 - ctt |
| | one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1] |
| | ct = 1 - one_minus_ct |
| | ctt = np.concatenate((ctt[1:], [0])) |
| | return ct, ctt |
| |
|
| |
|
| | class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): |
| | """ |
| | A scheduler for vector quantized diffusion. |
| | |
| | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic |
| | methods the library implements for all schedulers such as loading and saving. |
| | |
| | Args: |
| | num_vec_classes (`int`): |
| | The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked |
| | latent pixel. |
| | num_train_timesteps (`int`, defaults to 100): |
| | The number of diffusion steps to train the model. |
| | alpha_cum_start (`float`, defaults to 0.99999): |
| | The starting cumulative alpha value. |
| | alpha_cum_end (`float`, defaults to 0.00009): |
| | The ending cumulative alpha value. |
| | gamma_cum_start (`float`, defaults to 0.00009): |
| | The starting cumulative gamma value. |
| | gamma_cum_end (`float`, defaults to 0.99999): |
| | The ending cumulative gamma value. |
| | """ |
| |
|
| | order = 1 |
| |
|
| | @register_to_config |
| | def __init__( |
| | self, |
| | num_vec_classes: int, |
| | num_train_timesteps: int = 100, |
| | alpha_cum_start: float = 0.99999, |
| | alpha_cum_end: float = 0.000009, |
| | gamma_cum_start: float = 0.000009, |
| | gamma_cum_end: float = 0.99999, |
| | ): |
| | self.num_embed = num_vec_classes |
| |
|
| | |
| | self.mask_class = self.num_embed - 1 |
| |
|
| | at, att = alpha_schedules(num_train_timesteps, alpha_cum_start=alpha_cum_start, alpha_cum_end=alpha_cum_end) |
| | ct, ctt = gamma_schedules(num_train_timesteps, gamma_cum_start=gamma_cum_start, gamma_cum_end=gamma_cum_end) |
| |
|
| | num_non_mask_classes = self.num_embed - 1 |
| | bt = (1 - at - ct) / num_non_mask_classes |
| | btt = (1 - att - ctt) / num_non_mask_classes |
| |
|
| | at = torch.tensor(at.astype("float64")) |
| | bt = torch.tensor(bt.astype("float64")) |
| | ct = torch.tensor(ct.astype("float64")) |
| | log_at = torch.log(at) |
| | log_bt = torch.log(bt) |
| | log_ct = torch.log(ct) |
| |
|
| | att = torch.tensor(att.astype("float64")) |
| | btt = torch.tensor(btt.astype("float64")) |
| | ctt = torch.tensor(ctt.astype("float64")) |
| | log_cumprod_at = torch.log(att) |
| | log_cumprod_bt = torch.log(btt) |
| | log_cumprod_ct = torch.log(ctt) |
| |
|
| | self.log_at = log_at.float() |
| | self.log_bt = log_bt.float() |
| | self.log_ct = log_ct.float() |
| | self.log_cumprod_at = log_cumprod_at.float() |
| | self.log_cumprod_bt = log_cumprod_bt.float() |
| | self.log_cumprod_ct = log_cumprod_ct.float() |
| |
|
| | |
| | self.num_inference_steps = None |
| | self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy()) |
| |
|
| | def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): |
| | """ |
| | Sets the discrete timesteps used for the diffusion chain (to be run before inference). |
| | |
| | Args: |
| | num_inference_steps (`int`): |
| | The number of diffusion steps used when generating samples with a pre-trained model. |
| | device (`str` or `torch.device`, *optional*): |
| | The device to which the timesteps and diffusion process parameters (alpha, beta, gamma) should be moved |
| | to. |
| | """ |
| | self.num_inference_steps = num_inference_steps |
| | timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() |
| | self.timesteps = torch.from_numpy(timesteps).to(device) |
| |
|
| | self.log_at = self.log_at.to(device) |
| | self.log_bt = self.log_bt.to(device) |
| | self.log_ct = self.log_ct.to(device) |
| | self.log_cumprod_at = self.log_cumprod_at.to(device) |
| | self.log_cumprod_bt = self.log_cumprod_bt.to(device) |
| | self.log_cumprod_ct = self.log_cumprod_ct.to(device) |
| |
|
| | def step( |
| | self, |
| | model_output: torch.Tensor, |
| | timestep: torch.long, |
| | sample: torch.LongTensor, |
| | generator: Optional[torch.Generator] = None, |
| | return_dict: bool = True, |
| | ) -> Union[VQDiffusionSchedulerOutput, Tuple]: |
| | """ |
| | Predict the sample from the previous timestep by the reverse transition distribution. See |
| | [`~VQDiffusionScheduler.q_posterior`] for more details about how the distribution is computer. |
| | |
| | Args: |
| | log_p_x_0: (`torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`): |
| | The log probabilities for the predicted classes of the initial latent pixels. Does not include a |
| | prediction for the masked class as the initial unnoised image cannot be masked. |
| | t (`torch.long`): |
| | The timestep that determines which transition matrices are used. |
| | x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): |
| | The classes of each latent pixel at time `t`. |
| | generator (`torch.Generator`, or `None`): |
| | A random number generator for the noise applied to `p(x_{t-1} | x_t)` before it is sampled from. |
| | return_dict (`bool`, *optional*, defaults to `True`): |
| | Whether or not to return a [`~schedulers.scheduling_vq_diffusion.VQDiffusionSchedulerOutput`] or |
| | `tuple`. |
| | |
| | Returns: |
| | [`~schedulers.scheduling_vq_diffusion.VQDiffusionSchedulerOutput`] or `tuple`: |
| | If return_dict is `True`, [`~schedulers.scheduling_vq_diffusion.VQDiffusionSchedulerOutput`] is |
| | returned, otherwise a tuple is returned where the first element is the sample tensor. |
| | """ |
| | if timestep == 0: |
| | log_p_x_t_min_1 = model_output |
| | else: |
| | log_p_x_t_min_1 = self.q_posterior(model_output, sample, timestep) |
| |
|
| | log_p_x_t_min_1 = gumbel_noised(log_p_x_t_min_1, generator) |
| |
|
| | x_t_min_1 = log_p_x_t_min_1.argmax(dim=1) |
| |
|
| | if not return_dict: |
| | return (x_t_min_1,) |
| |
|
| | return VQDiffusionSchedulerOutput(prev_sample=x_t_min_1) |
| |
|
| | def q_posterior(self, log_p_x_0, x_t, t): |
| | """ |
| | Calculates the log probabilities for the predicted classes of the image at timestep `t-1`: |
| | |
| | ``` |
| | p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) ) |
| | ``` |
| | |
| | Args: |
| | log_p_x_0 (`torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`): |
| | The log probabilities for the predicted classes of the initial latent pixels. Does not include a |
| | prediction for the masked class as the initial unnoised image cannot be masked. |
| | x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): |
| | The classes of each latent pixel at time `t`. |
| | t (`torch.Long`): |
| | The timestep that determines which transition matrix is used. |
| | |
| | Returns: |
| | `torch.Tensor` of shape `(batch size, num classes, num latent pixels)`: |
| | The log probabilities for the predicted classes of the image at timestep `t-1`. |
| | """ |
| | log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed) |
| |
|
| | log_q_x_t_given_x_0 = self.log_Q_t_transitioning_to_known_class( |
| | t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=True |
| | ) |
| |
|
| | log_q_t_given_x_t_min_1 = self.log_Q_t_transitioning_to_known_class( |
| | t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=False |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | q = log_p_x_0 - log_q_x_t_given_x_0 |
| |
|
| | |
| | |
| | q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | q = q - q_log_sum_exp |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | q = self.apply_cumulative_transitions(q, t - 1) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | log_p_x_t_min_1 = q + log_q_t_given_x_t_min_1 + q_log_sum_exp |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | return log_p_x_t_min_1 |
| |
|
| | def log_Q_t_transitioning_to_known_class( |
| | self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.Tensor, cumulative: bool |
| | ): |
| | """ |
| | Calculates the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each |
| | latent pixel in `x_t`. |
| | |
| | Args: |
| | t (`torch.Long`): |
| | The timestep that determines which transition matrix is used. |
| | x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`): |
| | The classes of each latent pixel at time `t`. |
| | log_onehot_x_t (`torch.Tensor` of shape `(batch size, num classes, num latent pixels)`): |
| | The log one-hot vectors of `x_t`. |
| | cumulative (`bool`): |
| | If cumulative is `False`, the single step transition matrix `t-1`->`t` is used. If cumulative is |
| | `True`, the cumulative transition matrix `0`->`t` is used. |
| | |
| | Returns: |
| | `torch.Tensor` of shape `(batch size, num classes - 1, num latent pixels)`: |
| | Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability |
| | transition matrix. |
| | |
| | When non cumulative, returns `self.num_classes - 1` rows because the initial latent pixel cannot be |
| | masked. |
| | |
| | Where: |
| | - `q_n` is the probability distribution for the forward process of the `n`th latent pixel. |
| | - C_0 is a class of a latent pixel embedding |
| | - C_k is the class of the masked latent pixel |
| | |
| | non-cumulative result (omitting logarithms): |
| | ``` |
| | q_0(x_t | x_{t-1} = C_0) ... q_n(x_t | x_{t-1} = C_0) |
| | . . . |
| | . . . |
| | . . . |
| | q_0(x_t | x_{t-1} = C_k) ... q_n(x_t | x_{t-1} = C_k) |
| | ``` |
| | |
| | cumulative result (omitting logarithms): |
| | ``` |
| | q_0_cumulative(x_t | x_0 = C_0) ... q_n_cumulative(x_t | x_0 = C_0) |
| | . . . |
| | . . . |
| | . . . |
| | q_0_cumulative(x_t | x_0 = C_{k-1}) ... q_n_cumulative(x_t | x_0 = C_{k-1}) |
| | ``` |
| | """ |
| | if cumulative: |
| | a = self.log_cumprod_at[t] |
| | b = self.log_cumprod_bt[t] |
| | c = self.log_cumprod_ct[t] |
| | else: |
| | a = self.log_at[t] |
| | b = self.log_bt[t] |
| | c = self.log_ct[t] |
| |
|
| | if not cumulative: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | log_onehot_x_t_transitioning_from_masked = log_onehot_x_t[:, -1, :].unsqueeze(1) |
| |
|
| | |
| | |
| | |
| | log_onehot_x_t = log_onehot_x_t[:, :-1, :] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | log_Q_t = (log_onehot_x_t + a).logaddexp(b) |
| |
|
| | |
| | mask_class_mask = x_t == self.mask_class |
| | mask_class_mask = mask_class_mask.unsqueeze(1).expand(-1, self.num_embed - 1, -1) |
| | log_Q_t[mask_class_mask] = c |
| |
|
| | if not cumulative: |
| | log_Q_t = torch.cat((log_Q_t, log_onehot_x_t_transitioning_from_masked), dim=1) |
| |
|
| | return log_Q_t |
| |
|
| | def apply_cumulative_transitions(self, q, t): |
| | bsz = q.shape[0] |
| | a = self.log_cumprod_at[t] |
| | b = self.log_cumprod_bt[t] |
| | c = self.log_cumprod_ct[t] |
| |
|
| | num_latent_pixels = q.shape[2] |
| | c = c.expand(bsz, 1, num_latent_pixels) |
| |
|
| | q = (q + a).logaddexp(b) |
| | q = torch.cat((q, c), dim=1) |
| |
|
| | return q |
| |
|