File size: 10,241 Bytes
6a07ce1
 
3631a8e
6a07ce1
 
 
 
 
 
570384a
 
 
3631a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6a07ce1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570384a
 
 
6a07ce1
570384a
6a07ce1
570384a
 
6a07ce1
 
60d66bd
b1e7bdb
 
 
 
 
6a07ce1
 
60d66bd
 
 
8cdb001
b1e7bdb
 
570384a
b1e7bdb
 
 
 
 
6a07ce1
b1e7bdb
 
 
 
 
 
570384a
b1e7bdb
 
 
 
6a07ce1
60d66bd
6a07ce1
60d66bd
8cdb001
b1e7bdb
 
 
 
 
 
570384a
60d66bd
8cdb001
60d66bd
 
6a07ce1
3631a8e
b1e7bdb
 
 
 
570384a
b1e7bdb
 
570384a
b1e7bdb
8cdb001
b1e7bdb
 
8cdb001
459ac47
 
 
 
 
 
 
 
 
 
6a07ce1
 
 
 
 
 
 
 
 
 
 
3631a8e
 
6a07ce1
60d66bd
8cdb001
6a07ce1
60d66bd
570384a
b1e7bdb
 
570384a
b1e7bdb
 
 
 
570384a
b1e7bdb
 
 
 
 
 
570384a
60d66bd
 
 
 
 
570384a
60d66bd
 
b1e7bdb
3631a8e
570384a
3631a8e
 
6a07ce1
8cdb001
3631a8e
8cdb001
3631a8e
 
b1e7bdb
6a07ce1
3631a8e
 
 
570384a
 
 
b1e7bdb
8cdb001
6a07ce1
 
60d66bd
570384a
b1e7bdb
6a07ce1
 
b1e7bdb
570384a
b1e7bdb
 
 
 
6a07ce1
 
 
 
60d66bd
6a07ce1
 
 
 
