File size: 7,839 Bytes
45131ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bf81fc
45131ff
 
 
 
 
 
1bf81fc
 
45131ff
1bf81fc
 
45131ff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
from safetensors import safe_open
from collections import defaultdict
import os

def inspect_checkpoint(checkpoint_path, detailed=False):
    """
    Inspect the structure of a safetensors checkpoint file.
    
    Args:
        checkpoint_path: Path to the .safetensors file
        detailed: If True, shows more detailed information
    """
    
    if not os.path.exists(checkpoint_path):
        print(f"❌ File not found: {checkpoint_path}")
        return
    
    print("=" * 80)
    print(f"INSPECTING: {os.path.basename(checkpoint_path)}")
    print("=" * 80)
    
    # File size
    size_bytes = os.path.getsize(checkpoint_path)
    size_gb = size_bytes / (1024**3)
    print(f"\nπŸ“¦ File Size: {size_gb:.2f} GB ({size_bytes:,} bytes)")
    
    with safe_open(checkpoint_path, framework="pt") as f:
        keys = list(f.keys())
        
        print(f"\nπŸ“Š Total Parameters: {len(keys):,}")
        
        # Categorize keys by component
        print("\n" + "=" * 80)
        print("COMPONENT BREAKDOWN")
        print("=" * 80)
        
        categories = defaultdict(list)
        
        for key in keys:
            # Categorize by prefix
            if any(x in key.lower() for x in ['vae', 'first_stage', 'decoder', 'encoder', 'quant_conv', 'post_quant']):
                categories['VAE'].append(key)
            elif any(x in key.lower() for x in ['text_encoder', 'cond_stage', 'clip', 'transformer.text_model']):
                categories['Text Encoder'].append(key)
            elif any(x in key.lower() for x in ['model.diffusion', 'diffusion_model', 'transformer', 'double_blocks', 'single_blocks']):
                categories['UNet/Transformer'].append(key)
            else:
                categories['Other'].append(key)
        
        # Print summary
        for category, cat_keys in sorted(categories.items()):
            print(f"\n{category}: {len(cat_keys)} parameters")
        
        # Analyze key patterns
        print("\n" + "=" * 80)
        print("KEY PATTERNS")
        print("=" * 80)
        
        # Group by top-level prefix
        prefix_groups = defaultdict(int)
        for key in keys:
            prefix = key.split('.')[0] if '.' in key else key
            prefix_groups[prefix] += 1
        
        print("\nTop-level prefixes:")
        for prefix, count in sorted(prefix_groups.items(), key=lambda x: -x[1]):
            print(f"  {prefix}: {count} parameters")
        
        # Show sample keys from each category
        print("\n" + "=" * 80)
        print("SAMPLE KEYS FROM EACH COMPONENT")
        print("=" * 80)
        
        for category, cat_keys in sorted(categories.items()):
            if cat_keys:
                print(f"\n{category} (showing first 5):")
                for key in cat_keys[:5]:
                    tensor = f.get_tensor(key)
                    print(f"  {key}")
                    print(f"    └─ shape: {tuple(tensor.shape)}, dtype: {tensor.dtype}")
        
        if detailed:
            print("\n" + "=" * 80)
            print("ALL KEYS (DETAILED)")
            print("=" * 80)
            
            for i, key in enumerate(keys, 1):
                tensor = f.get_tensor(key)
                print(f"\n{i}. {key}")
                print(f"   Shape: {tuple(tensor.shape)}")
                print(f"   Dtype: {tensor.dtype}")
                print(f"   Size: {tensor.numel():,} elements")
        
        # Check for common FLUX/SD patterns
        print("\n" + "=" * 80)
        print("MODEL TYPE DETECTION")
        print("=" * 80)
        
        has_flux_blocks = any('double_blocks' in k or 'single_blocks' in k for k in keys)
        has_sd_unet = any('model.diffusion_model' in k for k in keys)
        has_vae = any('vae' in k.lower() or 'first_stage' in k for k in keys)
        has_text_encoder = any('text_encoder' in k.lower() or 'cond_stage' in k for k in keys)
        
        print(f"\nβœ“ FLUX-style blocks: {'βœ… YES' if has_flux_blocks else '❌ NO'}")
        print(f"βœ“ SD-style UNet: {'βœ… YES' if has_sd_unet else '❌ NO'}")
        print(f"βœ“ VAE included: {'βœ… YES' if has_vae else '❌ NO'}")
        print(f"βœ“ Text Encoder included: {'βœ… YES' if has_text_encoder else '❌ NO'}")
        
        if has_flux_blocks:
            print("\nπŸ” Likely model type: FLUX")
        elif has_sd_unet:
            print("\nπŸ” Likely model type: Stable Diffusion")
        else:
            print("\n⚠️  Could not determine model type")
        
        # Check if complete checkpoint
        print("\n" + "=" * 80)
        print("CHECKPOINT COMPLETENESS")
        print("=" * 80)
        
        if has_vae and has_text_encoder:
            print("\nβœ… This appears to be a COMPLETE checkpoint")
            print("   (Contains UNet/Transformer + VAE + Text Encoder)")
        else:
            print("\n⚠️  This appears to be a PARTIAL checkpoint")
            if not has_vae:
                print("   Missing: VAE")
            if not has_text_encoder:
                print("   Missing: Text Encoder")
    
    print("\n" + "=" * 80)
    print("INSPECTION COMPLETE")
    print("=" * 80)


