Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from diffusers import StableDiffusionPipeline | |
| import torch | |
| import time | |
| # Set up the page | |
| st.set_page_config(page_title="Stable Diffusion Image Generator", layout="wide") | |
| st.title("π Stable Diffusion Image Generator") | |
| st.write("Generate images using Stable Diffusion v1-5") | |
| # Sidebar for settings | |
| with st.sidebar: | |
| st.header("Settings") | |
| prompt = st.text_area( | |
| "Enter your prompt", | |
| value="a photo of an astronaut riding a horse on mars", | |
| height=100, | |
| ) | |
| generate_button = st.button("Generate Image") | |
| # Load the model (with caching to avoid reloading) | |
| def load_model(): | |
| model_id = "runwayml/stable-diffusion-v1-5" | |
| try: | |
| # Try GPU first, fallback to CPU | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32,safety_checker=None | |
| ) | |
| pipe = pipe.to(device) | |
| return pipe | |
| except Exception as e: | |
| st.error(f"Failed to load model: {e}") | |
| return None | |
| pipe = load_model() | |
| # Generate and display the image | |
| if generate_button and prompt: | |
| if pipe is None: | |
| st.error("Model failed to load. Check logs for details.") | |
| else: | |
| with st.spinner("Generating image (this may take a while...) β³"): | |
| try: | |
| start_time = time.time() | |
| image = pipe(prompt).images[0] | |
| generation_time = time.time() - start_time | |
| st.image(image, caption=f"Generated in {generation_time:.2f} seconds") | |
| st.success("Image generated successfully! π") | |
| # Option to download | |
| st.download_button( | |
| label="Download Image", | |
| data=image_to_bytes(image), | |
| file_name="generated_image.png", | |
| mime="image/png", | |
| ) | |
| except Exception as e: | |
| st.error(f"Error during generation: {e}") | |
| # Helper function to convert PIL image to bytes | |
| def image_to_bytes(image): | |
| import io | |
| buf = io.BytesIO() | |
| image.save(buf, format="PNG") | |
| return buf.getvalue() |