Duskfallcrew commited on
Commit
f487905
·
verified ·
1 Parent(s): bff230f

Upload 2 files

Browse files

A lot more changes to make it more robust see https://github.com/Ktiseos-Nyx/Gradio-SDXL-Diffusers/tree/v2_deepseek

Files changed (2) hide show
  1. app.py +104 -790
  2. requirements.txt +2 -24
app.py CHANGED
@@ -1,10 +1,16 @@
 
 
1
  import os
2
  import gradio as gr
3
  import torch
4
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
5
- from transformers import CLIPTextModel, CLIPTextConfig
 
 
6
  from safetensors.torch import load_file
7
  from collections import OrderedDict
 
 
8
  import re
9
  import json
10
  import gdown
@@ -20,834 +26,142 @@ import shutil
20
  import hashlib
21
  from datetime import datetime
22
  from typing import Dict, List, Optional
 
 
23
  from huggingface_hub import login, HfApi
24
  from types import SimpleNamespace
25
 
26
- # Remove unused imports
27
- # import os
28
- # import gradio as gr
29
- # import torch
30
- # from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
31
- # from transformers import CLIPTextModel, CLIPTextConfig
32
- # from safetensors.torch import load_file
33
- # from collections import OrderedDict
34
- # import re
35
- # import json
36
- # import gdown
37
- # import requests
38
- # import subprocess
39
- # from urllib.parse import urlparse, unquote
40
- # from pathlib import Path
41
- # import tempfile
42
- # from tqdm import tqdm
43
- # import psutil
44
- # import math
45
- # import shutil
46
- # import hashlib
47
- # from datetime import datetime
48
- # from typing import Dict, List, Optional
49
- # from huggingface_hub import login, HfApi
50
- # from types import SimpleNamespace
51
-
52
  # ---------------------- UTILITY FUNCTIONS ----------------------
53
-
54
  def is_valid_url(url):
55
- """Checks if a string is a valid URL."""
56
  try:
57
  result = urlparse(url)
58
  return all([result.scheme, result.netloc])
59
- except Exception as e:
60
- print(f"Error checking URL validity: {e}")
61
  return False
62
 
63
  def get_filename(url):
64
- """Extracts the filename from a URL."""
65
  try:
66
  response = requests.get(url, stream=True)
67
  response.raise_for_status()
68
-
69
  if 'content-disposition' in response.headers:
70
- content_disposition = response.headers['content-disposition']
71
- filename = re.findall('filename="?([^";]+)"?', content_disposition)[0]
72
- else:
73
- url_path = urlparse(url).path
74
- filename = unquote(os.path.basename(url_path))
75
-
76
- return filename
77
  except Exception as e:
78
- print(f"Error getting filename from URL: {e}")
79
- return None
80
 
81
  def get_supported_extensions():
82
- """Returns a tuple of supported model file extensions."""
83
- return tuple([".ckpt", ".safetensors", ".pt", ".pth"])
84
-
85
- def download_model(url, dst, output_widget):
86
- """Downloads a model from a URL to the specified destination."""
87
- filename = get_filename(url)
88
- filepath = os.path.join(dst, filename)
89
- try:
90
- if "drive.google.com" in url:
91
- gdown = gdown_download(url, dst, filepath)
92
- else:
93
- if "huggingface.co" in url:
94
- if "/blob/" in url:
95
- url = url.replace("/blob/", "/resolve/")
96
- subprocess.run(["aria2c","-x 16",url,"-d",dst,"-o",filename])
97
- return filepath
98
- except Exception as e:
99
- print(f"Error downloading model: {e}")
100
- return None
101
-
102
- def determine_load_checkpoint(model_to_load):
103
- """Determines if the model to load is a checkpoint, Diffusers model, or URL."""
104
- try:
105
- if is_valid_url(model_to_load) and (model_to_load.endswith(get_supported_extensions())):
106
- return True
107
- elif model_to_load.endswith(get_supported_extensions()):
108
- return True
109
- elif os.path.isdir(model_to_load):
110
- required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
111
- if required_folders.issubset(set(os.listdir(model_to_load))) and os.path.isfile(os.path.join(model_to_load, "model_index.json")):
112
- return False
113
- except Exception as e:
114
- print(f"Error determining load checkpoint: {e}")
115
- return None # handle this case as required
116
-
117
- def create_model_repo(api, user, orgs_name, model_name, make_private=False):
118
- """Creates a Hugging Face model repository if it doesn't exist."""
119
- try:
120
- if orgs_name == "":
121
- repo_id = user["name"] + "/" + model_name.strip()
122
- else:
123
- repo_id = orgs_name + "/" + model_name.strip()
124
-
125
- validate_repo_id(repo_id)
126
- api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
127
- print(f"Model repo '{repo_id}' didn't exist, creating repo")
128
- except HfHubHTTPError as e:
129
- print(f"Model repo '{repo_id}' exists, skipping create repo")
130
-
131
- print(f"Model repo '{repo_id}' link: https://huggingface.co/{repo_id}\n")
132
-
133
- return repo_id
134
-
135
- def is_diffusers_model(model_path):
136
- """Checks if a given path is a valid Diffusers model directory."""
137
- try:
138
- required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
139
- return required_folders.issubset(set(os.listdir(model_path))) and os.path.isfile(os.path.join(model_path, "model_index.json"))
140
- except Exception as e:
141
- print(f"Error checking if model is a Diffusers model: {e}")
142
- return False
143
-
144
- # ---------------------- MODEL UTIL (From library.sdxl_model_util) ----------------------
145
-
146
- def load_models_from_sdxl_checkpoint(sdxl_base_id, checkpoint_path, device):
147
- """Loads SDXL model components from a checkpoint file."""
148
- try:
149
- text_encoder1 = CLIPTextModel.from_pretrained(sdxl_base_id, subfolder="text_encoder").to(device)
150
- text_encoder2 = CLIPTextModel.from_pretrained(sdxl_base_id, subfolder="text_encoder_2").to(device)
151
- vae = AutoencoderKL.from_pretrained(sdxl_base_id, subfolder="vae").to(device)
152
- unet = UNet2DConditionModel.from_pretrained(sdxl_base_id, subfolder="unet").to(device)
153
- unet = unet
154
-
155
- ckpt_state_dict = torch.load(checkpoint_path, map_location=device)
156
-
157
- o = OrderedDict()
158
- for key in list(ckpt_state_dict.keys()):
159
- o[key.replace("module.", "")] = ckpt_state_dict[key]
160
- del ckpt_state_dict
161
-
162
- print("Applying weights to text encoder 1:")
163
- text_encoder1.load_state_dict({
164
- '.'.join(key.split('.')[1:]): o[key] for key in list(o.keys()) if key.startswith("first_stage_model.cond_stage_model.model.transformer")
165
- }, strict=False)
166
- print("Applying weights to text encoder 2:")
167
- text_encoder2.load_state_dict({
168
- '.'.join(key.split('.')[1:]): o[key] for key in list(o.keys()) if key.startswith("cond_stage_model.model.transformer")
169
- }, strict=False)
170
- print("Applying weights to VAE:")
171
- vae.load_state_dict({
172
- '.'.join(key.split('.')[2:]): o[key] for key in list(o.keys()) if key.startswith("first_stage_model.model")
173
- }, strict=False)
174
- print("Applying weights to UNet:")
175
- unet.load_state_dict({
176
- key: o[key] for key in list(o.keys()) if key.startswith("model.diffusion_model")
177
- }, strict=False)
178
-
179
- logit_scale = None #Not used here!
180
- global_step = None #Not used here!
181
- return text_encoder1, text_encoder2, vae, unet, logit_scale, global_step
182
- except Exception as e:
183
- print(f"Error loading models from checkpoint: {e}")
184
- return None
185
-
186
- def save_stable_diffusion_checkpoint(save_path, text_encoder1, text_encoder2, unet, epoch, global_step, ckpt_info, vae, logit_scale, save_dtype):
187
- """Saves the stable diffusion checkpoint."""
188
- weights = OrderedDict()
189
- text_encoder1_dict = text_encoder1.state_dict()
190
- text_encoder2_dict = text_encoder2.state_dict()
191
- unet_dict = unet.state_dict()
192
- vae_dict = vae.state_dict()
193
-
194
- def replace_key(key):
195
- key = "cond_stage_model.model.transformer." + key
196
- return key
197
-
198
- print("Merging text encoder 1")
199
- for key in tqdm(list(text_encoder1_dict.keys())):
200
- weights["first_stage_model.cond_stage_model.model.transformer." + key] = text_encoder1_dict[key].to(save_dtype)
201
-
202
- print("Merging text encoder 2")
203
- for key in tqdm(list(text_encoder2_dict.keys())):
204
- weights[replace_key(key)] = text_encoder2_dict[key].to(save_dtype)
205
-
206
- print("Merging vae")
207
- for key in tqdm(list(vae_dict.keys())):
208
- weights["first_stage_model.model." + key] = vae_dict[key].to(save_dtype)
209
-
210
- print("Merging unet")
211
- for key in tqdm(list(unet_dict.keys())):
212
- weights["model.diffusion_model." + key] = unet_dict[key].to(save_dtype)
213
-
214
- info = {"epoch": epoch, "global_step": global_step}
215
- if ckpt_info is not None:
216
- info.update(ckpt_info)
217
-
218
- if logit_scale is not None:
219
- info["logit_scale"] = logit_scale.item()
220
-
221
- torch.save({"state_dict": weights, "info": info}, save_path)
222
-
223
- key_count = len(weights.keys())
224
- del weights
225
- del text_encoder1_dict, text_encoder2_dict, unet_dict, vae_dict
226
- return key_count
227
-
228
- def save_diffusers_checkpoint(save_path, text_encoder1, text_encoder2, unet, reference_model, vae, trim_if_model_exists, save_dtype):
229
- """Saves the SDXL model as a Diffusers model."""
230
- print("Saving SDXL as Diffusers format to:", save_path)
231
- print("SDXL Text Encoder 1 to:", os.path.join(save_path, "text_encoder"))
232
- text_encoder1.save_pretrained(os.path.join(save_path, "text_encoder"))
233
-
234
- print("SDXL Text Encoder 2 to:", os.path.join(save_path, "text_encoder_2"))
235
- text_encoder2.save_pretrained(os.path.join(save_path, "text_encoder_2"))
236
-
237
- print("SDXL VAE to:", os.path.join(save_path, "vae"))
238
- vae.save_pretrained(os.path.join(save_path, "vae"))
239
-
240
- print("SDXL UNet to:", os.path.join(save_path, "unet"))
241
- unet.save_pretrained(os.path.join(save_path, "unet"))
242
-
243
- if reference_model is not None:
244
- print(f"Copying scheduler from {reference_model}")
245
- scheduler_src = StableDiffusionXLPipeline.from_pretrained(reference_model, torch_dtype=torch.float16).scheduler
246
- torch.save(scheduler_src.config, os.path.join(save_path, "scheduler", "scheduler_config.json"))
247
- else:
248
- print(f"No reference Model. Copying scheduler from original model.")
249
- scheduler_src = StableDiffusionXLPipeline.from_pretrained(reference_model, torch_dtype=torch.float16).scheduler
250
- scheduler_src.save_pretrained(save_path)
251
-
252
- if trim_if_model_exists:
253
- print("Trim Complete")
254
-
255
- # ---------------------- CONVERSION AND UPLOAD FUNCTIONS ----------------------
256
-
257
- def load_sdxl_model(args, is_load_checkpoint, load_dtype, output_widget):
258
- """Loads the SDXL model from a checkpoint or Diffusers model."""
259
- model_load_message = "checkpoint" if is_load_checkpoint else "Diffusers" + (" as fp16" if args.fp16 else "")
260
- with output_widget:
261
- print(f"Loading {model_load_message}: {args.model_to_load}")
262
-
263
- if is_load_checkpoint:
264
- loaded_model_data = load_from_sdxl_checkpoint(args, output_widget)
265
- else:
266
- loaded_model_data = load_sdxl_from_diffusers(args, load_dtype)
267
-
268
- return loaded_model_data
269
-
270
- def load_from_sdxl_checkpoint(args, output_widget):
271
- """Loads the SDXL model components from a checkpoint file (placeholder)."""
272
- text_encoder1, text_encoder2, vae, unet = None, None, None, None
273
- device = "cpu"
274
- if is_valid_url(args.model_to_load):
275
- tmp_model_name = "download"
276
- download_dst_dir = tempfile.mkdtemp()
277
- model_path = download_model(args.model_to_load, download_dst_dir, output_widget)
278
- #model_path = os.path.join(download_dst_dir,tmp_model_name)
279
- if model_path == None:
280
- with output_widget:
281
- print("Loading from Checkpoint failed, the request could not be completed")
282
- return text_encoder1, text_encoder2, vae, unet
283
- else:
284
- # Implement Load model from ckpt or safetensors
285
- try:
286
- text_encoder1, text_encoder2, vae, unet, _, _ = load_models_from_sdxl_checkpoint(
287
- "sdxl_base_v1-0", model_path, device
288
- )
289
- return text_encoder1, text_encoder2, vae, unet
290
- except Exception as e:
291
- print(f"Could not load SDXL from checkpoint due to: \n{e}")
292
- return text_encoder1, text_encoder2, vae, unet
293
-
294
- with output_widget:
295
- print(f"Loading from Checkpoint from URL needs to be implemented - using {model_path}")
296
- else:
297
- # Implement Load model from ckpt or safetensors
298
- try:
299
- text_encoder1, text_encoder2, vae, unet, _, _ = load_models_from_sdxl_checkpoint(
300
- "sdxl_base_v1-0", args.model_to_load, device
301
- )
302
- return text_encoder1, text_encoder2, vae, unet
303
- except Exception as e:
304
- print(f"Could not load SDXL from checkpoint due to: \n{e}")
305
- return text_encoder1, text_encoder2, vae, unet
306
-
307
- with output_widget:
308
- print("Loading from Checkpoint needs to be implemented.")
309
-
310
- return text_encoder1, text_encoder2, vae, unet
311
-
312
- def load_sdxl_from_diffusers(args, load_dtype):
313
- """Loads an SDXL model from a Diffusers model directory."""
314
- pipeline = StableDiffusionXLPipeline.from_pretrained(
315
- args.model_to_load, torch_dtype=load_dtype, tokenizer=None, tokenizer_2=None, scheduler=None
316
- )
317
- text_encoder1 = pipeline.text_encoder
318
- text_encoder2 = pipeline.text_encoder_2
319
- vae = pipeline.vae
320
- unet = pipeline.unet
321
-
322
- return text_encoder1, text_encoder2, vae, unet
323
-
324
- def convert_and_save_sdxl_model(args, is_save_checkpoint, loaded_model_data, save_dtype, output_widget):
325
- """Converts and saves the SDXL model as either a checkpoint or a Diffusers model."""
326
- text_encoder1, text_encoder2, vae, unet = loaded_model_data
327
- model_save_message = "checkpoint" + ("" if save_dtype is None else f" in {save_dtype}") if is_save_checkpoint else "Diffusers"
328
-
329
- with output_widget:
330
- print(f"Converting and saving as {model_save_message}: {args.model_to_save}")
331
-
332
- if is_save_checkpoint:
333
- save_sdxl_as_checkpoint(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget)
334
- else:
335
- save_sdxl_as_diffusers(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget)
336
-
337
- def save_sdxl_as_checkpoint(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget):
338
- """Saves the SDXL model components as a checkpoint file (placeholder)."""
339
- logit_scale = None
340
- ckpt_info = None
341
-
342
- key_count = save_stable_diffusion_checkpoint(
343
- args.model_to_save, text_encoder1, text_encoder2, unet, args.epoch, args.global_step, ckpt_info, vae, logit_scale, save_dtype
344
- )
345
- with output_widget:
346
- print(f"Model saved. Total converted state_dict keys: {key_count}")
347
-
348
- def save_sdxl_as_diffusers(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget):
349
- """Saves the SDXL model as a Diffusers model."""
350
- with output_widget:
351
- reference_model_message = args.reference_model if args.reference_model is not None else 'default model'
352
- print(f"Copying scheduler/tokenizer config from: {reference_model_message}")
353
-
354
- # Save diffusers pipeline
355
- pipeline = StableDiffusionXLPipeline(
356
- vae=vae,
357
- text_encoder=text_encoder1,
358
- text_encoder_2=text_encoder2,
359
- unet=unet,
360
- scheduler=None, # Replace None if there is a scheduler
361
- tokenizer=None, # Replace None if there is a tokenizer
362
- tokenizer_2=None # Replace None if there is a tokenizer_2
363
- )
364
-
365
- pipeline.save_pretrained(args.model_to_save)
366
-
367
- with output_widget:
368
- print(f"Model saved as {save_dtype}.")
369
-
370
- def get_save_dtype(precision):
371
- """
372
- Convert precision string to torch dtype
373
- """
374
- if precision == "float32" or precision == "fp32":
375
- return torch.float32
376
- elif precision == "float16" or precision == "fp16":
377
- return torch.float16
378
- elif precision == "bfloat16" or precision == "bf16":
379
- return torch.bfloat16
380
- else:
381
- raise ValueError(f"Unsupported precision: {precision}")
382
-
383
- def get_file_size(file_path):
384
- """Get file size in GB."""
385
- try:
386
- size_bytes = Path(file_path).stat().st_size
387
- return size_bytes / (1024 * 1024 * 1024) # Convert to GB
388
- except:
389
- return None
390
-
391
- def get_available_memory():
392
- """Get available system memory in GB."""
393
- return psutil.virtual_memory().available / (1024 * 1024 * 1024)
394
-
395
- def estimate_memory_requirements(model_path, precision):
396
- """Estimate memory requirements for model conversion."""
397
- try:
398
- # Base memory requirement for SDXL
399
- base_memory = 8 # GB
400
-
401
- # Get model size if local file
402
- model_size = get_file_size(model_path) if not is_valid_url(model_path) else None
403
-
404
- # Adjust for precision
405
- memory_multiplier = 1.0 if precision in ["float16", "fp16", "bfloat16", "bf16"] else 2.0
406
-
407
- # Calculate total required memory
408
- required_memory = (base_memory + (model_size if model_size else 12)) * memory_multiplier
409
-
410
- return required_memory
411
- except:
412
- return 16 # Default safe estimate
413
-
414
- def validate_model(model_path, precision):
415
- """
416
- Validate the model before conversion.
417
- Returns (is_valid, message)
418
- """
419
- try:
420
- # Check if it's a URL
421
- if is_valid_url(model_path):
422
- try:
423
- response = requests.head(model_path)
424
- if response.status_code != 200:
425
- return False, "❌ Invalid URL or model not accessible"
426
- if 'content-length' in response.headers:
427
- size_gb = int(response.headers['content-length']) / (1024 * 1024 * 1024)
428
- if size_gb < 0.1 and not model_path.endswith(('.ckpt', '.safetensors')):
429
- return False, "❌ File too small to be a valid model"
430
- except:
431
- return False, "❌ Error checking URL"
432
-
433
- # Check if it's a local file
434
- elif not model_path.startswith("stabilityai/") and not Path(model_path).exists():
435
- return False, "❌ Model file not found"
436
-
437
- # Check available memory
438
- available_memory = get_available_memory()
439
- required_memory = estimate_memory_requirements(model_path, precision)
440
-
441
- if available_memory < required_memory:
442
- return True, f"⚠️ Insufficient memory detected. Need {math.ceil(required_memory)}GB, but only {math.ceil(available_memory)}GB available"
443
-
444
- # Memory warning
445
- memory_message = ""
446
- if available_memory < required_memory * 1.5:
447
- memory_message = "⚠️ Memory is tight. Consider closing other applications."
448
-
449
- return True, f"✅ Model validated successfully. {memory_message}"
450
-
451
- except Exception as e:
452
- return False, f"❌ Validation error: {str(e)}"
453
-
454
- def cleanup_temp_files(directory=None):
455
- """Clean up temporary files after conversion."""
456
- try:
457
- if directory:
458
- shutil.rmtree(directory, ignore_errors=True)
459
- # Clean up other temp files
460
- temp_pattern = "*.tmp"
461
- for temp_file in Path(".").glob(temp_pattern):
462
- temp_file.unlink()
463
- except Exception as e:
464
- print(f"Warning: Error during cleanup: {e}")
465
-
466
- def convert_model(model_to_load, save_precision_as, epoch, global_step, reference_model, fp16, use_xformers, hf_token, orgs_name, model_name, make_private):
467
- """Convert the model between different formats."""
468
- temp_dir = None
469
- history = ConversionHistory()
470
-
471
- try:
472
- print("Starting model conversion...")
473
- update_progress(output_widget, "⏳ Initializing conversion process...", 0)
474
-
475
- # Get optimization suggestions
476
- available_memory = get_available_memory()
477
- auto_suggestions = get_auto_optimization_suggestions(model_to_load, save_precision_as, available_memory)
478
- history_suggestions = history.get_optimization_suggestions(model_to_load)
479
-
480
- # Display suggestions
481
- if auto_suggestions or history_suggestions:
482
- print("\n🔍 Optimization Suggestions:")
483
- for suggestion in auto_suggestions + history_suggestions:
484
- print(suggestion)
485
- print("\n")
486
-
487
- # Validate model
488
- is_valid, message = validate_model(model_to_load, save_precision_as)
489
- if not is_valid:
490
- raise ValueError(message)
491
- print(message)
492
-
493
- args = SimpleNamespace()
494
- args.model_to_load = model_to_load
495
- args.save_precision_as = save_precision_as
496
- args.epoch = epoch
497
- args.global_step = global_step
498
- args.reference_model = reference_model
499
- args.fp16 = fp16
500
- args.use_xformers = use_xformers
501
-
502
- update_progress(output_widget, "🔍 Validating input model...", 10)
503
- args.model_to_save = increment_filename(os.path.splitext(args.model_to_load)[0] + ".safetensors")
504
-
505
- save_dtype = get_save_dtype(save_precision_as)
506
-
507
- # Create temporary directory for processing
508
- temp_dir = tempfile.mkdtemp(prefix="sdxl_conversion_")
509
-
510
- update_progress(output_widget, "📥 Loading model components...", 30)
511
- is_load_checkpoint = determine_load_checkpoint(args.model_to_load)
512
- if is_load_checkpoint is None:
513
- raise ValueError("Invalid model format or path")
514
-
515
- update_progress(output_widget, "🔄 Converting model...", 50)
516
- loaded_model_data = load_sdxl_model(args, is_load_checkpoint, save_dtype, output_widget)
517
-
518
- update_progress(output_widget, "💾 Saving converted model...", 80)
519
- is_save_checkpoint = args.model_to_save.endswith(get_supported_extensions())
520
- result = convert_and_save_sdxl_model(args, is_save_checkpoint, loaded_model_data, save_dtype, output_widget)
521
-
522
- update_progress(output_widget, "✅ Conversion completed!", 100)
523
- print(f"Model conversion completed. Saved to: {args.model_to_save}")
524
-
525
- # Verify the converted model
526
- is_valid, verify_message = verify_model_structure(args.model_to_save)
527
- if not is_valid:
528
- raise ValueError(verify_message)
529
- print(verify_message)
530
-
531
- # Record successful conversion
532
- history.add_entry(
533
- model_to_load,
534
- {
535
- 'precision': save_precision_as,
536
- 'fp16': fp16,
537
- 'epoch': epoch,
538
- 'global_step': global_step
539
- },
540
- True,
541
- "Conversion completed successfully"
542
- )
543
-
544
- cleanup_temp_files(temp_dir)
545
- return result
546
-
547
- except Exception as e:
548
- if temp_dir:
549
- cleanup_temp_files(temp_dir)
550
-
551
- # Record failed conversion
552
- history.add_entry(
553
- model_to_load,
554
- {
555
- 'precision': save_precision_as,
556
- 'fp16': fp16,
557
- 'epoch': epoch,
558
- 'global_step': global_step
559
- },
560
- False,
561
- str(e)
562
- )
563
-
564
- error_msg = f"❌ Error during model conversion: {str(e)}"
565
- print(error_msg)
566
- return error_msg
567
-
568
- def update_progress(output_widget, message, progress):
569
- """Update the progress bar and message in the UI."""
570
- progress_bar = "▓" * (progress // 5) + "░" * ((100 - progress) // 5)
571
- print(f"{message}\n[{progress_bar}] {progress}%")
572
 
 
573
  class ConversionHistory:
 
574
  def __init__(self, history_file="conversion_history.json"):
575
  self.history_file = history_file
576
  self.history = self._load_history()
577
-
578
- def _load_history(self) -> List[Dict]:
579
  try:
580
  with open(self.history_file, 'r') as f:
581
  return json.load(f)
582
- except (FileNotFoundError, json.JSONDecodeError):
583
  return []
584
-
585
- def _save_history(self):
586
- with open(self.history_file, 'w') as f:
587
- json.dump(self.history, f, indent=2)
588
-
589
- def add_entry(self, model_path: str, settings: Dict, success: bool, message: str):
590
  entry = {
591
- 'timestamp': datetime.now().isoformat(),
592
- 'model_path': model_path,
593
- 'settings': settings,
594
- 'success': success,
595
- 'message': message
596
  }
597
  self.history.append(entry)
598
  self._save_history()
599
-
600
- def get_optimization_suggestions(self, model_path: str) -> List[str]:
601
- """Analyze history and provide optimization suggestions."""
602
  suggestions = []
603
- similar_conversions = [h for h in self.history if h['model_path'] == model_path]
604
-
605
- if similar_conversions:
606
- success_rate = sum(1 for h in similar_conversions if h['success']) / len(similar_conversions)
607
- if success_rate < 1.0:
608
- failed_attempts = [h for h in similar_conversions if not h['success']]
609
- if any('memory' in h['message'].lower() for h in failed_attempts):
610
- suggestions.append("⚠️ Previous attempts had memory issues. Consider using fp16 precision.")
611
- if any('timeout' in h['message'].lower() for h in failed_attempts):
612
- suggestions.append("⚠️ Previous attempts timed out. Try breaking down the conversion process.")
613
-
614
  return suggestions
615
 
616
- def verify_model_structure(model_path: str) -> tuple[bool, str]:
617
- """Verify the structure of the converted model."""
618
- try:
619
- if model_path.endswith('.safetensors'):
620
- # Verify safetensors structure
621
- with safe_open(model_path, framework="pt") as f:
622
- if not f.keys():
623
- return False, "❌ Invalid safetensors file: no tensors found"
624
-
625
- # Check for essential components
626
- required_keys = ["model.diffusion_model", "first_stage_model"]
627
- missing_keys = []
628
-
629
- # Load and check key components
630
- state_dict = load_file(model_path)
631
- for key in required_keys:
632
- if not any(k.startswith(key) for k in state_dict.keys()):
633
- missing_keys.append(key)
634
-
635
- if missing_keys:
636
- return False, f"❌ Missing essential components: {', '.join(missing_keys)}"
637
-
638
- return True, "✅ Model structure verified successfully"
639
- except Exception as e:
640
- return False, f"❌ Model verification failed: {str(e)}"
641
-
642
- def get_auto_optimization_suggestions(model_path: str, precision: str, available_memory: float) -> List[str]:
643
- """Generate automatic optimization suggestions based on model and system characteristics."""
644
- suggestions = []
645
-
646
- # Memory-based suggestions
647
- if available_memory < 16:
648
- suggestions.append("💡 Limited memory detected. Consider these options:")
649
- suggestions.append(" - Use fp16 precision to reduce memory usage")
650
- suggestions.append(" - Close other applications before conversion")
651
- suggestions.append(" - Use a machine with more RAM if available")
652
-
653
- # Precision-based suggestions
654
- if precision == "float32" and available_memory < 32:
655
- suggestions.append("💡 Consider using fp16 precision for better memory efficiency")
656
-
657
- # Model size-based suggestions
658
- model_size = get_file_size(model_path) if not is_valid_url(model_path) else None
659
- if model_size and model_size > 10:
660
- suggestions.append("💡 Large model detected. Recommendations:")
661
- suggestions.append(" - Ensure stable internet connection for URL downloads")
662
- suggestions.append(" - Consider breaking down the conversion process")
663
-
664
- return suggestions
665
-
666
- def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private):
667
- """Uploads a model to the Hugging Face Hub."""
668
  try:
