File size: 9,113 Bytes
39dc0dd
1ba3b32
 
 
 
 
 
 
39dc0dd
1ba3b32
 
 
 
 
 
 
 
 
 
 
 
39dc0dd
1ba3b32
 
39dc0dd
1ba3b32
 
 
 
 
 
39dc0dd
1ba3b32
 
 
 
 
 
 
 
 
 
 
 
 
39dc0dd
1ba3b32
 
 
39dc0dd
1ba3b32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39dc0dd
1ba3b32
 
 
 
 
 
 
 
 
 
39dc0dd
1ba3b32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39dc0dd
1ba3b32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39dc0dd
1ba3b32
 
 
 
 
39dc0dd
1ba3b32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39dc0dd
1ba3b32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39dc0dd
1ba3b32
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import gradio as gr
import torch
from diffusers import DiffusionPipeline
import numpy as np
from PIL import Image
import random
from typing import List, Tuple, Optional
import time

# Initialize the model
def load_model():
    """Load the Stable Diffusion XL model"""
    pipe = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16,
        use_safetensors=True,
        variant="fp16"
    )
    if torch.cuda.is_available():
        pipe = pipe.to("cuda")
    return pipe

# Global model instance
model = None

def get_model():
    """Get or create the model instance"""
    global model
    if model is None:
        model = load_model()
    return model

def generate_images(
    prompt: str,
    negative_prompt: str = "",
    uploaded_images: Optional[List] = None,
    guidance_scale: float = 7.5,
    num_inference_steps: int = 50,
    strength: float = 0.8
) -> Tuple[List[Image.Image], str]:
    """
    Generate 4 images using Stable Diffusion XL
    """
    try:
        pipe = get_model()
        
        # Default prompt if empty
        if not prompt.strip():
            prompt = "a beautiful landscape, professional photography, high quality"
        
        # Check if we're doing img2img or text2img
        if uploaded_images and len(uploaded_images) > 0:
            # Use the first uploaded image for img2img
            init_image = uploaded_images[0]
            if isinstance(init_image, str):
                init_image = Image.open(init_image)
            
            # Generate 4 images with img2img
            images = []
            for i in range(4):
                result = pipe(
                    prompt=prompt,
                    negative_prompt=negative_prompt,
                    image=init_image,
                    num_images_per_prompt=1,
                    guidance_scale=guidance_scale,
                    num_inference_steps=num_inference_steps,
                    strength=strength
                )
                images.append(result.images[0])
        else:
            # Generate 4 images with text2img
            result = pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                num_images_per_prompt=4,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps
            )
            images = result.images
        
        return images, "✅ Images generated successfully!"
        
    except Exception as e:
        error_msg = f"❌ Error generating images: {str(e)}"
        return [], error_msg

def select_image_for_detail(gallery_data: List, evt: gr.SelectData) -> Tuple[Optional[Image.Image], str]:
    """
    Handle image selection from gallery for detailed view
    """
    if gallery_data and evt.index < len(gallery_data):
        selected_image = gallery_data[evt.index]
        if isinstance(selected_image, str):
            selected_image = Image.open(selected_image)
        return selected_image, f"📸 Selected image {evt.index + 1} for detailed view"
    return None, "No image selected"

# Custom CSS for better styling
custom_css = """
.main-container {
    max-width: 1200px;
    margin: 0 auto;
}
.gallery-container {
    border: 2px solid #e0e0e0;
    border-radius: 8px;
    padding: 10px;
}
.selected-image-container {
    border: 3px solid #4CAF50;
    border-radius: 8px;
    padding: 10px;
}
.generate-btn {
    background: linear-gradient(45deg, #667eea 0%, #764ba2 100%) !important;
}
"""

