Spaces:
Runtime error
Runtime error
| # Install Gradio if not already installed | |
| import gradio as gr | |
| from diffusers import DiffusionPipeline | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import io | |
| from transformers.utils import move_cache | |
| move_cache() | |
| from PIL import Image | |
| # Load the SDXL refiner model pipeline | |
| refiner = DiffusionPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-xl-refiner-1.0", | |
| torch_dtype=torch.float16, | |
| use_safetensors=True, | |
| variant="fp16" | |
| ) | |
| refiner.to("cpu") # Use GPU if available, else change to "cpu" | |
| # Function to generate the image | |
| def generate_image(prompt, n_steps=20, high_noise_frac=0.8): | |
| # Generate refined image directly with the refiner model | |
| refined_image = refiner( | |
| prompt=prompt, | |
| num_inference_steps=n_steps, | |
| output_type="pil" | |
| ).images[0] | |
| # Save image to a buffer for download | |
| buf = io.BytesIO() | |
| refined_image.save(buf, format="PNG") | |
| buf.seek(0) | |
| return refined_image, buf | |
| # Create Gradio app with explicit submit button | |
| with gr.Blocks() as app: | |
| gr.Markdown("## AI Image Generator with Refinement") | |
| prompt = gr.Textbox(label="Enter a prompt", placeholder="e.g., A dragon flying", lines=2) | |
| n_steps = gr.Slider(minimum=10, maximum=50, step=1, value=20, label="Inference Steps") | |
| high_noise_frac = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.8, label="High Noise Fraction") | |
| generate_button = gr.Button("Generate Image") | |
| # Output display for the generated image and download button | |
| image_output = gr.Image(label="Generated Image") | |
| download_output = gr.File(label="Download Image") | |
| # Define button click action to trigger image generation | |
| generate_button.click(fn=generate_image, inputs=[prompt, n_steps, high_noise_frac], outputs=[image_output, download_output]) | |
| # Launch the app | |
| app.launch() |