Update app.py
Browse files
app.py
CHANGED
|
@@ -25,6 +25,24 @@ def color_loss(images, target_color=(0.1, 0.9, 0.5)):
|
|
| 25 |
error = torch.abs(images - target).mean() # Mean absolute difference between the image pixels and the target color
|
| 26 |
return error
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
# And the core function to generate an image given the relevant inputs
|
| 29 |
def generate(color, guidance_loss_scale):
|
| 30 |
target_color = ImageColor.getcolor(color, "RGB") # Target color as RGB
|
|
@@ -36,7 +54,8 @@ def generate(color, guidance_loss_scale):
|
|
| 36 |
noise_pred = image_pipe.unet(model_input, t)["sample"]
|
| 37 |
x = x.detach().requires_grad_()
|
| 38 |
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
|
| 39 |
-
loss = color_loss(x0, target_color) * guidance_loss_scale
|
|
|
|
| 40 |
cond_grad = -torch.autograd.grad(loss, x)[0]
|
| 41 |
x = x.detach() + cond_grad
|
| 42 |
x = scheduler.step(noise_pred, t, x).prev_sample
|
|
|
|
| 25 |
error = torch.abs(images - target).mean() # Mean absolute difference between the image pixels and the target color
|
| 26 |
return error
|
| 27 |
|
| 28 |
+
def monochromatic_loss(images, threshold=0.5, target_value=0.01):
|
| 29 |
+
# Convert images to grayscale (simple average of channels)
|
| 30 |
+
# We assume images are [N, C, H, W] where C=3 (RGB)
|
| 31 |
+
grayscale_images = (images[:,0,:,:] + images[:,1,:,:] + images[:,2,:,:]) / 3.0
|
| 32 |
+
|
| 33 |
+
# Penalize pixels that are not close to black or white
|
| 34 |
+
# Encourage values close to target_value (e.g., 0.01 for black) or 1.0 (for white)
|
| 35 |
+
# This creates a strong push towards high contrast
|
| 36 |
+
loss_black = torch.abs(grayscale_images - target_value)
|
| 37 |
+
loss_white = torch.abs(grayscale_images - (1.0 - target_value))
|
| 38 |
+
|
| 39 |
+
# For each pixel, take the minimum deviation from either black or white
|
| 40 |
+
min_deviation = torch.min(loss_black, loss_white)
|
| 41 |
+
|
| 42 |
+
# We want to minimize this deviation across the image
|
| 43 |
+
loss = min_deviation.mean()
|
| 44 |
+
return loss
|
| 45 |
+
|
| 46 |
# And the core function to generate an image given the relevant inputs
|
| 47 |
def generate(color, guidance_loss_scale):
|
| 48 |
target_color = ImageColor.getcolor(color, "RGB") # Target color as RGB
|
|
|
|
| 54 |
noise_pred = image_pipe.unet(model_input, t)["sample"]
|
| 55 |
x = x.detach().requires_grad_()
|
| 56 |
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
|
| 57 |
+
# loss = color_loss(x0, target_color) * guidance_loss_scale
|
| 58 |
+
loss = monochromatic_loss(x0)
|
| 59 |
cond_grad = -torch.autograd.grad(loss, x)[0]
|
| 60 |
x = x.detach() + cond_grad
|
| 61 |
x = scheduler.step(noise_pred, t, x).prev_sample
|