Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from qdhf_things import run_qdhf, many_pictures | |
| from generate_examples import EXAMPLE_PROMPTS | |
| import os | |
| import io | |
| # Get the absolute path to the examples directory | |
| EXAMPLES_DIR = os.path.abspath("./examples") | |
| def generate_images(prompt, init_pop, total_itrs): | |
| init_pop = int(init_pop) | |
| total_itrs = int(total_itrs) | |
| # Use placeholder if prompt is empty | |
| if not prompt.strip(): | |
| prompt = "a duck crossing the street" | |
| archive_plots = [] | |
| for archive, plt_fig in run_qdhf(prompt, init_pop, total_itrs): | |
| buf = io.BytesIO() | |
| plt_fig.savefig(buf, format='png') | |
| buf.seek(0) | |
| archive_plots.append(buf.getvalue()) | |
| final_archive_plot = archive_plots[-1] | |
| generated_images = many_pictures(archive, prompt) | |
| # Save the final archive plot and generated images as temporary files | |
| temp_archive_file = "temp_archive_plot.png" | |
| temp_images_file = "temp_generated_images.png" | |
| with open(temp_archive_file, 'wb') as f: | |
| f.write(final_archive_plot) | |
| generated_images.savefig(temp_images_file) | |
| return temp_archive_file, temp_images_file | |
| def show_example(prompt): | |
| index = EXAMPLE_PROMPTS.index(prompt) | |
| archive_plot_path = os.path.join(EXAMPLES_DIR, f"archive_{index}.mp4") | |
| images_path = os.path.join(EXAMPLES_DIR, f"archive_pics_{index}.png") | |
| return prompt, archive_plot_path, images_path | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Quality Diversity through Human Feedback") | |
| gr.Markdown("[Paper](https://arxiv.org/abs/2310.12103) | [Project Website](https://liding.info/qdhf/)") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| prompt_input = gr.Textbox(label="Enter your prompt here", placeholder="a duck crossing the street") | |
| init_pop = gr.Slider(minimum=10, maximum=300, value=200, step=10, label="Initial Population") | |
| total_itrs = gr.Slider(minimum=10, maximum=300, value=200, step=10, label="Total Iterations") | |
| generate_button = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=2): | |
| archive_output = gr.Video(label="Archive Plot") | |
| images_output = gr.Image(label="Generated Pictures") | |
| generate_button.click(generate_images, | |
| inputs=[prompt_input, init_pop, total_itrs], | |
| outputs=[archive_output, images_output]) | |
| gr.Examples( | |
| examples=EXAMPLE_PROMPTS, | |
| inputs=prompt_input, | |
| outputs=[prompt_input, archive_output, images_output], | |
| fn=show_example, | |
| cache_examples=True, | |
| label="Example Prompts" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |