srcphag commited on
Commit
45131ff
·
1 Parent(s): 9d2d5e5

New merged file with labels

Browse files
CheckSafetensors.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from safetensors import safe_open
2
+
3
+ with safe_open("flux1-depth-dev_fp4_merged_model.safetensors", framework="pt") as f:
4
+ keys = f.keys()
5
+
6
+ has_vae = any('vae' in k or 'decoder' in k for k in keys)
7
+ has_clip = any('text_encoder' in k or 'clip' in k for k in keys)
8
+ has_unet = any('unet' in k or 'transformer' in k for k in keys)
9
+
10
+ print(f"VAE: {has_vae}, CLIP: {has_clip}, UNet: {has_unet}")
InspectSafetensors.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from safetensors import safe_open
2
+ from collections import defaultdict
3
+ import os
4
+
5
+ def inspect_checkpoint(checkpoint_path, detailed=False):
6
+ """
7
+ Inspect the structure of a safetensors checkpoint file.
8
+
9
+ Args:
10
+ checkpoint_path: Path to the .safetensors file
11
+ detailed: If True, shows more detailed information
12
+ """
13
+
14
+ if not os.path.exists(checkpoint_path):
15
+ print(f"❌ File not found: {checkpoint_path}")
16
+ return
17
+
18
+ print("=" * 80)
19
+ print(f"INSPECTING: {os.path.basename(checkpoint_path)}")
20
+ print("=" * 80)
21
+
22
+ # File size
23
+ size_bytes = os.path.getsize(checkpoint_path)
24
+ size_gb = size_bytes / (1024**3)
25
+ print(f"\n📦 File Size: {size_gb:.2f} GB ({size_bytes:,} bytes)")
26
+
27
+ with safe_open(checkpoint_path, framework="pt") as f:
28
+ keys = list(f.keys())
29
+
30
+ print(f"\n📊 Total Parameters: {len(keys):,}")
31
+
32
+ # Categorize keys by component
33
+ print("\n" + "=" * 80)
34
+ print("COMPONENT BREAKDOWN")
35
+ print("=" * 80)
36
+
37
+ categories = defaultdict(list)
38
+
39
+ for key in keys:
40
+ # Categorize by prefix
41
+ if any(x in key.lower() for x in ['vae', 'first_stage', 'decoder', 'encoder', 'quant_conv', 'post_quant']):
42
+ categories['VAE'].append(key)
43
+ elif any(x in key.lower() for x in ['text_encoder', 'cond_stage', 'clip', 'transformer.text_model']):
44
+ categories['Text Encoder'].append(key)
45
+ elif any(x in key.lower() for x in ['model.diffusion', 'diffusion_model', 'transformer', 'double_blocks', 'single_blocks']):
46
+ categories['UNet/Transformer'].append(key)
47
+ else:
48
+ categories['Other'].append(key)
49
+
50
+ # Print summary
51
+ for category, cat_keys in sorted(categories.items()):
52
+ print(f"\n{category}: {len(cat_keys)} parameters")
53
+
54
+ # Analyze key patterns
55
+ print("\n" + "=" * 80)
56
+ print("KEY PATTERNS")
57
+ print("=" * 80)
58
+
59
+ # Group by top-level prefix
60
+ prefix_groups = defaultdict(int)
61
+ for key in keys:
62
+ prefix = key.split('.')[0] if '.' in key else key
63
+ prefix_groups[prefix] += 1
64
+
65
+ print("\nTop-level prefixes:")
66
+ for prefix, count in sorted(prefix_groups.items(), key=lambda x: -x[1]):
67
+ print(f" {prefix}: {count} parameters")
68
+
69
+ # Show sample keys from each category
70
+ print("\n" + "=" * 80)
71
+ print("SAMPLE KEYS FROM EACH COMPONENT")
72
+ print("=" * 80)
73
+
74
+ for category, cat_keys in sorted(categories.items()):
75
+ if cat_keys:
76
+ print(f"\n{category} (showing first 5):")
77
+ for key in cat_keys[:5]:
78
+ tensor = f.get_tensor(key)
79
+ print(f" {key}")
80
+ print(f" └─ shape: {tuple(tensor.shape)}, dtype: {tensor.dtype}")
81
+
82
+ if detailed:
83
+ print("\n" + "=" * 80)
84
+ print("ALL KEYS (DETAILED)")
85
+ print("=" * 80)
86
+
87
+ for i, key in enumerate(keys, 1):
88
+ tensor = f.get_tensor(key)
89
+ print(f"\n{i}. {key}")
90
+ print(f" Shape: {tuple(tensor.shape)}")
91
+ print(f" Dtype: {tensor.dtype}")
92
+ print(f" Size: {tensor.numel():,} elements")
93
+
94
+ # Check for common FLUX/SD patterns
95
+ print("\n" + "=" * 80)
96
+ print("MODEL TYPE DETECTION")
97
+ print("=" * 80)
98
+
99
+ has_flux_blocks = any('double_blocks' in k or 'single_blocks' in k for k in keys)
100
+ has_sd_unet = any('model.diffusion_model' in k for k in keys)
101
+ has_vae = any('vae' in k.lower() or 'first_stage' in k for k in keys)
102
+ has_text_encoder = any('text_encoder' in k.lower() or 'cond_stage' in k for k in keys)
103
+
104
+ print(f"\n✓ FLUX-style blocks: {'✅ YES' if has_flux_blocks else '❌ NO'}")
105
+ print(f"✓ SD-style UNet: {'✅ YES' if has_sd_unet else '❌ NO'}")
106
+ print(f"✓ VAE included: {'✅ YES' if has_vae else '❌ NO'}")
107
+ print(f"✓ Text Encoder included: {'✅ YES' if has_text_encoder else '❌ NO'}")
108
+
109
+ if has_flux_blocks:
110
+ print("\n🔍 Likely model type: FLUX")
111
+ elif has_sd_unet:
112
+ print("\n🔍 Likely model type: Stable Diffusion")
113
+ else:
114
+ print("\n⚠️ Could not determine model type")
115
+
116
+ # Check if complete checkpoint
117
+ print("\n" + "=" * 80)
118
+ print("CHECKPOINT COMPLETENESS")
119
+ print("=" * 80)
120
+
121
+ if has_vae and has_text_encoder:
122
+ print("\n✅ This appears to be a COMPLETE checkpoint")
123
+ print(" (Contains UNet/Transformer + VAE + Text Encoder)")
124
+ else:
125
+ print("\n⚠️ This appears to be a PARTIAL checkpoint")
126
+ if not has_vae:
127
+ print(" Missing: VAE")
128
+ if not has_text_encoder:
129
+ print(" Missing: Text Encoder")
130
+
131
+ print("\n" + "=" * 80)
132
+ print("INSPECTION COMPLETE")
133
+ print("=" * 80)
134
+
135
+
136
+ def compare_checkpoints(working_checkpoint, broken_checkpoint):
137
+ """
138
+ Compare two checkpoints to see the differences.
139
+
140
+ Args:
141
+ working_checkpoint: Path to checkpoint that works
142
+ broken_checkpoint: Path to checkpoint that doesn't work
143
+ """
144
+
145
+ print("=" * 80)
146
+ print("COMPARING CHECKPOINTS")
147
+ print("=" * 80)
148
+
149
+ with safe_open(working_checkpoint, framework="pt") as f1:
150
+ keys1 = set(f1.keys())
151
+
152
+ with safe_open(broken_checkpoint, framework="pt") as f2:
153
+ keys2 = set(f2.keys())
154
+
155
+ print(f"\nWorking checkpoint: {len(keys1)} keys")
156
+ print(f"Broken checkpoint: {len(keys2)} keys")
157
+
158
+ only_in_working = keys1 - keys2
159
+ only_in_broken = keys2 - keys1
160
+ common = keys1 & keys2
161
+
162
+ print(f"\nCommon keys: {len(common)}")
163
+ print(f"Only in working: {len(only_in_working)}")
164
+ print(f"Only in broken: {len(only_in_broken)}")
165
+
166
+ if only_in_working:
167
+ print("\n🔍 Keys present in WORKING but missing in BROKEN (first 20):")
168
+ for key in sorted(only_in_working)[:20]:
169
+ print(f" - {key}")
170
+
171
+ if only_in_broken:
172
+ print("\n🔍 Keys present in BROKEN but missing in WORKING (first 20):")
173
+ for key in sorted(only_in_broken)[:20]:
174
+ print(f" + {key}")
175
+
176
+ # Compare key patterns
177
+ print("\n" + "=" * 80)
178
+ print("KEY PATTERN COMPARISON")
179
+ print("=" * 80)
180
+
181
+ def get_prefixes(keys):
182
+ prefixes = defaultdict(int)
183
+ for key in keys:
184
+ prefix = key.split('.')[0]
185
+ prefixes[prefix] += 1
186
+ return prefixes
187
+
188
+ prefixes1 = get_prefixes(keys1)
189
+ prefixes2 = get_prefixes(keys2)
190
+
191
+ all_prefixes = set(prefixes1.keys()) | set(prefixes2.keys())
192
+
193
+ print(f"\n{'Prefix':<30} {'Working':<15} {'Broken':<15}")
194
+ print("-" * 60)
195
+ for prefix in sorted(all_prefixes):
196
+ count1 = prefixes1.get(prefix, 0)
197
+ count2 = prefixes2.get(prefix, 0)
198
+ status = "✅" if count1 == count2 else "⚠️ "
199
+ print(f"{status} {prefix:<28} {count1:<15} {count2:<15}")
200
+
201
+
202
+ # Example usage
203
+ if __name__ == "__main__":
204
+ # Inspect a single checkpoint
205
+ print("OPTION 1: Inspect your working checkpoint")
206
+ print("-" * 80)
207
+ inspect_checkpoint(
208
+ "test/Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors",
209
+ detailed=False # Set to True for full key listing
210
+ )
211
+
212
+ print("\n\n")
213
+
214
+ # Compare two checkpoints
215
+ print("OPTION 2: Compare working vs broken checkpoint")
216
+ print("-" * 80)
217
+ # compare_checkpoints(
218
+ # "path/to/working_checkpoint.safetensors",
219
+ # "path/to/broken_checkpoint.safetensors"
220
+ # )
MergeSafetensors.py CHANGED
@@ -67,8 +67,8 @@ def merge_model_components(
67
  # Example usage
68
  if __name__ == "__main__":
69
  merge_model_components(
70
- unet_path="svdq-fp4_r32-flux.1-depth-dev.safetensors",
71
  vae_path="vae/diffusion_pytorch_model.safetensors",
72
  text_encoder_path="text_encoder/model.safetensors",
73
- output_path="flux1-depth-dev_fp4_merged_model.safetensors"
74
  )
 
67
  # Example usage
68
  if __name__ == "__main__":
69
  merge_model_components(
70
+ unet_path="flux1-depth-dev.safetensors",
71
  vae_path="vae/diffusion_pytorch_model.safetensors",
72
  text_encoder_path="text_encoder/model.safetensors",
73
+ output_path="flux1-depth-dev_merged_model.safetensors"
74
  )
MergeSafetensors2.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from safetensors.torch import save_file, load_file
2
+ import torch
3
+ import os
4
+
5
+ def inspect_keys(file_path, max_keys=10):
6
+ """Helper function to inspect the structure of a safetensors file."""
7
+ state = load_file(file_path)
8
+ keys = list(state.keys())
9
+ print(f"\n{os.path.basename(file_path)} - Total keys: {len(keys)}")
10
+ print(f"First {max_keys} keys:")
11
+ for k in keys[:max_keys]:
12
+ print(f" {k}")
13
+ return keys
14
+
15
+ def merge_for_comfyui(
16
+ unet_path,
17
+ vae_path,
18
+ text_encoder_path,
19
+ output_path,
20
+ model_type="flux" # "flux", "sd15", "sdxl"
21
+ ):
22
+ """
23
+ Merge components into ComfyUI-compatible safetensors checkpoint.
24
+
25
+ Args:
26
+ unet_path: Path to the main model/transformer safetensors
27
+ vae_path: Path to the VAE safetensors
28
+ text_encoder_path: Path to the text encoder/CLIP safetensors
29
+ output_path: Path for the merged checkpoint
30
+ model_type: Type of model (flux, sd15, sdxl)
31
+ """
32
+
33
+ print("=" * 60)
34
+ print("STEP 1: Inspecting input files...")
35
+ print("=" * 60)
36
+
37
+ # Inspect each file to understand structure
38
+ unet_keys = inspect_keys(unet_path)
39
+ vae_keys = inspect_keys(vae_path)
40
+ text_encoder_keys = inspect_keys(text_encoder_path)
41
+
42
+ print("\n" + "=" * 60)
43
+ print("STEP 2: Loading weights...")
44
+ print("=" * 60)
45
+
46
+ unet_state = load_file(unet_path)
47
+ vae_state = load_file(vae_path)
48
+ text_encoder_state = load_file(text_encoder_path)
49
+
50
+ print("\n" + "=" * 60)
51
+ print("STEP 3: Merging with proper key structure...")
52
+ print("=" * 60)
53
+
54
+ merged_state = {}
55
+
56
+ # Determine key prefixes based on existing structure
57
+ sample_unet_key = unet_keys[0]
58
+ sample_vae_key = vae_keys[0]
59
+ sample_te_key = text_encoder_keys[0]
60
+
61
+ print(f"\nDetected key patterns:")
62
+ print(f" UNet: {sample_unet_key}")
63
+ print(f" VAE: {sample_vae_key}")
64
+ print(f" Text Encoder: {sample_te_key}")
65
+
66
+ # Add UNet/Transformer weights
67
+ for key, value in unet_state.items():
68
+ # Keep original keys or add model prefix if needed
69
+ if key.startswith('model.') or key.startswith('diffusion_model.'):
70
+ merged_state[key] = value
71
+ else:
72
+ # Add ComfyUI-expected prefix
73
+ merged_state[f'model.diffusion_model.{key}'] = value
74
+
75
+ # Add VAE weights with proper structure
76
+ for key, value in vae_state.items():
77
+ if key.startswith('first_stage_model.') or key.startswith('vae.'):
78
+ merged_state[key] = value
79
+ elif key.startswith('decoder.') or key.startswith('encoder.'):
80
+ merged_state[f'first_stage_model.{key}'] = value
81
+ else:
82
+ merged_state[f'first_stage_model.decoder.{key}'] = value
83
+
84
+ # Add text encoder weights
85
+ for key, value in text_encoder_state.items():
86
+ if key.startswith('cond_stage_model.') or key.startswith('text_encoder.'):
87
+ merged_state[key] = value
88
+ else:
89
+ # For FLUX, might need different structure
90
+ if model_type.lower() == "flux":
91
+ merged_state[f'text_encoders.{key}'] = value
92
+ else:
93
+ merged_state[f'cond_stage_model.transformer.{key}'] = value
94
+
95
+ print(f"\nMerged state contains {len(merged_state)} parameters")
96
+
97
+ # Add metadata for ComfyUI recognition
98
+ print("\n" + "=" * 60)
99
+ print("STEP 4: Saving merged checkpoint...")
100
+ print("=" * 60)
101
+
102
+ save_file(merged_state, output_path)
103
+
104
+ print("\n✅ Merge complete!")
105
+ print(f"File saved to: {output_path}")
106
+
107
+ size_gb = os.path.getsize(output_path) / (1024**3)
108
+ print(f"File size: {size_gb:.2f} GB")
109
+
110
+ # Verify the merged file
111
+ print("\n" + "=" * 60)
112
+ print("STEP 5: Verifying merged file...")
113
+ print("=" * 60)
114
+ inspect_keys(output_path, max_keys=20)
115
+
116
+
117
+ def simple_merge_keep_structure(
118
+ unet_path,
119
+ vae_path,
120
+ text_encoder_path,
121
+ output_path
122
+ ):
123
+ """
124
+ Simple merge that preserves original key structure.
125
+ Use this if the files already have proper ComfyUI keys.
126
+ """
127
+ print("Loading all components...")
128
+
129
+ unet_state = load_file(unet_path)
130
+ vae_state = load_file(vae_path)
131
+ text_encoder_state = load_file(text_encoder_path)
132
+
133
+ print("Merging...")
134
+ merged_state = {}
135
+ merged_state.update(unet_state)
136
+ merged_state.update(vae_state)
137
+ merged_state.update(text_encoder_state)
138
+
139
+ print(f"Saving {len(merged_state)} parameters...")
140
+ save_file(merged_state, output_path)
141
+
142
+ size_gb = os.path.getsize(output_path) / (1024**3)
143
+ print(f"✅ Done! File size: {size_gb:.2f} GB")
144
+
145
+
146
+ # Example usage
147
+ if __name__ == "__main__":
148
+ # Option 1: Smart merge with key detection
149
+ merge_for_comfyui(
150
+ unet_path="svdq-fp4_r32-flux.1-depth-dev.safetensors",
151
+ vae_path="vae/diffusion_pytorch_model.safetensors",
152
+ text_encoder_path="text_encoder/model.safetensors",
153
+ output_path="flux1-depth-dev_fp4_merged_model.safetensors",
154
+ model_type="flux"
155
+ )
156
+
157
+ # Option 2: Simple merge (if keys are already correct)
158
+ # simple_merge_keep_structure(
159
+ # unet_path="path/to/model.safetensors",
160
+ # vae_path="path/to/vae.safetensors",
161
+ # text_encoder_path="path/to/text_encoder.safetensors",
162
+ # output_path="merged_checkpoint.safetensors"
163
+ # )
ae.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:afc8e28272cd15db3919bacdb6918ce9c1ed22e96cb12c4d5ed0fba823529e38
3
- size 335304388
 
 
 
 
flux1-depth-dev_fp4_merged_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:44d0a7a1c2353131a545f3a62ceff6a5affa2b5cfcc28f334a075bb58df8819b
3
- size 7866673892
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7eff742003708c77b291d7ca416bf695dc7c01a3a26250febac7c79ee6f390d
3
+ size 7866741316