Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.autograd import Variable | |
| import torch.optim as optim | |
| import kornia.augmentation as K | |
| from CLIP import clip | |
| from torchvision import transforms | |
| from PIL import Image | |
| import numpy as np | |
| import math | |
| from matplotlib import pyplot as plt | |
| from fastprogress.fastprogress import master_bar, progress_bar | |
| from IPython.display import HTML | |
| from base64 import b64encode | |
| # Definitions | |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| def sinc(x): | |
| return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([])) | |
| def lanczos(x, a): | |
| cond = torch.logical_and(-a < x, x < a) | |
| out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([])) | |
| return out / out.sum() | |
| def ramp(ratio, width): | |
| n = math.ceil(width / ratio + 1) | |
| out = torch.empty([n]) | |
| cur = 0 | |
| for i in range(out.shape[0]): | |
| out[i] = cur | |
| cur += ratio | |
| return torch.cat([-out[1:].flip([0]), out])[1:-1] | |
| class Prompt(nn.Module): | |
| def __init__(self, embed, weight=1., stop=float('-inf')): | |
| super().__init__() | |
| self.register_buffer('embed', embed) | |
| self.register_buffer('weight', torch.as_tensor(weight)) | |
| self.register_buffer('stop', torch.as_tensor(stop)) | |
| def forward(self, input): | |
| input_normed = F.normalize(input.unsqueeze(1), dim=2) | |
| embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2) | |
| dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) | |
| dists = dists * self.weight.sign() | |
| return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean() | |
| class MakeCutouts(nn.Module): | |
| def __init__(self, cut_size, cutn, cut_pow=1.): | |
| super().__init__() | |
| self.cut_size = cut_size | |
| self.cutn = cutn | |
| self.cut_pow = cut_pow | |
| self.augs = nn.Sequential( | |
| K.RandomHorizontalFlip(p=0.5), | |
| K.RandomSharpness(0.3,p=0.4), | |
| K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'), | |
| K.RandomPerspective(0.2,p=0.4), | |
| K.ColorJitter(hue=0.01, saturation=0.01, p=0.7)) | |
| self.noise_fac = 0.1 | |
| def forward(self, input): | |
| sideY, sideX = input.shape[2:4] | |
| max_size = min(sideX, sideY) | |
| min_size = min(sideX, sideY, self.cut_size) | |
| cutouts = [] | |
| for _ in range(self.cutn): | |
| size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size) | |
| offsetx = torch.randint(0, sideX - size + 1, ()) | |
| offsety = torch.randint(0, sideY - size + 1, ()) | |
| cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size] | |
| cutouts.append(resample(cutout, (self.cut_size, self.cut_size))) | |
| batch = self.augs(torch.cat(cutouts, dim=0)) | |
| if self.noise_fac: | |
| facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) | |
| batch = batch + facs * torch.randn_like(batch) | |
| return batch | |
| def resample(input, size, align_corners=True): | |
| n, c, h, w = input.shape | |
| dh, dw = size | |
| input = input.view([n * c, 1, h, w]) | |
| if dh < h: | |
| kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype) | |
| pad_h = (kernel_h.shape[0] - 1) // 2 | |
| input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect') | |
| input = F.conv2d(input, kernel_h[None, None, :, None]) | |
| if dw < w: | |
| kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype) | |
| pad_w = (kernel_w.shape[0] - 1) // 2 | |
| input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect') | |
| input = F.conv2d(input, kernel_w[None, None, None, :]) | |
| input = input.view([n, c, h, w]) | |
| return F.interpolate(input, size, mode='bicubic', align_corners=align_corners) | |
| class ReplaceGrad(torch.autograd.Function): | |
| def forward(ctx, x_forward, x_backward): | |
| ctx.shape = x_backward.shape | |
| return x_forward | |
| def backward(ctx, grad_in): | |
| return None, grad_in.sum_to_size(ctx.shape) | |
| replace_grad = ReplaceGrad.apply | |
| # Set up CLIP | |
| perceptor = clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device) | |
| normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], | |
| std=[0.26862954, 0.26130258, 0.27577711]) | |
| cut_size = perceptor.visual.input_resolution | |
| cutn=8 # 64 but using less to save on CPU | |
| cut_pow=1 | |
| make_cutouts = MakeCutouts(cut_size, cutn, cut_pow=cut_pow) | |
| # ImStack | |
| class ImStack(nn.Module): | |
| """ This class represents an image as a series of stacked arrays, where each is 1/2 | |
| the resolution of the next. This is useful eg when trying to create an image to minimise | |
| some loss - parameters in the early (small) layers can have an affect on the overall | |
| structure and shapes while those in later layers act as residuals and fill in fine detail. | |
| """ | |
| def __init__(self, n_layers=3, base_size=32, scale=2, | |
| init_image=None, out_size=256, decay=0.7): | |
| """Constructs the Image Stack | |
| Args: | |
| TODO | |
| """ | |
| super().__init__() | |
| self.n_layers = n_layers | |
| self.base_size = base_size | |
| self.sig = nn.Sigmoid() | |
| self.layers = [] | |
| for i in range(n_layers): | |
| side = base_size * (scale**i) | |
| tim = torch.randn((3, side, side)).to(device)*(decay**i) | |
| self.layers.append(tim) | |
| self.scalers = [nn.Upsample(scale_factor=out_size/(l.shape[1]), mode='bilinear', align_corners=False) for l in self.layers] | |
| self.preview_scalers = [nn.Upsample(scale_factor=224/(l.shape[1]), mode='bilinear', align_corners=False) for l in self.layers] | |
| if init_image != None: # Given a PIL image, decompose it into a stack | |
| downscalers = [nn.Upsample(scale_factor=(l.shape[1]/out_size), mode='bilinear', align_corners=False) for l in self.layers] | |
| final_side = base_size * (scale ** n_layers) | |
| im = torch.tensor(np.array(init_image.resize((out_size, out_size)))/255).clip(1e-03, 1-1e-3) # Between 0 and 1 (non-inclusive) | |
| im = im.permute(2, 0, 1).unsqueeze(0).to(device) # torch.log(im/(1-im)) | |
| for i in range(n_layers):self.layers[i] *= 0 # Sero out the layers | |
| for i in range(n_layers): | |
| side = base_size * (scale**i) | |
| out = self.forward() | |
| residual = (torch.logit(im) - torch.logit(out)) | |
| Image.fromarray((torch.logit(residual).detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8)).save(f'residual{i}.png') | |
| self.layers[i] = downscalers[i](residual).squeeze() | |
| for l in self.layers: l.requires_grad = True | |
| def forward(self): | |
| im = self.scalers[0](self.layers[0].unsqueeze(0)) | |
| for i in range(1, self.n_layers): | |
| im += self.scalers[i](self.layers[i].unsqueeze(0)) | |
| return self.sig(im) | |
| def preview(self, n_preview=2): | |
| im = self.preview_scalers[0](self.layers[0].unsqueeze(0)) | |
| for i in range(1, n_preview): | |
| im += self.preview_scalers[i](self.layers[i].unsqueeze(0)) | |
| return self.sig(im) | |
| def to_pil(self): | |
| return Image.fromarray((self.forward().detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8)) | |
| def preview_pil(self): | |
| return Image.fromarray((self.preview().detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8)) | |
| def save(self, fn): | |
| self.to_pil().save(fn) | |
| def plot_layers(self): | |
| fig, axs = plt.subplots(1, self.n_layers, figsize=(15, 5)) | |
| for i in range(self.n_layers): | |
| im = (self.sig(self.layers[i].unsqueeze(0)).detach().cpu().squeeze().permute([1, 2, 0]) * 255).numpy().astype(np.uint8) | |
| axs[i].imshow(im) | |
| def generate(text, n_steps): | |
| lr=0.25 #@param | |
| n_iter = int(n_steps) | |
| # init_image=None #@param | |
| weight_decay=1e-5 #@param | |
| out_size=180 #@param | |
| base_size=20 #@param | |
| n_layers=3 #@param | |
| scale=3 #@param | |
| p_prompts = [] | |
| embed = perceptor.encode_text(clip.tokenize(text).to(device)).float() | |
| p_prompts.append(Prompt(embed, 1, float('-inf')).to(device)) # 1 is the weight | |
| # SOme negative prompts | |
| n_prompts = [] | |
| for pr in ["Random noise", 'saturated rainbow RGB deep dream']: | |
| embed = perceptor.encode_text(clip.tokenize(pr).to(device)).float() | |
| n_prompts.append(Prompt(embed, 0.5, float('-inf')).to(device)) # 0.5 is the weight | |
| # The ImageStack - trying a different scale and n_layers | |
| ims = ImStack(base_size=base_size, scale=scale, n_layers=n_layers, out_size=out_size, decay=0.4) | |
| optimizer = optim.Adam(ims.layers, lr=lr, weight_decay=weight_decay) | |
| losses = [] | |
| for i in range(n_iter): | |
| optimizer.zero_grad() | |
| if i < 15: # Save time by skipping the cutouts and focusing on the lower layers | |
| im = ims.preview(n_preview=1 + i//20 ) | |
| iii = perceptor.encode_image(normalize(im)).float() | |
| else: | |
| im = ims() | |
| iii = perceptor.encode_image(normalize(make_cutouts(im))).float() | |
| l = 0 | |
| for prompt in p_prompts: | |
| l += prompt(iii) | |
| for prompt in n_prompts: | |
| l -= prompt(iii) | |
| losses.append(float(l.detach().cpu())) | |
| l.backward() # Backprop | |
| optimizer.step() # Update | |
| im = ims.to_pil() | |
| return np.array(im) | |
| iface = gr.Interface(fn=generate, | |
| description = "Attempt at a Gradio demo for https://colab.research.google.com/drive/1dBPXIspuMocqfcJqfjCn_PFeUfr36KGu?usp=sharing. A little slow on CPU so check out the colab for higher res generation.", | |
| inputs=[ | |
| gr.inputs.Textbox(label="Text Input"), | |
| gr.inputs.Number(default=64, label="N Steps") | |
| ], | |
| outputs=[ | |
| gr.outputs.Image(type="numpy", label="Output Image") | |
| ], | |
| ).launch(enable_queue=True) |