|
|
""" |
|
|
Convert Diffusers-format FLUX model to ComfyUI-compatible checkpoint. |
|
|
This handles the proper folder structure and key naming. |
|
|
""" |
|
|
|
|
|
from safetensors.torch import save_file, load_file |
|
|
import os |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
def check_diffusers_structure(diffusers_folder): |
|
|
""" |
|
|
Check and display the structure of a Diffusers model folder. |
|
|
""" |
|
|
diffusers_folder = Path(diffusers_folder) |
|
|
|
|
|
print("=" * 80) |
|
|
print("CHECKING DIFFUSERS FOLDER STRUCTURE") |
|
|
print("=" * 80) |
|
|
print(f"\nFolder: {diffusers_folder}") |
|
|
|
|
|
if not diffusers_folder.exists(): |
|
|
print(f"❌ Folder does not exist!") |
|
|
return False |
|
|
|
|
|
|
|
|
model_index = diffusers_folder / "model_index.json" |
|
|
if not model_index.exists(): |
|
|
print(f"❌ Not a Diffusers model (missing model_index.json)") |
|
|
return False |
|
|
|
|
|
print(f"✅ Found model_index.json") |
|
|
|
|
|
|
|
|
print("\nFolder contents:") |
|
|
for item in sorted(diffusers_folder.iterdir()): |
|
|
if item.is_dir(): |
|
|
print(f" 📁 {item.name}/") |
|
|
|
|
|
for file in sorted(item.iterdir())[:5]: |
|
|
size_mb = file.stat().st_size / (1024**2) |
|
|
print(f" - {file.name} ({size_mb:.1f} MB)") |
|
|
file_count = len(list(item.iterdir())) |
|
|
if file_count > 5: |
|
|
print(f" ... and {file_count - 5} more files") |
|
|
else: |
|
|
size_mb = item.stat().st_size / (1024**2) |
|
|
print(f" 📄 {item.name} ({size_mb:.1f} MB)") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Component Check") |
|
|
print("=" * 80) |
|
|
|
|
|
components = { |
|
|
"transformer": "Main FLUX transformer model", |
|
|
"vae": "VAE encoder/decoder", |
|
|
"text_encoder": "CLIP text encoder", |
|
|
"text_encoder_2": "T5 text encoder" |
|
|
} |
|
|
|
|
|
for folder_name, description in components.items(): |
|
|
folder_path = diffusers_folder / folder_name |
|
|
if folder_path.exists(): |
|
|
safetensors_files = list(folder_path.glob("*.safetensors")) |
|
|
if safetensors_files: |
|
|
print(f"✅ {folder_name}: {description}") |
|
|
print(f" Found: {safetensors_files[0].name}") |
|
|
else: |
|
|
print(f"⚠️ {folder_name}: folder exists but no .safetensors files") |
|
|
else: |
|
|
print(f"❌ {folder_name}: missing") |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def convert_diffusers_to_comfyui( |
|
|
diffusers_folder, |
|
|
output_path, |
|
|
fp16=False |
|
|
): |
|
|
""" |
|
|
Convert a Diffusers FLUX model folder to a single ComfyUI checkpoint. |
|
|
|
|
|
Args: |
|
|
diffusers_folder: Path to folder containing model_index.json |
|
|
output_path: Output path for the merged .safetensors file |
|
|
fp16: If True, convert to float16 to save space |
|
|
""" |
|
|
|
|
|
diffusers_folder = Path(diffusers_folder) |
|
|
|
|
|
|
|
|
model_index = diffusers_folder / "model_index.json" |
|
|
if not model_index.exists(): |
|
|
raise ValueError(f"Not a Diffusers model folder. Missing: {model_index}") |
|
|
|
|
|
with open(model_index) as f: |
|
|
config = json.load(f) |
|
|
|
|
|
print("=" * 80) |
|
|
print("DIFFUSERS TO COMFYUI CONVERTER") |
|
|
print("=" * 80) |
|
|
print(f"\nModel: {config.get('_name_or_path', 'Unknown')}") |
|
|
print(f"Format: {config.get('_class_name', 'Unknown')}") |
|
|
|
|
|
merged_state = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Loading Transformer...") |
|
|
print("=" * 80) |
|
|
|
|
|
transformer_path = diffusers_folder / "transformer" |
|
|
transformer_file = None |
|
|
|
|
|
|
|
|
if not transformer_path.exists(): |
|
|
raise ValueError(f"Transformer folder not found: {transformer_path}") |
|
|
|
|
|
for file in transformer_path.glob("*.safetensors"): |
|
|
transformer_file = file |
|
|
break |
|
|
|
|
|
if not transformer_file: |
|
|
print(f"\n❌ No safetensors file found in: {transformer_path}") |
|
|
print("\nFiles in transformer folder:") |
|
|
for file in transformer_path.iterdir(): |
|
|
print(f" - {file.name}") |
|
|
raise ValueError(f"No safetensors file found in {transformer_path}") |
|
|
|
|
|
print(f"Found: {transformer_file.name}") |
|
|
transformer_state = load_file(str(transformer_file)) |
|
|
print(f"Loaded {len(transformer_state)} transformer parameters") |
|
|
|
|
|
|
|
|
for key, value in transformer_state.items(): |
|
|
if fp16 and value.dtype.is_floating_point: |
|
|
value = value.half() |
|
|
merged_state[key] = value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Loading VAE...") |
|
|
print("=" * 80) |
|
|
|
|
|
vae_path = diffusers_folder / "vae" |
|
|
vae_file = None |
|
|
|
|
|
for file in vae_path.glob("*.safetensors"): |
|
|
vae_file = file |
|
|
break |
|
|
|
|
|
if not vae_file: |
|
|
print("⚠️ No VAE file found, skipping...") |
|
|
else: |
|
|
print(f"Found: {vae_file.name}") |
|
|
vae_state = load_file(str(vae_file)) |
|
|
print(f"Loaded {len(vae_state)} VAE parameters") |
|
|
|
|
|
|
|
|
for key, value in vae_state.items(): |
|
|
if fp16 and value.dtype.is_floating_point: |
|
|
value = value.half() |
|
|
|
|
|
merged_state[key] = value |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Loading Text Encoders...") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
clip_path = diffusers_folder / "text_encoder" |
|
|
if clip_path.exists(): |
|
|
clip_file = None |
|
|
for file in clip_path.glob("*.safetensors"): |
|
|
clip_file = file |
|
|
break |
|
|
|
|
|
if clip_file: |
|
|
print(f"Found CLIP: {clip_file.name}") |
|
|
clip_state = load_file(str(clip_file)) |
|
|
print(f"Loaded {len(clip_state)} CLIP parameters") |
|
|
|
|
|
for key, value in clip_state.items(): |
|
|
if fp16 and value.dtype.is_floating_point: |
|
|
value = value.half() |
|
|
|
|
|
merged_state[key] = value |
|
|
else: |
|
|
print("⚠️ No CLIP file found") |
|
|
|
|
|
|
|
|
t5_path = diffusers_folder / "text_encoder_2" |
|
|
if t5_path.exists(): |
|
|
t5_file = None |
|
|
for file in t5_path.glob("*.safetensors"): |
|
|
t5_file = file |
|
|
break |
|
|
|
|
|
if t5_file: |
|
|
print(f"Found T5: {t5_file.name}") |
|
|
print("⚠️ Loading T5 (this may take a while, it's large)...") |
|
|
t5_state = load_file(str(t5_file)) |
|
|
print(f"Loaded {len(t5_state)} T5 parameters") |
|
|
|
|
|
for key, value in t5_state.items(): |
|
|
if fp16 and value.dtype.is_floating_point: |
|
|
value = value.half() |
|
|
merged_state[key] = value |
|
|
else: |
|
|
print("⚠️ No T5 file found") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Saving merged checkpoint...") |
|
|
print("=" * 80) |
|
|
|
|
|
print(f"Total parameters: {len(merged_state):,}") |
|
|
print(f"Output: {output_path}") |
|
|
|
|
|
save_file(merged_state, output_path) |
|
|
|
|
|
size_gb = os.path.getsize(output_path) / (1024**3) |
|
|
print(f"\n✅ Conversion complete!") |
|
|
print(f"File size: {size_gb:.2f} GB") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("Key Structure in Merged File") |
|
|
print("=" * 80) |
|
|
|
|
|
sample_keys = list(merged_state.keys())[:10] |
|
|
print("\nFirst 10 keys:") |
|
|
for key in sample_keys: |
|
|
print(f" {key}") |
|
|
|
|
|
return output_path |
|
|
|
|
|
|
|
|
def convert_with_working_template( |
|
|
diffusers_folder, |
|
|
working_checkpoint, |
|
|
output_path, |
|
|
replace_transformer_only=True |
|
|
): |
|
|
""" |
|
|
Use a working checkpoint as template, replacing components from Diffusers model. |
|
|
This ensures key naming matches what ComfyUI expects. |
|
|
|
|
|
Args: |
|
|
diffusers_folder: Path to Diffusers model folder |
|
|
working_checkpoint: Path to a working ComfyUI checkpoint |
|
|
output_path: Output path for merged checkpoint |
|
|
replace_transformer_only: If True, only replace transformer, keep VAE/encoders from template |
|
|
""" |
|
|
|
|
|
print("=" * 80) |
|
|
print("TEMPLATE-BASED CONVERSION") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
print("\nLoading template checkpoint...") |
|
|
template_state = load_file(working_checkpoint) |
|
|
print(f"Template has {len(template_state)} parameters") |
|
|
|
|
|
|
|
|
template_keys = set(template_state.keys()) |
|
|
transformer_keys = {k for k in template_keys if 'transformer' in k or 'double_blocks' in k or 'single_blocks' in k} |
|
|
vae_keys = {k for k in template_keys if 'vae' in k.lower() or 'first_stage' in k} |
|
|
text_encoder_keys = {k for k in template_keys if 'text_encoder' in k or 'clip' in k.lower()} |
|
|
|
|
|
print(f"\nTemplate structure:") |
|
|
print(f" Transformer keys: {len(transformer_keys)}") |
|
|
print(f" VAE keys: {len(vae_keys)}") |
|
|
print(f" Text encoder keys: {len(text_encoder_keys)}") |
|
|
|
|
|
|
|
|
diffusers_folder = Path(diffusers_folder) |
|
|
transformer_path = diffusers_folder / "transformer" |
|
|
|
|
|
if not transformer_path.exists(): |
|
|
raise ValueError(f"Transformer folder not found: {transformer_path}") |
|
|
|
|
|
|
|
|
transformer_file = None |
|
|
patterns = ["*.safetensors", "model.safetensors", "diffusion_pytorch_model.safetensors"] |
|
|
|
|
|
for pattern in patterns: |
|
|
files = list(transformer_path.glob(pattern)) |
|
|
if files: |
|
|
transformer_file = files[0] |
|
|
break |
|
|
|
|
|
|
|
|
if not transformer_file: |
|
|
print(f"\n❌ No safetensors file found in: {transformer_path}") |
|
|
print("\nFiles in transformer folder:") |
|
|
for file in transformer_path.iterdir(): |
|
|
print(f" - {file.name}") |
|
|
raise ValueError("Could not find transformer safetensors file. See list above.") |
|
|
|
|
|
print(f"\nLoading new transformer from: {transformer_file.name}") |
|
|
new_transformer = load_file(str(transformer_file)) |
|
|
|
|
|
|
|
|
print("\nReplacing transformer weights...") |
|
|
merged_state = dict(template_state) |
|
|
|
|
|
|
|
|
replaced = 0 |
|
|
for key in transformer_keys: |
|
|
if key in new_transformer: |
|
|
merged_state[key] = new_transformer[key] |
|
|
replaced += 1 |
|
|
|
|
|
print(f"Replaced {replaced} transformer parameters") |
|
|
|
|
|
if not replace_transformer_only: |
|
|
print("\n⚠️ Also replacing VAE and text encoders...") |
|
|
|
|
|
vae_file = next((diffusers_folder / "vae").glob("*.safetensors"), None) |
|
|
if vae_file: |
|
|
vae_state = load_file(str(vae_file)) |
|
|
for key in vae_keys: |
|
|
if key in vae_state: |
|
|
merged_state[key] = vae_state[key] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\nSaving to {output_path}...") |
|
|
save_file(merged_state, output_path) |
|
|
|
|
|
size_gb = os.path.getsize(output_path) / (1024**3) |
|
|
print(f"✅ Done! File size: {size_gb:.2f} GB") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
check_diffusers_structure("../") |
|
|
|
|
|
print("\n\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
convert_with_working_template( |
|
|
diffusers_folder="../", |
|
|
working_checkpoint="../flux1-depth-dev.safetensors", |
|
|
output_path="../flux1-depth-dev_ComfyMerged.safetensors", |
|
|
replace_transformer_only=False |
|
|
) |
|
|
|