570384a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""Pipeline management for SDXL Model Merger."""

import torch
from diffusers import (
    StableDiffusionXLPipeline,
    AutoencoderKL,
    DPMSolverSDEScheduler,
)

from . import config
from .config import device, dtype, CACHE_DIR, device_description, is_running_on_spaces, set_download_cancelled
from .downloader import get_safe_filename_from_url, download_file_with_progress
from .gpu_decorator import GPU


@GPU(duration=300)
def _load_and_setup_pipeline(checkpoint_path, vae_path, lora_paths_and_strengths, load_kwargs):
    """GPU-decorated helper that performs all GPU-intensive pipeline setup."""
    _pipe = StableDiffusionXLPipeline.from_single_file(
        str(checkpoint_path),
        **load_kwargs,
    )
    print("  βœ… Text encoders loaded")

    # Move to device (unless using device_map='auto' which handles this automatically)
    if not is_running_on_spaces() or device != "cpu":
        print(f"  βš™οΈ Moving pipeline to device: {device_description}...")
        _pipe = _pipe.to(device=device, dtype=dtype)

    # Load custom VAE if provided
    if vae_path is not None:
        print("  βš™οΈ Loading VAE weights...")
        vae = AutoencoderKL.from_single_file(
            str(vae_path),
            torch_dtype=dtype,
        )
        print("  βš™οΈ Setting custom VAE...")
        _pipe.vae = vae.to(device=device, dtype=torch.float32)

    # Load and fuse each LoRA
    if lora_paths_and_strengths:
        # Ensure pipeline is on device for LoRA fusion
        _pipe = _pipe.to(device=device, dtype=dtype)

        for i, (lora_path, strength) in enumerate(lora_paths_and_strengths):
            adapter_name = f"lora_{i}"
            print(f"  βš™οΈ Loading LoRA {i+1}/{len(lora_paths_and_strengths)}...")
            _pipe.load_lora_weights(str(lora_path), adapter_name=adapter_name)
            print(f"  βš™οΈ Fusing LoRA {i+1} with strength={strength}...")
            _pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength)
            _pipe.unload_lora_weights()
    else:
        # Move pipeline to device even without LoRAs
        _pipe = _pipe.to(device=device, dtype=dtype)

    # Set scheduler
    print("  βš™οΈ Configuring scheduler...")
    _pipe.scheduler = DPMSolverSDEScheduler.from_config(
        _pipe.scheduler.config,
        algorithm_type="sde-dpmsolver++",
        use_karras_sigmas=False,
    )

    # Keep VAE in float32 to prevent colorful static output
    _pipe.vae.to(dtype=torch.float32)

    return _pipe


def load_pipeline(
    checkpoint_url: str,
    vae_url: str,
    lora_urls_str: str,
    lora_strengths_str: str,
    progress=None
) -> tuple[str, str]:
    """
    Load SDXL pipeline with checkpoint, VAE, and LoRAs.

    Args:
        checkpoint_url: URL to base model .safetensors file
        vae_url: Optional URL to VAE .safetensors file
        lora_urls_str: Newline-separated URLs for LoRA models
        lora_strengths_str: Comma-separated strength values for each LoRA
        progress: Optional gr.Progress() object for UI updates

    Yields:
        Tuple of (status_message, progress_text) at each loading stage.

    Returns:
        Final yielded tuple of (final_status_message, progress_text)
    """
    # Clear any previously loaded pipeline so the UI reflects loading state
    config.set_pipe(None)

    try:
        set_download_cancelled(False)

        print("=" * 60)
        print("πŸ”„ Loading SDXL Pipeline...")
        print("=" * 60)

        checkpoint_filename = get_safe_filename_from_url(checkpoint_url, type_prefix="model")
        checkpoint_path = CACHE_DIR / checkpoint_filename

        # Check if checkpoint is already cached
        checkpoint_cached = checkpoint_path.exists() and checkpoint_path.stat().st_size > 0

        # Validate cache file before using it
        if checkpoint_cached:
            is_valid, msg = config.validate_cache_file(checkpoint_path)
            if not is_valid:
                print(f"  ⚠️ Cache invalid: {msg}")
                checkpoint_path.unlink(missing_ok=True)
                checkpoint_cached = False

        # VAE: Use suffix="_vae" and default to "vae.safetensors" for proper caching/dropdown matching
        vae_filename = get_safe_filename_from_url(vae_url, default_name="vae.safetensors", suffix="_vae") if vae_url.strip() else None
        vae_path = CACHE_DIR / vae_filename if vae_filename else None
        vae_cached = vae_url.strip() and vae_path and vae_path.exists() and vae_path.stat().st_size > 0

        # Validate VAE cache file before using it
        if vae_cached:
            is_valid, msg = config.validate_cache_file(vae_path)
            if not is_valid:
                print(f"  ⚠️ VAE Cache invalid: {msg}")
                vae_path.unlink(missing_ok=True)
                vae_cached = False

        # Download checkpoint (skips if already cached)
        if progress:
            progress(0.1, desc="Downloading base model..." if not checkpoint_cached else "Loading base model...")

        if not checkpoint_cached:
            status_msg = f"πŸ“₯ Downloading {checkpoint_path.name}..."
            print(f"  πŸ“₯ Downloading: {checkpoint_path.name}")
        else:
            status_msg = f"βœ… Using cached {checkpoint_path.name}"
            print(f"  βœ… Using cached: {checkpoint_path.name}")

        yield status_msg, "Starting download..."

        if not checkpoint_cached:
            download_file_with_progress(checkpoint_url, checkpoint_path)

        # Download VAE if provided (loading happens in _load_and_setup_pipeline)
        if vae_url and vae_url.strip():
            if vae_path:
                status_msg = f"πŸ“₯ Downloading {vae_path.name}..." if not vae_cached else f"βœ… Using cached {vae_path.name}"
                print(f"  πŸ“₯ VAE: {vae_path.name}" if not vae_cached else f"  βœ… VAE (cached): {vae_path.name}")

                if progress:
                    progress(0.2, desc="Downloading VAE..." if not vae_cached else "Loading VAE...")

                yield status_msg, f"Downloading VAE: {vae_path.name}" if not vae_cached else f"Using cached VAE: {vae_path.name}"

                if not vae_cached:
                    download_file_with_progress(vae_url, vae_path)

        # For CPU/low-memory environments on Spaces, use device_map for better RAM management
        load_kwargs = {
            "torch_dtype": dtype,
            "use_safetensors": True,
        }

        if is_running_on_spaces() and device == "cpu":
            print("  ℹ️ CPU mode detected: enabling device_map='auto' for better RAM management")
            load_kwargs["device_map"] = "auto"

        # Parse LoRA URLs & ensure strengths list matches
        lora_urls = [u.strip() for u in lora_urls_str.split("\n") if u.strip()]
        strengths_raw = [s.strip() for s in lora_strengths_str.split(",")]
        strengths = []
        for i, url in enumerate(lora_urls):
            try:
                val = float(strengths_raw[i]) if i < len(strengths_raw) else 1.0
                strengths.append(val)
            except ValueError:
                strengths.append(1.0)

        # Download LoRAs (CPU-bound downloads, before GPU work)
        lora_paths_and_strengths = []
        if lora_urls:
            for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)):
                lora_filename = get_safe_filename_from_url(lora_url, suffix="_lora")
                lora_path = CACHE_DIR / lora_filename
                lora_cached = lora_path.exists() and lora_path.stat().st_size > 0

                # Validate LoRA cache file before using it
                if lora_cached:
                    is_valid, msg = config.validate_cache_file(lora_path)
                    if not is_valid:
                        print(f"  ⚠️ LoRA Cache invalid: {msg}")
                        lora_path.unlink(missing_ok=True)
                        lora_cached = False

                if not lora_cached:
                    print(f"  πŸ“₯ LoRA {i+1}/{len(lora_urls)}: Downloading {lora_path.name}...")
                    status_msg = f"πŸ“₯ Downloading LoRA {i+1}/{len(lora_urls)}: {lora_path.name}..."
                else:
                    print(f"  βœ… LoRA {i+1}/{len(lora_urls)}: Using cached {lora_path.name}")
                    status_msg = f"βœ… Using cached LoRA {i+1}/{len(lora_urls)}: {lora_path.name}"

                yield (
                    status_msg,
                    f"Downloading LoRA {i+1}/{len(lora_urls)} ({lora_path.name})..." if not lora_cached
                    else f"Using cached LoRA {i+1}/{len(lora_urls)} ({lora_path.name})"
                )

                if not lora_cached:
                    download_file_with_progress(lora_url, lora_path)

                lora_paths_and_strengths.append((lora_path, strength))

        # All downloads complete β€” now do GPU-intensive setup in one decorated call
        yield "βš™οΈ Loading SDXL pipeline...", "Loading model weights into memory..."

        if progress:
            progress(0.5, desc="Loading pipeline...")

        _pipe = _load_and_setup_pipeline(
            checkpoint_path, vae_path, lora_paths_and_strengths, load_kwargs
        )

        if progress:
            progress(0.95, desc="Finalizing...")

        # βœ… Only publish the pipeline globally AFTER all steps succeed
        config.set_pipe(_pipe)

        print("  βœ… Pipeline ready!")
        yield "βœ… Pipeline ready!", f"Ready! Loaded {len(lora_urls)} LoRA(s)"

    except KeyboardInterrupt:
        set_download_cancelled(False)
        config.set_pipe(None)
        print("\n⚠️ Download cancelled by user")
        return ("⚠️ Download cancelled by user", "Cancelled")
    except Exception as e:
        import traceback
        config.set_pipe(None)
        error_msg = f"❌ Error loading pipeline: {str(e)}"
        print(f"\n{error_msg}")
        print(traceback.format_exc())
        return (error_msg, f"Error: {str(e)}")


def cancel_download():
    """Set the global cancellation flag to stop any ongoing downloads."""
    set_download_cancelled(True)


def get_pipeline() -> StableDiffusionXLPipeline | None:
    """Get the currently loaded pipeline."""
    return config.get_pipe()