ibraheemmoosa commited on
Commit
c2db598
·
1 Parent(s): 5acb367

Sampler script.

Browse files
Files changed (1) hide show
  1. Score-SDE/sample-from-score-sde.py +494 -0
Score-SDE/sample-from-score-sde.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import librosa
4
+ from torch.utils.data import TensorDataset
5
+ import matplotlib.pyplot as plt
6
+ import jax
7
+ import jax.tools.colab_tpu
8
+ import jax.numpy as jnp
9
+ import flax
10
+ import flax.linen as nn
11
+ from typing import Any, Tuple
12
+ import functools
13
+ import torch
14
+ from flax.serialization import to_bytes, from_bytes
15
+ import tensorflow as tf
16
+ from torch.utils.data import DataLoader
17
+ import torchvision.transforms as transforms
18
+ from torchvision.datasets import MNIST
19
+ import tqdm
20
+ from scipy import integrate
21
+ import matplotlib.pyplot as plt
22
+ from torchvision.utils import make_grid
23
+ import soundfile
24
+ import librosa.display
25
+ import IPython.display as ipd
26
+ import random
27
+ import argparse
28
+
29
+ parser = argparse.ArgumentParser()
30
+ parser.add_argument('--sigma', type=float, default=25.0)
31
+ parser.add_argument('--n_epochs', type=int, default=500)
32
+ parser.add_argument('--batch_size', type=int, default=512)
33
+ parser.add_argument('--lr', type=float, default=1e-2)
34
+ parser.add_argument('--num_steps', type=int, default=500)
35
+ parser.add_argument('--pc_num_steps', type=int, default=500)
36
+ parser.add_argument('--signal_to_noise_ratio', type=float, default=0.16)
37
+ parser.add_argument('--etol', type=float, default=1e-5)
38
+ parser.add_argument('--sample_batch_size', type=int, default=64)
39
+ parser.add_argument('--sample_no', type=int, default=25)
40
+ args = parser.parse_args(args=[]) # required for colab
41
+
42
+
43
+ class GaussianFourierProjection(nn.Module):
44
+ """Gaussian random features for encoding time steps."""
45
+ embed_dim: int
46
+ scale: float = 30.
47
+ @nn.compact
48
+ def __call__(self, x):
49
+ # Randomly sample weights during initialization. These weights are fixed
50
+ # during optimization and are not trainable.
51
+ W = self.param('W', jax.nn.initializers.normal(stddev=self.scale),
52
+ (self.embed_dim // 2, ))
53
+ W = jax.lax.stop_gradient(W)
54
+ x_proj = x[:, None] * W[None, :] * 2 * jnp.pi
55
+ return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1)
56
+
57
+
58
+ class Dense(nn.Module):
59
+ """A fully connected layer that reshapes outputs to feature maps."""
60
+ output_dim: int
61
+
62
+ @nn.compact
63
+ def __call__(self, x):
64
+ return nn.Dense(self.output_dim)(x)[:, None, None, :]
65
+
66
+
67
+ class ScoreNet(nn.Module):
68
+ """A time-dependent score-based model built upon U-Net architecture.
69
+
70
+ Args:
71
+ marginal_prob_std: A function that takes time t and gives the standard
72
+ deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
73
+ channels: The number of channels for feature maps of each resolution.
74
+ embed_dim: The dimensionality of Gaussian random feature embeddings.
75
+ """
76
+ marginal_prob_std: Any
77
+ channels: Tuple[int] = (32, 64, 128, 256)
78
+ embed_dim: int = 256
79
+
80
+ @nn.compact
81
+ def __call__(self, x, t):
82
+ # The swish activation function
83
+ act = nn.swish
84
+ # Obtain the Gaussian random feature embedding for t
85
+ embed = act(nn.Dense(self.embed_dim)(
86
+ GaussianFourierProjection(embed_dim=self.embed_dim)(t)))
87
+
88
+ # Encoding path
89
+ h1 = nn.Conv(self.channels[0], (3, 3), (1, 1), padding='VALID',
90
+ use_bias=False)(x)
91
+ ## Incorporate information from t
92
+ h1 += Dense(self.channels[0])(embed)
93
+ ## Group normalization
94
+ h1 = nn.GroupNorm(4)(h1)
95
+ h1 = act(h1)
96
+ h2 = nn.Conv(self.channels[1], (3, 3), (2, 2), padding='VALID',
97
+ use_bias=False)(h1)
98
+ h2 += Dense(self.channels[1])(embed)
99
+ h2 = nn.GroupNorm()(h2)
100
+ h2 = act(h2)
101
+ h3 = nn.Conv(self.channels[2], (3, 3), (2, 2), padding='VALID',
102
+ use_bias=False)(h2)
103
+ h3 += Dense(self.channels[2])(embed)
104
+ h3 = nn.GroupNorm()(h3)
105
+ h3 = act(h3)
106
+ h4 = nn.Conv(self.channels[3], (3, 3), (2, 2), padding='VALID',
107
+ use_bias=False)(h3)
108
+ h4 += Dense(self.channels[3])(embed)
109
+ h4 = nn.GroupNorm()(h4)
110
+ h4 = act(h4)
111
+
112
+ # Decoding path
113
+ h = nn.Conv(self.channels[2], (3, 3), (1, 1), padding=((2, 2), (2, 2)),
114
+ input_dilation=(2, 2), use_bias=False)(h4)
115
+ ## Skip connection from the encoding path
116
+ h += Dense(self.channels[2])(embed)
117
+ h = nn.GroupNorm()(h)
118
+ h = act(h)
119
+ h = nn.Conv(self.channels[1], (3, 3), (1, 1), padding=((2, 3), (2, 2)),
120
+ input_dilation=(2, 2), use_bias=False)(
121
+ jnp.concatenate([h, h3], axis=-1)
122
+ )
123
+ h += Dense(self.channels[1])(embed)
124
+ h = nn.GroupNorm()(h)
125
+ h = act(h)
126
+ h = nn.Conv(self.channels[0], (3, 3), (1, 1), padding=((2, 3), (2, 2)),
127
+ input_dilation=(2, 2), use_bias=False)(
128
+ jnp.concatenate([h, h2], axis=-1)
129
+ )
130
+ h += Dense(self.channels[0])(embed)
131
+ h = nn.GroupNorm()(h)
132
+ h = act(h)
133
+ h = nn.Conv(1, (3, 3), (1, 1), padding=((2, 2), (2, 2)))(
134
+ jnp.concatenate([h, h1], axis=-1)
135
+ )
136
+
137
+ # Normalize output
138
+ h = h / self.marginal_prob_std(t)[:, None, None, None]
139
+ return h
140
+
141
+
142
+ def marginal_prob_std(t, sigma):
143
+ """Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.
144
+
145
+ Args:
146
+ t: A vector of time steps.
147
+ sigma: The $\sigma$ in our SDE.
148
+
149
+ Returns:
150
+ The standard deviation.
151
+ """
152
+ return jnp.sqrt((sigma**(2 * t) - 1.) / 2. / jnp.log(sigma))
153
+
154
+ def diffusion_coeff(t, sigma):
155
+ """Compute the diffusion coefficient of our SDE.
156
+
157
+ Args:
158
+ t: A vector of time steps.
159
+ sigma: The $\sigma$ in our SDE.
160
+
161
+ Returns:
162
+ The vector of diffusion coefficients.
163
+ """
164
+ return sigma**t
165
+
166
+
167
+ def loss_fn(rng, model, params, x, marginal_prob_std, eps=1e-5):
168
+ """The loss function for training score-based generative models.
169
+
170
+ Args:
171
+ model: A `flax.linen.Module` object that represents the structure of
172
+ the score-based model.
173
+ params: A dictionary that contains all trainable parameters.
174
+ x: A mini-batch of training data.
175
+ marginal_prob_std: A function that gives the standard deviation of
176
+ the perturbation kernel.
177
+ eps: A tolerance value for numerical stability.
178
+ """
179
+ rng, step_rng = jax.random.split(rng)
180
+ random_t = jax.random.uniform(step_rng, (x.shape[0],), minval=eps, maxval=1.)
181
+ rng, step_rng = jax.random.split(rng)
182
+ z = jax.random.normal(step_rng, x.shape)
183
+ std = marginal_prob_std(random_t)
184
+ perturbed_x = x + z * std[:, None, None, None]
185
+ score = model.apply(params, perturbed_x, random_t)
186
+ loss = jnp.mean(jnp.sum((score * std[:, None, None, None] + z)**2,
187
+ axis=(1,2,3)))
188
+ return loss
189
+
190
+ def get_train_step_fn(model, marginal_prob_std):
191
+ """Create a one-step training function.
192
+
193
+ Args:
194
+ model: A `flax.linen.Module` object that represents the structure of
195
+ the score-based model.
196
+ marginal_prob_std: A function that gives the standard deviation of
197
+ the perturbation kernel.
198
+ Returns:
199
+ A function that runs one step of training.
200
+ """
201
+
202
+ val_and_grad_fn = jax.value_and_grad(loss_fn, argnums=2)
203
+ def step_fn(rng, x, optimizer):
204
+ params = optimizer.target
205
+ loss, grad = val_and_grad_fn(rng, model, params, x, marginal_prob_std)
206
+ mean_grad = jax.lax.pmean(grad, axis_name='device')
207
+ mean_loss = jax.lax.pmean(loss, axis_name='device')
208
+ new_optimizer = optimizer.apply_gradient(mean_grad)
209
+
210
+ return mean_loss, new_optimizer
211
+ return jax.pmap(step_fn, axis_name='device')
212
+
213
+
214
+ def score_fn(score_model, params, x, t):
215
+ return score_model.apply(params, x, t)
216
+
217
+ def Euler_Maruyama_sampler(rng,
218
+ score_model,
219
+ params,
220
+ marginal_prob_std,
221
+ diffusion_coeff,
222
+ batch_size=64,
223
+ num_steps=args.num_steps,
224
+ eps=1e-3):
225
+ """Generate samples from score-based models with the Euler-Maruyama solver.
226
+
227
+ Args:
228
+ rng: A JAX random state.
229
+ score_model: A `flax.linen.Module` object that represents the architecture
230
+ of a score-based model.
231
+ params: A dictionary that contains the model parameters.
232
+ marginal_prob_std: A function that gives the standard deviation of
233
+ the perturbation kernel.
234
+ diffusion_coeff: A function that gives the diffusion coefficient of the SDE.
235
+ batch_size: The number of samplers to generate by calling this function once.
236
+ num_steps: The number of sampling steps.
237
+ Equivalent to the number of discretized time steps.
238
+ eps: The smallest time step for numerical stability.
239
+
240
+ Returns:
241
+ Samples.
242
+ """
243
+ rng, step_rng = jax.random.split(rng)
244
+ time_shape = (jax.local_device_count(), batch_size // jax.local_device_count())
245
+ sample_shape = time_shape + (28, 313, 1)
246
+ init_x = jax.random.normal(step_rng, sample_shape) * marginal_prob_std(1.)
247
+ time_steps = jnp.linspace(1., eps, num_steps)
248
+ step_size = time_steps[0] - time_steps[1]
249
+ x = init_x
250
+ for time_step in tqdm.notebook.tqdm(time_steps):
251
+ batch_time_step = jnp.ones(time_shape) * time_step
252
+ g = diffusion_coeff(time_step)
253
+ mean_x = x + (g**2) * pmap_score_fn(score_model,
254
+ params,
255
+ x,
256
+ batch_time_step) * step_size
257
+ rng, step_rng = jax.random.split(rng)
258
+ x = mean_x + jnp.sqrt(step_size) * g * jax.random.normal(step_rng, x.shape)
259
+ # Do not include any noise in the last sampling step.
260
+ return mean_x
261
+
262
+
263
+ def pc_sampler(rng,
264
+ score_model,
265
+ params,
266
+ marginal_prob_std,
267
+ diffusion_coeff,
268
+ batch_size=64,
269
+ num_steps=args.num_steps,
270
+ snr=args.signal_to_noise_ratio,
271
+ eps=1e-3):
272
+ """Generate samples from score-based models with Predictor-Corrector method.
273
+
274
+ Args:
275
+ rng: A JAX random state.
276
+ score_model: A `flax.linen.Module` that represents the
277
+ architecture of the score-based model.
278
+ params: A dictionary that contains the parameters of the score-based model.
279
+ marginal_prob_std: A function that gives the standard deviation
280
+ of the perturbation kernel.
281
+ diffusion_coeff: A function that gives the diffusion coefficient
282
+ of the SDE.
283
+ batch_size: The number of samplers to generate by calling this function once.
284
+ num_steps: The number of sampling steps.
285
+ Equivalent to the number of discretized time steps.
286
+ eps: The smallest time step for numerical stability.
287
+
288
+ Returns:
289
+ Samples.
290
+ """
291
+ time_shape = (jax.local_device_count(), batch_size // jax.local_device_count())
292
+ sample_shape = time_shape + (28, 313, 1)
293
+ rng, step_rng = jax.random.split(rng)
294
+ init_x = jax.random.normal(step_rng, sample_shape) * marginal_prob_std(1.)
295
+ time_steps = jnp.linspace(1., eps, num_steps)
296
+ step_size = time_steps[0] - time_steps[1]
297
+ x = init_x
298
+ for time_step in tqdm.notebook.tqdm(time_steps):
299
+ batch_time_step = jnp.ones(time_shape) * time_step
300
+ # Corrector step (Langevin MCMC)
301
+ grad = pmap_score_fn(score_model, params, x, batch_time_step)
302
+ grad_norm = jnp.linalg.norm(grad.reshape(sample_shape[0], sample_shape[1], -1),
303
+ axis=-1).mean()
304
+ noise_norm = np.sqrt(np.prod(x.shape[1:]))
305
+ langevin_step_size = 2 * (snr * noise_norm / grad_norm)**2
306
+ rng, step_rng = jax.random.split(rng)
307
+ z = jax.random.normal(step_rng, x.shape)
308
+ x = x + langevin_step_size * grad + jnp.sqrt(2 * langevin_step_size) * z
309
+
310
+ # Predictor step (Euler-Maruyama)
311
+ g = diffusion_coeff(time_step)
312
+ score = pmap_score_fn(score_model, params, x, batch_time_step)
313
+ x_mean = x + (g**2) * score * step_size
314
+ rng, step_rng = jax.random.split(rng)
315
+ z = jax.random.normal(step_rng, x.shape)
316
+ x = x_mean + jnp.sqrt(g**2 * step_size) * z
317
+
318
+ # The last step does not include any noise
319
+ return x_mean
320
+
321
+
322
+ def ode_sampler(rng,
323
+ score_model,
324
+ params,
325
+ marginal_prob_std,
326
+ diffusion_coeff,
327
+ batch_size=64,
328
+ atol=args.etol,
329
+ rtol=args.etol,
330
+ z=None,
331
+ eps=1e-3):
332
+ """Generate samples from score-based models with black-box ODE solvers.
333
+
334
+ Args:
335
+ rng: A JAX random state.
336
+ score_model: A `flax.linen.Module` object that represents architecture
337
+ of the score-based model.
338
+ params: A dictionary that contains model parameters.
339
+ marginal_prob_std: A function that returns the standard deviation
340
+ of the perturbation kernel.
341
+ diffusion_coeff: A function that returns the diffusion coefficient of the SDE.
342
+ batch_size: The number of samplers to generate by calling this function once.
343
+ atol: Tolerance of absolute errors.
344
+ rtol: Tolerance of relative errors.
345
+ z: The latent code that governs the final sample. If None, we start from p_1;
346
+ otherwise, we start from the given z.
347
+ eps: The smallest time step for numerical stability.
348
+ """
349
+
350
+ time_shape = (jax.local_device_count(), batch_size // jax.local_device_count())
351
+ sample_shape = time_shape + (28, 313, 1)
352
+ # Create the latent code
353
+ if z is None:
354
+ rng, step_rng = jax.random.split(rng)
355
+ z = jax.random.normal(step_rng, sample_shape)
356
+ init_x = z * marginal_prob_std(1.)
357
+ else:
358
+ init_x = z
359
+
360
+ shape = init_x.shape
361
+
362
+ def score_eval_wrapper(sample, time_steps):
363
+ """A wrapper of the score-based model for use by the ODE solver."""
364
+ sample = jnp.asarray(sample, dtype=jnp.float32).reshape(sample_shape)
365
+ time_steps = jnp.asarray(time_steps).reshape(time_shape)
366
+ score = pmap_score_fn(score_model, params, sample, time_steps)
367
+ return np.asarray(score).reshape((-1,)).astype(np.float64)
368
+
369
+ def ode_func(t, x):
370
+ """The ODE function for use by the ODE solver."""
371
+ time_steps = np.ones(time_shape) * t
372
+ g = diffusion_coeff(t)
373
+ return -0.5 * (g**2) * score_eval_wrapper(x, time_steps)
374
+
375
+ # Run the black-box ODE solver.
376
+ res = integrate.solve_ivp(ode_func, (1., eps), np.asarray(init_x).reshape(-1),
377
+ rtol=rtol, atol=atol, method='RK45')
378
+ print(f"Number of function evaluations: {res.nfev}")
379
+ x = jnp.asarray(res.y[:, -1]).reshape(shape)
380
+
381
+ return x
382
+
383
+
384
+ def noise_removal(sample, threshold=-35.0):
385
+ # k = torch.tensor(np.asarray(samples)[args.sample_no])
386
+ # k = torch.mean(k, axis=1, keepdims=False)
387
+ p = np.array(sample)
388
+
389
+ DB = librosa.amplitude_to_db(p, ref=np.max)
390
+ DB_noise_removed = np.where(DB > threshold, DB, -80)
391
+
392
+
393
+ return DB, DB_noise_removed
394
+
395
+ def audio(sample, noise_threshold=-35.0):
396
+ sampling_rate = 16000
397
+
398
+ call_with_noise, call_wo_noise = noise_removal(sample, threshold=noise_threshold)
399
+ call_wo_noise = librosa.db_to_amplitude(call_wo_noise)
400
+ back_audio = librosa.feature.inverse.mel_to_audio(call_wo_noise, sr=sampling_rate)
401
+ return back_audio
402
+ # soundfile.write('audio.wav', back_audio, samplerate=sampling_rate, subtype='FLOAT')
403
+ # birdsong_back_audio, _ = librosa.load('audio.wav', sr=sampling_rate)
404
+ # return birdsong_back_audio
405
+
406
+ if __name__ == '__main__':
407
+
408
+ sigma = args.sigma
409
+ marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
410
+ diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)
411
+
412
+ n_epochs = args.n_epochs
413
+ batch_size = args.batch_size
414
+ lr=args.lr
415
+
416
+ pmap_score_fn = jax.pmap(score_fn, static_broadcasted_argnums=(0, 1))
417
+
418
+ rng = jax.random.PRNGKey(0)
419
+ fake_input = jnp.ones((batch_size, 28, 313, 1))
420
+ fake_time = jnp.ones(batch_size)
421
+ score_model = ScoreNet(marginal_prob_std_fn)
422
+ params = score_model.init({'params': rng}, fake_input, fake_time)
423
+
424
+ # dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
425
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
426
+ optimizer = flax.optim.Adam(learning_rate=lr).create(params)
427
+ train_step_fn = get_train_step_fn(score_model, marginal_prob_std_fn)
428
+ tqdm_epoch = tqdm.notebook.trange(n_epochs)
429
+
430
+ assert batch_size % jax.local_device_count() == 0
431
+ data_shape = (jax.local_device_count(), -1, 28, 313, 1)
432
+
433
+ # optimizer = flax.jax_utils.replicate(optimizer)
434
+ # for epoch in tqdm_epoch:
435
+ # avg_loss = 0.
436
+ # num_items = 0
437
+ # for x in data_loader:
438
+ # x = x[0]
439
+ # x = x.numpy().reshape(data_shape)
440
+ # rng, *step_rng = jax.random.split(rng, jax.local_device_count() + 1)
441
+ # step_rng = jnp.asarray(step_rng)
442
+ # loss, optimizer = train_step_fn(step_rng, x, optimizer)
443
+ # loss = flax.jax_utils.unreplicate(loss)
444
+ # avg_loss += loss.item() * x.shape[0]
445
+ # num_items += x.shape[0]
446
+ # # Print the averaged training loss so far.
447
+ # tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
448
+ # # Update the checkpoint after each epoch of training.
449
+ # with tf.io.gfile.GFile('ckpt.flax', 'wb') as fout:
450
+ # fout.write(to_bytes(flax.jax_utils.unreplicate(optimizer)))
451
+
452
+ num_steps = args.num_steps
453
+ signal_to_noise_ratio = args.signal_to_noise_ratio
454
+ pc_num_steps = args.pc_num_steps
455
+ error_tolerance = args.etol
456
+
457
+ sample_batch_size = args.sample_batch_size
458
+ sampler = ode_sampler
459
+
460
+ ## Load the pre-trained checkpoint from disk.
461
+ score_model = ScoreNet(marginal_prob_std_fn)
462
+ fake_input = jnp.ones((sample_batch_size, 28, 313, 1))
463
+ fake_time = jnp.ones((sample_batch_size, ))
464
+ rng = jax.random.PRNGKey(0)
465
+ params = score_model.init({'params': rng}, fake_input, fake_time)
466
+ optimizer = flax.optim.Adam().create(params)
467
+ with tf.io.gfile.GFile('ckpt.flax', 'rb') as fin:
468
+ optimizer = from_bytes(optimizer, fin.read())
469
+
470
+ ## Generate samples using the specified sampler.
471
+ rng, step_rng = jax.random.split(rng)
472
+ samples = sampler(rng,
473
+ score_model,
474
+ optimizer.target,
475
+ marginal_prob_std_fn,
476
+ diffusion_coeff_fn,
477
+ sample_batch_size)
478
+
479
+ ## Sample visualization.
480
+ # samples = jnp.clip(samples, 0.0, 10000.0)
481
+ samples = jnp.transpose(samples.reshape((-1, 28, 313, 1)), (0, 3, 1, 2))
482
+ %matplotlib inline
483
+ sample_grid = make_grid(torch.tensor(np.asarray(samples)), nrow=int(np.sqrt(sample_batch_size)))
484
+
485
+ plt.figure(figsize=(6,6))
486
+ plt.axis('off')
487
+ plt.imshow(sample_grid.permute(1, 2, 0).cpu(), vmin=0., vmax=1.)
488
+ plt.show()
489
+
490
+ # audio_and_viz(samples)
491
+
492
+ j = 7
493
+ viz(jnp.mean(samples[j], 0))
494
+ ipd.Audio(audio(jnp.mean(samples[j], 0), noise_threshold=-25.0), rate=16000)