trenden commited on
Commit
23bac1d
·
verified ·
1 Parent(s): 84ae99b

Upload sgmse/sdes.py

Browse files
Files changed (1) hide show
  1. sgmse/sdes.py +313 -0
sgmse/sdes.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
3
+
4
+ Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py
5
+ """
6
+ import abc
7
+ import warnings
8
+
9
+ import numpy as np
10
+ from sgmse.util.tensors import batch_broadcast
11
+ import torch
12
+
13
+ from sgmse.util.registry import Registry
14
+
15
+
16
+ SDERegistry = Registry("SDE")
17
+
18
+
19
+ class SDE(abc.ABC):
20
+ """SDE abstract class. Functions are designed for a mini-batch of inputs."""
21
+
22
+ def __init__(self, N):
23
+ """Construct an SDE.
24
+
25
+ Args:
26
+ N: number of discretization time steps.
27
+ """
28
+ super().__init__()
29
+ self.N = N
30
+
31
+ @property
32
+ @abc.abstractmethod
33
+ def T(self):
34
+ """End time of the SDE."""
35
+ pass
36
+
37
+ @abc.abstractmethod
38
+ def sde(self, x, y, t, *args):
39
+ pass
40
+
41
+ @abc.abstractmethod
42
+ def marginal_prob(self, x, y, t, *args):
43
+ """Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$."""
44
+ pass
45
+
46
+ @abc.abstractmethod
47
+ def prior_sampling(self, shape, *args):
48
+ """Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`."""
49
+ pass
50
+
51
+ @abc.abstractmethod
52
+ def prior_logp(self, z):
53
+ """Compute log-density of the prior distribution.
54
+
55
+ Useful for computing the log-likelihood via probability flow ODE.
56
+
57
+ Args:
58
+ z: latent code
59
+ Returns:
60
+ log probability density
61
+ """
62
+ pass
63
+
64
+ @staticmethod
65
+ @abc.abstractmethod
66
+ def add_argparse_args(parent_parser):
67
+ """
68
+ Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser.
69
+ """
70
+ pass
71
+
72
+ def discretize(self, x, y, t, stepsize):
73
+ """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
74
+
75
+ Useful for reverse diffusion sampling and probabiliy flow sampling.
76
+ Defaults to Euler-Maruyama discretization.
77
+
78
+ Args:
79
+ x: a torch tensor
80
+ t: a torch float representing the time step (from 0 to `self.T`)
81
+
82
+ Returns:
83
+ f, G
84
+ """
85
+ dt = stepsize
86
+ drift, diffusion = self.sde(x, y, t)
87
+ f = drift * dt
88
+ G = diffusion * torch.sqrt(dt)
89
+ return f, G
90
+
91
+ def reverse(oself, score_model, probability_flow=False):
92
+ """Create the reverse-time SDE/ODE.
93
+
94
+ Args:
95
+ score_model: A function that takes x, t and y and returns the score.
96
+ probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
97
+ """
98
+ N = oself.N
99
+ T = oself.T
100
+ sde_fn = oself.sde
101
+ discretize_fn = oself.discretize
102
+
103
+ # Build the class for reverse-time SDE.
104
+ class RSDE(oself.__class__):
105
+ def __init__(self):
106
+ self.N = N
107
+ self.probability_flow = probability_flow
108
+
109
+ @property
110
+ def T(self):
111
+ return T
112
+
113
+ def sde(self, x, y, t, *args):
114
+ """Create the drift and diffusion functions for the reverse SDE/ODE."""
115
+ rsde_parts = self.rsde_parts(x, y, t, *args)
116
+ total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"]
117
+ return total_drift, diffusion
118
+
119
+ def rsde_parts(self, x, y, t, *args):
120
+ sde_drift, sde_diffusion = sde_fn(x, y, t, *args)
121
+ score = score_model(x, y, t, *args)
122
+ score_drift = -sde_diffusion[:, None, None, None]**2 * score * (0.5 if self.probability_flow else 1.)
123
+ diffusion = torch.zeros_like(sde_diffusion) if self.probability_flow else sde_diffusion
124
+ total_drift = sde_drift + score_drift
125
+ return {
126
+ 'total_drift': total_drift, 'diffusion': diffusion, 'sde_drift': sde_drift,
127
+ 'sde_diffusion': sde_diffusion, 'score_drift': score_drift, 'score': score,
128
+ }
129
+
130
+ def discretize(self, x, y, t, stepsize):
131
+ """Create discretized iteration rules for the reverse diffusion sampler."""
132
+ f, G = discretize_fn(x, y, t, stepsize)
133
+ rev_f = f - G[:, None, None, None] ** 2 * score_model(x, y, t) * (0.5 if self.probability_flow else 1.)
134
+ rev_G = torch.zeros_like(G) if self.probability_flow else G
135
+ return rev_f, rev_G
136
+
137
+ return RSDE()
138
+
139
+ @abc.abstractmethod
140
+ def copy(self):
141
+ pass
142
+
143
+
144
+ @SDERegistry.register("ouve")
145
+ class OUVESDE(SDE):
146
+ @staticmethod
147
+ def add_argparse_args(parser):
148
+ parser.add_argument("--theta", type=float, default=1.5, help="The constant stiffness of the Ornstein-Uhlenbeck process. 1.5 by default.")
149
+ parser.add_argument("--sigma-min", type=float, default=0.05, help="The minimum sigma to use. 0.05 by default.")
150
+ parser.add_argument("--sigma-max", type=float, default=0.5, help="The maximum sigma to use. 0.5 by default.")
151
+ parser.add_argument("--N", type=int, default=30, help="The number of timesteps in the SDE discretization. 30 by default")
152
+ parser.add_argument("--sampler_type", type=str, default="pc", help="Type of sampler to use. 'pc' by default.")
153
+ return parser
154
+
155
+ def __init__(self, theta, sigma_min, sigma_max, N=30, sampler_type="pc", **ignored_kwargs):
156
+ """Construct an Ornstein-Uhlenbeck Variance Exploding SDE.
157
+
158
+ Note that the "steady-state mean" `y` is not provided at construction, but must rather be given as an argument
159
+ to the methods which require it (e.g., `sde` or `marginal_prob`).
160
+
161
+ dx = -theta (y-x) dt + sigma(t) dw
162
+
163
+ with
164
+
165
+ sigma(t) = sigma_min (sigma_max/sigma_min)^t * sqrt(2 log(sigma_max/sigma_min))
166
+
167
+ Args:
168
+ theta: stiffness parameter.
169
+ sigma_min: smallest sigma.
170
+ sigma_max: largest sigma.
171
+ N: number of discretization steps
172
+ """
173
+ super().__init__(N)
174
+ self.theta = theta
175
+ self.sigma_min = sigma_min
176
+ self.sigma_max = sigma_max
177
+ self.logsig = np.log(self.sigma_max / self.sigma_min)
178
+ self.N = N
179
+ self.sampler_type = sampler_type
180
+
181
+ def copy(self):
182
+ return OUVESDE(self.theta, self.sigma_min, self.sigma_max, N=self.N, sampler_type=self.sampler_type)
183
+
184
+ @property
185
+ def T(self):
186
+ return 1
187
+
188
+ def sde(self, x, y, t):
189
+ drift = self.theta * (y - x)
190
+ # the sqrt(2*logsig) factor is required here so that logsig does not in the end affect the perturbation kernel
191
+ # standard deviation. this can be understood from solving the integral of [exp(2s) * g(s)^2] from s=0 to t
192
+ # with g(t) = sigma(t) as defined here, and seeing that `logsig` remains in the integral solution
193
+ # unless this sqrt(2*logsig) factor is included.
194
+ sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
195
+ diffusion = sigma * np.sqrt(2 * self.logsig)
196
+ return drift, diffusion
197
+
198
+ def _mean(self, x0, y, t):
199
+ theta = self.theta
200
+ exp_interp = torch.exp(-theta * t)[:, None, None, None]
201
+ return exp_interp * x0 + (1 - exp_interp) * y
202
+
203
+ def alpha(self, t):
204
+ return torch.exp(-self.theta * t)
205
+
206
+ def _std(self, t):
207
+ # This is a full solution to the ODE for P(t) in our derivations, after choosing g(s) as in self.sde()
208
+ sigma_min, theta, logsig = self.sigma_min, self.theta, self.logsig
209
+ # could maybe replace the two torch.exp(... * t) terms here by cached values **t
210
+ return torch.sqrt(
211
+ (
212
+ sigma_min**2
213
+ * torch.exp(-2 * theta * t)
214
+ * (torch.exp(2 * (theta + logsig) * t) - 1)
215
+ * logsig
216
+ )
217
+ /
218
+ (theta + logsig)
219
+ )
220
+
221
+ def marginal_prob(self, x0, y, t):
222
+ return self._mean(x0, y, t), self._std(t)
223
+
224
+ def prior_sampling(self, shape, y):
225
+ if shape != y.shape:
226
+ warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
227
+ std = self._std(torch.ones((y.shape[0],), device=y.device))
228
+ x_T = y + torch.randn_like(y) * std[:, None, None, None]
229
+ return x_T
230
+
231
+ def prior_logp(self, z):
232
+ raise NotImplementedError("prior_logp for OU SDE not yet implemented!")
233
+
234
+
235
+ @SDERegistry.register("sbve")
236
+ class SBVESDE(SDE):
237
+ @staticmethod
238
+ def add_argparse_args(parser):
239
+ parser.add_argument("--N", type=int, default=50, help="The number of timesteps in the SDE discretization. 50 by default")
240
+ parser.add_argument("--k", type=float, default=2.6, help="Parameter of the diffusion coefficient. 2.6 by default.")
241
+ parser.add_argument("--c", type=float, default=0.4, help="Parameter of the diffusion coefficient. 0.4 by default.")
242
+ parser.add_argument("--eps", type=float, default=1e-8, help="Small constant to avoid numerical instability. 1e-8 by default.")
243
+ parser.add_argument("--sampler_type", type=str, default="ode")
244
+ return parser
245
+
246
+ def __init__(self, k, c, N=50, eps=1e-8, sampler_type="ode", **ignored_kwargs):
247
+ """Construct a Schrodinger Bridge with Variance Exploding SDE.
248
+
249
+ As described in Jukić et al., „Schrödinger Bridge for Generative Speech Enhancement“, 2024.
250
+
251
+ Args:
252
+ k: stiffness parameter.
253
+ c: diffusion parameter.
254
+ N: number of discretization steps
255
+ """
256
+ super().__init__(N)
257
+ self.k = k
258
+ self.c = c
259
+ self.N = N
260
+ self.eps = eps
261
+ self.sampler_type = sampler_type
262
+
263
+ def copy(self):
264
+ return SBVESDE(self.k, self.c, N=self.N)
265
+
266
+ @property
267
+ def T(self):
268
+ return 1
269
+
270
+ def sde(self, x, y, t):
271
+ f = 0.0 # Table 1
272
+ g = torch.sqrt(torch.tensor(self.c)) * self.k**(t) # Table 1
273
+ return f, g
274
+
275
+ def _sigmas_alphas(self, t):
276
+ alpha_t = torch.ones_like(t)
277
+ alpha_T = torch.ones_like(t)
278
+ sigma_t = torch.sqrt((self.c*(self.k**(2*t)-1.0)) \
279
+ / (2*torch.log(torch.tensor(self.k)))) # Table 1
280
+ sigma_T = torch.sqrt((self.c*(self.k**(2*self.T)-1.0)) \
281
+ / (2*torch.log(torch.tensor(self.k)))) # Table 1
282
+
283
+ alpha_bart = alpha_t / (alpha_T + self.eps) # below Eq. (9)
284
+ sigma_bart = torch.sqrt(sigma_T**2 - sigma_t**2 + self.eps) # below Eq. (9)
285
+
286
+ return sigma_t, sigma_T, sigma_bart, alpha_t, alpha_T, alpha_bart
287
+
288
+ def _mean(self, x0, y, t):
289
+ sigma_t, sigma_T, sigma_bart, alpha_t, alpha_T, alpha_bart = self._sigmas_alphas(t)
290
+
291
+ w_xt = alpha_t * sigma_bart**2 / (sigma_T**2 + self.eps) # below Eq. (11)
292
+ w_yt = alpha_bart * sigma_t**2 / (sigma_T**2 + self.eps) # below Eq. (11)
293
+
294
+ mu = w_xt[:, None, None, None] * x0 + w_yt[:, None, None, None] * y # Eq. (11)
295
+ return mu
296
+
297
+ def _std(self, t):
298
+ sigma_t, sigma_T, sigma_bart, alpha_t, alpha_T, alpha_bart = self._sigmas_alphas(t)
299
+
300
+ sigma_xt = (alpha_t * sigma_bart * sigma_t) / (sigma_T + self.eps)
301
+ return sigma_xt
302
+
303
+ def marginal_prob(self, x0, y, t):
304
+ return self._mean(x0, y, t), self._std(t)
305
+
306
+ def prior_sampling(self, shape, y):
307
+ if shape != y.shape:
308
+ warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
309
+ x_T = y
310
+ return x_T
311
+
312
+ def prior_logp(self, z):
313
+ raise NotImplementedError("prior_logp for SBVE SDE not yet implemented!")