Spaces:
Runtime error
Runtime error
Commit
·
cb16212
1
Parent(s):
5233387
Uploading diffedit app
Browse files- .gitattributes +4 -0
- Gradio Demo.ipynb +0 -0
- app.py +179 -0
- fruitbowl.jpg +0 -0
- horse.jpg +3 -0
- packages.txt +1 -0
- requirements.txt +10 -0
.gitattributes
CHANGED
|
@@ -32,3 +32,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
horse.jpg filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
fruitbowl.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
Gradio Demo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
app.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision import transforms as tfms
|
| 3 |
+
import numpy as np
|
| 4 |
+
import cv2
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 7 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
|
| 8 |
+
from diffusers import StableDiffusionInpaintPipeline
|
| 9 |
+
import gradio as gr
|
| 10 |
+
|
| 11 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 12 |
+
def load_artifacts():
|
| 13 |
+
'''
|
| 14 |
+
A function to load all diffusion artifacts
|
| 15 |
+
'''
|
| 16 |
+
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", torch_dtype=torch.float16).to(device)
|
| 17 |
+
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to(device)
|
| 18 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16)
|
| 19 |
+
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).to(device)
|
| 20 |
+
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
| 21 |
+
return vae, unet, tokenizer, text_encoder, scheduler
|
| 22 |
+
|
| 23 |
+
def load_image(p):
|
| 24 |
+
'''
|
| 25 |
+
Function to load images from a defined path
|
| 26 |
+
'''
|
| 27 |
+
return Image.open(p).convert('RGB').resize((512,512))
|
| 28 |
+
|
| 29 |
+
def pil_to_latents(image):
|
| 30 |
+
'''
|
| 31 |
+
Function to convert image to latents
|
| 32 |
+
'''
|
| 33 |
+
init_image = tfms.ToTensor()(image).unsqueeze(0) * 2.0 - 1.0
|
| 34 |
+
init_image = init_image.to(device=device, dtype=torch.float16)
|
| 35 |
+
init_latent_dist = vae.encode(init_image).latent_dist.sample() * 0.18215
|
| 36 |
+
return init_latent_dist
|
| 37 |
+
|
| 38 |
+
def latents_to_pil(latents):
|
| 39 |
+
'''
|
| 40 |
+
Function to convert latents to images
|
| 41 |
+
'''
|
| 42 |
+
latents = (1 / 0.18215) * latents
|
| 43 |
+
with torch.no_grad():
|
| 44 |
+
image = vae.decode(latents).sample
|
| 45 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 46 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
| 47 |
+
images = (image * 255).round().astype("uint8")
|
| 48 |
+
pil_images = [Image.fromarray(image) for image in images]
|
| 49 |
+
return pil_images
|
| 50 |
+
|
| 51 |
+
def text_enc(prompts, maxlen=None):
|
| 52 |
+
'''
|
| 53 |
+
A function to take a texual promt and convert it into embeddings
|
| 54 |
+
'''
|
| 55 |
+
if maxlen is None: maxlen = tokenizer.model_max_length
|
| 56 |
+
inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
|
| 57 |
+
return text_encoder(inp.input_ids.to(device))[0].half()
|
| 58 |
+
|
| 59 |
+
def prompt_2_img_i2i_fast(prompts, init_img, g=7.5, seed=100, strength =0.5, steps=50, dim=512):
|
| 60 |
+
"""
|
| 61 |
+
Diffusion process to convert prompt to image
|
| 62 |
+
"""
|
| 63 |
+
# Converting textual prompts to embedding
|
| 64 |
+
text = text_enc(prompts)
|
| 65 |
+
|
| 66 |
+
# Adding an unconditional prompt , helps in the generation process
|
| 67 |
+
uncond = text_enc([""], text.shape[1])
|
| 68 |
+
emb = torch.cat([uncond, text])
|
| 69 |
+
|
| 70 |
+
# Setting the seed
|
| 71 |
+
if seed: torch.manual_seed(seed)
|
| 72 |
+
|
| 73 |
+
# Setting number of steps in scheduler
|
| 74 |
+
scheduler.set_timesteps(steps)
|
| 75 |
+
|
| 76 |
+
# Convert the seed image to latent
|
| 77 |
+
init_latents = pil_to_latents(init_img)
|
| 78 |
+
|
| 79 |
+
# Figuring initial time step based on strength
|
| 80 |
+
init_timestep = int(steps * strength)
|
| 81 |
+
timesteps = scheduler.timesteps[-init_timestep]
|
| 82 |
+
timesteps = torch.tensor([timesteps], device=device)
|
| 83 |
+
|
| 84 |
+
# Adding noise to the latents
|
| 85 |
+
noise = torch.randn(init_latents.shape, generator=None, device=device, dtype=init_latents.dtype)
|
| 86 |
+
init_latents = scheduler.add_noise(init_latents, noise, timesteps)
|
| 87 |
+
latents = init_latents
|
| 88 |
+
|
| 89 |
+
# We need to scale the i/p latents to match the variance
|
| 90 |
+
inp = scheduler.scale_model_input(torch.cat([latents] * 2), timesteps)
|
| 91 |
+
# Predicting noise residual using U-Net
|
| 92 |
+
with torch.no_grad(): u,t = unet(inp, timesteps, encoder_hidden_states=emb).sample.chunk(2)
|
| 93 |
+
|
| 94 |
+
# Performing Guidance
|
| 95 |
+
pred = u + g*(t-u)
|
| 96 |
+
|
| 97 |
+
# Zero shot prediction
|
| 98 |
+
latents = scheduler.step(pred, timesteps, latents).pred_original_sample
|
| 99 |
+
|
| 100 |
+
# Returning the latent representation to output an array of 4x64x64
|
| 101 |
+
return latents.detach().cpu()
|
| 102 |
+
|
| 103 |
+
def create_mask_fast(init_img, rp, qp, n=20, s=0.5):
|
| 104 |
+
## Initialize a dictionary to save n iterations
|
| 105 |
+
diff = {}
|
| 106 |
+
|
| 107 |
+
## Repeating the difference process n times
|
| 108 |
+
for idx in range(n):
|
| 109 |
+
## Creating denoised sample using reference / original text
|
| 110 |
+
orig_noise = prompt_2_img_i2i_fast(prompts=rp, init_img=init_img, strength=s, seed = 100*idx)[0]
|
| 111 |
+
## Creating denoised sample using query / target text
|
| 112 |
+
query_noise = prompt_2_img_i2i_fast(prompts=qp, init_img=init_img, strength=s, seed = 100*idx)[0]
|
| 113 |
+
## Taking the difference
|
| 114 |
+
diff[idx] = (np.array(orig_noise)-np.array(query_noise))
|
| 115 |
+
|
| 116 |
+
## Creating a mask placeholder
|
| 117 |
+
mask = np.zeros_like(diff[0])
|
| 118 |
+
|
| 119 |
+
## Taking an average of 10 iterations
|
| 120 |
+
for idx in range(n):
|
| 121 |
+
## Note np.abs is a key step
|
| 122 |
+
mask += np.abs(diff[idx])
|
| 123 |
+
|
| 124 |
+
## Averaging multiple channels
|
| 125 |
+
mask = mask.mean(0)
|
| 126 |
+
|
| 127 |
+
## Normalizing
|
| 128 |
+
mask = (mask - mask.mean()) / np.std(mask)
|
| 129 |
+
|
| 130 |
+
## Binarizing and returning the mask object
|
| 131 |
+
return (mask > 0).astype("uint8")
|
| 132 |
+
|
| 133 |
+
def improve_mask(mask):
|
| 134 |
+
mask = cv2.GaussianBlur(mask*255,(3,3),1) > 0
|
| 135 |
+
return mask.astype('uint8')
|
| 136 |
+
|
| 137 |
+
vae, unet, tokenizer, text_encoder, scheduler = load_artifacts()
|
| 138 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
| 139 |
+
"runwayml/stable-diffusion-inpainting",
|
| 140 |
+
revision="fp16",
|
| 141 |
+
torch_dtype=torch.float16,
|
| 142 |
+
).to(device)
|
| 143 |
+
|
| 144 |
+
def fastDiffEdit(init_img, reference_prompt , query_prompt, g=7.5, seed=100, strength =0.7, steps=20, dim=512):
|
| 145 |
+
|
| 146 |
+
## Step 1: Create mask
|
| 147 |
+
mask = create_mask_fast(init_img=init_img, rp=reference_prompt, qp=query_prompt, n=20)
|
| 148 |
+
|
| 149 |
+
## Improve masking using CV trick
|
| 150 |
+
mask = improve_mask(mask)
|
| 151 |
+
|
| 152 |
+
## Step 2 and 3: Diffusion process using mask
|
| 153 |
+
output = pipe(
|
| 154 |
+
prompt=query_prompt,
|
| 155 |
+
image=init_img,
|
| 156 |
+
mask_image=Image.fromarray(mask*255).resize((512,512)),
|
| 157 |
+
generator=torch.Generator(device).manual_seed(100),
|
| 158 |
+
num_inference_steps = steps
|
| 159 |
+
).images
|
| 160 |
+
return output[0]
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
demo = gr.Interface(
|
| 165 |
+
fn=fastDiffEdit,
|
| 166 |
+
inputs=[
|
| 167 |
+
gr.inputs.Image(shape=(512, 512), type="pil", label = "Upload your image photo"),
|
| 168 |
+
gr.Textbox(label="Describe your image. Ex: a horse image"),
|
| 169 |
+
gr.Textbox(label="Retype the description with target output. Ex: a zebra image")],
|
| 170 |
+
outputs="image",
|
| 171 |
+
title = "DiffEdit demo",
|
| 172 |
+
description = "DiffEdit paper demo. Upload an image, pass reference prompt describing the image, pass query prompt to replace the object with target object",
|
| 173 |
+
examples = [
|
| 174 |
+
["fruitbowl.jpg", "a bowl of fruit", "a bowl of grapes"],
|
| 175 |
+
["horse.jpg", "a horse image", "a zebra image"]],
|
| 176 |
+
enable_queue=True
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
demo.launch()
|
fruitbowl.jpg
ADDED
|
horse.jpg
ADDED
|
Git LFS Details
|
packages.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python3-opencv
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu113
|
| 2 |
+
torch
|
| 3 |
+
torchvision
|
| 4 |
+
|
| 5 |
+
Pillow
|
| 6 |
+
opencv-python
|
| 7 |
+
ftfy
|
| 8 |
+
transformers==4.23.1
|
| 9 |
+
diffusers==0.6.0
|
| 10 |
+
|