Kyle Pearson commited on
Commit
60d66bd
Β·
1 Parent(s): 6a07ce1

Update Gradio to 6.9.0, add cached file checks in LoRA loading, fix unloading errors, optimize download cancellation, improve device/dtype handling, update optimum dependency.

Browse files
Files changed (5) hide show
  1. app.py +17 -8
  2. requirements.txt +2 -1
  3. src/downloader.py +56 -13
  4. src/exporter.py +4 -1
  5. src/pipeline.py +47 -32
app.py CHANGED
@@ -83,7 +83,7 @@ def create_app():
83
  from src.exporter import export_merged_model
84
  from src.config import get_cached_models, get_cached_checkpoints, get_cached_vaes, get_cached_loras
85
 
86
- with gr.Blocks(title="SDXL Model Merger", css=header_css) as demo:
87
  # Header section
88
  with gr.Column(elem_classes=["feature-card"]):
89
  gr.HTML("""
@@ -393,7 +393,18 @@ def create_app():
393
  fn=load_pipeline,
394
  inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths],
395
  outputs=[load_status, load_progress],
396
- show_api=False,
 
 
 
 
 
 
 
 
 
 
 
397
  )
398
 
399
  def on_cached_checkpoint_change(cached_path):
@@ -423,10 +434,10 @@ def create_app():
423
  def on_cached_lora_change(cached_path, current_urls):
424
  """Add cached LoRA to the list."""
425
  if cached_path and cached_path != "(None found)":
426
- # Add new LoRA to existing URLs (avoid duplicate)
427
  urls_list = [u.strip() for u in current_urls.split("\n") if u.strip()]
428
- if cached_path not in urls_list:
429
- urls_list.append(cached_path)
 
430
  return gr.update(value="\n".join(urls_list))
431
  return gr.update()
432
 
@@ -470,7 +481,6 @@ def create_app():
470
  fn=generate_image,
471
  inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y],
472
  outputs=[image_output, gen_progress],
473
- show_api=False,
474
  ).then(
475
  fn=lambda img, msg: on_generate_complete(msg, "Done", img),
476
  inputs=[image_output, gen_progress],
@@ -515,7 +525,6 @@ def create_app():
515
  ),
516
  inputs=[include_lora, quantize_toggle, qtype_dropdown, format_dropdown],
517
  outputs=[download_link, export_progress],
518
- show_api=False,
519
  ).then(
520
  fn=lambda path, msg: on_export_complete(msg, "Exported", path),
521
  inputs=[download_link, export_progress],
@@ -534,7 +543,7 @@ def create_app():
534
  def main():
535
  """Create and launch the Gradio app."""
536
  app = create_app()
537
- # CSS is embedded in the Blocks, so we pass it to launch for Gradio 6+
538
  app.launch()
539
 
540
 
 
83
  from src.exporter import export_merged_model
84
  from src.config import get_cached_models, get_cached_checkpoints, get_cached_vaes, get_cached_loras
85
 
86
+ with gr.Blocks(title="SDXL Model Merger") as demo:
87
  # Header section
88
  with gr.Column(elem_classes=["feature-card"]):
89
  gr.HTML("""
 
393
  fn=load_pipeline,
394
  inputs=[checkpoint_url, vae_url, lora_urls, lora_strengths],
395
  outputs=[load_status, load_progress],
396
+ ).then(
397
+ fn=on_load_pipeline_complete,
398
+ inputs=[load_status, load_progress],
399
+ outputs=[load_status, load_progress, load_btn],
400
+ ).then(
401
+ fn=lambda: (
402
+ gr.update(choices=["(None found)"] + get_cached_checkpoints()),
403
+ gr.update(choices=["(None found)"] + get_cached_vaes()),
404
+ gr.update(choices=["(None found)"] + get_cached_loras()),
405
+ ),
406
+ inputs=[],
407
+ outputs=[cached_checkpoints, cached_vaes, cached_loras],
408
  )
409
 
410
  def on_cached_checkpoint_change(cached_path):
 
434
  def on_cached_lora_change(cached_path, current_urls):
435
  """Add cached LoRA to the list."""
436
  if cached_path and cached_path != "(None found)":
 
437
  urls_list = [u.strip() for u in current_urls.split("\n") if u.strip()]
438
+ file_url = f"file://{cached_path}"
439
+ if file_url not in urls_list:
440
+ urls_list.append(file_url)
441
  return gr.update(value="\n".join(urls_list))
442
  return gr.update()
443
 
 
481
  fn=generate_image,
