shashnk commited on
Commit
757129e
·
1 Parent(s): f56e0f0

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. app.py +234 -0
  3. config.py +20 -0
  4. requirements.txt +0 -0
  5. screenshot.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ screenshot.png filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from base64 import b64encode
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
6
+ from huggingface_hub import notebook_login
7
+
8
+ # For video display:
9
+ from matplotlib import pyplot as plt
10
+ from pathlib import Path
11
+ from PIL import Image
12
+ from torch import autocast
13
+ from torchvision import transforms as tfms
14
+ from tqdm.auto import tqdm
15
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
16
+ import os
17
+
18
+ from config import RADIO_OPTIONS, MAPPING
19
+
20
+ import streamlit as st
21
+
22
+
23
+ torch.manual_seed(1)
24
+ if not (Path.home()/'.cache/huggingface'/'token').exists(): notebook_login()
25
+
26
+ # Supress some unnecessary warnings when loading the CLIPTextModel
27
+ logging.set_verbosity_error()
28
+
29
+ # Set device
30
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ # if "mps" == torch_device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
32
+
33
+
34
+ # Load the autoencoder model which will be used to decode the latents into image space.
35
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
36
+
37
+ # Load the tokenizer and text encoder to tokenize and encode the text.
38
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
39
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
40
+
41
+ # The UNet model for generating the latents.
42
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
43
+
44
+ # The noise scheduler
45
+ scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
46
+
47
+ # To the GPU we go!
48
+ vae = vae.to(torch_device)
49
+ text_encoder = text_encoder.to(torch_device)
50
+ unet = unet.to(torch_device)
51
+
52
+
53
+
54
+ import streamlit as st
55
+
56
+ st.markdown('<h1 style="text-align: center;">Dreamstream</h1>', unsafe_allow_html=True)
57
+
58
+ col1, col2 = st.columns([3,1])
59
+ prompt = col1.text_input("Imagine...")
60
+ dropdown_value = col2.selectbox("Style", RADIO_OPTIONS, index=0)
61
+
62
+ prompt += prompt + f" in style of {dropdown_value}"
63
+ prompt = prompt.lower()
64
+
65
+ generate = st.button("Generate")
66
+
67
+ if generate:
68
+
69
+ def pil_to_latent(input_im):
70
+ # Single image -> single latent in a batch (so size 1, 4, 64, 64)
71
+ with torch.no_grad():
72
+ latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
73
+ return 0.18215 * latent.latent_dist.sample()
74
+
75
+
76
+ def latents_to_pil(latents):
77
+ # bath of latents -> list of images
78
+ latents = (1 / 0.18215) * latents
79
+ with torch.no_grad():
80
+ image = vae.decode(latents).sample
81
+ image = (image / 2 + 0.5).clamp(0, 1)
82
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
83
+ images = (image * 255).round().astype("uint8")
84
+ pil_images = [Image.fromarray(image) for image in images]
85
+ return pil_images
86
+
87
+
88
+ def set_timesteps(scheduler, num_inference_steps):
89
+ scheduler.set_timesteps(num_inference_steps)
90
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32)
91
+
92
+ def get_output_embeds(input_embeddings):
93
+ # CLIP's text model uses causal mask, so we prepare it here:
94
+ bsz, seq_len = input_embeddings.shape[:2]
95
+ causal_attention_mask = text_encoder.text_model._build_causal_attention_mask(bsz, seq_len, dtype=input_embeddings.dtype)
96
+
97
+ # Getting the output embeddings involves calling the model with passing output_hidden_states=True
98
+ # so that it doesn't just return the pooled final predictions:
99
+ encoder_outputs = text_encoder.text_model.encoder(
100
+ inputs_embeds=input_embeddings,
101
+ attention_mask=None, # We aren't using an attention mask so that can be None
102
+ causal_attention_mask=causal_attention_mask.to(torch_device),
103
+ output_attentions=None,
104
+ output_hidden_states=True, # We want the output embs not the final output
105
+ return_dict=None,
106
+ )
107
+
108
+ # We're interested in the output hidden state only
109
+ output = encoder_outputs[0]
110
+
111
+ # There is a final layer norm we need to pass these through
112
+ output = text_encoder.text_model.final_layer_norm(output)
113
+
114
+ # And now they're ready!
115
+ return output
116
+
117
+ def saturation_loss(images):
118
+ red_variance = images[:, 0].var()
119
+ green_variance = images[:, 1].var()
120
+ blue_variance = images[:, 2].var()
121
+ return -(red_variance + green_variance + blue_variance) # Negative because we want to maximize variance
122
+
123
+
124
+ def generate_with_embs(text_embeddings, use_saturation_loss=True):
125
+ height = 512 # default height of Stable Diffusion
126
+ width = 512 # default width of Stable Diffusion
127
+ num_inference_steps = 50 # Number of denoising steps
128
+ guidance_scale = 8 # Scale for classifier-free guidance
129
+ generator = torch.manual_seed(32) # Seed generator to create the inital latent noise
130
+ batch_size = 1
131
+ saturation_loss_scale = 200
132
+
133
+ uncond_input = tokenizer(
134
+ [""] * batch_size, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt"
135
+ )
136
+ with torch.no_grad():
137
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
138
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
139
+
140
+ # Prep Scheduler
141
+ set_timesteps(scheduler, num_inference_steps)
142
+
143
+ # Prep latents
144
+ latents = torch.randn(
145
+ (batch_size, unet.in_channels, height // 8, width // 8),
146
+ generator=generator,
147
+ )
148
+ latents = latents.to(torch_device)
149
+ latents = latents * scheduler.init_noise_sigma
150
+
151
+ # Loop
152
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
153
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
154
+ latent_model_input = torch.cat([latents] * 2)
155
+ sigma = scheduler.sigmas[i]
156
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
157
+
158
+ # predict the noise residual
159
+ with torch.no_grad():
160
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
161
+
162
+ print("Shape of noise_pred:", noise_pred.shape)
163
+
164
+ # perform CFG
165
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
166
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
167
+
168
+ #### ADDITIONAL GUIDANCE ###
169
+ # if i%5 == 0:
170
+ if use_saturation_loss and i%5 == 0:
171
+ # Requires grad on the latents
172
+ latents = latents.detach().requires_grad_()
173
+
174
+ # Get the predicted x0:
175
+ latents_x0 = latents - sigma * noise_pred
176
+ # latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
177
+
178
+ # Decode to image space
179
+ denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
180
+
181
+ # Calculate loss
182
+ loss = saturation_loss(denoised_images) * saturation_loss_scale
183
+
184
+ # Occasionally print it out
185
+ if i%10==0:
186
+ print(i, 'loss:', loss.item())
187
+
188
+ # Get gradient
189
+ cond_grad = torch.autograd.grad(loss, latents)[0]
190
+
191
+ # Modify the latents based on this gradient
192
+ latents = latents.detach() - cond_grad * sigma**2
193
+
194
+ # Now step with scheduler
195
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
196
+
197
+ return latents_to_pil(latents)[0]
198
+
199
+
200
+
201
+ illustration_embed = torch.load(list(MAPPING[dropdown_value].values())[0])
202
+
203
+ # Tokenize
204
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
205
+ input_ids = text_input.input_ids.to(torch_device)
206
+
207
+ # Get token embeddings
208
+ token_emb_layer = text_encoder.text_model.embeddings.token_embedding
209
+ token_embeddings = token_emb_layer(input_ids)
210
+
211
+ # The new embedding. Which is now a mixture of the token embeddings for 'puppy' and 'skunk'
212
+ replacement_token_embedding = illustration_embed[list(MAPPING[dropdown_value].keys())[0]].to(torch_device)
213
+
214
+ # Insert this into the token embeddings
215
+ token_embeddings[0, torch.where(input_ids[0]==6829)[0]] = replacement_token_embedding.to(torch_device)
216
+
217
+ # Combine with pos embs
218
+ pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
219
+ position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
220
+ position_embeddings = pos_emb_layer(position_ids)
221
+ input_embeddings = token_embeddings + position_embeddings
222
+
223
+ # Feed through to get final output embs
224
+ modified_output_embeddings = get_output_embeds(input_embeddings)
225
+
226
+ col7, col8 = st.columns([1,1])
227
+ # Generate an image with saturation_loss
228
+ with_loss_image = generate_with_embs(modified_output_embeddings, use_saturation_loss=True)
229
+ col7.image(with_loss_image, caption="With Saturation Loss", use_column_width=True, channels="RGB")
230
+
231
+ # Generate an image without saturation_loss
232
+ without_loss_image = generate_with_embs(modified_output_embeddings, use_saturation_loss=False)
233
+ col8.image(without_loss_image, caption="Without Saturation Loss", use_column_width=True, channels="RGB")
234
+
config.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MAPPING = {
2
+ "MIDJOURNEY-STYLE": {
3
+ "<midjourney-style>": "./midjourney-style/learned embeddings.bin"
4
+ },
5
+ "FAIRY-TALE": {
6
+ "<fairy-tale-painting-style>": "./fairy-tale-painting-style/learned embeddings.bin"
7
+ },
8
+ "ILLUSTRATION": {
9
+ "<illustration_style>": "./illustration_style/learned embeddings.bin"
10
+ },
11
+ "KUVSHINOV": {
12
+ "<kuvshinov>": "./kuvshinov/learned embeddings.bin"
13
+ },
14
+ "MARC-ALLANTE": {
15
+ "<Marc Allante>": "./style-of-marc-allante/learned embeddings.bin"
16
+ }
17
+
18
+ }
19
+
20
+ RADIO_OPTIONS = ["MIDJOURNEY-STYLE", "FAIRY-TALE", "ILLUSTRATION", "KUVSHINOV", "MARC-ALLANTE"]
requirements.txt ADDED
File without changes
screenshot.png ADDED

Git LFS Details

  • SHA256: a493336cf605092ab72e0c87eda4620ad934e5ff225cb497f1f99cf95a85af4c
  • Pointer size: 132 Bytes
  • Size of remote file: 2.04 MB