File size: 21,882 Bytes
6a07ce1
 
 
 
 
 
 
 
 
 
 
 
57a52e9
 
 
 
 
6a07ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570384a
6a07ce1
 
 
 
60d66bd
6a07ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570384a
6a07ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570384a
6a07ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
570384a
 
 
 
 
 
 
6a07ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570384a
6a07ce1
 
 
 
 
 
 
 
570384a
6a07ce1
8cdb001
6a07ce1
 
 
570384a
6a07ce1
 
 
 
 
570384a
6a07ce1
 
 
 
 
570384a
6a07ce1
 
 
 
b1e7bdb
60d66bd
 
 
570384a
60d66bd
 
 
 
 
 
 
 
6a07ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60d66bd
 
 
6a07ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570384a
6a07ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570384a
6a07ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d723e62
6a07ce1
 
d723e62
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
#!/usr/bin/env python3
"""
SDXL Model Merger - Modernized with modular architecture and improved UI/UX.

This application allows you to:
- Load SDXL checkpoints with optional VAE and multiple LoRAs
- Generate images with seamless tiling support
- Export merged models with quantization options

Author: Qwen Code Assistant
"""

try:
    import spaces  # noqa: F401 β€” must be imported before torch/CUDA packages
except ImportError:
    pass

import gradio as gr


def create_app():
    """Create and configure the Gradio app."""

    header_css = """
    .header-gradient {
        background: linear-gradient(135deg, #10b981 0%, #7c3aed 100%);
        -webkit-background-clip: text;
        -webkit-text-fill-color: transparent;
        background-clip: text;
    }

    .feature-card {
        border-radius: 12px;
        padding: 20px;
        margin-bottom: 16px;
        box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
        transition: transform 0.2s ease;
    }

    .feature-card:hover {
        transform: translateY(-2px);
        box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1);
    }

    .gradio-container .label {
        font-weight: 600;
        color: #374151;
        margin-bottom: 8px;
    }

    .status-success { color: #059669 !important; font-weight: 600; }
    .status-error { color: #dc2626 !important; font-weight: 600; }
    .status-warning { color: #d97706 !important; font-weight: 600; }

    .gradio-container .btn {
        border-radius: 8px;
        padding: 12px 24px;
        font-weight: 600;
    }

    .gradio-container textarea,
    .gradio-container input[type="number"],
    .gradio-container input[type="text"] {
        border-radius: 8px;
        border-color: #d1d5db;
    }

    .gradio-container textarea:focus,
    .gradio-container input:focus {
        outline: none;
        border-color: #6366f1;
        box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.1);
    }

    .gradio-container .tabitem {
        background: transparent;
        border-radius: 12px;
    }

    .progress-text {
        font-weight: 500;
        color: #6b7280 !important;
    }
    """

    from src.pipeline import load_pipeline
    from src.generator import generate_image
    from src.exporter import export_merged_model
    from src.config import get_cached_models, get_cached_checkpoints, get_cached_vaes, get_cached_loras

    with gr.Blocks(title="SDXL Model Merger") as demo:
        # Header section
        with gr.Column(elem_classes=["feature-card"]):
            gr.HTML("""
                <div style="text-align: center; margin-bottom: 24px;">
                    <h1 style="font-size: 2.5em; margin: 0; line-height: 1.2;">
                        <span class="header-gradient">SDXL Model Merger</span>
                    </h1>
                    <p style="color: #6b7280; font-size: 1.1em; max-width: 600px; margin: 16px auto;">
                        Merge checkpoints, LoRAs, and VAEs - then bake LoRAs into a single exportable
                        checkpoint with optional quantization.
                    </p>
                </div>
            """)

            # Feature highlights
            with gr.Row():
                with gr.Column(scale=1):
                    gr.HTML("""
                        <div style="text-align: center; padding: 16px;">
                            <div style="font-size: 2.5em; margin-bottom: 8px;">πŸš€</div>
                            <strong>Fast Loading</strong>
                            <p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">With progress tracking & cache</p>
                        </div>
                    """)
                with gr.Column(scale=1):
                    gr.HTML("""
                        <div style="text-align: center; padding: 16px;">
                            <div style="font-size: 2.5em; margin-bottom: 8px;">🎨</div>
                            <strong>Panorama Gen</strong>
                            <p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">Seamless tiling support</p>
                        </div>
                    """)
                with gr.Column(scale=1):
                    gr.HTML("""
                        <div style="text-align: center; padding: 16px;">
                            <div style="font-size: 2.5em; margin-bottom: 8px;">πŸ“¦</div>
                            <strong>Export Ready</strong>
                            <p style="font-size: 0.85em; color: #6b7280; margin-top: 4px;">Quantization & format options</p>
                        </div>
                    """)

        gr.Markdown("---")

        with gr.Tab("Load Pipeline"):
            gr.Markdown("### Load SDXL Pipeline with Checkpoint, VAE, and LoRAs")

            # Progress indicator for pipeline loading
            load_progress = gr.Textbox(
                label="Loading Progress",
                placeholder="Ready to start...",
                show_label=True,
                info="Real-time status of model downloads and pipeline setup"
            )

            with gr.Row():
                with gr.Column(scale=2):
                    # Checkpoint URL with cached models dropdown
                    checkpoint_url = gr.Textbox(
                        label="Base Model (.safetensors) URL",
                        value="https://civitai.com/api/download/models/354657?type=Model&format=SafeTensor&size=full&fp=fp16",
                        placeholder="e.g., https://civitai.com/api/download/models/...",
                        info="Download link for the base SDXL checkpoint"
                    )

                    # Dropdown of cached checkpoints
                    cached_checkpoints = gr.Dropdown(
                        choices=["(None found)"] + get_cached_checkpoints(),
                        label="Cached Checkpoints",
                        value="(None found)" if not get_cached_checkpoints() else None,
                        info="Models already downloaded to .cache/"
                    )

                    # VAE URL
                    vae_url = gr.Textbox(
                        label="VAE (.safetensors) URL",
                        value="https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/resolve/main/sdxl.vae.safetensors?download=true",
                        placeholder="Leave blank to use model's built-in VAE",
                        info="Optional custom VAE for improved quality"
                    )

                    # Dropdown of cached VAEs
                    cached_vaes = gr.Dropdown(
                        choices=["(None found)"] + get_cached_vaes(),
                        label="Cached VAEs",
                        value="(None found)" if not get_cached_vaes() else None,
                        info="Select a VAE to load"
                    )

                with gr.Column(scale=1):
                    # LoRA URLs input
                    lora_urls = gr.Textbox(
                        label="LoRA URLs (one per line)",
                        lines=5,
                        value="https://civitai.com/api/download/models/143197?type=Model&format=SafeTensor",
                        placeholder="https://civit.ai/...\nhttps://huggingface.co/...",
                        info="Multiple LoRAs can be loaded and fused together"
                    )

                    # Dropdown of cached LoRAs
                    cached_loras = gr.Dropdown(
                        choices=["(None found)"] + get_cached_loras(),
                        label="Cached LoRAs",
                        value="(None found)" if not get_cached_loras() else None,
                        info="Select a LoRA to add to the list below"
                    )

                    lora_strengths = gr.Textbox(
                        label="LoRA Strengths",
                        value="1.0",
                        placeholder="e.g., 0.8,1.0,0.5",
                        info="Comma-separated strength values for each LoRA"
                    )

            with gr.Row():
                load_btn = gr.Button("πŸš€ Load Pipeline", variant="primary", size="lg")

            # Detailed status display
            load_status = gr.HTML(
                label="Status",
                value='<div class="status-success">βœ… Ready to load pipeline</div>',
            )

        with gr.Tab("Generate Image"):
            gr.Markdown("### Generate Panorama Images with Seamless Tiling")

            # Progress indicator for image generation
            gen_progress = gr.Textbox(
                label="Generation Progress",
                placeholder="Ready to generate...",
                show_label=True,
                info="Real-time status of image generation"
            )

            with gr.Row():
                with gr.Column(scale=1):
                    prompt = gr.Textbox(
                        label="Positive Prompt",
                        value="Glowing mushrooms around pyramids amidst a cosmic backdrop, equirectangular, 360 panorama, cinematic",
                        lines=4,
                        placeholder="Describe the image you want to generate..."
                    )

                    cfg = gr.Slider(
                        minimum=1.0, maximum=20.0, value=3.0, step=0.5,
                        label="CFG Scale",
                        info="Higher values make outputs match prompt more strictly"
                    )

                    height = gr.Number(
                        value=1024, precision=0,
                        label="Height (pixels)",
                        info="Output image height"
                    )

                with gr.Column(scale=1):
                    negative_prompt = gr.Textbox(
                        label="Negative Prompt",
                        value="boring, text, signature, watermark, low quality, bad quality",
                        lines=4,
                        placeholder="Elements to avoid in generation..."
                    )

                    steps = gr.Slider(
                        minimum=1, maximum=100, value=8, step=1,
                        label="Inference Steps",
                        info="More steps = better quality but slower"
                    )

                    width = gr.Number(
                        value=2048, precision=0,
                        label="Width (pixels)",
                        info="Output image width"
                    )

            with gr.Row():
                tile_x = gr.Checkbox(True, label="X-axis Seamless Tiling")
                tile_y = gr.Checkbox(False, label="Y-axis Seamless Tiling")

            seed = gr.Number(
                value=80484030936239,
                precision=0,
                label="Seed",
                info="Random seed for reproducible generation"
            )

            with gr.Row():
                gen_btn = gr.Button("✨ Generate Image", variant="secondary", size="lg")

            with gr.Row():
                image_output = gr.Image(
                    label="Result",
                    height=400,
                    show_label=True
                )
                with gr.Column():
                    gen_status = gr.HTML(
                        label="Generation Status",
                        value='<div class="status-success">βœ… Ready to generate</div>',
                    )

                    gr.HTML("""
                        <div style="margin-top: 16px; padding: 12px; background-color: #e5e7eb !important; border-radius: 8px;">
                            <strong style="color: #1f2937 !important;">πŸ’‘ Tips:</strong>
                            <ul style="margin: 8px 0; padding-left: 20px; font-size: 0.9em; color: #1f2937 !important;">
                                <li>Use wide aspect ratios (e.g., 1024x2048) for panoramas</li>
                                <li>Enable seamless tiling for texture-like outputs</li>
                                <li>Lower CFG (3-5) for more creative results</li>
                            </ul>
                        </div>
                    """)

        with gr.Tab("Export Model"):
            gr.Markdown("### Export Merged Checkpoint with Quantization Options")

            # Progress indicator for export
            export_progress = gr.Textbox(
                label="Export Progress",
                placeholder="Ready to export...",
                show_label=True,
                info="Real-time status of model export and quantization"
            )

            with gr.Row():
                include_lora = gr.Checkbox(
                    True,
                    label="Include Fused LoRAs",
                    info="Bake the loaded LoRAs into the exported model"
                )

                quantize_toggle = gr.Checkbox(
                    False,
                    label="Apply Quantization",
                    info="Reduce model size with quantization"
                )

            qtype_row = gr.Row(visible=True)
            with qtype_row:
                qtype_dropdown = gr.Dropdown(
                    choices=["none", "int8", "int4", "float8"],
                    value="int8",
                    label="Quantization Method",
                    info="Trade quality for smaller file size"
                )

            with gr.Row():
                format_dropdown = gr.Dropdown(
                    choices=["safetensors", "bin"],
                    value="safetensors",
                    label="Export Format",
                    info="safetensors is recommended for safety"
                )

            with gr.Row():
                export_btn = gr.Button("πŸ’Ύ Save Merged Checkpoint", variant="primary", size="lg")

            with gr.Row():
                download_link = gr.File(
                    label="Download Merged File",
                    show_label=True,
                )

                with gr.Column():
                    export_status = gr.HTML(
                        label="Export Status",
                        value='<div class="status-success">βœ… Ready to export</div>',
                    )

                    gr.HTML("""
                        <div style="margin-top: 16px; padding: 12px; background: #e0f2fe; border-radius: 8px;">
                            <strong>ℹ️ About Quantization:</strong>
                            <p style="font-size: 0.9em; margin: 8px 0;">
                                Reduces model size by lowering precision. Int8 is typically
                                lossless for inference while cutting size in half.
                            </p>
                        </div>
                    """)

        # Event handlers - all inside Blocks context

        def on_load_pipeline_start():
            """Called when pipeline loading starts."""
            return (
                '<div class="status-warning">⏳ Loading started...</div>',
                "Starting download...",
                gr.update(interactive=False)
            )

        def on_load_pipeline_complete(status_msg, progress_text):
            """Called when pipeline loading completes."""
            if "βœ…" in status_msg:
                return (
                    '<div class="status-success">βœ… Pipeline loaded successfully!</div>',
                    progress_text,
                    gr.update(interactive=True)
                )
            elif "⚠️" in status_msg or "cancelled" in status_msg.lower():
                return (
                    '<div class="status-warning">⚠️ Download cancelled</div>',
                    progress_text,
                    gr.update(interactive=True)
                )
            else:
                return (
                    f'<div class="status-error">{status_msg}</div>',
                    progress_text,
                    gr.update(interactive=True)
                )

        load_btn.click(
            fn=on_load_pipeline_start,
            inputs=[],
            outputs=[load_status, load_progress, load_btn],
        ).then(
            fn=load_pipeline,
            inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths],
            outputs=[load_status, load_progress],
            show_progress="full",
        ).then(
            fn=on_load_pipeline_complete,
            inputs=[load_status, load_progress],
            outputs=[load_status, load_progress, load_btn],
        ).then(
            fn=lambda: (
                gr.update(choices=["(None found)"] + get_cached_checkpoints()),
                gr.update(choices=["(None found)"] + get_cached_vaes()),
                gr.update(choices=["(None found)"] + get_cached_loras()),
            ),
            inputs=[],
            outputs=[cached_checkpoints, cached_vaes, cached_loras],
        )

        def on_cached_checkpoint_change(cached_path):
            """Update URL when a cached checkpoint is selected."""
            if cached_path and cached_path != "(None found)":
                return gr.update(value=f"file://{cached_path}")
            return gr.update()

        cached_checkpoints.change(
            fn=lambda x: gr.update(value=f"file://{x}" if x and x != "(None found)" else ""),
            inputs=cached_checkpoints,
            outputs=checkpoint_url,
        )

        def on_cached_vae_change(cached_path):
            """Update VAE URL when a cached VAE is selected."""
            if cached_path and cached_path != "(None found)":
                return gr.update(value=f"file://{cached_path}")
            return gr.update()

        cached_vaes.change(
            fn=on_cached_vae_change,
            inputs=cached_vaes,
            outputs=vae_url,
        )

        def on_cached_lora_change(cached_path, current_urls):
            """Add cached LoRA to the list."""
            if cached_path and cached_path != "(None found)":
                urls_list = [u.strip() for u in current_urls.split("\n") if u.strip()]
                file_url = f"file://{cached_path}"
                if file_url not in urls_list:
                    urls_list.append(file_url)
                    return gr.update(value="\n".join(urls_list))
            return gr.update()

        cached_loras.change(
            fn=on_cached_lora_change,
            inputs=[cached_loras, lora_urls],
            outputs=lora_urls,
        )


        def on_generate_start():
            """Called when image generation starts."""
            return (
                '<div class="status-warning">⏳ Generating image...</div>',
                "Starting generation...",
                gr.update(interactive=False)
            )

        def on_generate_complete(status_msg, progress_text, image):
            """Called when image generation completes."""
            if image is None:
                return (
                    f'<div class="status-error">{status_msg}</div>',
                    "",
                    gr.update(interactive=True),
                    gr.update()
                )
            else:
                return (
                    '<div class="status-success">βœ… Generation complete!</div>',
                    "Done",
                    gr.update(interactive=True),
                    gr.update(value=image)
                )

        gen_btn.click(
            fn=on_generate_start,
            inputs=[],
            outputs=[gen_status, gen_progress, gen_btn],
        ).then(
            fn=generate_image,
            inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y, seed],
            outputs=[image_output, gen_progress],
        ).then(
            fn=lambda img, msg: on_generate_complete(msg, "Done", img),
            inputs=[image_output, gen_progress],
            outputs=[gen_status, gen_progress, gen_btn, image_output],
        )

        def on_export_start():
            """Called when export starts."""
            return (
                '<div class="status-warning">⏳ Export started...</div>',
                "Starting export...",
                gr.update(interactive=False)
            )

        def on_export_complete(status_msg, progress_text, file_path):
            """Called when export completes."""
            if file_path is None:
                return (
                    f'<div class="status-error">{status_msg}</div>',
                    "",
                    gr.update(interactive=True),
                    gr.update(value=None)
                )
            else:
                return (
                    '<div class="status-success">βœ… Export complete!</div>',
                    "Exported successfully",
                    gr.update(interactive=True),
                    gr.update(value=file_path)
                )

        export_btn.click(
            fn=on_export_start,
            inputs=[],
            outputs=[export_status, export_progress, export_btn],
        ).then(
            fn=lambda inc, q, qt, fmt: export_merged_model(
                include_lora=inc,
                quantize=q and (qt != "none"),
                qtype=qt,  # always pass the string value; exporter handles "none" correctly
                save_format=fmt,
            ),
            inputs=[include_lora, quantize_toggle, qtype_dropdown, format_dropdown],
            outputs=[download_link, export_progress],
        ).then(
            fn=lambda path, msg: on_export_complete(msg, "Exported", path),
            inputs=[download_link, export_progress],
            outputs=[export_status, export_progress, export_btn, download_link],
        )

        quantize_toggle.change(
            fn=lambda checked: gr.update(visible=checked),
            inputs=[quantize_toggle],
            outputs=qtype_row,
        )

    return demo


demo = create_app()

if __name__ == "__main__":
    demo.launch()