from safetensors.torch import save_file, load_file import torch import os def inspect_keys(file_path, max_keys=10): """Helper function to inspect the structure of a safetensors file.""" state = load_file(file_path) keys = list(state.keys()) print(f"\n{os.path.basename(file_path)} - Total keys: {len(keys)}") print(f"First {max_keys} keys:") for k in keys[:max_keys]: print(f" {k}") return keys def merge_for_comfyui( unet_path, vae_path, text_encoder_path, output_path, model_type="flux" # "flux", "sd15", "sdxl" ): """ Merge components into ComfyUI-compatible safetensors checkpoint. Args: unet_path: Path to the main model/transformer safetensors vae_path: Path to the VAE safetensors text_encoder_path: Path to the text encoder/CLIP safetensors output_path: Path for the merged checkpoint model_type: Type of model (flux, sd15, sdxl) """ print("=" * 60) print("STEP 1: Inspecting input files...") print("=" * 60) # Inspect each file to understand structure unet_keys = inspect_keys(unet_path) vae_keys = inspect_keys(vae_path) text_encoder_keys = inspect_keys(text_encoder_path) print("\n" + "=" * 60) print("STEP 2: Loading weights...") print("=" * 60) unet_state = load_file(unet_path) vae_state = load_file(vae_path) text_encoder_state = load_file(text_encoder_path) print("\n" + "=" * 60) print("STEP 3: Merging with proper key structure...") print("=" * 60) merged_state = {} # Determine key prefixes based on existing structure sample_unet_key = unet_keys[0] sample_vae_key = vae_keys[0] sample_te_key = text_encoder_keys[0] print(f"\nDetected key patterns:") print(f" UNet: {sample_unet_key}") print(f" VAE: {sample_vae_key}") print(f" Text Encoder: {sample_te_key}") # Add UNet/Transformer weights for key, value in unet_state.items(): # Keep original keys or add model prefix if needed if key.startswith('model.') or key.startswith('diffusion_model.'): merged_state[key] = value else: # Add ComfyUI-expected prefix merged_state[f'model.diffusion_model.{key}'] = value # Add VAE weights with proper structure for key, value in vae_state.items(): if key.startswith('first_stage_model.') or key.startswith('vae.'): merged_state[key] = value elif key.startswith('decoder.') or key.startswith('encoder.'): merged_state[f'first_stage_model.{key}'] = value else: merged_state[f'first_stage_model.decoder.{key}'] = value # Add text encoder weights for key, value in text_encoder_state.items(): if key.startswith('cond_stage_model.') or key.startswith('text_encoder.'): merged_state[key] = value else: # For FLUX, might need different structure if model_type.lower() == "flux": merged_state[f'text_encoders.{key}'] = value else: merged_state[f'cond_stage_model.transformer.{key}'] = value print(f"\nMerged state contains {len(merged_state)} parameters") # Add metadata for ComfyUI recognition print("\n" + "=" * 60) print("STEP 4: Saving merged checkpoint...") print("=" * 60) save_file(merged_state, output_path) print("\nāœ… Merge complete!") print(f"File saved to: {output_path}") size_gb = os.path.getsize(output_path) / (1024**3) print(f"File size: {size_gb:.2f} GB") # Verify the merged file print("\n" + "=" * 60) print("STEP 5: Verifying merged file...") print("=" * 60) inspect_keys(output_path, max_keys=20) def simple_merge_keep_structure( unet_path, vae_path, text_encoder_path, output_path ): """ Simple merge that preserves original key structure. Use this if the files already have proper ComfyUI keys. """ print("Loading all components...") unet_state = load_file(unet_path) vae_state = load_file(vae_path) text_encoder_state = load_file(text_encoder_path) print("Merging...") merged_state = {} merged_state.update(unet_state) merged_state.update(vae_state) merged_state.update(text_encoder_state) print(f"Saving {len(merged_state)} parameters...") save_file(merged_state, output_path) size_gb = os.path.getsize(output_path) / (1024**3) print(f"āœ… Done! File size: {size_gb:.2f} GB") # Example usage if __name__ == "__main__": # Option 1: Smart merge with key detection merge_for_comfyui( unet_path="../flux1-depth-dev.safetensors", vae_path="../vae/diffusion_pytorch_model.safetensors", text_encoder_path="../text_encoder/model.safetensors", output_path="../flux1-depth-dev_merged_model.safetensors", model_type="flux" ) # Option 2: Simple merge (if keys are already correct) # simple_merge_keep_structure( # unet_path="path/to/model.safetensors", # vae_path="path/to/vae.safetensors", # text_encoder_path="path/to/text_encoder.safetensors", # output_path="merged_checkpoint.safetensors" # )