rohithb commited on
Commit
1b341bf
·
1 Parent(s): da3f878

Uploaded bins and helper scripts.

Browse files
learned_embeds_doodle.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2514c146d7fedf9d73c8def774dd84ea0006ca71ac784b2fbcb58b55872aefba
3
+ size 3840
learned_embeds_kaleido.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:387c8f7ae79e8b4c224a5042aa52c337e15658440d9a374b356fc07706b0db37
3
+ size 3819
learned_embeds_oil_paint.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:754d7d9c1fcdc7e05fd273f21e77b05bc89a4ba25415d24de1286f1fbdf9e0c7
3
+ size 3840
learned_embeds_strip_style.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1642fabeef65edb374857553e75dae1f6ec8ab3aeba634d0a026f836e8cc4db
3
+ size 3840
learned_embeds_watercolor.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee35016acb95e43fae35b41bde2e46fdd0d9fbdcd3da8f13a638cf35c8bab2d4
3
+ size 3819
stable_diffusion.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.transforms import ToTensor
3
+ from utils import get_style_embeddings, get_EOS_pos_in_prompt, invert_loss
4
+ from base64 import b64encode
5
+ import numpy as np
6
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
7
+
8
+ from torch import autocast
9
+ from torchvision import transforms as tfms
10
+ from tqdm.auto import tqdm
11
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
12
+ from PIL import Image
13
+ import os
14
+ import torchvision.transforms as T
15
+
16
+
17
+ class StableDiffusion:
18
+ def __init__(self, torch_device, num_inference_steps=30, height=512, width=512, guidance_scale=7.5):
19
+ # Load the autoencoder
20
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder='vae')
21
+
22
+ # Load tokenizer and text encoder to tokenize and encode the text
23
+ self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
24
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
25
+
26
+ # Unet model for generating latents
27
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder='unet')
28
+
29
+ # Noise scheduler
30
+ self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
31
+
32
+ # Move everything to GPU
33
+ self.torch_device = torch_device
34
+ self.vae = vae.to(self.torch_device)
35
+ self.text_encoder = text_encoder.to(self.torch_device)
36
+ self.unet = unet.to(self.torch_device)
37
+
38
+ # additional properties
39
+ self.num_inference_steps = num_inference_steps
40
+ self.height = height # default height of Stable Diffusion
41
+ self.width = width # default width of Stable Diffusion
42
+ self.guidance_scale = guidance_scale # Scale for classifier-free guidance
43
+
44
+
45
+ # Prep Scheduler
46
+ def set_timesteps(self):
47
+ self.scheduler.set_timesteps(self.num_inference_steps)
48
+ self.scheduler.timesteps = self.scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility, fixed in diffusers PR 3925
49
+
50
+
51
+ def additional_guidance(self, latents, noise_pred, t, sigma, custom_loss_fn, custom_loss_scale):
52
+ #### ADDITIONAL GUIDANCE ###
53
+ # Requires grad on the latents
54
+ latents = latents.detach().requires_grad_()
55
+
56
+ # Get the predicted x0:
57
+ latents_x0 = latents - sigma * noise_pred
58
+ #print(f"latents: {latents.shape}, noise_pred:{noise_pred.shape}")
59
+ #latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
60
+
61
+ # Decode to image space
62
+ denoised_images = self.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
63
+
64
+ # Calculate loss
65
+ loss = custom_loss_fn(denoised_images) * custom_loss_scale
66
+
67
+ # Get gradient
68
+ cond_grad = torch.autograd.grad(loss, latents, allow_unused=False)[0]
69
+
70
+ # Modify the latents based on this gradient
71
+ latents = latents.detach() - cond_grad * sigma**2
72
+ return latents, loss
73
+
74
+
75
+ def generate_with_embs(self, text_embeddings, max_length, random_seed, custom_loss_fn, custom_loss_scale):
76
+
77
+ generator = torch.manual_seed(random_seed) # Seed generator to create the inital latent noise
78
+ batch_size = 1
79
+
80
+ uncond_input = self.tokenizer(
81
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
82
+ )
83
+ with torch.no_grad():
84
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.torch_device))[0]
85
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
86
+
87
+ # Prep Scheduler
88
+ self.set_timesteps()
89
+
90
+ # Prep latents
91
+ latents = torch.randn( (batch_size, self.unet.in_channels, self.height // 8, self.width // 8), generator=generator,)
92
+ latents = latents.to(self.torch_device)
93
+ latents = latents * self.scheduler.init_noise_sigma
94
+
95
+ # Loop
96
+ for i, t in tqdm(enumerate(self.scheduler.timesteps), total=len(self.scheduler.timesteps)):
97
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
98
+ latent_model_input = torch.cat([latents] * 2)
99
+ sigma = self.scheduler.sigmas[i]
100
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
101
+
102
+ # predict the noise residual
103
+ with torch.no_grad():
104
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
105
+
106
+ # perform guidance
107
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
108
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
109
+ if custom_loss_fn is not None:
110
+ if i%10 == 0:
111
+ latents, custom_loss = self.additional_guidance(latents, noise_pred, t, sigma, custom_loss_fn, custom_loss_scale)
112
+ print(i, 'loss:', custom_loss.item())
113
+
114
+ # compute the previous noisy sample x_t -> x_t-1
115
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
116
+ return self.latents_to_pil(latents)[0]
117
+
118
+
119
+ def get_output_embeds(self, input_embeddings):
120
+ # CLIP's text model uses causal mask, so we prepare it here:
121
+ bsz, seq_len = input_embeddings.shape[:2]
122
+ causal_attention_mask = self.text_encoder.text_model._build_causal_attention_mask(bsz, seq_len, dtype=input_embeddings.dtype)
123
+
124
+ # Getting the output embeddings involves calling the model with passing output_hidden_states=True
125
+ # so that it doesn't just return the pooled final predictions:
126
+ encoder_outputs = self.text_encoder.text_model.encoder(
127
+ inputs_embeds=input_embeddings,
128
+ attention_mask=None, # We aren't using an attention mask so that can be None
129
+ causal_attention_mask=causal_attention_mask.to(self.torch_device),
130
+ output_attentions=None,
131
+ output_hidden_states=True, # We want the output embs not the final output
132
+ return_dict=None,
133
+ )
134
+
135
+ # We're interested in the output hidden state only
136
+ output = encoder_outputs[0]
137
+
138
+ # There is a final layer norm we need to pass these through
139
+ output = self.text_encoder.text_model.final_layer_norm(output)
140
+
141
+ # And now they're ready!
142
+ return output
143
+
144
+
145
+ def pil_to_latent(self, input_im):
146
+ # Single image -> single latent in a batch (so size 1, 4, 64, 64)
147
+ with torch.no_grad():
148
+ latent = self.vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(self.torch_device)*2-1) # Note scaling
149
+ return 0.18215 * latent.latent_dist.sample()
150
+
151
+
152
+ def latents_to_pil(self, latents):
153
+ # bath of latents -> list of images
154
+ latents = (1 / 0.18215) * latents
155
+ with torch.no_grad():
156
+ image = self.vae.decode(latents).sample
157
+ image = (image / 2 + 0.5).clamp(0, 1)
158
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
159
+ images = (image * 255).round().astype("uint8")
160
+ pil_images = [Image.fromarray(image) for image in images]
161
+ return pil_images
162
+
163
+
164
+ def generate_image_with_custom_style(self, prompt, style_token_embedding=None, random_seed=41, custom_loss_fn = None, custom_loss_scale=None):
165
+ eos_pos = get_EOS_pos_in_prompt(prompt)
166
+
167
+ # tokenize
168
+ text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
169
+ max_length = text_input.input_ids.shape[-1]
170
+ input_ids = text_input.input_ids.to(self.torch_device)
171
+
172
+ # get token embeddings
173
+ token_emb_layer = self.text_encoder.text_model.embeddings.token_embedding
174
+ token_embeddings = token_emb_layer(input_ids)
175
+
176
+ # Append style token towards the end of the sentence embeddings
177
+ if style_token_embedding is not None:
178
+ token_embeddings[-1, eos_pos, :] = style_token_embedding
179
+
180
+ # combine with pos embs
181
+ pos_emb_layer = self.text_encoder.text_model.embeddings.position_embedding
182
+ position_ids = self.text_encoder.text_model.embeddings.position_ids[:, :77]
183
+ position_embeddings = pos_emb_layer(position_ids)
184
+ input_embeddings = token_embeddings + position_embeddings
185
+
186
+ # Feed through to get final output embs
187
+ modified_output_embeddings = self.get_output_embeds(input_embeddings)
188
+
189
+ # And generate an image with this:
190
+ generated_image = self.generate_with_embs(modified_output_embeddings, max_length, random_seed, custom_loss_fn, custom_loss_scale)
191
+ return generated_image
utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from matplotlib import pyplot as plt
3
+
4
+
5
+
6
+
7
+ def get_style_embeddings(style_file):
8
+ style_embed = torch.load(style_file)
9
+ style_name = list(style_embed.keys())[0]
10
+ return style_embed[style_name]
11
+
12
+
13
+ def get_EOS_pos_in_prompt(prompt):
14
+ return len(prompt.split())+1
15
+
16
+
17
+ def invert_loss(gen_image):
18
+ loss = torch.nn.functional.mse_loss(gen_image[:,0], gen_image[:,2]) + torch.nn.functional.mse_loss(gen_image[:,2], gen_image[:,1]) + torch.nn.functional.mse_loss(gen_image[:,0], gen_image[:,1])
19
+ return loss
20
+
21
+
22
+ def blue_loss(images):
23
+ # How far are the blue channel values to 0.9:
24
+ error = torch.abs(images[:,2] - 0.9).mean() # [:,2] -> all images in batch, only the blue channel
25
+ return error
26
+
27
+
28
+ def show_images(images_list):
29
+ # Let's visualize the four channels of this latent representation:
30
+ fig, axs = plt.subplots(1, len(images_list), figsize=(16, 4))
31
+ for c in range(len(images_list)):
32
+ axs[c].imshow(images_list[c])
33
+ plt.show()