Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision | |
| from datasets import load_dataset | |
| from diffusers import DDIMScheduler, DDPMPipeline | |
| from matplotlib import pyplot as plt | |
| from PIL import Image | |
| from torchvision import transforms | |
| from tqdm.auto import tqdm | |
| device = "cpu" | |
| # Load the pretrained pipeline | |
| pipeline_name = "johnowhitaker/sd-class-wikiart-from-bedrooms" | |
| image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device) | |
| # Sample some images with a DDIM Scheduler over 40 steps | |
| scheduler = DDIMScheduler.from_pretrained(pipeline_name) | |
| scheduler.set_timesteps(num_inference_steps=40) | |
| import open_clip | |
| clip_model, _, preprocess = open_clip.create_model_and_transforms( | |
| "ViT-B-32", pretrained="openai" | |
| ) | |
| clip_model.to(device) | |
| # Transforms to resize and augment an image + normalize to match CLIP's training data | |
| tfms = torchvision.transforms.Compose( | |
| [ | |
| torchvision.transforms.RandomResizedCrop(224), # Random CROP each time | |
| torchvision.transforms.RandomAffine( | |
| 5 | |
| ), # One possible random augmentation: skews the image | |
| torchvision.transforms.RandomHorizontalFlip(), # You can add additional augmentations if you like | |
| torchvision.transforms.Normalize( | |
| mean=(0.48145466, 0.4578275, 0.40821073), | |
| std=(0.26862954, 0.26130258, 0.27577711), | |
| ), | |
| ] | |
| ) | |
| # And define a loss function that takes an image, embeds it and compares with | |
| # the text features of the prompt | |
| def clip_loss(image, text_features): | |
| image_features = clip_model.encode_image( | |
| tfms(image) | |
| ) # Note: applies the above transforms | |
| input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2) | |
| embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2) | |
| dists = ( | |
| input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) | |
| ) # Squared Great Circle Distance | |
| return dists.mean() | |
| import gradio as gr | |
| from PIL import Image | |
| def text2img(prompt, guidance_scale, n_cuts): | |
| print(f"prompt: {prompt}") | |
| # We embed a prompt with CLIP as our target | |
| text = open_clip.tokenize([prompt]).to(device) | |
| with torch.no_grad(), torch.cuda.amp.autocast(): | |
| text_features = clip_model.encode_text(text) | |
| x = torch.randn(1, 3, 256, 256).to(device) # RAM usage is high, you may want only 1 image at a time | |
| for i, t in tqdm(enumerate(scheduler.timesteps)): | |
| model_input = scheduler.scale_model_input(x, t) | |
| # predict the noise residual | |
| with torch.no_grad(): | |
| noise_pred = image_pipe.unet(model_input, t)["sample"] | |
| cond_grad = 0 | |
| for cut in range(int(n_cuts)): | |
| # Set requires grad on x | |
| x = x.detach().requires_grad_() | |
| # Get the predicted x0: | |
| x0 = scheduler.step(noise_pred, t, x).pred_original_sample | |
| # Calculate loss | |
| loss = clip_loss(x0, text_features) * guidance_scale | |
| # Get gradient (scale by n_cuts since we want the average) | |
| cond_grad -= torch.autograd.grad(loss, x)[0] / n_cuts | |
| # Modify x based on this gradient | |
| alpha_bar = scheduler.alphas_cumprod[i] | |
| x = x.detach() + cond_grad * alpha_bar.sqrt() # Note the additional scaling factor here! | |
| # Now step with scheduler | |
| x = scheduler.step(noise_pred, t, x).prev_sample | |
| grid = torchvision.utils.make_grid(x, nrow=1) | |
| im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5 | |
| im = Image.fromarray(np.array(im * 255).astype(np.uint8)) | |
| return im | |
| # See the gradio docs for the types of inputs and outputs available | |
| inputs = ["text", "number", "number"] | |
| outputs = gr.Image(label="text-guided image") | |
| # And the minimal interface | |
| demo = gr.Interface( | |
| fn=text2img, | |
| inputs=inputs, | |
| outputs=outputs, | |
| examples=[ | |
| ["Red Rose (still life), red flower painting", 10, 4], | |
| ["Blue sky, pure and bright, positive mood", 8, 5], # You can provide some example inputs to get people started | |
| ], | |
| ) | |
| demo.launch() |