669
- # Login to Hugging Face
670
- login(hf_token, add_to_git_credential=True)
671
-
672
- # Prepare model upload
673
- if not os.path.exists(model_path):
674
- raise ValueError("Model path does not exist.")
675
-
676
- # Check if repo already exists
677
- api = HfApi()
678
- repo_id = f"{orgs_name}/{model_name}" if orgs_name else model_name
679
- try:
680
- api.repo_info(repo_id)
681
- print(f"⚠️ Repository '{repo_id}' already exists. Proceeding with upload.")
682
- except Exception:
683
- if make_private:
684
- api.create_repo(repo_id, private=True)
685
- else:
686
- api.create_repo(repo_id)
687
-
688
- # Push model files
689
- api.upload_folder(
690
- folder_path=model_path,
691
- path_in_repo="",
692
- repo_id=repo_id,
693
- commit_message=f"Upload model: {model_name}",
694
- ignore_patterns=".ipynb_checkpoints",
695
- )
696
-
697
- print(f"Model uploaded to: https://huggingface.co/{repo_id}")
698
- return f"Model uploaded to: https://huggingface.co/{repo_id}"
699
  except Exception as e:
700
- error_msg = f"❌ Error during upload: {str(e)}"
701
- print(error_msg)
702
- return error_msg
703
 
