File size: 1,854 Bytes
e559386
 
 
 
 
 
 
 
83207f7
 
70ec849
e559386
 
 
 
 
 
70ec849
e559386
7b9afe1
e559386
 
 
70ec849
 
e559386
 
70ec849
 
e559386
70ec849
e559386
70ec849
e559386
70ec849
 
e559386
70ec849
 
 
 
 
 
 
 
 
 
 
 
 
 
e559386
70ec849
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Install Gradio if not already installed

import gradio as gr
from diffusers import DiffusionPipeline
import torch
import matplotlib.pyplot as plt
import numpy as np
import io
from transformers.utils import move_cache
move_cache()
from PIL import Image

# Load the SDXL refiner model pipeline
refiner = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0",
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16"
)
refiner.to("cpu")  # Use GPU if available, else change to "cpu"

# Function to generate the image
def generate_image(prompt, n_steps=20, high_noise_frac=0.8):
    # Generate refined image directly with the refiner model
    refined_image = refiner(
        prompt=prompt,
        num_inference_steps=n_steps,
        output_type="pil"
    ).images[0]
    
    # Save image to a buffer for download
    buf = io.BytesIO()
    refined_image.save(buf, format="PNG")
    buf.seek(0)
    
    return refined_image, buf

# Create Gradio app with explicit submit button
with gr.Blocks() as app:
    gr.Markdown("## AI Image Generator with Refinement")
    prompt = gr.Textbox(label="Enter a prompt", placeholder="e.g., A dragon flying", lines=2)
    n_steps = gr.Slider(minimum=10, maximum=50, step=1, value=20, label="Inference Steps")
    high_noise_frac = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.8, label="High Noise Fraction")
    generate_button = gr.Button("Generate Image")
    
    # Output display for the generated image and download button
    image_output = gr.Image(label="Generated Image")
    download_output = gr.File(label="Download Image")

    # Define button click action to trigger image generation
    generate_button.click(fn=generate_image, inputs=[prompt, n_steps, high_noise_frac], outputs=[image_output, download_output])

# Launch the app
app.launch()