NingKanae Linoy Tsaban commited on
Commit
c771a6a
·
0 Parent(s):

Duplicate from LinoyTsaban/edit_friendly_ddpm_inversion

Browse files

Co-authored-by: Linoy Tsaban <LinoyTsaban@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ Examples/source_a_ceramic_vase_with_yellow_flowers.jpeg filter=lfs diff=lfs merge=lfs -text
36
+ Examples/source_a_model_on_a_runway.jpeg filter=lfs diff=lfs merge=lfs -text
Examples/ddpm_a_bronze_statue_of_an_old_man.png ADDED
Examples/ddpm_a_pink_ceramic_vase_with_a_wheat_bouquet.png ADDED
Examples/ddpm_a_zebra_on_the_run_way.png ADDED
Examples/gnochi_mirror.jpeg ADDED
Examples/gnochi_mirror_reconstrcution.png ADDED
Examples/gnochi_mirror_watercolor_painting.png ADDED
Examples/source_a_ceramic_vase_with_yellow_flowers.jpeg ADDED

Git LFS Details

  • SHA256: 0f5ecbc8fedf38fc285d4c07a4905648b9b8542ed10d101c223eaf6cd0c8f125
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
Examples/source_a_model_on_a_runway.jpeg ADDED

Git LFS Details

  • SHA256: 95e7e0f6b6deafec8dd4e755a5239723d970feb5291c1139ca44758f41bde2ce
  • Pointer size: 132 Bytes
  • Size of remote file: 3.46 MB