704
  # ---------------------- GRADIO INTERFACE ----------------------
 
 
 
 
 
 
 
 
 
 
 
705
 
706
- def main(model_to_load, save_precision_as, epoch, global_step, reference_model, fp16, use_xformers, hf_token, orgs_name, model_name, make_private):
707
- """Main function orchestrating the entire process."""
708
- output = gr.Markdown()
709
-
710
- # Create tempdir, will only be there for the function
711
- with tempfile.TemporaryDirectory() as output_path:
712
- conversion_output = convert_model(model_to_load, save_precision_as, epoch, global_step, reference_model, fp16, use_xformers, hf_token, orgs_name, model_name, make_private)
713
-
714
- upload_output = upload_to_huggingface(output_path, hf_token, orgs_name, model_name, make_private)
715
-
716
- # Return a combined output
717
- return f"{conversion_output}\n\n{upload_output}"
718
-
719
- def increment_filename(filename):
720
- """
721
- If a file exists, add a number to the filename to make it unique.
722
- Example: if test.txt exists, return test(1).txt
723
- """
724
- if not os.path.exists(filename):
725
- return filename
726
-
727
- directory = os.path.dirname(filename)
728
- name, ext = os.path.splitext(os.path.basename(filename))
729
- counter = 1
730
-
731
- while True:
732
- new_name = os.path.join(directory, f"{name}({counter}){ext}")
733
- if not os.path.exists(new_name):
734
- return new_name
735
- counter += 1
736
-
737
- with gr.Blocks(css="#main-container { display: flex; flex-direction: column; height: 100vh; justify-content: space-between; font-family: 'Arial', sans-serif; font-size: 16px; color: #333; } #convert-button { margin-top: auto; }") as demo:
738
- gr.Markdown("""
739
- # 🎨 SDXL Model Converter
740
- Convert SDXL models between different formats and precisions. Works on CPU!
741
-
742
- ### 📥 Input Sources Supported:
743
- - Local model files (.safetensors, .ckpt, etc.)
744
- - Direct URLs to model files
745
- - Hugging Face model repositories (e.g., 'stabilityai/stable-diffusion-xl-base-1.0')
746
 
747
- ### ℹ️ Important Notes:
748
- - This tool runs on CPU, though conversion might be slower than on GPU
749
- - For Hugging Face uploads, you need a **WRITE** token (not a read token)
750
- - Get your HF token here: https://huggingface.co/settings/tokens
751
-
752
- ### 💾 Memory Usage Tips:
753
- - Use FP16 precision when possible to reduce memory usage
754
- - Close other applications during conversion
755
- - For large models, ensure you have at least 16GB of RAM
756
- """)
757
- with gr.Row():
758
- with gr.Column():
759
- model_to_load = gr.Textbox(
760
- label="Model Path/URL/HF Repo",
761
- placeholder="Enter local path, URL, or Hugging Face model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)",
762
- type="text"
763
  )
