from safetensors import safe_open from collections import defaultdict import os def inspect_checkpoint(checkpoint_path, detailed=False): """ Inspect the structure of a safetensors checkpoint file. Args: checkpoint_path: Path to the .safetensors file detailed: If True, shows more detailed information """ if not os.path.exists(checkpoint_path): print(f"āŒ File not found: {checkpoint_path}") return print("=" * 80) print(f"INSPECTING: {os.path.basename(checkpoint_path)}") print("=" * 80) # File size size_bytes = os.path.getsize(checkpoint_path) size_gb = size_bytes / (1024**3) print(f"\nšŸ“¦ File Size: {size_gb:.2f} GB ({size_bytes:,} bytes)") with safe_open(checkpoint_path, framework="pt") as f: keys = list(f.keys()) print(f"\nšŸ“Š Total Parameters: {len(keys):,}") # Categorize keys by component print("\n" + "=" * 80) print("COMPONENT BREAKDOWN") print("=" * 80) categories = defaultdict(list) for key in keys: # Categorize by prefix if any(x in key.lower() for x in ['vae', 'first_stage', 'decoder', 'encoder', 'quant_conv', 'post_quant']): categories['VAE'].append(key) elif any(x in key.lower() for x in ['text_encoder', 'cond_stage', 'clip', 'transformer.text_model']): categories['Text Encoder'].append(key) elif any(x in key.lower() for x in ['model.diffusion', 'diffusion_model', 'transformer', 'double_blocks', 'single_blocks']): categories['UNet/Transformer'].append(key) else: categories['Other'].append(key) # Print summary for category, cat_keys in sorted(categories.items()): print(f"\n{category}: {len(cat_keys)} parameters") # Analyze key patterns print("\n" + "=" * 80) print("KEY PATTERNS") print("=" * 80) # Group by top-level prefix prefix_groups = defaultdict(int) for key in keys: prefix = key.split('.')[0] if '.' in key else key prefix_groups[prefix] += 1 print("\nTop-level prefixes:") for prefix, count in sorted(prefix_groups.items(), key=lambda x: -x[1]): print(f" {prefix}: {count} parameters") # Show sample keys from each category print("\n" + "=" * 80) print("SAMPLE KEYS FROM EACH COMPONENT") print("=" * 80) for category, cat_keys in sorted(categories.items()): if cat_keys: print(f"\n{category} (showing first 5):") for key in cat_keys[:5]: tensor = f.get_tensor(key) print(f" {key}") print(f" └─ shape: {tuple(tensor.shape)}, dtype: {tensor.dtype}") if detailed: print("\n" + "=" * 80) print("ALL KEYS (DETAILED)") print("=" * 80) for i, key in enumerate(keys, 1): tensor = f.get_tensor(key) print(f"\n{i}. {key}") print(f" Shape: {tuple(tensor.shape)}") print(f" Dtype: {tensor.dtype}") print(f" Size: {tensor.numel():,} elements") # Check for common FLUX/SD patterns print("\n" + "=" * 80) print("MODEL TYPE DETECTION") print("=" * 80) has_flux_blocks = any('double_blocks' in k or 'single_blocks' in k for k in keys) has_sd_unet = any('model.diffusion_model' in k for k in keys) has_vae = any('vae' in k.lower() or 'first_stage' in k for k in keys) has_text_encoder = any('text_encoder' in k.lower() or 'cond_stage' in k for k in keys) print(f"\nāœ“ FLUX-style blocks: {'āœ… YES' if has_flux_blocks else 'āŒ NO'}") print(f"āœ“ SD-style UNet: {'āœ… YES' if has_sd_unet else 'āŒ NO'}") print(f"āœ“ VAE included: {'āœ… YES' if has_vae else 'āŒ NO'}") print(f"āœ“ Text Encoder included: {'āœ… YES' if has_text_encoder else 'āŒ NO'}") if has_flux_blocks: print("\nšŸ” Likely model type: FLUX") elif has_sd_unet: print("\nšŸ” Likely model type: Stable Diffusion") else: print("\nāš ļø Could not determine model type") # Check if complete checkpoint print("\n" + "=" * 80) print("CHECKPOINT COMPLETENESS") print("=" * 80) if has_vae and has_text_encoder: print("\nāœ… This appears to be a COMPLETE checkpoint") print(" (Contains UNet/Transformer + VAE + Text Encoder)") else: print("\nāš ļø This appears to be a PARTIAL checkpoint") if not has_vae: print(" Missing: VAE") if not has_text_encoder: print(" Missing: Text Encoder") print("\n" + "=" * 80) print("INSPECTION COMPLETE") print("=" * 80) def compare_checkpoints(working_checkpoint, broken_checkpoint): """ Compare two checkpoints to see the differences. Args: working_checkpoint: Path to checkpoint that works broken_checkpoint: Path to checkpoint that doesn't work """ print("=" * 80) print("COMPARING CHECKPOINTS") print("=" * 80) with safe_open(working_checkpoint, framework="pt") as f1: keys1 = set(f1.keys()) with safe_open(broken_checkpoint, framework="pt") as f2: keys2 = set(f2.keys()) print(f"\nWorking checkpoint: {len(keys1)} keys") print(f"Broken checkpoint: {len(keys2)} keys") only_in_working = keys1 - keys2 only_in_broken = keys2 - keys1 common = keys1 & keys2 print(f"\nCommon keys: {len(common)}") print(f"Only in working: {len(only_in_working)}") print(f"Only in broken: {len(only_in_broken)}") if only_in_working: print("\nšŸ” Keys present in WORKING but missing in BROKEN (first 20):") for key in sorted(only_in_working)[:20]: print(f" - {key}") if only_in_broken: print("\nšŸ” Keys present in BROKEN but missing in WORKING (first 20):") for key in sorted(only_in_broken)[:20]: print(f" + {key}") # Compare key patterns print("\n" + "=" * 80) print("KEY PATTERN COMPARISON") print("=" * 80) def get_prefixes(keys): prefixes = defaultdict(int) for key in keys: prefix = key.split('.')[0] prefixes[prefix] += 1 return prefixes prefixes1 = get_prefixes(keys1) prefixes2 = get_prefixes(keys2) all_prefixes = set(prefixes1.keys()) | set(prefixes2.keys()) print(f"\n{'Prefix':<30} {'Working':<15} {'Broken':<15}") print("-" * 60) for prefix in sorted(all_prefixes): count1 = prefixes1.get(prefix, 0) count2 = prefixes2.get(prefix, 0) status = "āœ…" if count1 == count2 else "āš ļø " print(f"{status} {prefix:<28} {count1:<15} {count2:<15}") # Example usage if __name__ == "__main__": # Inspect a single checkpoint print("OPTION 1: Inspect your working checkpoint") print("-" * 80) inspect_checkpoint( "../flux1-depth-dev_ComfyMerged.safetensors", detailed=False # Set to True for full key listing ) print("\n\n") # Compare two checkpoints # print("OPTION 2: Compare working vs broken checkpoint") # print("-" * 80) # compare_checkpoints( # "Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors", # "flux1-depth-dev_fp4_merged_model.safetensors" # )