Examples/source_an_old_man.png ADDED
README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Edit Friendly Ddpm Inversion
3
+ emoji: 🖼️
4
+ colorFrom: pink
5
+ colorTo: orange
6
+ sdk: gradio
7
+ sdk_version: 3.32.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: LinoyTsaban/edit_friendly_ddpm_inversion
11
+ ---
12
+
13
+ ## BibTeX
14
+
15
+ ```
16
+ @article{HubermanSpiegelglas2023,
17
+ title = {An Edit Friendly DDPM Noise Space: Inversion and Manipulations},
18
+ author = {Huberman-Spiegelglas, Inbar and Kulikov, Vladimir and Michaeli, Tomer},
19
+ journal = {arXiv preprint arXiv:2304.06140},
20
+ year = {2023}
21
+ }
22
+ ```
app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import random
4
+ import requests
5
+ from io import BytesIO
6
+ from diffusers import StableDiffusionPipeline
7
+ from diffusers import DDIMScheduler
8
+ from utils import *
9
+ from inversion_utils import *
10
+ from torch import autocast, inference_mode
11
+ import re
12
+
13
+ def randomize_seed_fn(seed, randomize_seed):
14
+ if randomize_seed:
15
+ seed = random.randint(0, np.iinfo(np.int32).max)
16
+ torch.manual_seed(seed)
17
+ return seed
18
+
19
+ def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
20
+
21
+ # inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
22
+ # based on the code in https://github.com/inbarhub/DDPM_inversion
23
+
24
+ # returns wt, zs, wts:
25
+ # wt - inverted latent
26
+ # wts - intermediate inverted latents
27
+ # zs - noise maps
28
+
29
+ sd_pipe.scheduler.set_timesteps(num_diffusion_steps)
30
+
31
+ # vae encode image
32
+ with autocast("cuda"), inference_mode():
33
+ w0 = (sd_pipe.vae.encode(x0).latent_dist.mode() * 0.18215).float()
34
+
35
+ # find Zs and wts - forward process
36
+ wt, zs, wts = inversion_forward_process(sd_pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=False, num_inference_steps=num_diffusion_steps)
37
+ return zs, wts
38
+
39
+
40
+
41
+ def sample(zs, wts, prompt_tar="", skip=36, cfg_scale_tar=15, eta = 1):
42
+
43
+ # reverse process (via Zs and wT)
44
+ w0, _ = inversion_reverse_process(sd_pipe, xT=wts[skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=False, zs=zs[skip:])
45
+
46
+ # vae decode image
47
+ with autocast("cuda"), inference_mode():
48
+ x0_dec = sd_pipe.vae.decode(1 / 0.18215 * w0).sample
49
+ if x0_dec.dim()<4:
50
+ x0_dec = x0_dec[None,:,:,:]
51
+ img = image_grid(x0_dec)
52
+ return img
53
+
54
+ # load pipelines
55
+ sd_model_id = "runwayml/stable-diffusion-v1-5"
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id).to(device)
58
+ sd_pipe.scheduler = DDIMScheduler.from_config(sd_model_id, subfolder = "scheduler")
59
+
60
+
61
+
62
+ def get_example():
63
+ case = [
64
+ [
65
+ 'Examples/gnochi_mirror.jpeg',
66
+ 'Watercolor painting of a cat sitting next to a mirror',
67
+ 'Examples/gnochi_mirror_watercolor_painting.png',
68
+ '',
69
+ 100,
70
+ 3.5,
71
+ 36,
72
+ 15,
73
+
74
+ ],
75
+ [
76
+ 'Examples/source_an_old_man.png',
77
+ 'A bronze statue of an old man',
78
+ 'Examples/ddpm_a_bronze_statue_of_an_old_man.png',
79
+ '',
80
+ 100,
81
+ 3.5,
82
+ 36,
83
+ 15,
84
+
85
+ ],
86
+ [
87
+ 'Examples/source_a_ceramic_vase_with_yellow_flowers.jpeg',
88
+ 'A pink ceramic vase with a wheat bouquet',
89
+ 'Examples/ddpm_a_pink_ceramic_vase_with_a_wheat_bouquet.png',
90
+ '',
91
+ 100,
92
+ 3.5,
93
+ 36,
94
+ 15,
95
+
96
+ ],
97
+
98
+ [
99
+ 'Examples/source_a_model_on_a_runway.jpeg',
100
+ 'A zebra on the runway',
101
+ 'Examples/ddpm_a_zebra_on_the_run_way.png',
102
+ '',
103
+ 100,
104
+ 3.5,
105
+ 36,
106
+ 15,
107
+
108
+ ]
109
+
110
+
111
+ ]
112
+ return case
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+ ########
121
+ # demo #
122
+ ########
123
+
124
+ intro = """
125
+ <h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">
126
+ Edit Friendly DDPM Inversion
127
+ </h1>
128
+ <p style="font-size: 0.9rem; text-align: center; margin: 0rem; line-height: 1.2em; margin-top:1em">
129
+ Based on the work introduced in:
130
+ <a href="https://arxiv.org/abs/2304.06140" style="text-decoration: underline;" target="_blank">An Edit Friendly DDPM Noise Space:
131
+ Inversion and Manipulations </a>
132
+ <p/>
133
+ <p style="font-size: 0.9rem; margin: 0rem; line-height: 1.2em; margin-top:1em">
134
+ For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings.
135
+ <a href="https://huggingface.co/spaces/LinoyTsaban/edit_friendly_ddpm_inversion?duplicate=true">
136
+ <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>
137
+ <p/>"""
138
+ with gr.Blocks(css='style.css') as demo:
139
+
140
+ def reset_do_inversion():
141
+ do_inversion = True
142
+ return do_inversion
143
+
144
+
145
+ def edit(input_image,
146
+ do_inversion,
147
+ wts, zs,
148
+ src_prompt ="",
149
+ tar_prompt="",
150
+ steps=100,
151
+ cfg_scale_src = 3.5,
152
+ cfg_scale_tar = 15,
153
+ skip=36,
154
+ seed = 0,
155
+ randomize_seed = True):
156
+
157
+ x0 = load_512(input_image, device=device)
158
+
159
+ if do_inversion or randomize_seed:
160
+ zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=cfg_scale_src)
161
+ wts = gr.State(value=wts_tensor)
162
+ zs = gr.State(value=zs_tensor)
163
+ do_inversion = False
164
+
165
+ output = sample(zs.value, wts.value, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=cfg_scale_tar)
166
+ return output, wts, zs, do_inversion
167
+
168
+ gr.HTML(intro)
169
+ wts = gr.State()
170
+ zs = gr.State()
171
+ do_inversion = gr.State(value=True)
172
+ with gr.Row():
173
+ input_image = gr.Image(label="Input Image", interactive=True)
174
+ input_image.style(height=365, width=365)
175
+ output_image = gr.Image(label=f"Edited Image", interactive=False)
176
+ output_image.style(height=365, width=365)
177
+
178
+ with gr.Row():
179
+ tar_prompt = gr.Textbox(lines=1, label="Describe your desired edited output", interactive=True)
180
+
181
+ with gr.Row():
182
+ with gr.Column(scale=1, min_width=100):
183
+ edit_button = gr.Button("Run")
184
+
185
+
186
+
187
+ with gr.Accordion("Advanced Options", open=False):
188
+ with gr.Row():
189
+ with gr.Column():
190
+ #inversion
191
+ src_prompt = gr.Textbox(lines=1, label="Source Prompt", interactive=True, placeholder="describe the original image")
192
+ steps = gr.Number(value=100, precision=0, label="Num Diffusion Steps", interactive=True)
193
+ cfg_scale_src = gr.Slider(minimum=1, maximum=15, value=3.5, label=f"Source Guidance Scale", interactive=True)
194
+ with gr.Column():
195
+ # reconstruction
196
+ skip = gr.Slider(minimum=0, maximum=60, value=36, step = 1, label="Skip Steps", interactive=True)
197
+ cfg_scale_tar = gr.Slider(minimum=7, maximum=18,value=15, label=f"Target Guidance Scale", interactive=True)
198
+ seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
199
+ randomize_seed = gr.Checkbox(label='Randomize seed', value=False)
200
+
201
+
202
+ edit_button.click(
203
+ fn = randomize_seed_fn,
204
+ inputs = [seed, randomize_seed],
205
+ outputs = [seed], queue = False).then(
206
+ fn=edit,
207
+ inputs=[input_image,
208
+ do_inversion, wts, zs,
209
+ src_prompt,
210
+ tar_prompt,
211
+ steps,
212
+ cfg_scale_src,
213
+ cfg_scale_tar,
214
+ skip,
215
+ seed,randomize_seed
216
+ ],
217
+ outputs=[output_image, wts, zs, do_inversion],
218
+ )
219
+
220
+ input_image.change(
221
+ fn = reset_do_inversion,
222
+ outputs = [do_inversion]
223
+ )
224
+
225
+ src_prompt.change(
226
+ fn = reset_do_inversion,
227
+ outputs = [do_inversion]
228
+ )
229
+
230
+
231
+ gr.Examples(
232
+ label='Examples',
233
+ examples=get_example(),
234
+ inputs=[input_image, tar_prompt,output_image, src_prompt,steps,
235
+ cfg_scale_tar,
236
+ skip,
237
+ cfg_scale_tar
238
+
239
+ ],
240
+ outputs=[output_image ],
241
+ )
242
+
243
+
244
+
245
+ demo.queue()
246
+ demo.launch(share=False)
inversion_utils.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from tqdm import tqdm
4
+ from PIL import Image, ImageDraw ,ImageFont
5
+ from matplotlib import pyplot as plt
6
+ import torchvision.transforms as T
7
+ import os
8
+ import yaml
9
+ import numpy as np
10
+ import gradio as gr
11
+
12
+ # This file was copied from the DDPM inversion Repo - https://github.com/inbarhub/DDPM_inversion #
13
+
14
+ def load_512(image_path, left=0, right=0, top=0, bottom=0, device=None):
15
+ if type(image_path) is str:
16
+ image = np.array(Image.open(image_path).convert('RGB'))[:, :, :3]
17
+ else:
18
+ image = image_path
19
+ h, w, c = image.shape
20
+ left = min(left, w-1)
21
+ right = min(right, w - left - 1)
22
+ top = min(top, h - left - 1)
23
+ bottom = min(bottom, h - top - 1)
24
+ image = image[top:h-bottom, left:w-right]
25
+ h, w, c = image.shape
26
+ if h < w:
27
+ offset = (w - h) // 2
28
+ image = image[:, offset:offset + h]
29
+ elif w < h:
30
+ offset = (h - w) // 2
31
+ image = image[offset:offset + w]
32
+ image = np.array(Image.fromarray(image).resize((512, 512)))
33
+ image = torch.from_numpy(image).float() / 127.5 - 1
34
+ image = image.permute(2, 0, 1).unsqueeze(0).to(device)
35
+
36
+ return image
37
+
38
+
39
+ def load_real_image(folder = "data/", img_name = None, idx = 0, img_size=512, device='cuda'):
40
+ from PIL import Image
41
+ from glob import glob
42
+ if img_name is not None:
43
+ path = os.path.join(folder, img_name)
44
+ else:
45
+ path = glob(folder + "*")[idx]
46
+
47
+ img = Image.open(path).resize((img_size,
48
+ img_size))
49
+
50
+ img = pil_to_tensor(img).to(device)
51
+
52
+ if img.shape[1]== 4:
53
+ img = img[:,:3,:,:]
54
+ return img
55
+
56
+ def mu_tilde(model, xt,x0, timestep):
57
+ "mu_tilde(x_t, x_0) DDPM paper eq. 7"
58
+ prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
59
+ alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
60
+ alpha_t = model.scheduler.alphas[timestep]
61
+ beta_t = 1 - alpha_t
62
+ alpha_bar = model.scheduler.alphas_cumprod[timestep]
63
+ return ((alpha_prod_t_prev ** 0.5 * beta_t) / (1-alpha_bar)) * x0 + ((alpha_t**0.5 *(1-alpha_prod_t_prev)) / (1- alpha_bar))*xt
64
+
65
+ def sample_xts_from_x0(model, x0, num_inference_steps=50):
66
+ """
67
+ Samples from P(x_1:T|x_0)
68
+ """
69
+ # torch.manual_seed(43256465436)
70
+ alpha_bar = model.scheduler.alphas_cumprod
71
+ sqrt_one_minus_alpha_bar = (1-alpha_bar) ** 0.5
72
+ alphas = model.scheduler.alphas
73
+ betas = 1 - alphas
74
+ variance_noise_shape = (
75
+ num_inference_steps,
76
+ model.unet.in_channels,
77
+ model.unet.sample_size,
78
+ model.unet.sample_size)
79
+
80
+ timesteps = model.scheduler.timesteps.to(model.device)
81
+ t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
82
+ xts = torch.zeros(variance_noise_shape).to(x0.device)
83
+ for t in reversed(timesteps):
84
+ idx = t_to_idx[int(t)]
85
+ xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
86
+ xts = torch.cat([xts, x0 ],dim = 0)
87
+
88
+ return xts
89
+
90
+ def encode_text(model, prompts):
91
+ text_input = model.tokenizer(
92
+ prompts,
93
+ padding="max_length",
94
+ max_length=model.tokenizer.model_max_length,
95
+ truncation=True,
96
+ return_tensors="pt",
97
+ )
98
+ with torch.no_grad():
99
+ text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0]
100
+ return text_encoding
101
+
102
+ def forward_step(model, model_output, timestep, sample):
103
+ next_timestep = min(model.scheduler.config.num_train_timesteps - 2,
104
+ timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps)
105
+
106
+ # 2. compute alphas, betas
107
+ alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
108
+ # alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] if next_ltimestep >= 0 else self.scheduler.final_alpha_cumprod
109
+
110
+ beta_prod_t = 1 - alpha_prod_t
111
+
112
+ # 3. compute predicted original sample from predicted noise also called
113
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
114
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
115
+
116
+ # 5. TODO: simple noising implementatiom
117
+ next_sample = model.scheduler.add_noise(pred_original_sample,
118
+ model_output,
119
+ torch.LongTensor([next_timestep]))
120
+ return next_sample
121
+
122
+
123
+ def get_variance(model, timestep): #, prev_timestep):
124
+ prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
125
+ alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
126
+ alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
127
+ beta_prod_t = 1 - alpha_prod_t
128
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
129
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
130
+ return variance
131
+
132
+ def inversion_forward_process(model, x0,
133
+ etas = None,
134
+ prog_bar = False,
135
+ prompt = "",
136
+ cfg_scale = 3.5,
137
+ num_inference_steps=50, eps = None
138
+ ):
139
+
140
+ if not prompt=="":
141
+ text_embeddings = encode_text(model, prompt)
142
+ uncond_embedding = encode_text(model, "")
143
+ timesteps = model.scheduler.timesteps.to(model.device)
144
+ variance_noise_shape = (
145
+ num_inference_steps,
146
+ model.unet.in_channels,
147
+ model.unet.sample_size,
148
+ model.unet.sample_size)
149
+ if etas is None or (type(etas) in [int, float] and etas == 0):
150
+ eta_is_zero = True
151
+ zs = None
152
+ else:
153
+ eta_is_zero = False
154
+ if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
155
+ xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps)
156
+ alpha_bar = model.scheduler.alphas_cumprod
157
+ zs = torch.zeros(size=variance_noise_shape, device=model.device)
158
+
159
+ t_to_idx = {int(v):k for k,v in enumerate(timesteps)}
160
+ xt = x0
161
+ op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps)
162
+
163
+ for t in op:
164
+ idx = t_to_idx[int(t)]
165
+ # 1. predict noise residual
166
+ if not eta_is_zero:
167
+ xt = xts[idx][None]
168
+
169
+ with torch.no_grad():
170
+ out = model.unet.forward(xt, timestep = t, encoder_hidden_states = uncond_embedding)
171
+ if not prompt=="":
172
+ cond_out = model.unet.forward(xt, timestep=t, encoder_hidden_states = text_embeddings)
173
+
174
+ if not prompt=="":
175
+ ## classifier free guidance
176
+ noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample)
177
+ else:
178
+ noise_pred = out.sample
179
+
180
+ if eta_is_zero:
181
+ # 2. compute more noisy image and set x_t -> x_t+1
182
+ xt = forward_step(model, noise_pred, t, xt)
183
+
184
+ else:
185
+ xtm1 = xts[idx+1][None]
186
+ # pred of x0
187
+ pred_original_sample = (xt - (1-alpha_bar[t]) ** 0.5 * noise_pred ) / alpha_bar[t] ** 0.5
188
+
189
+ # direction to xt
190
+ prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
191
+ alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
192
+
193
+ variance = get_variance(model, t)
194
+ pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance ) ** (0.5) * noise_pred
195
+
196
+ mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
197
+
198
+ z = (xtm1 - mu_xt ) / ( etas[idx] * variance ** 0.5 )
199
+ zs[idx] = z
200
+
201
+ # correction to avoid error accumulation
202
+ xtm1 = mu_xt + ( etas[idx] * variance ** 0.5 )*z
203
+ xts[idx+1] = xtm1
204
+
205
+ if not zs is None:
206
+ zs[-1] = torch.zeros_like(zs[-1])
207
+
208
+ return xt, zs, xts
209
+
210
+
211
+ def reverse_step(model, model_output, timestep, sample, eta = 0, variance_noise=None):
212
+ # 1. get previous step value (=t-1)
213
+ prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
214
+ # 2. compute alphas, betas
215
+ alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
216
+ alpha_prod_t_prev = model.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
217
+ beta_prod_t = 1 - alpha_prod_t
218
+ # 3. compute predicted original sample from predicted noise also called
219
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
220
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
221
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
222
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
223
+ # variance = self.scheduler._get_variance(timestep, prev_timestep)
224
+ variance = get_variance(model, timestep) #, prev_timestep)
225
+ std_dev_t = eta * variance ** (0.5)
226
+ # Take care of asymetric reverse process (asyrp)
227
+ model_output_direction = model_output
228
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
229
+ # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
230
+ pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
231
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
232
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
233
+ # 8. Add noice if eta > 0
234
+ if eta > 0:
235
+ if variance_noise is None:
236
+ variance_noise = torch.randn(model_output.shape, device=model.device)
237
+ sigma_z = eta * variance ** (0.5) * variance_noise
238
+ prev_sample = prev_sample + sigma_z
239
+
240
+ return prev_sample
241
+
242
+ def inversion_reverse_process(model,
243
+ xT,
244
+ etas = 0,
245
+ prompts = "",
246
+ cfg_scales = None,
247
+ prog_bar = False,
248
+ zs = None,
249
+ controller=None,
250
+ asyrp = False
251
+ ):
252
+
253
+ batch_size = len(prompts)
254
+
255
+ cfg_scales_tensor = torch.Tensor(cfg_scales).view(-1,1,1,1).to(model.device)
256
+
257
+ text_embeddings = encode_text(model, prompts)
258
+ uncond_embedding = encode_text(model, [""] * batch_size)
259
+
260
+ if etas is None: etas = 0
261
+ if type(etas) in [int, float]: etas = [etas]*model.scheduler.num_inference_steps
262
+ assert len(etas) == model.scheduler.num_inference_steps
263
+ timesteps = model.scheduler.timesteps.to(model.device)
264
+
265
+ xt = xT.expand(batch_size, -1, -1, -1)
266
+ op = tqdm(timesteps[-zs.shape[0]:]) if prog_bar else timesteps[-zs.shape[0]:]
267
+
268
+ t_to_idx = {int(v):k for k,v in enumerate(timesteps[-zs.shape[0]:])}
269
+
270
+ for t in op:
271
+ idx = t_to_idx[int(t)]
272
+ ## Unconditional embedding
273
+ with torch.no_grad():
274
+ uncond_out = model.unet.forward(xt, timestep = t,
275
+ encoder_hidden_states = uncond_embedding)
276
+
277
+ ## Conditional embedding
278
+ if prompts:
279
+ with torch.no_grad():
280
+ cond_out = model.unet.forward(xt, timestep = t,
281
+ encoder_hidden_states = text_embeddings)
282
+
283
+
284
+ z = zs[idx] if not zs is None else None
285
+ z = z.expand(batch_size, -1, -1, -1)
286
+ if prompts:
287
+ ## classifier free guidance
288
+ noise_pred = uncond_out.sample + cfg_scales_tensor * (cond_out.sample - uncond_out.sample)
289
+ else:
290
+ noise_pred = uncond_out.sample
291
+ # 2. compute less noisy image and set x_t -> x_t-1
292
+ xt = reverse_step(model, noise_pred, t, xt, eta = etas[idx], variance_noise = z)
293
+ if controller is not None:
294
+ xt = controller.step_callback(xt)
295
+ return xt, zs
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ diffusers
2
+ accelerate
3
+ transformers
4
+ torch
5
+ torchvision
style.css ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ This CSS file is modified from:
3
+ https://huggingface.co/spaces/DeepFloyd/IF/blob/main/style.css
4
+ */
5
+
6
+ h1 {
7
+ text-align: center;
8
+ }
9
+
10
+ .gradio-container {
11
+ font-family: 'IBM Plex Sans', sans-serif;
12
+ }
13
+
14
+ .gr-button {
15
+ color: white;
16
+ border-color: black;
17
+ background: black;
18
+ }
19
+
20
+ input[type='range'] {
21
+ accent-color: black;
22
+ }
23
+
24
+ .dark input[type='range'] {
25
+ accent-color: #dfdfdf;
26
+ }
27
+
28
+ .container {
29
+ max-width: 730px;
30
+ margin: auto;
31
+ padding-top: 1.5rem;
32
+ }
33
+
34
+
35
+ .gr-button:focus {
36
+ border-color: rgb(147 197 253 / var(--tw-border-opacity));
37
+ outline: none;
38
+ box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000);
39
+ --tw-border-opacity: 1;
40
+ --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color);
41
+ --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color);
42
+ --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity));
43
+ --tw-ring-opacity: .5;
44
+ }
45
+
46
+
47
+ /* .footer {
48
+ margin-bottom: 45px;
49
+ margin-top: 35px;
50
+ text-align: center;
51
+ border-bottom: 1px solid #e5e5e5;
52
+ }
53
+
54
+ .footer>p {
55
+ font-size: .8rem;
56
+ display: inline-block;
57
+ padding: 0 10px;
58
+ transform: translateY(10px);
59
+ background: white;
60
+ }
61
+
62
+ .dark .footer {
63
+ border-color: #303030;
64
+ }
65
+
66
+ .dark .footer>p {
67
+ background: #0b0f19;
68
+ }
69
+
70
+ .acknowledgments h4 {
71
+ margin: 1.25em 0 .25em 0;
72
+ font-weight: bold;
73
+ font-size: 115%;
74
+ }
75
+
76
+ .animate-spin {
77
+ animation: spin 1s linear infinite;
78
+ } */
79
+ /*
80
+ @keyframes spin {
81
+ from {
82
+ transform: rotate(0deg);
83
+ }
84
+
85
+ to {
86
+ transform: rotate(360deg);
87
+ }
88
+ } */
89
+
90
+ .gr-form {
91
+ flex: 1 1 50%;
92
+ border-top-right-radius: 0;
93
+ border-bottom-right-radius: 0;
94
+ }
95
+
96
+ #prompt-container {
97
+ gap: 0;
98
+ }
99
+
100
+ #prompt-text-input,
101
+ #negative-prompt-text-input {
102
+ padding: .45rem 0.625rem
103
+ }
104
+
105
+ #component-16 {
106
+ border-top-width: 1px !important;
107
+ margin-top: 1em
108
+ }
109
+
110
+ .image_duplication {
111
+ position: absolute;
112
+ width: 100px;
113
+ left: 50px
114
+ }
115
+
116
+ #component-0 {
117
+ max-width: 730px;
118
+ margin: auto;
119
+ padding-top: 1.5rem;
120
+ }
121
+
utils.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ from PIL import Image, ImageDraw ,ImageFont
3
+ from matplotlib import pyplot as plt
4
+ import torchvision.transforms as T
5
+ import os
6
+ import torch
7
+ import yaml
8
+
9
+ # This file was copied from the DDPM inversion Repo - https://github.com/inbarhub/DDPM_inversion #
10
+
11
+ def show_torch_img(img):
12
+ img = to_np_image(img)
13
+ plt.imshow(img)
14
+ plt.axis("off")
15
+
16
+ def to_np_image(all_images):
17
+ all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()[0]
18
+ return all_images
19
+
20
+ def tensor_to_pil(tensor_imgs):
21
+ if type(tensor_imgs) == list:
22
+ tensor_imgs = torch.cat(tensor_imgs)
23
+ tensor_imgs = (tensor_imgs / 2 + 0.5).clamp(0, 1)
24
+ to_pil = T.ToPILImage()
25
+ pil_imgs = [to_pil(img) for img in tensor_imgs]
26
+ return pil_imgs
27
+
28
+ def pil_to_tensor(pil_imgs):
29
+ to_torch = T.ToTensor()
30
+ if type(pil_imgs) == PIL.Image.Image:
31
+ tensor_imgs = to_torch(pil_imgs).unsqueeze(0)*2-1
32
+ elif type(pil_imgs) == list:
33
+ tensor_imgs = torch.cat([to_torch(pil_imgs).unsqueeze(0)*2-1 for img in pil_imgs]).to(device)
34
+ else:
35
+ raise Exception("Input need to be PIL.Image or list of PIL.Image")
36
+ return tensor_imgs
37
+
38
+
39
+ ## TODO implement this
40
+ # n = 10
41
+ # num_rows = 4
42
+ # num_col = n // num_rows
43
+ # num_col = num_col + 1 if n % num_rows else num_col
44
+ # num_col
45
+ def add_margin(pil_img, top = 0, right = 0, bottom = 0,
46
+ left = 0, color = (255,255,255)):
47
+ width, height = pil_img.size
48
+ new_width = width + right + left
49
+ new_height = height + top + bottom
50
+ result = Image.new(pil_img.mode, (new_width, new_height), color)
51
+
52
+ result.paste(pil_img, (left, top))
53
+ return result
54
+
55
+ def image_grid(imgs, rows = 1, cols = None,
56
+ size = None,
57
+ titles = None, text_pos = (0, 0)):
58
+ if type(imgs) == list and type(imgs[0]) == torch.Tensor:
59
+ imgs = torch.cat(imgs)
60
+ if type(imgs) == torch.Tensor:
61
+ imgs = tensor_to_pil(imgs)
62
+
63
+ if not size is None:
64
+ imgs = [img.resize((size,size)) for img in imgs]
65
+ if cols is None:
66
+ cols = len(imgs)
67
+ assert len(imgs) >= rows*cols
68
+
69
+ top=20
70
+ w, h = imgs[0].size
71
+ delta = 0
72
+ if len(imgs)> 1 and not imgs[1].size[1] == h:
73
+ delta = top
74
+ h = imgs[1].size[1]
75
+ if not titles is None:
76
+ font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf",
77
+ size = 20, encoding="unic")
78
+ h = top + h
79
+ grid = Image.new('RGB', size=(cols*w, rows*h+delta))
80
+ for i, img in enumerate(imgs):
81
+
82
+ if not titles is None:
83
+ img = add_margin(img, top = top, bottom = 0,left=0)
84
+ draw = ImageDraw.Draw(img)
85
+ draw.text(text_pos, titles[i],(0,0,0),
86
+ font = font)
87
+ if not delta == 0 and i > 0:
88
+ grid.paste(img, box=(i%cols*w, i//cols*h+delta))
89
+ else:
90
+ grid.paste(img, box=(i%cols*w, i//cols*h))
91
+
92
+ return grid
93
+
94
+
95
+ """
96
+ input_folder - dataset folder
97
+ """
98
+ def load_dataset(input_folder):
99
+ # full_file_names = glob.glob(input_folder)
100
+ # class_names = [x[0] for x in os.walk(input_folder)]
101
+ class_names = next(os.walk(input_folder))[1]
102
+ class_names[:] = [d for d in class_names if not d[0] == '.']
103
+ file_names=[]
104
+ for class_name in class_names:
105
+ cur_path = os.path.join(input_folder, class_name)
106
+ filenames = next(os.walk(cur_path), (None, None, []))[2]
107
+ filenames = [f for f in filenames if not f[0] == '.']
108
+ file_names.append(filenames)
109
+ return class_names, file_names
110
+
111
+
112
+ def dataset_from_yaml(yaml_location):
113
+ with open(yaml_location, 'r') as stream:
114
+ data_loaded = yaml.safe_load(stream)
115
+
116
+ return data_loaded