WiNE-iNEFF commited on
Commit
498d77b
·
1 Parent(s): 13c420f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -3
app.py CHANGED
@@ -26,6 +26,11 @@ def show_images_save(x):
26
  grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
27
  return grid_im
28
 
 
 
 
 
 
29
 
30
  def generate():
31
  scheduler = DDIMScheduler.from_pretrained(pipeline_name)
@@ -38,6 +43,21 @@ def generate():
38
  x = scheduler.step(noise_pred, t, x).prev_sample
39
  return show_images_save(x)
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  def crrop(file):
43
  width, height = file.size
@@ -48,14 +68,17 @@ def crrop(file):
48
  sav.append(file.crop(box))
49
  return sav
50
 
51
-
52
  def ex():
53
  t = time()
54
  print(ctime(t))
55
  return crrop(generate())
56
 
 
 
 
 
57
 
58
- demo = gr.Blocks(css=".container {max-width: 730px;margin: auto;} #gallery {margin-bottom: 15px; margin-left: auto; margin-right: auto; border-bottom-right-radius: .5rem !important; border-bottom-left-radius: .5rem !important;}")
59
 
60
  with demo:
61
  gr.HTML(
@@ -72,10 +95,18 @@ with demo:
72
  with gr.Tabs():
73
  with gr.TabItem("V1"):
74
  with gr.Column():
75
- with gr.Row().style(equal_height=True):
76
  gall = gr.Gallery(elem_id='gallery').style(grid=[4])
77
  greet_btn = gr.Button("Generate")
78
  greet_btn.click(fn=ex, outputs=gall)
 
 
 
 
 
 
 
 
79
  gr.HTML(
80
  """
81
  <div class="footer">
 
26
  grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
27
  return grid_im
28
 
29
+ def color_loss(images, target_color=(0.1, 0.9, 0.5, 1)):
30
+ target = (torch.tensor(target_color).to(images.device) * 2 - 1)
31
+ target = target[None, :, None, None]
32
+ error = torch.abs(images - target).mean()
33
+ return error
34
 
35
  def generate():
36
  scheduler = DDIMScheduler.from_pretrained(pipeline_name)
 
43
  x = scheduler.step(noise_pred, t, x).prev_sample
44
  return show_images_save(x)
45
 
46
+ def generate_g(color, guidance_loss_scale):
47
+ target_color = ImageColor.getcolor(color, "RGBA") # Target color as RGB
48
+ target_color = [a / 255 for a in target_color] # Rescale from (0, 255) to (0, 1)
49
+ x = torch.randn(8, 4, 64, 64).to(device)
50
+ for i, t in tqdm(enumerate(scheduler.timesteps)):
51
+ model_input = scheduler.scale_model_input(x, t)
52
+ with torch.no_grad():
53
+ noise_pred = image_pipe.unet(model_input, t)["sample"]
54
+ x = x.detach().requires_grad_()
55
+ x0 = scheduler.step(noise_pred, t, x).pred_original_sample
56
+ loss = color_loss(x0, target_color) * guidance_loss_scale
57
+ cond_grad = -torch.autograd.grad(loss, x)[0]
58
+ x = x.detach() + cond_grad
59
+ x = scheduler.step(noise_pred, t, x).prev_sample
60
+ return show_images_save(x)
61
 
62
  def crrop(file):
63
  width, height = file.size
 
68
  sav.append(file.crop(box))
69
  return sav
70
 
 
71
  def ex():
72
  t = time()
73
  print(ctime(t))
74
  return crrop(generate())
75
 
76
+ def ex_g(picker, slider):
77
+ t = time()
78
+ print(ctime(t))
79
+ return crrop(generate_g(picker, slider))
80
 
81
+ demo = gr.Blocks(css="#row-1 {max-height: 300px !important;} .container {max-width: 730px;margin: auto;} #gallery {margin-bottom: 15px; margin-left: auto; margin-right: auto; border-bottom-right-radius: .5rem !important; border-bottom-left-radius: .5rem !important;}")
82
 
83
  with demo:
84
  gr.HTML(
 
95
  with gr.Tabs():
96
  with gr.TabItem("V1"):
97
  with gr.Column():
98
+ with gr.Row(elem_id='row-1').style(equal_height=True):
99
  gall = gr.Gallery(elem_id='gallery').style(grid=[4])
100
  greet_btn = gr.Button("Generate")
101
  greet_btn.click(fn=ex, outputs=gall)
102
+ with gr.TabItem("V2"):
103
+ with gr.Column():
104
+ with gr.Row(elem_id='row-1').style(equal_height=True):
105
+ picker = gr.ColorPicker(label="color", value="#55FFAA")
106
+ slide = gr.Slider(label="guidance_scale", minimum=0, maximum=100, value=50)
107
+ gall = gr.Gallery(elem_id='gallery').style(grid=[4])
108
+ greet_btn = gr.Button("Generate")
109
+ greet_btn.click(fn=ex_g, inputs=[picker, slide] outputs=gall)
110
  gr.HTML(
111
  """
112
  <div class="footer">