File size: 7,839 Bytes
45131ff 1bf81fc 45131ff 1bf81fc 45131ff 1bf81fc 45131ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
from safetensors import safe_open
from collections import defaultdict
import os
def inspect_checkpoint(checkpoint_path, detailed=False):
"""
Inspect the structure of a safetensors checkpoint file.
Args:
checkpoint_path: Path to the .safetensors file
detailed: If True, shows more detailed information
"""
if not os.path.exists(checkpoint_path):
print(f"β File not found: {checkpoint_path}")
return
print("=" * 80)
print(f"INSPECTING: {os.path.basename(checkpoint_path)}")
print("=" * 80)
# File size
size_bytes = os.path.getsize(checkpoint_path)
size_gb = size_bytes / (1024**3)
print(f"\nπ¦ File Size: {size_gb:.2f} GB ({size_bytes:,} bytes)")
with safe_open(checkpoint_path, framework="pt") as f:
keys = list(f.keys())
print(f"\nπ Total Parameters: {len(keys):,}")
# Categorize keys by component
print("\n" + "=" * 80)
print("COMPONENT BREAKDOWN")
print("=" * 80)
categories = defaultdict(list)
for key in keys:
# Categorize by prefix
if any(x in key.lower() for x in ['vae', 'first_stage', 'decoder', 'encoder', 'quant_conv', 'post_quant']):
categories['VAE'].append(key)
elif any(x in key.lower() for x in ['text_encoder', 'cond_stage', 'clip', 'transformer.text_model']):
categories['Text Encoder'].append(key)
elif any(x in key.lower() for x in ['model.diffusion', 'diffusion_model', 'transformer', 'double_blocks', 'single_blocks']):
categories['UNet/Transformer'].append(key)
else:
categories['Other'].append(key)
# Print summary
for category, cat_keys in sorted(categories.items()):
print(f"\n{category}: {len(cat_keys)} parameters")
# Analyze key patterns
print("\n" + "=" * 80)
print("KEY PATTERNS")
print("=" * 80)
# Group by top-level prefix
prefix_groups = defaultdict(int)
for key in keys:
prefix = key.split('.')[0] if '.' in key else key
prefix_groups[prefix] += 1
print("\nTop-level prefixes:")
for prefix, count in sorted(prefix_groups.items(), key=lambda x: -x[1]):
print(f" {prefix}: {count} parameters")
# Show sample keys from each category
print("\n" + "=" * 80)
print("SAMPLE KEYS FROM EACH COMPONENT")
print("=" * 80)
for category, cat_keys in sorted(categories.items()):
if cat_keys:
print(f"\n{category} (showing first 5):")
for key in cat_keys[:5]:
tensor = f.get_tensor(key)
print(f" {key}")
print(f" ββ shape: {tuple(tensor.shape)}, dtype: {tensor.dtype}")
if detailed:
print("\n" + "=" * 80)
print("ALL KEYS (DETAILED)")
print("=" * 80)
for i, key in enumerate(keys, 1):
tensor = f.get_tensor(key)
print(f"\n{i}. {key}")
print(f" Shape: {tuple(tensor.shape)}")
print(f" Dtype: {tensor.dtype}")
print(f" Size: {tensor.numel():,} elements")
# Check for common FLUX/SD patterns
print("\n" + "=" * 80)
print("MODEL TYPE DETECTION")
print("=" * 80)
has_flux_blocks = any('double_blocks' in k or 'single_blocks' in k for k in keys)
has_sd_unet = any('model.diffusion_model' in k for k in keys)
has_vae = any('vae' in k.lower() or 'first_stage' in k for k in keys)
has_text_encoder = any('text_encoder' in k.lower() or 'cond_stage' in k for k in keys)
print(f"\nβ FLUX-style blocks: {'β
YES' if has_flux_blocks else 'β NO'}")
print(f"β SD-style UNet: {'β
YES' if has_sd_unet else 'β NO'}")
print(f"β VAE included: {'β
YES' if has_vae else 'β NO'}")
print(f"β Text Encoder included: {'β
YES' if has_text_encoder else 'β NO'}")
if has_flux_blocks:
print("\nπ Likely model type: FLUX")
elif has_sd_unet:
print("\nπ Likely model type: Stable Diffusion")
else:
print("\nβ οΈ Could not determine model type")
# Check if complete checkpoint
print("\n" + "=" * 80)
print("CHECKPOINT COMPLETENESS")
print("=" * 80)
if has_vae and has_text_encoder:
print("\nβ
This appears to be a COMPLETE checkpoint")
print(" (Contains UNet/Transformer + VAE + Text Encoder)")
else:
print("\nβ οΈ This appears to be a PARTIAL checkpoint")
if not has_vae:
print(" Missing: VAE")
if not has_text_encoder:
print(" Missing: Text Encoder")
print("\n" + "=" * 80)
print("INSPECTION COMPLETE")
print("=" * 80)
def compare_checkpoints(working_checkpoint, broken_checkpoint):
"""
Compare two checkpoints to see the differences.
Args:
working_checkpoint: Path to checkpoint that works
broken_checkpoint: Path to checkpoint that doesn't work
"""
print("=" * 80)
print("COMPARING CHECKPOINTS")
print("=" * 80)
with safe_open(working_checkpoint, framework="pt") as f1:
keys1 = set(f1.keys())
with safe_open(broken_checkpoint, framework="pt") as f2:
keys2 = set(f2.keys())
print(f"\nWorking checkpoint: {len(keys1)} keys")
print(f"Broken checkpoint: {len(keys2)} keys")
only_in_working = keys1 - keys2
only_in_broken = keys2 - keys1
common = keys1 & keys2
print(f"\nCommon keys: {len(common)}")
print(f"Only in working: {len(only_in_working)}")
print(f"Only in broken: {len(only_in_broken)}")
if only_in_working:
print("\nπ Keys present in WORKING but missing in BROKEN (first 20):")
for key in sorted(only_in_working)[:20]:
print(f" - {key}")
if only_in_broken:
print("\nπ Keys present in BROKEN but missing in WORKING (first 20):")
for key in sorted(only_in_broken)[:20]:
print(f" + {key}")
# Compare key patterns
print("\n" + "=" * 80)
print("KEY PATTERN COMPARISON")
print("=" * 80)
def get_prefixes(keys):
prefixes = defaultdict(int)
for key in keys:
prefix = key.split('.')[0]
prefixes[prefix] += 1
return prefixes
prefixes1 = get_prefixes(keys1)
prefixes2 = get_prefixes(keys2)
all_prefixes = set(prefixes1.keys()) | set(prefixes2.keys())
print(f"\n{'Prefix':<30} {'Working':<15} {'Broken':<15}")
print("-" * 60)
for prefix in sorted(all_prefixes):
count1 = prefixes1.get(prefix, 0)
count2 = prefixes2.get(prefix, 0)
status = "β
" if count1 == count2 else "β οΈ "
print(f"{status} {prefix:<28} {count1:<15} {count2:<15}")
# Example usage
if __name__ == "__main__":
# Inspect a single checkpoint
print("OPTION 1: Inspect your working checkpoint")
print("-" * 80)
inspect_checkpoint(
"../flux1-depth-dev_ComfyMerged.safetensors",
detailed=False # Set to True for full key listing
)
print("\n\n")
# Compare two checkpoints
# print("OPTION 2: Compare working vs broken checkpoint")
# print("-" * 80)
# compare_checkpoints(
# "Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors",
# "flux1-depth-dev_fp4_merged_model.safetensors"
# )
|