saneshashank commited on
Commit
4a86ebc
·
verified ·
1 Parent(s): 37b2b46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -1
app.py CHANGED
@@ -25,6 +25,24 @@ def color_loss(images, target_color=(0.1, 0.9, 0.5)):
25
  error = torch.abs(images - target).mean() # Mean absolute difference between the image pixels and the target color
26
  return error
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # And the core function to generate an image given the relevant inputs
29
  def generate(color, guidance_loss_scale):
30
  target_color = ImageColor.getcolor(color, "RGB") # Target color as RGB
@@ -36,7 +54,8 @@ def generate(color, guidance_loss_scale):
36
  noise_pred = image_pipe.unet(model_input, t)["sample"]
37
  x = x.detach().requires_grad_()
38
  x0 = scheduler.step(noise_pred, t, x).pred_original_sample
39
- loss = color_loss(x0, target_color) * guidance_loss_scale
 
40
  cond_grad = -torch.autograd.grad(loss, x)[0]
41
  x = x.detach() + cond_grad
42
  x = scheduler.step(noise_pred, t, x).prev_sample
 
25
  error = torch.abs(images - target).mean() # Mean absolute difference between the image pixels and the target color
26
  return error
27
 
28
+ def monochromatic_loss(images, threshold=0.5, target_value=0.01):
29
+ # Convert images to grayscale (simple average of channels)
30
+ # We assume images are [N, C, H, W] where C=3 (RGB)
31
+ grayscale_images = (images[:,0,:,:] + images[:,1,:,:] + images[:,2,:,:]) / 3.0
32
+
33
+ # Penalize pixels that are not close to black or white
34
+ # Encourage values close to target_value (e.g., 0.01 for black) or 1.0 (for white)
35
+ # This creates a strong push towards high contrast
36
+ loss_black = torch.abs(grayscale_images - target_value)
37
+ loss_white = torch.abs(grayscale_images - (1.0 - target_value))
38
+
39
+ # For each pixel, take the minimum deviation from either black or white
40
+ min_deviation = torch.min(loss_black, loss_white)
41
+
42
+ # We want to minimize this deviation across the image
43
+ loss = min_deviation.mean()
44
+ return loss
45
+
46
  # And the core function to generate an image given the relevant inputs
47
  def generate(color, guidance_loss_scale):
48
  target_color = ImageColor.getcolor(color, "RGB") # Target color as RGB
 
54
  noise_pred = image_pipe.unet(model_input, t)["sample"]
55
  x = x.detach().requires_grad_()
56
  x0 = scheduler.step(noise_pred, t, x).pred_original_sample
57
+ # loss = color_loss(x0, target_color) * guidance_loss_scale
58
+ loss = monochromatic_loss(x0)
59
  cond_grad = -torch.autograd.grad(loss, x)[0]
60
  x = x.detach() + cond_grad
61
  x = scheduler.step(noise_pred, t, x).prev_sample