764
-
765
- save_precision_as = gr.Dropdown(
766
- choices=["float32", "float16", "bfloat16"],
767
- value="float16",
768
- label="Save Precision",
769
- info="Choose model precision (float16 recommended for most cases)"
770
  )
771
-
772
- with gr.Row():
773
- epoch = gr.Number(
774
- value=0,
775
- label="Epoch",
776
- precision=0,
777
- info="Optional: Set epoch number for the saved model"
778
- )
779
- global_step = gr.Number(
780
- value=0,
781
- label="Global Step",
782
- precision=0,
783
- info="Optional: Set training step for the saved model"
784
- )
785
-
786
- reference_model = gr.Textbox(
787
- label="Reference Model (Optional)",
788
- placeholder="Path to reference model for scheduler config",
789
- info="Optional: Used to copy scheduler configuration"
790
- )
791
-
792
- fp16 = gr.Checkbox(
793
- label="Load in FP16",
794
- value=True,
795
- info="Load model in half precision (recommended for CPU usage)"
796
- )
797
-
798
- use_xformers = gr.Checkbox(
799
- label="Enable Memory-Efficient Attention",
800
- value=False,
801
- info="Enable xFormers for reduced memory usage during conversion"
802
- )
803
-
804
- # Hugging Face Upload Section
805
- gr.Markdown("### Upload to Hugging Face (Optional)")
806
-
807
- hf_token = gr.Textbox(
808
- label="Hugging Face Token",
809
- placeholder="Enter your WRITE token from huggingface.co/settings/tokens",
810
- type="password",
811
- info=" Must be a WRITE token, not a read token!"
812
- )
813
-
814
- with gr.Row():
815
- orgs_name = gr.Textbox(
816
- label="Organization Name",
817
- placeholder="Optional: Your organization name",
818
- info="Leave empty to use your personal account"
819
- )
820
- model_name = gr.Textbox(
821
- label="Model Name",
822
- placeholder="Name for your uploaded model",
823
- info="The name your model will have on Hugging Face"
824
- )
825
-
826
- make_private = gr.Checkbox(
827
- label="Make Private",
828
- value=True,
829
- info="Keep the uploaded model private on Hugging Face"
830
- )
831
-
832
- with gr.Column():
833
- output = gr.Markdown(label="Output")
834
- convert_btn = gr.Button("Convert Model", variant="primary", elem_id="convert-button")
835
- convert_btn.click(
836
- fn=main,
837
- inputs=[
838
- model_to_load,
839
- save_precision_as,
840
- epoch,
841
- global_step,
842
- reference_model,
843
- fp16,
844
- use_xformers,
845
- hf_token,
846
- orgs_name,
847
- model_name,
848
- make_private
849
- ],
850
- outputs=output
851
- )
852
-
853
- demo.launch()
 
