# -*- coding: utf-8 -*- """Generalas.ipynb Automatically generated by Colaboratory. Original file is located at https://colab.research.google.com/drive/1xauVtAe5BhUoYH2JItzQtqvFTDxKDSyZ """ # Commented out IPython magic to ensure Python compatibility. import gradio as gr import math from inspect import isfunction from functools import partial # %matplotlib inline import matplotlib.pyplot as plt from tqdm.auto import tqdm from einops import rearrange import torch from torch import nn, einsum import torch.nn.functional as F import numpy as np device = "cuda" if torch.cuda.is_available() else "cpu" def exists(x): return x is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, *args, **kwargs): return self.fn(x, *args, **kwargs) + x def Upsample(dim): return nn.ConvTranspose2d(dim, dim, 4, 2, 1) def Downsample(dim): return nn.Conv2d(dim, dim, 4, 2, 1) class SinusoidalPositionEmbeddings(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, time): device = time.device half_dim = self.dim // 2 embeddings = math.log(10000) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) embeddings = time[:, None] * embeddings[None, :] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) return embeddings #ITT class Block(nn.Module): def __init__(self, dim, dim_out, groups = 8): super().__init__() self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.SiLU() def forward(self, x, scale_shift = None): x = self.proj(x) x = self.norm(x) if exists(scale_shift): scale, shift = scale_shift x = x * (scale + 1) + shift x = self.act(x) return x class ResnetBlock(nn.Module): """https://arxiv.org/abs/1512.03385""" def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): super().__init__() self.mlp = ( nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) if exists(time_emb_dim) else None ) self.block1 = Block(dim, dim_out, groups=groups) self.block2 = Block(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None): h = self.block1(x) if exists(self.mlp) and exists(time_emb): time_emb = self.mlp(time_emb) h = rearrange(time_emb, "b c -> b c 1 1") + h h = self.block2(h) return h + self.res_conv(x) class ConvNextBlock(nn.Module): """https://arxiv.org/abs/2201.03545""" def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True): super().__init__() self.mlp = ( nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim)) if exists(time_emb_dim) else None ) self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim) self.net = nn.Sequential( nn.GroupNorm(1, dim) if norm else nn.Identity(), nn.Conv2d(dim, dim_out * mult, 3, padding=1), nn.GELU(), nn.GroupNorm(1, dim_out * mult), nn.Conv2d(dim_out * mult, dim_out, 3, padding=1), ) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None): h = self.ds_conv(x) if exists(self.mlp) and exists(time_emb): assert exists(time_emb), "time embedding must be passed in" condition = self.mlp(time_emb) h = h + rearrange(condition, "b c -> b c 1 1") h = self.net(h) return h + self.res_conv(x) class Attention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.scale = dim_head**-0.5 self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map( lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv ) q = q * self.scale sim = einsum("b h d i, b h d j -> b h i j", q, k) sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) out = einsum("b h i j, b h d j -> b h i d", attn, v) out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) return self.to_out(out) class LinearAttention(nn.Module): def __init__(self, dim, heads=4, dim_head=32): super().__init__() self.scale = dim_head**-0.5 self.heads = heads hidden_dim = dim_head * heads self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim)) def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = map( lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv ) q = q.softmax(dim=-2) k = k.softmax(dim=-1) q = q * self.scale context = torch.einsum("b h d n, b h e n -> b h d e", k, v) out = torch.einsum("b h d e, b h d n -> b h e n", context, q) out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) return self.to_out(out) class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = nn.GroupNorm(1, dim) def forward(self, x): x = self.norm(x) return self.fn(x) class Unet(nn.Module): def __init__( self, dim, init_dim=None, out_dim=None, dim_mults=(1, 2, 3), channels=3, with_time_emb=True, resnet_block_groups=8, use_convnext=True, convnext_mult=2, ): super().__init__() # determine dimensions self.channels = channels init_dim = default(init_dim, dim // 3 * 2) self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3) dims = [init_dim, *map(lambda m: dim * m, dim_mults)] in_out = list(zip(dims[:-1], dims[1:])) if use_convnext: block_klass = partial(ConvNextBlock, mult=convnext_mult) else: block_klass = partial(ResnetBlock, groups=resnet_block_groups) # time embeddings if with_time_emb: time_dim = dim * 4 self.time_mlp = nn.Sequential( SinusoidalPositionEmbeddings(dim), nn.Linear(dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim), ) else: time_dim = None self.time_mlp = None # layers self.downs = nn.ModuleList([]) self.ups = nn.ModuleList([]) num_resolutions = len(in_out) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (num_resolutions - 1) self.downs.append( nn.ModuleList( [ block_klass(dim_in, dim_out, time_emb_dim=time_dim), block_klass(dim_out, dim_out, time_emb_dim=time_dim), Residual(PreNorm(dim_out, LinearAttention(dim_out))), Downsample(dim_out) if not is_last else nn.Identity(), ] ) ) mid_dim = dims[-1] self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): is_last = ind >= (num_resolutions - 1) self.ups.append( nn.ModuleList( [ block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim), block_klass(dim_in, dim_in, time_emb_dim=time_dim), Residual(PreNorm(dim_in, LinearAttention(dim_in))), Upsample(dim_in) if not is_last else nn.Identity(), ] ) ) out_dim = default(out_dim, channels) self.final_conv = nn.Sequential( block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1) ) def forward(self, x, time): x = self.init_conv(x) t = self.time_mlp(time) if exists(self.time_mlp) else None h = [] # downsample for block1, block2, attn, downsample in self.downs: x = block1(x, t) x = block2(x, t) x = attn(x) h.append(x) x = downsample(x) # bottleneck x = self.mid_block1(x, t) x = self.mid_attn(x) x = self.mid_block2(x, t) # upsample for block1, block2, attn, upsample in self.ups: x = torch.cat((x, h.pop()), dim=1) x = block1(x, t) x = block2(x, t) x = attn(x) x = upsample(x) return self.final_conv(x) image_size = 64 channels = 3 batch_size = 32 best_model = Unet( dim=image_size, channels=channels, dim_mults=(1, 2, 3), ) best_model.load_state_dict(torch.load('1_2_3model.pth', map_location=torch.device('cpu'))) best_model.to(device) def cosine_beta_schedule(timesteps, s=0.008): """ cosine schedule as proposed in https://arxiv.org/abs/2102.09672 """ steps = timesteps + 1 x = torch.linspace(0, timesteps, steps) alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0.0001, 0.9999) def linear_beta_schedule(timesteps): beta_start = 0.0001 beta_end = 0.02 return torch.linspace(beta_start, beta_end, timesteps) def quadratic_beta_schedule(timesteps): beta_start = 0.0001 beta_end = 0.02 return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2 def sigmoid_beta_schedule(timesteps): beta_start = 0.0001 beta_end = 0.02 betas = torch.linspace(-6, 6, timesteps) return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start timesteps = 200 # define beta schedule betas = linear_beta_schedule(timesteps=timesteps) # define alphas alphas = 1. - betas alphas_cumprod = torch.cumprod(alphas, axis=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) sqrt_recip_alphas = torch.sqrt(1.0 / alphas) # calculations for diffusion q(x_t | x_{t-1}) and others sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) def extract(a, t, x_shape): batch_size = t.shape[0] out = a.gather(-1, t.cpu()) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) # forward diffusion def q_sample(x_start, t, noise=None): if noise is None: noise = torch.randn_like(x_start) sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape) sqrt_one_minus_alphas_cumprod_t = extract( sqrt_one_minus_alphas_cumprod, t, x_start.shape ) return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise # def get_noisy_image(x_start, t): # # add noise # x_noisy = q_sample(x_start, t=t) # # turn back into PIL image # noisy_image = reverse_transform(x_noisy.squeeze()) # return noisy_image import matplotlib.pyplot as plt # use seed for reproducability torch.manual_seed(0) # source: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs): if not isinstance(imgs[0], list): # Make a 2d grid even if there's just 1 row imgs = [imgs] num_rows = len(imgs) num_cols = len(imgs[0]) + with_orig fig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False) for row_idx, row in enumerate(imgs): row = [image] + row if with_orig else row for col_idx, img in enumerate(row): ax = axs[row_idx, col_idx] ax.imshow(np.asarray(img), **imshow_kwargs) ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) if with_orig: axs[0, 0].set(title='Original image') axs[0, 0].title.set_size(8) if row_title is not None: for row_idx in range(num_rows): axs[row_idx, 0].set(ylabel=row_title[row_idx]) plt.tight_layout() def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"): if noise is None: noise = torch.randn_like(x_start) x_noisy = q_sample(x_start=x_start, t=t, noise=noise) predicted_noise = denoise_model(x_noisy, t) if loss_type == 'l1': loss = F.l1_loss(noise, predicted_noise) elif loss_type == 'l2': loss = F.mse_loss(noise, predicted_noise) elif loss_type == "huber": loss = F.smooth_l1_loss(noise, predicted_noise) else: raise NotImplementedError() return loss @torch.no_grad() def p_sample(model, x, t, t_index): betas_t = extract(betas, t, x.shape) sqrt_one_minus_alphas_cumprod_t = extract( sqrt_one_minus_alphas_cumprod, t, x.shape ) sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape) # Equation 11 in the paper # Use our model (noise predictor) to predict the mean model_mean = sqrt_recip_alphas_t * ( x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t ) if t_index == 0: return model_mean else: posterior_variance_t = extract(posterior_variance, t, x.shape) noise = torch.randn_like(x) # Algorithm 2 line 4: return model_mean + torch.sqrt(posterior_variance_t) * noise # Algorithm 2 but save all images: @torch.no_grad() def p_sample_loop(model, shape): device = next(model.parameters()).device b = shape[0] # start from pure noise (for each example in the batch) img = torch.randn(shape, device=device) imgs = [] for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps): img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i) imgs.append(img.cpu().numpy()) return imgs @torch.no_grad() def sample(model, image_size, batch_size=16, channels=3): return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size)) # sample 64 images sample_size = 21 samples = sample(best_model, image_size=image_size, batch_size=sample_size, channels=channels) """UI""" def show_picture(random_index): image=(samples[-1][random_index].transpose(1, 2, 0) + 1.0) / 2.0 clipped_image = np.clip(image, 0.0, 1.0) return clipped_image demo = gr.Interface(fn=show_picture, inputs=gr.Slider(minimum=0, maximum=sample_size-1, step=1), outputs="image") if __name__ == "__main__": demo.launch(show_api=False, share=True) # random_index = 7 # image=(samples[-1][random_index].transpose(1, 2, 0) + 1.0) / 2.0 # clipped_image = np.clip(image, 0.0, 1.0) # plt.imshow(clipped_image) #clip/clamp