Spaces:
Runtime error
Runtime error
Commit ·
c771a6a
0
Parent(s):
Duplicate from LinoyTsaban/edit_friendly_ddpm_inversion
Browse filesCo-authored-by: Linoy Tsaban <LinoyTsaban@users.noreply.huggingface.co>
- .gitattributes +36 -0
- Examples/ddpm_a_bronze_statue_of_an_old_man.png +0 -0
- Examples/ddpm_a_pink_ceramic_vase_with_a_wheat_bouquet.png +0 -0
- Examples/ddpm_a_zebra_on_the_run_way.png +0 -0
- Examples/gnochi_mirror.jpeg +0 -0
- Examples/gnochi_mirror_reconstrcution.png +0 -0
- Examples/gnochi_mirror_watercolor_painting.png +0 -0
- Examples/source_a_ceramic_vase_with_yellow_flowers.jpeg +3 -0
- Examples/source_a_model_on_a_runway.jpeg +3 -0
- Examples/source_an_old_man.png +0 -0
- README.md +22 -0
- app.py +246 -0
- inversion_utils.py +295 -0
- requirements.txt +5 -0
- style.css +121 -0
- utils.py +116 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
Examples/source_a_ceramic_vase_with_yellow_flowers.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Examples/source_a_model_on_a_runway.jpeg filter=lfs diff=lfs merge=lfs -text
|
Examples/ddpm_a_bronze_statue_of_an_old_man.png
ADDED
|
Examples/ddpm_a_pink_ceramic_vase_with_a_wheat_bouquet.png
ADDED
|
Examples/ddpm_a_zebra_on_the_run_way.png
ADDED
|
Examples/gnochi_mirror.jpeg
ADDED
|
Examples/gnochi_mirror_reconstrcution.png
ADDED
|
Examples/gnochi_mirror_watercolor_painting.png
ADDED
|
Examples/source_a_ceramic_vase_with_yellow_flowers.jpeg
ADDED
|
Git LFS Details
|
Examples/source_a_model_on_a_runway.jpeg
ADDED
|
Git LFS Details
|
Examples/source_an_old_man.png
ADDED
|
README.md
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Edit Friendly Ddpm Inversion
|
| 3 |
+
emoji: 🖼️
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: orange
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 3.32.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
duplicated_from: LinoyTsaban/edit_friendly_ddpm_inversion
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## BibTeX
|
| 14 |
+
|
| 15 |
+
```
|
| 16 |
+
@article{HubermanSpiegelglas2023,
|
| 17 |
+
title = {An Edit Friendly DDPM Noise Space: Inversion and Manipulations},
|
| 18 |
+
author = {Huberman-Spiegelglas, Inbar and Kulikov, Vladimir and Michaeli, Tomer},
|
| 19 |
+
journal = {arXiv preprint arXiv:2304.06140},
|
| 20 |
+
year = {2023}
|
| 21 |
+
}
|
| 22 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import requests
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from diffusers import StableDiffusionPipeline
|
| 7 |
+
from diffusers import DDIMScheduler
|
| 8 |
+
from utils import *
|
| 9 |
+
from inversion_utils import *
|
| 10 |
+
from torch import autocast, inference_mode
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
def randomize_seed_fn(seed, randomize_seed):
|
| 14 |
+
if randomize_seed:
|
| 15 |
+
seed = random.randint(0, np.iinfo(np.int32).max)
|
| 16 |
+
torch.manual_seed(seed)
|
| 17 |
+
return seed
|
| 18 |
+
|
| 19 |
+
def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
|
| 20 |
+
|
| 21 |
+
# inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
|
| 22 |
+
# based on the code in https://github.com/inbarhub/DDPM_inversion
|
| 23 |
+
|
| 24 |
+
# returns wt, zs, wts:
|
| 25 |
+
# wt - inverted latent
|
| 26 |
+
# wts - intermediate inverted latents
|
| 27 |
+
# zs - noise maps
|
| 28 |
+
|
| 29 |
+
sd_pipe.scheduler.set_timesteps(num_diffusion_steps)
|
| 30 |
+
|
| 31 |
+
# vae encode image
|
| 32 |
+
with autocast("cuda"), inference_mode():
|
| 33 |
+
w0 = (sd_pipe.vae.encode(x0).latent_dist.mode() * 0.18215).float()
|
| 34 |
+
|
| 35 |
+
# find Zs and wts - forward process
|
| 36 |
+
wt, zs, wts = inversion_forward_process(sd_pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=False, num_inference_steps=num_diffusion_steps)
|
| 37 |
+
return zs, wts
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def sample(zs, wts, prompt_tar="", skip=36, cfg_scale_tar=15, eta = 1):
|
| 42 |
+
|
| 43 |
+
# reverse process (via Zs and wT)
|
| 44 |
+
w0, _ = inversion_reverse_process(sd_pipe, xT=wts[skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=False, zs=zs[skip:])
|
| 45 |
+
|
| 46 |
+
# vae decode image
|
| 47 |
+
with autocast("cuda"), inference_mode():
|
| 48 |
+
x0_dec = sd_pipe.vae.decode(1 / 0.18215 * w0).sample
|
| 49 |
+
if x0_dec.dim()<4:
|
| 50 |
+
x0_dec = x0_dec[None,:,:,:]
|
| 51 |
+
img = image_grid(x0_dec)
|
| 52 |
+
return img
|
| 53 |
+
|
| 54 |
+
# load pipelines
|
| 55 |
+
sd_model_id = "runwayml/stable-diffusion-v1-5"
|
| 56 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 57 |
+
sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
|
| 58 |
+
sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_example():
|
| 63 |
+
case = [
|
| 64 |
+
[
|
| 65 |
+
'Examples/gnochi_mirror.jpeg',
|
| 66 |
+
'Watercolor painting of a cat sitting next to a mirror',
|
| 67 |
+
'Examples/gnochi_mirror_watercolor_painting.png',
|
| 68 |
+
'',
|
| 69 |
+
100,
|
| 70 |
+
3.5,
|
| 71 |
+
36,
|
| 72 |
+
15,
|
| 73 |
+
|
| 74 |
+
],
|
| 75 |
+
[
|
| 76 |
+
'Examples/source_an_old_man.png',
|
| 77 |
+
'A bronze statue of an old man',
|
| 78 |
+
'Examples/ddpm_a_bronze_statue_of_an_old_man.png',
|
| 79 |
+
'',
|
| 80 |
+
100,
|
| 81 |
+
3.5,
|
| 82 |
+
36,
|
| 83 |
+
15,
|
| 84 |
+
|
| 85 |
+
],
|
| 86 |
+
[
|
| 87 |
+
'Examples/source_a_ceramic_vase_with_yellow_flowers.jpeg',
|
| 88 |
+
'A pink ceramic vase with a wheat bouquet',
|
| 89 |
+
'Examples/ddpm_a_pink_ceramic_vase_with_a_wheat_bouquet.png',
|
| 90 |
+
'',
|
| 91 |
+
100,
|
| 92 |
+
3.5,
|
| 93 |
+
36,
|
| 94 |
+
15,
|
| 95 |
+
|
| 96 |
+
],
|
| 97 |
+
|
| 98 |
+
[
|
| 99 |
+
'Examples/source_a_model_on_a_runway.jpeg',
|
| 100 |
+
'A zebra on the runway',
|
| 101 |
+
'Examples/ddpm_a_zebra_on_the_run_way.png',
|
| 102 |
+
'',
|
| 103 |
+
100,
|
| 104 |
+
3.5,
|
| 105 |
+
36,
|
| 106 |
+
15,
|
| 107 |
+
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
]
|
| 112 |
+
return case
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
########
|
| 121 |
+
# demo #
|
| 122 |
+
########
|
| 123 |
+
|
| 124 |
+
intro = """
|
| 125 |
+
<h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">
|
| 126 |
+
Edit Friendly DDPM Inversion
|
| 127 |
+
</h1>
|
| 128 |
+
<p style="font-size: 0.9rem; text-align: center; margin: 0rem; line-height: 1.2em; margin-top:1em">
|
| 129 |
+
Based on the work introduced in:
|
| 130 |
+
<a href="https://arxiv.org/abs/2304.06140" style="text-decoration: underline;" target="_blank">An Edit Friendly DDPM Noise Space:
|
| 131 |
+
Inversion and Manipulations </a>
|
| 132 |
+
<p/>
|
| 133 |
+
<p style="font-size: 0.9rem; margin: 0rem; line-height: 1.2em; margin-top:1em">
|
| 134 |
+
For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
|
| 135 |
+
<a href="https://huggingface.co/spaces/LinoyTsaban/edit_friendly_ddpm_inversion?duplicate=true">
|
| 136 |
+
<img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
|
| 137 |
+
<p/>"""
|
| 138 |
+
with gr.Blocks(css='style.css') as demo:
|
| 139 |
+
|
| 140 |
+
def reset_do_inversion():
|
| 141 |
+
do_inversion = True
|
| 142 |
+
return do_inversion
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def edit(input_image,
|
| 146 |
+
do_inversion,
|
| 147 |
+
wts, zs,
|
| 148 |
+
src_prompt ="",
|
| 149 |
+
tar_prompt="",
|
| 150 |
+
steps=100,
|
| 151 |
+
cfg_scale_src = 3.5,
|
| 152 |
+
cfg_scale_tar = 15,
|
| 153 |
+
skip=36,
|
| 154 |
+
seed = 0,
|
| 155 |
+
randomize_seed = True):
|
| 156 |
+
|
| 157 |
+
x0 = load_512(input_image, device=device)
|
| 158 |
+
|
| 159 |
+
if do_inversion or randomize_seed:
|
| 160 |
+
zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src)
|
| 161 |
+
wts = gr.State(value=wts_tensor)
|
| 162 |
+
zs = gr.State(value=zs_tensor)
|
| 163 |
+
do_inversion = False
|
| 164 |
+
|
| 165 |
+
output = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=cfg_scale_tar)
|
| 166 |
+
return output, wts, zs, do_inversion
|
| 167 |
+
|
| 168 |
+
gr.HTML(intro)
|
| 169 |
+
wts = gr.State()
|
| 170 |
+
zs = gr.State()
|
| 171 |
+
do_inversion = gr.State(value=True)
|
| 172 |
+
with gr.Row():
|
| 173 |
+
input_image = gr.Image(label="Input Image", interactive=True)
|
| 174 |
+
input_image.style(height=365, width=365)
|
| 175 |
+
output_image = gr.Image(label=f"Edited Image", interactive=False)
|
| 176 |
+
output_image.style(height=365, width=365)
|
| 177 |
+
|
| 178 |
+
with gr.Row():
|
| 179 |
+
tar_prompt = gr.Textbox(lines=1, label="Describe your desired edited output", interactive=True)
|
| 180 |
+
|
| 181 |
+
with gr.Row():
|
| 182 |
+
with gr.Column(scale=1, min_width=100):
|
| 183 |
+
edit_button = gr.Button("Run")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 188 |
+
with gr.Row():
|
| 189 |
+
with gr.Column():
|
| 190 |
+
#inversion
|
| 191 |
+
src_prompt = gr.Textbox(lines=1, label="Source Prompt", interactive=True, placeholder="describe the original image")
|
| 192 |
+
steps = gr.Number(value=100, precision=0, label="Num Diffusion Steps", interactive=True)
|
| 193 |
+
cfg_scale_src = gr.Slider(minimum=1, maximum=15, value=3.5, label=f"Source Guidance Scale", interactive=True)
|
| 194 |
+
with gr.Column():
|
| 195 |
+
# reconstruction
|
| 196 |
+
skip = gr.Slider(minimum=0, maximum=60, value=36, step = 1, label="Skip Steps", interactive=True)
|
| 197 |
+
cfg_scale_tar = gr.Slider(minimum=7, maximum=18,value=15, label=f"Target Guidance Scale", interactive=True)
|
| 198 |
+
seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
|
| 199 |
+
randomize_seed = gr.Checkbox(label='Randomize seed', value=False)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
edit_button.click(
|
| 203 |
+
fn = randomize_seed_fn,
|
| 204 |
+
inputs = [seed, randomize_seed],
|
| 205 |
+
outputs = [seed], queue = False).then(
|
| 206 |
+
fn=edit,
|
| 207 |
+
inputs=[input_image,
|
| 208 |
+
do_inversion, wts, zs,
|
| 209 |
+
src_prompt,
|
| 210 |
+
tar_prompt,
|
| 211 |
+
steps,
|
| 212 |
+
cfg_scale_src,
|
| 213 |
+
cfg_scale_tar,
|
| 214 |
+
skip,
|
| 215 |
+
seed,randomize_seed
|
| 216 |
+
],
|
| 217 |
+
outputs=[output_image, wts, zs, do_inversion],
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
input_image.change(
|
| 221 |
+
fn = reset_do_inversion,
|
| 222 |
+
outputs = [do_inversion]
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
src_prompt.change(
|
| 226 |
+
fn = reset_do_inversion,
|
| 227 |
+
outputs = [do_inversion]
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
gr.Examples(
|
| 232 |
+
label='Examples',
|
| 233 |
+
examples=get_example(),
|
| 234 |
+
inputs=[input_image, tar_prompt,output_image, src_prompt,steps,
|
| 235 |
+
cfg_scale_tar,
|
| 236 |
+
skip,
|
| 237 |
+
cfg_scale_tar
|
| 238 |
+
|
| 239 |
+
],
|
| 240 |
+
outputs=[output_image ],
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
demo.queue()
|
| 246 |
+
demo.launch(share=False)
|
inversion_utils.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
from PIL import Image, ImageDraw ,ImageFont
|
| 5 |
+
from matplotlib import pyplot as plt
|
| 6 |
+
import torchvision.transforms as T
|
| 7 |
+
import os
|
| 8 |
+
import yaml
|
| 9 |
+
import numpy as np
|
| 10 |
+
import gradio as gr
|
| 11 |
+
|
| 12 |
+
# This file was copied from the DDPM inversion Repo - https://github.com/inbarhub/DDPM_inversion #
|
| 13 |
+
|
| 14 |
+
def load_512(image_path, left=0, right=0, top=0, bottom=0, device=None):
|
| 15 |
+
if type(image_path) is str:
|
| 16 |
+
image = np.array(Image.open(image_path).convert('RGB'))[:, :, :3]
|
| 17 |
+
else:
|
| 18 |
+
image = image_path
|
| 19 |
+
h, w, c = image.shape
|
| 20 |
+
left = min(left, w-1)
|
| 21 |
+
right = min(right, w - left - 1)
|
| 22 |
+
top = min(top, h - left - 1)
|
| 23 |
+
bottom = min(bottom, h - top - 1)
|
| 24 |
+
image = image[top:h-bottom, left:w-right]
|
| 25 |
+
h, w, c = image.shape
|
| 26 |
+
if h < w:
|
| 27 |
+
offset = (w - h) // 2
|
| 28 |
+
image = image[:, offset:offset + h]
|
| 29 |
+
elif w < h:
|
| 30 |
+
offset = (h - w) // 2
|
| 31 |
+
image = image[offset:offset + w]
|
| 32 |
+
image = np.array(Image.fromarray(image).resize((512, 512)))
|
| 33 |
+
image = torch.from_numpy(image).float() / 127.5 - 1
|
| 34 |
+
image = image.permute(2, 0, 1).unsqueeze(0).to(device)
|
| 35 |
+
|
| 36 |
+
return image
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_real_image(folder = "data/", img_name = None, idx = 0, img_size=512, device='cuda'):
|
| 40 |
+
from PIL import Image
|
| 41 |
+
from glob import glob
|
| 42 |
+
if img_name is not None:
|
| 43 |
+
path = os.path.join(folder, img_name)
|
| 44 |
+
else:
|
| 45 |
+
path = glob(folder + "*")[idx]
|
| 46 |
+
|
| 47 |
+
img = Image.open(path).resize((img_size,
|
| 48 |
+
img_size))
|
| 49 |
+
|
| 50 |
+
img = pil_to_tensor(img).to(device)
|
| 51 |
+
|
| 52 |
+
if img.shape[1]== 4:
|
| 53 |
+
img = img[:,:3,:,:]
|
| 54 |
+
return img
|
| 55 |
+
|
| 56 |
+
def mu_tilde(model, xt,x0, timestep):
|
| 57 |
+
"mu_tilde(x_t, x_0) DDPM paper eq. 7"
|
| 58 |
+
prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
| 59 |
+
alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
|
| 60 |
+
alpha_t = model.scheduler.alphas[timestep]
|
| 61 |
+
beta_t = 1 - alpha_t
|
| 62 |
+
alpha_bar = model.scheduler.alphas_cumprod[timestep]
|
| 63 |
+
return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1-alpha_bar)) * x0 + ((alpha_t**0.5 *(1-alpha_prod_t_prev)) / (1- alpha_bar))*xt
|
| 64 |
+
|
| 65 |
+
def sample_xts_from_x0(model, x0, num_inference_steps=50):
|
| 66 |
+
"""
|
| 67 |
+
Samples from P(x_1:T|x_0)
|
| 68 |
+
"""
|
| 69 |
+
# torch.manual_seed(43256465436)
|
| 70 |
+
alpha_bar = model.scheduler.alphas_cumprod
|
| 71 |
+
sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5
|
| 72 |
+
alphas = model.scheduler.alphas
|
| 73 |
+
betas = 1 - alphas
|
| 74 |
+
variance_noise_shape = (
|
| 75 |
+
num_inference_steps,
|
| 76 |
+
model.unet.in_channels,
|
| 77 |
+
model.unet.sample_size,
|
| 78 |
+
model.unet.sample_size)
|
| 79 |
+
|
| 80 |
+
timesteps = model.scheduler.timesteps.to(model.device)
|
| 81 |
+
t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
|
| 82 |
+
xts = torch.zeros(variance_noise_shape).to(x0.device)
|
| 83 |
+
for t in reversed(timesteps):
|
| 84 |
+
idx = t_to_idx[int(t)]
|
| 85 |
+
xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
|
| 86 |
+
xts = torch.cat([xts, x0 ],dim = 0)
|
| 87 |
+
|
| 88 |
+
return xts
|
| 89 |
+
|
| 90 |
+
def encode_text(model, prompts):
|
| 91 |
+
text_input = model.tokenizer(
|
| 92 |
+
prompts,
|
| 93 |
+
padding="max_length",
|
| 94 |
+
max_length=model.tokenizer.model_max_length,
|
| 95 |
+
truncation=True,
|
| 96 |
+
return_tensors="pt",
|
| 97 |
+
)
|
| 98 |
+
with torch.no_grad():
|
| 99 |
+
text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0]
|
| 100 |
+
return text_encoding
|
| 101 |
+
|
| 102 |
+
def forward_step(model, model_output, timestep, sample):
|
| 103 |
+
next_timestep = min(model.scheduler.config.num_train_timesteps - 2,
|
| 104 |
+
timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps)
|
| 105 |
+
|
| 106 |
+
# 2. compute alphas, betas
|
| 107 |
+
alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
|
| 108 |
+
# alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 else self.scheduler.final_alpha_cumprod
|
| 109 |
+
|
| 110 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 111 |
+
|
| 112 |
+
# 3. compute predicted original sample from predicted noise also called
|
| 113 |
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 114 |
+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
| 115 |
+
|
| 116 |
+
# 5. TODO: simple noising implementatiom
|
| 117 |
+
next_sample = model.scheduler.add_noise(pred_original_sample,
|
| 118 |
+
model_output,
|
| 119 |
+
torch.LongTensor([next_timestep]))
|
| 120 |
+
return next_sample
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def get_variance(model, timestep): #, prev_timestep):
|
| 124 |
+
prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
| 125 |
+
alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
|
| 126 |
+
alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
|
| 127 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 128 |
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
| 129 |
+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
| 130 |
+
return variance
|
| 131 |
+
|
| 132 |
+
def inversion_forward_process(model, x0,
|
| 133 |
+
etas = None,
|
| 134 |
+
prog_bar = False,
|
| 135 |
+
prompt = "",
|
| 136 |
+
cfg_scale = 3.5,
|
| 137 |
+
num_inference_steps=50, eps = None
|
| 138 |
+
):
|
| 139 |
+
|
| 140 |
+
if not prompt=="":
|
| 141 |
+
text_embeddings = encode_text(model, prompt)
|
| 142 |
+
uncond_embedding = encode_text(model, "")
|
| 143 |
+
timesteps = model.scheduler.timesteps.to(model.device)
|
| 144 |
+
variance_noise_shape = (
|
| 145 |
+
num_inference_steps,
|
| 146 |
+
model.unet.in_channels,
|
| 147 |
+
model.unet.sample_size,
|
| 148 |
+
model.unet.sample_size)
|
| 149 |
+
if etas is None or (type(etas) in [int, float] and etas == 0):
|
| 150 |
+
eta_is_zero = True
|
| 151 |
+
zs = None
|
| 152 |
+
else:
|
| 153 |
+
eta_is_zero = False
|
| 154 |
+
if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
|
| 155 |
+
xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps)
|
| 156 |
+
alpha_bar = model.scheduler.alphas_cumprod
|
| 157 |
+
zs = torch.zeros(size=variance_noise_shape, device=model.device)
|
| 158 |
+
|
| 159 |
+
t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
|
| 160 |
+
xt = x0
|
| 161 |
+
op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps)
|
| 162 |
+
|
| 163 |
+
for t in op:
|
| 164 |
+
idx = t_to_idx[int(t)]
|
| 165 |
+
# 1. predict noise residual
|
| 166 |
+
if not eta_is_zero:
|
| 167 |
+
xt = xts[idx][None]
|
| 168 |
+
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
out = model.unet.forward(xt, timestep = t, encoder_hidden_states = uncond_embedding)
|
| 171 |
+
if not prompt=="":
|
| 172 |
+
cond_out = model.unet.forward(xt, timestep=t, encoder_hidden_states = text_embeddings)
|
| 173 |
+
|
| 174 |
+
if not prompt=="":
|
| 175 |
+
## classifier free guidance
|
| 176 |
+
noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample)
|
| 177 |
+
else:
|
| 178 |
+
noise_pred = out.sample
|
| 179 |
+
|
| 180 |
+
if eta_is_zero:
|
| 181 |
+
# 2. compute more noisy image and set x_t -> x_t+1
|
| 182 |
+
xt = forward_step(model, noise_pred, t, xt)
|
| 183 |
+
|
| 184 |
+
else:
|
| 185 |
+
xtm1 = xts[idx+1][None]
|
| 186 |
+
# pred of x0
|
| 187 |
+
pred_original_sample = (xt - (1-alpha_bar[t]) ** 0.5 * noise_pred ) / alpha_bar[t] ** 0.5
|
| 188 |
+
|
| 189 |
+
# direction to xt
|
| 190 |
+
prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
| 191 |
+
alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
|
| 192 |
+
|
| 193 |
+
variance = get_variance(model, t)
|
| 194 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance ) ** (0.5) * noise_pred
|
| 195 |
+
|
| 196 |
+
mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
| 197 |
+
|
| 198 |
+
z = (xtm1 - mu_xt ) / ( etas[idx] * variance ** 0.5 )
|
| 199 |
+
zs[idx] = z
|
| 200 |
+
|
| 201 |
+
# correction to avoid error accumulation
|
| 202 |
+
xtm1 = mu_xt + ( etas[idx] * variance ** 0.5 )*z
|
| 203 |
+
xts[idx+1] = xtm1
|
| 204 |
+
|
| 205 |
+
if not zs is None:
|
| 206 |
+
zs[-1] = torch.zeros_like(zs[-1])
|
| 207 |
+
|
| 208 |
+
return xt, zs, xts
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def reverse_step(model, model_output, timestep, sample, eta = 0, variance_noise=None):
|
| 212 |
+
# 1. get previous step value (=t-1)
|
| 213 |
+
prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
|
| 214 |
+
# 2. compute alphas, betas
|
| 215 |
+
alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
|
| 216 |
+
alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
|
| 217 |
+
beta_prod_t = 1 - alpha_prod_t
|
| 218 |
+
# 3. compute predicted original sample from predicted noise also called
|
| 219 |
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 220 |
+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
| 221 |
+
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
| 222 |
+
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
| 223 |
+
# variance = self.scheduler._get_variance(timestep, prev_timestep)
|
| 224 |
+
variance = get_variance(model, timestep) #, prev_timestep)
|
| 225 |
+
std_dev_t = eta * variance ** (0.5)
|
| 226 |
+
# Take care of asymetric reverse process (asyrp)
|
| 227 |
+
model_output_direction = model_output
|
| 228 |
+
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 229 |
+
# pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
|
| 230 |
+
pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
|
| 231 |
+
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
| 232 |
+
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
| 233 |
+
# 8. Add noice if eta > 0
|
| 234 |
+
if eta > 0:
|
| 235 |
+
if variance_noise is None:
|
| 236 |
+
variance_noise = torch.randn(model_output.shape, device=model.device)
|
| 237 |
+
sigma_z = eta * variance ** (0.5) * variance_noise
|
| 238 |
+
prev_sample = prev_sample + sigma_z
|
| 239 |
+
|
| 240 |
+
return prev_sample
|
| 241 |
+
|
| 242 |
+
def inversion_reverse_process(model,
|
| 243 |
+
xT,
|
| 244 |
+
etas = 0,
|
| 245 |
+
prompts = "",
|
| 246 |
+
cfg_scales = None,
|
| 247 |
+
prog_bar = False,
|
| 248 |
+
zs = None,
|
| 249 |
+
controller=None,
|
| 250 |
+
asyrp = False
|
| 251 |
+
):
|
| 252 |
+
|
| 253 |
+
batch_size = len(prompts)
|
| 254 |
+
|
| 255 |
+
cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(model.device)
|
| 256 |
+
|
| 257 |
+
text_embeddings = encode_text(model, prompts)
|
| 258 |
+
uncond_embedding = encode_text(model, [""] * batch_size)
|
| 259 |
+
|
| 260 |
+
if etas is None: etas = 0
|
| 261 |
+
if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
|
| 262 |
+
assert len(etas) == model.scheduler.num_inference_steps
|
| 263 |
+
timesteps = model.scheduler.timesteps.to(model.device)
|
| 264 |
+
|
| 265 |
+
xt = xT.expand(batch_size, -1, -1, -1)
|
| 266 |
+
op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
|
| 267 |
+
|
| 268 |
+
t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
|
| 269 |
+
|
| 270 |
+
for t in op:
|
| 271 |
+
idx = t_to_idx[int(t)]
|
| 272 |
+
## Unconditional embedding
|
| 273 |
+
with torch.no_grad():
|
| 274 |
+
uncond_out = model.unet.forward(xt, timestep = t,
|
| 275 |
+
encoder_hidden_states = uncond_embedding)
|
| 276 |
+
|
| 277 |
+
## Conditional embedding
|
| 278 |
+
if prompts:
|
| 279 |
+
with torch.no_grad():
|
| 280 |
+
cond_out = model.unet.forward(xt, timestep = t,
|
| 281 |
+
encoder_hidden_states = text_embeddings)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
z = zs[idx] if not zs is None else None
|
| 285 |
+
z = z.expand(batch_size, -1, -1, -1)
|
| 286 |
+
if prompts:
|
| 287 |
+
## classifier free guidance
|
| 288 |
+
noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
|
| 289 |
+
else:
|
| 290 |
+
noise_pred = uncond_out.sample
|
| 291 |
+
# 2. compute less noisy image and set x_t -> x_t-1
|
| 292 |
+
xt = reverse_step(model, noise_pred, t, xt, eta = etas[idx], variance_noise = z)
|
| 293 |
+
if controller is not None:
|
| 294 |
+
xt = controller.step_callback(xt)
|
| 295 |
+
return xt, zs
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diffusers
|
| 2 |
+
accelerate
|
| 3 |
+
transformers
|
| 4 |
+
torch
|
| 5 |
+
torchvision
|
style.css
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
This CSS file is modified from:
|
| 3 |
+
https://huggingface.co/spaces/DeepFloyd/IF/blob/main/style.css
|
| 4 |
+
*/
|
| 5 |
+
|
| 6 |
+
h1 {
|
| 7 |
+
text-align: center;
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
.gradio-container {
|
| 11 |
+
font-family: 'IBM Plex Sans', sans-serif;
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
.gr-button {
|
| 15 |
+
color: white;
|
| 16 |
+
border-color: black;
|
| 17 |
+
background: black;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
input[type='range'] {
|
| 21 |
+
accent-color: black;
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
.dark input[type='range'] {
|
| 25 |
+
accent-color: #dfdfdf;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
.container {
|
| 29 |
+
max-width: 730px;
|
| 30 |
+
margin: auto;
|
| 31 |
+
padding-top: 1.5rem;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
.gr-button:focus {
|
| 36 |
+
border-color: rgb(147 197 253 / var(--tw-border-opacity));
|
| 37 |
+
outline: none;
|
| 38 |
+
box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
|
| 39 |
+
--tw-border-opacity: 1;
|
| 40 |
+
--tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
|
| 41 |
+
--tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
|
| 42 |
+
--tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
|
| 43 |
+
--tw-ring-opacity: .5;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
/* .footer {
|
| 48 |
+
margin-bottom: 45px;
|
| 49 |
+
margin-top: 35px;
|
| 50 |
+
text-align: center;
|
| 51 |
+
border-bottom: 1px solid #e5e5e5;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
.footer>p {
|
| 55 |
+
font-size: .8rem;
|
| 56 |
+
display: inline-block;
|
| 57 |
+
padding: 0 10px;
|
| 58 |
+
transform: translateY(10px);
|
| 59 |
+
background: white;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
.dark .footer {
|
| 63 |
+
border-color: #303030;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
.dark .footer>p {
|
| 67 |
+
background: #0b0f19;
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
.acknowledgments h4 {
|
| 71 |
+
margin: 1.25em 0 .25em 0;
|
| 72 |
+
font-weight: bold;
|
| 73 |
+
font-size: 115%;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
.animate-spin {
|
| 77 |
+
animation: spin 1s linear infinite;
|
| 78 |
+
} */
|
| 79 |
+
/*
|
| 80 |
+
@keyframes spin {
|
| 81 |
+
from {
|
| 82 |
+
transform: rotate(0deg);
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
to {
|
| 86 |
+
transform: rotate(360deg);
|
| 87 |
+
}
|
| 88 |
+
} */
|
| 89 |
+
|
| 90 |
+
.gr-form {
|
| 91 |
+
flex: 1 1 50%;
|
| 92 |
+
border-top-right-radius: 0;
|
| 93 |
+
border-bottom-right-radius: 0;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
#prompt-container {
|
| 97 |
+
gap: 0;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
#prompt-text-input,
|
| 101 |
+
#negative-prompt-text-input {
|
| 102 |
+
padding: .45rem 0.625rem
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
#component-16 {
|
| 106 |
+
border-top-width: 1px !important;
|
| 107 |
+
margin-top: 1em
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
.image_duplication {
|
| 111 |
+
position: absolute;
|
| 112 |
+
width: 100px;
|
| 113 |
+
left: 50px
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
#component-0 {
|
| 117 |
+
max-width: 730px;
|
| 118 |
+
margin: auto;
|
| 119 |
+
padding-top: 1.5rem;
|
| 120 |
+
}
|
| 121 |
+
|
utils.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import PIL
|
| 2 |
+
from PIL import Image, ImageDraw ,ImageFont
|
| 3 |
+
from matplotlib import pyplot as plt
|
| 4 |
+
import torchvision.transforms as T
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import yaml
|
| 8 |
+
|
| 9 |
+
# This file was copied from the DDPM inversion Repo - https://github.com/inbarhub/DDPM_inversion #
|
| 10 |
+
|
| 11 |
+
def show_torch_img(img):
|
| 12 |
+
img = to_np_image(img)
|
| 13 |
+
plt.imshow(img)
|
| 14 |
+
plt.axis("off")
|
| 15 |
+
|
| 16 |
+
def to_np_image(all_images):
|
| 17 |
+
all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()[0]
|
| 18 |
+
return all_images
|
| 19 |
+
|
| 20 |
+
def tensor_to_pil(tensor_imgs):
|
| 21 |
+
if type(tensor_imgs) == list:
|
| 22 |
+
tensor_imgs = torch.cat(tensor_imgs)
|
| 23 |
+
tensor_imgs = (tensor_imgs / 2 + 0.5).clamp(0, 1)
|
| 24 |
+
to_pil = T.ToPILImage()
|
| 25 |
+
pil_imgs = [to_pil(img) for img in tensor_imgs]
|
| 26 |
+
return pil_imgs
|
| 27 |
+
|
| 28 |
+
def pil_to_tensor(pil_imgs):
|
| 29 |
+
to_torch = T.ToTensor()
|
| 30 |
+
if type(pil_imgs) == PIL.Image.Image:
|
| 31 |
+
tensor_imgs = to_torch(pil_imgs).unsqueeze(0)*2-1
|
| 32 |
+
elif type(pil_imgs) == list:
|
| 33 |
+
tensor_imgs = torch.cat([to_torch(pil_imgs).unsqueeze(0)*2-1 for img in pil_imgs]).to(device)
|
| 34 |
+
else:
|
| 35 |
+
raise Exception("Input need to be PIL.Image or list of PIL.Image")
|
| 36 |
+
return tensor_imgs
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
## TODO implement this
|
| 40 |
+
# n = 10
|
| 41 |
+
# num_rows = 4
|
| 42 |
+
# num_col = n // num_rows
|
| 43 |
+
# num_col = num_col + 1 if n % num_rows else num_col
|
| 44 |
+
# num_col
|
| 45 |
+
def add_margin(pil_img, top = 0, right = 0, bottom = 0,
|
| 46 |
+
left = 0, color = (255,255,255)):
|
| 47 |
+
width, height = pil_img.size
|
| 48 |
+
new_width = width + right + left
|
| 49 |
+
new_height = height + top + bottom
|
| 50 |
+
result = Image.new(pil_img.mode, (new_width, new_height), color)
|
| 51 |
+
|
| 52 |
+
result.paste(pil_img, (left, top))
|
| 53 |
+
return result
|
| 54 |
+
|
| 55 |
+
def image_grid(imgs, rows = 1, cols = None,
|
| 56 |
+
size = None,
|
| 57 |
+
titles = None, text_pos = (0, 0)):
|
| 58 |
+
if type(imgs) == list and type(imgs[0]) == torch.Tensor:
|
| 59 |
+
imgs = torch.cat(imgs)
|
| 60 |
+
if type(imgs) == torch.Tensor:
|
| 61 |
+
imgs = tensor_to_pil(imgs)
|
| 62 |
+
|
| 63 |
+
if not size is None:
|
| 64 |
+
imgs = [img.resize((size,size)) for img in imgs]
|
| 65 |
+
if cols is None:
|
| 66 |
+
cols = len(imgs)
|
| 67 |
+
assert len(imgs) >= rows*cols
|
| 68 |
+
|
| 69 |
+
top=20
|
| 70 |
+
w, h = imgs[0].size
|
| 71 |
+
delta = 0
|
| 72 |
+
if len(imgs)> 1 and not imgs[1].size[1] == h:
|
| 73 |
+
delta = top
|
| 74 |
+
h = imgs[1].size[1]
|
| 75 |
+
if not titles is None:
|
| 76 |
+
font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf",
|
| 77 |
+
size = 20, encoding="unic")
|
| 78 |
+
h = top + h
|
| 79 |
+
grid = Image.new('RGB', size=(cols*w, rows*h+delta))
|
| 80 |
+
for i, img in enumerate(imgs):
|
| 81 |
+
|
| 82 |
+
if not titles is None:
|
| 83 |
+
img = add_margin(img, top = top, bottom = 0,left=0)
|
| 84 |
+
draw = ImageDraw.Draw(img)
|
| 85 |
+
draw.text(text_pos, titles[i],(0,0,0),
|
| 86 |
+
font = font)
|
| 87 |
+
if not delta == 0 and i > 0:
|
| 88 |
+
grid.paste(img, box=(i%cols*w, i//cols*h+delta))
|
| 89 |
+
else:
|
| 90 |
+
grid.paste(img, box=(i%cols*w, i//cols*h))
|
| 91 |
+
|
| 92 |
+
return grid
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
"""
|
| 96 |
+
input_folder - dataset folder
|
| 97 |
+
"""
|
| 98 |
+
def load_dataset(input_folder):
|
| 99 |
+
# full_file_names = glob.glob(input_folder)
|
| 100 |
+
# class_names = [x[0] for x in os.walk(input_folder)]
|
| 101 |
+
class_names = next(os.walk(input_folder))[1]
|
| 102 |
+
class_names[:] = [d for d in class_names if not d[0] == '.']
|
| 103 |
+
file_names=[]
|
| 104 |
+
for class_name in class_names:
|
| 105 |
+
cur_path = os.path.join(input_folder, class_name)
|
| 106 |
+
filenames = next(os.walk(cur_path), (None, None, []))[2]
|
| 107 |
+
filenames = [f for f in filenames if not f[0] == '.']
|
| 108 |
+
file_names.append(filenames)
|
| 109 |
+
return class_names, file_names
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def dataset_from_yaml(yaml_location):
|
| 113 |
+
with open(yaml_location, 'r') as stream:
|
| 114 |
+
data_loaded = yaml.safe_load(stream)
|
| 115 |
+
|
| 116 |
+
return data_loaded
|