Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision | |
| from datasets import load_dataset | |
| from diffusers import DDIMScheduler,DDPMPipeline | |
| from matplotlib import pyplot as plt | |
| from PIL import Image | |
| from torchvision import transforms | |
| from tqdm.auto import tqdm | |
| import open_clip | |
| #backend code | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Load the pretrained pipeline | |
| pipeline_name = "johnowhitaker/sd-class-wikiart-from-bedrooms" | |
| image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device) | |
| clip_model,_,preprocess = open_clip.create_model_and_transforms("ViT-B-32",pretrained="openai") | |
| clip_model.to(device) | |
| # Sample some images with a DDIM Scheduler over 40 steps | |
| scheduler = DDIMScheduler.from_pretrained(pipeline_name) | |
| scheduler.set_timesteps(num_inference_steps=40) | |
| def color_loss(images,target_color=(0.1,0.9,0.5)): | |
| target=torch.tensor(target_color).to(images.device)*2-1 | |
| target=target[None, :, None,None] | |
| error=torch.abs(images-target).mean() | |
| return error | |
| tfms = torchvision.transforms.Compose( | |
| [ | |
| torchvision.transforms.RandomResizedCrop(224), | |
| torchvision.transforms.RandomAffine(5), | |
| torchvision.transforms.RandomHorizontalFlip(), | |
| torchvision.transforms.Normalize( | |
| mean=(0.48145466, 0.4578275, 0.40821073), | |
| std=(0.26862954, 0.26130258, 0.27577711), | |
| ) | |
| ] | |
| ) | |
| def clip_loss(image, text_features): | |
| image_features = clip_model.encode_image(tfms(image)) # Note: applies the above transforms | |
| input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2) | |
| embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2) | |
| dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) # Squared Great Circle Distance | |
| return dists.mean() | |
| n_cuts = 4 # @param |