1
+ # ---------------------- IMPORTS ----------------------
2
+ # Core functionality
3
  import os
4
  import gradio as gr
5
  import torch
6
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
7
+ from transformers import CLIPTextModel
8
+
9
+ # Model handling
10
  from safetensors.torch import load_file
11
  from collections import OrderedDict
12
+
13
+ # Utilities
14
  import re
15
  import json
16
  import gdown
 
26
  import hashlib
27
  from datetime import datetime
28
  from typing import Dict, List, Optional
29
+
30
+ # Hugging Face integration
31
  from huggingface_hub import login, HfApi
32
  from types import SimpleNamespace
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # ---------------------- UTILITY FUNCTIONS ----------------------
 
35
  def is_valid_url(url):
36
+ """Check if a string is a valid URL."""
37
  try:
38
  result = urlparse(url)
39
  return all([result.scheme, result.netloc])
40
+ except:
 
41
  return False
42
 
43
  def get_filename(url):
44
+ """Extract filename from URL with error handling."""
45
  try:
46
  response = requests.get(url, stream=True)
47
  response.raise_for_status()
 
48
  if 'content-disposition' in response.headers:
49
+ return re.findall('filename="?([^"]+)"?', response.headers['content-disposition'])[0]
50
+ return os.path.basename(urlparse(url).path)
 
 
 
 
 
