Spaces:
Running
Running
Commit
·
498d77b
1
Parent(s):
13c420f
Update app.py
Browse files
app.py
CHANGED
|
@@ -26,6 +26,11 @@ def show_images_save(x):
|
|
| 26 |
grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
|
| 27 |
return grid_im
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
def generate():
|
| 31 |
scheduler = DDIMScheduler.from_pretrained(pipeline_name)
|
|
@@ -38,6 +43,21 @@ def generate():
|
|
| 38 |
x = scheduler.step(noise_pred, t, x).prev_sample
|
| 39 |
return show_images_save(x)
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
def crrop(file):
|
| 43 |
width, height = file.size
|
|
@@ -48,14 +68,17 @@ def crrop(file):
|
|
| 48 |
sav.append(file.crop(box))
|
| 49 |
return sav
|
| 50 |
|
| 51 |
-
|
| 52 |
def ex():
|
| 53 |
t = time()
|
| 54 |
print(ctime(t))
|
| 55 |
return crrop(generate())
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
demo = gr.Blocks(css=".container {max-width: 730px;margin: auto;} #gallery {margin-bottom: 15px; margin-left: auto; margin-right: auto; border-bottom-right-radius: .5rem !important; border-bottom-left-radius: .5rem !important;}")
|
| 59 |
|
| 60 |
with demo:
|
| 61 |
gr.HTML(
|
|
@@ -72,10 +95,18 @@ with demo:
|
|
| 72 |
with gr.Tabs():
|
| 73 |
with gr.TabItem("V1"):
|
| 74 |
with gr.Column():
|
| 75 |
-
with gr.Row().style(equal_height=True):
|
| 76 |
gall = gr.Gallery(elem_id='gallery').style(grid=[4])
|
| 77 |
greet_btn = gr.Button("Generate")
|
| 78 |
greet_btn.click(fn=ex, outputs=gall)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
gr.HTML(
|
| 80 |
"""
|
| 81 |
<div class="footer">
|
|
|
|
| 26 |
grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
|
| 27 |
return grid_im
|
| 28 |
|
| 29 |
+
def color_loss(images, target_color=(0.1, 0.9, 0.5, 1)):
|
| 30 |
+
target = (torch.tensor(target_color).to(images.device) * 2 - 1)
|
| 31 |
+
target = target[None, :, None, None]
|
| 32 |
+
error = torch.abs(images - target).mean()
|
| 33 |
+
return error
|
| 34 |
|
| 35 |
def generate():
|
| 36 |
scheduler = DDIMScheduler.from_pretrained(pipeline_name)
|
|
|
|
| 43 |
x = scheduler.step(noise_pred, t, x).prev_sample
|
| 44 |
return show_images_save(x)
|
| 45 |
|
| 46 |
+
def generate_g(color, guidance_loss_scale):
|
| 47 |
+
target_color = ImageColor.getcolor(color, "RGBA") # Target color as RGB
|
| 48 |
+
target_color = [a / 255 for a in target_color] # Rescale from (0, 255) to (0, 1)
|
| 49 |
+
x = torch.randn(8, 4, 64, 64).to(device)
|
| 50 |
+
for i, t in tqdm(enumerate(scheduler.timesteps)):
|
| 51 |
+
model_input = scheduler.scale_model_input(x, t)
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
noise_pred = image_pipe.unet(model_input, t)["sample"]
|
| 54 |
+
x = x.detach().requires_grad_()
|
| 55 |
+
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
|
| 56 |
+
loss = color_loss(x0, target_color) * guidance_loss_scale
|
| 57 |
+
cond_grad = -torch.autograd.grad(loss, x)[0]
|
| 58 |
+
x = x.detach() + cond_grad
|
| 59 |
+
x = scheduler.step(noise_pred, t, x).prev_sample
|
| 60 |
+
return show_images_save(x)
|
| 61 |
|
| 62 |
def crrop(file):
|
| 63 |
width, height = file.size
|
|
|
|
| 68 |
sav.append(file.crop(box))
|
| 69 |
return sav
|
| 70 |
|
|
|
|
| 71 |
def ex():
|
| 72 |
t = time()
|
| 73 |
print(ctime(t))
|
| 74 |
return crrop(generate())
|
| 75 |
|
| 76 |
+
def ex_g(picker, slider):
|
| 77 |
+
t = time()
|
| 78 |
+
print(ctime(t))
|
| 79 |
+
return crrop(generate_g(picker, slider))
|
| 80 |
|
| 81 |
+
demo = gr.Blocks(css="#row-1 {max-height: 300px !important;} .container {max-width: 730px;margin: auto;} #gallery {margin-bottom: 15px; margin-left: auto; margin-right: auto; border-bottom-right-radius: .5rem !important; border-bottom-left-radius: .5rem !important;}")
|
| 82 |
|
| 83 |
with demo:
|
| 84 |
gr.HTML(
|
|
|
|
| 95 |
with gr.Tabs():
|
| 96 |
with gr.TabItem("V1"):
|
| 97 |
with gr.Column():
|
| 98 |
+
with gr.Row(elem_id='row-1').style(equal_height=True):
|
| 99 |
gall = gr.Gallery(elem_id='gallery').style(grid=[4])
|
| 100 |
greet_btn = gr.Button("Generate")
|
| 101 |
greet_btn.click(fn=ex, outputs=gall)
|
| 102 |
+
with gr.TabItem("V2"):
|
| 103 |
+
with gr.Column():
|
| 104 |
+
with gr.Row(elem_id='row-1').style(equal_height=True):
|
| 105 |
+
picker = gr.ColorPicker(label="color", value="#55FFAA")
|
| 106 |
+
slide = gr.Slider(label="guidance_scale", minimum=0, maximum=100, value=50)
|
| 107 |
+
gall = gr.Gallery(elem_id='gallery').style(grid=[4])
|
| 108 |
+
greet_btn = gr.Button("Generate")
|
| 109 |
+
greet_btn.click(fn=ex_g, inputs=[picker, slide] outputs=gall)
|
| 110 |
gr.HTML(
|
| 111 |
"""
|
| 112 |
<div class="footer">
|