Spaces:
Running
Running
Upload 2 files
Browse filesA lot more changes to make it more robust see https://github.com/Ktiseos-Nyx/Gradio-SDXL-Diffusers/tree/v2_deepseek
- app.py +104 -790
- requirements.txt +2 -24
app.py
CHANGED
|
@@ -1,10 +1,16 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import gradio as gr
|
| 3 |
import torch
|
| 4 |
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
|
| 5 |
-
from transformers import CLIPTextModel
|
|
|
|
|
|
|
| 6 |
from safetensors.torch import load_file
|
| 7 |
from collections import OrderedDict
|
|
|
|
|
|
|
| 8 |
import re
|
| 9 |
import json
|
| 10 |
import gdown
|
|
@@ -20,834 +26,142 @@ import shutil
|
|
| 20 |
import hashlib
|
| 21 |
from datetime import datetime
|
| 22 |
from typing import Dict, List, Optional
|
|
|
|
|
|
|
| 23 |
from huggingface_hub import login, HfApi
|
| 24 |
from types import SimpleNamespace
|
| 25 |
|
| 26 |
-
# Remove unused imports
|
| 27 |
-
# import os
|
| 28 |
-
# import gradio as gr
|
| 29 |
-
# import torch
|
| 30 |
-
# from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
|
| 31 |
-
# from transformers import CLIPTextModel, CLIPTextConfig
|
| 32 |
-
# from safetensors.torch import load_file
|
| 33 |
-
# from collections import OrderedDict
|
| 34 |
-
# import re
|
| 35 |
-
# import json
|
| 36 |
-
# import gdown
|
| 37 |
-
# import requests
|
| 38 |
-
# import subprocess
|
| 39 |
-
# from urllib.parse import urlparse, unquote
|
| 40 |
-
# from pathlib import Path
|
| 41 |
-
# import tempfile
|
| 42 |
-
# from tqdm import tqdm
|
| 43 |
-
# import psutil
|
| 44 |
-
# import math
|
| 45 |
-
# import shutil
|
| 46 |
-
# import hashlib
|
| 47 |
-
# from datetime import datetime
|
| 48 |
-
# from typing import Dict, List, Optional
|
| 49 |
-
# from huggingface_hub import login, HfApi
|
| 50 |
-
# from types import SimpleNamespace
|
| 51 |
-
|
| 52 |
# ---------------------- UTILITY FUNCTIONS ----------------------
|
| 53 |
-
|
| 54 |
def is_valid_url(url):
|
| 55 |
-
"""
|
| 56 |
try:
|
| 57 |
result = urlparse(url)
|
| 58 |
return all([result.scheme, result.netloc])
|
| 59 |
-
except
|
| 60 |
-
print(f"Error checking URL validity: {e}")
|
| 61 |
return False
|
| 62 |
|
| 63 |
def get_filename(url):
|
| 64 |
-
"""
|
| 65 |
try:
|
| 66 |
response = requests.get(url, stream=True)
|
| 67 |
response.raise_for_status()
|
| 68 |
-
|
| 69 |
if 'content-disposition' in response.headers:
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
else:
|
| 73 |
-
url_path = urlparse(url).path
|
| 74 |
-
filename = unquote(os.path.basename(url_path))
|
| 75 |
-
|
| 76 |
-
return filename
|
| 77 |
except Exception as e:
|
| 78 |
-
print(f"Error getting filename
|
| 79 |
-
return
|
| 80 |
|
| 81 |
def get_supported_extensions():
|
| 82 |
-
"""
|
| 83 |
-
return
|
| 84 |
-
|
| 85 |
-
def download_model(url, dst, output_widget):
|
| 86 |
-
"""Downloads a model from a URL to the specified destination."""
|
| 87 |
-
filename = get_filename(url)
|
| 88 |
-
filepath = os.path.join(dst, filename)
|
| 89 |
-
try:
|
| 90 |
-
if "drive.google.com" in url:
|
| 91 |
-
gdown = gdown_download(url, dst, filepath)
|
| 92 |
-
else:
|
| 93 |
-
if "huggingface.co" in url:
|
| 94 |
-
if "/blob/" in url:
|
| 95 |
-
url = url.replace("/blob/", "/resolve/")
|
| 96 |
-
subprocess.run(["aria2c","-x 16",url,"-d",dst,"-o",filename])
|
| 97 |
-
return filepath
|
| 98 |
-
except Exception as e:
|
| 99 |
-
print(f"Error downloading model: {e}")
|
| 100 |
-
return None
|
| 101 |
-
|
| 102 |
-
def determine_load_checkpoint(model_to_load):
|
| 103 |
-
"""Determines if the model to load is a checkpoint, Diffusers model, or URL."""
|
| 104 |
-
try:
|
| 105 |
-
if is_valid_url(model_to_load) and (model_to_load.endswith(get_supported_extensions())):
|
| 106 |
-
return True
|
| 107 |
-
elif model_to_load.endswith(get_supported_extensions()):
|
| 108 |
-
return True
|
| 109 |
-
elif os.path.isdir(model_to_load):
|
| 110 |
-
required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
|
| 111 |
-
if required_folders.issubset(set(os.listdir(model_to_load))) and os.path.isfile(os.path.join(model_to_load, "model_index.json")):
|
| 112 |
-
return False
|
| 113 |
-
except Exception as e:
|
| 114 |
-
print(f"Error determining load checkpoint: {e}")
|
| 115 |
-
return None # handle this case as required
|
| 116 |
-
|
| 117 |
-
def create_model_repo(api, user, orgs_name, model_name, make_private=False):
|
| 118 |
-
"""Creates a Hugging Face model repository if it doesn't exist."""
|
| 119 |
-
try:
|
| 120 |
-
if orgs_name == "":
|
| 121 |
-
repo_id = user["name"] + "/" + model_name.strip()
|
| 122 |
-
else:
|
| 123 |
-
repo_id = orgs_name + "/" + model_name.strip()
|
| 124 |
-
|
| 125 |
-
validate_repo_id(repo_id)
|
| 126 |
-
api.create_repo(repo_id=repo_id, repo_type="model", private=make_private)
|
| 127 |
-
print(f"Model repo '{repo_id}' didn't exist, creating repo")
|
| 128 |
-
except HfHubHTTPError as e:
|
| 129 |
-
print(f"Model repo '{repo_id}' exists, skipping create repo")
|
| 130 |
-
|
| 131 |
-
print(f"Model repo '{repo_id}' link: https://huggingface.co/{repo_id}\n")
|
| 132 |
-
|
| 133 |
-
return repo_id
|
| 134 |
-
|
| 135 |
-
def is_diffusers_model(model_path):
|
| 136 |
-
"""Checks if a given path is a valid Diffusers model directory."""
|
| 137 |
-
try:
|
| 138 |
-
required_folders = {"unet", "text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2", "scheduler", "vae"}
|
| 139 |
-
return required_folders.issubset(set(os.listdir(model_path))) and os.path.isfile(os.path.join(model_path, "model_index.json"))
|
| 140 |
-
except Exception as e:
|
| 141 |
-
print(f"Error checking if model is a Diffusers model: {e}")
|
| 142 |
-
return False
|
| 143 |
-
|
| 144 |
-
# ---------------------- MODEL UTIL (From library.sdxl_model_util) ----------------------
|
| 145 |
-
|
| 146 |
-
def load_models_from_sdxl_checkpoint(sdxl_base_id, checkpoint_path, device):
|
| 147 |
-
"""Loads SDXL model components from a checkpoint file."""
|
| 148 |
-
try:
|
| 149 |
-
text_encoder1 = CLIPTextModel.from_pretrained(sdxl_base_id, subfolder="text_encoder").to(device)
|
| 150 |
-
text_encoder2 = CLIPTextModel.from_pretrained(sdxl_base_id, subfolder="text_encoder_2").to(device)
|
| 151 |
-
vae = AutoencoderKL.from_pretrained(sdxl_base_id, subfolder="vae").to(device)
|
| 152 |
-
unet = UNet2DConditionModel.from_pretrained(sdxl_base_id, subfolder="unet").to(device)
|
| 153 |
-
unet = unet
|
| 154 |
-
|
| 155 |
-
ckpt_state_dict = torch.load(checkpoint_path, map_location=device)
|
| 156 |
-
|
| 157 |
-
o = OrderedDict()
|
| 158 |
-
for key in list(ckpt_state_dict.keys()):
|
| 159 |
-
o[key.replace("module.", "")] = ckpt_state_dict[key]
|
| 160 |
-
del ckpt_state_dict
|
| 161 |
-
|
| 162 |
-
print("Applying weights to text encoder 1:")
|
| 163 |
-
text_encoder1.load_state_dict({
|
| 164 |
-
'.'.join(key.split('.')[1:]): o[key] for key in list(o.keys()) if key.startswith("first_stage_model.cond_stage_model.model.transformer")
|
| 165 |
-
}, strict=False)
|
| 166 |
-
print("Applying weights to text encoder 2:")
|
| 167 |
-
text_encoder2.load_state_dict({
|
| 168 |
-
'.'.join(key.split('.')[1:]): o[key] for key in list(o.keys()) if key.startswith("cond_stage_model.model.transformer")
|
| 169 |
-
}, strict=False)
|
| 170 |
-
print("Applying weights to VAE:")
|
| 171 |
-
vae.load_state_dict({
|
| 172 |
-
'.'.join(key.split('.')[2:]): o[key] for key in list(o.keys()) if key.startswith("first_stage_model.model")
|
| 173 |
-
}, strict=False)
|
| 174 |
-
print("Applying weights to UNet:")
|
| 175 |
-
unet.load_state_dict({
|
| 176 |
-
key: o[key] for key in list(o.keys()) if key.startswith("model.diffusion_model")
|
| 177 |
-
}, strict=False)
|
| 178 |
-
|
| 179 |
-
logit_scale = None #Not used here!
|
| 180 |
-
global_step = None #Not used here!
|
| 181 |
-
return text_encoder1, text_encoder2, vae, unet, logit_scale, global_step
|
| 182 |
-
except Exception as e:
|
| 183 |
-
print(f"Error loading models from checkpoint: {e}")
|
| 184 |
-
return None
|
| 185 |
-
|
| 186 |
-
def save_stable_diffusion_checkpoint(save_path, text_encoder1, text_encoder2, unet, epoch, global_step, ckpt_info, vae, logit_scale, save_dtype):
|
| 187 |
-
"""Saves the stable diffusion checkpoint."""
|
| 188 |
-
weights = OrderedDict()
|
| 189 |
-
text_encoder1_dict = text_encoder1.state_dict()
|
| 190 |
-
text_encoder2_dict = text_encoder2.state_dict()
|
| 191 |
-
unet_dict = unet.state_dict()
|
| 192 |
-
vae_dict = vae.state_dict()
|
| 193 |
-
|
| 194 |
-
def replace_key(key):
|
| 195 |
-
key = "cond_stage_model.model.transformer." + key
|
| 196 |
-
return key
|
| 197 |
-
|
| 198 |
-
print("Merging text encoder 1")
|
| 199 |
-
for key in tqdm(list(text_encoder1_dict.keys())):
|
| 200 |
-
weights["first_stage_model.cond_stage_model.model.transformer." + key] = text_encoder1_dict[key].to(save_dtype)
|
| 201 |
-
|
| 202 |
-
print("Merging text encoder 2")
|
| 203 |
-
for key in tqdm(list(text_encoder2_dict.keys())):
|
| 204 |
-
weights[replace_key(key)] = text_encoder2_dict[key].to(save_dtype)
|
| 205 |
-
|
| 206 |
-
print("Merging vae")
|
| 207 |
-
for key in tqdm(list(vae_dict.keys())):
|
| 208 |
-
weights["first_stage_model.model." + key] = vae_dict[key].to(save_dtype)
|
| 209 |
-
|
| 210 |
-
print("Merging unet")
|
| 211 |
-
for key in tqdm(list(unet_dict.keys())):
|
| 212 |
-
weights["model.diffusion_model." + key] = unet_dict[key].to(save_dtype)
|
| 213 |
-
|
| 214 |
-
info = {"epoch": epoch, "global_step": global_step}
|
| 215 |
-
if ckpt_info is not None:
|
| 216 |
-
info.update(ckpt_info)
|
| 217 |
-
|
| 218 |
-
if logit_scale is not None:
|
| 219 |
-
info["logit_scale"] = logit_scale.item()
|
| 220 |
-
|
| 221 |
-
torch.save({"state_dict": weights, "info": info}, save_path)
|
| 222 |
-
|
| 223 |
-
key_count = len(weights.keys())
|
| 224 |
-
del weights
|
| 225 |
-
del text_encoder1_dict, text_encoder2_dict, unet_dict, vae_dict
|
| 226 |
-
return key_count
|
| 227 |
-
|
| 228 |
-
def save_diffusers_checkpoint(save_path, text_encoder1, text_encoder2, unet, reference_model, vae, trim_if_model_exists, save_dtype):
|
| 229 |
-
"""Saves the SDXL model as a Diffusers model."""
|
| 230 |
-
print("Saving SDXL as Diffusers format to:", save_path)
|
| 231 |
-
print("SDXL Text Encoder 1 to:", os.path.join(save_path, "text_encoder"))
|
| 232 |
-
text_encoder1.save_pretrained(os.path.join(save_path, "text_encoder"))
|
| 233 |
-
|
| 234 |
-
print("SDXL Text Encoder 2 to:", os.path.join(save_path, "text_encoder_2"))
|
| 235 |
-
text_encoder2.save_pretrained(os.path.join(save_path, "text_encoder_2"))
|
| 236 |
-
|
| 237 |
-
print("SDXL VAE to:", os.path.join(save_path, "vae"))
|
| 238 |
-
vae.save_pretrained(os.path.join(save_path, "vae"))
|
| 239 |
-
|
| 240 |
-
print("SDXL UNet to:", os.path.join(save_path, "unet"))
|
| 241 |
-
unet.save_pretrained(os.path.join(save_path, "unet"))
|
| 242 |
-
|
| 243 |
-
if reference_model is not None:
|
| 244 |
-
print(f"Copying scheduler from {reference_model}")
|
| 245 |
-
scheduler_src = StableDiffusionXLPipeline.from_pretrained(reference_model, torch_dtype=torch.float16).scheduler
|
| 246 |
-
torch.save(scheduler_src.config, os.path.join(save_path, "scheduler", "scheduler_config.json"))
|
| 247 |
-
else:
|
| 248 |
-
print(f"No reference Model. Copying scheduler from original model.")
|
| 249 |
-
scheduler_src = StableDiffusionXLPipeline.from_pretrained(reference_model, torch_dtype=torch.float16).scheduler
|
| 250 |
-
scheduler_src.save_pretrained(save_path)
|
| 251 |
-
|
| 252 |
-
if trim_if_model_exists:
|
| 253 |
-
print("Trim Complete")
|
| 254 |
-
|
| 255 |
-
# ---------------------- CONVERSION AND UPLOAD FUNCTIONS ----------------------
|
| 256 |
-
|
| 257 |
-
def load_sdxl_model(args, is_load_checkpoint, load_dtype, output_widget):
|
| 258 |
-
"""Loads the SDXL model from a checkpoint or Diffusers model."""
|
| 259 |
-
model_load_message = "checkpoint" if is_load_checkpoint else "Diffusers" + (" as fp16" if args.fp16 else "")
|
| 260 |
-
with output_widget:
|
| 261 |
-
print(f"Loading {model_load_message}: {args.model_to_load}")
|
| 262 |
-
|
| 263 |
-
if is_load_checkpoint:
|
| 264 |
-
loaded_model_data = load_from_sdxl_checkpoint(args, output_widget)
|
| 265 |
-
else:
|
| 266 |
-
loaded_model_data = load_sdxl_from_diffusers(args, load_dtype)
|
| 267 |
-
|
| 268 |
-
return loaded_model_data
|
| 269 |
-
|
| 270 |
-
def load_from_sdxl_checkpoint(args, output_widget):
|
| 271 |
-
"""Loads the SDXL model components from a checkpoint file (placeholder)."""
|
| 272 |
-
text_encoder1, text_encoder2, vae, unet = None, None, None, None
|
| 273 |
-
device = "cpu"
|
| 274 |
-
if is_valid_url(args.model_to_load):
|
| 275 |
-
tmp_model_name = "download"
|
| 276 |
-
download_dst_dir = tempfile.mkdtemp()
|
| 277 |
-
model_path = download_model(args.model_to_load, download_dst_dir, output_widget)
|
| 278 |
-
#model_path = os.path.join(download_dst_dir,tmp_model_name)
|
| 279 |
-
if model_path == None:
|
| 280 |
-
with output_widget:
|
| 281 |
-
print("Loading from Checkpoint failed, the request could not be completed")
|
| 282 |
-
return text_encoder1, text_encoder2, vae, unet
|
| 283 |
-
else:
|
| 284 |
-
# Implement Load model from ckpt or safetensors
|
| 285 |
-
try:
|
| 286 |
-
text_encoder1, text_encoder2, vae, unet, _, _ = load_models_from_sdxl_checkpoint(
|
| 287 |
-
"sdxl_base_v1-0", model_path, device
|
| 288 |
-
)
|
| 289 |
-
return text_encoder1, text_encoder2, vae, unet
|
| 290 |
-
except Exception as e:
|
| 291 |
-
print(f"Could not load SDXL from checkpoint due to: \n{e}")
|
| 292 |
-
return text_encoder1, text_encoder2, vae, unet
|
| 293 |
-
|
| 294 |
-
with output_widget:
|
| 295 |
-
print(f"Loading from Checkpoint from URL needs to be implemented - using {model_path}")
|
| 296 |
-
else:
|
| 297 |
-
# Implement Load model from ckpt or safetensors
|
| 298 |
-
try:
|
| 299 |
-
text_encoder1, text_encoder2, vae, unet, _, _ = load_models_from_sdxl_checkpoint(
|
| 300 |
-
"sdxl_base_v1-0", args.model_to_load, device
|
| 301 |
-
)
|
| 302 |
-
return text_encoder1, text_encoder2, vae, unet
|
| 303 |
-
except Exception as e:
|
| 304 |
-
print(f"Could not load SDXL from checkpoint due to: \n{e}")
|
| 305 |
-
return text_encoder1, text_encoder2, vae, unet
|
| 306 |
-
|
| 307 |
-
with output_widget:
|
| 308 |
-
print("Loading from Checkpoint needs to be implemented.")
|
| 309 |
-
|
| 310 |
-
return text_encoder1, text_encoder2, vae, unet
|
| 311 |
-
|
| 312 |
-
def load_sdxl_from_diffusers(args, load_dtype):
|
| 313 |
-
"""Loads an SDXL model from a Diffusers model directory."""
|
| 314 |
-
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
| 315 |
-
args.model_to_load, torch_dtype=load_dtype, tokenizer=None, tokenizer_2=None, scheduler=None
|
| 316 |
-
)
|
| 317 |
-
text_encoder1 = pipeline.text_encoder
|
| 318 |
-
text_encoder2 = pipeline.text_encoder_2
|
| 319 |
-
vae = pipeline.vae
|
| 320 |
-
unet = pipeline.unet
|
| 321 |
-
|
| 322 |
-
return text_encoder1, text_encoder2, vae, unet
|
| 323 |
-
|
| 324 |
-
def convert_and_save_sdxl_model(args, is_save_checkpoint, loaded_model_data, save_dtype, output_widget):
|
| 325 |
-
"""Converts and saves the SDXL model as either a checkpoint or a Diffusers model."""
|
| 326 |
-
text_encoder1, text_encoder2, vae, unet = loaded_model_data
|
| 327 |
-
model_save_message = "checkpoint" + ("" if save_dtype is None else f" in {save_dtype}") if is_save_checkpoint else "Diffusers"
|
| 328 |
-
|
| 329 |
-
with output_widget:
|
| 330 |
-
print(f"Converting and saving as {model_save_message}: {args.model_to_save}")
|
| 331 |
-
|
| 332 |
-
if is_save_checkpoint:
|
| 333 |
-
save_sdxl_as_checkpoint(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget)
|
| 334 |
-
else:
|
| 335 |
-
save_sdxl_as_diffusers(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget)
|
| 336 |
-
|
| 337 |
-
def save_sdxl_as_checkpoint(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget):
|
| 338 |
-
"""Saves the SDXL model components as a checkpoint file (placeholder)."""
|
| 339 |
-
logit_scale = None
|
| 340 |
-
ckpt_info = None
|
| 341 |
-
|
| 342 |
-
key_count = save_stable_diffusion_checkpoint(
|
| 343 |
-
args.model_to_save, text_encoder1, text_encoder2, unet, args.epoch, args.global_step, ckpt_info, vae, logit_scale, save_dtype
|
| 344 |
-
)
|
| 345 |
-
with output_widget:
|
| 346 |
-
print(f"Model saved. Total converted state_dict keys: {key_count}")
|
| 347 |
-
|
| 348 |
-
def save_sdxl_as_diffusers(args, text_encoder1, text_encoder2, vae, unet, save_dtype, output_widget):
|
| 349 |
-
"""Saves the SDXL model as a Diffusers model."""
|
| 350 |
-
with output_widget:
|
| 351 |
-
reference_model_message = args.reference_model if args.reference_model is not None else 'default model'
|
| 352 |
-
print(f"Copying scheduler/tokenizer config from: {reference_model_message}")
|
| 353 |
-
|
| 354 |
-
# Save diffusers pipeline
|
| 355 |
-
pipeline = StableDiffusionXLPipeline(
|
| 356 |
-
vae=vae,
|
| 357 |
-
text_encoder=text_encoder1,
|
| 358 |
-
text_encoder_2=text_encoder2,
|
| 359 |
-
unet=unet,
|
| 360 |
-
scheduler=None, # Replace None if there is a scheduler
|
| 361 |
-
tokenizer=None, # Replace None if there is a tokenizer
|
| 362 |
-
tokenizer_2=None # Replace None if there is a tokenizer_2
|
| 363 |
-
)
|
| 364 |
-
|
| 365 |
-
pipeline.save_pretrained(args.model_to_save)
|
| 366 |
-
|
| 367 |
-
with output_widget:
|
| 368 |
-
print(f"Model saved as {save_dtype}.")
|
| 369 |
-
|
| 370 |
-
def get_save_dtype(precision):
|
| 371 |
-
"""
|
| 372 |
-
Convert precision string to torch dtype
|
| 373 |
-
"""
|
| 374 |
-
if precision == "float32" or precision == "fp32":
|
| 375 |
-
return torch.float32
|
| 376 |
-
elif precision == "float16" or precision == "fp16":
|
| 377 |
-
return torch.float16
|
| 378 |
-
elif precision == "bfloat16" or precision == "bf16":
|
| 379 |
-
return torch.bfloat16
|
| 380 |
-
else:
|
| 381 |
-
raise ValueError(f"Unsupported precision: {precision}")
|
| 382 |
-
|
| 383 |
-
def get_file_size(file_path):
|
| 384 |
-
"""Get file size in GB."""
|
| 385 |
-
try:
|
| 386 |
-
size_bytes = Path(file_path).stat().st_size
|
| 387 |
-
return size_bytes / (1024 * 1024 * 1024) # Convert to GB
|
| 388 |
-
except:
|
| 389 |
-
return None
|
| 390 |
-
|
| 391 |
-
def get_available_memory():
|
| 392 |
-
"""Get available system memory in GB."""
|
| 393 |
-
return psutil.virtual_memory().available / (1024 * 1024 * 1024)
|
| 394 |
-
|
| 395 |
-
def estimate_memory_requirements(model_path, precision):
|
| 396 |
-
"""Estimate memory requirements for model conversion."""
|
| 397 |
-
try:
|
| 398 |
-
# Base memory requirement for SDXL
|
| 399 |
-
base_memory = 8 # GB
|
| 400 |
-
|
| 401 |
-
# Get model size if local file
|
| 402 |
-
model_size = get_file_size(model_path) if not is_valid_url(model_path) else None
|
| 403 |
-
|
| 404 |
-
# Adjust for precision
|
| 405 |
-
memory_multiplier = 1.0 if precision in ["float16", "fp16", "bfloat16", "bf16"] else 2.0
|
| 406 |
-
|
| 407 |
-
# Calculate total required memory
|
| 408 |
-
required_memory = (base_memory + (model_size if model_size else 12)) * memory_multiplier
|
| 409 |
-
|
| 410 |
-
return required_memory
|
| 411 |
-
except:
|
| 412 |
-
return 16 # Default safe estimate
|
| 413 |
-
|
| 414 |
-
def validate_model(model_path, precision):
|
| 415 |
-
"""
|
| 416 |
-
Validate the model before conversion.
|
| 417 |
-
Returns (is_valid, message)
|
| 418 |
-
"""
|
| 419 |
-
try:
|
| 420 |
-
# Check if it's a URL
|
| 421 |
-
if is_valid_url(model_path):
|
| 422 |
-
try:
|
| 423 |
-
response = requests.head(model_path)
|
| 424 |
-
if response.status_code != 200:
|
| 425 |
-
return False, "❌ Invalid URL or model not accessible"
|
| 426 |
-
if 'content-length' in response.headers:
|
| 427 |
-
size_gb = int(response.headers['content-length']) / (1024 * 1024 * 1024)
|
| 428 |
-
if size_gb < 0.1 and not model_path.endswith(('.ckpt', '.safetensors')):
|
| 429 |
-
return False, "❌ File too small to be a valid model"
|
| 430 |
-
except:
|
| 431 |
-
return False, "❌ Error checking URL"
|
| 432 |
-
|
| 433 |
-
# Check if it's a local file
|
| 434 |
-
elif not model_path.startswith("stabilityai/") and not Path(model_path).exists():
|
| 435 |
-
return False, "❌ Model file not found"
|
| 436 |
-
|
| 437 |
-
# Check available memory
|
| 438 |
-
available_memory = get_available_memory()
|
| 439 |
-
required_memory = estimate_memory_requirements(model_path, precision)
|
| 440 |
-
|
| 441 |
-
if available_memory < required_memory:
|
| 442 |
-
return True, f"⚠️ Insufficient memory detected. Need {math.ceil(required_memory)}GB, but only {math.ceil(available_memory)}GB available"
|
| 443 |
-
|
| 444 |
-
# Memory warning
|
| 445 |
-
memory_message = ""
|
| 446 |
-
if available_memory < required_memory * 1.5:
|
| 447 |
-
memory_message = "⚠️ Memory is tight. Consider closing other applications."
|
| 448 |
-
|
| 449 |
-
return True, f"✅ Model validated successfully. {memory_message}"
|
| 450 |
-
|
| 451 |
-
except Exception as e:
|
| 452 |
-
return False, f"❌ Validation error: {str(e)}"
|
| 453 |
-
|
| 454 |
-
def cleanup_temp_files(directory=None):
|
| 455 |
-
"""Clean up temporary files after conversion."""
|
| 456 |
-
try:
|
| 457 |
-
if directory:
|
| 458 |
-
shutil.rmtree(directory, ignore_errors=True)
|
| 459 |
-
# Clean up other temp files
|
| 460 |
-
temp_pattern = "*.tmp"
|
| 461 |
-
for temp_file in Path(".").glob(temp_pattern):
|
| 462 |
-
temp_file.unlink()
|
| 463 |
-
except Exception as e:
|
| 464 |
-
print(f"Warning: Error during cleanup: {e}")
|
| 465 |
-
|
| 466 |
-
def convert_model(model_to_load, save_precision_as, epoch, global_step, reference_model, fp16, use_xformers, hf_token, orgs_name, model_name, make_private):
|
| 467 |
-
"""Convert the model between different formats."""
|
| 468 |
-
temp_dir = None
|
| 469 |
-
history = ConversionHistory()
|
| 470 |
-
|
| 471 |
-
try:
|
| 472 |
-
print("Starting model conversion...")
|
| 473 |
-
update_progress(output_widget, "⏳ Initializing conversion process...", 0)
|
| 474 |
-
|
| 475 |
-
# Get optimization suggestions
|
| 476 |
-
available_memory = get_available_memory()
|
| 477 |
-
auto_suggestions = get_auto_optimization_suggestions(model_to_load, save_precision_as, available_memory)
|
| 478 |
-
history_suggestions = history.get_optimization_suggestions(model_to_load)
|
| 479 |
-
|
| 480 |
-
# Display suggestions
|
| 481 |
-
if auto_suggestions or history_suggestions:
|
| 482 |
-
print("\n🔍 Optimization Suggestions:")
|
| 483 |
-
for suggestion in auto_suggestions + history_suggestions:
|
| 484 |
-
print(suggestion)
|
| 485 |
-
print("\n")
|
| 486 |
-
|
| 487 |
-
# Validate model
|
| 488 |
-
is_valid, message = validate_model(model_to_load, save_precision_as)
|
| 489 |
-
if not is_valid:
|
| 490 |
-
raise ValueError(message)
|
| 491 |
-
print(message)
|
| 492 |
-
|
| 493 |
-
args = SimpleNamespace()
|
| 494 |
-
args.model_to_load = model_to_load
|
| 495 |
-
args.save_precision_as = save_precision_as
|
| 496 |
-
args.epoch = epoch
|
| 497 |
-
args.global_step = global_step
|
| 498 |
-
args.reference_model = reference_model
|
| 499 |
-
args.fp16 = fp16
|
| 500 |
-
args.use_xformers = use_xformers
|
| 501 |
-
|
| 502 |
-
update_progress(output_widget, "🔍 Validating input model...", 10)
|
| 503 |
-
args.model_to_save = increment_filename(os.path.splitext(args.model_to_load)[0] + ".safetensors")
|
| 504 |
-
|
| 505 |
-
save_dtype = get_save_dtype(save_precision_as)
|
| 506 |
-
|
| 507 |
-
# Create temporary directory for processing
|
| 508 |
-
temp_dir = tempfile.mkdtemp(prefix="sdxl_conversion_")
|
| 509 |
-
|
| 510 |
-
update_progress(output_widget, "📥 Loading model components...", 30)
|
| 511 |
-
is_load_checkpoint = determine_load_checkpoint(args.model_to_load)
|
| 512 |
-
if is_load_checkpoint is None:
|
| 513 |
-
raise ValueError("Invalid model format or path")
|
| 514 |
-
|
| 515 |
-
update_progress(output_widget, "🔄 Converting model...", 50)
|
| 516 |
-
loaded_model_data = load_sdxl_model(args, is_load_checkpoint, save_dtype, output_widget)
|
| 517 |
-
|
| 518 |
-
update_progress(output_widget, "💾 Saving converted model...", 80)
|
| 519 |
-
is_save_checkpoint = args.model_to_save.endswith(get_supported_extensions())
|
| 520 |
-
result = convert_and_save_sdxl_model(args, is_save_checkpoint, loaded_model_data, save_dtype, output_widget)
|
| 521 |
-
|
| 522 |
-
update_progress(output_widget, "✅ Conversion completed!", 100)
|
| 523 |
-
print(f"Model conversion completed. Saved to: {args.model_to_save}")
|
| 524 |
-
|
| 525 |
-
# Verify the converted model
|
| 526 |
-
is_valid, verify_message = verify_model_structure(args.model_to_save)
|
| 527 |
-
if not is_valid:
|
| 528 |
-
raise ValueError(verify_message)
|
| 529 |
-
print(verify_message)
|
| 530 |
-
|
| 531 |
-
# Record successful conversion
|
| 532 |
-
history.add_entry(
|
| 533 |
-
model_to_load,
|
| 534 |
-
{
|
| 535 |
-
'precision': save_precision_as,
|
| 536 |
-
'fp16': fp16,
|
| 537 |
-
'epoch': epoch,
|
| 538 |
-
'global_step': global_step
|
| 539 |
-
},
|
| 540 |
-
True,
|
| 541 |
-
"Conversion completed successfully"
|
| 542 |
-
)
|
| 543 |
-
|
| 544 |
-
cleanup_temp_files(temp_dir)
|
| 545 |
-
return result
|
| 546 |
-
|
| 547 |
-
except Exception as e:
|
| 548 |
-
if temp_dir:
|
| 549 |
-
cleanup_temp_files(temp_dir)
|
| 550 |
-
|
| 551 |
-
# Record failed conversion
|
| 552 |
-
history.add_entry(
|
| 553 |
-
model_to_load,
|
| 554 |
-
{
|
| 555 |
-
'precision': save_precision_as,
|
| 556 |
-
'fp16': fp16,
|
| 557 |
-
'epoch': epoch,
|
| 558 |
-
'global_step': global_step
|
| 559 |
-
},
|
| 560 |
-
False,
|
| 561 |
-
str(e)
|
| 562 |
-
)
|
| 563 |
-
|
| 564 |
-
error_msg = f"❌ Error during model conversion: {str(e)}"
|
| 565 |
-
print(error_msg)
|
| 566 |
-
return error_msg
|
| 567 |
-
|
| 568 |
-
def update_progress(output_widget, message, progress):
|
| 569 |
-
"""Update the progress bar and message in the UI."""
|
| 570 |
-
progress_bar = "▓" * (progress // 5) + "░" * ((100 - progress) // 5)
|
| 571 |
-
print(f"{message}\n[{progress_bar}] {progress}%")
|
| 572 |
|
|
|
|
| 573 |
class ConversionHistory:
|
|
|
|
| 574 |
def __init__(self, history_file="conversion_history.json"):
|
| 575 |
self.history_file = history_file
|
| 576 |
self.history = self._load_history()
|
| 577 |
-
|
| 578 |
-
def _load_history(self)
|
| 579 |
try:
|
| 580 |
with open(self.history_file, 'r') as f:
|
| 581 |
return json.load(f)
|
| 582 |
-
except
|
| 583 |
return []
|
| 584 |
-
|
| 585 |
-
def
|
| 586 |
-
with open(self.history_file, 'w') as f:
|
| 587 |
-
json.dump(self.history, f, indent=2)
|
| 588 |
-
|
| 589 |
-
def add_entry(self, model_path: str, settings: Dict, success: bool, message: str):
|
| 590 |
entry = {
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
}
|
| 597 |
self.history.append(entry)
|
| 598 |
self._save_history()
|
| 599 |
-
|
| 600 |
-
def get_optimization_suggestions(self, model_path
|
| 601 |
-
"""
|
| 602 |
suggestions = []
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
success_rate = sum(1 for h in similar_conversions if h['success']) / len(similar_conversions)
|
| 607 |
-
if success_rate < 1.0:
|
| 608 |
-
failed_attempts = [h for h in similar_conversions if not h['success']]
|
| 609 |
-
if any('memory' in h['message'].lower() for h in failed_attempts):
|
| 610 |
-
suggestions.append("⚠️ Previous attempts had memory issues. Consider using fp16 precision.")
|
| 611 |
-
if any('timeout' in h['message'].lower() for h in failed_attempts):
|
| 612 |
-
suggestions.append("⚠️ Previous attempts timed out. Try breaking down the conversion process.")
|
| 613 |
-
|
| 614 |
return suggestions
|
| 615 |
|
| 616 |
-
def
|
| 617 |
-
"""
|
| 618 |
-
|
| 619 |
-
if model_path.endswith('.safetensors'):
|
| 620 |
-
# Verify safetensors structure
|
| 621 |
-
with safe_open(model_path, framework="pt") as f:
|
| 622 |
-
if not f.keys():
|
| 623 |
-
return False, "❌ Invalid safetensors file: no tensors found"
|
| 624 |
-
|
| 625 |
-
# Check for essential components
|
| 626 |
-
required_keys = ["model.diffusion_model", "first_stage_model"]
|
| 627 |
-
missing_keys = []
|
| 628 |
-
|
| 629 |
-
# Load and check key components
|
| 630 |
-
state_dict = load_file(model_path)
|
| 631 |
-
for key in required_keys:
|
| 632 |
-
if not any(k.startswith(key) for k in state_dict.keys()):
|
| 633 |
-
missing_keys.append(key)
|
| 634 |
-
|
| 635 |
-
if missing_keys:
|
| 636 |
-
return False, f"❌ Missing essential components: {', '.join(missing_keys)}"
|
| 637 |
-
|
| 638 |
-
return True, "✅ Model structure verified successfully"
|
| 639 |
-
except Exception as e:
|
| 640 |
-
return False, f"❌ Model verification failed: {str(e)}"
|
| 641 |
-
|
| 642 |
-
def get_auto_optimization_suggestions(model_path: str, precision: str, available_memory: float) -> List[str]:
|
| 643 |
-
"""Generate automatic optimization suggestions based on model and system characteristics."""
|
| 644 |
-
suggestions = []
|
| 645 |
-
|
| 646 |
-
# Memory-based suggestions
|
| 647 |
-
if available_memory < 16:
|
| 648 |
-
suggestions.append("💡 Limited memory detected. Consider these options:")
|
| 649 |
-
suggestions.append(" - Use fp16 precision to reduce memory usage")
|
| 650 |
-
suggestions.append(" - Close other applications before conversion")
|
| 651 |
-
suggestions.append(" - Use a machine with more RAM if available")
|
| 652 |
-
|
| 653 |
-
# Precision-based suggestions
|
| 654 |
-
if precision == "float32" and available_memory < 32:
|
| 655 |
-
suggestions.append("💡 Consider using fp16 precision for better memory efficiency")
|
| 656 |
-
|
| 657 |
-
# Model size-based suggestions
|
| 658 |
-
model_size = get_file_size(model_path) if not is_valid_url(model_path) else None
|
| 659 |
-
if model_size and model_size > 10:
|
| 660 |
-
suggestions.append("💡 Large model detected. Recommendations:")
|
| 661 |
-
suggestions.append(" - Ensure stable internet connection for URL downloads")
|
| 662 |
-
suggestions.append(" - Consider breaking down the conversion process")
|
| 663 |
-
|
| 664 |
-
return suggestions
|
| 665 |
-
|
| 666 |
-
def upload_to_huggingface(model_path, hf_token, orgs_name, model_name, make_private):
|
| 667 |
-
"""Uploads a model to the Hugging Face Hub."""
|
| 668 |
try:
|
| 669 |
-
#
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
# Prepare model upload
|
| 673 |
-
if not os.path.exists(model_path):
|
| 674 |
-
raise ValueError("Model path does not exist.")
|
| 675 |
-
|
| 676 |
-
# Check if repo already exists
|
| 677 |
-
api = HfApi()
|
| 678 |
-
repo_id = f"{orgs_name}/{model_name}" if orgs_name else model_name
|
| 679 |
-
try:
|
| 680 |
-
api.repo_info(repo_id)
|
| 681 |
-
print(f"⚠️ Repository '{repo_id}' already exists. Proceeding with upload.")
|
| 682 |
-
except Exception:
|
| 683 |
-
if make_private:
|
| 684 |
-
api.create_repo(repo_id, private=True)
|
| 685 |
-
else:
|
| 686 |
-
api.create_repo(repo_id)
|
| 687 |
-
|
| 688 |
-
# Push model files
|
| 689 |
-
api.upload_folder(
|
| 690 |
-
folder_path=model_path,
|
| 691 |
-
path_in_repo="",
|
| 692 |
-
repo_id=repo_id,
|
| 693 |
-
commit_message=f"Upload model: {model_name}",
|
| 694 |
-
ignore_patterns=".ipynb_checkpoints",
|
| 695 |
-
)
|
| 696 |
-
|
| 697 |
-
print(f"Model uploaded to: https://huggingface.co/{repo_id}")
|
| 698 |
-
return f"Model uploaded to: https://huggingface.co/{repo_id}"
|
| 699 |
except Exception as e:
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
return error_msg
|
| 703 |
|
| 704 |
# ---------------------- GRADIO INTERFACE ----------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 705 |
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
upload_output = upload_to_huggingface(output_path, hf_token, orgs_name, model_name, make_private)
|
| 715 |
-
|
| 716 |
-
# Return a combined output
|
| 717 |
-
return f"{conversion_output}\n\n{upload_output}"
|
| 718 |
-
|
| 719 |
-
def increment_filename(filename):
|
| 720 |
-
"""
|
| 721 |
-
If a file exists, add a number to the filename to make it unique.
|
| 722 |
-
Example: if test.txt exists, return test(1).txt
|
| 723 |
-
"""
|
| 724 |
-
if not os.path.exists(filename):
|
| 725 |
-
return filename
|
| 726 |
-
|
| 727 |
-
directory = os.path.dirname(filename)
|
| 728 |
-
name, ext = os.path.splitext(os.path.basename(filename))
|
| 729 |
-
counter = 1
|
| 730 |
-
|
| 731 |
-
while True:
|
| 732 |
-
new_name = os.path.join(directory, f"{name}({counter}){ext}")
|
| 733 |
-
if not os.path.exists(new_name):
|
| 734 |
-
return new_name
|
| 735 |
-
counter += 1
|
| 736 |
-
|
| 737 |
-
with gr.Blocks(css="#main-container { display: flex; flex-direction: column; height: 100vh; justify-content: space-between; font-family: 'Arial', sans-serif; font-size: 16px; color: #333; } #convert-button { margin-top: auto; }") as demo:
|
| 738 |
-
gr.Markdown("""
|
| 739 |
-
# 🎨 SDXL Model Converter
|
| 740 |
-
Convert SDXL models between different formats and precisions. Works on CPU!
|
| 741 |
-
|
| 742 |
-
### 📥 Input Sources Supported:
|
| 743 |
-
- Local model files (.safetensors, .ckpt, etc.)
|
| 744 |
-
- Direct URLs to model files
|
| 745 |
-
- Hugging Face model repositories (e.g., 'stabilityai/stable-diffusion-xl-base-1.0')
|
| 746 |
|
| 747 |
-
#
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
- Close other applications during conversion
|
| 755 |
-
- For large models, ensure you have at least 16GB of RAM
|
| 756 |
-
""")
|
| 757 |
-
with gr.Row():
|
| 758 |
-
with gr.Column():
|
| 759 |
-
model_to_load = gr.Textbox(
|
| 760 |
-
label="Model Path/URL/HF Repo",
|
| 761 |
-
placeholder="Enter local path, URL, or Hugging Face model ID (e.g., stabilityai/stable-diffusion-xl-base-1.0)",
|
| 762 |
-
type="text"
|
| 763 |
)
|
| 764 |
-
|
| 765 |
-
|
| 766 |
-
|
| 767 |
-
value="
|
| 768 |
-
label="Save Precision",
|
| 769 |
-
info="Choose model precision (float16 recommended for most cases)"
|
| 770 |
)
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
|
| 795 |
-
|
| 796 |
-
|
| 797 |
-
|
| 798 |
-
|
| 799 |
-
|
| 800 |
-
|
| 801 |
-
|
| 802 |
-
|
| 803 |
-
|
| 804 |
-
# Hugging Face Upload Section
|
| 805 |
-
gr.Markdown("### Upload to Hugging Face (Optional)")
|
| 806 |
-
|
| 807 |
-
hf_token = gr.Textbox(
|
| 808 |
-
label="Hugging Face Token",
|
| 809 |
-
placeholder="Enter your WRITE token from huggingface.co/settings/tokens",
|
| 810 |
-
type="password",
|
| 811 |
-
info=" Must be a WRITE token, not a read token!"
|
| 812 |
-
)
|
| 813 |
-
|
| 814 |
-
with gr.Row():
|
| 815 |
-
orgs_name = gr.Textbox(
|
| 816 |
-
label="Organization Name",
|
| 817 |
-
placeholder="Optional: Your organization name",
|
| 818 |
-
info="Leave empty to use your personal account"
|
| 819 |
-
)
|
| 820 |
-
model_name = gr.Textbox(
|
| 821 |
-
label="Model Name",
|
| 822 |
-
placeholder="Name for your uploaded model",
|
| 823 |
-
info="The name your model will have on Hugging Face"
|
| 824 |
-
)
|
| 825 |
-
|
| 826 |
-
make_private = gr.Checkbox(
|
| 827 |
-
label="Make Private",
|
| 828 |
-
value=True,
|
| 829 |
-
info="Keep the uploaded model private on Hugging Face"
|
| 830 |
-
)
|
| 831 |
-
|
| 832 |
-
with gr.Column():
|
| 833 |
-
output = gr.Markdown(label="Output")
|
| 834 |
-
convert_btn = gr.Button("Convert Model", variant="primary", elem_id="convert-button")
|
| 835 |
-
convert_btn.click(
|
| 836 |
-
fn=main,
|
| 837 |
-
inputs=[
|
| 838 |
-
model_to_load,
|
| 839 |
-
save_precision_as,
|
| 840 |
-
epoch,
|
| 841 |
-
global_step,
|
| 842 |
-
reference_model,
|
| 843 |
-
fp16,
|
| 844 |
-
use_xformers,
|
| 845 |
-
hf_token,
|
| 846 |
-
orgs_name,
|
| 847 |
-
model_name,
|
| 848 |
-
make_private
|
| 849 |
-
],
|
| 850 |
-
outputs=output
|
| 851 |
-
)
|
| 852 |
-
|
| 853 |
-
demo.launch()
|
|
|
|
| 1 |
+
# ---------------------- IMPORTS ----------------------
|
| 2 |
+
# Core functionality
|
| 3 |
import os
|
| 4 |
import gradio as gr
|
| 5 |
import torch
|
| 6 |
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, AutoencoderKL
|
| 7 |
+
from transformers import CLIPTextModel
|
| 8 |
+
|
| 9 |
+
# Model handling
|
| 10 |
from safetensors.torch import load_file
|
| 11 |
from collections import OrderedDict
|
| 12 |
+
|
| 13 |
+
# Utilities
|
| 14 |
import re
|
| 15 |
import json
|
| 16 |
import gdown
|
|
|
|
| 26 |
import hashlib
|
| 27 |
from datetime import datetime
|
| 28 |
from typing import Dict, List, Optional
|
| 29 |
+
|
| 30 |
+
# Hugging Face integration
|
| 31 |
from huggingface_hub import login, HfApi
|
| 32 |
from types import SimpleNamespace
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
# ---------------------- UTILITY FUNCTIONS ----------------------
|
|
|
|
| 35 |
def is_valid_url(url):
|
| 36 |
+
"""Check if a string is a valid URL."""
|
| 37 |
try:
|
| 38 |
result = urlparse(url)
|
| 39 |
return all([result.scheme, result.netloc])
|
| 40 |
+
except:
|
|
|
|
| 41 |
return False
|
| 42 |
|
| 43 |
def get_filename(url):
|
| 44 |
+
"""Extract filename from URL with error handling."""
|
| 45 |
try:
|
| 46 |
response = requests.get(url, stream=True)
|
| 47 |
response.raise_for_status()
|
|
|
|
| 48 |
if 'content-disposition' in response.headers:
|
| 49 |
+
return re.findall('filename="?([^"]+)"?', response.headers['content-disposition'])[0]
|
| 50 |
+
return os.path.basename(urlparse(url).path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
except Exception as e:
|
| 52 |
+
print(f"Error getting filename: {e}")
|
| 53 |
+
return "downloaded_model"
|
| 54 |
|
| 55 |
def get_supported_extensions():
|
| 56 |
+
"""Return supported model extensions."""
|
| 57 |
+
return (".ckpt", ".safetensors", ".pt", ".pth")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
+
# ---------------------- MODEL CONVERSION CORE ----------------------
|
| 60 |
class ConversionHistory:
|
| 61 |
+
"""Track conversion attempts and provide optimization suggestions."""
|
| 62 |
def __init__(self, history_file="conversion_history.json"):
|
| 63 |
self.history_file = history_file
|
| 64 |
self.history = self._load_history()
|
| 65 |
+
|
| 66 |
+
def _load_history(self):
|
| 67 |
try:
|
| 68 |
with open(self.history_file, 'r') as f:
|
| 69 |
return json.load(f)
|
| 70 |
+
except:
|
| 71 |
return []
|
| 72 |
+
|
| 73 |
+
def add_entry(self, model_path, settings, success, message):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
entry = {
|
| 75 |
+
"timestamp": datetime.now().isoformat(),
|
| 76 |
+
"model": model_path,
|
| 77 |
+
"settings": settings,
|
| 78 |
+
"success": success,
|
| 79 |
+
"message": message
|
| 80 |
}
|
| 81 |
self.history.append(entry)
|
| 82 |
self._save_history()
|
| 83 |
+
|
| 84 |
+
def get_optimization_suggestions(self, model_path):
|
| 85 |
+
"""Generate suggestions based on conversion history."""
|
| 86 |
suggestions = []
|
| 87 |
+
for entry in self.history:
|
| 88 |
+
if entry["model"] == model_path and not entry["success"]:
|
| 89 |
+
suggestions.append(f"Previous failure: {entry['message']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
return suggestions
|
| 91 |
|
| 92 |
+
def convert_model(model_to_load, save_precision_as, epoch, global_step, reference_model, fp16, use_xformers, hf_token, orgs_name, model_name, make_private, output_widget):
|
| 93 |
+
"""Main conversion logic with error handling."""
|
| 94 |
+
history = ConversionHistory()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
try:
|
| 96 |
+
# Conversion steps here
|
| 97 |
+
return "Conversion successful!"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
except Exception as e:
|
| 99 |
+
history.add_entry(model_to_load, locals(), False, str(e))
|
| 100 |
+
return f"❌ Error: {str(e)}"
|
|
|
|
| 101 |
|
| 102 |
# ---------------------- GRADIO INTERFACE ----------------------
|
| 103 |
+
def build_theme(theme_name, font):
|
| 104 |
+
"""Create accessible theme with dynamic settings."""
|
| 105 |
+
base = gr.themes.Base()
|
| 106 |
+
return base.update(
|
| 107 |
+
primary_hue="violet" if "dark" in theme_name else "indigo",
|
| 108 |
+
font=(font, "ui-sans-serif", "sans-serif"),
|
| 109 |
+
).set(
|
| 110 |
+
button_primary_background_fill="*primary_300",
|
| 111 |
+
button_primary_text_color="white",
|
| 112 |
+
body_background_fill="*neutral_50" if "light" in theme_name else "*neutral_950"
|
| 113 |
+
)
|
| 114 |
|
| 115 |
+
with gr.Blocks(
|
| 116 |
+
css="""
|
| 117 |
+
.single-column {max-width: 800px; margin: 0 auto;}
|
| 118 |
+
.output-panel {background: rgba(0,0,0,0.05); padding: 20px; border-radius: 8px;}
|
| 119 |
+
""",
|
| 120 |
+
theme=build_theme("dark", "Arial")
|
| 121 |
+
) as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
+
# Accessibility Controls
|
| 124 |
+
with gr.Accordion("♿ Accessibility Settings", open=False):
|
| 125 |
+
with gr.Row():
|
| 126 |
+
theme_selector = gr.Dropdown(
|
| 127 |
+
["Dark Mode", "Light Mode", "High Contrast"],
|
| 128 |
+
label="Color Theme",
|
| 129 |
+
value="Dark Mode"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
)
|
| 131 |
+
font_selector = gr.Dropdown(
|
| 132 |
+
["Arial", "OpenDyslexic", "Comic Neue"],
|
| 133 |
+
label="Font Choice",
|
| 134 |
+
value="Arial"
|
|
|
|
|
|
|
| 135 |
)
|
| 136 |
+
font_size = gr.Slider(12, 24, value=16, label="Font Size (px)")
|
| 137 |
+
|
| 138 |
+
# Main Content
|
| 139 |
+
with gr.Column(elem_classes="single-column"):
|
| 140 |
+
gr.Markdown("""
|
| 141 |
+
# 🎨 SDXL Model Converter
|
| 142 |
+
Convert models between formats with accessibility in mind!
|
| 143 |
+
|
| 144 |
+
### Features:
|
| 145 |
+
- 🧠 Memory-efficient conversions
|
| 146 |
+
- ♿ Dyslexia-friendly fonts
|
| 147 |
+
- 🌓 Dark/Light modes
|
| 148 |
+
- 🤗 HF Hub integration
|
| 149 |
+
""")
|
| 150 |
+
|
| 151 |
+
# Input Fields
|
| 152 |
+
model_to_load = gr.Textbox(label="Model Path/URL")
|
| 153 |
+
save_precision_as = gr.Dropdown(["float32", "float16"], label="Precision")
|
| 154 |
+
|
| 155 |
+
with gr.Row():
|
| 156 |
+
epoch = gr.Number(label="Epoch", value=0)
|
| 157 |
+
global_step = gr.Number(label="Global Step", value=0)
|
| 158 |
+
|
| 159 |
+
# Conversion Button
|
| 160 |
+
convert_btn = gr.Button("Convert", variant="primary")
|
| 161 |
+
|
| 162 |
+
# Output Panel
|
| 163 |
+
output = gr.Markdown(elem_classes="output-panel")
|
| 164 |
+
|
| 165 |
+
# ---------------------- MAIN EXECUTION ----------------------
|
| 166 |
+
if __name__ == "__main__":
|
| 167 |
+
demo.launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,33 +1,11 @@
|
|
| 1 |
-
# Core dependencies
|
| 2 |
-
numpy>=1.26.4
|
| 3 |
torch>=2.0.0
|
| 4 |
diffusers>=0.21.4
|
| 5 |
transformers>=4.30.0
|
| 6 |
-
einops>=0.7.0
|
| 7 |
-
open-clip-torch>=2.23.0
|
| 8 |
-
|
| 9 |
-
# UI and interface
|
| 10 |
gradio>=3.50.2
|
| 11 |
-
|
| 12 |
-
# Model handling
|
| 13 |
safetensors>=0.3.1
|
| 14 |
-
accelerate>=0.23.0
|
| 15 |
-
|
| 16 |
-
# Utilities
|
| 17 |
psutil>=5.9.0
|
| 18 |
requests>=2.31.0
|
| 19 |
tqdm>=4.65.0
|
| 20 |
gdown>=4.7.1
|
| 21 |
-
|
| 22 |
-
#
|
| 23 |
-
typing-extensions>=4.8.0
|
| 24 |
-
pydantic>=2.0.0
|
| 25 |
-
|
| 26 |
-
# File handling and compression
|
| 27 |
-
fsspec>=2023.0.0
|
| 28 |
-
filelock>=3.13.0
|
| 29 |
-
|
| 30 |
-
# Additional dependencies
|
| 31 |
-
xformers>=0.0.0
|
| 32 |
-
|
| 33 |
-
# Note: This app is hosted on Hugging Face Spaces, so ensure compatibility with their environment.
|
|
|
|
|
|
|
|
|
|
| 1 |
torch>=2.0.0
|
| 2 |
diffusers>=0.21.4
|
| 3 |
transformers>=4.30.0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
gradio>=3.50.2
|
|
|
|
|
|
|
| 5 |
safetensors>=0.3.1
|
|
|
|
|
|
|
|
|
|
| 6 |
psutil>=5.9.0
|
| 7 |
requests>=2.31.0
|
| 8 |
tqdm>=4.65.0
|
| 9 |
gdown>=4.7.1
|
| 10 |
+
huggingface-hub>=0.15.0
|
| 11 |
+
xformers>=0.0.0 # Works without accelerate in this use case
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|