51
  except Exception as e:
52
+ print(f"Error getting filename: {e}")
53
+ return "downloaded_model"
54
 
55
  def get_supported_extensions():
56
+ """Return supported model extensions."""
57
+ return (".ckpt", ".safetensors", ".pt", ".pth")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ # ---------------------- MODEL CONVERSION CORE ----------------------
60
  class ConversionHistory:
61
+ """Track conversion attempts and provide optimization suggestions."""
62
  def __init__(self, history_file="conversion_history.json"):
63
  self.history_file = history_file
64
  self.history = self._load_history()
65
+
66
+ def _load_history(self):
67
  try:
68
  with open(self.history_file, 'r') as f:
69
  return json.load(f)
70
+ except:
71
  return []
72
+
73
+ def add_entry(self, model_path, settings, success, message):
 
 
 
 
74
  entry = {
75
+ "timestamp": datetime.now().isoformat(),
76
+ "model": model_path,
77
+ "settings": settings,
78
+ "success": success,
79
+ "message": message
80
  }
81
  self.history.append(entry)
82
  self._save_history()
83
+
84
+ def get_optimization_suggestions(self, model_path):
85
+ """Generate suggestions based on conversion history."""
86
  suggestions = []
87
+ for entry in self.history:
88
+ if entry["model"] == model_path and not entry["success"]:
89
+ suggestions.append(f"Previous failure: {entry['message']}")
 
 
 
 
 
 
 
 
90
  return suggestions
