| | |
| | |
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| | from torch.cuda.amp import autocast |
| |
|
| | import torchvision |
| | from torchvision.transforms import transforms |
| | from torch.utils.data import DataLoader |
| |
|
| | from torch.optim import Adam |
| |
|
| | from einops import rearrange, reduce, repeat |
| | import math |
| | from random import random |
| |
|
| | from collections import namedtuple |
| | from functools import partial |
| | from tqdm.auto import tqdm |
| | import logging |
| | import os |
| |
|
| | from PIL import Image |
| | from torchvision import utils |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | def exists(x): |
| | return x is not None |
| |
|
| | def default(val, d): |
| | if exists(val): |
| | return val |
| | return d() if callable(d) else d |
| |
|
| |
|
| | |
| |
|
| |
|
| | def cast_tuple(t, length = 1): |
| | if isinstance(t, tuple): |
| | return t |
| | return ((t,) * length) |
| |
|
| |
|
| | |
| |
|
| |
|
| | def divisible_by(numer, denom): |
| | return (numer % denom) == 0 |
| |
|
| |
|
| | |
| |
|
| |
|
| | def identity(t, *args, **kwargs): |
| | return t |
| |
|
| |
|
| | |
| |
|
| |
|
| | def cycle(dl): |
| | while True: |
| | for data in dl: |
| | yield data |
| |
|
| |
|
| | |
| |
|
| |
|
| | def has_int_squareroot(num): |
| | return (math.sqrt(num) ** 2) == num |
| |
|
| |
|
| | |
| |
|
| |
|
| | def num_to_groups(num, divisor): |
| | groups = num // divisor |
| | remainder = num % divisor |
| | arr = [divisor] * groups |
| | if remainder > 0: |
| | arr.append(remainder) |
| | return arr |
| |
|
| |
|
| | |
| |
|
| |
|
| | def convert_image_to_fn(img_type, image): |
| | if image.mode != img_type: |
| | return image.convert(img_type) |
| | return image |
| |
|
| |
|
| | |
| |
|
| |
|
| | def extract(a, t, x_shape): |
| | b, *_ = t.shape |
| | out = a.gather(-1, t) |
| | return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | def normalize_to_neg_one_to_one(img): |
| | return img * 2 - 1 |
| |
|
| | def unnormalize_to_zero_to_one(t): |
| | return (t + 1) * 0.5 |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | class SinusoidalPosEmb(nn.Module): |
| | def __init__(self, dim, theta = 10000): |
| | super().__init__() |
| | self.dim = dim |
| | self.theta = theta |
| |
|
| | def forward(self, x): |
| | device = x.device |
| | half_dim = self.dim // 2 |
| | emb = math.log(self.theta) / (half_dim - 1) |
| | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) |
| | emb = x[:, None] * emb[None, :] |
| | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) |
| | return emb |
| |
|
| |
|
| | |
| |
|
| |
|
| | class RandomOrLearnedSinusoidalPosEmb(nn.Module): |
| | """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ |
| | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ |
| |
|
| | def __init__(self, dim, is_random = False): |
| | super().__init__() |
| | assert divisible_by(dim, 2) |
| | half_dim = dim // 2 |
| | self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random) |
| |
|
| | def forward(self, x): |
| | x = rearrange(x, 'b -> b 1') |
| | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi |
| | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) |
| | fouriered = torch.cat((x, fouriered), dim = -1) |
| | return fouriered |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | def linear_beta_schedule(timesteps): |
| | """ |
| | linear schedule, proposed in original ddpm paper |
| | """ |
| | scale = 1000 / timesteps |
| | beta_start = scale * 0.0001 |
| | beta_end = scale * 0.02 |
| | return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) |
| |
|
| |
|
| | |
| |
|
| |
|
| | def cosine_beta_schedule(timesteps, s = 0.008): |
| | """ |
| | cosine schedule |
| | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ |
| | """ |
| | steps = timesteps + 1 |
| | t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps |
| | alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2 |
| | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
| | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
| | return torch.clip(betas, 0, 0.999) |
| |
|
| |
|
| | |
| |
|
| |
|
| | def sigmoid_beta_schedule(timesteps, start = -3, end = 3, tau = 1, clamp_min = 1e-5): |
| | """ |
| | sigmoid schedule |
| | proposed in https://arxiv.org/abs/2212.11972 - Figure 8 |
| | better for images > 64x64, when used during training |
| | """ |
| | steps = timesteps + 1 |
| | t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps |
| | v_start = torch.tensor(start / tau).sigmoid() |
| | v_end = torch.tensor(end / tau).sigmoid() |
| | alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start) |
| | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
| | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
| | return torch.clip(betas, 0, 0.999) |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | class GaussianDiffusion(nn.Module): |
| | |
| |
|
| | def __init__( |
| | self, |
| | model, |
| | *, |
| | image_size, |
| | timesteps = 1000, |
| | sampling_timesteps = None, |
| | objective = 'pred_noise', |
| | beta_schedule = 'linear', |
| | schedule_fn_kwargs = dict(), |
| | ddim_sampling_eta = 0., |
| | auto_normalize = True, |
| | offset_noise_strength = 0., |
| | min_snr_loss_weight = False, |
| | min_snr_gamma = 5 |
| | ): |
| | super().__init__() |
| | assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim) |
| | assert not hasattr(model, 'random_or_learned_sinusoidal_cond') or not model.random_or_learned_sinusoidal_cond |
| |
|
| | self.model = model |
| |
|
| | self.channels = self.model.channels |
| | self.self_condition = self.model.self_condition |
| |
|
| | self.image_size = image_size |
| |
|
| | self.objective = objective |
| |
|
| | assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])' |
| |
|
| | if beta_schedule == 'linear': |
| | beta_schedule_fn = linear_beta_schedule |
| | elif beta_schedule == 'cosine': |
| | beta_schedule_fn = cosine_beta_schedule |
| | elif beta_schedule == 'sigmoid': |
| | beta_schedule_fn = sigmoid_beta_schedule |
| | else: |
| | raise ValueError(f'unknown beta schedule {beta_schedule}') |
| |
|
| | betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs) |
| |
|
| | alphas = 1. - betas |
| | alphas_cumprod = torch.cumprod(alphas, dim=0) |
| | alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) |
| |
|
| | timesteps, = betas.shape |
| | self.num_timesteps = int(timesteps) |
| |
|
| | |
| |
|
| | self.sampling_timesteps = default(sampling_timesteps, timesteps) |
| |
|
| | assert self.sampling_timesteps <= timesteps |
| | self.is_ddim_sampling = self.sampling_timesteps < timesteps |
| | self.ddim_sampling_eta = ddim_sampling_eta |
| |
|
| | |
| |
|
| | register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) |
| |
|
| | register_buffer('betas', betas) |
| | register_buffer('alphas_cumprod', alphas_cumprod) |
| | register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) |
| |
|
| | |
| |
|
| | register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) |
| | register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) |
| | register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) |
| | register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) |
| | register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) |
| |
|
| | |
| |
|
| | posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) |
| |
|
| | |
| |
|
| | register_buffer('posterior_variance', posterior_variance) |
| |
|
| | |
| |
|
| | register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) |
| | register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) |
| | register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) |
| |
|
| | |
| |
|
| | self.offset_noise_strength = offset_noise_strength |
| |
|
| | |
| | |
| |
|
| | snr = alphas_cumprod / (1 - alphas_cumprod) |
| |
|
| | |
| |
|
| | maybe_clipped_snr = snr.clone() |
| | if min_snr_loss_weight: |
| | maybe_clipped_snr.clamp_(max = min_snr_gamma) |
| |
|
| | if objective == 'pred_noise': |
| | register_buffer('loss_weight', maybe_clipped_snr / snr) |
| | elif objective == 'pred_x0': |
| | register_buffer('loss_weight', maybe_clipped_snr) |
| | elif objective == 'pred_v': |
| | register_buffer('loss_weight', maybe_clipped_snr / (snr + 1)) |
| |
|
| | |
| |
|
| | self.normalize = normalize_to_neg_one_to_one if auto_normalize else identity |
| | self.unnormalize = unnormalize_to_zero_to_one if auto_normalize else identity |
| |
|
| | @property |
| | def device(self): |
| | return self.betas.device |
| |
|
| | def predict_start_from_noise(self, x_t, t, noise): |
| | return ( |
| | extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - |
| | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise |
| | ) |
| |
|
| | def predict_noise_from_start(self, x_t, t, x0): |
| | return ( |
| | (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ |
| | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) |
| | ) |
| |
|
| | def predict_v(self, x_start, t, noise): |
| | return ( |
| | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise - |
| | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start |
| | ) |
| |
|
| | def predict_start_from_v(self, x_t, t, v): |
| | return ( |
| | extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - |
| | extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v |
| | ) |
| |
|
| | def q_posterior(self, x_start, x_t, t): |
| | posterior_mean = ( |
| | extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + |
| | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t |
| | ) |
| | posterior_variance = extract(self.posterior_variance, t, x_t.shape) |
| | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) |
| | return posterior_mean, posterior_variance, posterior_log_variance_clipped |
| |
|
| | def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False): |
| | model_output = self.model(x, t, x_self_cond) |
| | maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity |
| |
|
| | if self.objective == 'pred_noise': |
| | pred_noise = model_output |
| | x_start = self.predict_start_from_noise(x, t, pred_noise) |
| | x_start = maybe_clip(x_start) |
| |
|
| | if clip_x_start and rederive_pred_noise: |
| | pred_noise = self.predict_noise_from_start(x, t, x_start) |
| |
|
| | elif self.objective == 'pred_x0': |
| | x_start = model_output |
| | x_start = maybe_clip(x_start) |
| | pred_noise = self.predict_noise_from_start(x, t, x_start) |
| |
|
| | elif self.objective == 'pred_v': |
| | v = model_output |
| | x_start = self.predict_start_from_v(x, t, v) |
| | x_start = maybe_clip(x_start) |
| | pred_noise = self.predict_noise_from_start(x, t, x_start) |
| |
|
| | return ModelPrediction(pred_noise, x_start) |
| |
|
| | def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True): |
| | preds = self.model_predictions(x, t, x_self_cond) |
| | x_start = preds.pred_x_start |
| |
|
| | if clip_denoised: |
| | x_start.clamp_(-1., 1.) |
| |
|
| | model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t) |
| | return model_mean, posterior_variance, posterior_log_variance, x_start |
| |
|
| | @torch.inference_mode() |
| | def p_sample(self, x, t: int, x_self_cond = None): |
| | b, *_, device = *x.shape, self.device |
| | batched_times = torch.full((b,), t, device = device, dtype = torch.long) |
| | model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True) |
| | noise = torch.randn_like(x) if t > 0 else 0. |
| | pred_img = model_mean + (0.5 * model_log_variance).exp() * noise |
| | return pred_img, x_start |
| |
|
| | @torch.inference_mode() |
| | def p_sample_loop(self, shape, return_all_timesteps = False): |
| | batch, device = shape[0], self.device |
| |
|
| | img = torch.randn(shape, device = device) |
| | imgs = [img] |
| |
|
| | x_start = None |
| |
|
| | for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): |
| | self_cond = x_start if self.self_condition else None |
| | img, x_start = self.p_sample(img, t, self_cond) |
| | imgs.append(img) |
| |
|
| | ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) |
| |
|
| | ret = self.unnormalize(ret) |
| | return ret |
| |
|
| | @torch.inference_mode() |
| | def ddim_sample(self, shape, return_all_timesteps = False): |
| | batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective |
| |
|
| | times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) |
| | times = list(reversed(times.int().tolist())) |
| | time_pairs = list(zip(times[:-1], times[1:])) |
| |
|
| | img = torch.randn(shape, device = device) |
| | imgs = [img] |
| |
|
| | x_start = None |
| |
|
| | for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): |
| | time_cond = torch.full((batch,), time, device = device, dtype = torch.long) |
| | self_cond = x_start if self.self_condition else None |
| | pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True, rederive_pred_noise = True) |
| |
|
| | if time_next < 0: |
| | img = x_start |
| | imgs.append(img) |
| | continue |
| |
|
| | alpha = self.alphas_cumprod[time] |
| | alpha_next = self.alphas_cumprod[time_next] |
| |
|
| | sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() |
| | c = (1 - alpha_next - sigma ** 2).sqrt() |
| |
|
| | noise = torch.randn_like(img) |
| |
|
| | img = x_start * alpha_next.sqrt() + \ |
| | c * pred_noise + \ |
| | sigma * noise |
| |
|
| | imgs.append(img) |
| |
|
| | ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1) |
| |
|
| | ret = self.unnormalize(ret) |
| | return ret |
| |
|
| | @torch.inference_mode() |
| | def sample(self, batch_size = 16, return_all_timesteps = False): |
| | image_size, channels = self.image_size, self.channels |
| | sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample |
| | return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps) |
| |
|
| | @torch.inference_mode() |
| | def interpolate(self, x1, x2, t = None, lam = 0.5): |
| | b, *_, device = *x1.shape, x1.device |
| | t = default(t, self.num_timesteps - 1) |
| |
|
| | assert x1.shape == x2.shape |
| |
|
| | t_batched = torch.full((b,), t, device = device) |
| | xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2)) |
| |
|
| | img = (1 - lam) * xt1 + lam * xt2 |
| |
|
| | x_start = None |
| |
|
| | for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t): |
| | self_cond = x_start if self.self_condition else None |
| | img, x_start = self.p_sample(img, i, self_cond) |
| |
|
| | return img |
| |
|
| | @autocast(enabled = False) |
| | def q_sample(self, x_start, t, noise = None): |
| | noise = default(noise, lambda: torch.randn_like(x_start)) |
| |
|
| | return ( |
| | extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + |
| | extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise |
| | ) |
| |
|
| | def p_losses(self, x_start, t, noise = None, offset_noise_strength = None): |
| | b, c, h, w = x_start.shape |
| |
|
| | noise = default(noise, lambda: torch.randn_like(x_start)) |
| |
|
| | |
| |
|
| | offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength) |
| |
|
| | if offset_noise_strength > 0.: |
| | offset_noise = torch.randn(x_start.shape[:2], device = self.device) |
| | noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1') |
| |
|
| | |
| |
|
| | x = self.q_sample(x_start = x_start, t = t, noise = noise) |
| |
|
| | |
| | |
| | |
| |
|
| | x_self_cond = None |
| | if self.self_condition and random() < 0.5: |
| | with torch.no_grad(): |
| | x_self_cond = self.model_predictions(x, t).pred_x_start |
| | x_self_cond.detach_() |
| |
|
| | |
| |
|
| | model_out = self.model(x, t, x_self_cond) |
| |
|
| | if self.objective == 'pred_noise': |
| | target = noise |
| | elif self.objective == 'pred_x0': |
| | target = x_start |
| | elif self.objective == 'pred_v': |
| | v = self.predict_v(x_start, t, noise) |
| | target = v |
| | else: |
| | raise ValueError(f'unknown objective {self.objective}') |
| |
|
| | loss = F.mse_loss(model_out, target, reduction = 'none') |
| | loss = reduce(loss, 'b ... -> b', 'mean') |
| |
|
| | loss = loss * extract(self.loss_weight, t, loss.shape) |
| | return loss.mean() |
| |
|
| | def forward(self, img, *args, **kwargs): |
| | b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size |
| | assert h == img_size and w == img_size, f'height and width of image must be {img_size}' |
| | t = torch.randint(0, self.num_timesteps, (b,), device=device).long() |
| |
|
| | img = self.normalize(img) |
| | return self.p_losses(img, t, *args, **kwargs) |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | def default_conv(in_channels, out_channels, kernel_size, bias=True): |
| | return nn.Conv2d( |
| | in_channels, out_channels, kernel_size, |
| | padding=(kernel_size//2), bias=bias) |
| |
|
| |
|
| | |
| |
|
| |
|
| | class Swish(nn.Module): |
| | def forward(self, x): |
| | return x * torch.sigmoid(x) |
| |
|
| |
|
| | |
| |
|
| |
|
| | class AttnBlock(nn.Module): |
| | def __init__(self, in_ch): |
| | super().__init__() |
| | self.group_norm = nn.GroupNorm(32, in_ch) |
| | self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) |
| | self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) |
| | self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) |
| | self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) |
| |
|
| | def forward(self, x): |
| | B, C, H, W = x.shape |
| | h = self.group_norm(x) |
| | q = self.proj_q(h) |
| | k = self.proj_k(h) |
| | v = self.proj_v(h) |
| |
|
| | q = q.permute(0, 2, 3, 1).view(B, H * W, C) |
| | k = k.view(B, C, H * W) |
| | w = torch.bmm(q, k) * (int(C) ** (-0.5)) |
| | assert list(w.shape) == [B, H * W, H * W] |
| | w = F.softmax(w, dim=-1) |
| |
|
| | v = v.permute(0, 2, 3, 1).view(B, H * W, C) |
| | h = torch.bmm(w, v) |
| | assert list(h.shape) == [B, H * W, C] |
| | h = h.view(B, H, W, C).permute(0, 3, 1, 2) |
| | h = self.proj(h) |
| |
|
| | return x + h |
| |
|
| |
|
| | |
| |
|
| |
|
| | class ResBlock(nn.Module): |
| | def __init__(self, in_ch, out_ch, tdim, dropout, attn=False): |
| | super().__init__() |
| | self.block1 = nn.Sequential( |
| | nn.GroupNorm(32, in_ch), |
| | Swish(), |
| | nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), |
| | ) |
| | self.temb_proj = nn.Sequential( |
| | Swish(), |
| | nn.Linear(tdim, out_ch), |
| | ) |
| | self.block2 = nn.Sequential( |
| | nn.GroupNorm(32, out_ch), |
| | Swish(), |
| | nn.Dropout(dropout), |
| | nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), |
| | ) |
| | if in_ch != out_ch: |
| | self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) |
| | else: |
| | self.shortcut = nn.Identity() |
| | if attn: |
| | self.attn = AttnBlock(out_ch) |
| | else: |
| | self.attn = nn.Identity() |
| |
|
| | def forward(self, x, temb): |
| | h = self.block1(x) |
| | h += self.temb_proj(temb)[:, :, None, None] |
| | h = self.block2(h) |
| |
|
| | h = h + self.shortcut(x) |
| | h = self.attn(h) |
| | return h |
| |
|
| |
|
| | |
| |
|
| |
|
| | class EDSR(nn.Module): |
| | |
| |
|
| | def __init__(self, |
| | resblocks=['ResBlock', 'ResBlock', 'ResBlock', 'AttnBlock', 'AttnBlock', 'ResBlock', 'ResBlock', 'ResBlock'], |
| | n_feats=128, |
| | t_dim=256, |
| | dropout=0.1, |
| | channels=1, |
| | out_dim=1, |
| | self_condition = False, |
| | learned_sinusoidal_cond=False, |
| | random_fourier_features=False, |
| | learned_sinusoidal_dim=16, |
| | sinusoidal_pos_emb_theta=10000, |
| | conv=default_conv): |
| | super(EDSR, self).__init__() |
| |
|
| | self.resblocks = resblocks |
| | self.n_feats = n_feats |
| | self.t_dim = t_dim |
| | self.dropout = dropout |
| | self.channels = channels |
| | self.out_dim = out_dim |
| | self.self_condition = self_condition |
| | self.kernel_size = 3 |
| |
|
| | |
| | if learned_sinusoidal_cond: |
| | sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features) |
| | fourier_dim = learned_sinusoidal_dim + 1 |
| | else: |
| | sinu_pos_emb = SinusoidalPosEmb(dim=self.n_feats, theta=sinusoidal_pos_emb_theta) |
| | fourier_dim = self.n_feats |
| |
|
| | self.time_mlp = nn.Sequential( |
| | sinu_pos_emb, |
| | nn.Linear(fourier_dim, self.t_dim), |
| | nn.GELU(), |
| | nn.Linear(self.t_dim, self.t_dim) |
| | ) |
| |
|
| | |
| | self.head = conv(self.channels, self.n_feats, self.kernel_size) |
| |
|
| | |
| | self.body = nn.ModuleList() |
| | for block in resblocks: |
| | if block == "ResBlock": |
| | self.body.append( |
| | ResBlock(in_ch=self.n_feats, |
| | out_ch=self.n_feats, |
| | tdim=self.t_dim, |
| | dropout=self.dropout, |
| | attn=False)) |
| | elif block == "AttnBlock": |
| | self.body.append( |
| | ResBlock(in_ch=self.n_feats, |
| | out_ch=self.n_feats, |
| | tdim=self.t_dim, |
| | dropout=self.dropout, |
| | attn=True)) |
| | else: |
| | raise NotImplementedError("Model currently doesn't support this kind of block!") |
| | self.body.append(conv(self.n_feats, self.n_feats, self.kernel_size)) |
| |
|
| | |
| | self.tail = conv(self.n_feats, self.out_dim, self.kernel_size) |
| |
|
| |
|
| | def forward(self, x, t, cond=None): |
| | t = self.time_mlp(t) |
| |
|
| | x = self.head(x) |
| |
|
| | res = x |
| | for block in self.body: |
| | if isinstance(block, ResBlock): |
| | res = block(res, t) |
| | else: |
| | res = block(res) |
| | res += x |
| |
|
| | x = self.tail(res) |
| |
|
| | return x |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | |
| | save_path = '/content/DDPM_ResNet_Unet/resnet/model' |
| | log_path = '/content/DDPM_ResNet_Unet/resnet/log' |
| |
|
| | if not os.path.exists(log_path): |
| | os.mkdir(log_path) |
| | if not os.path.exists(save_path): |
| | os.mkdir(save_path) |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| |
|
| | |
| | logging.basicConfig( |
| | filename=os.path.join(log_path, 'info.log'), |
| | filemode="w", |
| | level=logging.DEBUG, |
| | format= '[%(asctime)s] %(levelname)s - %(message)s', |
| | datefmt='%H:%M:%S', |
| | force=True |
| | ) |
| |
|
| |
|
| | |
| | pil_logger = logging.getLogger('PIL') |
| | pil_logger.setLevel(logging.INFO) |
| |
|
| | |
| | console = logging.StreamHandler() |
| | console.setLevel(logging.INFO) |
| | logging.getLogger().addHandler(console) |
| |
|
| | logger = logging.getLogger('Diffusion_Resnet') |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| | model = EDSR( |
| | resblocks=['ResBlock', 'ResBlock', 'ResBlock', 'AttnBlock', 'AttnBlock', |
| | 'AttnBlock', 'AttnBlock', 'ResBlock', 'ResBlock', 'ResBlock',], |
| | n_feats=256, |
| | t_dim=512, |
| | dropout=0.1, |
| | channels=1, |
| | out_dim=1, |
| | learned_sinusoidal_cond=False, |
| | random_fourier_features=False, |
| | learned_sinusoidal_dim=16, |
| | sinusoidal_pos_emb_theta=10000,) |
| |
|
| | diffusion_model = GaussianDiffusion( |
| | model, |
| | image_size=28, |
| | timesteps=1000, |
| | sampling_timesteps=None, |
| | objective ='pred_noise', |
| | beta_schedule ='linear', |
| | schedule_fn_kwargs=dict(), |
| | ddim_sampling_eta= 0., |
| | auto_normalize = True, |
| | offset_noise_strength = 0., |
| | min_snr_loss_weight = False, |
| | min_snr_gamma = 5) |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| | transform = transforms.Compose([ |
| | transforms.ToTensor(), |
| | |
| | ]) |
| |
|
| | train_dataset = torchvision.datasets.MNIST(root='.', train=True, |
| | download=True, transform=transform) |
| | |
| | |
| |
|
| | train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True) |
| | |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| | train_lr = 1e-4 |
| | adam_betas = (0.9, 0.99) |
| |
|
| | optimizer = Adam(diffusion_model.parameters(), |
| | lr=train_lr, |
| | betas=adam_betas) |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| |
|
| | |
| |
|
| |
|
| | |
| | max_epoches = 50 |
| | iter_print = 100 |
| | iter_sample = 1000 |
| | save_each = 1 |
| |
|
| | diffusion_model = diffusion_model.to(device) |
| |
|
| | last_trained_path = None |
| | if last_trained_path: |
| | data = torch.load(os.path.join(last_trained_path)) |
| | diffusion_model.load_state_dict(data['model']) |
| | optimizer.load_state_dict(data['opt']) |
| | count = data['step'] |
| | start_epoch = data['epoch'] |
| | log_loss = data['loss'] |
| | else: |
| | count = 0 |
| | start_epoch = 1 |
| | log_loss = [] |
| |
|
| | for epoch in range(start_epoch, max_epoches+1): |
| | diffusion_model.train() |
| | for img, _ in train_dataloader: |
| | img = img.to(device) |
| |
|
| | loss = diffusion_model(img) |
| |
|
| | optimizer.zero_grad() |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | if count % iter_print == 0 or count == 0: |
| | logger.info('Epoch {}/{}, Iter {}: Loss = {}, lr = {}'.format( |
| | epoch, |
| | max_epoches, |
| | count, |
| | loss.mean().item(), |
| | train_lr, |
| | )) |
| |
|
| | log_loss.append(loss.mean().item()) |
| |
|
| | loss = None |
| |
|
| | count += 1 |
| |
|
| | if count % iter_sample == 0: |
| | diffusion_model.eval() |
| |
|
| | sample_imgs = diffusion_model.sample(batch_size=16) |
| |
|
| | utils.save_image(sample_imgs, |
| | os.path.join(log_path, f"iter_{count}.png"), |
| | nrow = int(math.sqrt(16))) |
| |
|
| |
|
| | if epoch % save_each == 0: |
| | data = { |
| | 'model': diffusion_model.state_dict(), |
| | 'opt': optimizer.state_dict(), |
| | 'step': count, |
| | 'epoch': epoch, |
| | 'loss': log_loss, |
| | } |
| |
|
| | torch.save(data, os.path.join(save_path, f"epoch_{epoch}.pth")) |
| |
|
| |
|