Spaces:
Running on Zero
Running on Zero
File size: 10,241 Bytes
6a07ce1 3631a8e 6a07ce1 570384a 3631a8e 6a07ce1 570384a 6a07ce1 570384a 6a07ce1 570384a 6a07ce1 60d66bd b1e7bdb 6a07ce1 60d66bd 8cdb001 b1e7bdb 570384a b1e7bdb 6a07ce1 b1e7bdb 570384a b1e7bdb 6a07ce1 60d66bd 6a07ce1 60d66bd 8cdb001 b1e7bdb 570384a 60d66bd 8cdb001 60d66bd 6a07ce1 3631a8e b1e7bdb 570384a b1e7bdb 570384a b1e7bdb 8cdb001 b1e7bdb 8cdb001 459ac47 6a07ce1 3631a8e 6a07ce1 60d66bd 8cdb001 6a07ce1 60d66bd 570384a b1e7bdb 570384a b1e7bdb 570384a b1e7bdb 570384a 60d66bd 570384a 60d66bd b1e7bdb 3631a8e 570384a 3631a8e 6a07ce1 8cdb001 3631a8e 8cdb001 3631a8e b1e7bdb 6a07ce1 3631a8e 570384a b1e7bdb 8cdb001 6a07ce1 60d66bd 570384a b1e7bdb 6a07ce1 b1e7bdb 570384a b1e7bdb 6a07ce1 60d66bd 6a07ce1 570384a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 | """Pipeline management for SDXL Model Merger."""
import torch
from diffusers import (
StableDiffusionXLPipeline,
AutoencoderKL,
DPMSolverSDEScheduler,
)
from . import config
from .config import device, dtype, CACHE_DIR, device_description, is_running_on_spaces, set_download_cancelled
from .downloader import get_safe_filename_from_url, download_file_with_progress
from .gpu_decorator import GPU
@GPU(duration=300)
def _load_and_setup_pipeline(checkpoint_path, vae_path, lora_paths_and_strengths, load_kwargs):
"""GPU-decorated helper that performs all GPU-intensive pipeline setup."""
_pipe = StableDiffusionXLPipeline.from_single_file(
str(checkpoint_path),
**load_kwargs,
)
print(" β
Text encoders loaded")
# Move to device (unless using device_map='auto' which handles this automatically)
if not is_running_on_spaces() or device != "cpu":
print(f" βοΈ Moving pipeline to device: {device_description}...")
_pipe = _pipe.to(device=device, dtype=dtype)
# Load custom VAE if provided
if vae_path is not None:
print(" βοΈ Loading VAE weights...")
vae = AutoencoderKL.from_single_file(
str(vae_path),
torch_dtype=dtype,
)
print(" βοΈ Setting custom VAE...")
_pipe.vae = vae.to(device=device, dtype=torch.float32)
# Load and fuse each LoRA
if lora_paths_and_strengths:
# Ensure pipeline is on device for LoRA fusion
_pipe = _pipe.to(device=device, dtype=dtype)
for i, (lora_path, strength) in enumerate(lora_paths_and_strengths):
adapter_name = f"lora_{i}"
print(f" βοΈ Loading LoRA {i+1}/{len(lora_paths_and_strengths)}...")
_pipe.load_lora_weights(str(lora_path), adapter_name=adapter_name)
print(f" βοΈ Fusing LoRA {i+1} with strength={strength}...")
_pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength)
_pipe.unload_lora_weights()
else:
# Move pipeline to device even without LoRAs
_pipe = _pipe.to(device=device, dtype=dtype)
# Set scheduler
print(" βοΈ Configuring scheduler...")
_pipe.scheduler = DPMSolverSDEScheduler.from_config(
_pipe.scheduler.config,
algorithm_type="sde-dpmsolver++",
use_karras_sigmas=False,
)
# Keep VAE in float32 to prevent colorful static output
_pipe.vae.to(dtype=torch.float32)
return _pipe
def load_pipeline(
checkpoint_url: str,
vae_url: str,
lora_urls_str: str,
lora_strengths_str: str,
progress=None
) -> tuple[str, str]:
"""
Load SDXL pipeline with checkpoint, VAE, and LoRAs.
Args:
checkpoint_url: URL to base model .safetensors file
vae_url: Optional URL to VAE .safetensors file
lora_urls_str: Newline-separated URLs for LoRA models
lora_strengths_str: Comma-separated strength values for each LoRA
progress: Optional gr.Progress() object for UI updates
Yields:
Tuple of (status_message, progress_text) at each loading stage.
Returns:
Final yielded tuple of (final_status_message, progress_text)
"""
# Clear any previously loaded pipeline so the UI reflects loading state
config.set_pipe(None)
try:
set_download_cancelled(False)
print("=" * 60)
print("π Loading SDXL Pipeline...")
print("=" * 60)
checkpoint_filename = get_safe_filename_from_url(checkpoint_url, type_prefix="model")
checkpoint_path = CACHE_DIR / checkpoint_filename
# Check if checkpoint is already cached
checkpoint_cached = checkpoint_path.exists() and checkpoint_path.stat().st_size > 0
# Validate cache file before using it
if checkpoint_cached:
is_valid, msg = config.validate_cache_file(checkpoint_path)
if not is_valid:
print(f" β οΈ Cache invalid: {msg}")
checkpoint_path.unlink(missing_ok=True)
checkpoint_cached = False
# VAE: Use suffix="_vae" and default to "vae.safetensors" for proper caching/dropdown matching
vae_filename = get_safe_filename_from_url(vae_url, default_name="vae.safetensors", suffix="_vae") if vae_url.strip() else None
vae_path = CACHE_DIR / vae_filename if vae_filename else None
vae_cached = vae_url.strip() and vae_path and vae_path.exists() and vae_path.stat().st_size > 0
# Validate VAE cache file before using it
if vae_cached:
is_valid, msg = config.validate_cache_file(vae_path)
if not is_valid:
print(f" β οΈ VAE Cache invalid: {msg}")
vae_path.unlink(missing_ok=True)
vae_cached = False
# Download checkpoint (skips if already cached)
if progress:
progress(0.1, desc="Downloading base model..." if not checkpoint_cached else "Loading base model...")
if not checkpoint_cached:
status_msg = f"π₯ Downloading {checkpoint_path.name}..."
print(f" π₯ Downloading: {checkpoint_path.name}")
else:
status_msg = f"β
Using cached {checkpoint_path.name}"
print(f" β
Using cached: {checkpoint_path.name}")
yield status_msg, "Starting download..."
if not checkpoint_cached:
download_file_with_progress(checkpoint_url, checkpoint_path)
# Download VAE if provided (loading happens in _load_and_setup_pipeline)
if vae_url and vae_url.strip():
if vae_path:
status_msg = f"π₯ Downloading {vae_path.name}..." if not vae_cached else f"β
Using cached {vae_path.name}"
print(f" π₯ VAE: {vae_path.name}" if not vae_cached else f" β
VAE (cached): {vae_path.name}")
if progress:
progress(0.2, desc="Downloading VAE..." if not vae_cached else "Loading VAE...")
yield status_msg, f"Downloading VAE: {vae_path.name}" if not vae_cached else f"Using cached VAE: {vae_path.name}"
if not vae_cached:
download_file_with_progress(vae_url, vae_path)
# For CPU/low-memory environments on Spaces, use device_map for better RAM management
load_kwargs = {
"torch_dtype": dtype,
"use_safetensors": True,
}
if is_running_on_spaces() and device == "cpu":
print(" βΉοΈ CPU mode detected: enabling device_map='auto' for better RAM management")
load_kwargs["device_map"] = "auto"
# Parse LoRA URLs & ensure strengths list matches
lora_urls = [u.strip() for u in lora_urls_str.split("\n") if u.strip()]
strengths_raw = [s.strip() for s in lora_strengths_str.split(",")]
strengths = []
for i, url in enumerate(lora_urls):
try:
val = float(strengths_raw[i]) if i < len(strengths_raw) else 1.0
strengths.append(val)
except ValueError:
strengths.append(1.0)
# Download LoRAs (CPU-bound downloads, before GPU work)
lora_paths_and_strengths = []
if lora_urls:
for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)):
lora_filename = get_safe_filename_from_url(lora_url, suffix="_lora")
lora_path = CACHE_DIR / lora_filename
lora_cached = lora_path.exists() and lora_path.stat().st_size > 0
# Validate LoRA cache file before using it
if lora_cached:
is_valid, msg = config.validate_cache_file(lora_path)
if not is_valid:
print(f" β οΈ LoRA Cache invalid: {msg}")
lora_path.unlink(missing_ok=True)
lora_cached = False
if not lora_cached:
print(f" π₯ LoRA {i+1}/{len(lora_urls)}: Downloading {lora_path.name}...")
status_msg = f"π₯ Downloading LoRA {i+1}/{len(lora_urls)}: {lora_path.name}..."
else:
print(f" β
LoRA {i+1}/{len(lora_urls)}: Using cached {lora_path.name}")
status_msg = f"β
Using cached LoRA {i+1}/{len(lora_urls)}: {lora_path.name}"
yield (
status_msg,
f"Downloading LoRA {i+1}/{len(lora_urls)} ({lora_path.name})..." if not lora_cached
else f"Using cached LoRA {i+1}/{len(lora_urls)} ({lora_path.name})"
)
if not lora_cached:
download_file_with_progress(lora_url, lora_path)
lora_paths_and_strengths.append((lora_path, strength))
# All downloads complete β now do GPU-intensive setup in one decorated call
yield "βοΈ Loading SDXL pipeline...", "Loading model weights into memory..."
if progress:
progress(0.5, desc="Loading pipeline...")
_pipe = _load_and_setup_pipeline(
checkpoint_path, vae_path, lora_paths_and_strengths, load_kwargs
)
if progress:
progress(0.95, desc="Finalizing...")
# β
Only publish the pipeline globally AFTER all steps succeed
config.set_pipe(_pipe)
print(" β
Pipeline ready!")
yield "β
Pipeline ready!", f"Ready! Loaded {len(lora_urls)} LoRA(s)"
except KeyboardInterrupt:
set_download_cancelled(False)
config.set_pipe(None)
print("\nβ οΈ Download cancelled by user")
return ("β οΈ Download cancelled by user", "Cancelled")
except Exception as e:
import traceback
config.set_pipe(None)
error_msg = f"β Error loading pipeline: {str(e)}"
print(f"\n{error_msg}")
print(traceback.format_exc())
return (error_msg, f"Error: {str(e)}")
def cancel_download():
"""Set the global cancellation flag to stop any ongoing downloads."""
set_download_cancelled(True)
def get_pipeline() -> StableDiffusionXLPipeline | None:
"""Get the currently loaded pipeline."""
return config.get_pipe()
|