tulsi0897 commited on
Commit
ef6397b
·
1 Parent(s): 9e4e89a

adding app.py

Browse files
Files changed (1) hide show
  1. app.py +240 -0
app.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import IPython.display as display
4
+ import matplotlib.pyplot as plt
5
+ from base64 import b64encode
6
+ import numpy
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
10
+ from huggingface_hub import notebook_login
11
+
12
+ # For video display:
13
+ from IPython.display import HTML
14
+ from matplotlib import pyplot as plt
15
+ from pathlib import Path
16
+ from PIL import Image
17
+ from torch import autocast
18
+ from torchvision import transforms as tfms
19
+ from tqdm.auto import tqdm
20
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
21
+ import os
22
+
23
+ torch.manual_seed(1)
24
+
25
+ # Supress some unnecessary warnings when loading the CLIPTextModel
26
+ logging.set_verbosity_error()
27
+
28
+ # Set device
29
+ torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
30
+ if "mps" == torch_device: os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
31
+
32
+ # Load the autoencoder model which will be used to decode the latents into image space.
33
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
34
+
35
+ # Load the tokenizer and text encoder to tokenize and encode the text.
36
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
37
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
38
+
39
+ # The UNet model for generating the latents.
40
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
41
+
42
+ # The noise scheduler
43
+ scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
44
+
45
+ # To the GPU we go!
46
+ vae = vae.to(torch_device)
47
+ text_encoder = text_encoder.to(torch_device)
48
+ unet = unet.to(torch_device);
49
+
50
+ def pil_to_latent(input_im):
51
+ # Single image -> single latent in a batch (so size 1, 4, 64, 64)
52
+ with torch.no_grad():
53
+ latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
54
+ return 0.18215 * latent.latent_dist.sample()
55
+
56
+ def latents_to_pil(latents):
57
+ # bath of latents -> list of images
58
+ latents = (1 / 0.18215) * latents
59
+ with torch.no_grad():
60
+ image = vae.decode(latents).sample
61
+ image = (image / 2 + 0.5).clamp(0, 1)
62
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
63
+ images = (image * 255).round().astype("uint8")
64
+ pil_images = [Image.fromarray(image) for image in images]
65
+ return pil_images
66
+
67
+ # Prep Scheduler
68
+ def set_timesteps(scheduler, num_inference_steps):
69
+ scheduler.set_timesteps(num_inference_steps)
70
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility, fixed in diffusers PR 3925
71
+
72
+ def blue_loss(images):
73
+ # How far are the blue channel values to 0.9:
74
+ error = torch.abs(images[:,2] - 0.9).mean() # [:,2] -> all images in batch, only the blue channel
75
+ return error
76
+
77
+ def diversity_loss(images):
78
+ # Calculate the pairwise L2 distances between images
79
+ pairwise_distances = torch.norm(images.unsqueeze(1) - images.unsqueeze(0), p=2, dim=3)
80
+ # Encourage diversity by minimizing the mean distance
81
+ diversity_loss = torch.mean(pairwise_distances)
82
+ return diversity_loss
83
+
84
+ def red_loss(images):
85
+ # How far are the red channel values to a target value (e.g., 0.7):
86
+ error = torch.abs(images[:, 0] - 0.7).mean() # [:, 0] -> all images in batch, only the red channel
87
+ return error
88
+
89
+ def green_loss(images):
90
+ # How far are the green channel values to a target value (e.g., 0.8):
91
+ error = torch.abs(images[:, 1] - 0.8).mean() # [:, 1] -> all images in batch, only the green channel
92
+ return error
93
+
94
+ def saturation_loss(images, target_saturation=0.5):
95
+ # Calculate the saturation of each image (based on color intensity)
96
+ saturation = images.max(dim=3)[0] - images.min(dim=3)[0]
97
+ # Calculate the mean absolute difference from the target saturation
98
+ loss = torch.abs(saturation - target_saturation).mean()
99
+ return loss
100
+
101
+ def brightness_loss(images, target_brightness=0.6):
102
+ # Calculate the brightness of each image (e.g., average pixel intensity)
103
+ brightness = images.mean(dim=(2, 3))
104
+ # Calculate the mean squared error from the target brightness
105
+ loss = (brightness - target_brightness).pow(2).mean()
106
+ return loss
107
+
108
+ def edge_detection_loss(images):
109
+ # Use Sobel filters to compute image gradients in x and y directions
110
+ gradient_x = F.conv2d(images, torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=images.dtype).view(1, 1, 3, 3), padding=1)
111
+ gradient_y = F.conv2d(images, torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=images.dtype).view(1, 1, 3, 3), padding=1)
112
+ # Calculate the magnitude of the gradients
113
+ gradient_magnitude = torch.sqrt(gradient_x**2 + gradient_y**2)
114
+ # Encourage a specific level of edge presence
115
+ loss = gradient_magnitude.mean()
116
+ return loss
117
+
118
+ def noise_regularization_loss(images, noise_std=0.1):
119
+ # Calculate the mean squared error of the image against noisy versions of itself
120
+ noisy_images = images + noise_std * torch.randn_like(images)
121
+ loss = torch.mean((images - noisy_images).pow(2))
122
+ return loss
123
+
124
+ def image_generation(prompt, loss_fxn):
125
+ generated_image = []
126
+ seed_list = [8, 16, 32, 64, 128]
127
+ for seed in seed_list:
128
+ latents_values = []
129
+ height = 512 # default height of Stable Diffusion
130
+ width = 512
131
+ num_inference_steps = 50
132
+ guidance_scale = 8 # default width of Stable Diffusion
133
+ num_inference_steps = num_inference_steps
134
+ guidance_scale = guidance_scale
135
+ batch_size = 1
136
+ blue_loss_scale = 200 #param
137
+ generator = torch.manual_seed(seed)
138
+
139
+ # Prep text
140
+ text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
141
+ with torch.no_grad():
142
+ text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
143
+
144
+ # And the uncond. input as before:
145
+ max_length = text_input.input_ids.shape[-1]
146
+ uncond_input = tokenizer(
147
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
148
+ )
149
+ with torch.no_grad():
150
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
151
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
152
+
153
+ # Prep Scheduler
154
+ set_timesteps(scheduler, num_inference_steps)
155
+
156
+ # Prep latents
157
+ latents = torch.randn(
158
+ (batch_size, unet.in_channels, height // 8, width // 8),
159
+ generator=generator,
160
+ )
161
+ latents = latents.to(torch_device)
162
+ latents = latents * scheduler.init_noise_sigma
163
+
164
+ # Loop
165
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
166
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
167
+ latent_model_input = torch.cat([latents] * 2)
168
+ sigma = scheduler.sigmas[i]
169
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
170
+
171
+ # predict the noise residual
172
+ with torch.no_grad():
173
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
174
+
175
+ # perform CFG
176
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
177
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
178
+
179
+ #### ADDITIONAL GUIDANCE ###
180
+ if i%5 == 0:
181
+ # Requires grad on the latents
182
+ latents = latents.detach().requires_grad_()
183
+
184
+ # Get the predicted x0:
185
+ latents_x0 = latents - sigma * noise_pred
186
+ #latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
187
+
188
+ # Decode to image space
189
+ denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
190
+
191
+ # Calculate loss
192
+ loss = blue_loss(denoised_images) * blue_loss_scale
193
+
194
+ # Occasionally print it out
195
+ # if i%10==0:
196
+ # print(i, 'loss:', loss.item())
197
+
198
+ # Get gradient
199
+ cond_grad = torch.autograd.grad(loss, latents)[0]
200
+
201
+ # Modify the latents based on this gradient
202
+ latents = latents.detach() - cond_grad * sigma**2
203
+
204
+ # Now step with scheduler
205
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
206
+ generated_image.append(latents_to_pil(latents)[0])
207
+ latents_values.append(latents)
208
+
209
+ return generated_image, latents_values
210
+
211
+
212
+ # Create a Gradio interface
213
+ iface = gr.Interface(
214
+ fn=image_generation,
215
+ inputs=[
216
+ # gr.inputs.CheckboxGroup(
217
+ # label="Seed List", choices=[8, 32, 64, 128, 256], type="number"
218
+ # ),
219
+ gr.inputs.Textbox(label="Prompt Input"),
220
+ gr.inputs.Radio(
221
+ label="Loss Function",
222
+ choices=[
223
+ "Diversity Loss",
224
+ "Saturation Loss",
225
+ "Brightness Loss",
226
+ "Edge Detection Loss",
227
+ "Noise Regularization Loss",
228
+ "Blue Loss",
229
+ "Red Loss",
230
+ "Green Loss"
231
+ ],
232
+ ),
233
+ ],
234
+ outputs=gr.outputs.Image(type="pil", label="Generated Images"),
235
+ title="Stable Diffusion Guided by Loss Function Image Generation with Gradio",
236
+ description="Enter parameters to generate images using Stable Diffusion with optional loss functions.",
237
+ )
238
+
239
+ # Launch the Gradio interface
240
+ iface.launch()