Spaces:
Runtime error
Runtime error
File size: 6,991 Bytes
cb16212 f9fb4bc cb16212 1173e62 cb16212 1173e62 cb16212 1173e62 cb16212 1173e62 f9fb4bc cb16212 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import torch
from torchvision import transforms as tfms
import numpy as np
import cv2
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
from diffusers import StableDiffusionInpaintPipeline
import gradio as gr
import os
auth_token = os.environ.get("API_TOKEN")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_artifacts():
'''
A function to load all diffusion artifacts
'''
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", torch_dtype=torch.float16,use_auth_token=auth_token).to(device)
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16, use_auth_token=auth_token).to(device)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16, use_auth_token=auth_token)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16, use_auth_token=auth_token).to(device)
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
return vae, unet, tokenizer, text_encoder, scheduler
def load_image(p):
'''
Function to load images from a defined path
'''
return Image.open(p).convert('RGB').resize((512,512))
def pil_to_latents(image):
'''
Function to convert image to latents
'''
init_image = tfms.ToTensor()(image).unsqueeze(0) * 2.0 - 1.0
init_image = init_image.to(device=device, dtype=torch.float16)
init_latent_dist = vae.encode(init_image).latent_dist.sample() * 0.18215
return init_latent_dist
def latents_to_pil(latents):
'''
Function to convert latents to images
'''
latents = (1 / 0.18215) * latents
with torch.no_grad():
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def text_enc(prompts, maxlen=None):
'''
A function to take a texual promt and convert it into embeddings
'''
if maxlen is None: maxlen = tokenizer.model_max_length
inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
return text_encoder(inp.input_ids.to(device))[0].half()
def prompt_2_img_i2i_fast(prompts, init_img, g=7.5, seed=100, strength =0.5, steps=50, dim=512):
"""
Diffusion process to convert prompt to image
"""
# Converting textual prompts to embedding
text = text_enc(prompts)
# Adding an unconditional prompt , helps in the generation process
uncond = text_enc([""], text.shape[1])
emb = torch.cat([uncond, text])
# Setting the seed
if seed: torch.manual_seed(seed)
# Setting number of steps in scheduler
scheduler.set_timesteps(steps)
# Convert the seed image to latent
init_latents = pil_to_latents(init_img)
# Figuring initial time step based on strength
init_timestep = int(steps * strength)
timesteps = scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps], device=device)
# Adding noise to the latents
noise = torch.randn(init_latents.shape, generator=None, device=device, dtype=init_latents.dtype)
init_latents = scheduler.add_noise(init_latents, noise, timesteps)
latents = init_latents
# We need to scale the i/p latents to match the variance
inp = scheduler.scale_model_input(torch.cat([latents] * 2), timesteps)
# Predicting noise residual using U-Net
with torch.no_grad(): u,t = unet(inp, timesteps, encoder_hidden_states=emb).sample.chunk(2)
# Performing Guidance
pred = u + g*(t-u)
# Zero shot prediction
latents = scheduler.step(pred, timesteps, latents).pred_original_sample
# Returning the latent representation to output an array of 4x64x64
return latents.detach().cpu()
def create_mask_fast(init_img, rp, qp, n=20, s=0.5):
## Initialize a dictionary to save n iterations
diff = {}
## Repeating the difference process n times
for idx in range(n):
## Creating denoised sample using reference / original text
orig_noise = prompt_2_img_i2i_fast(prompts=rp, init_img=init_img, strength=s, seed = 100*idx)[0]
## Creating denoised sample using query / target text
query_noise = prompt_2_img_i2i_fast(prompts=qp, init_img=init_img, strength=s, seed = 100*idx)[0]
## Taking the difference
diff[idx] = (np.array(orig_noise)-np.array(query_noise))
## Creating a mask placeholder
mask = np.zeros_like(diff[0])
## Taking an average of 10 iterations
for idx in range(n):
## Note np.abs is a key step
mask += np.abs(diff[idx])
## Averaging multiple channels
mask = mask.mean(0)
## Normalizing
mask = (mask - mask.mean()) / np.std(mask)
## Binarizing and returning the mask object
return (mask > 0).astype("uint8")
def improve_mask(mask):
mask = cv2.GaussianBlur(mask*255,(3,3),1) > 0
return mask.astype('uint8')
vae, unet, tokenizer, text_encoder, scheduler = load_artifacts()
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16,
use_auth_token=auth_token
).to(device)
def fastDiffEdit(init_img, reference_prompt , query_prompt, g=7.5, seed=100, strength =0.7, steps=20, dim=512):
## Step 1: Create mask
mask = create_mask_fast(init_img=init_img, rp=reference_prompt, qp=query_prompt, n=20)
## Improve masking using CV trick
mask = improve_mask(mask)
## Step 2 and 3: Diffusion process using mask
output = pipe(
prompt=query_prompt,
image=init_img,
mask_image=Image.fromarray(mask*255).resize((512,512)),
generator=torch.Generator(device).manual_seed(100),
num_inference_steps = steps
).images
return output[0]
demo = gr.Interface(
fn=fastDiffEdit,
inputs=[
gr.inputs.Image(shape=(512, 512), type="pil", label = "Upload your image photo"),
gr.Textbox(label="Describe your image. Ex: a horse image"),
gr.Textbox(label="Retype the description with target output. Ex: a zebra image")],
outputs="image",
title = "DiffEdit demo",
description = "DiffEdit paper demo. Upload an image, pass reference prompt describing the image, pass query prompt to replace the object with target object",
examples = [
["fruitbowl.jpg", "a bowl of fruit", "a bowl of grapes"],
["horse.jpg", "a horse image", "a zebra image"]],
enable_queue=True
)
demo.launch() |