482
  inputs=[prompt, negative_prompt, cfg, steps, height, width, tile_x, tile_y],
483
  outputs=[image_output, gen_progress],
 
484
  ).then(
485
  fn=lambda img, msg: on_generate_complete(msg, "Done", img),
486
  inputs=[image_output, gen_progress],
 
525
  ),
526
  inputs=[include_lora, quantize_toggle, qtype_dropdown, format_dropdown],
527
  outputs=[download_link, export_progress],
 
528
  ).then(
529
  fn=lambda path, msg: on_export_complete(msg, "Exported", path),
530
  inputs=[download_link, export_progress],
 
543
  def main():
544
  """Create and launch the Gradio app."""
545
  app = create_app()
546
+ # CSS is passed to launch() in Gradio 6+
547
  app.launch()
548
 
549
 
requirements.txt CHANGED
@@ -5,12 +5,13 @@ torch>=2.0.0
5
  diffusers>=0.24.0
6
  transformers>=4.35.0
7
  safetensors>=0.4.0
 
8
 
9
  # Image processing
10
  Pillow>=10.0.0
11
 
12
  # UI framework
13
- gradio>=4.0.0
14
 
15
  # Download utilities
16
  tqdm>=4.65.0
 
5
  diffusers>=0.24.0
6
  transformers>=4.35.0
7
  safetensors>=0.4.0
8
+ optimum>=1.0.0
9
 
10
  # Image processing
11
  Pillow>=10.0.0
12
 
13
  # UI framework
14
+ gradio>=6.9.0
15
 
16
  # Download utilities
17
  tqdm>=4.65.0
src/downloader.py CHANGED
@@ -127,18 +127,56 @@ def set_download_cancelled(value: bool):
127
  download_cancelled = value
128
 
129
 
130
- def get_cached_file_size(url: str) -> tuple[Path | None, int | None]:
131
  """
132
  Check if file exists in cache and matches expected size.
133
- Returns (path, expected_size) or (None, None) if no valid cache.
 
 
 
 
 
 
 
 
 
 
 
134
  """
135
- # Simple implementation - would need URL-to-filename mapping for production
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  return None, None
137
 
138
 
139
  def download_file_with_progress(url: str, output_path: Path, progress_bar=None) -> Path:
140
  """
141
  Download a file with Gradio-synced progress bar + cancel support.
 
 
 
142
 
143
  Args:
144
  url: File URL to download (http/https/file)
@@ -146,14 +184,13 @@ def download_file_with_progress(url: str, output_path: Path, progress_bar=None)
146
  progress_bar: Optional gr.Progress() object for UI updates
147
 
148
  Returns:
149
- Path to the downloaded file
150
 
151
  Raises:
152
  KeyboardInterrupt: If download is cancelled
153
  requests.RequestException: If download fails
154
  """
155
  global download_cancelled
156
- download_cancelled = False
157
 
158
  # Handle local file:// URLs
159
  if url.startswith("file://"):
@@ -163,7 +200,7 @@ def download_file_with_progress(url: str, output_path: Path, progress_bar=None)
163
  output_path.parent.mkdir(parents=True, exist_ok=True)
164
  # Copy the file to cache location
165
  shutil.copy2(str(local_path), str(output_path))
166
-
167
  # Update progress bar for cached files
168
  if progress_bar:
169
  progress_bar(1.0)
@@ -171,18 +208,24 @@ def download_file_with_progress(url: str, output_path: Path, progress_bar=None)
171
  else:
172
  raise FileNotFoundError(f"Local file not found: {local_path}")
173
 
174
- # Cache check: if file exists and size matches URL's content-length, skip re-download
175
  expected_size = None
176
  try:
177
  head = requests.head(url, timeout=10)
178
  expected_size = int(head.headers.get('content-length', 0))
179
- if output_path.exists() and output_path.stat().st_size == expected_size:
180
- # Cache hit - still update progress to show completion
181
- if progress_bar:
182
- progress_bar(1.0)
183
- return output_path # Cache hit!
184
  except Exception:
185
- pass # Skip cache validation on errors
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  output_path.parent.mkdir(parents=True, exist_ok=True)
188
 
 
127
  download_cancelled = value
128
 
129
 
130
+ def get_cached_file_size(url: str, suffix: str = "", type_prefix: str | None = None) -> tuple[Path | None, int | None]:
131
  """
132
  Check if file exists in cache and matches expected size.
133
+
134
+ Uses the same filename generation logic as download operations to find
135
+ cached files by URL.
136
+
137
+ Args:
138
+ url: The download URL to check for cached file
139
+ suffix: Optional suffix (e.g., '_vae', '_lora') for special file types
140
+ type_prefix: Optional prefix after model_id (e.g., 'model')
141
+
142
+ Returns:
143
+ Tuple of (cached_file_path, file_size) if valid cache exists,
144
+ or (None, None) if no valid cache found.
145
  """
146
+ from .config import CACHE_DIR
147
+
148
+ # Generate the expected filename for this URL
149
+ default_name = "vae.safetensors" if suffix == "_vae" else (
150
+ "lora.safetensors" if suffix == "_lora" else "model.safetensors"
151
+ )
152
+
153
+ cached_filename = get_safe_filename_from_url(
154
+ url,
155
+ default_name=default_name,
156
+ suffix=suffix,
157
+ type_prefix=type_prefix
158
+ )
159
+
160
+ cached_path = CACHE_DIR / cached_filename
161
+
162
+ if cached_path.exists() and cached_path.is_file():
163
+ try:
164
+ file_size = cached_path.stat().st_size
165
+ # Only return valid cache if file has content
166
+ if file_size > 0:
167
+ return cached_path, file_size
168
+ except OSError:
169
+ pass
170
+
171
  return None, None
172
 
173
 
174
  def download_file_with_progress(url: str, output_path: Path, progress_bar=None) -> Path:
175
  """
176
  Download a file with Gradio-synced progress bar + cancel support.
177
+
178
+ Checks for existing cached files before downloading. If a valid cache
179
+ exists (file exists with matching expected size), skips re-download.
180
 
181
  Args:
182
  url: File URL to download (http/https/file)
 
184
  progress_bar: Optional gr.Progress() object for UI updates
185
 
186
  Returns:
187
+ Path to the downloaded (or cached) file
188
 
189
  Raises:
190
  KeyboardInterrupt: If download is cancelled
191
  requests.RequestException: If download fails
192
  """
193
  global download_cancelled
 
194
 
195
  # Handle local file:// URLs
196
  if url.startswith("file://"):
 
200
  output_path.parent.mkdir(parents=True, exist_ok=True)
201
  # Copy the file to cache location
202
  shutil.copy2(str(local_path), str(output_path))
203
+
204
  # Update progress bar for cached files
205
  if progress_bar:
206
  progress_bar(1.0)
 
208
  else:
209
  raise FileNotFoundError(f"Local file not found: {local_path}")
210
 
211
+ # Early cache check: if file exists and size matches URL's content-length, skip re-download
212
  expected_size = None
213
  try:
214
  head = requests.head(url, timeout=10)
215
  expected_size = int(head.headers.get('content-length', 0))
 
 
 
 
 
216
  except Exception:
217
+ pass # Skip header fetch on errors
218
+
219
+ if output_path.exists() and expected_size is not None:
220
+ try:
221
+ cached_size = output_path.stat().st_size
222
+ if cached_size == expected_size:
223
+ # Cache hit - file exists with correct size
224
+ if progress_bar:
225
+ progress_bar(1.0)
226
+ return output_path # Skip re-download!
227
+ except OSError:
228
+ pass # File access error, proceed with download
229
 
230
  output_path.parent.mkdir(parents=True, exist_ok=True)
231
 
src/exporter.py CHANGED
@@ -34,7 +34,10 @@ def export_merged_model(
34
  # Step 1: Unload LoRAs
35
  yield "πŸ’Ύ Exporting model...", "Unloading LoRAs..."
36
  if include_lora:
37
- global_pipe.unload_lora_weights()
 
 
 
38
 
39
  merged_state_dict = {}
40
 
 
34
  # Step 1: Unload LoRAs
35
  yield "πŸ’Ύ Exporting model...", "Unloading LoRAs..."
36
  if include_lora:
37
+ try:
38
+ global_pipe.unload_lora_weights()
39
+ except Exception:
40
+ pass
41
 
42
  merged_state_dict = {}
43
 
src/pipeline.py CHANGED
@@ -9,8 +9,8 @@ from diffusers import (
9
  DPMSolverSDEScheduler,
10
  )
11
 
12
- from .config import device, dtype, pipe as global_pipe, CACHE_DIR, download_cancelled, device_description
13
- from .downloader import download_file_with_progress, get_safe_filename_from_url
14
 
15
 
16
  def _make_asymmetric_forward(module, pad_h: int, pad_w: int, tile_x: bool, tile_y: bool):
@@ -74,28 +74,41 @@ def load_pipeline(
74
  Returns:
75
  Tuple of (final_status_message, progress_text)
76
  """
77
- global global_pipe, download_cancelled
78
 
79
  try:
 
80
  checkpoint_filename = get_safe_filename_from_url(checkpoint_url, type_prefix="model")
81
  checkpoint_path = CACHE_DIR / checkpoint_filename
 
 
 
82
 
83
  # VAE: Use suffix="_vae" and default to "vae.safetensors" for proper caching/dropdown matching
84
  vae_filename = get_safe_filename_from_url(vae_url, default_name="vae.safetensors", suffix="_vae") if vae_url.strip() else "vae.safetensors"
85
  vae_path = CACHE_DIR / vae_filename
 
86
 
87
- # Download checkpoint
88
  if progress:
89
- progress(0.1, desc="Downloading base model...")
90
- yield f"πŸ“₯ Downloading {checkpoint_path.name}...", "Starting download..."
91
- download_file_with_progress(checkpoint_url, checkpoint_path)
 
 
 
 
92
 
93
  # Download VAE if provided
94
  if vae_url.strip():
 
95
  if progress:
96
- progress(0.2, desc="Downloading VAE...")
97
- yield f"πŸ“₯ Downloading {vae_path.name}...", f"Downloading VAE: {vae_path.name}"
98
- download_file_with_progress(vae_url, vae_path)
 
 
 
99
  vae = AutoencoderKL.from_single_file(str(vae_path), torch_dtype=dtype)
100
  else:
101
  vae = None
@@ -126,27 +139,30 @@ def load_pipeline(
126
 
127
  # Load and fuse each LoRA sequentially (only if URLs exist)
128
  if lora_urls:
129
- first_lora_filename = get_safe_filename_from_url(lora_urls[0], "lora_0.safetensors", suffix="_lora")
130
- first_lora_path = CACHE_DIR / first_lora_filename
131
- yield f"πŸ“₯ Downloading LoRA: {first_lora_path.name}...", f"Downloading LoRA 1/... ({first_lora_path.name})..."
132
- download_file_with_progress(lora_urls[0], first_lora_path)
133
-
134
- global_pipe.load_lora_weights(str(first_lora_path), adapter_name="main_lora")
135
- global_pipe.fuse_lora(adapter_names=["main_lora"], lora_scale=strengths[0])
136
-
137
- for i in range(1, len(lora_urls)):
138
- lora_filename = get_safe_filename_from_url(lora_urls[i], f"lora_{i}.safetensors", suffix="_lora")
139
  lora_path = CACHE_DIR / lora_filename
140
- yield f"πŸ“₯ Downloading LoRA {i+1}...", f"Downloading LoRA {i+1}/{len(lora_urls)} ({lora_path.name})..."
141
- download_file_with_progress(lora_urls[i], lora_path)
142
-
143
- global_pipe.unload_lora_weights()
144
- global_pipe.load_lora_weights(str(lora_path), adapter_name=f"lora_{i}")
145
- # Fuse all loaded adapters so far
146
- global_pipe.fuse_lora(
147
- adapter_names=["main_lora"] + [f"lora_{j}" for j in range(1, i+1)],
148
- lora_scale=strengths[i]
149
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  # Set scheduler and move to device (do this once at the end)
152
  yield "βš™οΈ Finalizing...", "Setting up scheduler..."
@@ -157,7 +173,7 @@ def load_pipeline(
157
  return ("βœ… Pipeline loaded successfully!", f"Ready! Loaded {len(lora_urls)} LoRA(s)")
158
 
159
  except KeyboardInterrupt:
160
- download_cancelled = False
161
  return ("⚠️ Download cancelled by user", "Cancelled")
162
  except Exception as e:
163
  return (f"❌ Error loading pipeline: {str(e)}", f"Error: {str(e)}")
@@ -165,8 +181,7 @@ def load_pipeline(
165
 
166
  def cancel_download():
167
  """Set the global cancellation flag to stop any ongoing downloads."""
168
- global download_cancelled
169
- download_cancelled = True
170
 
171
 
172
  def get_pipeline() -> StableDiffusionXLPipeline | None:
 
9
  DPMSolverSDEScheduler,
10
  )
11
 
12
+ from .config import device, dtype, pipe as global_pipe, CACHE_DIR, device_description
13
+ from .downloader import download_file_with_progress, get_safe_filename_from_url, set_download_cancelled
14
 
15
 
16
  def _make_asymmetric_forward(module, pad_h: int, pad_w: int, tile_x: bool, tile_y: bool):
 
74
  Returns:
75
  Tuple of (final_status_message, progress_text)
76
  """
77
+ global global_pipe
78
 
79
  try:
80
+ set_download_cancelled(False)
81
  checkpoint_filename = get_safe_filename_from_url(checkpoint_url, type_prefix="model")
82
  checkpoint_path = CACHE_DIR / checkpoint_filename
83
+
84
+ # Check if checkpoint is already cached
85
+ checkpoint_cached = checkpoint_path.exists() and checkpoint_path.stat().st_size > 0
86
 
87
  # VAE: Use suffix="_vae" and default to "vae.safetensors" for proper caching/dropdown matching
88
  vae_filename = get_safe_filename_from_url(vae_url, default_name="vae.safetensors", suffix="_vae") if vae_url.strip() else "vae.safetensors"
89
  vae_path = CACHE_DIR / vae_filename
90
+ vae_cached = vae_url.strip() and vae_path.exists() and vae_path.stat().st_size > 0
91
 
92
+ # Download checkpoint (skips if already cached)
93
  if progress:
94
+ progress(0.1, desc="Downloading base model..." if not checkpoint_cached else "Loading base model...")
95
+
96
+ status_msg = f"πŸ“₯ Downloading {checkpoint_path.name}..." if not checkpoint_cached else f"βœ… Using cached {checkpoint_path.name}"
97
+ yield status_msg, "Starting download..."
98
+
99
+ if not checkpoint_cached:
100
+ download_file_with_progress(checkpoint_url, checkpoint_path)
101
 
102
  # Download VAE if provided
103
  if vae_url.strip():
104
+ status_msg = f"πŸ“₯ Downloading {vae_path.name}..." if not vae_cached else f"βœ… Using cached {vae_path.name}"
105
  if progress:
106
+ progress(0.2, desc="Downloading VAE..." if not vae_cached else "Loading VAE...")
107
+
108
+ yield status_msg, f"Downloading VAE: {vae_path.name}" if not vae_cached else f"Using cached VAE: {vae_path.name}"
109
+
110
+ if not vae_cached:
111
+ download_file_with_progress(vae_url, vae_path)
112
  vae = AutoencoderKL.from_single_file(str(vae_path), torch_dtype=dtype)
113
  else:
114
  vae = None
 
139
 
140
  # Load and fuse each LoRA sequentially (only if URLs exist)
141
  if lora_urls:
142
+ global_pipe = global_pipe.to(device=device, dtype=dtype)
143
+ for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)):
144
+ lora_filename = get_safe_filename_from_url(lora_url, f"lora_{i}.safetensors", suffix="_lora")
 
 
 
 
 
 
 
145
  lora_path = CACHE_DIR / lora_filename
146
+ lora_cached = lora_path.exists() and lora_path.stat().st_size > 0
147
+
148
+ status_msg = (
149
+ f"πŸ“₯ Downloading LoRA {i+1}/{len(lora_urls)}: {lora_path.name}..."
150
+ if not lora_cached else
151
+ f"βœ… Using cached LoRA {i+1}/{len(lora_urls)}: {lora_path.name}"
 
 
 
152
  )
153
+
154
+ yield (
155
+ status_msg,
156
+ f"Downloading LoRA {i+1}/{len(lora_urls)} ({lora_path.name})..." if not lora_cached
157
+ else f"Using cached LoRA {i+1}/{len(lora_urls)} ({lora_path.name})"
158
+ )
159
+
160
+ if not lora_cached:
161
+ download_file_with_progress(lora_url, lora_path)
162
+ adapter_name = f"lora_{i}"
163
+ global_pipe.load_lora_weights(str(lora_path), adapter_name=adapter_name)
164
+ global_pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength)
165
+ global_pipe.unload_lora_weights()
166
 
167
  # Set scheduler and move to device (do this once at the end)
168
  yield "βš™οΈ Finalizing...", "Setting up scheduler..."
 
173
  return ("βœ… Pipeline loaded successfully!", f"Ready! Loaded {len(lora_urls)} LoRA(s)")
174
 
175
  except KeyboardInterrupt:
176
+ set_download_cancelled(False)
177
  return ("⚠️ Download cancelled by user", "Cancelled")
178
  except Exception as e:
179
  return (f"❌ Error loading pipeline: {str(e)}", f"Error: {str(e)}")
 
181
 
182
  def cancel_download():
183
  """Set the global cancellation flag to stop any ongoing downloads."""
184
+ set_download_cancelled(True)
 
185
 
186
 
187
  def get_pipeline() -> StableDiffusionXLPipeline | None: