Spaces:
Build error
Build error
| import matplotlib.pyplot as plt | |
| import torch | |
| import torch.nn as nn | |
| def plot_images(figure, imgs): | |
| h, w = figure | |
| assert(h*w == imgs.shape[0]), "figure grid doesn't match imgs amount" | |
| _, axs = plt.subplots(w, h) | |
| img_index = 0 | |
| for i in range(h): | |
| for j in range(w): | |
| axs[j, i].imshow(imgs[img_index]) | |
| axs[j, i].axis('off') | |
| img_index = img_index + 1 | |
| def denoise_image(noised_image, predicted_noise, t, betas, alphas, alpha_bar): | |
| z = torch.randn_like(noised_image) | |
| noise = betas.sqrt()[t] * z | |
| mean = (noised_image - predicted_noise * ((1 - alphas[t]) / (1 - alpha_bar[t]).sqrt())) / alphas[t].sqrt() | |
| return mean + noise | |
| class DDPM(nn.Module): | |
| def __init__(self, betas): | |
| super(DDPM, self).__init__() | |
| self.betas = betas | |
| self.alphas = 1.0 - betas | |
| self.alpha_bars = torch.cumprod(self.alphas, dim=0) | |
| def forward(self, x, t): | |
| batch_size = x.shape[0] | |
| device = x.device | |
| # Get corresponding alpha_bar_t | |
| alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1).to(device) | |
| # Sample noise | |
| noise = torch.randn_like(x) | |
| # Compute the noised image | |
| noised_image = torch.sqrt(alpha_bar_t) * x + torch.sqrt(1 - alpha_bar_t) * noise | |
| return noised_image, noise | |
| def generate_img(model, sampler,betas, alpha, alpha_bar,batch_size, sampling_count, context=None, device=None): | |
| if device is None: | |
| device = torch.device("cpu") | |
| model.eval() | |
| if context is None: | |
| context = [0 for _ in range(batch_size)] | |
| context = torch.tensor(context, dtype=torch.int).to(device) | |
| with torch.no_grad(): | |
| noised_img = sampler(torch.rand((batch_size, 3, 16, 16)).to(device), | |
| torch.ones(batch_size, dtype=torch.int) * 200)[0] | |
| for t in range(sampling_count, 0, -1): | |
| _t = torch.tensor([[t for _ in range(noised_img.shape[0])]], dtype=torch.float32).to(device).T | |
| noise = model(noised_img, _t, context) | |
| noised_img = denoise_image(noised_img, noise, t, betas, alpha, alpha_bar) | |
| return noised_img |