from typing import Any, Callable, Tuple, Union import torch import torch.distributions as td import torch.nn as nn from torchtyping import TensorType from sim_priors_pk.models.diffusion.noise import GaussianProcess, Normal, OrnsteinUhlenbeck class DiscreteDiffusion(nn.Module): """ Discrete diffusion (https://arxiv.org/abs/2006.11239) Args: dim: Dimension of data num_steps: Number of diffusion steps beta_fn: Scheduler for noise levels noise_fn: Type of noise parallel_elbo: Whether to compute ELBO in parallel or not """ def __init__( self, dim: int, num_steps: int, beta_fn: Callable, noise_fn: Callable, parallel_elbo: bool = False, is_time_series: bool = False, predict_gaussian_noise: bool = True, **kwargs, ): super().__init__() self.dim = dim self.num_steps = num_steps self.parallel_elbo = parallel_elbo self.is_time_series = is_time_series self.predict_gaussian_noise = predict_gaussian_noise betas = beta_fn(torch.linspace(0, 1, num_steps)) alphas = torch.cumprod(1 - betas, dim=0) self.register_buffer("betas", betas) # Register betas as a buffer self.register_buffer("alphas", alphas) # Register alphas as a buffer self.noise = noise_fn def forward( self, x: TensorType[..., "dim"], # noqa: F821 i: TensorType[..., 1], **kwargs, ) -> Tuple[TensorType[..., "dim"], TensorType[..., "dim"]]: # noqa: F821 noise_gaussian = torch.randn_like(x) if self.is_time_series: cov = self.noise.covariance(**kwargs) L = torch.linalg.cholesky(cov) noise = L @ noise_gaussian else: noise = noise_gaussian alpha = self.alphas[i.long()].to(x) y = torch.sqrt(alpha) * x + torch.sqrt(1 - alpha) * noise if self.predict_gaussian_noise: return y, noise_gaussian else: return y, noise def get_loss( self, model: Callable, x: TensorType[..., "dim"], **kwargs, ) -> TensorType[..., "dim"]: i = torch.randint(0, self.num_steps, size=(x.shape[0],)) i = i.view(-1, *(1,) * len(x.shape[1:])).expand_as(x[..., :1]).to(x) x_noisy, noise = self.forward(x, i, **kwargs) pred_noise = model(x_noisy, i=i, **kwargs) loss = (pred_noise - noise) ** 2 return loss @torch.no_grad() def sample( self, model: Callable, num_samples: Union[int, Tuple], device: str = "cpu", **kwargs, ) -> TensorType["*num_samples", "dim"]: if isinstance(num_samples, int): num_samples = (num_samples,) x = self.noise(*num_samples, **kwargs).to(device) if self.is_time_series and self.predict_gaussian_noise: cov = self.noise.covariance(**kwargs) L = torch.linalg.cholesky(cov) else: L = None for diff_step in reversed(range(0, self.num_steps)): alpha = self.alphas[diff_step] beta = self.betas[diff_step] # An alternative can be: # alpha_prev = self.alphas[diff_step - 1] # sigma = beta * (1 - alpha_prev) / (1 - alpha) sigma = beta if diff_step == 0: z = 0 else: z = self.noise(*num_samples, **kwargs).to(device) i = torch.Tensor([diff_step]).expand_as(x[..., :1]).to(device) pred_noise = model(x, i=i, **kwargs) if L is not None: pred_noise = L @ pred_noise x = (x - beta * pred_noise / (1 - alpha).sqrt()) / (1 - beta).sqrt() + sigma.sqrt() * z return x @torch.no_grad() def log_prob( self, model: Callable, x: TensorType[..., "dim"], num_samples: int = 1, **kwargs, ) -> TensorType[..., 1]: if self.is_time_series and self.predict_gaussian_noise: cov = self.noise.covariance(**kwargs) L = torch.linalg.cholesky(cov) else: L = None func = self._elbo_parallel if self.parallel_elbo else self._elbo_sequential return func(model, x, num_samples=num_samples, L=L, **kwargs) def _elbo_parallel( self, model: Callable, x: TensorType[..., "dim"], L: TensorType[..., "seq_len", "seq_len"], num_samples: int = 1, **kwargs, ) -> TensorType[..., 1]: """ Computes ELBO over all diffusion steps in parallel, then averages over `num_samples` runs. If diffusion `num_steps` large (and `num_samples` small) it will be heavy on the GPU memory. Args: model: Denoising diffusion model x: Clean input data num_samples: How many times to compute ELBO, final result is averaged over all ELBO samples **kwargs: Can be time, latent etc. depending on a model """ elbo = 0 i = expand_to_x(torch.arange(self.num_steps), x).expand(-1, *x[..., :1].shape).contiguous() alphas = expand_to_x(self.alphas, x) betas = expand_to_x(self.betas, x) xt, kwargs = expand_x_and_kwargs(x, kwargs, self.num_steps) for _ in range(num_samples): # Get diffused outputs xt, _ = self.forward(x, i, **kwargs) # [num_steps, ..., dim] # Output predicted noise epsilon = model(xt, i=i, **kwargs) if L is not None: epsilon = L @ epsilon # p(x_{t-1} | p_t) p_mu = get_p_mu(xt, betas, alphas, epsilon) px = td.Independent(td.Normal(p_mu[1:], betas[1:].sqrt()), 1) # p(x_0 | x_1) log_prob_x0_x1 = td.Independent(td.Normal(p_mu[0], betas[0].sqrt()), 1).log_prob(x) assert log_prob_x0_x1.shape == x.shape[:-1] # q(x_{t-1} | x_0, x_t), t > 1 qx = get_qx(x.unsqueeze(0), xt[1:], alphas[1:], alphas[:-1], betas[1:]) # KL[q(x_{t-1} | p_t) || p(x_{t-1} | p_t)] kl_q_p = td.kl_divergence(qx, px).sum(0) assert kl_q_p.shape == x.shape[:-1] # ELBO elbo_contribution = (log_prob_x0_x1 - kl_q_p) / num_samples elbo += elbo_contribution elbo = reduce_elbo(elbo, x) return elbo def _elbo_sequential( self, model: Callable, x: TensorType[..., "dim"], L: TensorType[..., "seq_len", "seq_len"], num_samples: int = 1, **kwargs, ) -> TensorType[..., 1]: """ Computes ELBO as a sum of diffusion steps - sequentially. Args: model: Denoising diffusion model x: Clean input data num_samples: How many times to compute ELBO, final result is averaged over all ELBO samples **kwargs: Can be time, latent etc. depending on a model """ elbo = 0 x, kwargs = expand_x_and_kwargs(x, kwargs, num_samples) for i in range(self.num_steps): # Prepare variables beta = self.betas[i].to(x) alpha = self.alphas[i].to(x) step = torch.Tensor([i]).expand_as(x[..., :1]).to(x) # Diffuse and predict noise xt, _ = self.forward(x, i=step, **kwargs) epsilon = model(xt, i=step, **kwargs) if L is not None: epsilon = L @ epsilon assert xt.shape == x.shape == epsilon.shape # p(x_{t-1} | p_t) p_mu = get_p_mu(xt, beta, alpha, epsilon) px = td.Independent(td.Normal(p_mu, beta.sqrt()), 1) if i == 0: elbo = elbo + px.log_prob(x).mean(0) else: prev_alpha = self.alphas[i - 1] # q(x_{t-1} | x_0, x_t), t > 1 qx = get_qx(x, xt, alpha, prev_alpha, beta) # KL[q(x_{t-1} | p_t) || p(x_{t-1} | p_t)] kl = td.kl_divergence(qx, px).mean(0) elbo = elbo - kl elbo = reduce_elbo(elbo, x) return elbo class GaussianDiffusion(DiscreteDiffusion): """Discrete diffusion with Gaussian noise""" def __init__(self, dim: int, num_steps: int, beta_fn: Callable, **kwargs): super().__init__(dim, num_steps, beta_fn, noise_fn=Normal(dim), **kwargs) class OUDiffusion(DiscreteDiffusion): """Discrete diffusion with noise coming from an OU process""" def __init__( self, dim: int, num_steps: int, beta_fn: Callable, predict_gaussian_noise: bool, theta: float = 0.5, **kwargs, ): super().__init__( dim=dim, num_steps=num_steps, beta_fn=beta_fn, noise_fn=OrnsteinUhlenbeck(dim, theta=theta), is_time_series=True, predict_gaussian_noise=predict_gaussian_noise, **kwargs, ) class GPDiffusion(DiscreteDiffusion): """Discrete diffusion with noise coming from a Gaussian process""" def __init__( self, dim: int, num_steps: int, beta_fn: Callable, predict_gaussian_noise: bool, sigma: float = 0.1, **kwargs, ): super().__init__( dim=dim, num_steps=num_steps, beta_fn=beta_fn, noise_fn=GaussianProcess(dim, sigma=sigma), is_time_series=True, predict_gaussian_noise=predict_gaussian_noise, **kwargs, ) def expand_to_x(inputs, x): return inputs.view(-1, *(1,) * len(x.shape)).to(x) def expand_x_and_kwargs(x, kwargs, N): # Expand dimensions x = x.unsqueeze(0).repeat_interleave(N, dim=0) # A hacky solution to repeat dimensions in all kwargs (latent, t, etc.) for key, value in kwargs.items(): if torch.is_tensor(value): kwargs[key] = value.unsqueeze(0).repeat_interleave(N, dim=0) return x, kwargs def reduce_elbo( elbo: TensorType["batch", Any], x: TensorType[Any], ) -> TensorType["batch", 1]: # Reduce ELBO over all but batch dimension: (B, ...) -> (B,) elbo = elbo.view(elbo.shape[0], -1).sum(1) if len(x.shape) > 2: elbo = elbo / x.shape[-2] return elbo.unsqueeze(1) def get_p_mu(xt, beta, alpha, epsilon): mu = 1 / (1 - beta).sqrt() * (xt - beta / (1 - alpha).sqrt() * epsilon) return mu def get_qx(x, xt, alpha, prev_alpha, beta): q_mu_1 = torch.sqrt(prev_alpha) * beta / (1 - alpha) * x q_mu_2 = torch.sqrt(1 - beta) * (1 - prev_alpha) / (1 - alpha) * xt q_mu = q_mu_1 + q_mu_2 q_sigma = beta * (1 - prev_alpha) / (1 - alpha) qx = td.Independent(td.Normal(q_mu, q_sigma.expand_as(q_mu).sqrt()), 1) return qx