# Create the Gradio interface
with gr.Blocks(css=custom_css, title="Stable Diffusion XL Demo") as demo:
    gr.HTML("""
    <div style="text-align: center; margin-bottom: 20px;">
        <h1>🎨 Stable Diffusion XL Image Generator</h1>
        <p>Generate amazing images with AI - Upload up to 4 images for img2img or use text prompts</p>
        <p style="font-size: 0.9em; color: #666;">
            Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank" style="color: #667eea;">anycoder</a>
        </p>
    </div>
    """)
    
    with gr.Row():
        with gr.Column(scale=2):
            # Input controls
            with gr.Group():
                gr.Markdown("### 📝 Text Prompts")
                prompt = gr.Textbox(
                    label="Prompt",
                    placeholder="Describe the image you want to generate...",
                    lines=3,
                    value="a beautiful landscape, professional photography, high quality"
                )
                negative_prompt = gr.Textbox(
                    label="Negative Prompt",
                    placeholder="What you don't want in the image...",
                    lines=2,
                    value="blurry, low quality, distorted, deformed"
                )
            
            with gr.Group():
                gr.Markdown("### 🖼️ Image Upload (Optional)")
                uploaded_images = gr.File(
                    label="Upload images for img2img (up to 4)",
                    file_count="multiple",
                    file_types=["image"],
                    height=120
                )
            
            with gr.Group():
                gr.Markdown("### ⚙️ Generation Settings")
                with gr.Row():
                    guidance_scale = gr.Slider(
                        label="Guidance Scale",
                        minimum=1.0,
                        maximum=20.0,
                        value=7.5,
                        step=0.5,
                        info="Higher = more prompt adherence"
                    )
                    num_inference_steps = gr.Slider(
                        label="Steps",
                        minimum=10,
                        maximum=100,
                        value=50,
                        step=5,
                        info="More steps = higher quality"
                    )
                
                strength = gr.Slider(
                    label="Img2Img Strength",
                    minimum=0.1,
                    maximum=1.0,
                    value=0.8,
                    step=0.1,
                    info="How much to transform the input image"
                )
            
            generate_btn = gr.Button(
                "🚀 Generate Images",
                variant="primary",
                size="lg",
                elem_classes=["generate-btn"]
            )
            
            status_msg = gr.Textbox(label="Status", interactive=False)
        
        with gr.Column(scale=3):
            # Output area
            with gr.Group(elem_classes=["gallery-container"]):
                gr.Markdown("### 🎯 Generated Images (Click to select)")
                gallery = gr.Gallery(
                    label="Generated Images",
                    columns=2,
                    rows=2,
                    height=400,
                    object_fit="cover",
                    allow_preview=True,
                    show_label=False
                )
            
            with gr.Group(elem_classes=["selected-image-container"]):
                gr.Markdown("### 🔍 Selected Image (Close-up View)")
                selected_image = gr.Image(
                    label="Selected Image",
                    height=400,
                    show_label=False
                )
                selection_info = gr.Textbox(label="Selection Info", interactive=False)
    
    # Event handlers
    generate_btn.click(
        fn=generate_images,
        inputs=[
            prompt,
            negative_prompt,
            uploaded_images,
            guidance_scale,
            num_inference_steps,
            strength
        ],
        outputs=[gallery, status_msg],
        show_progress=True
    )
    
    gallery.select(
        fn=select_image_for_detail,
        inputs=[gallery],
        outputs=[selected_image, selection_info]
    )
    
    # Examples
    gr.Examples(
        examples=[
            ["a majestic dragon flying over mountains, fantasy art, highly detailed", "cartoon, blurry", None],
            ["a futuristic city skyline at sunset, cyberpunk, neon lights", "daylight, rural", None],
            ["a portrait of a wise old wizard with a long beard, fantasy", "young, modern", None],
            ["a serene japanese garden with cherry blossoms, peaceful", "chaotic, urban", None]
        ],
        inputs=[prompt, negative_prompt, uploaded_images],
        cache_examples=False
    )
    
    # Footer
    gr.HTML("""
    <div style="text-align: center; margin-top: 30px; padding-top: 20px; border-top: 1px solid #e0e0e0;">
        <p style="color: #666; font-size: 0.9em;">
            Powered by <a href="https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0" target="_blank">Stability AI</a> • 
            Model: SDXL Base 1.0 • 
            <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">Built with anycoder</a>
        </p>
    </div>
    """)

if __name__ == "__main__":
    demo.launch(share=True)