File size: 12,806 Bytes
1bf81fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 |
"""
Convert Diffusers-format FLUX model to ComfyUI-compatible checkpoint.
This handles the proper folder structure and key naming.
"""
from safetensors.torch import save_file, load_file
import os
import json
from pathlib import Path
def check_diffusers_structure(diffusers_folder):
"""
Check and display the structure of a Diffusers model folder.
"""
diffusers_folder = Path(diffusers_folder)
print("=" * 80)
print("CHECKING DIFFUSERS FOLDER STRUCTURE")
print("=" * 80)
print(f"\nFolder: {diffusers_folder}")
if not diffusers_folder.exists():
print(f"❌ Folder does not exist!")
return False
# Check for model_index.json
model_index = diffusers_folder / "model_index.json"
if not model_index.exists():
print(f"❌ Not a Diffusers model (missing model_index.json)")
return False
print(f"✅ Found model_index.json")
# List all subfolders and files
print("\nFolder contents:")
for item in sorted(diffusers_folder.iterdir()):
if item.is_dir():
print(f" 📁 {item.name}/")
# List files in subfolder
for file in sorted(item.iterdir())[:5]: # Show first 5 files
size_mb = file.stat().st_size / (1024**2)
print(f" - {file.name} ({size_mb:.1f} MB)")
file_count = len(list(item.iterdir()))
if file_count > 5:
print(f" ... and {file_count - 5} more files")
else:
size_mb = item.stat().st_size / (1024**2)
print(f" 📄 {item.name} ({size_mb:.1f} MB)")
# Check expected components
print("\n" + "=" * 80)
print("Component Check")
print("=" * 80)
components = {
"transformer": "Main FLUX transformer model",
"vae": "VAE encoder/decoder",
"text_encoder": "CLIP text encoder",
"text_encoder_2": "T5 text encoder"
}
for folder_name, description in components.items():
folder_path = diffusers_folder / folder_name
if folder_path.exists():
safetensors_files = list(folder_path.glob("*.safetensors"))
if safetensors_files:
print(f"✅ {folder_name}: {description}")
print(f" Found: {safetensors_files[0].name}")
else:
print(f"⚠️ {folder_name}: folder exists but no .safetensors files")
else:
print(f"❌ {folder_name}: missing")
return True
def convert_diffusers_to_comfyui(
diffusers_folder,
output_path,
fp16=False
):
"""
Convert a Diffusers FLUX model folder to a single ComfyUI checkpoint.
Args:
diffusers_folder: Path to folder containing model_index.json
output_path: Output path for the merged .safetensors file
fp16: If True, convert to float16 to save space
"""
diffusers_folder = Path(diffusers_folder)
# Verify it's a Diffusers model
model_index = diffusers_folder / "model_index.json"
if not model_index.exists():
raise ValueError(f"Not a Diffusers model folder. Missing: {model_index}")
with open(model_index) as f:
config = json.load(f)
print("=" * 80)
print("DIFFUSERS TO COMFYUI CONVERTER")
print("=" * 80)
print(f"\nModel: {config.get('_name_or_path', 'Unknown')}")
print(f"Format: {config.get('_class_name', 'Unknown')}")
merged_state = {}
# ========================================================================
# 1. Load Transformer (main FLUX model)
# ========================================================================
print("\n" + "=" * 80)
print("Loading Transformer...")
print("=" * 80)
transformer_path = diffusers_folder / "transformer"
transformer_file = None
# Find the safetensors file
if not transformer_path.exists():
raise ValueError(f"Transformer folder not found: {transformer_path}")
for file in transformer_path.glob("*.safetensors"):
transformer_file = file
break
if not transformer_file:
print(f"\n❌ No safetensors file found in: {transformer_path}")
print("\nFiles in transformer folder:")
for file in transformer_path.iterdir():
print(f" - {file.name}")
raise ValueError(f"No safetensors file found in {transformer_path}")
print(f"Found: {transformer_file.name}")
transformer_state = load_file(str(transformer_file))
print(f"Loaded {len(transformer_state)} transformer parameters")
# Add transformer weights (keep original keys or minimal prefix)
for key, value in transformer_state.items():
if fp16 and value.dtype.is_floating_point:
value = value.half()
merged_state[key] = value
# ========================================================================
# 2. Load VAE
# ========================================================================
print("\n" + "=" * 80)
print("Loading VAE...")
print("=" * 80)
vae_path = diffusers_folder / "vae"
vae_file = None
for file in vae_path.glob("*.safetensors"):
vae_file = file
break
if not vae_file:
print("⚠️ No VAE file found, skipping...")
else:
print(f"Found: {vae_file.name}")
vae_state = load_file(str(vae_file))
print(f"Loaded {len(vae_state)} VAE parameters")
# Add VAE weights with proper prefix
for key, value in vae_state.items():
if fp16 and value.dtype.is_floating_point:
value = value.half()
# Keep original Diffusers VAE key structure
merged_state[key] = value
# ========================================================================
# 3. Load Text Encoders (CLIP + T5)
# ========================================================================
print("\n" + "=" * 80)
print("Loading Text Encoders...")
print("=" * 80)
# CLIP (text_encoder)
clip_path = diffusers_folder / "text_encoder"
if clip_path.exists():
clip_file = None
for file in clip_path.glob("*.safetensors"):
clip_file = file
break
if clip_file:
print(f"Found CLIP: {clip_file.name}")
clip_state = load_file(str(clip_file))
print(f"Loaded {len(clip_state)} CLIP parameters")
for key, value in clip_state.items():
if fp16 and value.dtype.is_floating_point:
value = value.half()
# Keep original structure
merged_state[key] = value
else:
print("⚠️ No CLIP file found")
# T5 (text_encoder_2) - often the largest component
t5_path = diffusers_folder / "text_encoder_2"
if t5_path.exists():
t5_file = None
for file in t5_path.glob("*.safetensors"):
t5_file = file
break
if t5_file:
print(f"Found T5: {t5_file.name}")
print("⚠️ Loading T5 (this may take a while, it's large)...")
t5_state = load_file(str(t5_file))
print(f"Loaded {len(t5_state)} T5 parameters")
for key, value in t5_state.items():
if fp16 and value.dtype.is_floating_point:
value = value.half()
merged_state[key] = value
else:
print("⚠️ No T5 file found")
# ========================================================================
# Save merged checkpoint
# ========================================================================
print("\n" + "=" * 80)
print("Saving merged checkpoint...")
print("=" * 80)
print(f"Total parameters: {len(merged_state):,}")
print(f"Output: {output_path}")
save_file(merged_state, output_path)
size_gb = os.path.getsize(output_path) / (1024**3)
print(f"\n✅ Conversion complete!")
print(f"File size: {size_gb:.2f} GB")
# Show key structure
print("\n" + "=" * 80)
print("Key Structure in Merged File")
print("=" * 80)
sample_keys = list(merged_state.keys())[:10]
print("\nFirst 10 keys:")
for key in sample_keys:
print(f" {key}")
return output_path
def convert_with_working_template(
diffusers_folder,
working_checkpoint,
output_path,
replace_transformer_only=True
):
"""
Use a working checkpoint as template, replacing components from Diffusers model.
This ensures key naming matches what ComfyUI expects.
Args:
diffusers_folder: Path to Diffusers model folder
working_checkpoint: Path to a working ComfyUI checkpoint
output_path: Output path for merged checkpoint
replace_transformer_only: If True, only replace transformer, keep VAE/encoders from template
"""
print("=" * 80)
print("TEMPLATE-BASED CONVERSION")
print("=" * 80)
# Load working checkpoint as template
print("\nLoading template checkpoint...")
template_state = load_file(working_checkpoint)
print(f"Template has {len(template_state)} parameters")
# Get key prefixes from template
template_keys = set(template_state.keys())
transformer_keys = {k for k in template_keys if 'transformer' in k or 'double_blocks' in k or 'single_blocks' in k}
vae_keys = {k for k in template_keys if 'vae' in k.lower() or 'first_stage' in k}
text_encoder_keys = {k for k in template_keys if 'text_encoder' in k or 'clip' in k.lower()}
print(f"\nTemplate structure:")
print(f" Transformer keys: {len(transformer_keys)}")
print(f" VAE keys: {len(vae_keys)}")
print(f" Text encoder keys: {len(text_encoder_keys)}")
# Load transformer from Diffusers
diffusers_folder = Path(diffusers_folder)
transformer_path = diffusers_folder / "transformer"
if not transformer_path.exists():
raise ValueError(f"Transformer folder not found: {transformer_path}")
# Find safetensors file (try multiple patterns)
transformer_file = None
patterns = ["*.safetensors", "model.safetensors", "diffusion_pytorch_model.safetensors"]
for pattern in patterns:
files = list(transformer_path.glob(pattern))
if files:
transformer_file = files[0]
break
# If still not found, list what's actually there
if not transformer_file:
print(f"\n❌ No safetensors file found in: {transformer_path}")
print("\nFiles in transformer folder:")
for file in transformer_path.iterdir():
print(f" - {file.name}")
raise ValueError("Could not find transformer safetensors file. See list above.")
print(f"\nLoading new transformer from: {transformer_file.name}")
new_transformer = load_file(str(transformer_file))
# Replace transformer weights
print("\nReplacing transformer weights...")
merged_state = dict(template_state) # Copy template
# Replace matching keys
replaced = 0
for key in transformer_keys:
if key in new_transformer:
merged_state[key] = new_transformer[key]
replaced += 1
print(f"Replaced {replaced} transformer parameters")
if not replace_transformer_only:
print("\n⚠️ Also replacing VAE and text encoders...")
# Load and replace VAE
vae_file = next((diffusers_folder / "vae").glob("*.safetensors"), None)
if vae_file:
vae_state = load_file(str(vae_file))
for key in vae_keys:
if key in vae_state:
merged_state[key] = vae_state[key]
# Similar for text encoders...
# Save
print(f"\nSaving to {output_path}...")
save_file(merged_state, output_path)
size_gb = os.path.getsize(output_path) / (1024**3)
print(f"✅ Done! File size: {size_gb:.2f} GB")
# Example usage
if __name__ == "__main__":
# Step 0: Check folder structure first
check_diffusers_structure("../")
print("\n\n")
# Method 1: Direct conversion
# convert_diffusers_to_comfyui(
# diffusers_folder="../",
# output_path="flux1-depth-dev_ComfyMerged.safetensors",
# fp16=True # Set False to keep original precision
# )
#Method 2: Use working checkpoint as template (RECOMMENDED)
convert_with_working_template(
diffusers_folder="../",
working_checkpoint="../flux1-depth-dev.safetensors",
output_path="../flux1-depth-dev_ComfyMerged.safetensors",
replace_transformer_only=False
)
|