Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from io import BytesIO | |
| from typing import Literal | |
| from diffusers import StableDiffusionPipeline | |
| import torch | |
| import time | |
| seed = 42 | |
| generator = torch.manual_seed(seed) | |
| NUM_ITERS_TO_RUN = 2 | |
| NUM_INFERENCE_STEPS = 20 | |
| NUM_IMAGES_PER_PROMPT = 1 | |
| def text2image( | |
| prompt: str, | |
| repo_id: Literal[ | |
| "dreamlike-art/dreamlike-photoreal-2.0", | |
| "hakurei/waifu-diffusion", | |
| "prompthero/openjourney", | |
| "stabilityai/stable-diffusion-2-1", | |
| "runwayml/stable-diffusion-v1-5", | |
| "nota-ai/bk-sdm-small", | |
| "CompVis/stable-diffusion-v1-4", | |
| ], | |
| ): | |
| start = time.time() | |
| if torch.cuda.is_available(): | |
| print("Using GPU") | |
| pipeline = StableDiffusionPipeline.from_pretrained( | |
| repo_id, | |
| torch_dtype=torch.float16, | |
| use_safetensors=True, | |
| ).to("cuda") | |
| else: | |
| print("Using CPU") | |
| pipeline = StableDiffusionPipeline.from_pretrained( | |
| repo_id, | |
| torch_dtype=torch.float32, | |
| use_safetensors=True, | |
| ) | |
| for _ in range(NUM_ITERS_TO_RUN): | |
| images = pipeline( | |
| prompt, | |
| num_inference_steps=NUM_INFERENCE_STEPS, | |
| generator=generator, | |
| num_images_per_prompt=NUM_IMAGES_PER_PROMPT, | |
| ).images | |
| end = time.time() | |
| return images[0], start, end | |
| def app(): | |
| st.header("Text-to-image Web App") | |
| st.subheader("Powered by Hugging Face") | |
| user_input = st.text_area( | |
| "Enter your text prompt below and click the button to submit." | |
| ) | |
| option = st.selectbox( | |
| "Select model (in order of processing time)", | |
| ( | |
| "nota-ai/bk-sdm-small", | |
| "CompVis/stable-diffusion-v1-4", | |
| "runwayml/stable-diffusion-v1-5", | |
| "prompthero/openjourney", | |
| "hakurei/waifu-diffusion", | |
| "stabilityai/stable-diffusion-2-1", | |
| "dreamlike-art/dreamlike-photoreal-2.0", | |
| ), | |
| ) | |
| with st.form("my_form"): | |
| submit = st.form_submit_button(label="Submit text prompt") | |
| if submit: | |
| with st.spinner(text="Generating image ... It may take up to 20 minutes."): | |
| im, start, end = text2image(prompt=user_input, repo_id=option) | |
| buf = BytesIO() | |
| im.save(buf, format="PNG") | |
| byte_im = buf.getvalue() | |
| hours, rem = divmod(end - start, 3600) | |
| minutes, seconds = divmod(rem, 60) | |
| st.success( | |
| "Processing time: {:0>2}:{:0>2}:{:05.2f}.".format( | |
| int(hours), int(minutes), seconds | |
| ) | |
| ) | |
| st.image(im) | |
| st.download_button( | |
| label="Click here to download", | |
| data=byte_im, | |
| file_name="generated_image.png", | |
| mime="image/png", | |
| ) | |
| if __name__ == "__main__": | |
| app() |