def compare_checkpoints(working_checkpoint, broken_checkpoint):
    """
    Compare two checkpoints to see the differences.
    
    Args:
        working_checkpoint: Path to checkpoint that works
        broken_checkpoint: Path to checkpoint that doesn't work
    """
    
    print("=" * 80)
    print("COMPARING CHECKPOINTS")
    print("=" * 80)
    
    with safe_open(working_checkpoint, framework="pt") as f1:
        keys1 = set(f1.keys())
    
    with safe_open(broken_checkpoint, framework="pt") as f2:
        keys2 = set(f2.keys())
    
    print(f"\nWorking checkpoint: {len(keys1)} keys")
    print(f"Broken checkpoint: {len(keys2)} keys")
    
    only_in_working = keys1 - keys2
    only_in_broken = keys2 - keys1
    common = keys1 & keys2
    
    print(f"\nCommon keys: {len(common)}")
    print(f"Only in working: {len(only_in_working)}")
    print(f"Only in broken: {len(only_in_broken)}")
    
    if only_in_working:
        print("\nπŸ” Keys present in WORKING but missing in BROKEN (first 20):")
        for key in sorted(only_in_working)[:20]:
            print(f"  - {key}")
    
    if only_in_broken:
        print("\nπŸ” Keys present in BROKEN but missing in WORKING (first 20):")
        for key in sorted(only_in_broken)[:20]:
            print(f"  + {key}")
    
    # Compare key patterns
    print("\n" + "=" * 80)
    print("KEY PATTERN COMPARISON")
    print("=" * 80)
    
    def get_prefixes(keys):
        prefixes = defaultdict(int)
        for key in keys:
            prefix = key.split('.')[0]
            prefixes[prefix] += 1
        return prefixes
    
    prefixes1 = get_prefixes(keys1)
    prefixes2 = get_prefixes(keys2)
    
    all_prefixes = set(prefixes1.keys()) | set(prefixes2.keys())
    
    print(f"\n{'Prefix':<30} {'Working':<15} {'Broken':<15}")
    print("-" * 60)
    for prefix in sorted(all_prefixes):
        count1 = prefixes1.get(prefix, 0)
        count2 = prefixes2.get(prefix, 0)
        status = "βœ…" if count1 == count2 else "⚠️ "
        print(f"{status} {prefix:<28} {count1:<15} {count2:<15}")


# Example usage
if __name__ == "__main__":
    # Inspect a single checkpoint
    print("OPTION 1: Inspect your working checkpoint")
    print("-" * 80)
    inspect_checkpoint(
        "../flux1-depth-dev_ComfyMerged.safetensors",
        detailed=False  # Set to True for full key listing
    )
    
    print("\n\n")
    
    # Compare two checkpoints
    # print("OPTION 2: Compare working vs broken checkpoint")
    # print("-" * 80)
    # compare_checkpoints(
    #     "Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors",
    #     "flux1-depth-dev_fp4_merged_model.safetensors"
    # )