Flux1-Depth-Dev / scripts /InspectSafetensors.py
srcphag's picture
New Diffusers structure
1bf81fc
from safetensors import safe_open
from collections import defaultdict
import os
def inspect_checkpoint(checkpoint_path, detailed=False):
"""
Inspect the structure of a safetensors checkpoint file.
Args:
checkpoint_path: Path to the .safetensors file
detailed: If True, shows more detailed information
"""
if not os.path.exists(checkpoint_path):
print(f"❌ File not found: {checkpoint_path}")
return
print("=" * 80)
print(f"INSPECTING: {os.path.basename(checkpoint_path)}")
print("=" * 80)
# File size
size_bytes = os.path.getsize(checkpoint_path)
size_gb = size_bytes / (1024**3)
print(f"\nπŸ“¦ File Size: {size_gb:.2f} GB ({size_bytes:,} bytes)")
with safe_open(checkpoint_path, framework="pt") as f:
keys = list(f.keys())
print(f"\nπŸ“Š Total Parameters: {len(keys):,}")
# Categorize keys by component
print("\n" + "=" * 80)
print("COMPONENT BREAKDOWN")
print("=" * 80)
categories = defaultdict(list)
for key in keys:
# Categorize by prefix
if any(x in key.lower() for x in ['vae', 'first_stage', 'decoder', 'encoder', 'quant_conv', 'post_quant']):
categories['VAE'].append(key)
elif any(x in key.lower() for x in ['text_encoder', 'cond_stage', 'clip', 'transformer.text_model']):
categories['Text Encoder'].append(key)
elif any(x in key.lower() for x in ['model.diffusion', 'diffusion_model', 'transformer', 'double_blocks', 'single_blocks']):
categories['UNet/Transformer'].append(key)
else:
categories['Other'].append(key)
# Print summary
for category, cat_keys in sorted(categories.items()):
print(f"\n{category}: {len(cat_keys)} parameters")
# Analyze key patterns
print("\n" + "=" * 80)
print("KEY PATTERNS")
print("=" * 80)
# Group by top-level prefix
prefix_groups = defaultdict(int)
for key in keys:
prefix = key.split('.')[0] if '.' in key else key
prefix_groups[prefix] += 1
print("\nTop-level prefixes:")
for prefix, count in sorted(prefix_groups.items(), key=lambda x: -x[1]):
print(f" {prefix}: {count} parameters")
# Show sample keys from each category
print("\n" + "=" * 80)
print("SAMPLE KEYS FROM EACH COMPONENT")
print("=" * 80)
for category, cat_keys in sorted(categories.items()):
if cat_keys:
print(f"\n{category} (showing first 5):")
for key in cat_keys[:5]:
tensor = f.get_tensor(key)
print(f" {key}")
print(f" └─ shape: {tuple(tensor.shape)}, dtype: {tensor.dtype}")
if detailed:
print("\n" + "=" * 80)
print("ALL KEYS (DETAILED)")
print("=" * 80)
for i, key in enumerate(keys, 1):
tensor = f.get_tensor(key)
print(f"\n{i}. {key}")
print(f" Shape: {tuple(tensor.shape)}")
print(f" Dtype: {tensor.dtype}")
print(f" Size: {tensor.numel():,} elements")
# Check for common FLUX/SD patterns
print("\n" + "=" * 80)
print("MODEL TYPE DETECTION")
print("=" * 80)
has_flux_blocks = any('double_blocks' in k or 'single_blocks' in k for k in keys)
has_sd_unet = any('model.diffusion_model' in k for k in keys)
has_vae = any('vae' in k.lower() or 'first_stage' in k for k in keys)
has_text_encoder = any('text_encoder' in k.lower() or 'cond_stage' in k for k in keys)
print(f"\nβœ“ FLUX-style blocks: {'βœ… YES' if has_flux_blocks else '❌ NO'}")
print(f"βœ“ SD-style UNet: {'βœ… YES' if has_sd_unet else '❌ NO'}")
print(f"βœ“ VAE included: {'βœ… YES' if has_vae else '❌ NO'}")
print(f"βœ“ Text Encoder included: {'βœ… YES' if has_text_encoder else '❌ NO'}")
if has_flux_blocks:
print("\nπŸ” Likely model type: FLUX")
elif has_sd_unet:
print("\nπŸ” Likely model type: Stable Diffusion")
else:
print("\n⚠️ Could not determine model type")
# Check if complete checkpoint
print("\n" + "=" * 80)
print("CHECKPOINT COMPLETENESS")
print("=" * 80)
if has_vae and has_text_encoder:
print("\nβœ… This appears to be a COMPLETE checkpoint")
print(" (Contains UNet/Transformer + VAE + Text Encoder)")
else:
print("\n⚠️ This appears to be a PARTIAL checkpoint")
if not has_vae:
print(" Missing: VAE")
if not has_text_encoder:
print(" Missing: Text Encoder")
print("\n" + "=" * 80)
print("INSPECTION COMPLETE")
print("=" * 80)
def compare_checkpoints(working_checkpoint, broken_checkpoint):
"""
Compare two checkpoints to see the differences.
Args:
working_checkpoint: Path to checkpoint that works
broken_checkpoint: Path to checkpoint that doesn't work
"""
print("=" * 80)
print("COMPARING CHECKPOINTS")
print("=" * 80)
with safe_open(working_checkpoint, framework="pt") as f1:
keys1 = set(f1.keys())
with safe_open(broken_checkpoint, framework="pt") as f2:
keys2 = set(f2.keys())
print(f"\nWorking checkpoint: {len(keys1)} keys")
print(f"Broken checkpoint: {len(keys2)} keys")
only_in_working = keys1 - keys2
only_in_broken = keys2 - keys1
common = keys1 & keys2
print(f"\nCommon keys: {len(common)}")
print(f"Only in working: {len(only_in_working)}")
print(f"Only in broken: {len(only_in_broken)}")
if only_in_working:
print("\nπŸ” Keys present in WORKING but missing in BROKEN (first 20):")
for key in sorted(only_in_working)[:20]:
print(f" - {key}")
if only_in_broken:
print("\nπŸ” Keys present in BROKEN but missing in WORKING (first 20):")
for key in sorted(only_in_broken)[:20]:
print(f" + {key}")
# Compare key patterns
print("\n" + "=" * 80)
print("KEY PATTERN COMPARISON")
print("=" * 80)
def get_prefixes(keys):
prefixes = defaultdict(int)
for key in keys:
prefix = key.split('.')[0]
prefixes[prefix] += 1
return prefixes
prefixes1 = get_prefixes(keys1)
prefixes2 = get_prefixes(keys2)
all_prefixes = set(prefixes1.keys()) | set(prefixes2.keys())
print(f"\n{'Prefix':<30} {'Working':<15} {'Broken':<15}")
print("-" * 60)
for prefix in sorted(all_prefixes):
count1 = prefixes1.get(prefix, 0)
count2 = prefixes2.get(prefix, 0)
status = "βœ…" if count1 == count2 else "⚠️ "
print(f"{status} {prefix:<28} {count1:<15} {count2:<15}")
# Example usage
if __name__ == "__main__":
# Inspect a single checkpoint
print("OPTION 1: Inspect your working checkpoint")
print("-" * 80)
inspect_checkpoint(
"../flux1-depth-dev_ComfyMerged.safetensors",
detailed=False # Set to True for full key listing
)
print("\n\n")
# Compare two checkpoints
# print("OPTION 2: Compare working vs broken checkpoint")
# print("-" * 80)
# compare_checkpoints(
# "Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors",
# "flux1-depth-dev_fp4_merged_model.safetensors"
# )