from safetensors.torch import save_file, load_file import torch def merge_model_components( unet_path, vae_path, text_encoder_path, output_path ): """ Merge UNet, VAE, and text encoder into a single safetensors file. Args: unet_path: Path to the main model/unet safetensors file vae_path: Path to the VAE safetensors file text_encoder_path: Path to the text encoder/CLIP safetensors file output_path: Path where the merged file will be saved """ print("Loading UNet/Model weights...") unet_state = load_file(unet_path) print("Loading VAE weights...") vae_state = load_file(vae_path) print("Loading Text Encoder weights...") text_encoder_state = load_file(text_encoder_path) # Merge all state dictionaries print("Merging state dictionaries...") merged_state = {} # Add all UNet weights merged_state.update(unet_state) # Add VAE weights with proper prefixes if needed for key, value in vae_state.items(): # If keys don't already have 'vae.' prefix, add it if not key.startswith('vae.'): merged_state[f'vae.{key}'] = value else: merged_state[key] = value # Add text encoder weights with proper prefixes for key, value in text_encoder_state.items(): # If keys don't already have 'text_encoder.' prefix, add it if not key.startswith('text_encoder.'): merged_state[f'text_encoder.{key}'] = value else: merged_state[key] = value print(f"Total parameters in merged model: {len(merged_state)}") print(f"Saving merged model to {output_path}...") # Save the merged state dictionary save_file(merged_state, output_path) print("✅ Merge complete!") print(f"File saved to: {output_path}") # Print file size import os size_gb = os.path.getsize(output_path) / (1024**3) print(f"File size: {size_gb:.2f} GB") # Example usage if __name__ == "__main__": merge_model_components( unet_path="flux1-depth-dev.safetensors", vae_path="vae/diffusion_pytorch_model.safetensors", text_encoder_path="text_encoder/model.safetensors", output_path="flux1-depth-dev_merged_model.safetensors" )