NeuralODE_SDE / Score-SDE /train-score-sde.py
ibraheemmoosa's picture
Add Score-SDE training script.
c096304
import jax
import jax.numpy as jnp
from jax import random
import flax
import flax.linen as nn
from typing import Any, Tuple
import functools
import numpy as np
import torch
from torch.utils.data import TensorDataset
key = random.PRNGKey(0)
dataset = []
with np.load('spectograms.npz') as data:
for file in data.files:
dataset.append(data[file])
dataset = np.stack(dataset)
dataset = np.expand_dims(dataset, axis=3)
dataset = TensorDataset(torch.from_numpy(dataset))
# The following code is copied with minor modifications from https://colab.research.google.com/drive/1SeXMpILhkJPjXUaesvzEhc3Ke6Zl_zxJ?usp=sharing
class GaussianFourierProjection(nn.Module):
"""Gaussian random features for encoding time steps."""
embed_dim: int
scale: float = 30.
@nn.compact
def __call__(self, x):
# Randomly sample weights during initialization. These weights are fixed
# during optimization and are not trainable.
W = self.param('W', jax.nn.initializers.normal(stddev=self.scale),
(self.embed_dim // 2, ))
W = jax.lax.stop_gradient(W)
x_proj = x[:, None] * W[None, :] * 2 * jnp.pi
return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1)
class Dense(nn.Module):
"""A fully connected layer that reshapes outputs to feature maps."""
output_dim: int
@nn.compact
def __call__(self, x):
return nn.Dense(self.output_dim)(x)[:, None, None, :]
class ScoreNet(nn.Module):
"""A time-dependent score-based model built upon U-Net architecture.
Args:
marginal_prob_std: A function that takes time t and gives the standard
deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
channels: The number of channels for feature maps of each resolution.
embed_dim: The dimensionality of Gaussian random feature embeddings.
"""
marginal_prob_std: Any
channels: Tuple[int] = (32, 64, 128, 256)
embed_dim: int = 256
@nn.compact
def __call__(self, x, t):
# The swish activation function
act = nn.swish
# Obtain the Gaussian random feature embedding for t
embed = act(nn.Dense(self.embed_dim)(
GaussianFourierProjection(embed_dim=self.embed_dim)(t)))
# Encoding path
h1 = nn.Conv(self.channels[0], (3, 3), (1, 1), padding='VALID',
use_bias=False)(x)
# print('h1', h1.shape)#26x311
## Incorporate information from t
h1 += Dense(self.channels[0])(embed)
## Group normalization
h1 = nn.GroupNorm(4)(h1)
h1 = act(h1)
h2 = nn.Conv(self.channels[1], (3, 3), (2, 2), padding='VALID',
use_bias=False)(h1)
# print('h2', h2.shape)#12x155
h2 += Dense(self.channels[1])(embed)
h2 = nn.GroupNorm()(h2)
h2 = act(h2)
h3 = nn.Conv(self.channels[2], (3, 3), (2, 2), padding='VALID',
use_bias=False)(h2)
# print('h3', h3.shape)#5x77
h3 += Dense(self.channels[2])(embed)
h3 = nn.GroupNorm()(h3)
h3 = act(h3)
h4 = nn.Conv(self.channels[3], (3, 3), (2, 2), padding='VALID',
use_bias=False)(h3)
# print('h4', h4.shape)#2x38
h4 += Dense(self.channels[3])(embed)
h4 = nn.GroupNorm()(h4)
h4 = act(h4)
# Decoding path
h = nn.Conv(self.channels[2], (3, 3), (1, 1), padding=((2, 2), (2, 2)),
input_dilation=(2, 2), use_bias=False)(h4)
# print('h', h.shape)#5x77
## Skip connection from the encoding path
h += Dense(self.channels[2])(embed)
h = nn.GroupNorm()(h)
h = act(h)
h = nn.Conv(self.channels[1], (3, 3), (1, 1), padding=((2, 3), (2, 2)),
input_dilation=(2, 2), use_bias=False)(
jnp.concatenate([h, h3], axis=-1)
)
# print('h', h.shape)#12x155
h += Dense(self.channels[1])(embed)
h = nn.GroupNorm()(h)
h = act(h)
h = nn.Conv(self.channels[0], (3, 3), (1, 1), padding=((2, 3), (2, 2)),
input_dilation=(2, 2), use_bias=False)(
jnp.concatenate([h, h2], axis=-1)
)
# print('h', h.shape)#26x311
h += Dense(self.channels[0])(embed)
h = nn.GroupNorm()(h)
h = act(h)
h = nn.Conv(1, (3, 3), (1, 1), padding=((2, 2), (2, 2)))(
jnp.concatenate([h, h1], axis=-1)
)
# print('h', h.shape)#28x313
# Normalize output
h = h / self.marginal_prob_std(t)[:, None, None, None]
return h
def marginal_prob_std(t, sigma):
"""Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.
Args:
t: A vector of time steps.
sigma: The $\sigma$ in our SDE.
Returns:
The standard deviation.
"""
return jnp.sqrt((sigma**(2 * t) - 1.) / 2. / jnp.log(sigma))
def diffusion_coeff(t, sigma):
"""Compute the diffusion coefficient of our SDE.
Args:
t: A vector of time steps.
sigma: The $\sigma$ in our SDE.
Returns:
The vector of diffusion coefficients.
"""
return sigma**t
sigma = 25.0#@param {'type':'number'}
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)
def loss_fn(rng, model, params, x, marginal_prob_std, eps=1e-5):
"""The loss function for training score-based generative models.
Args:
model: A `flax.linen.Module` object that represents the structure of
the score-based model.
params: A dictionary that contains all trainable parameters.
x: A mini-batch of training data.
marginal_prob_std: A function that gives the standard deviation of
the perturbation kernel.
eps: A tolerance value for numerical stability.
"""
rng, step_rng = jax.random.split(rng)
random_t = jax.random.uniform(step_rng, (x.shape[0],), minval=eps, maxval=1.)
rng, step_rng = jax.random.split(rng)
z = jax.random.normal(step_rng, x.shape)
std = marginal_prob_std(random_t)
perturbed_x = x + z * std[:, None, None, None]
score = model.apply(params, perturbed_x, random_t)
loss = jnp.mean(jnp.sum((score * std[:, None, None, None] + z)**2,
axis=(1,2,3)))
return loss
def get_train_step_fn(model, marginal_prob_std):
"""Create a one-step training function.
Args:
model: A `flax.linen.Module` object that represents the structure of
the score-based model.
marginal_prob_std: A function that gives the standard deviation of
the perturbation kernel.
Returns:
A function that runs one step of training.
"""
val_and_grad_fn = jax.value_and_grad(loss_fn, argnums=2)
def step_fn(rng, x, optimizer):
params = optimizer.target
loss, grad = val_and_grad_fn(rng, model, params, x, marginal_prob_std)
mean_grad = jax.lax.pmean(grad, axis_name='device')
mean_loss = jax.lax.pmean(loss, axis_name='device')
new_optimizer = optimizer.apply_gradient(mean_grad)
return mean_loss, new_optimizer
return jax.pmap(step_fn, axis_name='device')
#@title Training (double click to expand or collapse)
import torch
import functools
import flax
from flax.serialization import to_bytes, from_bytes
import tensorflow as tf
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
import tqdm
n_epochs = 500#@param {'type':'integer'}
## size of a mini-batch
batch_size = 512#@param {'type':'integer'}
## learning rate
lr=1e-3 #@param {'type':'number'}
rng = jax.random.PRNGKey(0)
fake_input = jnp.ones((batch_size, 28, 313, 1))
fake_time = jnp.ones(batch_size)
score_model = ScoreNet(marginal_prob_std_fn)
params = score_model.init({'params': rng}, fake_input, fake_time)
# dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
optimizer = flax.optim.Adam(learning_rate=lr).create(params)
train_step_fn = get_train_step_fn(score_model, marginal_prob_std_fn)
tqdm_epoch = tqdm.notebook.trange(n_epochs)
assert batch_size % jax.local_device_count() == 0
data_shape = (jax.local_device_count(), -1, 28, 313, 1)
optimizer = flax.jax_utils.replicate(optimizer)
for epoch in tqdm_epoch:
avg_loss = 0.
num_items = 0
for x in data_loader:
x = x[0]
x = x.numpy().reshape(data_shape)
rng, *step_rng = jax.random.split(rng, jax.local_device_count() + 1)
step_rng = jnp.asarray(step_rng)
loss, optimizer = train_step_fn(step_rng, x, optimizer)
loss = flax.jax_utils.unreplicate(loss)
avg_loss += loss.item() * x.shape[0]
num_items += x.shape[0]
# Print the averaged training loss so far.
tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
# Update the checkpoint after each epoch of training.
with tf.io.gfile.GFile('ckpt.flax', 'wb') as fout:
fout.write(to_bytes(flax.jax_utils.unreplicate(optimizer)))