Spaces:
Runtime error
Runtime error
Fix examples
Browse files
app.py
CHANGED
|
@@ -18,7 +18,7 @@ MIN_SIZE = 0.01
|
|
| 18 |
WHITE = 255
|
| 19 |
COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
|
| 20 |
|
| 21 |
-
PROMPT1 = "3D Pixar animation of a cute unicorn and a pink hedgehog and a nerdy owl traveling in a magical forest"
|
| 22 |
PROMPT2 = "science fiction movie poster with an astronaut and a robot and a green alien and a spaceship"
|
| 23 |
PROMPT3 = "a golden retriever and a german shepherd and a boston terrier and an english bulldog and a border collie in a pool"
|
| 24 |
EXAMPLE_BOXES = {
|
|
@@ -146,7 +146,7 @@ def generate(
|
|
| 146 |
|
| 147 |
filter_token_indices = convert_token_indices(filter_token_indices) if len(filter_token_indices.strip()) > 0 else None
|
| 148 |
num_tokens = int(num_tokens) if len(num_tokens.strip()) > 0 else None
|
| 149 |
-
prompts = [prompt.strip(
|
| 150 |
|
| 151 |
images = inference(
|
| 152 |
boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
|
|
@@ -158,14 +158,14 @@ def generate(
|
|
| 158 |
|
| 159 |
def convert_token_indices(token_indices, nested=False):
|
| 160 |
if nested:
|
| 161 |
-
return [convert_token_indices(indices, nested=False) for indices in token_indices.split(
|
| 162 |
|
| 163 |
-
return [int(index.strip()) for index in token_indices.split(
|
| 164 |
|
| 165 |
|
| 166 |
def draw(sketchpad):
|
| 167 |
boxes = []
|
| 168 |
-
for i, layer in enumerate(sketchpad[
|
| 169 |
non_zeros = layer.nonzero()
|
| 170 |
x1 = x2 = y1 = y2 = 0
|
| 171 |
if len(non_zeros[0]) > 0:
|
|
@@ -177,7 +177,7 @@ def draw(sketchpad):
|
|
| 177 |
y2 = y1y2.max()
|
| 178 |
|
| 179 |
if (x2 - x1 < MIN_SIZE) or (y2 - y1 < MIN_SIZE):
|
| 180 |
-
raise gr.Error(f
|
| 181 |
|
| 182 |
boxes.append((x1, y1, x2, y2))
|
| 183 |
|
|
@@ -185,15 +185,16 @@ def draw(sketchpad):
|
|
| 185 |
return [boxes, layout_image]
|
| 186 |
|
| 187 |
|
| 188 |
-
def draw_boxes(boxes):
|
| 189 |
if len(boxes) == 0:
|
| 190 |
return None
|
| 191 |
|
| 192 |
boxes = np.array(boxes) * RESOLUTION
|
| 193 |
-
image = Image.new(
|
| 194 |
drawing = ImageDraw.Draw(image)
|
| 195 |
for i, box in enumerate(boxes.astype(int).tolist()):
|
| 196 |
-
|
|
|
|
| 197 |
|
| 198 |
return image
|
| 199 |
|
|
@@ -202,35 +203,11 @@ def clear(batch_size):
|
|
| 202 |
return [[], None, None, None]
|
| 203 |
|
| 204 |
|
| 205 |
-
def
|
| 206 |
-
prompt,
|
| 207 |
-
subject_token_indices,
|
| 208 |
-
filter_token_indices,
|
| 209 |
-
num_tokens,
|
| 210 |
-
init_step_size,
|
| 211 |
-
final_step_size,
|
| 212 |
-
num_clusters_per_subject,
|
| 213 |
-
cross_loss_scale,
|
| 214 |
-
self_loss_scale,
|
| 215 |
-
classifier_free_guidance_scale,
|
| 216 |
-
batch_size,
|
| 217 |
-
num_iterations,
|
| 218 |
-
loss_threshold,
|
| 219 |
-
num_guidance_steps,
|
| 220 |
-
seed,
|
| 221 |
-
):
|
| 222 |
-
layers = []
|
| 223 |
boxes = EXAMPLE_BOXES[prompt]
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
sketchpad = {'layers': layers}
|
| 228 |
-
layout_images = draw_boxes(boxes)
|
| 229 |
-
out_images = generate(prompt, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
|
| 230 |
-
final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale, classifier_free_guidance_scale,
|
| 231 |
-
batch_size, num_iterations, loss_threshold, num_guidance_steps, seed, boxes)
|
| 232 |
-
|
| 233 |
-
return boxes, sketchpad, layout_image, out_images
|
| 234 |
|
| 235 |
|
| 236 |
def main():
|
|
@@ -274,7 +251,7 @@ def main():
|
|
| 274 |
}
|
| 275 |
"""
|
| 276 |
|
| 277 |
-
nltk.download(
|
| 278 |
|
| 279 |
with gr.Blocks(
|
| 280 |
css=css,
|
|
@@ -301,13 +278,13 @@ def main():
|
|
| 301 |
)
|
| 302 |
|
| 303 |
with gr.Row():
|
| 304 |
-
sketchpad = gr.Sketchpad(label="Sketch Pad", width=RESOLUTION, height=RESOLUTION)
|
| 305 |
layout_image = gr.Image(type="pil", label="Bounding Boxes", interactive=False, width=RESOLUTION, height=RESOLUTION, scale=1)
|
| 306 |
|
| 307 |
with gr.Row():
|
| 308 |
-
clear_button = gr.Button(value=
|
| 309 |
-
generate_layout_button = gr.Button(value=
|
| 310 |
-
generate_image_button = gr.Button(value=
|
| 311 |
|
| 312 |
with gr.Row():
|
| 313 |
out_images = gr.Gallery(type="pil", label="Generated Images", interactive=False)
|
|
@@ -392,29 +369,29 @@ def main():
|
|
| 392 |
gr.Examples(
|
| 393 |
examples=[
|
| 394 |
[
|
| 395 |
-
|
| 396 |
"7,8,17;11,12,17;15,16,17", "5,6,9,10,13,14,18,19", "21",
|
| 397 |
25, 10, 3, 1, 1,
|
| 398 |
7.5, 1, 5, 0.2, 15,
|
| 399 |
286,
|
| 400 |
],
|
| 401 |
[
|
| 402 |
-
|
| 403 |
"7;10;13,14;17", "5,6,8,9,11,12,15,16", "17",
|
| 404 |
18, 5, 3, 1, 1,
|
| 405 |
7.5, 1, 5, 0.2, 15,
|
| 406 |
216,
|
| 407 |
],
|
| 408 |
[
|
| 409 |
-
|
| 410 |
"2,3;6,7;10,11;14,15;18,19", "1,4,5,8,9,12,13,16,17,20,21", "22",
|
| 411 |
18, 5, 3, 1, 1,
|
| 412 |
7.5, 1, 5, 0.2, 15,
|
| 413 |
156,
|
| 414 |
],
|
| 415 |
],
|
| 416 |
-
fn=generate_example,
|
| 417 |
inputs=[
|
|
|
|
| 418 |
prompt, subject_token_indices, filter_token_indices, num_tokens,
|
| 419 |
init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
|
| 420 |
classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps,
|
|
@@ -427,5 +404,5 @@ def main():
|
|
| 427 |
|
| 428 |
demo.launch(show_api=False, show_error=True)
|
| 429 |
|
| 430 |
-
if __name__ ==
|
| 431 |
main()
|
|
|
|
| 18 |
WHITE = 255
|
| 19 |
COLORS = ["red", "blue", "green", "orange", "purple", "turquoise", "olive"]
|
| 20 |
|
| 21 |
+
PROMPT1 = "3D Pixar animation of a cute unicorn and a pink hedgehog and a nerdy owl traveling in a magical forest"
|
| 22 |
PROMPT2 = "science fiction movie poster with an astronaut and a robot and a green alien and a spaceship"
|
| 23 |
PROMPT3 = "a golden retriever and a german shepherd and a boston terrier and an english bulldog and a border collie in a pool"
|
| 24 |
EXAMPLE_BOXES = {
|
|
|
|
| 146 |
|
| 147 |
filter_token_indices = convert_token_indices(filter_token_indices) if len(filter_token_indices.strip()) > 0 else None
|
| 148 |
num_tokens = int(num_tokens) if len(num_tokens.strip()) > 0 else None
|
| 149 |
+
prompts = [prompt.strip(".").strip(",").strip()] * batch_size
|
| 150 |
|
| 151 |
images = inference(
|
| 152 |
boxes, prompts, subject_token_indices, filter_token_indices, num_tokens, init_step_size,
|
|
|
|
| 158 |
|
| 159 |
def convert_token_indices(token_indices, nested=False):
|
| 160 |
if nested:
|
| 161 |
+
return [convert_token_indices(indices, nested=False) for indices in token_indices.split(";")]
|
| 162 |
|
| 163 |
+
return [int(index.strip()) for index in token_indices.split(",") if len(index.strip()) > 0]
|
| 164 |
|
| 165 |
|
| 166 |
def draw(sketchpad):
|
| 167 |
boxes = []
|
| 168 |
+
for i, layer in enumerate(sketchpad["layers"]):
|
| 169 |
non_zeros = layer.nonzero()
|
| 170 |
x1 = x2 = y1 = y2 = 0
|
| 171 |
if len(non_zeros[0]) > 0:
|
|
|
|
| 177 |
y2 = y1y2.max()
|
| 178 |
|
| 179 |
if (x2 - x1 < MIN_SIZE) or (y2 - y1 < MIN_SIZE):
|
| 180 |
+
raise gr.Error(f"Box in layer {i} is too small")
|
| 181 |
|
| 182 |
boxes.append((x1, y1, x2, y2))
|
| 183 |
|
|
|
|
| 185 |
return [boxes, layout_image]
|
| 186 |
|
| 187 |
|
| 188 |
+
def draw_boxes(boxes, is_sketch=False):
|
| 189 |
if len(boxes) == 0:
|
| 190 |
return None
|
| 191 |
|
| 192 |
boxes = np.array(boxes) * RESOLUTION
|
| 193 |
+
image = Image.new("RGB", (RESOLUTION, RESOLUTION), (WHITE, WHITE, WHITE))
|
| 194 |
drawing = ImageDraw.Draw(image)
|
| 195 |
for i, box in enumerate(boxes.astype(int).tolist()):
|
| 196 |
+
color = "black" if is_sketch else COLORS[i % len(COLORS)]
|
| 197 |
+
drawing.rectangle(box, outline=color, width=4)
|
| 198 |
|
| 199 |
return image
|
| 200 |
|
|
|
|
| 203 |
return [[], None, None, None]
|
| 204 |
|
| 205 |
|
| 206 |
+
def make_example_inputs(prompt):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
boxes = EXAMPLE_BOXES[prompt]
|
| 208 |
+
sketchpad = draw_boxes(boxes, is_sketch=True)
|
| 209 |
+
layout_image = draw_boxes(boxes)
|
| 210 |
+
return sketchpad, layout_image, prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
|
| 213 |
def main():
|
|
|
|
| 251 |
}
|
| 252 |
"""
|
| 253 |
|
| 254 |
+
nltk.download("averaged_perceptron_tagger")
|
| 255 |
|
| 256 |
with gr.Blocks(
|
| 257 |
css=css,
|
|
|
|
| 278 |
)
|
| 279 |
|
| 280 |
with gr.Row():
|
| 281 |
+
sketchpad = gr.Sketchpad(label="Sketch Pad (draw each bounding box in a different layer)", width=RESOLUTION, height=RESOLUTION)
|
| 282 |
layout_image = gr.Image(type="pil", label="Bounding Boxes", interactive=False, width=RESOLUTION, height=RESOLUTION, scale=1)
|
| 283 |
|
| 284 |
with gr.Row():
|
| 285 |
+
clear_button = gr.Button(value="Clear")
|
| 286 |
+
generate_layout_button = gr.Button(value="Generate layout")
|
| 287 |
+
generate_image_button = gr.Button(value="Generate image")
|
| 288 |
|
| 289 |
with gr.Row():
|
| 290 |
out_images = gr.Gallery(type="pil", label="Generated Images", interactive=False)
|
|
|
|
| 369 |
gr.Examples(
|
| 370 |
examples=[
|
| 371 |
[
|
| 372 |
+
*make_example_inputs(PROMPT1),
|
| 373 |
"7,8,17;11,12,17;15,16,17", "5,6,9,10,13,14,18,19", "21",
|
| 374 |
25, 10, 3, 1, 1,
|
| 375 |
7.5, 1, 5, 0.2, 15,
|
| 376 |
286,
|
| 377 |
],
|
| 378 |
[
|
| 379 |
+
*make_example_inputs(PROMPT2),
|
| 380 |
"7;10;13,14;17", "5,6,8,9,11,12,15,16", "17",
|
| 381 |
18, 5, 3, 1, 1,
|
| 382 |
7.5, 1, 5, 0.2, 15,
|
| 383 |
216,
|
| 384 |
],
|
| 385 |
[
|
| 386 |
+
*make_example_inputs(PROMPT3),
|
| 387 |
"2,3;6,7;10,11;14,15;18,19", "1,4,5,8,9,12,13,16,17,20,21", "22",
|
| 388 |
18, 5, 3, 1, 1,
|
| 389 |
7.5, 1, 5, 0.2, 15,
|
| 390 |
156,
|
| 391 |
],
|
| 392 |
],
|
|
|
|
| 393 |
inputs=[
|
| 394 |
+
sketchpad, layout_image,
|
| 395 |
prompt, subject_token_indices, filter_token_indices, num_tokens,
|
| 396 |
init_step_size, final_step_size, num_clusters_per_subject, cross_loss_scale, self_loss_scale,
|
| 397 |
classifier_free_guidance_scale, batch_size, num_iterations, loss_threshold, num_guidance_steps,
|
|
|
|
| 404 |
|
| 405 |
demo.launch(show_api=False, show_error=True)
|
| 406 |
|
| 407 |
+
if __name__ == "__main__":
|
| 408 |
main()
|