import torch import torch.nn as nn from enum import Enum from tqdm import trange Schedule = Enum('Schedule', ['LINEAR', 'COSINE']) class DiffusionManager(nn.Module): def __init__(self, model: nn.Module, noise_steps=1000, start=0.0001, end=0.02, device="cpu", **kwargs ) -> None: super().__init__(**kwargs) self.model = model self.noise_steps = noise_steps self.start = start self.end = end self.device = device self.schedule = None self.set_schedule() #model.set_parent(self) def _get_schedule(self, schedule_type: Schedule = Schedule.LINEAR): if schedule_type == Schedule.LINEAR: return torch.linspace(self.start, self.end, self.noise_steps) elif schedule_type == Schedule.COSINE: # https://arxiv.org/pdf/2102.09672 page 4 #https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py #line 18 def get_alphahat_at(t): def f(t): s=self.start return torch.cos((t/self.noise_steps + s)/(1+s) * torch.pi/2) ** 2 return f(t)/f(torch.zeros_like(t)) t = torch.Tensor(range(self.noise_steps)) t = 1-(get_alphahat_at(t + 1)/get_alphahat_at(t)) t = torch.minimum(t, torch.ones_like(t) * 0.999) #"In practice, we clip β_t to be no larger than 0.999 to prevent singularities at the end of the diffusion process n" return t def set_schedule(self, schedule: Schedule = Schedule.LINEAR): self.schedule = self._get_schedule(schedule).to(self.device) def get_schedule_at(self, step): beta = self.schedule alpha = 1 - beta alpha_hat = torch.cumprod(alpha, dim=0) return self._unsqueezify(beta.data[step]), self._unsqueezify(alpha.data[step]), self._unsqueezify(alpha_hat.data[step]) @staticmethod def _unsqueezify(value): return value.view(-1, 1, 1, 1)#.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) def noise_image(self, image, step): image = image.to(self.device) beta, alpha, alpha_hat = self.get_schedule_at(step) epsilon = torch.randn_like(image) # print(alpha_hat) # print(alpha_hat.size()) # print(image.size()) noised_img = torch.sqrt(alpha_hat) * image + torch.sqrt(1 - alpha_hat) * epsilon return noised_img, epsilon def random_timesteps(self, amt=1): return torch.randint(low=1, high=self.noise_steps, size=(amt,)) def sample(self, img_size, condition, amt=5, use_tqdm=True): if tuple(condition.shape)[0] < amt: condition = condition.repeat(amt, 1) self.model.eval() condition = condition.to(self.device) my_trange = lambda x, y, z: trange(x,y, z, leave=False,dynamic_ncols=True) fn = my_trange if use_tqdm else range with torch.no_grad(): cur_img = torch.randn((amt, 3, img_size, img_size)).to(self.device) for i in fn(self.noise_steps-1, 0, -1): timestep = torch.ones(amt) * (i) timestep = timestep.to(self.device) predicted_noise = self.model(cur_img, timestep, condition) beta, alpha, alpha_hat = self.get_schedule_at(i) cur_img = (1/torch.sqrt(alpha))*(cur_img - (beta/torch.sqrt(1-alpha_hat))*predicted_noise) if i > 1: cur_img = cur_img + torch.sqrt(beta)*torch.randn_like(cur_img) self.model.train() return cur_img def sample_multicond(self, img_size, condition, use_tqdm=True): num_conditions = condition.shape[0] amt = num_conditions self.model.eval() condition = condition.to(self.device) my_trange = lambda x, y, z: trange(x, y, z, leave=False, dynamic_ncols=True) fn = my_trange if use_tqdm else range with torch.no_grad(): cur_img = torch.randn((amt, 3, img_size, img_size)).to(self.device) for i in fn(self.noise_steps-1, 0, -1): timestep = torch.ones(amt) * i timestep = timestep.to(self.device) predicted_noise = self.model(cur_img, timestep, condition) beta, alpha, alpha_hat = self.get_schedule_at(i) cur_img = (1 / torch.sqrt(alpha)) * (cur_img - (beta / torch.sqrt(1 - alpha_hat)) * predicted_noise) if i > 1: cur_img = cur_img + torch.sqrt(beta) * torch.randn_like(cur_img) self.model.train() # Return images sampled for each condition return cur_img def training_loop_iteration(self, optimizer, batch, label, criterion): def print_(string): for i in range(10): print(string) batch = batch.to(self.device) #label = label.long() # uncomment for nn.Embedding label = label.to(self.device) timesteps = self.random_timesteps(batch.shape[0]).to(self.device) noisy_batch, real_noise = self.noise_image(batch, timesteps) if torch.isnan(noisy_batch).any() or torch.isnan(real_noise).any(): print_("NaNs detected in the noisy batch or real noise") pred_noise = self.model(noisy_batch, timesteps, label) if torch.isnan(pred_noise).any(): print_("NaNs detected in the predicted noise") loss = criterion(real_noise, pred_noise) if torch.isnan(loss).any(): print_("NaNs detected in the loss") loss.backward() optimizer.step() return loss.item()