File size: 9,341 Bytes
1bf81fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 |
"""
Convert Diffusers-format FLUX model to ComfyUI-compatible checkpoint.
This handles the proper folder structure and key naming.
"""
from safetensors.torch import save_file, load_file
import os
import json
from pathlib import Path
def convert_diffusers_to_comfyui(
diffusers_folder,
output_path,
fp16=False
):
"""
Convert a Diffusers FLUX model folder to a single ComfyUI checkpoint.
Args:
diffusers_folder: Path to folder containing model_index.json
output_path: Output path for the merged .safetensors file
fp16: If True, convert to float16 to save space
"""
diffusers_folder = Path(diffusers_folder)
# Verify it's a Diffusers model
model_index = diffusers_folder / "model_index.json"
if not model_index.exists():
raise ValueError(f"Not a Diffusers model folder. Missing: {model_index}")
with open(model_index) as f:
config = json.load(f)
print("=" * 80)
print("DIFFUSERS TO COMFYUI CONVERTER")
print("=" * 80)
print(f"\nModel: {config.get('_name_or_path', 'Unknown')}")
print(f"Format: {config.get('_class_name', 'Unknown')}")
merged_state = {}
# ========================================================================
# 1. Load Transformer (main FLUX model)
# ========================================================================
print("\n" + "=" * 80)
print("Loading Transformer...")
print("=" * 80)
transformer_path = diffusers_folder / "transformer"
transformer_file = None
# Find the safetensors file
for file in transformer_path.glob("*.safetensors"):
transformer_file = file
break
if not transformer_file:
raise ValueError(f"No safetensors file found in {transformer_path}")
print(f"Found: {transformer_file.name}")
transformer_state = load_file(str(transformer_file))
print(f"Loaded {len(transformer_state)} transformer parameters")
# Add transformer weights (keep original keys or minimal prefix)
for key, value in transformer_state.items():
if fp16 and value.dtype.is_floating_point:
value = value.half()
merged_state[key] = value
# ========================================================================
# 2. Load VAE
# ========================================================================
print("\n" + "=" * 80)
print("Loading VAE...")
print("=" * 80)
vae_path = diffusers_folder / "vae"
vae_file = None
for file in vae_path.glob("*.safetensors"):
vae_file = file
break
if not vae_file:
print("⚠️ No VAE file found, skipping...")
else:
print(f"Found: {vae_file.name}")
vae_state = load_file(str(vae_file))
print(f"Loaded {len(vae_state)} VAE parameters")
# Add VAE weights with proper prefix
for key, value in vae_state.items():
if fp16 and value.dtype.is_floating_point:
value = value.half()
# Keep original Diffusers VAE key structure
merged_state[key] = value
# ========================================================================
# 3. Load Text Encoders (CLIP + T5)
# ========================================================================
print("\n" + "=" * 80)
print("Loading Text Encoders...")
print("=" * 80)
# CLIP (text_encoder)
clip_path = diffusers_folder / "text_encoder"
if clip_path.exists():
clip_file = None
for file in clip_path.glob("*.safetensors"):
clip_file = file
break
if clip_file:
print(f"Found CLIP: {clip_file.name}")
clip_state = load_file(str(clip_file))
print(f"Loaded {len(clip_state)} CLIP parameters")
for key, value in clip_state.items():
if fp16 and value.dtype.is_floating_point:
value = value.half()
# Keep original structure
merged_state[key] = value
else:
print("⚠️ No CLIP file found")
# T5 (text_encoder_2) - often the largest component
t5_path = diffusers_folder / "text_encoder_2"
if t5_path.exists():
t5_file = None
for file in t5_path.glob("*.safetensors"):
t5_file = file
break
if t5_file:
print(f"Found T5: {t5_file.name}")
print("⚠️ Loading T5 (this may take a while, it's large)...")
t5_state = load_file(str(t5_file))
print(f"Loaded {len(t5_state)} T5 parameters")
for key, value in t5_state.items():
if fp16 and value.dtype.is_floating_point:
value = value.half()
merged_state[key] = value
else:
print("⚠️ No T5 file found")
# ========================================================================
# Save merged checkpoint
# ========================================================================
print("\n" + "=" * 80)
print("Saving merged checkpoint...")
print("=" * 80)
print(f"Total parameters: {len(merged_state):,}")
print(f"Output: {output_path}")
save_file(merged_state, output_path)
size_gb = os.path.getsize(output_path) / (1024**3)
print(f"\n✅ Conversion complete!")
print(f"File size: {size_gb:.2f} GB")
# Show key structure
print("\n" + "=" * 80)
print("Key Structure in Merged File")
print("=" * 80)
sample_keys = list(merged_state.keys())[:10]
print("\nFirst 10 keys:")
for key in sample_keys:
print(f" {key}")
return output_path
def convert_with_working_template(
diffusers_folder,
working_checkpoint,
output_path,
replace_transformer_only=True
):
"""
Use a working checkpoint as template, replacing components from Diffusers model.
This ensures key naming matches what ComfyUI expects.
Args:
diffusers_folder: Path to Diffusers model folder
working_checkpoint: Path to a working ComfyUI checkpoint
output_path: Output path for merged checkpoint
replace_transformer_only: If True, only replace transformer, keep VAE/encoders from template
"""
print("=" * 80)
print("TEMPLATE-BASED CONVERSION")
print("=" * 80)
# Load working checkpoint as template
print("\nLoading template checkpoint...")
template_state = load_file(working_checkpoint)
print(f"Template has {len(template_state)} parameters")
# Get key prefixes from template
template_keys = set(template_state.keys())
transformer_keys = {k for k in template_keys if 'transformer' in k or 'double_blocks' in k or 'single_blocks' in k}
vae_keys = {k for k in template_keys if 'vae' in k.lower() or 'first_stage' in k}
text_encoder_keys = {k for k in template_keys if 'text_encoder' in k or 'clip' in k.lower()}
print(f"\nTemplate structure:")
print(f" Transformer keys: {len(transformer_keys)}")
print(f" VAE keys: {len(vae_keys)}")
print(f" Text encoder keys: {len(text_encoder_keys)}")
# Load transformer from Diffusers
diffusers_folder = Path(diffusers_folder)
transformer_path = diffusers_folder / "transformer"
transformer_file = next(transformer_path.glob("*.safetensors"))
print(f"\nLoading new transformer from: {transformer_file.name}")
new_transformer = load_file(str(transformer_file))
# Replace transformer weights
print("\nReplacing transformer weights...")
merged_state = dict(template_state) # Copy template
# Replace matching keys
replaced = 0
for key in transformer_keys:
if key in new_transformer:
merged_state[key] = new_transformer[key]
replaced += 1
print(f"Replaced {replaced} transformer parameters")
if not replace_transformer_only:
print("\n⚠️ Also replacing VAE and text encoders...")
# Load and replace VAE
vae_file = next((diffusers_folder / "vae").glob("*.safetensors"), None)
if vae_file:
vae_state = load_file(str(vae_file))
for key in vae_keys:
if key in vae_state:
merged_state[key] = vae_state[key]
# Similar for text encoders...
# Save
print(f"\nSaving to {output_path}...")
save_file(merged_state, output_path)
size_gb = os.path.getsize(output_path) / (1024**3)
print(f"✅ Done! File size: {size_gb:.2f} GB")
# Example usage
if __name__ == "__main__":
# Method 1: Direct conversion
# convert_diffusers_to_comfyui(
# diffusers_folder="../",
# output_path="flux1-depth-dev_ComfyMerged.safetensors",
# fp16=True # Set False to keep original precision
# )
#Method 2: Use working checkpoint as template (RECOMMENDED)
convert_with_working_template(
diffusers_folder="../",
working_checkpoint="../quantized/svdq-fp4_r32-flux.1-depth-dev.safetensors",
output_path="svdq-fp4_r32-flux.1-depth-dev_ComfyMerged.safetensors",
replace_transformer_only=True
)
|