Spaces:
Runtime error
Runtime error
Merge pull request #55 from abidlabs/main
Browse files- app/{app_gradio.py → gradio/app_gradio.py} +11 -47
- app/{app_gradio_ngrok.py → gradio/app_gradio_ngrok.py} +5 -15
- app/gradio/dalle_mini +1 -0
- app/gradio/requirements.txt +4 -0
- app/sample_images/image_0.jpg +0 -0
- app/sample_images/image_1.jpg +0 -0
- app/sample_images/image_2.jpg +0 -0
- app/sample_images/image_3.jpg +0 -0
- app/sample_images/image_4.jpg +0 -0
- app/sample_images/image_5.jpg +0 -0
- app/sample_images/image_6.jpg +0 -0
- app/sample_images/image_7.jpg +0 -0
- app/sample_images/readme.txt +0 -1
- app/ui_gradio.py +0 -91
- requirements.txt +1 -1
app/{app_gradio.py → gradio/app_gradio.py}
RENAMED
|
@@ -18,12 +18,16 @@ from PIL import Image
|
|
| 18 |
import numpy as np
|
| 19 |
import matplotlib.pyplot as plt
|
| 20 |
|
| 21 |
-
|
| 22 |
from vqgan_jax.modeling_flax_vqgan import VQModel
|
| 23 |
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
import gradio as gr
|
| 26 |
|
|
|
|
|
|
|
| 27 |
|
| 28 |
DALLE_REPO = 'flax-community/dalle-mini'
|
| 29 |
DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'
|
|
@@ -58,34 +62,12 @@ def generate(input, rng, params):
|
|
| 58 |
def get_images(indices, params):
|
| 59 |
return vqgan.decode_code(indices, params=params)
|
| 60 |
|
| 61 |
-
def plot_images(images):
|
| 62 |
-
fig = plt.figure(figsize=(40, 20))
|
| 63 |
-
columns = 4
|
| 64 |
-
rows = 2
|
| 65 |
-
plt.subplots_adjust(hspace=0, wspace=0)
|
| 66 |
-
|
| 67 |
-
for i in range(1, columns*rows +1):
|
| 68 |
-
fig.add_subplot(rows, columns, i)
|
| 69 |
-
plt.imshow(images[i-1])
|
| 70 |
-
plt.gca().axes.get_yaxis().set_visible(False)
|
| 71 |
-
plt.show()
|
| 72 |
-
|
| 73 |
-
def stack_reconstructions(images):
|
| 74 |
-
w, h = images[0].size[0], images[0].size[1]
|
| 75 |
-
img = Image.new("RGB", (len(images)*w, h))
|
| 76 |
-
for i, img_ in enumerate(images):
|
| 77 |
-
img.paste(img_, (i*w,0))
|
| 78 |
-
return img
|
| 79 |
-
|
| 80 |
p_generate = jax.pmap(generate, "batch")
|
| 81 |
p_get_images = jax.pmap(get_images, "batch")
|
| 82 |
|
| 83 |
bart_params = replicate(model.params)
|
| 84 |
vqgan_params = replicate(vqgan.params)
|
| 85 |
|
| 86 |
-
# ## CLIP Scoring
|
| 87 |
-
from transformers import CLIPProcessor, FlaxCLIPModel
|
| 88 |
-
|
| 89 |
clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 90 |
print("Initialize FlaxCLIPModel")
|
| 91 |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
@@ -137,48 +119,30 @@ def top_k_predictions(prompt, num_candidates=32, k=8):
|
|
| 137 |
|
| 138 |
def run_inference(prompt, num_images=32, num_preds=8):
|
| 139 |
images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
|
| 140 |
-
predictions =
|
| 141 |
output_title = f"""
|
| 142 |
-
<p style="font-size:22px; font-style:bold">Best predictions</p>
|
| 143 |
-
<p>We asked our model to generate 32 candidates for your prompt:</p>
|
| 144 |
-
|
| 145 |
-
<pre>
|
| 146 |
-
|
| 147 |
<b>{prompt}</b>
|
| 148 |
-
</pre>
|
| 149 |
-
<p>We then used a pre-trained <a href="https://huggingface.co/openai/clip-vit-base-patch32">CLIP model</a> to score them according to the
|
| 150 |
-
similarity of the text and the image representations.</p>
|
| 151 |
-
|
| 152 |
-
<p>This is the result:</p>
|
| 153 |
"""
|
| 154 |
-
|
| 155 |
-
<p>Read more about the process <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA">in our report</a>.<p>
|
| 156 |
-
<p style='text-align: center'>Created with <a href="https://github.com/borisdayma/dalle-mini">DALLE·mini</a></p>
|
| 157 |
-
"""
|
| 158 |
-
return (output_title, predictions, output_description)
|
| 159 |
|
| 160 |
outputs = [
|
| 161 |
gr.outputs.HTML(label=""), # To be used as title
|
| 162 |
gr.outputs.Image(label=''),
|
| 163 |
-
gr.outputs.HTML(label=""), # Additional text that appears in the screenshot
|
| 164 |
]
|
| 165 |
|
| 166 |
description = """
|
| 167 |
-
|
| 168 |
-
It reproduces the essential characteristics of OpenAI's DALL·E, at a fraction of the size.
|
| 169 |
-
|
| 170 |
-
Please, write what you would like the model to generate, or select one of the examples below.
|
| 171 |
"""
|
| 172 |
gr.Interface(run_inference,
|
| 173 |
-
inputs=[gr.inputs.Textbox(label='
|
| 174 |
outputs=outputs,
|
| 175 |
title='DALL·E mini',
|
| 176 |
description=description,
|
| 177 |
-
article="<p style='text-align: center'>
|
| 178 |
layout='vertical',
|
| 179 |
theme='huggingface',
|
| 180 |
examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
|
| 181 |
allow_flagging=False,
|
| 182 |
live=False,
|
| 183 |
# server_port=8999
|
| 184 |
-
).launch()
|
|
|
|
| 18 |
import numpy as np
|
| 19 |
import matplotlib.pyplot as plt
|
| 20 |
|
|
|
|
| 21 |
from vqgan_jax.modeling_flax_vqgan import VQModel
|
| 22 |
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
| 23 |
|
| 24 |
+
# ## CLIP Scoring
|
| 25 |
+
from transformers import CLIPProcessor, FlaxCLIPModel
|
| 26 |
+
|
| 27 |
import gradio as gr
|
| 28 |
|
| 29 |
+
from dalle_mini.helpers import captioned_strip
|
| 30 |
+
|
| 31 |
|
| 32 |
DALLE_REPO = 'flax-community/dalle-mini'
|
| 33 |
DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'
|
|
|
|
| 62 |
def get_images(indices, params):
|
| 63 |
return vqgan.decode_code(indices, params=params)
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
p_generate = jax.pmap(generate, "batch")
|
| 66 |
p_get_images = jax.pmap(get_images, "batch")
|
| 67 |
|
| 68 |
bart_params = replicate(model.params)
|
| 69 |
vqgan_params = replicate(vqgan.params)
|
| 70 |
|
|
|
|
|
|
|
|
|
|
| 71 |
clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
| 72 |
print("Initialize FlaxCLIPModel")
|
| 73 |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
|
|
| 119 |
|
| 120 |
def run_inference(prompt, num_images=32, num_preds=8):
|
| 121 |
images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
|
| 122 |
+
predictions = captioned_strip(images)
|
| 123 |
output_title = f"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
<b>{prompt}</b>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
"""
|
| 126 |
+
return (output_title, predictions)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
outputs = [
|
| 129 |
gr.outputs.HTML(label=""), # To be used as title
|
| 130 |
gr.outputs.Image(label=''),
|
|
|
|
| 131 |
]
|
| 132 |
|
| 133 |
description = """
|
| 134 |
+
DALL·E-mini is an AI model that generates images from any prompt you give! Generate images from text:
|
|
|
|
|
|
|
|
|
|
| 135 |
"""
|
| 136 |
gr.Interface(run_inference,
|
| 137 |
+
inputs=[gr.inputs.Textbox(label='What do you want to see?')],
|
| 138 |
outputs=outputs,
|
| 139 |
title='DALL·E mini',
|
| 140 |
description=description,
|
| 141 |
+
article="<p style='text-align: center'> Created by Boris Dayma et al. 2021 | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a> | <a href='https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA'>Report</a></p>",
|
| 142 |
layout='vertical',
|
| 143 |
theme='huggingface',
|
| 144 |
examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
|
| 145 |
allow_flagging=False,
|
| 146 |
live=False,
|
| 147 |
# server_port=8999
|
| 148 |
+
).launch(share=True)
|
app/{app_gradio_ngrok.py → gradio/app_gradio_ngrok.py}
RENAMED
|
@@ -7,25 +7,15 @@ import numpy as np
|
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
from io import BytesIO
|
| 9 |
import base64
|
|
|
|
| 10 |
|
| 11 |
import gradio as gr
|
| 12 |
|
| 13 |
-
|
| 14 |
-
import os
|
| 15 |
-
backend_url = os.environ["BACKEND_SERVER"]
|
| 16 |
|
| 17 |
-
def compose_predictions(images, caption=None):
|
| 18 |
-
increased_h = 0 if caption is None else 48
|
| 19 |
-
w, h = images[0].size[0], images[0].size[1]
|
| 20 |
-
img = Image.new("RGB", (len(images)*w, h + increased_h))
|
| 21 |
-
for i, img_ in enumerate(images):
|
| 22 |
-
img.paste(img_, (i*w, increased_h))
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
|
| 27 |
-
draw.text((20, 3), caption, (255,255,255), font=font)
|
| 28 |
-
return img
|
| 29 |
|
| 30 |
class ServiceError(Exception):
|
| 31 |
def __init__(self, status_code):
|
|
@@ -46,7 +36,7 @@ def get_images_from_ngrok(prompt):
|
|
| 46 |
def run_inference(prompt):
|
| 47 |
try:
|
| 48 |
images = get_images_from_ngrok(prompt)
|
| 49 |
-
predictions =
|
| 50 |
output_title = f"""
|
| 51 |
<p style="font-size:22px; font-style:bold">Best predictions</p>
|
| 52 |
<p>We asked our model to generate 128 candidates for your prompt:</p>
|
|
|
|
| 7 |
import matplotlib.pyplot as plt
|
| 8 |
from io import BytesIO
|
| 9 |
import base64
|
| 10 |
+
import os
|
| 11 |
|
| 12 |
import gradio as gr
|
| 13 |
|
| 14 |
+
from dalle_mini.helpers import captioned_strip
|
|
|
|
|
|
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
backend_url = os.environ["BACKEND_SERVER"]
|
| 18 |
+
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
class ServiceError(Exception):
|
| 21 |
def __init__(self, status_code):
|
|
|
|
| 36 |
def run_inference(prompt):
|
| 37 |
try:
|
| 38 |
images = get_images_from_ngrok(prompt)
|
| 39 |
+
predictions = captioned_strip(images)
|
| 40 |
output_title = f"""
|
| 41 |
<p style="font-size:22px; font-style:bold">Best predictions</p>
|
| 42 |
<p>We asked our model to generate 128 candidates for your prompt:</p>
|
app/gradio/dalle_mini
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
../../dalle_mini/
|
app/gradio/requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Requirements for huggingface spaces
|
| 2 |
+
gradio>=2.2.3
|
| 3 |
+
flax
|
| 4 |
+
transformers
|
app/sample_images/image_0.jpg
DELETED
|
Binary file (9.02 kB)
|
|
|
app/sample_images/image_1.jpg
DELETED
|
Binary file (9.71 kB)
|
|
|
app/sample_images/image_2.jpg
DELETED
|
Binary file (14.1 kB)
|
|
|
app/sample_images/image_3.jpg
DELETED
|
Binary file (9.38 kB)
|
|
|
app/sample_images/image_4.jpg
DELETED
|
Binary file (9.97 kB)
|
|
|
app/sample_images/image_5.jpg
DELETED
|
Binary file (15.3 kB)
|
|
|
app/sample_images/image_6.jpg
DELETED
|
Binary file (11.1 kB)
|
|
|
app/sample_images/image_7.jpg
DELETED
|
Binary file (8.55 kB)
|
|
|
app/sample_images/readme.txt
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
These images were generated by one of our checkpoints, as responses to the prompt "snowy mountains by the sea".
|
|
|
|
|
|
app/ui_gradio.py
DELETED
|
@@ -1,91 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python
|
| 2 |
-
# coding: utf-8
|
| 3 |
-
|
| 4 |
-
from PIL import Image
|
| 5 |
-
import gradio as gr
|
| 6 |
-
|
| 7 |
-
def compose_predictions(images, caption=None):
|
| 8 |
-
increased_h = 0 if caption is None else 48
|
| 9 |
-
w, h = images[0].size[0], images[0].size[1]
|
| 10 |
-
img = Image.new("RGB", (len(images)*w, h + increased_h))
|
| 11 |
-
for i, img_ in enumerate(images):
|
| 12 |
-
img.paste(img_, (i*w, increased_h))
|
| 13 |
-
|
| 14 |
-
if caption is not None:
|
| 15 |
-
draw = ImageDraw.Draw(img)
|
| 16 |
-
font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
|
| 17 |
-
draw.text((20, 3), caption, (255,255,255), font=font)
|
| 18 |
-
return img
|
| 19 |
-
|
| 20 |
-
def compose_predictions_grid(images):
|
| 21 |
-
cols = 4
|
| 22 |
-
rows = len(images) // cols
|
| 23 |
-
w, h = images[0].size[0], images[0].size[1]
|
| 24 |
-
img = Image.new("RGB", (w * cols, h * rows))
|
| 25 |
-
for i, img_ in enumerate(images):
|
| 26 |
-
row = i // cols
|
| 27 |
-
col = i % cols
|
| 28 |
-
img.paste(img_, (w * col, h * row))
|
| 29 |
-
return img
|
| 30 |
-
|
| 31 |
-
def top_k_predictions_real(prompt, num_candidates=32, k=8):
|
| 32 |
-
images = hallucinate(prompt, num_images=num_candidates)
|
| 33 |
-
images = clip_top_k(prompt, images, k=num_preds)
|
| 34 |
-
return images
|
| 35 |
-
|
| 36 |
-
def top_k_predictions(prompt, num_candidates=32, k=8):
|
| 37 |
-
images = []
|
| 38 |
-
for i in range(k):
|
| 39 |
-
image = Image.open(f"sample_images/image_{i}.jpg")
|
| 40 |
-
images.append(image)
|
| 41 |
-
return images
|
| 42 |
-
|
| 43 |
-
def run_inference(prompt, num_images=32, num_preds=8):
|
| 44 |
-
images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
|
| 45 |
-
predictions = compose_predictions(images)
|
| 46 |
-
output_title = f"""
|
| 47 |
-
<p style="font-size:22px; font-style:bold">Best predictions</p>
|
| 48 |
-
<p>We asked our model to generate 32 candidates for your prompt:</p>
|
| 49 |
-
|
| 50 |
-
<pre>
|
| 51 |
-
|
| 52 |
-
<b>{prompt}</b>
|
| 53 |
-
</pre>
|
| 54 |
-
<p>We then used a pre-trained <a href="https://huggingface.co/openai/clip-vit-base-patch32">CLIP model</a> to score them according to the
|
| 55 |
-
similarity of the text and the image representations.</p>
|
| 56 |
-
|
| 57 |
-
<p>This is the result:</p>
|
| 58 |
-
"""
|
| 59 |
-
output_description = """
|
| 60 |
-
<p>Read more about the process <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA">in our report</a>.<p>
|
| 61 |
-
<p style='text-align: center'>Created with <a href="https://github.com/borisdayma/dalle-mini">DALLE·mini</a></p>
|
| 62 |
-
"""
|
| 63 |
-
return (output_title, predictions, output_description)
|
| 64 |
-
|
| 65 |
-
outputs = [
|
| 66 |
-
gr.outputs.HTML(label=""), # To be used as title
|
| 67 |
-
gr.outputs.Image(label=''),
|
| 68 |
-
gr.outputs.HTML(label=""), # Additional text that appears in the screenshot
|
| 69 |
-
]
|
| 70 |
-
|
| 71 |
-
description = """
|
| 72 |
-
Welcome to our demo of DALL·E-mini. This project was created on TPU v3-8s during the 🤗 Flax / JAX Community Week.
|
| 73 |
-
It reproduces the essential characteristics of OpenAI's DALL·E, at a fraction of the size.
|
| 74 |
-
|
| 75 |
-
Please, write what you would like the model to generate, or select one of the examples below.
|
| 76 |
-
"""
|
| 77 |
-
gr.Interface(run_inference,
|
| 78 |
-
inputs=[gr.inputs.Textbox(label='Prompt')], #, gr.inputs.Slider(1,64,1,8, label='Candidates to generate'), gr.inputs.Slider(1,8,1,1, label='Best predictions to show')],
|
| 79 |
-
outputs=outputs,
|
| 80 |
-
title='DALL·E mini',
|
| 81 |
-
description=description,
|
| 82 |
-
article="<p style='text-align: center'> DALLE·mini by Boris Dayma et al. | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a></p>",
|
| 83 |
-
layout='vertical',
|
| 84 |
-
theme='huggingface',
|
| 85 |
-
examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
|
| 86 |
-
allow_flagging=False,
|
| 87 |
-
live=False,
|
| 88 |
-
server_port=8999
|
| 89 |
-
).launch(
|
| 90 |
-
share=True # Creates temporary public link if true
|
| 91 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,2 +1,2 @@
|
|
| 1 |
# Requirements for huggingface spaces
|
| 2 |
-
streamlit>=0.84.2
|
|
|
|
| 1 |
# Requirements for huggingface spaces
|
| 2 |
+
streamlit>=0.84.2
|