Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
from diffusers import StableDiffusionPipeline
|
| 10 |
+
|
| 11 |
+
# Define Loss Functions (same as in your code)
|
| 12 |
+
def edge_loss(image_tensor):
|
| 13 |
+
grayscale = image_tensor.mean(dim=0, keepdim=True)
|
| 14 |
+
grayscale = grayscale.unsqueeze(0)
|
| 15 |
+
sobel_x = torch.tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]], device=image_tensor.device).float().unsqueeze(0).unsqueeze(0)
|
| 16 |
+
sobel_y = sobel_x.transpose(2, 3)
|
| 17 |
+
gx = F.conv2d(grayscale, sobel_x, padding=1)
|
| 18 |
+
gy = F.conv2d(grayscale, sobel_y, padding=1)
|
| 19 |
+
return -torch.mean(torch.sqrt(gx ** 2 + gy ** 2))
|
| 20 |
+
|
| 21 |
+
def texture_loss(image_tensor):
|
| 22 |
+
return F.mse_loss(image_tensor, torch.rand_like(image_tensor, device=image_tensor.device))
|
| 23 |
+
|
| 24 |
+
def entropy_loss(image_tensor):
|
| 25 |
+
hist = torch.histc(image_tensor, bins=256, min=0, max=255)
|
| 26 |
+
hist = hist / hist.sum()
|
| 27 |
+
return -torch.sum(hist * torch.log(hist + 1e-7))
|
| 28 |
+
|
| 29 |
+
def symmetry_loss(image_tensor):
|
| 30 |
+
width = image_tensor.shape[-1]
|
| 31 |
+
left_half = image_tensor[:, :, :width // 2]
|
| 32 |
+
right_half = torch.flip(image_tensor[:, :, width // 2:], dims=[-1])
|
| 33 |
+
return F.mse_loss(left_half, right_half)
|
| 34 |
+
|
| 35 |
+
def contrast_loss(image_tensor):
|
| 36 |
+
min_val = image_tensor.min()
|
| 37 |
+
max_val = image_tensor.max()
|
| 38 |
+
return -torch.mean((image_tensor - min_val) / (max_val - min_val + 1e-7))
|
| 39 |
+
|
| 40 |
+
# Setup Stable Diffusion Pipeline
|
| 41 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 42 |
+
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)
|
| 43 |
+
|
| 44 |
+
# Image transform to tensor
|
| 45 |
+
transform = transforms.ToTensor()
|
| 46 |
+
|
| 47 |
+
# Loss functions dictionary
|
| 48 |
+
losses = {
|
| 49 |
+
"edge": edge_loss,
|
| 50 |
+
"texture": texture_loss,
|
| 51 |
+
"entropy": entropy_loss,
|
| 52 |
+
"symmetry": symmetry_loss,
|
| 53 |
+
"contrast": contrast_loss
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
# Define function to generate images for a given seed
|
| 57 |
+
def generate_images(seed):
|
| 58 |
+
generator = torch.Generator(device).manual_seed(seed)
|
| 59 |
+
output_image = pipe("A futuristic city skyline at sunset", generator=generator).images[0]
|
| 60 |
+
|
| 61 |
+
# Convert to tensor
|
| 62 |
+
image_tensor = transform(output_image).to(device)
|
| 63 |
+
|
| 64 |
+
loss_images = []
|
| 65 |
+
loss_values = []
|
| 66 |
+
|
| 67 |
+
# Compute losses and generate modified images
|
| 68 |
+
for loss_name, loss_fn in losses.items():
|
| 69 |
+
loss_value = loss_fn(image_tensor)
|
| 70 |
+
|
| 71 |
+
# Resize to thumbnail size
|
| 72 |
+
thumbnail_image = output_image.copy()
|
| 73 |
+
thumbnail_image.thumbnail((128, 128))
|
| 74 |
+
|
| 75 |
+
# Save loss image with thumbnail
|
| 76 |
+
loss_images.append(thumbnail_image)
|
| 77 |
+
loss_values.append(f"{loss_name}: {loss_value.item():.4f}")
|
| 78 |
+
|
| 79 |
+
return loss_images, loss_values
|
| 80 |
+
|
| 81 |
+
# Gradio Interface
|
| 82 |
+
def gradio_interface(seed):
|
| 83 |
+
loss_images, loss_values = generate_images(int(seed))
|
| 84 |
+
return loss_images, loss_values
|
| 85 |
+
|
| 86 |
+
# Set up Gradio UI
|
| 87 |
+
interface = gr.Interface(
|
| 88 |
+
fn=gradio_interface,
|
| 89 |
+
inputs=gr.inputs.Textbox(label="Enter Seed"),
|
| 90 |
+
outputs=[gr.outputs.Gallery(label="Loss Images"), gr.outputs.Textbox(label="Loss Values")]
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Launch the interface
|
| 94 |
+
interface.launch()
|