Spaces:
Runtime error
Runtime error
| import open_clip | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| from tqdm.auto import tqdm | |
| from PIL import Image, ImageColor | |
| from torchvision import transforms | |
| from diffusers import DDIMScheduler, DDPMPipeline | |
| device = ( | |
| "mps" | |
| if torch.backends.mps.is_available() | |
| else "cuda" | |
| if torch.cuda.is_available() | |
| else "cpu" | |
| ) | |
| # Load the pretrained pipeline | |
| pipeline_name = "alkzar90/sd-class-ukiyo-e-256" | |
| image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device) | |
| # Sample some images with a DDIM Scheduler over 40 steps | |
| scheduler = DDIMScheduler.from_pretrained(pipeline_name) | |
| scheduler.set_timesteps(num_inference_steps=40) | |
| # Color guidance | |
| #------------------------------------------------------------------------------- | |
| # Color guidance function | |
| def color_loss(images, target_color=(0.1, 0.9, 0.5)): | |
| """Given a target color (R, G, B) return a loss for how far away on average | |
| the images' pixels are from that color. Defaults to a light teal: (0.1, 0.9, 0.5)""" | |
| target = ( | |
| torch.tensor(target_color).to(images.device) * 2 - 1 | |
| ) # Map target color to (-1, 1) | |
| target = target[ | |
| None, :, None, None | |
| ] # Get shape right to work with the images (b, c, h, w) | |
| error = torch.abs( | |
| images - target | |
| ).mean() # Mean absolute difference between the image pixels and the target color | |
| return error | |
| # CLIP guidance | |
| #------------------------------------------------------------------------------- | |
| clip_model, _, preprocess = open_clip.create_model_and_transforms( | |
| "ViT-B-32", pretrained="openai" | |
| ) | |
| clip_model.to(device) | |
| # Transforms to resize and augment an image + normalize to match CLIP's training data | |
| tfms = transforms.Compose( | |
| [ | |
| transforms.RandomResizedCrop(224), # Random CROP each time | |
| transforms.RandomAffine( | |
| 5 | |
| ), # One possible random augmentation: skews the image | |
| transforms.RandomHorizontalFlip(), # You can add additional augmentations if you like | |
| transforms.Normalize( | |
| mean=(0.48145466, 0.4578275, 0.40821073), | |
| std=(0.26862954, 0.26130258, 0.27577711), | |
| ), | |
| ] | |
| ) | |
| # CLIP guidance function | |
| def clip_loss(image, text_features): | |
| image_features = clip_model.encode_image( | |
| tfms(image) | |
| ) # Note: applies the above transforms | |
| input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2) | |
| embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2) | |
| dists = ( | |
| input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) | |
| ) # Squared Great Circle Distance | |
| return dists.mean() | |
| # Sample generator loop | |
| #------------------------------------------------------------------------------- | |
| def generate(color, | |
| color_loss_scale, | |
| num_examples=4, | |
| seed=None, | |
| prompt=None, | |
| prompt_loss_scale=None, | |
| prompt_n_cuts=None, | |
| inference_steps=50, | |
| ): | |
| scheduler.set_timesteps(num_inference_steps=inference_steps) | |
| if seed: | |
| torch.manual_seed(seed) | |
| if prompt: | |
| text = open_clip.tokenize([prompt]).to(device) | |
| with torch.no_grad(), torch.cuda.amp.autocast(): | |
| text_features = clip_model.encode_text(text) | |
| target_color = ImageColor.getcolor(color, "RGB") # Target color as RGB | |
| target_color = [a / 255 for a in target_color] # Rescale from (0, 255) to (0, 1) | |
| x = torch.randn(num_examples, 3, 256, 256).to(device) | |
| for i, t in tqdm(enumerate(scheduler.timesteps)): | |
| model_input = scheduler.scale_model_input(x, t) | |
| with torch.no_grad(): | |
| noise_pred = image_pipe.unet(model_input, t)["sample"] | |
| x = x.detach().requires_grad_() | |
| x0 = scheduler.step(noise_pred, t, x).pred_original_sample | |
| # color loss | |
| loss = color_loss(x0, target_color) * color_loss_scale | |
| cond_color_grad = -torch.autograd.grad(loss, x)[0] | |
| # Modify x based solely on the color gradient -> x_cond | |
| x_cond = x.detach() + cond_color_grad | |
| # prompt loss (modify x_cond with cond_prompt_grad) based on | |
| # the original x (not modifified previously with cond_color_grad) | |
| if prompt: | |
| cond_prompt_grad = 0 | |
| for cut in range(prompt_n_cuts): | |
| # Set requires grad on x | |
| x = x.detach().requires_grad_() | |
| # Get the predicted x0: | |
| x0 = scheduler.step(noise_pred, t, x).pred_original_sample | |
| # Calculate loss | |
| prompt_loss = clip_loss(x0, text_features) * prompt_loss_scale | |
| # Get gradient (scale by n_cuts since we want the average) | |
| cond_prompt_grad -= torch.autograd.grad(prompt_loss, x, retain_graph=True)[0] / prompt_n_cuts | |
| # Modify x based on this gradient | |
| alpha_bar = scheduler.alphas_cumprod[i] | |
| x_cond = ( | |
| x_cond + cond_prompt_grad * alpha_bar.sqrt() | |
| ) # Note the additional scaling factor here! | |
| x = scheduler.step(noise_pred, t, x_cond).prev_sample | |
| grid = torchvision.utils.make_grid(x, nrow=4) | |
| im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5 | |
| im = Image.fromarray(np.array(im * 255).astype(np.uint8)) | |
| #im.save("test.jpeg") | |
| return im | |
| # GRADIO Interface | |
| #------------------------------------------------------------------------------- | |
| TITLE="Ukiyo-e postal generator service 🎴!" | |
| DESCRIPTION="This model is a diffusion model for unconditional image generation of Ukiyo-e images ✍ 🎨. \nThe model was train using fine-tuning with the google/ddpm-celebahq-256 pretrain-model and the dataset: https://huggingface.co/datasets/huggan/ukiyoe2photo" | |
| CSS = ".output-image, .input-image, .image-preview {height: 250px !important}" | |
| # See the gradio docs for the types of inputs and outputs available | |
| inputs = [ | |
| gr.ColorPicker(label="color (click on the square to pick the color)", value="#DF5C16"), # Add any inputs you need here | |
| gr.Slider(label="color_guidance_scale (how strong to blend the color)", minimum=0, maximum=30, value=6.7), | |
| gr.Slider(label="num_examples (# images generated)", minimum=2, maximum=12, value=2, step=4), | |
| gr.Number(label="seed (reproducibility and experimentation)", value=666), | |
| gr.Text(label="Text prompt (optional)", value=None), | |
| gr.Slider(label="prompt_guidance_scale (...)", minimum=0, maximum=1000, value=10), | |
| gr.Slider(label="prompt_n_cuts", minimum=4, maximum=12, step=4), | |
| gr.Slider(label="Number of inference steps (+ steps -> + guidance effect)", minimum=40, maximum=60, value=40, step=1), | |
| ] | |
| outputs = gr.Image(label="result") | |
| # And the minimal interface | |
| demo = gr.Interface( | |
| fn=generate, | |
| inputs=inputs, | |
| outputs=outputs, | |
| css=CSS, | |
| #examples=[ | |
| #["#DF5C16", 6.7, 12, 666, None, None, None, 40], | |
| #["#C01660", 13.5, 12, 1990, None, None, None, 40], | |
| #["#44CCAA", 8.9, 12, 1512, None, None, None, 40], | |
| #["#39A291", 5.0, 2, 666, "A sakura tree", 60, 4, 40], | |
| #["#0E0907", 0.0, 12, 666, "A big whale in the ocean", 60, 8, 52], | |
| #["#19A617", 4.6, 12, 666, "An island with sunset at background", 140, 4, 47], | |
| #], | |
| title=TITLE, | |
| description=DESCRIPTION, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(enable_queue=True) | |