File size: 3,656 Bytes
fe01e72
dd0f699
fe01e72
 
 
 
a9c4647
 
 
fe01e72
137a5e1
fe01e72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
052501f
fe01e72
 
 
 
 
 
 
a9c4647
 
 
 
 
 
 
 
 
 
 
fe01e72
 
 
 
 
a9c4647
 
fe01e72
 
a9c4647
 
 
 
 
 
 
fe01e72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a39d55
052501f
fe01e72
 
 
 
 
 
052501f
 
fe01e72
 
 
 
 
a9c4647
 
 
 
fe01e72
a9c4647
fe01e72
5a39d55
 
 
fe01e72
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
import spaces
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
import numpy as np
import imageio
import tempfile
import os

MODEL_ID = "stabilityai/stable-diffusion-2"

# Global pipeline variable
pipe = None


def initialize_pipeline():
    """Initialize the pipeline if not already loaded."""
    global pipe
    if pipe is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Initializing pipeline on device: {device}")

        pipe = StableDiffusionPipeline.from_pretrained(
            MODEL_ID,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        )
        pipe = pipe.to(device)
    return pipe


@spaces.GPU
def generate_image(prompt, seed, num_inference_steps):
    
    # Initialize pipeline
    pipeline = initialize_pipeline()
    device = pipeline.device

    # Set the random seed for reproducibility
    generator = torch.Generator(device=device).manual_seed(int(seed))

    # Store intermediate images
    frames = []

    def callback(step: int, timestep: int, latents):
        # Decode latents to image
        with torch.no_grad():
            image = pipeline.decode_latents(latents)
            image = pipeline.numpy_to_pil(image)[0]
            frames.append(image)

    # Generate the image with callback for each step
    with torch.no_grad():
        result = pipeline(
            prompt=prompt,
            num_inference_steps=int(num_inference_steps),
            generator=generator,
            callback=callback,
            callback_steps=1,
        )

    # Save frames as video
    with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
        video_path = tmpfile.name
        imageio.mimsave(video_path, frames, fps=5)

    # Return final image and video path
    return result.images[0], video_path


def create_interface():
    """Create and configure the Gradio interface."""
    # Create the Gradio interface
    interface = gr.Interface(
        fn=generate_image,
        inputs=[
            gr.Textbox(
                label="Prompt",
                placeholder="Enter a text description of the image you want to generate...",
                lines=3,
            ),
            gr.Slider(
                minimum=0,
                maximum=1000000,
                randomize=True,
                step=1,
                label="Seed",
                info="Random seed for reproducibility",
            ),
            gr.Slider(
                minimum=1,
                maximum=50,
                value=15,
                step=1,
                label="Diffusion Steps",
                info="Number of denoising steps (more steps = higher quality but slower)",
            ),
        ],
        outputs=[
            gr.Image(label="Generated Image", type="pil"),
            gr.Video(label="Diffusion Steps Video"),
        ],
        title="Stable Diffusion Image Generator",
        description="Generate images from text using Stable Diffusion. Enter a prompt, set the seed for reproducibility, and adjust the number of diffusion steps. Watch the diffusion process as a video.",
        examples=[
            ["A beautiful sunset over mountains", 42213, 50],
            ["A dog wearing a space suit, floating in space, hand-drawn illustration", 83289, 20],
            ["Cyberpunk city at night, neon lights", 12056, 40],
        ],
        cache_examples=False,
    )

    return interface


if __name__ == "__main__":
    # Create and launch the interface
    demo = create_interface()
    demo.launch(share=False, server_name="0.0.0.0", server_port=7860)