Kyle Pearson commited on
Commit
8cdb001
Β·
1 Parent(s): 60d66bd
Files changed (4) hide show
  1. app.py +1 -1
  2. src/config.py +8 -2
  3. src/downloader.py +49 -28
  4. src/pipeline.py +23 -13
app.py CHANGED
@@ -372,7 +372,7 @@ def create_app():
372
  progress_text,
373
  gr.update(interactive=True)
374
  )
375
- elif "⚠️" in status_msg:
376
  return (
377
  '<div class="status-warning">⚠️ Download cancelled</div>',
378
  progress_text,
 
372
  progress_text,
373
  gr.update(interactive=True)
374
  )
375
+ elif "⚠️" in status_msg or "cancelled" in status_msg.lower():
376
  return (
377
  '<div class="status-warning">⚠️ Download cancelled</div>',
378
  progress_text,
src/config.py CHANGED
@@ -108,22 +108,28 @@ def get_cached_checkpoints():
108
 
109
 
110
  def get_cached_vaes():
111
- """Get list of cached VAE files (model_id_*_vae.safetensors)."""
112
  if not CACHE_DIR.exists():
113
  return []
114
 
115
  models = []
 
 
 
116
  for file in sorted(CACHE_DIR.glob("*_vae.safetensors")):
117
  models.append(str(file))
118
  return models
119
 
120
 
121
  def get_cached_loras():
122
- """Get list of cached LoRA files (model_id_*_lora.safetensors)."""
123
  if not CACHE_DIR.exists():
124
  return []
125
 
126
  models = []
 
 
 
127
  for file in sorted(CACHE_DIR.glob("*_lora.safetensors")):
128
  models.append(str(file))
129
  return models
 
108
 
109
 
110
  def get_cached_vaes():
111
+ """Get list of cached VAE files (model_id_vae.safetensors or model_id_*_vae.safetensors)."""
112
  if not CACHE_DIR.exists():
113
  return []
114
 
115
  models = []
116
+ # Match both patterns:
117
+ # - model_id_vae.safetensors
118
+ # - model_id_name_vae.safetensors (for backward compatibility)
119
  for file in sorted(CACHE_DIR.glob("*_vae.safetensors")):
120
  models.append(str(file))
121
  return models
122
 
123
 
124
  def get_cached_loras():
125
+ """Get list of cached LoRA files (model_id_lora.safetensors or model_id_*_lora.safetensors)."""
126
  if not CACHE_DIR.exists():
127
  return []
128
 
129
  models = []
130
+ # Match both patterns:
131
+ # - model_id_lora.safetensors
132
+ # - model_id_name_lora.safetensors (for backward compatibility)
133
  for file in sorted(CACHE_DIR.glob("*_lora.safetensors")):
134
  models.append(str(file))
135
  return models
src/downloader.py CHANGED
@@ -27,8 +27,8 @@ def get_safe_filename_from_url(
27
 
28
  Naming patterns:
29
  - Checkpoint (type_prefix='model'): 12345_model.safetensors or 12345_model_anime_style.safetensors
30
- - VAE (suffix='_vae'): 12345_vae.safetensors or 12345_anime_vae.safetensors
31
- - LoRA (suffix='_lora'): 12345_lora.safetensors or 12345_name_lora.safetensors
32
 
33
  For HuggingFace URLs without model IDs, attempts to extract name from path or uses suffix-based naming.
34
 
@@ -39,10 +39,10 @@ def get_safe_filename_from_url(
39
  type_prefix: Optional prefix after model_id (e.g., 'model' -> 12345_model.safetensors)
40
  """
41
  model_id = extract_model_id(url)
42
-
43
  # If no CivitAI model ID, try to generate a name from HuggingFace path
44
  if not model_id and "huggingface.co" in url:
45
- # Try to extract name from URL path (e.g., sdxl-vae-fp16-fix -> vae)
46
  try:
47
  parts = url.split("huggingface.co/")[1] if "huggingface.co/" in url else ""
48
  if parts:
@@ -56,17 +56,35 @@ def get_safe_filename_from_url(
56
  model_id = f"hf_{clean_repo}"
57
  except Exception:
58
  pass
59
-
60
  if not model_id:
61
  return default_name
62
 
63
- # Build the name portion: either clean name from URL or fallback
64
- name_part = ""
65
-
66
- # For VAE/LoRA types, prefer the suffix-based naming and skip Content-Disposition parsing
67
- # to avoid double naming (e.g., sdxlvae_vae instead of just vae)
68
  is_special_type = suffix in ("_vae", "_lora")
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  if not is_special_type:
71
  try:
72
  response = requests.head(url, timeout=10, allow_redirects=True)
@@ -90,19 +108,22 @@ def get_safe_filename_from_url(
90
  parts.append(type_prefix)
91
  if name_part:
92
  parts.append(name_part)
 
 
93
  if suffix:
94
- # Avoid double underscores: only add separator if needed
95
- if not suffix.startswith('_'):
96
- parts.append('_' + suffix.lstrip('_'))
97
  else:
98
- parts.append(suffix)
 
99
 
100
  return '_'.join(p for p in parts if p).replace('__', '_') + '.safetensors'
101
 
102
 
103
  class TqdmGradio(TqdmBase):
104
  """tqdm subclass that sends progress updates to Gradio's gr.Progress()"""
105
-
106
  def __init__(self, *args, gradio_prog=None, **kwargs):
107
  super().__init__(*args, **kwargs)
108
  self.gradio_prog = gradio_prog
@@ -130,7 +151,7 @@ def set_download_cancelled(value: bool):
130
  def get_cached_file_size(url: str, suffix: str = "", type_prefix: str | None = None) -> tuple[Path | None, int | None]:
131
  """
132
  Check if file exists in cache and matches expected size.
133
-
134
  Uses the same filename generation logic as download operations to find
135
  cached files by URL.
136
 
@@ -144,21 +165,21 @@ def get_cached_file_size(url: str, suffix: str = "", type_prefix: str | None = N
144
  or (None, None) if no valid cache found.
145
  """
146
  from .config import CACHE_DIR
147
-
148
  # Generate the expected filename for this URL
149
  default_name = "vae.safetensors" if suffix == "_vae" else (
150
  "lora.safetensors" if suffix == "_lora" else "model.safetensors"
151
  )
152
-
153
  cached_filename = get_safe_filename_from_url(
154
- url,
155
  default_name=default_name,
156
  suffix=suffix,
157
  type_prefix=type_prefix
158
  )
159
-
160
  cached_path = CACHE_DIR / cached_filename
161
-
162
  if cached_path.exists() and cached_path.is_file():
163
  try:
164
  file_size = cached_path.stat().st_size
@@ -167,14 +188,14 @@ def get_cached_file_size(url: str, suffix: str = "", type_prefix: str | None = N
167
  return cached_path, file_size
168
  except OSError:
169
  pass
170
-
171
  return None, None
172
 
173
 
174
  def download_file_with_progress(url: str, output_path: Path, progress_bar=None) -> Path:
175
  """
176
  Download a file with Gradio-synced progress bar + cancel support.
177
-
178
  Checks for existing cached files before downloading. If a valid cache
179
  exists (file exists with matching expected size), skips re-download.
180
 
@@ -215,7 +236,7 @@ def download_file_with_progress(url: str, output_path: Path, progress_bar=None)
215
  expected_size = int(head.headers.get('content-length', 0))
216
  except Exception:
217
  pass # Skip header fetch on errors
218
-
219
  if output_path.exists() and expected_size is not None:
220
  try:
221
  cached_size = output_path.stat().st_size
@@ -267,7 +288,7 @@ def download_file_with_progress(url: str, output_path: Path, progress_bar=None)
267
  def clear_cache(cache_dir: Path = None, keep_extensions: list[str] = None):
268
  """
269
  Remove old cache files.
270
-
271
  Args:
272
  cache_dir: Cache directory path (defaults to config.CACHE_DIR)
273
  keep_extensions: File extensions to preserve (default: ['.safetensors'])
@@ -275,14 +296,14 @@ def clear_cache(cache_dir: Path = None, keep_extensions: list[str] = None):
275
  if cache_dir is None:
276
  from .config import CACHE_DIR
277
  cache_dir = CACHE_DIR
278
-
279
  if keep_extensions is None:
280
  keep_extensions = ['.safetensors']
281
-
282
  # Remove temp files
283
  for file in cache_dir.glob("*.tmp"):
284
  file.unlink()
285
-
286
  # Optional: age-based cleanup (7 days)
287
  # import time
288
  # cutoff = time.time() - 86400 * 7
 
27
 
28
  Naming patterns:
29
  - Checkpoint (type_prefix='model'): 12345_model.safetensors or 12345_model_anime_style.safetensors
30
+ - VAE (suffix='_vae'): 12345_vae.safetensors (no name extraction to avoid double suffix)
31
+ - LoRA (suffix='_lora'): 12345_lora.safetensors (no name extraction to avoid double suffix)
32
 
33
  For HuggingFace URLs without model IDs, attempts to extract name from path or uses suffix-based naming.
34
 
 
39
  type_prefix: Optional prefix after model_id (e.g., 'model' -> 12345_model.safetensors)
40
  """
41
  model_id = extract_model_id(url)
42
+
43
  # If no CivitAI model ID, try to generate a name from HuggingFace path
44
  if not model_id and "huggingface.co" in url:
45
+ # Try to extract name from URL path (e.g., sdxl-vae-fp16-fix -> fp16_fix)
46
  try:
47
  parts = url.split("huggingface.co/")[1] if "huggingface.co/" in url else ""
48
  if parts:
 
56
  model_id = f"hf_{clean_repo}"
57
  except Exception:
58
  pass
59
+
60
  if not model_id:
61
  return default_name
62
 
63
+ # Special handling for VAE/LoRA with HuggingFace URLs to avoid double suffix
 
 
 
 
64
  is_special_type = suffix in ("_vae", "_lora")
65
 
66
+ # Strip common suffixes from model_id when adding corresponding suffix
67
+ # (e.g., "sdxl_vae_fp16_fix" + "_vae" -> "sdxl_fp16_fix" + "_vae")
68
+ if is_special_type:
69
+ strip_suffix = suffix.lstrip('_') # "vae" or "lora"
70
+ model_id_lower = model_id.lower()
71
+ # Check if model_id contains the type (with underscore boundaries)
72
+ if f"_{strip_suffix}_" in model_id_lower or model_id_lower.endswith(f"_{strip_suffix}"):
73
+ # Remove the suffix from model_id
74
+ if model_id_lower.endswith(f"_{strip_suffix}"):
75
+ model_id = model_id[:-len(strip_suffix)-1]
76
+ else:
77
+ # Find and remove _suffix_ pattern
78
+ pattern = f"_{strip_suffix}_"
79
+ idx = model_id_lower.find(pattern)
80
+ if idx >= 0:
81
+ model_id = model_id[:idx] + model_id[idx+len(pattern):]
82
+
83
+ # Build the name portion: either clean name from URL or fallback
84
+ name_part = ""
85
+
86
+ # For VAE/LoRA types, skip Content-Disposition parsing to avoid double naming
87
+ # (e.g., sdxl_vae_vae instead of just vae)
88
  if not is_special_type:
89
  try:
90
  response = requests.head(url, timeout=10, allow_redirects=True)
 
108
  parts.append(type_prefix)
109
  if name_part:
110
  parts.append(name_part)
111
+
112
+ # Handle suffix - for VAE/LoRA we only add the suffix, not double naming
113
  if suffix:
114
+ if is_special_type:
115
+ # For _vae and _lora: just use model_id + suffix directly
116
+ return f"{model_id}{suffix}.safetensors"
117
  else:
118
+ # For other types (checkpoint), append suffix after name_part
119
+ parts.append(suffix.lstrip('_'))
120
 
121
  return '_'.join(p for p in parts if p).replace('__', '_') + '.safetensors'
122
 
123
 
124
  class TqdmGradio(TqdmBase):
125
  """tqdm subclass that sends progress updates to Gradio's gr.Progress()"""
126
+
127
  def __init__(self, *args, gradio_prog=None, **kwargs):
128
  super().__init__(*args, **kwargs)
129
  self.gradio_prog = gradio_prog
 
151
  def get_cached_file_size(url: str, suffix: str = "", type_prefix: str | None = None) -> tuple[Path | None, int | None]:
152
  """
153
  Check if file exists in cache and matches expected size.
154
+
155
  Uses the same filename generation logic as download operations to find
156
  cached files by URL.
157
 
 
165
  or (None, None) if no valid cache found.
166
  """
167
  from .config import CACHE_DIR
168
+
169
  # Generate the expected filename for this URL
170
  default_name = "vae.safetensors" if suffix == "_vae" else (
171
  "lora.safetensors" if suffix == "_lora" else "model.safetensors"
172
  )
173
+
174
  cached_filename = get_safe_filename_from_url(
175
+ url,
176
  default_name=default_name,
177
  suffix=suffix,
178
  type_prefix=type_prefix
179
  )
180
+
181
  cached_path = CACHE_DIR / cached_filename
182
+
183
  if cached_path.exists() and cached_path.is_file():
184
  try:
185
  file_size = cached_path.stat().st_size
 
188
  return cached_path, file_size
189
  except OSError:
190
  pass
191
+
192
  return None, None
193
 
194
 
195
  def download_file_with_progress(url: str, output_path: Path, progress_bar=None) -> Path:
196
  """
197
  Download a file with Gradio-synced progress bar + cancel support.
198
+
199
  Checks for existing cached files before downloading. If a valid cache
200
  exists (file exists with matching expected size), skips re-download.
201
 
 
236
  expected_size = int(head.headers.get('content-length', 0))
237
  except Exception:
238
  pass # Skip header fetch on errors
239
+
240
  if output_path.exists() and expected_size is not None:
241
  try:
242
  cached_size = output_path.stat().st_size
 
288
  def clear_cache(cache_dir: Path = None, keep_extensions: list[str] = None):
289
  """
290
  Remove old cache files.
291
+
292
  Args:
293
  cache_dir: Cache directory path (defaults to config.CACHE_DIR)
294
  keep_extensions: File extensions to preserve (default: ['.safetensors'])
 
296
  if cache_dir is None:
297
  from .config import CACHE_DIR
298
  cache_dir = CACHE_DIR
299
+
300
  if keep_extensions is None:
301
  keep_extensions = ['.safetensors']
302
+
303
  # Remove temp files
304
  for file in cache_dir.glob("*.tmp"):
305
  file.unlink()
306
+
307
  # Optional: age-based cleanup (7 days)
308
  # import time
309
  # cutoff = time.time() - 86400 * 7
src/pipeline.py CHANGED
@@ -83,7 +83,7 @@ def load_pipeline(
83
 
84
  # Check if checkpoint is already cached
85
  checkpoint_cached = checkpoint_path.exists() and checkpoint_path.stat().st_size > 0
86
-
87
  # VAE: Use suffix="_vae" and default to "vae.safetensors" for proper caching/dropdown matching
88
  vae_filename = get_safe_filename_from_url(vae_url, default_name="vae.safetensors", suffix="_vae") if vae_url.strip() else "vae.safetensors"
89
  vae_path = CACHE_DIR / vae_filename
@@ -92,10 +92,10 @@ def load_pipeline(
92
  # Download checkpoint (skips if already cached)
93
  if progress:
94
  progress(0.1, desc="Downloading base model..." if not checkpoint_cached else "Loading base model...")
95
-
96
  status_msg = f"πŸ“₯ Downloading {checkpoint_path.name}..." if not checkpoint_cached else f"βœ… Using cached {checkpoint_path.name}"
97
  yield status_msg, "Starting download..."
98
-
99
  if not checkpoint_cached:
100
  download_file_with_progress(checkpoint_url, checkpoint_path)
101
 
@@ -104,25 +104,27 @@ def load_pipeline(
104
  status_msg = f"πŸ“₯ Downloading {vae_path.name}..." if not vae_cached else f"βœ… Using cached {vae_path.name}"
105
  if progress:
106
  progress(0.2, desc="Downloading VAE..." if not vae_cached else "Loading VAE...")
107
-
108
  yield status_msg, f"Downloading VAE: {vae_path.name}" if not vae_cached else f"Using cached VAE: {vae_path.name}"
109
-
110
  if not vae_cached:
111
  download_file_with_progress(vae_url, vae_path)
112
- vae = AutoencoderKL.from_single_file(str(vae_path), torch_dtype=dtype)
113
- else:
114
- vae = None
115
 
116
- # Load base pipeline
 
117
  if progress:
118
  progress(0.4, desc="Loading SDXL pipeline...")
119
- yield f"βš™οΈ Loading pipeline...", f"Using device: {device_description}"
120
  global_pipe = StableDiffusionXLPipeline.from_single_file(
121
  str(checkpoint_path),
122
  torch_dtype=dtype,
123
  use_safetensors=True,
124
  safety_checker=None,
125
  )
 
 
 
 
126
  if vae:
127
  global_pipe.vae = vae.to(device=device, dtype=dtype)
128
 
@@ -141,7 +143,7 @@ def load_pipeline(
141
  if lora_urls:
142
  global_pipe = global_pipe.to(device=device, dtype=dtype)
143
  for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)):
144
- lora_filename = get_safe_filename_from_url(lora_url, f"lora_{i}.safetensors", suffix="_lora")
145
  lora_path = CACHE_DIR / lora_filename
146
  lora_cached = lora_path.exists() and lora_path.stat().st_size > 0
147
 
@@ -159,17 +161,25 @@ def load_pipeline(
159
 
160
  if not lora_cached:
161
  download_file_with_progress(lora_url, lora_path)
 
 
 
 
 
162
  adapter_name = f"lora_{i}"
163
  global_pipe.load_lora_weights(str(lora_path), adapter_name=adapter_name)
164
  global_pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength)
165
  global_pipe.unload_lora_weights()
166
 
167
  # Set scheduler and move to device (do this once at the end)
168
- yield "βš™οΈ Finalizing...", "Setting up scheduler..."
169
- # Use existing scheduler, just update algorithm_type for DPM++ SDE
 
 
170
  global_pipe.scheduler.config.algorithm_type = "sde-dpmsolver++"
171
  global_pipe = global_pipe.to(device=device, dtype=dtype)
172
 
 
173
  return ("βœ… Pipeline loaded successfully!", f"Ready! Loaded {len(lora_urls)} LoRA(s)")
174
 
175
  except KeyboardInterrupt:
 
83
 
84
  # Check if checkpoint is already cached
85
  checkpoint_cached = checkpoint_path.exists() and checkpoint_path.stat().st_size > 0
86
+
87
  # VAE: Use suffix="_vae" and default to "vae.safetensors" for proper caching/dropdown matching
88
  vae_filename = get_safe_filename_from_url(vae_url, default_name="vae.safetensors", suffix="_vae") if vae_url.strip() else "vae.safetensors"
89
  vae_path = CACHE_DIR / vae_filename
 
92
  # Download checkpoint (skips if already cached)
93
  if progress:
94
  progress(0.1, desc="Downloading base model..." if not checkpoint_cached else "Loading base model...")
95
+
96
  status_msg = f"πŸ“₯ Downloading {checkpoint_path.name}..." if not checkpoint_cached else f"βœ… Using cached {checkpoint_path.name}"
97
  yield status_msg, "Starting download..."
98
+
99
  if not checkpoint_cached:
100
  download_file_with_progress(checkpoint_url, checkpoint_path)
101
 
 
104
  status_msg = f"πŸ“₯ Downloading {vae_path.name}..." if not vae_cached else f"βœ… Using cached {vae_path.name}"
105
  if progress:
106
  progress(0.2, desc="Downloading VAE..." if not vae_cached else "Loading VAE...")
107
+
108
  yield status_msg, f"Downloading VAE: {vae_path.name}" if not vae_cached else f"Using cached VAE: {vae_path.name}"
109
+
110
  if not vae_cached:
111
  download_file_with_progress(vae_url, vae_path)
 
 
 
112
 
113
+ # Load base pipeline (yield progress during this heavy operation)
114
+ yield "βš™οΈ Loading SDXL pipeline...", "Loading model weights into memory..."
115
  if progress:
116
  progress(0.4, desc="Loading SDXL pipeline...")
117
+
118
  global_pipe = StableDiffusionXLPipeline.from_single_file(
119
  str(checkpoint_path),
120
  torch_dtype=dtype,
121
  use_safetensors=True,
122
  safety_checker=None,
123
  )
124
+ yield "βš™οΈ Pipeline loaded, setting up VAE...", f"Using device: {device_description}"
125
+ if progress:
126
+ progress(0.6, desc="Setting up VAE...")
127
+
128
  if vae:
129
  global_pipe.vae = vae.to(device=device, dtype=dtype)
130
 
 
143
  if lora_urls:
144
  global_pipe = global_pipe.to(device=device, dtype=dtype)
145
  for i, (lora_url, strength) in enumerate(zip(lora_urls, strengths)):
146
+ lora_filename = get_safe_filename_from_url(lora_url, suffix="_lora")
147
  lora_path = CACHE_DIR / lora_filename
148
  lora_cached = lora_path.exists() and lora_path.stat().st_size > 0
149
 
 
161
 
162
  if not lora_cached:
163
  download_file_with_progress(lora_url, lora_path)
164
+
165
+ yield f"βš™οΈ Loading LoRA {i+1}/{len(lora_urls)}...", f"Fusing {lora_path.name}..."
166
+ if progress:
167
+ progress(0.7 + (0.2 * i / len(lora_urls)), desc=f"Loading LoRA {i+1}/{len(lora_urls)}...")
168
+
169
  adapter_name = f"lora_{i}"
170
  global_pipe.load_lora_weights(str(lora_path), adapter_name=adapter_name)
171
  global_pipe.fuse_lora(adapter_names=[adapter_name], lora_scale=strength)
172
  global_pipe.unload_lora_weights()
173
 
174
  # Set scheduler and move to device (do this once at the end)
175
+ yield "βš™οΈ Finalizing pipeline...", "Setting up scheduler and moving to device..."
176
+ if progress:
177
+ progress(0.9, desc="Finalizing...")
178
+
179
  global_pipe.scheduler.config.algorithm_type = "sde-dpmsolver++"
180
  global_pipe = global_pipe.to(device=device, dtype=dtype)
181
 
182
+ yield "βœ… Pipeline ready!", f"Ready! Loaded {len(lora_urls)} LoRA(s)"
183
  return ("βœ… Pipeline loaded successfully!", f"Ready! Loaded {len(lora_urls)} LoRA(s)")
184
 
185
  except KeyboardInterrupt: