File size: 1,823 Bytes
d5b07c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# Import necessary libraries
from diffusers import StableDiffusionPipeline  # Import StableDiffusionPipeline for image generation
import torch  # Import PyTorch for deep learning operations
import gradio as gr  # Import Gradio for creating a web interface

# Define configuration parameters
class CFG:
    image_gen_steps = 35  # Number of steps for image generation
    image_gen_model_id = "stabilityai/stable-diffusion-2"  # ID of the StableDiffusion model
    image_gen_size = (400, 400)  # Size of the generated image
    image_gen_guidance_scale = 9  # Guidance scale for image generation

# Load the StableDiffusion model
image_gen_model = StableDiffusionPipeline.from_pretrained(
    CFG.image_gen_model_id,     
    revision="fp16", 
    guidance_scale=9
)

# Define a function for image generation
def generate_image(prompt):
    # Generate an image from a text prompt using the loaded model
    image = image_gen_model(
        prompt, 
        num_inference_steps=CFG.image_gen_steps,
        guidance_scale=CFG.image_gen_guidance_scale
    ).images[0]

    # Resize the generated image to the specified size
    image = image.resize(CFG.image_gen_size)
    return image  # Return the generated image as the result

# Define a Gradio interface
iface = gr.Interface(
    fn=generate_image,  # Use the generate_image function for processing input
    inputs="text",  # Accept text input from the user
    outputs="image",  # Display the generated image as output
    title="StableDiffusion Image Generation",  # Set the title of the web interface
    description="Generate images from text prompts using StableDiffusion model.",  # Provide a description
    live=False  # Set to False if you don't want real-time updates (for beginner-friendly interaction)
)

# Start the Gradio interface
iface.launch(debug=True)