|
|
import math |
|
|
|
|
|
import torch |
|
|
from einops import rearrange |
|
|
from tqdm import tqdm |
|
|
|
|
|
from .utils import get_tensor_items |
|
|
|
|
|
|
|
|
def get_named_beta_schedule(schedule_name, timesteps): |
|
|
if schedule_name == "linear": |
|
|
scale = 1000 / timesteps |
|
|
beta_start = scale * 0.0001 |
|
|
beta_end = scale * 0.02 |
|
|
return torch.linspace( |
|
|
beta_start, beta_end, timesteps, dtype=torch.float32 |
|
|
) |
|
|
elif schedule_name == "cosine": |
|
|
alpha_bar = lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 |
|
|
betas = [] |
|
|
for i in range(timesteps): |
|
|
t1 = i / timesteps |
|
|
t2 = (i + 1) / timesteps |
|
|
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), 0.999)) |
|
|
return torch.tensor(betas, dtype=torch.float32) |
|
|
|
|
|
|
|
|
class BaseDiffusion: |
|
|
|
|
|
def __init__(self, betas, percentile=None, gen_noise=torch.randn_like): |
|
|
self.betas = betas |
|
|
self.num_timesteps = betas.shape[0] |
|
|
|
|
|
alphas = 1. - betas |
|
|
self.alphas_cumprod = torch.cumprod(alphas, dim=0) |
|
|
self.alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=betas.dtype), self.alphas_cumprod[:-1]]) |
|
|
|
|
|
|
|
|
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) |
|
|
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod) |
|
|
|
|
|
|
|
|
self.posterior_mean_coef_1 = torch.sqrt(self.alphas_cumprod_prev) * betas / (1. - self.alphas_cumprod) |
|
|
self.posterior_mean_coef_2 = torch.sqrt(alphas) * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod) |
|
|
self.posterior_variance = betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod) |
|
|
self.posterior_log_variance = torch.log( |
|
|
torch.cat([self.posterior_variance[1].unsqueeze(0), self.posterior_variance[1:]]) |
|
|
) |
|
|
|
|
|
self.percentile = percentile |
|
|
self.time_scale = 1000 // self.num_timesteps |
|
|
self.gen_noise = gen_noise |
|
|
self.jump_length = 3 |
|
|
|
|
|
def process_x_start(self, x_start): |
|
|
bs, ndims = x_start.shape[0], len(x_start.shape[1:]) |
|
|
if self.percentile is not None: |
|
|
quantile = torch.quantile( |
|
|
rearrange(x_start, 'b ... -> b (...)').abs(), |
|
|
self.percentile, |
|
|
dim=-1 |
|
|
) |
|
|
quantile = torch.clip(quantile, min=1.) |
|
|
quantile = quantile.reshape(bs, *((1,) * ndims)) |
|
|
return torch.clip(x_start, -quantile, quantile) / quantile |
|
|
else: |
|
|
return torch.clip(x_start, -1., 1.) |
|
|
|
|
|
def get_x_start(self, x, t, noise): |
|
|
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape) |
|
|
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, noise.shape) |
|
|
pred_x_start = (x - sqrt_one_minus_alphas_cumprod * noise) / sqrt_alphas_cumprod |
|
|
return pred_x_start |
|
|
|
|
|
def get_noise(self, x, t, x_start): |
|
|
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) |
|
|
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape) |
|
|
pred_noise = (x - sqrt_alphas_cumprod * x_start) / sqrt_one_minus_alphas_cumprod |
|
|
return pred_noise |
|
|
|
|
|
def q_sample(self, x_start, t, noise=None): |
|
|
if noise is None: |
|
|
noise = self.gen_noise(x_start) |
|
|
sqrt_alphas_cumprod = get_tensor_items(self.sqrt_alphas_cumprod, t, x_start.shape) |
|
|
sqrt_one_minus_alphas_cumprod = get_tensor_items(self.sqrt_one_minus_alphas_cumprod, t, noise.shape) |
|
|
x_t = sqrt_alphas_cumprod * x_start + sqrt_one_minus_alphas_cumprod * noise |
|
|
return x_t |
|
|
|
|
|
def q_posterior_mean_variance(self, x_start, x_t, t): |
|
|
posterior_mean_coef_1 = get_tensor_items(self.posterior_mean_coef_1, t, x_start.shape) |
|
|
posterior_mean_coef_2 = get_tensor_items(self.posterior_mean_coef_2, t, x_t.shape) |
|
|
posterior_mean = posterior_mean_coef_1 * x_start + posterior_mean_coef_2 * x_t |
|
|
|
|
|
posterior_variance = get_tensor_items(self.posterior_variance, t, x_start.shape) |
|
|
posterior_log_variance = get_tensor_items(self.posterior_log_variance, t, x_start.shape) |
|
|
return posterior_mean, posterior_variance, posterior_log_variance |
|
|
|
|
|
def q_posterior_variance(self, t, prev_t, shape, eta=1., ): |
|
|
alphas_cumprod = get_tensor_items(self.alphas_cumprod, t, shape) |
|
|
prev_alphas_cumprod = get_tensor_items(self.alphas_cumprod, prev_t, shape) |
|
|
|
|
|
posterior_variance = torch.sqrt( |
|
|
eta * (1. - alphas_cumprod / prev_alphas_cumprod) * (1. - prev_alphas_cumprod) / (1. - alphas_cumprod) |
|
|
) |
|
|
return posterior_variance |
|
|
|
|
|
def text_guidance( |
|
|
self, model, x, t, context, context_mask, null_embedding, guidance_weight_text, |
|
|
uncondition_context=None, uncondition_context_mask=None, mask=None, masked_latent=None |
|
|
): |
|
|
large_x = x.repeat(2, 1, 1, 1) |
|
|
large_t = t.repeat(2).to(x.dtype) |
|
|
|
|
|
if uncondition_context is None: |
|
|
uncondition_context = torch.zeros_like(context) |
|
|
uncondition_context_mask = torch.zeros_like(context_mask) |
|
|
uncondition_context[:, 0] = null_embedding |
|
|
uncondition_context_mask[:, 0] = 1 |
|
|
large_context = torch.cat([context, uncondition_context]) |
|
|
large_context_mask = torch.cat([context_mask, uncondition_context_mask]) |
|
|
|
|
|
if mask is not None: |
|
|
mask = mask.repeat(2, 1, 1, 1) |
|
|
if masked_latent is not None: |
|
|
masked_latent = masked_latent.repeat(2, 1, 1, 1) |
|
|
|
|
|
if model.in_layer.in_channels == 9: |
|
|
large_x = torch.cat([large_x, mask, masked_latent], dim=1) |
|
|
|
|
|
pred_large_noise = model(large_x, large_t * self.time_scale, large_context, large_context_mask.bool()) |
|
|
pred_noise, uncond_pred_noise = torch.chunk(pred_large_noise, 2) |
|
|
pred_noise = (guidance_weight_text + 1.) * pred_noise - guidance_weight_text * uncond_pred_noise |
|
|
return pred_noise |
|
|
|
|
|
def p_mean_variance( |
|
|
self, model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta=1., |
|
|
negative_context=None, negative_context_mask=None, mask=None, masked_latent=None |
|
|
): |
|
|
|
|
|
pred_noise = self.text_guidance( |
|
|
model, x, t, context, context_mask, null_embedding, guidance_weight_text, |
|
|
negative_context, negative_context_mask, mask, masked_latent |
|
|
) |
|
|
|
|
|
pred_x_start = self.get_x_start(x, t, pred_noise) |
|
|
pred_x_start = self.process_x_start(pred_x_start) |
|
|
pred_noise = self.get_noise(x, t, pred_x_start) |
|
|
pred_var = self.q_posterior_variance(t, prev_t, x.shape, eta) |
|
|
|
|
|
prev_alphas_cumprod = get_tensor_items(self.alphas_cumprod, prev_t, x.shape) |
|
|
pred_mean = torch.sqrt(prev_alphas_cumprod) * pred_x_start |
|
|
pred_mean += torch.sqrt(1. - prev_alphas_cumprod - pred_var ** 2) * pred_noise |
|
|
return pred_mean, pred_var |
|
|
|
|
|
|
|
|
def p_sample( |
|
|
self, model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta=1., |
|
|
negative_context=None, negative_context_mask=None, mask=None, masked_latent=None |
|
|
): |
|
|
bs = x.shape[0] |
|
|
ndims = len(x.shape[1:]) |
|
|
pred_mean, pred_var = self.p_mean_variance( |
|
|
model, x, t, prev_t, context, context_mask, null_embedding, guidance_weight_text, eta, |
|
|
negative_context=negative_context, negative_context_mask=negative_context_mask, |
|
|
mask=mask, masked_latent=masked_latent |
|
|
) |
|
|
noise = torch.randn_like(x) |
|
|
mask = (prev_t != 0).reshape(bs, *((1,) * ndims)) |
|
|
sample = pred_mean + mask * pred_var * noise |
|
|
return sample |
|
|
|
|
|
|
|
|
def p_sample_loop( |
|
|
self, model, shape, times, device, context, context_mask, null_embedding, guidance_weight_text, eta=1., |
|
|
negative_context=None, negative_context_mask=None, mask=None, masked_latent=None, gan=False, |
|
|
): |
|
|
img = torch.randn(*shape, device=device) |
|
|
times = times + [0, ] |
|
|
times = list(zip(times[:-1], times[1:])) |
|
|
|
|
|
for time, prev_time in tqdm(times): |
|
|
time = torch.tensor([time] * shape[0], device=device) |
|
|
if gan: |
|
|
x_t = self.q_sample(img, time) |
|
|
pred_noise = model(x_t, time.type(x_t.dtype), context, context_mask.bool()) |
|
|
img = self.get_x_start(x_t, time, pred_noise) |
|
|
else: |
|
|
prev_time = torch.tensor([prev_time] * shape[0], device=device) |
|
|
img = self.p_sample( |
|
|
model, img, time, prev_time, context, context_mask, null_embedding, guidance_weight_text, eta, |
|
|
negative_context=negative_context, negative_context_mask=negative_context_mask, |
|
|
mask=mask, masked_latent=masked_latent |
|
|
) |
|
|
return img |
|
|
|
|
|
|
|
|
def get_diffusion(conf): |
|
|
betas = get_named_beta_schedule(**conf.schedule_params) |
|
|
base_diffusion = BaseDiffusion(betas, **conf.diffusion_params) |
|
|
return base_diffusion |
|
|
|