New merged file with labels
Browse files- CheckSafetensors.py +10 -0
- InspectSafetensors.py +220 -0
- MergeSafetensors.py +2 -2
- MergeSafetensors2.py +163 -0
- ae.safetensors +0 -3
- flux1-depth-dev_fp4_merged_model.safetensors +2 -2
CheckSafetensors.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from safetensors import safe_open
|
| 2 |
+
|
| 3 |
+
with safe_open("flux1-depth-dev_fp4_merged_model.safetensors", framework="pt") as f:
|
| 4 |
+
keys = f.keys()
|
| 5 |
+
|
| 6 |
+
has_vae = any('vae' in k or 'decoder' in k for k in keys)
|
| 7 |
+
has_clip = any('text_encoder' in k or 'clip' in k for k in keys)
|
| 8 |
+
has_unet = any('unet' in k or 'transformer' in k for k in keys)
|
| 9 |
+
|
| 10 |
+
print(f"VAE: {has_vae}, CLIP: {has_clip}, UNet: {has_unet}")
|
InspectSafetensors.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from safetensors import safe_open
|
| 2 |
+
from collections import defaultdict
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
def inspect_checkpoint(checkpoint_path, detailed=False):
|
| 6 |
+
"""
|
| 7 |
+
Inspect the structure of a safetensors checkpoint file.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
checkpoint_path: Path to the .safetensors file
|
| 11 |
+
detailed: If True, shows more detailed information
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
if not os.path.exists(checkpoint_path):
|
| 15 |
+
print(f"❌ File not found: {checkpoint_path}")
|
| 16 |
+
return
|
| 17 |
+
|
| 18 |
+
print("=" * 80)
|
| 19 |
+
print(f"INSPECTING: {os.path.basename(checkpoint_path)}")
|
| 20 |
+
print("=" * 80)
|
| 21 |
+
|
| 22 |
+
# File size
|
| 23 |
+
size_bytes = os.path.getsize(checkpoint_path)
|
| 24 |
+
size_gb = size_bytes / (1024**3)
|
| 25 |
+
print(f"\n📦 File Size: {size_gb:.2f} GB ({size_bytes:,} bytes)")
|
| 26 |
+
|
| 27 |
+
with safe_open(checkpoint_path, framework="pt") as f:
|
| 28 |
+
keys = list(f.keys())
|
| 29 |
+
|
| 30 |
+
print(f"\n📊 Total Parameters: {len(keys):,}")
|
| 31 |
+
|
| 32 |
+
# Categorize keys by component
|
| 33 |
+
print("\n" + "=" * 80)
|
| 34 |
+
print("COMPONENT BREAKDOWN")
|
| 35 |
+
print("=" * 80)
|
| 36 |
+
|
| 37 |
+
categories = defaultdict(list)
|
| 38 |
+
|
| 39 |
+
for key in keys:
|
| 40 |
+
# Categorize by prefix
|
| 41 |
+
if any(x in key.lower() for x in ['vae', 'first_stage', 'decoder', 'encoder', 'quant_conv', 'post_quant']):
|
| 42 |
+
categories['VAE'].append(key)
|
| 43 |
+
elif any(x in key.lower() for x in ['text_encoder', 'cond_stage', 'clip', 'transformer.text_model']):
|
| 44 |
+
categories['Text Encoder'].append(key)
|
| 45 |
+
elif any(x in key.lower() for x in ['model.diffusion', 'diffusion_model', 'transformer', 'double_blocks', 'single_blocks']):
|
| 46 |
+
categories['UNet/Transformer'].append(key)
|
| 47 |
+
else:
|
| 48 |
+
categories['Other'].append(key)
|
| 49 |
+
|
| 50 |
+
# Print summary
|
| 51 |
+
for category, cat_keys in sorted(categories.items()):
|
| 52 |
+
print(f"\n{category}: {len(cat_keys)} parameters")
|
| 53 |
+
|
| 54 |
+
# Analyze key patterns
|
| 55 |
+
print("\n" + "=" * 80)
|
| 56 |
+
print("KEY PATTERNS")
|
| 57 |
+
print("=" * 80)
|
| 58 |
+
|
| 59 |
+
# Group by top-level prefix
|
| 60 |
+
prefix_groups = defaultdict(int)
|
| 61 |
+
for key in keys:
|
| 62 |
+
prefix = key.split('.')[0] if '.' in key else key
|
| 63 |
+
prefix_groups[prefix] += 1
|
| 64 |
+
|
| 65 |
+
print("\nTop-level prefixes:")
|
| 66 |
+
for prefix, count in sorted(prefix_groups.items(), key=lambda x: -x[1]):
|
| 67 |
+
print(f" {prefix}: {count} parameters")
|
| 68 |
+
|
| 69 |
+
# Show sample keys from each category
|
| 70 |
+
print("\n" + "=" * 80)
|
| 71 |
+
print("SAMPLE KEYS FROM EACH COMPONENT")
|
| 72 |
+
print("=" * 80)
|
| 73 |
+
|
| 74 |
+
for category, cat_keys in sorted(categories.items()):
|
| 75 |
+
if cat_keys:
|
| 76 |
+
print(f"\n{category} (showing first 5):")
|
| 77 |
+
for key in cat_keys[:5]:
|
| 78 |
+
tensor = f.get_tensor(key)
|
| 79 |
+
print(f" {key}")
|
| 80 |
+
print(f" └─ shape: {tuple(tensor.shape)}, dtype: {tensor.dtype}")
|
| 81 |
+
|
| 82 |
+
if detailed:
|
| 83 |
+
print("\n" + "=" * 80)
|
| 84 |
+
print("ALL KEYS (DETAILED)")
|
| 85 |
+
print("=" * 80)
|
| 86 |
+
|
| 87 |
+
for i, key in enumerate(keys, 1):
|
| 88 |
+
tensor = f.get_tensor(key)
|
| 89 |
+
print(f"\n{i}. {key}")
|
| 90 |
+
print(f" Shape: {tuple(tensor.shape)}")
|
| 91 |
+
print(f" Dtype: {tensor.dtype}")
|
| 92 |
+
print(f" Size: {tensor.numel():,} elements")
|
| 93 |
+
|
| 94 |
+
# Check for common FLUX/SD patterns
|
| 95 |
+
print("\n" + "=" * 80)
|
| 96 |
+
print("MODEL TYPE DETECTION")
|
| 97 |
+
print("=" * 80)
|
| 98 |
+
|
| 99 |
+
has_flux_blocks = any('double_blocks' in k or 'single_blocks' in k for k in keys)
|
| 100 |
+
has_sd_unet = any('model.diffusion_model' in k for k in keys)
|
| 101 |
+
has_vae = any('vae' in k.lower() or 'first_stage' in k for k in keys)
|
| 102 |
+
has_text_encoder = any('text_encoder' in k.lower() or 'cond_stage' in k for k in keys)
|
| 103 |
+
|
| 104 |
+
print(f"\n✓ FLUX-style blocks: {'✅ YES' if has_flux_blocks else '❌ NO'}")
|
| 105 |
+
print(f"✓ SD-style UNet: {'✅ YES' if has_sd_unet else '❌ NO'}")
|
| 106 |
+
print(f"✓ VAE included: {'✅ YES' if has_vae else '❌ NO'}")
|
| 107 |
+
print(f"✓ Text Encoder included: {'✅ YES' if has_text_encoder else '❌ NO'}")
|
| 108 |
+
|
| 109 |
+
if has_flux_blocks:
|
| 110 |
+
print("\n🔍 Likely model type: FLUX")
|
| 111 |
+
elif has_sd_unet:
|
| 112 |
+
print("\n🔍 Likely model type: Stable Diffusion")
|
| 113 |
+
else:
|
| 114 |
+
print("\n⚠️ Could not determine model type")
|
| 115 |
+
|
| 116 |
+
# Check if complete checkpoint
|
| 117 |
+
print("\n" + "=" * 80)
|
| 118 |
+
print("CHECKPOINT COMPLETENESS")
|
| 119 |
+
print("=" * 80)
|
| 120 |
+
|
| 121 |
+
if has_vae and has_text_encoder:
|
| 122 |
+
print("\n✅ This appears to be a COMPLETE checkpoint")
|
| 123 |
+
print(" (Contains UNet/Transformer + VAE + Text Encoder)")
|
| 124 |
+
else:
|
| 125 |
+
print("\n⚠️ This appears to be a PARTIAL checkpoint")
|
| 126 |
+
if not has_vae:
|
| 127 |
+
print(" Missing: VAE")
|
| 128 |
+
if not has_text_encoder:
|
| 129 |
+
print(" Missing: Text Encoder")
|
| 130 |
+
|
| 131 |
+
print("\n" + "=" * 80)
|
| 132 |
+
print("INSPECTION COMPLETE")
|
| 133 |
+
print("=" * 80)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def compare_checkpoints(working_checkpoint, broken_checkpoint):
|
| 137 |
+
"""
|
| 138 |
+
Compare two checkpoints to see the differences.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
working_checkpoint: Path to checkpoint that works
|
| 142 |
+
broken_checkpoint: Path to checkpoint that doesn't work
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
print("=" * 80)
|
| 146 |
+
print("COMPARING CHECKPOINTS")
|
| 147 |
+
print("=" * 80)
|
| 148 |
+
|
| 149 |
+
with safe_open(working_checkpoint, framework="pt") as f1:
|
| 150 |
+
keys1 = set(f1.keys())
|
| 151 |
+
|
| 152 |
+
with safe_open(broken_checkpoint, framework="pt") as f2:
|
| 153 |
+
keys2 = set(f2.keys())
|
| 154 |
+
|
| 155 |
+
print(f"\nWorking checkpoint: {len(keys1)} keys")
|
| 156 |
+
print(f"Broken checkpoint: {len(keys2)} keys")
|
| 157 |
+
|
| 158 |
+
only_in_working = keys1 - keys2
|
| 159 |
+
only_in_broken = keys2 - keys1
|
| 160 |
+
common = keys1 & keys2
|
| 161 |
+
|
| 162 |
+
print(f"\nCommon keys: {len(common)}")
|
| 163 |
+
print(f"Only in working: {len(only_in_working)}")
|
| 164 |
+
print(f"Only in broken: {len(only_in_broken)}")
|
| 165 |
+
|
| 166 |
+
if only_in_working:
|
| 167 |
+
print("\n🔍 Keys present in WORKING but missing in BROKEN (first 20):")
|
| 168 |
+
for key in sorted(only_in_working)[:20]:
|
| 169 |
+
print(f" - {key}")
|
| 170 |
+
|
| 171 |
+
if only_in_broken:
|
| 172 |
+
print("\n🔍 Keys present in BROKEN but missing in WORKING (first 20):")
|
| 173 |
+
for key in sorted(only_in_broken)[:20]:
|
| 174 |
+
print(f" + {key}")
|
| 175 |
+
|
| 176 |
+
# Compare key patterns
|
| 177 |
+
print("\n" + "=" * 80)
|
| 178 |
+
print("KEY PATTERN COMPARISON")
|
| 179 |
+
print("=" * 80)
|
| 180 |
+
|
| 181 |
+
def get_prefixes(keys):
|
| 182 |
+
prefixes = defaultdict(int)
|
| 183 |
+
for key in keys:
|
| 184 |
+
prefix = key.split('.')[0]
|
| 185 |
+
prefixes[prefix] += 1
|
| 186 |
+
return prefixes
|
| 187 |
+
|
| 188 |
+
prefixes1 = get_prefixes(keys1)
|
| 189 |
+
prefixes2 = get_prefixes(keys2)
|
| 190 |
+
|
| 191 |
+
all_prefixes = set(prefixes1.keys()) | set(prefixes2.keys())
|
| 192 |
+
|
| 193 |
+
print(f"\n{'Prefix':<30} {'Working':<15} {'Broken':<15}")
|
| 194 |
+
print("-" * 60)
|
| 195 |
+
for prefix in sorted(all_prefixes):
|
| 196 |
+
count1 = prefixes1.get(prefix, 0)
|
| 197 |
+
count2 = prefixes2.get(prefix, 0)
|
| 198 |
+
status = "✅" if count1 == count2 else "⚠️ "
|
| 199 |
+
print(f"{status} {prefix:<28} {count1:<15} {count2:<15}")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# Example usage
|
| 203 |
+
if __name__ == "__main__":
|
| 204 |
+
# Inspect a single checkpoint
|
| 205 |
+
print("OPTION 1: Inspect your working checkpoint")
|
| 206 |
+
print("-" * 80)
|
| 207 |
+
inspect_checkpoint(
|
| 208 |
+
"test/Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors",
|
| 209 |
+
detailed=False # Set to True for full key listing
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
print("\n\n")
|
| 213 |
+
|
| 214 |
+
# Compare two checkpoints
|
| 215 |
+
print("OPTION 2: Compare working vs broken checkpoint")
|
| 216 |
+
print("-" * 80)
|
| 217 |
+
# compare_checkpoints(
|
| 218 |
+
# "path/to/working_checkpoint.safetensors",
|
| 219 |
+
# "path/to/broken_checkpoint.safetensors"
|
| 220 |
+
# )
|
MergeSafetensors.py
CHANGED
|
@@ -67,8 +67,8 @@ def merge_model_components(
|
|
| 67 |
# Example usage
|
| 68 |
if __name__ == "__main__":
|
| 69 |
merge_model_components(
|
| 70 |
-
unet_path="
|
| 71 |
vae_path="vae/diffusion_pytorch_model.safetensors",
|
| 72 |
text_encoder_path="text_encoder/model.safetensors",
|
| 73 |
-
output_path="flux1-depth-
|
| 74 |
)
|
|
|
|
| 67 |
# Example usage
|
| 68 |
if __name__ == "__main__":
|
| 69 |
merge_model_components(
|
| 70 |
+
unet_path="flux1-depth-dev.safetensors",
|
| 71 |
vae_path="vae/diffusion_pytorch_model.safetensors",
|
| 72 |
text_encoder_path="text_encoder/model.safetensors",
|
| 73 |
+
output_path="flux1-depth-dev_merged_model.safetensors"
|
| 74 |
)
|
MergeSafetensors2.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from safetensors.torch import save_file, load_file
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
def inspect_keys(file_path, max_keys=10):
|
| 6 |
+
"""Helper function to inspect the structure of a safetensors file."""
|
| 7 |
+
state = load_file(file_path)
|
| 8 |
+
keys = list(state.keys())
|
| 9 |
+
print(f"\n{os.path.basename(file_path)} - Total keys: {len(keys)}")
|
| 10 |
+
print(f"First {max_keys} keys:")
|
| 11 |
+
for k in keys[:max_keys]:
|
| 12 |
+
print(f" {k}")
|
| 13 |
+
return keys
|
| 14 |
+
|
| 15 |
+
def merge_for_comfyui(
|
| 16 |
+
unet_path,
|
| 17 |
+
vae_path,
|
| 18 |
+
text_encoder_path,
|
| 19 |
+
output_path,
|
| 20 |
+
model_type="flux" # "flux", "sd15", "sdxl"
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
Merge components into ComfyUI-compatible safetensors checkpoint.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
unet_path: Path to the main model/transformer safetensors
|
| 27 |
+
vae_path: Path to the VAE safetensors
|
| 28 |
+
text_encoder_path: Path to the text encoder/CLIP safetensors
|
| 29 |
+
output_path: Path for the merged checkpoint
|
| 30 |
+
model_type: Type of model (flux, sd15, sdxl)
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
print("=" * 60)
|
| 34 |
+
print("STEP 1: Inspecting input files...")
|
| 35 |
+
print("=" * 60)
|
| 36 |
+
|
| 37 |
+
# Inspect each file to understand structure
|
| 38 |
+
unet_keys = inspect_keys(unet_path)
|
| 39 |
+
vae_keys = inspect_keys(vae_path)
|
| 40 |
+
text_encoder_keys = inspect_keys(text_encoder_path)
|
| 41 |
+
|
| 42 |
+
print("\n" + "=" * 60)
|
| 43 |
+
print("STEP 2: Loading weights...")
|
| 44 |
+
print("=" * 60)
|
| 45 |
+
|
| 46 |
+
unet_state = load_file(unet_path)
|
| 47 |
+
vae_state = load_file(vae_path)
|
| 48 |
+
text_encoder_state = load_file(text_encoder_path)
|
| 49 |
+
|
| 50 |
+
print("\n" + "=" * 60)
|
| 51 |
+
print("STEP 3: Merging with proper key structure...")
|
| 52 |
+
print("=" * 60)
|
| 53 |
+
|
| 54 |
+
merged_state = {}
|
| 55 |
+
|
| 56 |
+
# Determine key prefixes based on existing structure
|
| 57 |
+
sample_unet_key = unet_keys[0]
|
| 58 |
+
sample_vae_key = vae_keys[0]
|
| 59 |
+
sample_te_key = text_encoder_keys[0]
|
| 60 |
+
|
| 61 |
+
print(f"\nDetected key patterns:")
|
| 62 |
+
print(f" UNet: {sample_unet_key}")
|
| 63 |
+
print(f" VAE: {sample_vae_key}")
|
| 64 |
+
print(f" Text Encoder: {sample_te_key}")
|
| 65 |
+
|
| 66 |
+
# Add UNet/Transformer weights
|
| 67 |
+
for key, value in unet_state.items():
|
| 68 |
+
# Keep original keys or add model prefix if needed
|
| 69 |
+
if key.startswith('model.') or key.startswith('diffusion_model.'):
|
| 70 |
+
merged_state[key] = value
|
| 71 |
+
else:
|
| 72 |
+
# Add ComfyUI-expected prefix
|
| 73 |
+
merged_state[f'model.diffusion_model.{key}'] = value
|
| 74 |
+
|
| 75 |
+
# Add VAE weights with proper structure
|
| 76 |
+
for key, value in vae_state.items():
|
| 77 |
+
if key.startswith('first_stage_model.') or key.startswith('vae.'):
|
| 78 |
+
merged_state[key] = value
|
| 79 |
+
elif key.startswith('decoder.') or key.startswith('encoder.'):
|
| 80 |
+
merged_state[f'first_stage_model.{key}'] = value
|
| 81 |
+
else:
|
| 82 |
+
merged_state[f'first_stage_model.decoder.{key}'] = value
|
| 83 |
+
|
| 84 |
+
# Add text encoder weights
|
| 85 |
+
for key, value in text_encoder_state.items():
|
| 86 |
+
if key.startswith('cond_stage_model.') or key.startswith('text_encoder.'):
|
| 87 |
+
merged_state[key] = value
|
| 88 |
+
else:
|
| 89 |
+
# For FLUX, might need different structure
|
| 90 |
+
if model_type.lower() == "flux":
|
| 91 |
+
merged_state[f'text_encoders.{key}'] = value
|
| 92 |
+
else:
|
| 93 |
+
merged_state[f'cond_stage_model.transformer.{key}'] = value
|
| 94 |
+
|
| 95 |
+
print(f"\nMerged state contains {len(merged_state)} parameters")
|
| 96 |
+
|
| 97 |
+
# Add metadata for ComfyUI recognition
|
| 98 |
+
print("\n" + "=" * 60)
|
| 99 |
+
print("STEP 4: Saving merged checkpoint...")
|
| 100 |
+
print("=" * 60)
|
| 101 |
+
|
| 102 |
+
save_file(merged_state, output_path)
|
| 103 |
+
|
| 104 |
+
print("\n✅ Merge complete!")
|
| 105 |
+
print(f"File saved to: {output_path}")
|
| 106 |
+
|
| 107 |
+
size_gb = os.path.getsize(output_path) / (1024**3)
|
| 108 |
+
print(f"File size: {size_gb:.2f} GB")
|
| 109 |
+
|
| 110 |
+
# Verify the merged file
|
| 111 |
+
print("\n" + "=" * 60)
|
| 112 |
+
print("STEP 5: Verifying merged file...")
|
| 113 |
+
print("=" * 60)
|
| 114 |
+
inspect_keys(output_path, max_keys=20)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def simple_merge_keep_structure(
|
| 118 |
+
unet_path,
|
| 119 |
+
vae_path,
|
| 120 |
+
text_encoder_path,
|
| 121 |
+
output_path
|
| 122 |
+
):
|
| 123 |
+
"""
|
| 124 |
+
Simple merge that preserves original key structure.
|
| 125 |
+
Use this if the files already have proper ComfyUI keys.
|
| 126 |
+
"""
|
| 127 |
+
print("Loading all components...")
|
| 128 |
+
|
| 129 |
+
unet_state = load_file(unet_path)
|
| 130 |
+
vae_state = load_file(vae_path)
|
| 131 |
+
text_encoder_state = load_file(text_encoder_path)
|
| 132 |
+
|
| 133 |
+
print("Merging...")
|
| 134 |
+
merged_state = {}
|
| 135 |
+
merged_state.update(unet_state)
|
| 136 |
+
merged_state.update(vae_state)
|
| 137 |
+
merged_state.update(text_encoder_state)
|
| 138 |
+
|
| 139 |
+
print(f"Saving {len(merged_state)} parameters...")
|
| 140 |
+
save_file(merged_state, output_path)
|
| 141 |
+
|
| 142 |
+
size_gb = os.path.getsize(output_path) / (1024**3)
|
| 143 |
+
print(f"✅ Done! File size: {size_gb:.2f} GB")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# Example usage
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
# Option 1: Smart merge with key detection
|
| 149 |
+
merge_for_comfyui(
|
| 150 |
+
unet_path="svdq-fp4_r32-flux.1-depth-dev.safetensors",
|
| 151 |
+
vae_path="vae/diffusion_pytorch_model.safetensors",
|
| 152 |
+
text_encoder_path="text_encoder/model.safetensors",
|
| 153 |
+
output_path="flux1-depth-dev_fp4_merged_model.safetensors",
|
| 154 |
+
model_type="flux"
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Option 2: Simple merge (if keys are already correct)
|
| 158 |
+
# simple_merge_keep_structure(
|
| 159 |
+
# unet_path="path/to/model.safetensors",
|
| 160 |
+
# vae_path="path/to/vae.safetensors",
|
| 161 |
+
# text_encoder_path="path/to/text_encoder.safetensors",
|
| 162 |
+
# output_path="merged_checkpoint.safetensors"
|
| 163 |
+
# )
|
ae.safetensors
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:afc8e28272cd15db3919bacdb6918ce9c1ed22e96cb12c4d5ed0fba823529e38
|
| 3 |
-
size 335304388
|
|
|
|
|
|
|
|
|
|
|
|
flux1-depth-dev_fp4_merged_model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a7eff742003708c77b291d7ca416bf695dc7c01a3a26250febac7c79ee6f390d
|
| 3 |
+
size 7866741316
|