trenden commited on
Commit
03d9cd8
·
verified ·
1 Parent(s): 9b9a995

Upload sgmse/sampling/__init__.py

Browse files
Files changed (1) hide show
  1. sgmse/sampling/__init__.py +249 -0
sgmse/sampling/__init__.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sampling.py
2
+ """Various sampling methods."""
3
+ from scipy import integrate
4
+ import torch
5
+
6
+ from .predictors import Predictor, PredictorRegistry, ReverseDiffusionPredictor
7
+ from .correctors import Corrector, CorrectorRegistry
8
+
9
+
10
+ __all__ = [
11
+ 'PredictorRegistry', 'CorrectorRegistry', 'Predictor', 'Corrector',
12
+ 'get_sampler'
13
+ ]
14
+
15
+
16
+ def to_flattened_numpy(x):
17
+ """Flatten a torch tensor `x` and convert it to numpy."""
18
+ return x.detach().cpu().numpy().reshape((-1,))
19
+
20
+
21
+ def from_flattened_numpy(x, shape):
22
+ """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
23
+ return torch.from_numpy(x.reshape(shape))
24
+
25
+
26
+ def get_pc_sampler(
27
+ predictor_name, corrector_name, sde, score_fn, y,
28
+ denoise=True, eps=3e-2, snr=0.1, corrector_steps=1, probability_flow: bool = False,
29
+ intermediate=False, **kwargs
30
+ ):
31
+ """Create a Predictor-Corrector (PC) sampler.
32
+
33
+ Args:
34
+ predictor_name: The name of a registered `sampling.Predictor`.
35
+ corrector_name: The name of a registered `sampling.Corrector`.
36
+ sde: An `sdes.SDE` object representing the forward SDE.
37
+ score_fn: A function (typically learned model) that predicts the score.
38
+ y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
39
+ denoise: If `True`, add one-step denoising to the final samples.
40
+ eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
41
+ snr: The SNR to use for the corrector. 0.1 by default, and ignored for `NoneCorrector`.
42
+ N: The number of reverse sampling steps. If `None`, uses the SDE's `N` property by default.
43
+
44
+ Returns:
45
+ A sampling function that returns samples and the number of function evaluations during sampling.
46
+ """
47
+ predictor_cls = PredictorRegistry.get_by_name(predictor_name)
48
+ corrector_cls = CorrectorRegistry.get_by_name(corrector_name)
49
+ predictor = predictor_cls(sde, score_fn, probability_flow=probability_flow)
50
+ corrector = corrector_cls(sde, score_fn, snr=snr, n_steps=corrector_steps)
51
+
52
+ def pc_sampler():
53
+ """The PC sampler function."""
54
+ with torch.no_grad():
55
+ xt = sde.prior_sampling(y.shape, y).to(y.device)
56
+ timesteps = torch.linspace(sde.T, eps, sde.N, device=y.device)
57
+ for i in range(sde.N):
58
+ t = timesteps[i]
59
+ if i != len(timesteps) - 1:
60
+ stepsize = t - timesteps[i+1]
61
+ else:
62
+ stepsize = timesteps[-1] # from eps to 0
63
+ vec_t = torch.ones(y.shape[0], device=y.device) * t
64
+ xt, xt_mean = corrector.update_fn(xt, y, vec_t)
65
+ xt, xt_mean = predictor.update_fn(xt, y, vec_t, stepsize)
66
+ x_result = xt_mean if denoise else xt
67
+ ns = sde.N * (corrector.n_steps + 1)
68
+ return x_result, ns
69
+
70
+ return pc_sampler
71
+
72
+
73
+ def get_ode_sampler(
74
+ sde, score_fn, y, inverse_scaler=None,
75
+ denoise=True, rtol=1e-5, atol=1e-5,
76
+ method='RK45', eps=3e-2, device='cuda', **kwargs
77
+ ):
78
+ """Probability flow ODE sampler with the black-box ODE solver.
79
+
80
+ Args:
81
+ sde: An `sdes.SDE` object representing the forward SDE.
82
+ score_fn: A function (typically learned model) that predicts the score.
83
+ y: A `torch.Tensor`, representing the (non-white-)noisy starting point(s) to condition the prior on.
84
+ inverse_scaler: The inverse data normalizer.
85
+ denoise: If `True`, add one-step denoising to final samples.
86
+ rtol: A `float` number. The relative tolerance level of the ODE solver.
87
+ atol: A `float` number. The absolute tolerance level of the ODE solver.
88
+ method: A `str`. The algorithm used for the black-box ODE solver.
89
+ See the documentation of `scipy.integrate.solve_ivp`.
90
+ eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
91
+ device: PyTorch device.
92
+
93
+ Returns:
94
+ A sampling function that returns samples and the number of function evaluations during sampling.
95
+ """
96
+ predictor = ReverseDiffusionPredictor(sde, score_fn, probability_flow=False)
97
+ rsde = sde.reverse(score_fn, probability_flow=True)
98
+
99
+ def denoise_update_fn(x):
100
+ vec_eps = torch.ones(x.shape[0], device=x.device) * eps
101
+ _, x = predictor.update_fn(x, y, vec_eps)
102
+ return x
103
+
104
+ def drift_fn(x, y, t):
105
+ """Get the drift function of the reverse-time SDE."""
106
+ return rsde.sde(x, y, t)[0]
107
+
108
+ def ode_sampler(z=None, **kwargs):
109
+ """The probability flow ODE sampler with black-box ODE solver.
110
+
111
+ Args:
112
+ model: A score model.
113
+ z: If present, generate samples from latent code `z`.
114
+ Returns:
115
+ samples, number of function evaluations.
116
+ """
117
+ with torch.no_grad():
118
+ # If not represent, sample the latent code from the prior distibution of the SDE.
119
+ x = sde.prior_sampling(y.shape, y).to(device)
120
+
121
+ def ode_func(t, x):
122
+ x = from_flattened_numpy(x, y.shape).to(device).type(torch.complex64)
123
+ vec_t = torch.ones(y.shape[0], device=x.device) * t
124
+ drift = drift_fn(x, y, vec_t)
125
+ return to_flattened_numpy(drift)
126
+
127
+ # Black-box ODE solver for the probability flow ODE
128
+ solution = integrate.solve_ivp(
129
+ ode_func, (sde.T, eps), to_flattened_numpy(x),
130
+ rtol=rtol, atol=atol, method=method, **kwargs
131
+ )
132
+ nfe = solution.nfev
133
+ x = torch.tensor(solution.y[:, -1]).reshape(y.shape).to(device).type(torch.complex64)
134
+
135
+ # Denoising is equivalent to running one predictor step without adding noise
136
+ if denoise:
137
+ x = denoise_update_fn(x)
138
+
139
+ if inverse_scaler is not None:
140
+ x = inverse_scaler(x)
141
+ return x, nfe
142
+
143
+ return ode_sampler
144
+
145
+ def get_sb_sampler(sde, model, y, eps=1e-4, n_steps=50, sampler_type="ode", **kwargs):
146
+ # adapted from https://github.com/NVIDIA/NeMo/blob/78357ae99ff2cf9f179f53fbcb02c88a5a67defb/nemo/collections/audio/parts/submodules/schroedinger_bridge.py#L382
147
+ def sde_sampler():
148
+ """The SB-SDE sampler function."""
149
+ with torch.no_grad():
150
+ xt = y[:, [0], :, :] # special case for storm_2ch
151
+ time_steps = torch.linspace(sde.T, eps, sde.N + 1, device=y.device)
152
+
153
+ # Initial values
154
+ time_prev = time_steps[0] * torch.ones(xt.shape[0], device=xt.device)
155
+ sigma_prev, sigma_T, sigma_bar_prev, alpha_prev, alpha_T, alpha_bar_prev = sde._sigmas_alphas(time_prev)
156
+
157
+ for t in time_steps[1:]:
158
+ # Prepare time steps for the whole batch
159
+ time = t * torch.ones(xt.shape[0], device=xt.device)
160
+
161
+ # Get noise schedule for current time
162
+ sigma_t, sigma_T, sigma_bart, alpha_t, alpha_T, alpha_bart = sde._sigmas_alphas(time)
163
+
164
+ # Run DNN
165
+ current_estimate = model(xt, y, time)
166
+
167
+ # Calculate scaling for the first-order discretization from the paper
168
+ weight_prev = alpha_t * sigma_t**2 / (alpha_prev * sigma_prev**2 + sde.eps)
169
+ tmp = 1 - sigma_t**2 / (sigma_prev**2 + sde.eps)
170
+ weight_estimate = alpha_t * tmp
171
+ weight_z = alpha_t * sigma_t * torch.sqrt(tmp)
172
+
173
+ # View as [B, C, D, T]
174
+ weight_prev = weight_prev[:, None, None, None]
175
+ weight_estimate = weight_estimate[:, None, None, None]
176
+ weight_z = weight_z[:, None, None, None]
177
+
178
+ # Random sample
179
+ z_norm = torch.randn_like(xt)
180
+
181
+ if t == time_steps[-1]:
182
+ weight_z = 0.0
183
+
184
+ # Update state: weighted sum of previous state, current estimate and noise
185
+ xt = weight_prev * xt + weight_estimate * current_estimate + weight_z * z_norm
186
+
187
+ # Save previous values
188
+ time_prev = time
189
+ alpha_prev = alpha_t
190
+ sigma_prev = sigma_t
191
+ sigma_bar_prev = sigma_bart
192
+
193
+ return xt, n_steps
194
+
195
+ def ode_sampler():
196
+ """The SB-ODE sampler function."""
197
+ with torch.no_grad():
198
+ xt = y
199
+ time_steps = torch.linspace(sde.T, eps, sde.N + 1, device=y.device)
200
+
201
+ # Initial values
202
+ time_prev = time_steps[0] * torch.ones(xt.shape[0], device=xt.device)
203
+ sigma_prev, sigma_T, sigma_bar_prev, alpha_prev, alpha_T, alpha_bar_prev = sde._sigmas_alphas(time_prev)
204
+
205
+ for t in time_steps[1:]:
206
+ # Prepare time steps for the whole batch
207
+ time = t * torch.ones(xt.shape[0], device=xt.device)
208
+
209
+ # Get noise schedule for current time
210
+ sigma_t, sigma_T, sigma_bart, alpha_t, alpha_T, alpha_bart = sde._sigmas_alphas(time)
211
+
212
+ # Run DNN
213
+ current_estimate = model(xt, y, time)
214
+
215
+ # Calculate scaling for the first-order discretization from the paper
216
+ weight_prev = alpha_t * sigma_t * sigma_bart / (alpha_prev * sigma_prev * sigma_bar_prev + sde.eps)
217
+ weight_estimate = (
218
+ alpha_t
219
+ / (sigma_T**2 + sde.eps)
220
+ * (sigma_bart**2 - sigma_bar_prev * sigma_t * sigma_bart / (sigma_prev + sde.eps))
221
+ )
222
+ weight_prior_mean = (
223
+ alpha_t
224
+ / (alpha_T * sigma_T**2 + sde.eps)
225
+ * (sigma_t**2 - sigma_prev * sigma_t * sigma_bart / (sigma_bar_prev + sde.eps))
226
+ )
227
+
228
+ # View as [B, C, D, T]
229
+ weight_prev = weight_prev[:, None, None, None]
230
+ weight_estimate = weight_estimate[:, None, None, None]
231
+ weight_prior_mean = weight_prior_mean[:, None, None, None]
232
+
233
+ # Update state: weighted sum of previous state, current estimate and prior
234
+ xt = weight_prev * xt + weight_estimate * current_estimate + weight_prior_mean * y
235
+
236
+ # Save previous values
237
+ time_prev = time
238
+ alpha_prev = alpha_t
239
+ sigma_prev = sigma_t
240
+ sigma_bar_prev = sigma_bart
241
+
242
+ return xt, n_steps
243
+
244
+ if sampler_type == "sde":
245
+ return sde_sampler
246
+ elif sampler_type == "ode":
247
+ return ode_sampler
248
+ else:
249
+ raise ValueError("Invalid type. Choose 'ode' or 'sde'.")