File size: 5,333 Bytes
45131ff 1bf81fc 45131ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
from safetensors.torch import save_file, load_file
import torch
import os
def inspect_keys(file_path, max_keys=10):
"""Helper function to inspect the structure of a safetensors file."""
state = load_file(file_path)
keys = list(state.keys())
print(f"\n{os.path.basename(file_path)} - Total keys: {len(keys)}")
print(f"First {max_keys} keys:")
for k in keys[:max_keys]:
print(f" {k}")
return keys
def merge_for_comfyui(
unet_path,
vae_path,
text_encoder_path,
output_path,
model_type="flux" # "flux", "sd15", "sdxl"
):
"""
Merge components into ComfyUI-compatible safetensors checkpoint.
Args:
unet_path: Path to the main model/transformer safetensors
vae_path: Path to the VAE safetensors
text_encoder_path: Path to the text encoder/CLIP safetensors
output_path: Path for the merged checkpoint
model_type: Type of model (flux, sd15, sdxl)
"""
print("=" * 60)
print("STEP 1: Inspecting input files...")
print("=" * 60)
# Inspect each file to understand structure
unet_keys = inspect_keys(unet_path)
vae_keys = inspect_keys(vae_path)
text_encoder_keys = inspect_keys(text_encoder_path)
print("\n" + "=" * 60)
print("STEP 2: Loading weights...")
print("=" * 60)
unet_state = load_file(unet_path)
vae_state = load_file(vae_path)
text_encoder_state = load_file(text_encoder_path)
print("\n" + "=" * 60)
print("STEP 3: Merging with proper key structure...")
print("=" * 60)
merged_state = {}
# Determine key prefixes based on existing structure
sample_unet_key = unet_keys[0]
sample_vae_key = vae_keys[0]
sample_te_key = text_encoder_keys[0]
print(f"\nDetected key patterns:")
print(f" UNet: {sample_unet_key}")
print(f" VAE: {sample_vae_key}")
print(f" Text Encoder: {sample_te_key}")
# Add UNet/Transformer weights
for key, value in unet_state.items():
# Keep original keys or add model prefix if needed
if key.startswith('model.') or key.startswith('diffusion_model.'):
merged_state[key] = value
else:
# Add ComfyUI-expected prefix
merged_state[f'model.diffusion_model.{key}'] = value
# Add VAE weights with proper structure
for key, value in vae_state.items():
if key.startswith('first_stage_model.') or key.startswith('vae.'):
merged_state[key] = value
elif key.startswith('decoder.') or key.startswith('encoder.'):
merged_state[f'first_stage_model.{key}'] = value
else:
merged_state[f'first_stage_model.decoder.{key}'] = value
# Add text encoder weights
for key, value in text_encoder_state.items():
if key.startswith('cond_stage_model.') or key.startswith('text_encoder.'):
merged_state[key] = value
else:
# For FLUX, might need different structure
if model_type.lower() == "flux":
merged_state[f'text_encoders.{key}'] = value
else:
merged_state[f'cond_stage_model.transformer.{key}'] = value
print(f"\nMerged state contains {len(merged_state)} parameters")
# Add metadata for ComfyUI recognition
print("\n" + "=" * 60)
print("STEP 4: Saving merged checkpoint...")
print("=" * 60)
save_file(merged_state, output_path)
print("\n✅ Merge complete!")
print(f"File saved to: {output_path}")
size_gb = os.path.getsize(output_path) / (1024**3)
print(f"File size: {size_gb:.2f} GB")
# Verify the merged file
print("\n" + "=" * 60)
print("STEP 5: Verifying merged file...")
print("=" * 60)
inspect_keys(output_path, max_keys=20)
def simple_merge_keep_structure(
unet_path,
vae_path,
text_encoder_path,
output_path
):
"""
Simple merge that preserves original key structure.
Use this if the files already have proper ComfyUI keys.
"""
print("Loading all components...")
unet_state = load_file(unet_path)
vae_state = load_file(vae_path)
text_encoder_state = load_file(text_encoder_path)
print("Merging...")
merged_state = {}
merged_state.update(unet_state)
merged_state.update(vae_state)
merged_state.update(text_encoder_state)
print(f"Saving {len(merged_state)} parameters...")
save_file(merged_state, output_path)
size_gb = os.path.getsize(output_path) / (1024**3)
print(f"✅ Done! File size: {size_gb:.2f} GB")
# Example usage
if __name__ == "__main__":
# Option 1: Smart merge with key detection
merge_for_comfyui(
unet_path="../flux1-depth-dev.safetensors",
vae_path="../vae/diffusion_pytorch_model.safetensors",
text_encoder_path="../text_encoder/model.safetensors",
output_path="../flux1-depth-dev_merged_model.safetensors",
model_type="flux"
)
# Option 2: Simple merge (if keys are already correct)
# simple_merge_keep_structure(
# unet_path="path/to/model.safetensors",
# vae_path="path/to/vae.safetensors",
# text_encoder_path="path/to/text_encoder.safetensors",
# output_path="merged_checkpoint.safetensors"
# )
|