| import jax |
| import jax.numpy as jnp |
| from jax import random |
| import flax |
| import flax.linen as nn |
| from typing import Any, Tuple |
| import functools |
| import numpy as np |
| import torch |
| from torch.utils.data import TensorDataset |
|
|
| key = random.PRNGKey(0) |
|
|
| dataset = [] |
| with np.load('spectograms.npz') as data: |
| for file in data.files: |
| dataset.append(data[file]) |
|
|
| dataset = np.stack(dataset) |
| dataset = np.expand_dims(dataset, axis=3) |
| dataset = TensorDataset(torch.from_numpy(dataset)) |
|
|
|
|
| |
|
|
| class GaussianFourierProjection(nn.Module): |
| """Gaussian random features for encoding time steps.""" |
| embed_dim: int |
| scale: float = 30. |
| @nn.compact |
| def __call__(self, x): |
| |
| |
| W = self.param('W', jax.nn.initializers.normal(stddev=self.scale), |
| (self.embed_dim // 2, )) |
| W = jax.lax.stop_gradient(W) |
| x_proj = x[:, None] * W[None, :] * 2 * jnp.pi |
| return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1) |
|
|
|
|
| class Dense(nn.Module): |
| """A fully connected layer that reshapes outputs to feature maps.""" |
| output_dim: int |
| |
| @nn.compact |
| def __call__(self, x): |
| return nn.Dense(self.output_dim)(x)[:, None, None, :] |
|
|
|
|
| class ScoreNet(nn.Module): |
| """A time-dependent score-based model built upon U-Net architecture. |
| |
| Args: |
| marginal_prob_std: A function that takes time t and gives the standard |
| deviation of the perturbation kernel p_{0t}(x(t) | x(0)). |
| channels: The number of channels for feature maps of each resolution. |
| embed_dim: The dimensionality of Gaussian random feature embeddings. |
| """ |
| marginal_prob_std: Any |
| channels: Tuple[int] = (32, 64, 128, 256) |
| embed_dim: int = 256 |
| |
| @nn.compact |
| def __call__(self, x, t): |
| |
| act = nn.swish |
| |
| embed = act(nn.Dense(self.embed_dim)( |
| GaussianFourierProjection(embed_dim=self.embed_dim)(t))) |
| |
| |
| h1 = nn.Conv(self.channels[0], (3, 3), (1, 1), padding='VALID', |
| use_bias=False)(x) |
| |
| |
| h1 += Dense(self.channels[0])(embed) |
| |
| h1 = nn.GroupNorm(4)(h1) |
| h1 = act(h1) |
| h2 = nn.Conv(self.channels[1], (3, 3), (2, 2), padding='VALID', |
| use_bias=False)(h1) |
| |
| h2 += Dense(self.channels[1])(embed) |
| h2 = nn.GroupNorm()(h2) |
| h2 = act(h2) |
| h3 = nn.Conv(self.channels[2], (3, 3), (2, 2), padding='VALID', |
| use_bias=False)(h2) |
| |
| h3 += Dense(self.channels[2])(embed) |
| h3 = nn.GroupNorm()(h3) |
| h3 = act(h3) |
| h4 = nn.Conv(self.channels[3], (3, 3), (2, 2), padding='VALID', |
| use_bias=False)(h3) |
| |
| h4 += Dense(self.channels[3])(embed) |
| h4 = nn.GroupNorm()(h4) |
| h4 = act(h4) |
| |
| |
| h = nn.Conv(self.channels[2], (3, 3), (1, 1), padding=((2, 2), (2, 2)), |
| input_dilation=(2, 2), use_bias=False)(h4) |
| |
| |
| h += Dense(self.channels[2])(embed) |
| h = nn.GroupNorm()(h) |
| h = act(h) |
| h = nn.Conv(self.channels[1], (3, 3), (1, 1), padding=((2, 3), (2, 2)), |
| input_dilation=(2, 2), use_bias=False)( |
| jnp.concatenate([h, h3], axis=-1) |
| ) |
| |
| h += Dense(self.channels[1])(embed) |
| h = nn.GroupNorm()(h) |
| h = act(h) |
| h = nn.Conv(self.channels[0], (3, 3), (1, 1), padding=((2, 3), (2, 2)), |
| input_dilation=(2, 2), use_bias=False)( |
| jnp.concatenate([h, h2], axis=-1) |
| ) |
| |
| h += Dense(self.channels[0])(embed) |
| h = nn.GroupNorm()(h) |
| h = act(h) |
| h = nn.Conv(1, (3, 3), (1, 1), padding=((2, 2), (2, 2)))( |
| jnp.concatenate([h, h1], axis=-1) |
| ) |
| |
| |
| h = h / self.marginal_prob_std(t)[:, None, None, None] |
| return h |
|
|
|
|
| def marginal_prob_std(t, sigma): |
| """Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$. |
| |
| Args: |
| t: A vector of time steps. |
| sigma: The $\sigma$ in our SDE. |
| |
| Returns: |
| The standard deviation. |
| """ |
| return jnp.sqrt((sigma**(2 * t) - 1.) / 2. / jnp.log(sigma)) |
|
|
| def diffusion_coeff(t, sigma): |
| """Compute the diffusion coefficient of our SDE. |
| |
| Args: |
| t: A vector of time steps. |
| sigma: The $\sigma$ in our SDE. |
| |
| Returns: |
| The vector of diffusion coefficients. |
| """ |
| return sigma**t |
| |
| sigma = 25.0 |
| marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma) |
| diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma) |
|
|
|
|
| def loss_fn(rng, model, params, x, marginal_prob_std, eps=1e-5): |
| """The loss function for training score-based generative models. |
| |
| Args: |
| model: A `flax.linen.Module` object that represents the structure of |
| the score-based model. |
| params: A dictionary that contains all trainable parameters. |
| x: A mini-batch of training data. |
| marginal_prob_std: A function that gives the standard deviation of |
| the perturbation kernel. |
| eps: A tolerance value for numerical stability. |
| """ |
| rng, step_rng = jax.random.split(rng) |
| random_t = jax.random.uniform(step_rng, (x.shape[0],), minval=eps, maxval=1.) |
| rng, step_rng = jax.random.split(rng) |
| z = jax.random.normal(step_rng, x.shape) |
| std = marginal_prob_std(random_t) |
| perturbed_x = x + z * std[:, None, None, None] |
| score = model.apply(params, perturbed_x, random_t) |
| loss = jnp.mean(jnp.sum((score * std[:, None, None, None] + z)**2, |
| axis=(1,2,3))) |
| return loss |
|
|
| def get_train_step_fn(model, marginal_prob_std): |
| """Create a one-step training function. |
| |
| Args: |
| model: A `flax.linen.Module` object that represents the structure of |
| the score-based model. |
| marginal_prob_std: A function that gives the standard deviation of |
| the perturbation kernel. |
| Returns: |
| A function that runs one step of training. |
| """ |
| |
| val_and_grad_fn = jax.value_and_grad(loss_fn, argnums=2) |
| def step_fn(rng, x, optimizer): |
| params = optimizer.target |
| loss, grad = val_and_grad_fn(rng, model, params, x, marginal_prob_std) |
| mean_grad = jax.lax.pmean(grad, axis_name='device') |
| mean_loss = jax.lax.pmean(loss, axis_name='device') |
| new_optimizer = optimizer.apply_gradient(mean_grad) |
|
|
| return mean_loss, new_optimizer |
| return jax.pmap(step_fn, axis_name='device') |
|
|
|
|
| |
| import torch |
| import functools |
| import flax |
| from flax.serialization import to_bytes, from_bytes |
| import tensorflow as tf |
| from torch.utils.data import DataLoader |
| import torchvision.transforms as transforms |
| from torchvision.datasets import MNIST |
| import tqdm |
|
|
| n_epochs = 500 |
| |
| batch_size = 512 |
| |
| lr=1e-3 |
|
|
| rng = jax.random.PRNGKey(0) |
| fake_input = jnp.ones((batch_size, 28, 313, 1)) |
| fake_time = jnp.ones(batch_size) |
| score_model = ScoreNet(marginal_prob_std_fn) |
| params = score_model.init({'params': rng}, fake_input, fake_time) |
|
|
| |
| data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) |
| optimizer = flax.optim.Adam(learning_rate=lr).create(params) |
| train_step_fn = get_train_step_fn(score_model, marginal_prob_std_fn) |
| tqdm_epoch = tqdm.notebook.trange(n_epochs) |
|
|
| assert batch_size % jax.local_device_count() == 0 |
| data_shape = (jax.local_device_count(), -1, 28, 313, 1) |
|
|
| optimizer = flax.jax_utils.replicate(optimizer) |
| for epoch in tqdm_epoch: |
| avg_loss = 0. |
| num_items = 0 |
| for x in data_loader: |
| x = x[0] |
| x = x.numpy().reshape(data_shape) |
| rng, *step_rng = jax.random.split(rng, jax.local_device_count() + 1) |
| step_rng = jnp.asarray(step_rng) |
| loss, optimizer = train_step_fn(step_rng, x, optimizer) |
| loss = flax.jax_utils.unreplicate(loss) |
| avg_loss += loss.item() * x.shape[0] |
| num_items += x.shape[0] |
| |
| tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items)) |
| |
| with tf.io.gfile.GFile('ckpt.flax', 'wb') as fout: |
| fout.write(to_bytes(flax.jax_utils.unreplicate(optimizer))) |
|
|