|
|
from safetensors import safe_open |
|
|
from collections import defaultdict |
|
|
import os |
|
|
|
|
|
def inspect_checkpoint(checkpoint_path, detailed=False): |
|
|
""" |
|
|
Inspect the structure of a safetensors checkpoint file. |
|
|
|
|
|
Args: |
|
|
checkpoint_path: Path to the .safetensors file |
|
|
detailed: If True, shows more detailed information |
|
|
""" |
|
|
|
|
|
if not os.path.exists(checkpoint_path): |
|
|
print(f"β File not found: {checkpoint_path}") |
|
|
return |
|
|
|
|
|
print("=" * 80) |
|
|
print(f"INSPECTING: {os.path.basename(checkpoint_path)}") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
size_bytes = os.path.getsize(checkpoint_path) |
|
|
size_gb = size_bytes / (1024**3) |
|
|
print(f"\nπ¦ File Size: {size_gb:.2f} GB ({size_bytes:,} bytes)") |
|
|
|
|
|
with safe_open(checkpoint_path, framework="pt") as f: |
|
|
keys = list(f.keys()) |
|
|
|
|
|
print(f"\nπ Total Parameters: {len(keys):,}") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("COMPONENT BREAKDOWN") |
|
|
print("=" * 80) |
|
|
|
|
|
categories = defaultdict(list) |
|
|
|
|
|
for key in keys: |
|
|
|
|
|
if any(x in key.lower() for x in ['vae', 'first_stage', 'decoder', 'encoder', 'quant_conv', 'post_quant']): |
|
|
categories['VAE'].append(key) |
|
|
elif any(x in key.lower() for x in ['text_encoder', 'cond_stage', 'clip', 'transformer.text_model']): |
|
|
categories['Text Encoder'].append(key) |
|
|
elif any(x in key.lower() for x in ['model.diffusion', 'diffusion_model', 'transformer', 'double_blocks', 'single_blocks']): |
|
|
categories['UNet/Transformer'].append(key) |
|
|
else: |
|
|
categories['Other'].append(key) |
|
|
|
|
|
|
|
|
for category, cat_keys in sorted(categories.items()): |
|
|
print(f"\n{category}: {len(cat_keys)} parameters") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("KEY PATTERNS") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
prefix_groups = defaultdict(int) |
|
|
for key in keys: |
|
|
prefix = key.split('.')[0] if '.' in key else key |
|
|
prefix_groups[prefix] += 1 |
|
|
|
|
|
print("\nTop-level prefixes:") |
|
|
for prefix, count in sorted(prefix_groups.items(), key=lambda x: -x[1]): |
|
|
print(f" {prefix}: {count} parameters") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("SAMPLE KEYS FROM EACH COMPONENT") |
|
|
print("=" * 80) |
|
|
|
|
|
for category, cat_keys in sorted(categories.items()): |
|
|
if cat_keys: |
|
|
print(f"\n{category} (showing first 5):") |
|
|
for key in cat_keys[:5]: |
|
|
tensor = f.get_tensor(key) |
|
|
print(f" {key}") |
|
|
print(f" ββ shape: {tuple(tensor.shape)}, dtype: {tensor.dtype}") |
|
|
|
|
|
if detailed: |
|
|
print("\n" + "=" * 80) |
|
|
print("ALL KEYS (DETAILED)") |
|
|
print("=" * 80) |
|
|
|
|
|
for i, key in enumerate(keys, 1): |
|
|
tensor = f.get_tensor(key) |
|
|
print(f"\n{i}. {key}") |
|
|
print(f" Shape: {tuple(tensor.shape)}") |
|
|
print(f" Dtype: {tensor.dtype}") |
|
|
print(f" Size: {tensor.numel():,} elements") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("MODEL TYPE DETECTION") |
|
|
print("=" * 80) |
|
|
|
|
|
has_flux_blocks = any('double_blocks' in k or 'single_blocks' in k for k in keys) |
|
|
has_sd_unet = any('model.diffusion_model' in k for k in keys) |
|
|
has_vae = any('vae' in k.lower() or 'first_stage' in k for k in keys) |
|
|
has_text_encoder = any('text_encoder' in k.lower() or 'cond_stage' in k for k in keys) |
|
|
|
|
|
print(f"\nβ FLUX-style blocks: {'β
YES' if has_flux_blocks else 'β NO'}") |
|
|
print(f"β SD-style UNet: {'β
YES' if has_sd_unet else 'β NO'}") |
|
|
print(f"β VAE included: {'β
YES' if has_vae else 'β NO'}") |
|
|
print(f"β Text Encoder included: {'β
YES' if has_text_encoder else 'β NO'}") |
|
|
|
|
|
if has_flux_blocks: |
|
|
print("\nπ Likely model type: FLUX") |
|
|
elif has_sd_unet: |
|
|
print("\nπ Likely model type: Stable Diffusion") |
|
|
else: |
|
|
print("\nβ οΈ Could not determine model type") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("CHECKPOINT COMPLETENESS") |
|
|
print("=" * 80) |
|
|
|
|
|
if has_vae and has_text_encoder: |
|
|
print("\nβ
This appears to be a COMPLETE checkpoint") |
|
|
print(" (Contains UNet/Transformer + VAE + Text Encoder)") |
|
|
else: |
|
|
print("\nβ οΈ This appears to be a PARTIAL checkpoint") |
|
|
if not has_vae: |
|
|
print(" Missing: VAE") |
|
|
if not has_text_encoder: |
|
|
print(" Missing: Text Encoder") |
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("INSPECTION COMPLETE") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
def compare_checkpoints(working_checkpoint, broken_checkpoint): |
|
|
""" |
|
|
Compare two checkpoints to see the differences. |
|
|
|
|
|
Args: |
|
|
working_checkpoint: Path to checkpoint that works |
|
|
broken_checkpoint: Path to checkpoint that doesn't work |
|
|
""" |
|
|
|
|
|
print("=" * 80) |
|
|
print("COMPARING CHECKPOINTS") |
|
|
print("=" * 80) |
|
|
|
|
|
with safe_open(working_checkpoint, framework="pt") as f1: |
|
|
keys1 = set(f1.keys()) |
|
|
|
|
|
with safe_open(broken_checkpoint, framework="pt") as f2: |
|
|
keys2 = set(f2.keys()) |
|
|
|
|
|
print(f"\nWorking checkpoint: {len(keys1)} keys") |
|
|
print(f"Broken checkpoint: {len(keys2)} keys") |
|
|
|
|
|
only_in_working = keys1 - keys2 |
|
|
only_in_broken = keys2 - keys1 |
|
|
common = keys1 & keys2 |
|
|
|
|
|
print(f"\nCommon keys: {len(common)}") |
|
|
print(f"Only in working: {len(only_in_working)}") |
|
|
print(f"Only in broken: {len(only_in_broken)}") |
|
|
|
|
|
if only_in_working: |
|
|
print("\nπ Keys present in WORKING but missing in BROKEN (first 20):") |
|
|
for key in sorted(only_in_working)[:20]: |
|
|
print(f" - {key}") |
|
|
|
|
|
if only_in_broken: |
|
|
print("\nπ Keys present in BROKEN but missing in WORKING (first 20):") |
|
|
for key in sorted(only_in_broken)[:20]: |
|
|
print(f" + {key}") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 80) |
|
|
print("KEY PATTERN COMPARISON") |
|
|
print("=" * 80) |
|
|
|
|
|
def get_prefixes(keys): |
|
|
prefixes = defaultdict(int) |
|
|
for key in keys: |
|
|
prefix = key.split('.')[0] |
|
|
prefixes[prefix] += 1 |
|
|
return prefixes |
|
|
|
|
|
prefixes1 = get_prefixes(keys1) |
|
|
prefixes2 = get_prefixes(keys2) |
|
|
|
|
|
all_prefixes = set(prefixes1.keys()) | set(prefixes2.keys()) |
|
|
|
|
|
print(f"\n{'Prefix':<30} {'Working':<15} {'Broken':<15}") |
|
|
print("-" * 60) |
|
|
for prefix in sorted(all_prefixes): |
|
|
count1 = prefixes1.get(prefix, 0) |
|
|
count2 = prefixes2.get(prefix, 0) |
|
|
status = "β
" if count1 == count2 else "β οΈ " |
|
|
print(f"{status} {prefix:<28} {count1:<15} {count2:<15}") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("OPTION 1: Inspect your working checkpoint") |
|
|
print("-" * 80) |
|
|
inspect_checkpoint( |
|
|
"../flux1-depth-dev_ComfyMerged.safetensors", |
|
|
detailed=False |
|
|
) |
|
|
|
|
|
print("\n\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|