diffusion_model / diffusion_sampler.py
leowajda's picture
compile sampling step with autograph
ea54e12
raw
history blame
6.72 kB
import numpy as np
import tqdm as tqdm
import tensorflow as tf
import math
from tensorflow import keras
from keras.models import load_model
def as_float32(t: tf.Tensor) -> tf.Tensor:
return tf.cast(t, dtype=tf.float32)
def batch_reshape(t: tf.Tensor, x: tf.Tensor) -> tf.Tensor:
def inner_function(coeff: tf.Tensor) -> tf.Tensor:
batch_dim = tf.shape(x)[0]
return tf.reshape(tf.gather(coeff, t), [batch_dim, 1, 1, 1])
return inner_function
class DiffusionSampler(keras.Model):
def __init__(
self,
model: keras.Model | str,
ema_model: keras.Model | str,
timesteps: int | None = 1_000,
beta_start: float | None = 1e-4,
beta_end: float | None = 0.02,
noise_scheduler: str = "linear",
ema: float = 0.999,
**kwargs,
):
super().__init__(**kwargs)
self.noise_predictor = load_model(filepath=model, safe_mode=False) if isinstance(model, str) else model
self.ema_noise_predictor = load_model(filepath=ema_model, safe_mode=False) if isinstance(model, str) else ema_model
self.ema = ema
self.beta_start = beta_start
self.beta_end = beta_end
self.timesteps = timesteps
betas = self.noise_scheduler(noise_scheduler)
alphas = 1.0 - betas
alphas_cum_prod = tf.math.cumprod(alphas, axis=0)
alphas_cum_prod_prev = tf.concat([tf.constant([1.0], dtype=tf.float64), alphas_cum_prod[:-1]], axis=0)
posterior_variances = betas * (1.0 - alphas_cum_prod_prev) / (1.0 - alphas_cum_prod)
self.betas = as_float32(betas)
self.posterior_variances = as_float32(posterior_variances)
self.alphas_cum_prod_prev = as_float32(alphas_cum_prod_prev)
self.one_minus_alphas_cum_prod = as_float32(1.0 - alphas_cum_prod)
self.one_minus_alphas_cum_prod_prev = as_float32(1.0 - alphas_cum_prod_prev)
self.sqrt_one_minus_alphas_cum_prod = as_float32(tf.sqrt(1.0 - alphas_cum_prod))
self.sqrt_alphas_cum_prod_prev = as_float32(tf.sqrt(alphas_cum_prod_prev))
self.sqrt_alphas_cum_prod = as_float32(tf.sqrt(alphas_cum_prod))
self.rev_sqrt_alphas_cum_prod = as_float32(1.0 / tf.sqrt(alphas_cum_prod))
self.rev_sqrt_alphas = as_float32(tf.sqrt(1.0 / alphas))
def ddpm_sample(self, pred_noise: tf.Tensor, x_t: tf.Tensor, t: tf.Tensor) -> tf.Tensor:
batch_dim = tf.shape(x_t)[0]
at_timestep = batch_reshape(t, x_t)
beta = at_timestep(self.betas)
rev_sqrt_alpha = at_timestep(self.rev_sqrt_alphas)
sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
posterior_variance = at_timestep(self.posterior_variances)
mean = rev_sqrt_alpha * (
x_t - (beta / sqrt_one_minus_alpha_cum_prod) * pred_noise
)
nonzero_mask = tf.reshape(
1 - tf.cast(tf.equal(t, 0), dtype=tf.float32), [batch_dim, 1, 1, 1]
)
random_noise = tf.random.normal(shape=x_t.shape, dtype=x_t.dtype)
return mean + nonzero_mask * tf.sqrt(posterior_variance) * random_noise
def ddim_sample(self, pred_noise: tf.Tensor, x_t: tf.Tensor, t: tf.Tensor, eta: float = 0.0) -> tf.Tensor:
at_timestep = batch_reshape(t, x_t)
sqrt_alpha_cum_prod_prev = at_timestep(self.sqrt_alphas_cum_prod_prev)
rev_sqrt_alpha_cum_prod = at_timestep(self.rev_sqrt_alphas_cum_prod)
sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
alpha_cum_prod_prev = at_timestep(self.alphas_cum_prod_prev)
one_minus_alpha_cum_prod = at_timestep(self.one_minus_alphas_cum_prod)
one_minus_alpha_cum_prod_prev = at_timestep(self.one_minus_alphas_cum_prod_prev)
x0_t = (
(x_t - (sqrt_one_minus_alpha_cum_prod * pred_noise)) * rev_sqrt_alpha_cum_prod
)
c1 = eta * tf.sqrt(
(one_minus_alpha_cum_prod_prev / one_minus_alpha_cum_prod) * (
one_minus_alpha_cum_prod / alpha_cum_prod_prev)
)
x_t_dir = tf.sqrt(one_minus_alpha_cum_prod_prev - tf.square(c1))
random_noise = tf.random.normal(shape=x_t.shape, dtype=x_t.dtype)
return sqrt_alpha_cum_prod_prev * x0_t + x_t_dir * pred_noise + c1 * random_noise
def noise_scheduler(self, scheduler: str, max_beta: int = 0.02) -> tf.Tensor:
alpha_bar = lambda t: tf.math.cos((t + 0.008) / 1.008 * tf.constant(math.pi, dtype=tf.float64) / 2) ** 2
cosine_scheduler = lambda i: tf.minimum(
1 - alpha_bar((i + 1) / tf.cast(self.timesteps, dtype=tf.float64)) / alpha_bar(
i / tf.cast(self.timesteps, dtype=tf.float64)), max_beta)
if scheduler == "linear":
x = tf.linspace(start=self.beta_start, stop=self.beta_end, num=self.timesteps)
return tf.cast(x, dtype=tf.float64)
elif scheduler == "cosine":
x = tf.vectorized_map(fn=cosine_scheduler, elems=tf.range(self.timesteps, dtype=tf.float64))
return tf.cast(x, dtype=tf.float64)
def x_t(self, x_start: tf.Tensor, t: tf.Tensor, noise: tf.Tensor) -> tf.Tensor:
at_timestep = batch_reshape(t, x_start)
sqrt_alpha_cum_prod = at_timestep(self.sqrt_alphas_cum_prod)
sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
return sqrt_alpha_cum_prod * x_start + sqrt_one_minus_alpha_cum_prod * noise
@tf.function()
def generate_images(
self,
num_images: int,
steps: int,
sample_strategy: str = "ddim",
step_strategy: str = "uniform",
ema: bool = True,
):
sampling_stategies = {
("ddpm", "linear"): (self.ddpm_sample, tf.range(self.timesteps, dtype=tf.float64)),
("ddpm", "quadratic"): (self.ddpm_sample, tf.range(self.timesteps, dtype=tf.float64)),
("ddim", "linear"): (self.ddim_sample, tf.range(steps, dtype=tf.float64)),
("ddim", "quadratic"): (self.ddim_sample, tf.cast(tf.linspace(start=0.0, stop=tf.sqrt(self.timesteps * 0.8), num=steps) ** 2, dtype=tf.float64))
}
noise_predictor = self.ema_noise_predictor if ema else self.noise_predictor
sampler, seq = sampling_stategies[(sample_strategy, step_strategy)]
samples = tf.random.normal(shape=(num_images, 64, 64, 3), dtype=tf.float32)
for t in tf.reverse(seq, axis=[0]):
tt = tf.cast(tf.fill(dims=(num_images,), value=t), dtype=tf.int64)
pred_noise = noise_predictor([samples, tt], training=False)
samples = sampler(pred_noise, samples, tt, )
return tf.clip_by_value(samples * 127.5 + 127.5, 0.0, 255.0)