| """ |
| Abstract SDE classes, Reverse SDE, and VE/VP SDEs. |
| |
| Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py |
| """ |
| import abc |
| import warnings |
|
|
| import numpy as np |
| from sgmse.util.tensors import batch_broadcast |
| import torch |
|
|
| from sgmse.util.registry import Registry |
|
|
|
|
| SDERegistry = Registry("SDE") |
|
|
|
|
| class SDE(abc.ABC): |
| """SDE abstract class. Functions are designed for a mini-batch of inputs.""" |
|
|
| def __init__(self, N): |
| """Construct an SDE. |
| |
| Args: |
| N: number of discretization time steps. |
| """ |
| super().__init__() |
| self.N = N |
|
|
| @property |
| @abc.abstractmethod |
| def T(self): |
| """End time of the SDE.""" |
| pass |
|
|
| @abc.abstractmethod |
| def sde(self, x, t, *args): |
| pass |
|
|
| @abc.abstractmethod |
| def marginal_prob(self, x, t, *args): |
| """Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$.""" |
| pass |
|
|
| @abc.abstractmethod |
| def prior_sampling(self, shape, *args): |
| """Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`.""" |
| pass |
|
|
| @abc.abstractmethod |
| def prior_logp(self, z): |
| """Compute log-density of the prior distribution. |
| |
| Useful for computing the log-likelihood via probability flow ODE. |
| |
| Args: |
| z: latent code |
| Returns: |
| log probability density |
| """ |
| pass |
|
|
| @staticmethod |
| @abc.abstractmethod |
| def add_argparse_args(parent_parser): |
| """ |
| Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser. |
| """ |
| pass |
|
|
| def discretize(self, x, t, y, stepsize): |
| """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. |
| |
| Useful for reverse diffusion sampling and probabiliy flow sampling. |
| Defaults to Euler-Maruyama discretization. |
| |
| Args: |
| x: a torch tensor |
| t: a torch float representing the time step (from 0 to `self.T`) |
| |
| Returns: |
| f, G |
| """ |
| dt = stepsize |
| drift, diffusion = self.sde(x, t, y) |
| f = drift * dt |
| G = diffusion * torch.sqrt(dt) |
| return f, G |
|
|
| def reverse(oself, score_model, probability_flow=False): |
| """Create the reverse-time SDE/ODE. |
| |
| Args: |
| score_model: A function that takes x, t and y and returns the score. |
| probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling. |
| """ |
| N = oself.N |
| T = oself.T |
| sde_fn = oself.sde |
| discretize_fn = oself.discretize |
|
|
| |
| class RSDE(oself.__class__): |
| def __init__(self): |
| self.N = N |
| self.probability_flow = probability_flow |
|
|
| @property |
| def T(self): |
| return T |
|
|
| def sde(self, x, t, *args): |
| """Create the drift and diffusion functions for the reverse SDE/ODE.""" |
| rsde_parts = self.rsde_parts(x, t, *args) |
| total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"] |
| return total_drift, diffusion |
|
|
| def rsde_parts(self, x, t, *args): |
| sde_drift, sde_diffusion = sde_fn(x, t, *args) |
| score = score_model(x, t, *args) |
| score_drift = -sde_diffusion[:, None, None, None]**2 * score * (0.5 if self.probability_flow else 1.) |
| diffusion = torch.zeros_like(sde_diffusion) if self.probability_flow else sde_diffusion |
| total_drift = sde_drift + score_drift |
| return { |
| 'total_drift': total_drift, 'diffusion': diffusion, 'sde_drift': sde_drift, |
| 'sde_diffusion': sde_diffusion, 'score_drift': score_drift, 'score': score, |
| } |
|
|
| def discretize(self, x, t, y, stepsize): |
| """Create discretized iteration rules for the reverse diffusion sampler.""" |
| f, G = discretize_fn(x, t, y, stepsize) |
| rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y) * (0.5 if self.probability_flow else 1.) |
| rev_G = torch.zeros_like(G) if self.probability_flow else G |
| return rev_f, rev_G |
|
|
| return RSDE() |
|
|
| @abc.abstractmethod |
| def copy(self): |
| pass |
|
|
|
|
| @SDERegistry.register("ouve") |
| class OUVESDE(SDE): |
| @staticmethod |
| def add_argparse_args(parser): |
| parser.add_argument("--sde-n", type=int, default=1000, help="The number of timesteps in the SDE discretization. 30 by default") |
| parser.add_argument("--theta", type=float, default=1.5, help="The constant stiffness of the Ornstein-Uhlenbeck process. 1.5 by default.") |
| parser.add_argument("--sigma-min", type=float, default=0.05, help="The minimum sigma to use. 0.05 by default.") |
| parser.add_argument("--sigma-max", type=float, default=0.5, help="The maximum sigma to use. 0.5 by default.") |
| return parser |
|
|
| def __init__(self, theta, sigma_min, sigma_max, N=1000, **ignored_kwargs): |
| """Construct an Ornstein-Uhlenbeck Variance Exploding SDE. |
| |
| Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument |
| to the methods which require it (e.g., `sde` or `marginal_prob`). |
| |
| dx = -theta (y-x) dt + sigma(t) dw |
| |
| with |
| |
| sigma(t) = sigma_min (sigma_max/sigma_min)^t * sqrt(2 log(sigma_max/sigma_min)) |
| |
| Args: |
| theta: stiffness parameter. |
| sigma_min: smallest sigma. |
| sigma_max: largest sigma. |
| N: number of discretization steps |
| """ |
| super().__init__(N) |
| self.theta = theta |
| self.sigma_min = sigma_min |
| self.sigma_max = sigma_max |
| self.logsig = np.log(self.sigma_max / self.sigma_min) |
| self.N = N |
|
|
| def copy(self): |
| return OUVESDE(self.theta, self.sigma_min, self.sigma_max, N=self.N) |
|
|
| @property |
| def T(self): |
| return 1 |
|
|
| def sde(self, x, t, y): |
| drift = self.theta * (y - x) |
| |
| |
| |
| |
| sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t |
| diffusion = sigma * np.sqrt(2 * self.logsig) |
| return drift, diffusion |
|
|
| def _mean(self, x0, t, y): |
| theta = self.theta |
| exp_interp = torch.exp(-theta * t)[:, None, None, None] |
| return exp_interp * x0 + (1 - exp_interp) * y |
|
|
| def alpha(self, t): |
| return torch.exp(-self.theta * t) |
|
|
| def _std(self, t): |
| |
| sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig |
| |
| return torch.sqrt( |
| ( |
| sigma_min**2 |
| * torch.exp(-2 * theta * t) |
| * (torch.exp(2 * (theta + logsig) * t) - 1) |
| * logsig |
| ) |
| / |
| (theta + logsig) |
| ) |
|
|
| def marginal_prob(self, x0, t, y): |
| return self._mean(x0, t, y), self._std(t) |
|
|
| def prior_sampling(self, shape, y): |
| if shape != y.shape: |
| warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.") |
| std = self._std(torch.ones((y.shape[0],), device=y.device)) |
| x_T = y + torch.randn_like(y) * std[:, None, None, None] |
| return x_T |
|
|
| def prior_logp(self, z): |
| raise NotImplementedError("prior_logp for OU SDE not yet implemented!") |
|
|
|
|
| @SDERegistry.register("ouvp") |
| class OUVPSDE(SDE): |
| |
| @staticmethod |
| def add_argparse_args(parser): |
| parser.add_argument("--sde-n", type=int, default=1000, |
| help="The number of timesteps in the SDE discretization. 1000 by default") |
| parser.add_argument("--beta-min", type=float, required=True, |
| help="The minimum beta to use.") |
| parser.add_argument("--beta-max", type=float, required=True, |
| help="The maximum beta to use.") |
| parser.add_argument("--stiffness", type=float, default=1, |
| help="The stiffness factor for the drift, to be multiplied by 0.5*beta(t). 1 by default.") |
| return parser |
|
|
| def __init__(self, beta_min, beta_max, stiffness=1, N=1000, **ignored_kwargs): |
| """ |
| !!! We do not utilize this SDE in our works due to observed instabilities around t=0.2. !!! |
| |
| Construct an Ornstein-Uhlenbeck Variance Preserving SDE: |
| |
| dx = -1/2 * beta(t) * stiffness * (y-x) dt + sqrt(beta(t)) * dw |
| |
| with |
| |
| beta(t) = beta_min + t(beta_max - beta_min) |
| |
| Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument |
| to the methods which require it (e.g., `sde` or `marginal_prob`). |
| |
| Args: |
| beta_min: smallest sigma. |
| beta_max: largest sigma. |
| stiffness: stiffness factor of the drift. 1 by default. |
| N: number of discretization steps |
| """ |
| super().__init__(N) |
| self.beta_min = beta_min |
| self.beta_max = beta_max |
| self.stiffness = stiffness |
| self.N = N |
|
|
| def copy(self): |
| return OUVPSDE(self.beta_min, self.beta_max, self.stiffness, N=self.N) |
|
|
| @property |
| def T(self): |
| return 1 |
|
|
| def _beta(self, t): |
| return self.beta_min + t * (self.beta_max - self.beta_min) |
|
|
| def sde(self, x, t, y): |
| drift = 0.5 * self.stiffness * batch_broadcast(self._beta(t), y) * (y - x) |
| diffusion = torch.sqrt(self._beta(t)) |
| return drift, diffusion |
|
|
| def _mean(self, x0, t, y): |
| b0, b1, s = self.beta_min, self.beta_max, self.stiffness |
| x0y_fac = torch.exp(-0.25 * s * t * (t * (b1-b0) + 2 * b0))[:, None, None, None] |
| return y + x0y_fac * (x0 - y) |
|
|
| def _std(self, t): |
| b0, b1, s = self.beta_min, self.beta_max, self.stiffness |
| return (1 - torch.exp(-0.5 * s * t * (t * (b1-b0) + 2 * b0))) / s |
|
|
| def marginal_prob(self, x0, t, y): |
| return self._mean(x0, t, y), self._std(t) |
|
|
| def prior_sampling(self, shape, y): |
| if shape != y.shape: |
| warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.") |
| std = self._std(torch.ones((y.shape[0],), device=y.device)) |
| x_T = y + torch.randn_like(y) * std[:, None, None, None] |
| return x_T |
|
|
| def prior_logp(self, z): |
| raise NotImplementedError("prior_logp for OU SDE not yet implemented!") |
|
|