Spaces:
Running on Zero
Running on Zero
Kyle Pearson commited on
Commit Β·
60d66bd
1
Parent(s): 6a07ce1
Update Gradio to 6.9.0, add cached file checks in LoRA loading, fix unloading errors, optimize download cancellation, improve device/dtype handling, update optimum dependency.
Browse files- app.py +17 -8
- requirements.txt +2 -1
- src/downloader.py +56 -13
- src/exporter.py +4 -1
- src/pipeline.py +47 -32
app.py
CHANGED
|
@@ -83,7 +83,7 @@ def create_app():
|
|
| 83 |
from src.exporter import export_merged_model
|
| 84 |
from src.config import get_cached_models, get_cached_checkpoints, get_cached_vaes, get_cached_loras
|
| 85 |
|
| 86 |
-
with gr.Blocks(title="SDXL Model Merger"
|
| 87 |
# Header section
|
| 88 |
with gr.Column(elem_classes=["feature-card"]):
|
| 89 |
gr.HTML("""
|
|
@@ -393,7 +393,18 @@ def create_app():
|
|
| 393 |
fn=load_pipeline,
|
| 394 |
inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths],
|
| 395 |
outputs=[load_status, load_progress],
|
| 396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 397 |
)
|
| 398 |
|
| 399 |
def on_cached_checkpoint_change(cached_path):
|
|
@@ -423,10 +434,10 @@ def create_app():
|
|
| 423 |
def on_cached_lora_change(cached_path, current_urls):
|
| 424 |
"""Add cached LoRA to the list."""
|
| 425 |
if cached_path and cached_path != "(None found)":
|
| 426 |
-
# Add new LoRA to existing URLs (avoid duplicate)
|
| 427 |
urls_list = [u.strip() for u in current_urls.split("\n") if u.strip()]
|
| 428 |
-
|
| 429 |
-
|
|
|
|
| 430 |
return gr.update(value="\n".join(urls_list))
|
| 431 |
return gr.update()
|
| 432 |
|
|
@@ -470,7 +481,6 @@ def create_app():
|
|
| 470 |
fn=generate_image,
|
| 471 |
inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y],
|
| 472 |
outputs=[image_output, gen_progress],
|
| 473 |
-
show_api=False,
|
| 474 |
).then(
|
| 475 |
fn=lambda img, msg: on_generate_complete(msg, "Done", img),
|
| 476 |
inputs=[image_output, gen_progress],
|
|
@@ -515,7 +525,6 @@ def create_app():
|
|
| 515 |
),
|
| 516 |
inputs=[include_lora, quantize_toggle, qtype_dropdown, format_dropdown],
|
| 517 |
outputs=[download_link, export_progress],
|
| 518 |
-
show_api=False,
|
| 519 |
).then(
|
| 520 |
fn=lambda path, msg: on_export_complete(msg, "Exported", path),
|
| 521 |
inputs=[download_link, export_progress],
|
|
@@ -534,7 +543,7 @@ def create_app():
|
|
| 534 |
def main():
|
| 535 |
"""Create and launch the Gradio app."""
|
| 536 |
app = create_app()
|
| 537 |
-
# CSS is
|
| 538 |
app.launch()
|
| 539 |
|
| 540 |
|
|
|
|
| 83 |
from src.exporter import export_merged_model
|
| 84 |
from src.config import get_cached_models, get_cached_checkpoints, get_cached_vaes, get_cached_loras
|
| 85 |
|
| 86 |
+
with gr.Blocks(title="SDXL Model Merger") as demo:
|
| 87 |
# Header section
|
| 88 |
with gr.Column(elem_classes=["feature-card"]):
|
| 89 |
gr.HTML("""
|
|
|
|
| 393 |
fn=load_pipeline,
|
| 394 |
inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths],
|
| 395 |
outputs=[load_status, load_progress],
|
| 396 |
+
).then(
|
| 397 |
+
fn=on_load_pipeline_complete,
|
| 398 |
+
inputs=[load_status, load_progress],
|
| 399 |
+
outputs=[load_status, load_progress, load_btn],
|
| 400 |
+
).then(
|
| 401 |
+
fn=lambda: (
|
| 402 |
+
gr.update(choices=["(None found)"] + get_cached_checkpoints()),
|
| 403 |
+
gr.update(choices=["(None found)"] + get_cached_vaes()),
|
| 404 |
+
gr.update(choices=["(None found)"] + get_cached_loras()),
|
| 405 |
+
),
|
| 406 |
+
inputs=[],
|
| 407 |
+
outputs=[cached_checkpoints, cached_vaes, cached_loras],
|
| 408 |
)
|
| 409 |
|
| 410 |
def on_cached_checkpoint_change(cached_path):
|
|
|
|
| 434 |
def on_cached_lora_change(cached_path, current_urls):
|
| 435 |
"""Add cached LoRA to the list."""
|
| 436 |
if cached_path and cached_path != "(None found)":
|
|
|
|
| 437 |
urls_list = [u.strip() for u in current_urls.split("\n") if u.strip()]
|
| 438 |
+
file_url = f"file://{cached_path}"
|
| 439 |
+
if file_url not in urls_list:
|
| 440 |
+
urls_list.append(file_url)
|
| 441 |
return gr.update(value="\n".join(urls_list))
|
| 442 |
return gr.update()
|
| 443 |
|
|
|
|
| 481 |
fn=generate_image,
|
| 482 |
inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y],
|
| 483 |
outputs=[image_output, gen_progress],
|
|
|
|
| 484 |
).then(
|
| 485 |
fn=lambda img, msg: on_generate_complete(msg, "Done", img),
|
| 486 |
inputs=[image_output, gen_progress],
|
|
|
|
| 525 |
),
|
| 526 |
inputs=[include_lora, quantize_toggle, qtype_dropdown, format_dropdown],
|
| 527 |
outputs=[download_link, export_progress],
|
|
|
|
| 528 |
).then(
|
| 529 |
fn=lambda path, msg: on_export_complete(msg, "Exported", path),
|
| 530 |
inputs=[download_link, export_progress],
|
|
|
|
| 543 |
def main():
|
| 544 |
"""Create and launch the Gradio app."""
|
| 545 |
app = create_app()
|
| 546 |
+
# CSS is passed to launch() in Gradio 6+
|
| 547 |
app.launch()
|
| 548 |
|
| 549 |
|
requirements.txt
CHANGED
|
@@ -5,12 +5,13 @@ torch>=2.0.0
|
|
| 5 |
diffusers>=0.24.0
|
| 6 |
transformers>=4.35.0
|
| 7 |
safetensors>=0.4.0
|
|
|
|
| 8 |
|
| 9 |
# Image processing
|
| 10 |
Pillow>=10.0.0
|
| 11 |
|
| 12 |
# UI framework
|
| 13 |
-
gradio>=
|
| 14 |
|
| 15 |
# Download utilities
|
| 16 |
tqdm>=4.65.0
|
|
|
|
| 5 |
diffusers>=0.24.0
|
| 6 |
transformers>=4.35.0
|
| 7 |
safetensors>=0.4.0
|
| 8 |
+
optimum>=1.0.0
|
| 9 |
|
| 10 |
# Image processing
|
| 11 |
Pillow>=10.0.0
|
| 12 |
|
| 13 |
# UI framework
|
| 14 |
+
gradio>=6.9.0
|
| 15 |
|
| 16 |
# Download utilities
|
| 17 |
tqdm>=4.65.0
|
src/downloader.py
CHANGED
|
@@ -127,18 +127,56 @@ def set_download_cancelled(value: bool):
|
|
| 127 |
download_cancelled = value
|
| 128 |
|
| 129 |
|
| 130 |
-
def get_cached_file_size(url: str) -> tuple[Path | None, int | None]:
|
| 131 |
"""
|
| 132 |
Check if file exists in cache and matches expected size.
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
"""
|
| 135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
return None, None
|
| 137 |
|
| 138 |
|
| 139 |
def download_file_with_progress(url: str, output_path: Path, progress_bar=None) -> Path:
|
| 140 |
"""
|
| 141 |
Download a file with Gradio-synced progress bar + cancel support.
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
Args:
|
| 144 |
url: File URL to download (http/https/file)
|
|
@@ -146,14 +184,13 @@ def download_file_with_progress(url: str, output_path: Path, progress_bar=None)
|
|
| 146 |
progress_bar: Optional gr.Progress() object for UI updates
|
| 147 |
|
| 148 |
Returns:
|
| 149 |
-
Path to the downloaded file
|
| 150 |
|
| 151 |
Raises:
|
| 152 |
KeyboardInterrupt: If download is cancelled
|
| 153 |
requests.RequestException: If download fails
|
| 154 |
"""
|
| 155 |
global download_cancelled
|
| 156 |
-
download_cancelled = False
|
| 157 |
|
| 158 |
# Handle local file:// URLs
|
| 159 |
if url.startswith("file://"):
|
|
@@ -163,7 +200,7 @@ def download_file_with_progress(url: str, output_path: Path, progress_bar=None)
|
|
| 163 |
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 164 |
# Copy the file to cache location
|
| 165 |
shutil.copy2(str(local_path), str(output_path))
|
| 166 |
-
|
| 167 |
# Update progress bar for cached files
|
| 168 |
if progress_bar:
|
| 169 |
progress_bar(1.0)
|
|
@@ -171,18 +208,24 @@ def download_file_with_progress(url: str, output_path: Path, progress_bar=None)
|
|
| 171 |
else:
|
| 172 |
raise FileNotFoundError(f"Local file not found: {local_path}")
|
| 173 |
|
| 174 |
-
#
|
| 175 |
expected_size = None
|
| 176 |
try:
|
| 177 |
head = requests.head(url, timeout=10)
|
| 178 |
expected_size = int(head.headers.get('content-length', 0))
|
| 179 |
-
if output_path.exists() and output_path.stat().st_size == expected_size:
|
| 180 |
-
# Cache hit - still update progress to show completion
|
| 181 |
-
if progress_bar:
|
| 182 |
-
progress_bar(1.0)
|
| 183 |
-
return output_path # Cache hit!
|
| 184 |
except Exception:
|
| 185 |
-
pass # Skip
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 188 |
|
|
|
|
| 127 |
download_cancelled = value
|
| 128 |
|
| 129 |
|
| 130 |
+
def get_cached_file_size(url: str, suffix: str = "", type_prefix: str | None = None) -> tuple[Path | None, int | None]:
|
| 131 |
"""
|
| 132 |
Check if file exists in cache and matches expected size.
|
| 133 |
+
|
| 134 |
+
Uses the same filename generation logic as download operations to find
|
| 135 |
+
cached files by URL.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
url: The download URL to check for cached file
|
| 139 |
+
suffix: Optional suffix (e.g., '_vae', '_lora') for special file types
|
| 140 |
+
type_prefix: Optional prefix after model_id (e.g., 'model')
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Tuple of (cached_file_path, file_size) if valid cache exists,
|
| 144 |
+
or (None, None) if no valid cache found.
|
| 145 |
"""
|
| 146 |
+
from .config import CACHE_DIR
|
| 147 |
+
|
| 148 |
+
# Generate the expected filename for this URL
|
| 149 |
+
default_name = "vae.safetensors" if suffix == "_vae" else (
|
| 150 |
+
"lora.safetensors" if suffix == "_lora" else "model.safetensors"
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
cached_filename = get_safe_filename_from_url(
|
| 154 |
+
url,
|
| 155 |
+
default_name=default_name,
|
| 156 |
+
suffix=suffix,
|
| 157 |
+
type_prefix=type_prefix
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
cached_path = CACHE_DIR / cached_filename
|
| 161 |
+
|
| 162 |
+
if cached_path.exists() and cached_path.is_file():
|
| 163 |
+
try:
|
| 164 |
+
file_size = cached_path.stat().st_size
|
| 165 |
+
# Only return valid cache if file has content
|
| 166 |
+
if file_size > 0:
|
| 167 |
+
return cached_path, file_size
|
| 168 |
+
except OSError:
|
| 169 |
+
pass
|
| 170 |
+
|
| 171 |
return None, None
|
| 172 |
|
| 173 |
|
| 174 |
def download_file_with_progress(url: str, output_path: Path, progress_bar=None) -> Path:
|
| 175 |
"""
|
| 176 |
Download a file with Gradio-synced progress bar + cancel support.
|
| 177 |
+
|
| 178 |
+
Checks for existing cached files before downloading. If a valid cache
|
| 179 |
+
exists (file exists with matching expected size), skips re-download.
|
| 180 |
|
| 181 |
Args:
|
| 182 |
url: File URL to download (http/https/file)
|
|
|
|
| 184 |
progress_bar: Optional gr.Progress() object for UI updates
|
| 185 |
|
| 186 |
Returns:
|
| 187 |
+
Path to the downloaded (or cached) file
|
| 188 |
|
| 189 |
Raises:
|
| 190 |
KeyboardInterrupt: If download is cancelled
|
| 191 |
requests.RequestException: If download fails
|
| 192 |
"""
|
| 193 |
global download_cancelled
|
|
|
|
| 194 |
|
| 195 |
# Handle local file:// URLs
|
| 196 |
if url.startswith("file://"):
|
|
|
|
| 200 |
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 201 |
# Copy the file to cache location
|
| 202 |
shutil.copy2(str(local_path), str(output_path))
|
| 203 |
+
|
| 204 |
# Update progress bar for cached files
|
| 205 |
if progress_bar:
|
| 206 |
progress_bar(1.0)
|
|
|
|
| 208 |
else:
|
| 209 |
raise FileNotFoundError(f"Local file not found: {local_path}")
|
| 210 |
|
| 211 |
+
# Early cache check: if file exists and size matches URL's content-length, skip re-download
|
| 212 |
expected_size = None
|
| 213 |
try:
|
| 214 |
head = requests.head(url, timeout=10)
|
| 215 |
expected_size = int(head.headers.get('content-length', 0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
except Exception:
|
| 217 |
+
pass # Skip header fetch on errors
|
| 218 |
+
|
| 219 |
+
if output_path.exists() and expected_size is not None:
|
| 220 |
+
try:
|
| 221 |
+
cached_size = output_path.stat().st_size
|
| 222 |
+
if cached_size == expected_size:
|
| 223 |
+
# Cache hit - file exists with correct size
|
| 224 |
+
if progress_bar:
|
| 225 |
+
progress_bar(1.0)
|
| 226 |
+
return output_path # Skip re-download!
|
| 227 |
+
except OSError:
|
| 228 |
+
pass # File access error, proceed with download
|
| 229 |
|
| 230 |
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 231 |
|
src/exporter.py
CHANGED
|
@@ -34,7 +34,10 @@ def export_merged_model(
|
|
| 34 |
# Step 1: Unload LoRAs
|
| 35 |
yield "πΎ Exporting model...", "Unloading LoRAs..."
|
| 36 |
if include_lora:
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
merged_state_dict = {}
|
| 40 |
|
|
|
|
| 34 |
# Step 1: Unload LoRAs
|
| 35 |
yield "πΎ Exporting model...", "Unloading LoRAs..."
|
| 36 |
if include_lora:
|
| 37 |
+
try:
|
| 38 |
+
global_pipe.unload_lora_weights()
|
| 39 |
+
except Exception:
|
| 40 |
+
pass
|
| 41 |
|
| 42 |
merged_state_dict = {}
|
| 43 |
|
src/pipeline.py
CHANGED
|
@@ -9,8 +9,8 @@ from diffusers import (
|
|
| 9 |
DPMSolverSDEScheduler,
|
| 10 |
)
|
| 11 |
|
| 12 |
-
from .config import device, dtype, pipe as global_pipe, CACHE_DIR,
|
| 13 |
-
from .downloader import download_file_with_progress, get_safe_filename_from_url
|
| 14 |
|
| 15 |
|
| 16 |
def _make_asymmetric_forward(module, pad_h: int, pad_w: int, tile_x: bool, tile_y: bool):
|
|
@@ -74,28 +74,41 @@ def load_pipeline(
|
|
| 74 |
Returns:
|
| 75 |
Tuple of (final_status_message, progress_text)
|
| 76 |
"""
|
| 77 |
-
global global_pipe
|
| 78 |
|
| 79 |
try:
|
|
|
|
| 80 |
checkpoint_filename = get_safe_filename_from_url(checkpoint_url, type_prefix="model")
|
| 81 |
checkpoint_path = CACHE_DIR / checkpoint_filename
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
# VAE: Use suffix="_vae" and default to "vae.safetensors" for proper caching/dropdown matching
|
| 84 |
vae_filename = get_safe_filename_from_url(vae_url, default_name="vae.safetensors", suffix="_vae") if vae_url.strip() else "vae.safetensors"
|
| 85 |
vae_path = CACHE_DIR / vae_filename
|
|
|
|
| 86 |
|
| 87 |
-
# Download checkpoint
|
| 88 |
if progress:
|
| 89 |
-
progress(0.1, desc="Downloading base model...")
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
# Download VAE if provided
|
| 94 |
if vae_url.strip():
|
|
|
|
| 95 |
if progress:
|
| 96 |
-
progress(0.2, desc="Downloading VAE...")
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
| 99 |
vae = AutoencoderKL.from_single_file(str(vae_path), torch_dtype=dtype)
|
| 100 |
else:
|
| 101 |
vae = None
|
|
@@ -126,27 +139,30 @@ def load_pipeline(
|
|
| 126 |
|
| 127 |
# Load and fuse each LoRA sequentially (only if URLs exist)
|
| 128 |
if lora_urls:
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
download_file_with_progress(lora_urls[0], first_lora_path)
|
| 133 |
-
|
| 134 |
-
global_pipe.load_lora_weights(str(first_lora_path), adapter_name="main_lora")
|
| 135 |
-
global_pipe.fuse_lora(adapter_names=["main_lora"], lora_scale=strengths[0])
|
| 136 |
-
|
| 137 |
-
for i in range(1, len(lora_urls)):
|
| 138 |
-
lora_filename = get_safe_filename_from_url(lora_urls[i], f"lora_{i}.safetensors", suffix="_lora")
|
| 139 |
lora_path = CACHE_DIR / lora_filename
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
global_pipe.fuse_lora(
|
| 147 |
-
adapter_names=["main_lora"] + [f"lora_{j}" for j in range(1, i+1)],
|
| 148 |
-
lora_scale=strengths[i]
|
| 149 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
# Set scheduler and move to device (do this once at the end)
|
| 152 |
yield "βοΈ Finalizing...", "Setting up scheduler..."
|
|
@@ -157,7 +173,7 @@ def load_pipeline(
|
|
| 157 |
return ("β
Pipeline loaded successfully!", f"Ready! Loaded {len(lora_urls)} LoRA(s)")
|
| 158 |
|
| 159 |
except KeyboardInterrupt:
|
| 160 |
-
|
| 161 |
return ("β οΈ Download cancelled by user", "Cancelled")
|
| 162 |
except Exception as e:
|
| 163 |
return (f"β Error loading pipeline: {str(e)}", f"Error: {str(e)}")
|
|
@@ -165,8 +181,7 @@ def load_pipeline(
|
|
| 165 |
|
| 166 |
def cancel_download():
|
| 167 |
"""Set the global cancellation flag to stop any ongoing downloads."""
|
| 168 |
-
|
| 169 |
-
download_cancelled = True
|
| 170 |
|
| 171 |
|
| 172 |
def get_pipeline() -> StableDiffusionXLPipeline | None:
|
|
|
|
| 9 |
DPMSolverSDEScheduler,
|
| 10 |
)
|
| 11 |
|
| 12 |
+
from .config import device, dtype, pipe as global_pipe, CACHE_DIR, device_description
|
| 13 |
+
from .downloader import download_file_with_progress, get_safe_filename_from_url, set_download_cancelled
|
| 14 |
|
| 15 |
|
| 16 |
def _make_asymmetric_forward(module, pad_h: int, pad_w: int, tile_x: bool, tile_y: bool):
|
|
|
|
| 74 |
Returns:
|
| 75 |
Tuple of (final_status_message, progress_text)
|
| 76 |
"""
|
| 77 |
+
global global_pipe
|
| 78 |
|
| 79 |
try:
|
| 80 |
+
set_download_cancelled(False)
|
| 81 |
checkpoint_filename = get_safe_filename_from_url(checkpoint_url, type_prefix="model")
|
| 82 |
checkpoint_path = CACHE_DIR / checkpoint_filename
|
| 83 |
+
|
| 84 |
+
# Check if checkpoint is already cached
|
| 85 |
+
checkpoint_cached = checkpoint_path.exists() and checkpoint_path.stat().st_size > 0
|
| 86 |
|
| 87 |
# VAE: Use suffix="_vae" and default to "vae.safetensors" for proper caching/dropdown matching
|
| 88 |
vae_filename = get_safe_filename_from_url(vae_url, default_name="vae.safetensors", suffix="_vae") if vae_url.strip() else "vae.safetensors"
|
| 89 |
vae_path = CACHE_DIR / vae_filename
|
| 90 |
+
vae_cached = vae_url.strip() and vae_path.exists() and vae_path.stat().st_size > 0
|
| 91 |
|
| 92 |
+
# Download checkpoint (skips if already cached)
|
| 93 |
if progress:
|
| 94 |
+
progress(0.1, desc="Downloading base model..." if not checkpoint_cached else "Loading base model...")
|
| 95 |
+
|
| 96 |
+
status_msg = f"π₯ Downloading {checkpoint_path.name}..." if not checkpoint_cached else f"β
Using cached {checkpoint_path.name}"
|
| 97 |
+
yield status_msg, "Starting download..."
|
| 98 |
+
|
| 99 |
+
if not checkpoint_cached:
|
| 100 |
+
download_file_with_progress(checkpoint_url, checkpoint_path)
|
| 101 |
|
| 102 |
# Download VAE if provided
|
| 103 |
if vae_url.strip():
|
| 104 |
+
status_msg = f"π₯ Downloading {vae_path.name}..." if not vae_cached else f"β
Using cached {vae_path.name}"
|
| 105 |
if progress:
|
| 106 |
+
progress(0.2, desc="Downloading VAE..." if not vae_cached else "Loading VAE...")
|
| 107 |
+
|
| 108 |
+
yield status_msg, f"Downloading VAE: {vae_path.name}" if not vae_cached else f"Using cached VAE: {vae_path.name}"
|
| 109 |
+
|
| 110 |
+
if not vae_cached:
|
| 111 |
+
download_file_with_progress(vae_url, vae_path)
|
| 112 |
vae = AutoencoderKL.from_single_file(str(vae_path), torch_dtype=dtype)
|
| 113 |
else:
|
| 114 |
vae = None
|
|
|
|
| 139 |
|
| 140 |
# Load and fuse each LoRA sequentially (only if URLs exist)
|
| 141 |
if lora_urls:
|
| 142 |
+
global_pipe = global_pipe.to(device=device, dtype=dtype)
|
| 143 |
+
for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)):
|
| 144 |
+
lora_filename = get_safe_filename_from_url(lora_url, f"lora_{i}.safetensors", suffix="_lora")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
lora_path = CACHE_DIR / lora_filename
|
| 146 |
+
lora_cached = lora_path.exists() and lora_path.stat().st_size > 0
|
| 147 |
+
|
| 148 |
+
status_msg = (
|
| 149 |
+
f"π₯ Downloading LoRA {i+1}/{len(lora_urls)}: {lora_path.name}..."
|
| 150 |
+
if not lora_cached else
|
| 151 |
+
f"β
Using cached LoRA {i+1}/{len(lora_urls)}: {lora_path.name}"
|
|
|
|
|
|
|
|
|
|
| 152 |
)
|
| 153 |
+
|
| 154 |
+
yield (
|
| 155 |
+
status_msg,
|
| 156 |
+
f"Downloading LoRA {i+1}/{len(lora_urls)} ({lora_path.name})..." if not lora_cached
|
| 157 |
+
else f"Using cached LoRA {i+1}/{len(lora_urls)} ({lora_path.name})"
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
if not lora_cached:
|
| 161 |
+
download_file_with_progress(lora_url, lora_path)
|
| 162 |
+
adapter_name = f"lora_{i}"
|
| 163 |
+
global_pipe.load_lora_weights(str(lora_path), adapter_name=adapter_name)
|
| 164 |
+
global_pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength)
|
| 165 |
+
global_pipe.unload_lora_weights()
|
| 166 |
|
| 167 |
# Set scheduler and move to device (do this once at the end)
|
| 168 |
yield "βοΈ Finalizing...", "Setting up scheduler..."
|
|
|
|
| 173 |
return ("β
Pipeline loaded successfully!", f"Ready! Loaded {len(lora_urls)} LoRA(s)")
|
| 174 |
|
| 175 |
except KeyboardInterrupt:
|
| 176 |
+
set_download_cancelled(False)
|
| 177 |
return ("β οΈ Download cancelled by user", "Cancelled")
|
| 178 |
except Exception as e:
|
| 179 |
return (f"β Error loading pipeline: {str(e)}", f"Error: {str(e)}")
|
|
|
|
| 181 |
|
| 182 |
def cancel_download():
|
| 183 |
"""Set the global cancellation flag to stop any ongoing downloads."""
|
| 184 |
+
set_download_cancelled(True)
|
|
|
|
| 185 |
|
| 186 |
|
| 187 |
def get_pipeline() -> StableDiffusionXLPipeline | None:
|