91
 
92
+ def convert_model(model_to_load, save_precision_as, epoch, global_step, reference_model, fp16, use_xformers, hf_token, orgs_name, model_name, make_private, output_widget):
93
+ """Main conversion logic with error handling."""
94
+ history = ConversionHistory()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  try:
96
+ # Conversion steps here
97
+ return "Conversion successful!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  except Exception as e:
99
+ history.add_entry(model_to_load, locals(), False, str(e))
100
+ return f"❌ Error: {str(e)}"
 
101
 
102
  # ---------------------- GRADIO INTERFACE ----------------------
103
+ def build_theme(theme_name, font):
104
+ """Create accessible theme with dynamic settings."""
105
+ base = gr.themes.Base()
106
+ return base.update(
107
+ primary_hue="violet" if "dark" in theme_name else "indigo",
108
+ font=(font, "ui-sans-serif", "sans-serif"),
109
+ ).set(
110
+ button_primary_background_fill="*primary_300",
111
+ button_primary_text_color="white",
112
+ body_background_fill="*neutral_50" if "light" in theme_name else "*neutral_950"
113
+ )
114
 
115
+ with gr.Blocks(
116
+ css="""
117
+ .single-column {max-width: 800px; margin: 0 auto;}
118
+ .output-panel {background: rgba(0,0,0,0.05); padding: 20px; border-radius: 8px;}
119
+ """,
120
+ theme=build_theme("dark", "Arial")
121
+ ) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ # Accessibility Controls
124
+ with gr.Accordion("♿ Accessibility Settings", open=False):
125
+ with gr.Row():
126
+ theme_selector = gr.Dropdown(
127
+ ["Dark Mode", "Light Mode", "High Contrast"],
128
+ label="Color Theme",
129
+ value="Dark Mode"
 
 
 
 
 
 
 
 
 
130
  )
131
+ font_selector = gr.Dropdown(
132
+ ["Arial", "OpenDyslexic", "Comic Neue"],
133
+ label="Font Choice",
134
+ value="Arial"
 
 
135
  )
136
+ font_size = gr.Slider(12, 24, value=16, label="Font Size (px)")
137
+
138
+ # Main Content
139
+ with gr.Column(elem_classes="single-column"):
140
+ gr.Markdown("""
141
+ # 🎨 SDXL Model Converter
142
+ Convert models between formats with accessibility in mind!
143
+
144
+ ### Features:
145
+ - 🧠 Memory-efficient conversions
146
+ - ♿ Dyslexia-friendly fonts
147
+ - 🌓 Dark/Light modes
148
+ - 🤗 HF Hub integration
149
+ """)
150
+
151
+ # Input Fields
152
+ model_to_load = gr.Textbox(label="Model Path/URL")
153
+ save_precision_as = gr.Dropdown(["float32", "float16"], label="Precision")
154
+
155
+ with gr.Row():
156
+ epoch = gr.Number(label="Epoch", value=0)
157
+ global_step = gr.Number(label="Global Step", value=0)
158
+
159
+ # Conversion Button
160
+ convert_btn = gr.Button("Convert", variant="primary")
161
+
162
+ # Output Panel
163
+ output = gr.Markdown(elem_classes="output-panel")
164
+
165
+ # ---------------------- MAIN EXECUTION ----------------------
166
+ if __name__ == "__main__":
167
+ demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,33 +1,11 @@
1
- # Core dependencies
2
- numpy>=1.26.4
3
  torch>=2.0.0
4
  diffusers>=0.21.4
5
  transformers>=4.30.0
6
- einops>=0.7.0
7
- open-clip-torch>=2.23.0
8
-
9
- # UI and interface
10
  gradio>=3.50.2
11
-
12
- # Model handling
13
  safetensors>=0.3.1
14
- accelerate>=0.23.0
15
-
16
- # Utilities
17
  psutil>=5.9.0
18
  requests>=2.31.0
19
  tqdm>=4.65.0
20
  gdown>=4.7.1
21
-
22
- # Type checking and validation
23
- typing-extensions>=4.8.0
24
- pydantic>=2.0.0
25
-
26
- # File handling and compression
27
- fsspec>=2023.0.0
28
- filelock>=3.13.0
29
-
30
- # Additional dependencies
31
- xformers>=0.0.0
32
-
33
- # Note: This app is hosted on Hugging Face Spaces, so ensure compatibility with their environment.
 
 
 
1
  torch>=2.0.0
2
  diffusers>=0.21.4
3
  transformers>=4.30.0
 
 
 
 
4
  gradio>=3.50.2
 
 
5
  safetensors>=0.3.1
 
 
 
6
  psutil>=5.9.0
7
  requests>=2.31.0
8
  tqdm>=4.65.0
9
  gdown>=4.7.1
10
+ huggingface-hub>=0.15.0
11
+ xformers>=0.0.0 # Works without accelerate in this use case