| from functools import partial |
| from typing import Callable, Optional, Tuple, Union |
|
|
| import numpy as np |
| import torch |
| import torch.distributions as td |
| import torch.nn as nn |
| from torchdiffeq import odeint |
| from torchsde import sdeint |
| from torchtyping import TensorType |
|
|
| from sim_priors_pk.models.diffusion.noise import GaussianProcess, Normal, OrnsteinUhlenbeck |
|
|
|
|
| class ContinuousDiffusion(nn.Module): |
| """ |
| Continuous diffusion using SDEs (https://arxiv.org/abs/2011.13456) |
| |
| Args: |
| dim: Dimension of data |
| beta_fn: Scheduler for noise levels |
| t1: Final diffusion time |
| noise_fn: Type of noise |
| predict_gaussian_noise: Whether to approximate score with unit normal |
| loss_weighting: Function returning loss weights given diffusion time |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| beta_fn: Callable, |
| t1: float = 1.0, |
| noise_fn: Callable = None, |
| loss_weighting: Callable = None, |
| is_time_series: bool = False, |
| predict_gaussian_noise: bool = True, |
| **kwargs, |
| ): |
| super().__init__() |
| self.dim = dim |
| self.t1 = t1 |
| self.predict_gaussian_noise = predict_gaussian_noise |
| self.is_time_series = is_time_series |
|
|
| self.beta_fn = beta_fn |
| self.noise = noise_fn |
| self.loss_weighting = partial(loss_weighting or (lambda beta, i: 1), beta_fn) |
|
|
| def forward( |
| self, |
| x: TensorType[..., "dim"], |
| i: TensorType[..., 1], |
| _return_all: Optional[bool] = False, |
| **kwargs, |
| ) -> Tuple[TensorType[..., "dim"], TensorType[..., "dim"]]: |
| noise_gaussian = torch.randn_like(x) |
|
|
| if self.is_time_series: |
| cov = self.noise.covariance(**kwargs) |
| L = torch.linalg.cholesky(cov) |
| noise = L @ noise_gaussian |
| else: |
| noise = noise_gaussian |
|
|
| beta_int = self.beta_fn.integral(i) |
|
|
| mean = x * torch.exp(-beta_int / 2) |
| std = (1 - torch.exp(-beta_int)).clamp(1e-5).sqrt() |
|
|
| y = mean + std * noise |
|
|
| if _return_all: |
| return y, noise, mean, std, cov if self.is_time_series else None |
|
|
| if self.predict_gaussian_noise: |
| return y, noise_gaussian |
| else: |
| return y, noise |
|
|
| def get_loss( |
| self, |
| model: Callable, |
| x: TensorType[..., "dim"], |
| **kwargs, |
| ) -> TensorType[..., 1]: |
| i = torch.rand(x.shape[0], *(1,) * len(x.shape[1:])).expand_as(x[..., :1]).to(x) |
| i = i * self.t1 |
|
|
| x_noisy, noise = self.forward(x, i, **kwargs) |
|
|
| pred_noise = model(x_noisy, i=i, **kwargs) |
| loss = self.loss_weighting(i) * (pred_noise - noise) ** 2 |
|
|
| return loss |
|
|
| def _get_score(self, model, x, i, L=None, **kwargs): |
| """ |
| Returns score: ∇_xs log p(xs) |
| """ |
| if isinstance(i, float): |
| i = torch.Tensor([i]).to(x) |
| if i.shape[:-1] != x.shape[:-1]: |
| i = i.view(*(1,) * len(x.shape)).expand_as(x[..., :1]) |
|
|
| beta_int = self.beta_fn.integral(i) |
| std = (1 - torch.exp(-beta_int)).clamp(1e-5).sqrt() |
|
|
| noise = model(x, i=i, **kwargs) |
|
|
| if L is not None: |
| |
| |
| |
| |
| |
| |
| if not self.predict_gaussian_noise: |
| noise = torch.linalg.solve_triangular(L, noise, upper=False) |
| noise = torch.linalg.solve_triangular(L.transpose(-1, -2), noise, upper=True) |
|
|
| score = -noise / std |
| return score |
|
|
| @torch.no_grad() |
| def log_prob( |
| self, |
| model: Callable, |
| x: Union[TensorType[..., "dim"], TensorType[..., "seq_len", "dim"]], |
| num_samples: int = 1, |
| **kwargs, |
| ) -> TensorType[..., 1]: |
| model.train() |
| self._e = torch.randn(num_samples, *x.shape).to(x) |
|
|
| if self.is_time_series: |
| cov = self.noise.covariance(**kwargs) |
| L = torch.linalg.cholesky(cov) |
| else: |
| L = None |
|
|
| def drift(i, state): |
| y, _ = state |
| with torch.set_grad_enabled(True): |
| y = y.requires_grad_(True) |
| score = self._get_score(model, y, i=i, L=L, **kwargs) |
| if self.is_time_series: |
| |
| score = cov @ score |
| dy = -0.5 * self.beta_fn(i) * (y + score) |
| divergence = divergence_approx(dy, y, self._e, num_samples=num_samples) |
| return dy, -divergence |
|
|
| interval = torch.Tensor([0, self.t1]).to(x) |
|
|
| |
| states = odeint( |
| drift, |
| (x, torch.zeros_like(x).to(x)), |
| interval, |
| method="rk4", |
| options={"step_size": 0.01}, |
| ) |
| y, div = states[0][-1], states[1][-1] |
|
|
| if self.is_time_series: |
| p0 = td.Independent( |
| torch.distributions.MultivariateNormal( |
| torch.zeros_like(y).transpose(-1, -2), |
| cov.unsqueeze(-3).repeat_interleave(self.dim, dim=-3), |
| ), |
| 1, |
| ) |
| log_prob = p0.log_prob(y.transpose(-1, -2)) - div.sum([-1, -2]) |
| log_prob = log_prob / x.shape[-2] |
| else: |
| p0 = td.Independent(td.Normal(torch.zeros_like(y), torch.ones_like(y)), 1) |
| log_prob = p0.log_prob(y) - div.sum(-1) |
|
|
| return log_prob.unsqueeze(-1) |
|
|
| @torch.no_grad() |
| def sample( |
| self, |
| model: Callable, |
| num_samples: int, |
| device: str = None, |
| use_ode: bool = True, |
| **kwargs, |
| ) -> TensorType["num_samples", "dim"]: |
| if isinstance(num_samples, int): |
| num_samples = (num_samples,) |
|
|
| sampler = self.ode_sample if use_ode else self.sde_sample |
| return sampler(model, num_samples, device, **kwargs) |
|
|
| @torch.no_grad() |
| def ode_sample( |
| self, |
| model: Callable, |
| num_samples: int, |
| device: str = None, |
| **kwargs, |
| ) -> TensorType["num_samples", "dim"]: |
| if self.is_time_series: |
| cov = self.noise.covariance(**kwargs) |
| L = torch.linalg.cholesky(cov) |
| else: |
| L = None |
|
|
| def drift(i, y): |
| score = self._get_score(model, y, i=i, L=L, **kwargs) |
| if self.is_time_series: |
| |
| score = cov @ score |
| return -0.5 * self.beta_fn(i) * (y + score) |
|
|
| x = self.noise(*num_samples, **kwargs).to(device) |
| t = torch.Tensor([self.t1, 0]).to(device) |
| y = odeint(drift, x, t, method="rk4", options={"step_size": 0.01})[1] |
| |
|
|
| return y |
|
|
| @torch.no_grad() |
| def sde_sample( |
| self, |
| model: Callable, |
| num_samples: int, |
| device: str = None, |
| **kwargs, |
| ) -> TensorType["num_samples", "dim"]: |
| if self.is_time_series: |
| cov = self.noise.covariance(**kwargs) |
| L = torch.linalg.cholesky(cov) |
| else: |
| L = None |
|
|
| is_time_series = self.is_time_series |
|
|
| x = self.noise(*num_samples, **kwargs).to(device) |
| shape = x.shape |
| x = x.transpose(-2, -1).flatten(0, -2) |
|
|
| class SDE(nn.Module): |
| noise_type = "general" if is_time_series else "diagonal" |
| sde_type = "ito" |
|
|
| def __init__(self, beta_fn, _get_score): |
| super().__init__() |
| self.beta_fn = beta_fn |
| self._get_score = _get_score |
|
|
| def f(self, i, inp): |
| i = -i |
| inp = inp.view(*shape) |
|
|
| score = self._get_score(model, inp, i=i, L=L, **kwargs) |
| if is_time_series: |
| score = cov @ score |
|
|
| dx = self.beta_fn(i) * (0.5 * inp + score) |
|
|
| if is_time_series: |
| return dx.transpose(-1, -2).flatten(0, -2) |
| return dx.view(-1, shape[-1]) |
|
|
| def g(self, i, inp): |
| i = -i |
| beta = -self.beta_fn(i).sqrt() |
|
|
| if is_time_series: |
| return (beta * L).repeat_interleave(shape[-1], dim=0) |
| return beta.view(1, 1).repeat(np.prod(shape[:-1]), shape[-1]).to(device) |
|
|
| sde = SDE(self.beta_fn, self._get_score) |
| interval = torch.Tensor([-self.t1, 0]).to(device) |
|
|
| step_size = self.t1 / 100 |
| if not is_time_series: |
| x = x.view(-1, shape[-1]) |
| else: |
| x = x.view(-1, shape[-2]) |
| y = sdeint(sde, x, interval, dt=step_size)[-1] |
| y = y.view(*shape) |
|
|
| return y |
|
|
|
|
| class ContinuousGaussianDiffusion(ContinuousDiffusion): |
| """Continuous diffusion using Gaussian noise""" |
|
|
| def __init__(self, dim: int, beta_fn: Callable, predict_gaussian_noise=None, **kwargs): |
| super().__init__(dim, beta_fn, noise_fn=Normal(dim), predict_gaussian_noise=True, **kwargs) |
|
|
|
|
| class ContinuousOUDiffusion(ContinuousDiffusion): |
| """Continuous diffusion using noise coming from an OU process""" |
|
|
| def __init__( |
| self, |
| dim: int, |
| beta_fn: Callable, |
| predict_gaussian_noise: bool = False, |
| theta: float = 0.5, |
| **kwargs, |
| ): |
| super().__init__( |
| dim=dim, |
| beta_fn=beta_fn, |
| noise_fn=OrnsteinUhlenbeck(dim, theta=theta), |
| predict_gaussian_noise=predict_gaussian_noise, |
| is_time_series=True, |
| **kwargs, |
| ) |
|
|
|
|
| class ContinuousGPDiffusion(ContinuousDiffusion): |
| """Continuous diffusion using noise coming from a Gaussian process""" |
|
|
| def __init__( |
| self, |
| dim: int, |
| beta_fn: Callable, |
| predict_gaussian_noise: bool = False, |
| sigma: float = 0.1, |
| **kwargs, |
| ): |
| super().__init__( |
| dim=dim, |
| beta_fn=beta_fn, |
| noise_fn=GaussianProcess(dim, sigma=sigma), |
| predict_gaussian_noise=predict_gaussian_noise, |
| is_time_series=True, |
| **kwargs, |
| ) |
|
|
|
|
| def divergence_approx(output, input, e, num_samples=1): |
| out = 0 |
| for i in range(num_samples): |
| out += torch.autograd.grad(output, input, e[i], create_graph=True)[0].detach() * e[i] |
| return out / num_samples |
|
|