Spaces:
Runtime error
Runtime error
Commit
·
8811dd9
1
Parent(s):
e85c7e3
Update app.py
Browse files
app.py
CHANGED
|
@@ -235,14 +235,6 @@ def show_images(images_list):
|
|
| 235 |
axs[c].imshow(images_list[c])
|
| 236 |
plt.show()
|
| 237 |
|
| 238 |
-
|
| 239 |
-
def invert_loss(gen_image):
|
| 240 |
-
inverter = T.RandomInvert(p=1.0)
|
| 241 |
-
inverted_img = inverter(gen_image)
|
| 242 |
-
#loss = torch.abs(gen_image - inverted_img).sum()
|
| 243 |
-
loss = torch.nn.functional.mse_loss(gen_image[:,0], gen_image[:,2]) + torch.nn.functional.mse_loss(gen_image[:,2], gen_image[:,1]) + torch.nn.functional.mse_loss(gen_image[:,0], gen_image[:,1])
|
| 244 |
-
return loss
|
| 245 |
-
|
| 246 |
def brilliance_loss(image, target_brilliance=10):
|
| 247 |
# Calculate the standard deviation of color channels
|
| 248 |
std_dev = torch.std(image, dim=(2, 3))
|
|
@@ -252,6 +244,42 @@ def brilliance_loss(image, target_brilliance=10):
|
|
| 252 |
loss = torch.abs(mean_std_dev - target_brilliance)
|
| 253 |
return loss
|
| 254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
def display_images_in_rows(images_with_titles, titles):
|
| 257 |
num_images = len(images_with_titles)
|
|
@@ -280,41 +308,46 @@ def display_images_in_rows(images_with_titles, titles):
|
|
| 280 |
# plt.show()
|
| 281 |
|
| 282 |
|
| 283 |
-
def image_generator(prompt
|
| 284 |
-
|
| 285 |
-
|
| 286 |
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
if loss_function:
|
| 291 |
-
generated_img = generate_image_custom_style(prompt,style_num = i,random_seed = seed_values[i],custom_loss_fn = loss_function)
|
| 292 |
-
images_with_loss.append(generated_img)
|
| 293 |
|
| 294 |
-
|
| 295 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
if images_with_loss != []:
|
| 300 |
-
generated_sd_images.append((images_with_loss[i], titles[i]))
|
| 301 |
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
-
|
| 305 |
-
def image_generator_wrapper(prompt = "dog", loss_function=None):
|
| 306 |
-
if loss_function == "Yes":
|
| 307 |
-
loss_function = brilliance_loss
|
| 308 |
-
else:
|
| 309 |
-
loss_function = None
|
| 310 |
|
| 311 |
-
|
|
|
|
|
|
|
| 312 |
|
| 313 |
description = 'Stable Diffusion is a generative artificial intelligence (generative AI) model that produces unique photorealistic images from text and image prompts.'
|
| 314 |
title = 'Image Generation using Stable Diffusion'
|
| 315 |
|
| 316 |
demo = gr.Interface(image_generator_wrapper,
|
| 317 |
inputs=[gr.Textbox(label="Enter prompt for generation", type="text", value="A ballerina cat dancing in space"),
|
| 318 |
-
gr.Radio(["
|
| 319 |
-
outputs=gr.Plot(label="Generated Images"),
|
|
|
|
|
|
|
| 320 |
demo.launch()
|
|
|
|
| 235 |
axs[c].imshow(images_list[c])
|
| 236 |
plt.show()
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
def brilliance_loss(image, target_brilliance=10):
|
| 239 |
# Calculate the standard deviation of color channels
|
| 240 |
std_dev = torch.std(image, dim=(2, 3))
|
|
|
|
| 244 |
loss = torch.abs(mean_std_dev - target_brilliance)
|
| 245 |
return loss
|
| 246 |
|
| 247 |
+
import numpy as np
|
| 248 |
+
from PIL import Image
|
| 249 |
+
|
| 250 |
+
import torch
|
| 251 |
+
from scipy.stats import wasserstein_distance
|
| 252 |
+
|
| 253 |
+
def exposure_loss(image, target_exposure = 3):
|
| 254 |
+
# Calculate the brightness (exposure) of the image.
|
| 255 |
+
image_brightness = torch.mean(image)
|
| 256 |
+
|
| 257 |
+
# Calculate the loss as the absolute difference from the target exposure.
|
| 258 |
+
loss = torch.abs(image_brightness - target_exposure)
|
| 259 |
+
return loss
|
| 260 |
+
|
| 261 |
+
def color_diversity_loss(images):
|
| 262 |
+
# Calculate color diversity by measuring the variance of color channels (R, G, B).
|
| 263 |
+
color_variance = torch.var(images, dim=(2, 3), keepdim=True)
|
| 264 |
+
# Sum the color variances for each channel to get the total color diversity.
|
| 265 |
+
total_color_diversity = torch.sum(color_variance, dim=1)
|
| 266 |
+
return total_color_diversity
|
| 267 |
+
|
| 268 |
+
def sharpness_loss(images):
|
| 269 |
+
# Apply the Laplacian filter to the images to measure sharpness.
|
| 270 |
+
laplacian_filter = torch.Tensor([[-1, -1, -1],
|
| 271 |
+
[-1, 8, -1],
|
| 272 |
+
[-1, -1, -1]]).view(1, 1, 3, 3).to(images.device)
|
| 273 |
+
|
| 274 |
+
# Expand the filter to match the number of channels in the input image.
|
| 275 |
+
laplacian_filter = laplacian_filter.expand(-1, images.shape[1], -1, -1)
|
| 276 |
+
|
| 277 |
+
# Apply the convolution operation.
|
| 278 |
+
laplacian = torch.abs(F.conv2d(images, laplacian_filter))
|
| 279 |
+
|
| 280 |
+
# Calculate sharpness as the negative of the Laplacian variance.
|
| 281 |
+
sharpness = torch.var(laplacian)
|
| 282 |
+
return sharpness
|
| 283 |
|
| 284 |
def display_images_in_rows(images_with_titles, titles):
|
| 285 |
num_images = len(images_with_titles)
|
|
|
|
| 308 |
# plt.show()
|
| 309 |
|
| 310 |
|
| 311 |
+
def image_generator(prompt="cat", loss_function=None):
|
| 312 |
+
images_without_loss = []
|
| 313 |
+
images_with_loss = []
|
| 314 |
|
| 315 |
+
for i in range(num_styles):
|
| 316 |
+
generated_img = generate_image_custom_style(prompt, style_num=i, random_seed=seed_values[i], custom_loss_fn=None)
|
| 317 |
+
images_without_loss.append(generated_img)
|
|
|
|
|
|
|
|
|
|
| 318 |
|
| 319 |
+
if loss_function:
|
| 320 |
+
if loss_function == "exposure_loss":
|
| 321 |
+
generated_img = generate_image_custom_style(prompt, style_num=i, random_seed=seed_values[i], custom_loss_fn=exposure_loss)
|
| 322 |
+
elif loss_function == "color_diversity_loss":
|
| 323 |
+
generated_img = generate_image_custom_style(prompt, style_num=i, random_seed=seed_values[i], custom_loss_fn=color_diversity_loss)
|
| 324 |
+
elif loss_function == "sharpness_loss":
|
| 325 |
+
generated_img = generate_image_custom_style(prompt, style_num=i, random_seed=seed_values[i], custom_loss_fn=sharpness_loss)
|
| 326 |
+
elif loss_function == "brilliance_loss":
|
| 327 |
+
generated_img = generate_image_custom_style(prompt, style_num=i, random_seed=seed_values[i], custom_loss_fn=brilliance_loss)
|
| 328 |
+
images_with_loss.append(generated_img)
|
| 329 |
|
| 330 |
+
generated_sd_images = []
|
| 331 |
+
titles = ["animal toy", "fft style", "mid journey", "oil style", "Space style"]
|
|
|
|
|
|
|
| 332 |
|
| 333 |
+
for i in range(len(titles)):
|
| 334 |
+
generated_sd_images.append((images_without_loss[i], titles[i]))
|
| 335 |
+
if images_with_loss:
|
| 336 |
+
generated_sd_images.append((images_with_loss[i], titles[i]))
|
| 337 |
|
| 338 |
+
return generated_sd_images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
|
| 340 |
+
# Create a wrapper function for image_generator()
|
| 341 |
+
def image_generator_wrapper(prompt="dog", selected_loss="None"):
|
| 342 |
+
return image_generator(prompt, selected_loss)
|
| 343 |
|
| 344 |
description = 'Stable Diffusion is a generative artificial intelligence (generative AI) model that produces unique photorealistic images from text and image prompts.'
|
| 345 |
title = 'Image Generation using Stable Diffusion'
|
| 346 |
|
| 347 |
demo = gr.Interface(image_generator_wrapper,
|
| 348 |
inputs=[gr.Textbox(label="Enter prompt for generation", type="text", value="A ballerina cat dancing in space"),
|
| 349 |
+
gr.Radio(["None", "exposure_loss", "color_diversity_loss", "sharpness_loss", "brilliance_loss"], value="None", label="Select Loss")],
|
| 350 |
+
outputs=gr.Plot(label="Generated Images"),
|
| 351 |
+
title=title,
|
| 352 |
+
description=description)
|
| 353 |
demo.launch()
|